- add Loss, Optimizer - change Trainer & Tester initialization interface: two styles of definition provided - handle Optimizer construction and loss function definition in a hard manner - add argparse in task-specific scripts. (seq_labeling.py & text_classify.py) - seq_labeling.py & text_classify.py worktags/v0.1.0
| @@ -0,0 +1,27 @@ | |||||
| import torch | |||||
| class Loss(object): | |||||
| """Loss function of the algorithm, | |||||
| either the wrapper of a loss function from framework, or a user-defined loss (need pytorch auto_grad support) | |||||
| """ | |||||
| def __init__(self, args): | |||||
| if args is None: | |||||
| # this is useful when | |||||
| self._loss = None | |||||
| elif isinstance(args, str): | |||||
| self._loss = self._borrow_from_pytorch(args) | |||||
| else: | |||||
| raise NotImplementedError | |||||
| def get(self): | |||||
| return self._loss | |||||
| @staticmethod | |||||
| def _borrow_from_pytorch(loss_name): | |||||
| if loss_name == "cross_entropy": | |||||
| return torch.nn.CrossEntropyLoss() | |||||
| else: | |||||
| raise NotImplementedError | |||||
| @@ -1,3 +1,54 @@ | |||||
| """ | |||||
| use optimizer from Pytorch | |||||
| """ | |||||
| import torch | |||||
| class Optimizer(object): | |||||
| """Wrapper of optimizer from framework | |||||
| names: arguments (type) | |||||
| 1. Adam: lr (float), weight_decay (float) | |||||
| 2. AdaGrad | |||||
| 3. RMSProp | |||||
| 4. SGD: lr (float), momentum (float) | |||||
| """ | |||||
| def __init__(self, optimizer_name, **kwargs): | |||||
| """ | |||||
| :param optimizer_name: str, the name of the optimizer | |||||
| :param kwargs: the arguments | |||||
| """ | |||||
| self.optim_name = optimizer_name | |||||
| self.kwargs = kwargs | |||||
| @property | |||||
| def name(self): | |||||
| return self.optim_name | |||||
| @property | |||||
| def params(self): | |||||
| return self.kwargs | |||||
| def construct_from_pytorch(self, model_params): | |||||
| """construct a optimizer from framework over given model parameters""" | |||||
| if self.optim_name in ["SGD", "sgd"]: | |||||
| if "lr" in self.kwargs: | |||||
| if "momentum" not in self.kwargs: | |||||
| self.kwargs["momentum"] = 0 | |||||
| optimizer = torch.optim.SGD(model_params, lr=self.kwargs["lr"], momentum=self.kwargs["momentum"]) | |||||
| else: | |||||
| raise ValueError("requires learning rate for SGD optimizer") | |||||
| elif self.optim_name in ["adam", "Adam"]: | |||||
| if "lr" in self.kwargs: | |||||
| if "weight_decay" not in self.kwargs: | |||||
| self.kwargs["weight_decay"] = 0 | |||||
| optimizer = torch.optim.Adam(model_params, lr=self.kwargs["lr"], | |||||
| weight_decay=self.kwargs["weight_decay"]) | |||||
| else: | |||||
| raise ValueError("requires learning rate for Adam optimizer") | |||||
| else: | |||||
| raise NotImplementedError | |||||
| return optimizer | |||||
| @@ -1,5 +1,3 @@ | |||||
| import _pickle | |||||
| import numpy as np | import numpy as np | ||||
| import torch | import torch | ||||
| @@ -14,43 +12,78 @@ logger = create_logger(__name__, "./train_test.log") | |||||
| class BaseTester(object): | class BaseTester(object): | ||||
| """An collection of model inference and evaluation of performance, used over validation/dev set and test set. """ | """An collection of model inference and evaluation of performance, used over validation/dev set and test set. """ | ||||
| def __init__(self, test_args): | |||||
| def __init__(self, **kwargs): | |||||
| """ | """ | ||||
| :param test_args: a dict-like object that has __getitem__ method, can be accessed by "test_args["key_str"]" | |||||
| :param kwargs: a dict-like object that has __getitem__ method, can be accessed by "test_args["key_str"]" | |||||
| """ | """ | ||||
| super(BaseTester, self).__init__() | super(BaseTester, self).__init__() | ||||
| self.validate_in_training = test_args["validate_in_training"] | |||||
| self.save_dev_data = None | |||||
| self.save_output = test_args["save_output"] | |||||
| self.output = None | |||||
| self.save_loss = test_args["save_loss"] | |||||
| self.mean_loss = None | |||||
| self.batch_size = test_args["batch_size"] | |||||
| self.pickle_path = test_args["pickle_path"] | |||||
| self.iterator = None | |||||
| self.use_cuda = test_args["use_cuda"] | |||||
| self.model = None | |||||
| """ | |||||
| "default_args" provides default value for important settings. | |||||
| The initialization arguments "kwargs" with the same key (name) will override the default value. | |||||
| "kwargs" must have the same type as "default_args" on corresponding keys. | |||||
| Otherwise, error will raise. | |||||
| """ | |||||
| default_args = {"save_output": False, # collect outputs of validation set | |||||
| "save_loss": False, # collect losses in validation | |||||
| "save_best_dev": False, # save best model during validation | |||||
| "batch_size": 8, | |||||
| "use_cuda": True, | |||||
| "pickle_path": "./save/", | |||||
| "model_name": "dev_best_model.pkl", | |||||
| "print_every_step": 1, | |||||
| } | |||||
| """ | |||||
| "required_args" is the collection of arguments that users must pass to Trainer explicitly. | |||||
| This is used to warn users of essential settings in the training. | |||||
| Obviously, "required_args" is the subset of "default_args". | |||||
| The value in "default_args" to the keys in "required_args" is simply for type check. | |||||
| """ | |||||
| # TODO: required arguments | |||||
| required_args = {} | |||||
| for req_key in required_args: | |||||
| if req_key not in kwargs: | |||||
| logger.error("Tester lacks argument {}".format(req_key)) | |||||
| raise ValueError("Tester lacks argument {}".format(req_key)) | |||||
| for key in default_args: | |||||
| if key in kwargs: | |||||
| if isinstance(kwargs[key], type(default_args[key])): | |||||
| default_args[key] = kwargs[key] | |||||
| else: | |||||
| msg = "Argument %s type mismatch: expected %s while get %s" % ( | |||||
| key, type(default_args[key]), type(kwargs[key])) | |||||
| logger.error(msg) | |||||
| raise ValueError(msg) | |||||
| else: | |||||
| # BeseTester doesn't care about extra arguments | |||||
| pass | |||||
| print(default_args) | |||||
| self.save_output = default_args["save_output"] | |||||
| self.save_best_dev = default_args["save_best_dev"] | |||||
| self.save_loss = default_args["save_loss"] | |||||
| self.batch_size = default_args["batch_size"] | |||||
| self.pickle_path = default_args["pickle_path"] | |||||
| self.use_cuda = default_args["use_cuda"] | |||||
| self.print_every_step = default_args["print_every_step"] | |||||
| self._model = None | |||||
| self.eval_history = [] | self.eval_history = [] | ||||
| self.batch_output = [] | self.batch_output = [] | ||||
| def test(self, network, dev_data): | def test(self, network, dev_data): | ||||
| if torch.cuda.is_available() and self.use_cuda: | if torch.cuda.is_available() and self.use_cuda: | ||||
| self.model = network.cuda() | |||||
| self._model = network.cuda() | |||||
| else: | else: | ||||
| self.model = network | |||||
| self._model = network | |||||
| # turn on the testing mode; clean up the history | # turn on the testing mode; clean up the history | ||||
| self.mode(network, test=True) | self.mode(network, test=True) | ||||
| self.eval_history.clear() | self.eval_history.clear() | ||||
| self.batch_output.clear() | self.batch_output.clear() | ||||
| # dev_data = self.prepare_input(self.pickle_path) | |||||
| # logger.info("validation data loaded") | |||||
| iterator = iter(Batchifier(RandomSampler(dev_data), self.batch_size, drop_last=True)) | iterator = iter(Batchifier(RandomSampler(dev_data), self.batch_size, drop_last=True)) | ||||
| n_batches = len(dev_data) // self.batch_size | |||||
| print_every_step = 1 | |||||
| step = 0 | step = 0 | ||||
| for batch_x, batch_y in self.make_batch(iterator, dev_data): | for batch_x, batch_y in self.make_batch(iterator, dev_data): | ||||
| @@ -65,21 +98,10 @@ class BaseTester(object): | |||||
| print_output = "[test step {}] {}".format(step, eval_results) | print_output = "[test step {}] {}".format(step, eval_results) | ||||
| logger.info(print_output) | logger.info(print_output) | ||||
| if step % print_every_step == 0: | |||||
| if step % self.print_every_step == 0: | |||||
| print(print_output) | print(print_output) | ||||
| step += 1 | step += 1 | ||||
| def prepare_input(self, data_path): | |||||
| """Save the dev data once it is loaded. Can return directly next time. | |||||
| :param data_path: str, the path to the pickle data for dev | |||||
| :return save_dev_data: list. Each entry is a sample, which is also a list of features and label(s). | |||||
| """ | |||||
| if self.save_dev_data is None: | |||||
| data_dev = _pickle.load(open(data_path + "data_dev.pkl", "rb")) | |||||
| self.save_dev_data = data_dev | |||||
| return self.save_dev_data | |||||
| def mode(self, model, test): | def mode(self, model, test): | ||||
| """Train mode or Test mode. This is for PyTorch currently. | """Train mode or Test mode. This is for PyTorch currently. | ||||
| @@ -117,15 +139,14 @@ class SeqLabelTester(BaseTester): | |||||
| Tester for sequence labeling. | Tester for sequence labeling. | ||||
| """ | """ | ||||
| def __init__(self, test_args): | |||||
| def __init__(self, **test_args): | |||||
| """ | """ | ||||
| :param test_args: a dict-like object that has __getitem__ method, can be accessed by "test_args["key_str"]" | :param test_args: a dict-like object that has __getitem__ method, can be accessed by "test_args["key_str"]" | ||||
| """ | """ | ||||
| super(SeqLabelTester, self).__init__(test_args) | |||||
| super(SeqLabelTester, self).__init__(**test_args) | |||||
| self.max_len = None | self.max_len = None | ||||
| self.mask = None | self.mask = None | ||||
| self.seq_len = None | self.seq_len = None | ||||
| self.batch_result = None | |||||
| def data_forward(self, network, inputs): | def data_forward(self, network, inputs): | ||||
| """This is only for sequence labeling with CRF decoder. | """This is only for sequence labeling with CRF decoder. | ||||
| @@ -159,10 +180,10 @@ class SeqLabelTester(BaseTester): | |||||
| :return: | :return: | ||||
| """ | """ | ||||
| batch_size, max_len = predict.size(0), predict.size(1) | batch_size, max_len = predict.size(0), predict.size(1) | ||||
| loss = self.model.loss(predict, truth, self.mask) / batch_size | |||||
| loss = self._model.loss(predict, truth, self.mask) / batch_size | |||||
| prediction = self.model.prediction(predict, self.mask) | |||||
| results = torch.Tensor(prediction).view(-1,) | |||||
| prediction = self._model.prediction(predict, self.mask) | |||||
| results = torch.Tensor(prediction).view(-1, ) | |||||
| # make sure "results" is in the same device as "truth" | # make sure "results" is in the same device as "truth" | ||||
| results = results.to(truth) | results = results.to(truth) | ||||
| accuracy = torch.sum(results == truth.view((-1,))).to(torch.float) / results.shape[0] | accuracy = torch.sum(results == truth.view((-1,))).to(torch.float) / results.shape[0] | ||||
| @@ -184,21 +205,16 @@ class SeqLabelTester(BaseTester): | |||||
| def make_batch(self, iterator, data): | def make_batch(self, iterator, data): | ||||
| return Action.make_batch(iterator, use_cuda=self.use_cuda, output_length=True) | return Action.make_batch(iterator, use_cuda=self.use_cuda, output_length=True) | ||||
| class ClassificationTester(BaseTester): | class ClassificationTester(BaseTester): | ||||
| """Tester for classification.""" | """Tester for classification.""" | ||||
| def __init__(self, test_args): | |||||
| def __init__(self, **test_args): | |||||
| """ | """ | ||||
| :param test_args: a dict-like object that has __getitem__ method, \ | :param test_args: a dict-like object that has __getitem__ method, \ | ||||
| can be accessed by "test_args["key_str"]" | can be accessed by "test_args["key_str"]" | ||||
| """ | """ | ||||
| super(ClassificationTester, self).__init__(test_args) | |||||
| self.pickle_path = test_args["pickle_path"] | |||||
| self.save_dev_data = None | |||||
| self.output = None | |||||
| self.mean_loss = None | |||||
| self.iterator = None | |||||
| super(ClassificationTester, self).__init__(**test_args) | |||||
| def make_batch(self, iterator, data, max_len=None): | def make_batch(self, iterator, data, max_len=None): | ||||
| return Action.make_batch(iterator, use_cuda=self.use_cuda, max_len=max_len) | return Action.make_batch(iterator, use_cuda=self.use_cuda, max_len=max_len) | ||||
| @@ -221,4 +237,3 @@ class ClassificationTester(BaseTester): | |||||
| y_true = torch.cat(y_true, dim=0) | y_true = torch.cat(y_true, dim=0) | ||||
| acc = float(torch.sum(y_pred == y_true)) / len(y_true) | acc = float(torch.sum(y_pred == y_true)) / len(y_true) | ||||
| return y_true.cpu().numpy(), y_prob.cpu().numpy(), acc | return y_true.cpu().numpy(), y_prob.cpu().numpy(), acc | ||||
| @@ -6,10 +6,11 @@ from datetime import timedelta | |||||
| import numpy as np | import numpy as np | ||||
| import torch | import torch | ||||
| import torch.nn as nn | |||||
| from fastNLP.core.action import Action | from fastNLP.core.action import Action | ||||
| from fastNLP.core.action import RandomSampler, Batchifier | from fastNLP.core.action import RandomSampler, Batchifier | ||||
| from fastNLP.core.loss import Loss | |||||
| from fastNLP.core.optimizer import Optimizer | |||||
| from fastNLP.core.tester import SeqLabelTester, ClassificationTester | from fastNLP.core.tester import SeqLabelTester, ClassificationTester | ||||
| from fastNLP.modules import utils | from fastNLP.modules import utils | ||||
| from fastNLP.saver.logger import create_logger | from fastNLP.saver.logger import create_logger | ||||
| @@ -23,14 +24,13 @@ class BaseTrainer(object): | |||||
| """Operations to train a model, including data loading, SGD, and validation. | """Operations to train a model, including data loading, SGD, and validation. | ||||
| Subclasses must implement the following abstract methods: | Subclasses must implement the following abstract methods: | ||||
| - define_optimizer | |||||
| - grad_backward | - grad_backward | ||||
| - get_loss | - get_loss | ||||
| """ | """ | ||||
| def __init__(self, train_args): | |||||
| def __init__(self, **kwargs): | |||||
| """ | """ | ||||
| :param train_args: dict of (key, value), or dict-like object. key is str. | |||||
| :param kwargs: dict of (key, value), or dict-like object. key is str. | |||||
| The base trainer requires the following keys: | The base trainer requires the following keys: | ||||
| - epochs: int, the number of epochs in training | - epochs: int, the number of epochs in training | ||||
| @@ -39,19 +39,58 @@ class BaseTrainer(object): | |||||
| - pickle_path: str, the path to pickle files for pre-processing | - pickle_path: str, the path to pickle files for pre-processing | ||||
| """ | """ | ||||
| super(BaseTrainer, self).__init__() | super(BaseTrainer, self).__init__() | ||||
| self.n_epochs = train_args["epochs"] | |||||
| self.batch_size = train_args["batch_size"] | |||||
| self.pickle_path = train_args["pickle_path"] | |||||
| self.validate = train_args["validate"] | |||||
| self.save_best_dev = train_args["save_best_dev"] | |||||
| self.model_saved_path = train_args["model_saved_path"] | |||||
| self.use_cuda = train_args["use_cuda"] | |||||
| self.model = None | |||||
| self.iterator = None | |||||
| self.loss_func = None | |||||
| self.optimizer = None | |||||
| """ | |||||
| "default_args" provides default value for important settings. | |||||
| The initialization arguments "kwargs" with the same key (name) will override the default value. | |||||
| "kwargs" must have the same type as "default_args" on corresponding keys. | |||||
| Otherwise, error will raise. | |||||
| """ | |||||
| default_args = {"epochs": 3, "batch_size": 8, "validate": True, "use_cuda": True, "pickle_path": "./save/", | |||||
| "save_best_dev": True, "model_name": "default_model_name.pkl", | |||||
| "loss": Loss(None), | |||||
| "optimizer": Optimizer("Adam", lr=0.001, weight_decay=0) | |||||
| } | |||||
| """ | |||||
| "required_args" is the collection of arguments that users must pass to Trainer explicitly. | |||||
| This is used to warn users of essential settings in the training. | |||||
| Obviously, "required_args" is the subset of "default_args". | |||||
| The value in "default_args" to the keys in "required_args" is simply for type check. | |||||
| """ | |||||
| # TODO: required arguments | |||||
| required_args = {} | |||||
| for req_key in required_args: | |||||
| if req_key not in kwargs: | |||||
| logger.error("Trainer lacks argument {}".format(req_key)) | |||||
| raise ValueError("Trainer lacks argument {}".format(req_key)) | |||||
| for key in default_args: | |||||
| if key in kwargs: | |||||
| if isinstance(kwargs[key], type(default_args[key])): | |||||
| default_args[key] = kwargs[key] | |||||
| else: | |||||
| msg = "Argument %s type mismatch: expected %s while get %s" % ( | |||||
| key, type(default_args[key]), type(kwargs[key])) | |||||
| logger.error(msg) | |||||
| raise ValueError(msg) | |||||
| else: | |||||
| # BaseTrainer doesn't care about extra arguments | |||||
| pass | |||||
| print(default_args) | |||||
| self.n_epochs = default_args["epochs"] | |||||
| self.batch_size = default_args["batch_size"] | |||||
| self.pickle_path = default_args["pickle_path"] | |||||
| self.validate = default_args["validate"] | |||||
| self.save_best_dev = default_args["save_best_dev"] | |||||
| self.use_cuda = default_args["use_cuda"] | |||||
| self.model_name = default_args["model_name"] | |||||
| self._model = None | |||||
| self._loss_func = default_args["loss"].get() # return a pytorch loss function or None | |||||
| self._optimizer = None | |||||
| self._optimizer_proto = default_args["optimizer"] | |||||
| def train(self, network, train_data, dev_data=None): | def train(self, network, train_data, dev_data=None): | ||||
| """General Training Steps | """General Training Steps | ||||
| @@ -72,12 +111,9 @@ class BaseTrainer(object): | |||||
| """ | """ | ||||
| # prepare model and data, transfer model to gpu if available | # prepare model and data, transfer model to gpu if available | ||||
| if torch.cuda.is_available() and self.use_cuda: | if torch.cuda.is_available() and self.use_cuda: | ||||
| self.model = network.cuda() | |||||
| self._model = network.cuda() | |||||
| else: | else: | ||||
| self.model = network | |||||
| # train_data = self.load_train_data(self.pickle_path) | |||||
| # logger.info("training data loaded") | |||||
| self._model = network | |||||
| # define tester over dev data | # define tester over dev data | ||||
| if self.validate: | if self.validate: | ||||
| @@ -88,7 +124,9 @@ class BaseTrainer(object): | |||||
| logger.info("validator defined as {}".format(str(validator))) | logger.info("validator defined as {}".format(str(validator))) | ||||
| self.define_optimizer() | self.define_optimizer() | ||||
| logger.info("optimizer defined as {}".format(str(self.optimizer))) | |||||
| logger.info("optimizer defined as {}".format(str(self._optimizer))) | |||||
| self.define_loss() | |||||
| logger.info("loss function defined as {}".format(str(self._loss_func))) | |||||
| # main training epochs | # main training epochs | ||||
| n_samples = len(train_data) | n_samples = len(train_data) | ||||
| @@ -113,7 +151,7 @@ class BaseTrainer(object): | |||||
| validator.test(network, dev_data) | validator.test(network, dev_data) | ||||
| if self.save_best_dev and self.best_eval_result(validator): | if self.save_best_dev and self.best_eval_result(validator): | ||||
| self.save_model(network) | |||||
| self.save_model(network, self.model_name) | |||||
| print("saved better model selected by dev") | print("saved better model selected by dev") | ||||
| logger.info("saved better model selected by dev") | logger.info("saved better model selected by dev") | ||||
| @@ -153,6 +191,11 @@ class BaseTrainer(object): | |||||
| logger.error("the number of folds in train and dev data unequals {}!={}".format(len(train_data_cv), | logger.error("the number of folds in train and dev data unequals {}!={}".format(len(train_data_cv), | ||||
| len(dev_data_cv))) | len(dev_data_cv))) | ||||
| raise RuntimeError("the number of folds in train and dev data unequals") | raise RuntimeError("the number of folds in train and dev data unequals") | ||||
| if self.validate is False: | |||||
| logger.warn("Cross validation requires self.validate to be True. Please turn it on. ") | |||||
| print("[warning] Cross validation requires self.validate to be True. Please turn it on. ") | |||||
| self.validate = True | |||||
| n_fold = len(train_data_cv) | n_fold = len(train_data_cv) | ||||
| logger.info("perform {} folds cross validation.".format(n_fold)) | logger.info("perform {} folds cross validation.".format(n_fold)) | ||||
| for i in range(n_fold): | for i in range(n_fold): | ||||
| @@ -186,7 +229,7 @@ class BaseTrainer(object): | |||||
| """ | """ | ||||
| Define framework-specific optimizer specified by the models. | Define framework-specific optimizer specified by the models. | ||||
| """ | """ | ||||
| raise NotImplementedError | |||||
| self._optimizer = self._optimizer_proto.construct_from_pytorch(self._model.parameters()) | |||||
| def update(self): | def update(self): | ||||
| """ | """ | ||||
| @@ -194,7 +237,7 @@ class BaseTrainer(object): | |||||
| For PyTorch, just call optimizer to update. | For PyTorch, just call optimizer to update. | ||||
| """ | """ | ||||
| raise NotImplementedError | |||||
| self._optimizer.step() | |||||
| def data_forward(self, network, x): | def data_forward(self, network, x): | ||||
| raise NotImplementedError | raise NotImplementedError | ||||
| @@ -206,7 +249,8 @@ class BaseTrainer(object): | |||||
| For PyTorch, just do "loss.backward()" | For PyTorch, just do "loss.backward()" | ||||
| """ | """ | ||||
| raise NotImplementedError | |||||
| self._model.zero_grad() | |||||
| loss.backward() | |||||
| def get_loss(self, predict, truth): | def get_loss(self, predict, truth): | ||||
| """ | """ | ||||
| @@ -215,21 +259,25 @@ class BaseTrainer(object): | |||||
| :param truth: ground truth label vector | :param truth: ground truth label vector | ||||
| :return: a scalar | :return: a scalar | ||||
| """ | """ | ||||
| if self.loss_func is None: | |||||
| if hasattr(self.model, "loss"): | |||||
| self.loss_func = self.model.loss | |||||
| logger.info("The model has a loss function, use it.") | |||||
| else: | |||||
| logger.info("The model didn't define loss, use Trainer's loss.") | |||||
| self.define_loss() | |||||
| return self.loss_func(predict, truth) | |||||
| return self._loss_func(predict, truth) | |||||
| def define_loss(self): | def define_loss(self): | ||||
| """ | """ | ||||
| Assign an instance of loss function to self.loss_func | |||||
| E.g. self.loss_func = nn.CrossEntropyLoss() | |||||
| if the model defines a loss, use model's loss. | |||||
| Otherwise, Trainer must has a loss argument, use it as loss. | |||||
| These two losses cannot be defined at the same time. | |||||
| Trainer does not handle loss definition or choose default losses. | |||||
| """ | """ | ||||
| raise NotImplementedError | |||||
| if hasattr(self._model, "loss") and self._loss_func is not None: | |||||
| raise ValueError("Both the model and Trainer define loss. Please take out your loss.") | |||||
| if hasattr(self._model, "loss"): | |||||
| self._loss_func = self._model.loss | |||||
| logger.info("The model has a loss function, use it.") | |||||
| else: | |||||
| if self._loss_func is None: | |||||
| raise ValueError("Please specify a loss function.") | |||||
| logger.info("The model didn't define loss, use Trainer's loss.") | |||||
| def best_eval_result(self, validator): | def best_eval_result(self, validator): | ||||
| """ | """ | ||||
| @@ -238,12 +286,15 @@ class BaseTrainer(object): | |||||
| """ | """ | ||||
| raise NotImplementedError | raise NotImplementedError | ||||
| def save_model(self, network): | |||||
| def save_model(self, network, model_name): | |||||
| """ | """ | ||||
| :param network: the PyTorch model | :param network: the PyTorch model | ||||
| :param model_name: str | |||||
| model_best_dev.pkl may be overwritten by a better model in future epochs. | model_best_dev.pkl may be overwritten by a better model in future epochs. | ||||
| """ | """ | ||||
| ModelSaver(self.model_saved_path + "model_best_dev.pkl").save_pytorch(network) | |||||
| if model_name[-4:] != ".pkl": | |||||
| model_name += ".pkl" | |||||
| ModelSaver(self.pickle_path + model_name).save_pytorch(network) | |||||
| def _create_validator(self, valid_args): | def _create_validator(self, valid_args): | ||||
| raise NotImplementedError | raise NotImplementedError | ||||
| @@ -266,18 +317,12 @@ class ToyTrainer(BaseTrainer): | |||||
| return network(x) | return network(x) | ||||
| def grad_backward(self, loss): | def grad_backward(self, loss): | ||||
| self.model.zero_grad() | |||||
| self._model.zero_grad() | |||||
| loss.backward() | loss.backward() | ||||
| def get_loss(self, pred, truth): | def get_loss(self, pred, truth): | ||||
| return np.mean(np.square(pred - truth)) | return np.mean(np.square(pred - truth)) | ||||
| def define_optimizer(self): | |||||
| self.optimizer = torch.optim.SGD(self.model.parameters(), lr=0.01) | |||||
| def update(self): | |||||
| self.optimizer.step() | |||||
| class SeqLabelTrainer(BaseTrainer): | class SeqLabelTrainer(BaseTrainer): | ||||
| """ | """ | ||||
| @@ -285,24 +330,14 @@ class SeqLabelTrainer(BaseTrainer): | |||||
| """ | """ | ||||
| def __init__(self, train_args): | |||||
| super(SeqLabelTrainer, self).__init__(train_args) | |||||
| self.vocab_size = train_args["vocab_size"] | |||||
| self.num_classes = train_args["num_classes"] | |||||
| def __init__(self, **kwargs): | |||||
| super(SeqLabelTrainer, self).__init__(**kwargs) | |||||
| # self.vocab_size = kwargs["vocab_size"] | |||||
| # self.num_classes = kwargs["num_classes"] | |||||
| self.max_len = None | self.max_len = None | ||||
| self.mask = None | self.mask = None | ||||
| self.best_accuracy = 0.0 | self.best_accuracy = 0.0 | ||||
| def define_optimizer(self): | |||||
| self.optimizer = torch.optim.SGD(self.model.parameters(), lr=0.01, momentum=0.9) | |||||
| def grad_backward(self, loss): | |||||
| self.model.zero_grad() | |||||
| loss.backward() | |||||
| def update(self): | |||||
| self.optimizer.step() | |||||
| def data_forward(self, network, inputs): | def data_forward(self, network, inputs): | ||||
| if not isinstance(inputs, tuple): | if not isinstance(inputs, tuple): | ||||
| raise RuntimeError("output_length must be true for sequence modeling. Receive {}".format(type(inputs[0]))) | raise RuntimeError("output_length must be true for sequence modeling. Receive {}".format(type(inputs[0]))) | ||||
| @@ -330,7 +365,7 @@ class SeqLabelTrainer(BaseTrainer): | |||||
| batch_size, max_len = predict.size(0), predict.size(1) | batch_size, max_len = predict.size(0), predict.size(1) | ||||
| assert truth.shape == (batch_size, max_len) | assert truth.shape == (batch_size, max_len) | ||||
| loss = self.model.loss(predict, truth, self.mask) | |||||
| loss = self._model.loss(predict, truth, self.mask) | |||||
| return loss | return loss | ||||
| def best_eval_result(self, validator): | def best_eval_result(self, validator): | ||||
| @@ -345,48 +380,25 @@ class SeqLabelTrainer(BaseTrainer): | |||||
| return Action.make_batch(iterator, output_length=True, use_cuda=self.use_cuda) | return Action.make_batch(iterator, output_length=True, use_cuda=self.use_cuda) | ||||
| def _create_validator(self, valid_args): | def _create_validator(self, valid_args): | ||||
| return SeqLabelTester(valid_args) | |||||
| return SeqLabelTester(**valid_args) | |||||
| class ClassificationTrainer(BaseTrainer): | class ClassificationTrainer(BaseTrainer): | ||||
| """Trainer for classification.""" | """Trainer for classification.""" | ||||
| def __init__(self, train_args): | |||||
| super(ClassificationTrainer, self).__init__(train_args) | |||||
| self.learn_rate = train_args["learn_rate"] | |||||
| self.momentum = train_args["momentum"] | |||||
| def __init__(self, **train_args): | |||||
| super(ClassificationTrainer, self).__init__(**train_args) | |||||
| self.iterator = None | self.iterator = None | ||||
| self.loss_func = None | self.loss_func = None | ||||
| self.optimizer = None | self.optimizer = None | ||||
| self.best_accuracy = 0 | self.best_accuracy = 0 | ||||
| def define_loss(self): | |||||
| self.loss_func = nn.CrossEntropyLoss() | |||||
| def define_optimizer(self): | |||||
| """ | |||||
| Define framework-specific optimizer specified by the models. | |||||
| """ | |||||
| self.optimizer = torch.optim.SGD( | |||||
| self.model.parameters(), | |||||
| lr=self.learn_rate, | |||||
| momentum=self.momentum) | |||||
| def data_forward(self, network, x): | def data_forward(self, network, x): | ||||
| """Forward through network.""" | """Forward through network.""" | ||||
| logits = network(x) | logits = network(x) | ||||
| return logits | return logits | ||||
| def grad_backward(self, loss): | |||||
| """Compute gradient backward.""" | |||||
| self.model.zero_grad() | |||||
| loss.backward() | |||||
| def update(self): | |||||
| """Apply gradient.""" | |||||
| self.optimizer.step() | |||||
| def make_batch(self, iterator): | def make_batch(self, iterator): | ||||
| return Action.make_batch(iterator, output_length=False, use_cuda=self.use_cuda) | return Action.make_batch(iterator, output_length=False, use_cuda=self.use_cuda) | ||||
| @@ -404,4 +416,4 @@ class ClassificationTrainer(BaseTrainer): | |||||
| return False | return False | ||||
| def _create_validator(self, valid_args): | def _create_validator(self, valid_args): | ||||
| return ClassificationTester(valid_args) | |||||
| return ClassificationTester(**valid_args) | |||||
| @@ -94,6 +94,10 @@ class ConfigSection(object): | |||||
| def __contains__(self, item): | def __contains__(self, item): | ||||
| return item in self.__dict__.keys() | return item in self.__dict__.keys() | ||||
| @property | |||||
| def data(self): | |||||
| return self.__dict__ | |||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||
| config = ConfigLoader('configLoader', 'there is no data') | config = ConfigLoader('configLoader', 'there is no data') | ||||
| @@ -18,7 +18,6 @@ MLP_HIDDEN = 2000 | |||||
| CLASSES_NUM = 5 | CLASSES_NUM = 5 | ||||
| from fastNLP.models.base_model import BaseModel | from fastNLP.models.base_model import BaseModel | ||||
| from fastNLP.core.trainer import BaseTrainer | |||||
| class MyNet(BaseModel): | class MyNet(BaseModel): | ||||
| @@ -60,18 +59,6 @@ class Net(nn.Module): | |||||
| return x, penalty | return x, penalty | ||||
| class MyTrainer(BaseTrainer): | |||||
| def __init__(self, args): | |||||
| super(MyTrainer, self).__init__(args) | |||||
| self.optimizer = None | |||||
| def define_optimizer(self): | |||||
| self.optimizer = optim.SGD(self.model.parameters(), lr=0.01, momentum=0.9) | |||||
| def define_loss(self): | |||||
| self.loss_func = nn.CrossEntropyLoss() | |||||
| def train(model_dict=None, using_cuda=True, learning_rate=0.06,\ | def train(model_dict=None, using_cuda=True, learning_rate=0.06,\ | ||||
| momentum=0.3, batch_size=32, epochs=5, coef=1.0, interval=10): | momentum=0.3, batch_size=32, epochs=5, coef=1.0, interval=10): | ||||
| """ | """ | ||||
| @@ -1,65 +1,11 @@ | |||||
| [General] | |||||
| revision = "first" | |||||
| datapath = "./data/smallset/imdb/" | |||||
| embed_path = "./data/smallset/imdb/embedding.txt" | |||||
| optimizer = "adam" | |||||
| attn_mode = "rout" | |||||
| seq_encoder = "bilstm" | |||||
| out_caps_num = 5 | |||||
| rout_iter = 3 | |||||
| max_snt_num = 40 | |||||
| max_wd_num = 40 | |||||
| max_epochs = 50 | |||||
| pre_trained = true | |||||
| batch_sz = 32 | |||||
| batch_sz_min = 32 | |||||
| bucket_sz = 5000 | |||||
| partial_update_until_epoch = 2 | |||||
| embed_size = 300 | |||||
| hidden_size = 200 | |||||
| dense_hidden = [300, 10] | |||||
| lr = 0.0002 | |||||
| decay_steps = 1000 | |||||
| decay_rate = 0.9 | |||||
| dropout = 0.2 | |||||
| early_stopping = 7 | |||||
| reg = 1e-06 | |||||
| [My] | |||||
| datapath = "./data/smallset/imdb/" | |||||
| embed_path = "./data/smallset/imdb/embedding.txt" | |||||
| optimizer = "adam" | |||||
| attn_mode = "rout" | |||||
| seq_encoder = "bilstm" | |||||
| out_caps_num = 5 | |||||
| rout_iter = 3 | |||||
| max_snt_num = 40 | |||||
| max_wd_num = 40 | |||||
| max_epochs = 50 | |||||
| pre_trained = true | |||||
| batch_sz = 32 | |||||
| batch_sz_min = 32 | |||||
| bucket_sz = 5000 | |||||
| partial_update_until_epoch = 2 | |||||
| embed_size = 300 | |||||
| hidden_size = 200 | |||||
| dense_hidden = [300, 10] | |||||
| lr = 0.0002 | |||||
| decay_steps = 1000 | |||||
| decay_rate = 0.9 | |||||
| dropout = 0.2 | |||||
| early_stopping = 70 | |||||
| reg = 1e-05 | |||||
| test = 5 | |||||
| new_attr = 40 | |||||
| [POS] | |||||
| [test_seq_label_trainer] | |||||
| epochs = 1 | epochs = 1 | ||||
| batch_size = 32 | batch_size = 32 | ||||
| pickle_path = "./data_for_tests/" | |||||
| validate = true | validate = true | ||||
| save_best_dev = true | save_best_dev = true | ||||
| model_saved_path = "./" | |||||
| use_cuda = true | |||||
| [test_seq_label_model] | |||||
| rnn_hidden_units = 100 | rnn_hidden_units = 100 | ||||
| rnn_layers = 1 | rnn_layers = 1 | ||||
| rnn_bi_direction = true | rnn_bi_direction = true | ||||
| @@ -68,13 +14,12 @@ dropout = 0.5 | |||||
| use_crf = true | use_crf = true | ||||
| use_cuda = true | use_cuda = true | ||||
| [POS_test] | |||||
| [test_seq_label_tester] | |||||
| save_output = true | save_output = true | ||||
| validate_in_training = true | validate_in_training = true | ||||
| save_dev_input = false | save_dev_input = false | ||||
| save_loss = true | save_loss = true | ||||
| batch_size = 1 | batch_size = 1 | ||||
| pickle_path = "./data_for_tests/" | |||||
| rnn_hidden_units = 100 | rnn_hidden_units = 100 | ||||
| rnn_layers = 1 | rnn_layers = 1 | ||||
| rnn_bi_direction = true | rnn_bi_direction = true | ||||
| @@ -84,7 +29,6 @@ use_crf = true | |||||
| use_cuda = true | use_cuda = true | ||||
| [POS_infer] | [POS_infer] | ||||
| pickle_path = "./data_for_tests/" | |||||
| rnn_hidden_units = 100 | rnn_hidden_units = 100 | ||||
| rnn_layers = 1 | rnn_layers = 1 | ||||
| rnn_bi_direction = true | rnn_bi_direction = true | ||||
| @@ -95,14 +39,9 @@ num_classes = 27 | |||||
| [text_class] | [text_class] | ||||
| epochs = 1 | epochs = 1 | ||||
| batch_size = 10 | batch_size = 10 | ||||
| pickle_path = "./save_path/" | |||||
| validate = false | validate = false | ||||
| save_best_dev = false | save_best_dev = false | ||||
| model_saved_path = "./save_path/" | |||||
| use_cuda = true | use_cuda = true | ||||
| learn_rate = 1e-3 | learn_rate = 1e-3 | ||||
| momentum = 0.9 | momentum = 0.9 | ||||
| [text_class_model] | |||||
| vocab_size = 867 | |||||
| num_classes = 18 | |||||
| model_name = "class_model.pkl" | |||||
| @@ -20,7 +20,7 @@ class MyNERTrainer(SeqLabelTrainer): | |||||
| override | override | ||||
| :return: | :return: | ||||
| """ | """ | ||||
| self.optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001) | |||||
| self.optimizer = torch.optim.Adam(self._model.parameters(), lr=0.001) | |||||
| self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=3000, gamma=0.5) | self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=3000, gamma=0.5) | ||||
| def update(self): | def update(self): | ||||
| @@ -1,7 +1,7 @@ | |||||
| import os | |||||
| import sys | import sys | ||||
| sys.path.append("..") | sys.path.append("..") | ||||
| import argparse | |||||
| from fastNLP.loader.config_loader import ConfigLoader, ConfigSection | from fastNLP.loader.config_loader import ConfigLoader, ConfigSection | ||||
| from fastNLP.core.trainer import SeqLabelTrainer | from fastNLP.core.trainer import SeqLabelTrainer | ||||
| from fastNLP.loader.dataset_loader import POSDatasetLoader, BaseLoader | from fastNLP.loader.dataset_loader import POSDatasetLoader, BaseLoader | ||||
| @@ -11,17 +11,29 @@ from fastNLP.loader.model_loader import ModelLoader | |||||
| from fastNLP.core.tester import SeqLabelTester | from fastNLP.core.tester import SeqLabelTester | ||||
| from fastNLP.models.sequence_modeling import SeqLabeling | from fastNLP.models.sequence_modeling import SeqLabeling | ||||
| from fastNLP.core.predictor import SeqLabelInfer | from fastNLP.core.predictor import SeqLabelInfer | ||||
| from fastNLP.core.optimizer import Optimizer | |||||
| parser = argparse.ArgumentParser() | |||||
| parser.add_argument("-s", "--save", type=str, default="./seq_label/", help="path to save pickle files") | |||||
| parser.add_argument("-t", "--train", type=str, default="./data_for_tests/people.txt", | |||||
| help="path to the training data") | |||||
| parser.add_argument("-c", "--config", type=str, default="./data_for_tests/config", help="path to the config file") | |||||
| parser.add_argument("-m", "--model_name", type=str, default="seq_label_model.pkl", help="the name of the model") | |||||
| parser.add_argument("-i", "--infer", type=str, default="data_for_tests/people_infer.txt", | |||||
| help="data used for inference") | |||||
| data_name = "people.txt" | |||||
| data_path = "data_for_tests/people.txt" | |||||
| pickle_path = "seq_label/" | |||||
| data_infer_path = "data_for_tests/people_infer.txt" | |||||
| args = parser.parse_args() | |||||
| pickle_path = args.save | |||||
| model_name = args.model_name | |||||
| config_dir = args.config | |||||
| data_path = args.train | |||||
| data_infer_path = args.infer | |||||
| def infer(): | def infer(): | ||||
| # Load infer configuration, the same as test | # Load infer configuration, the same as test | ||||
| test_args = ConfigSection() | test_args = ConfigSection() | ||||
| ConfigLoader("config.cfg", "").load_config("./data_for_tests/config", {"POS_test": test_args}) | |||||
| ConfigLoader("config.cfg", "").load_config(config_dir, {"POS_infer": test_args}) | |||||
| # fetch dictionary size and number of labels from pickle files | # fetch dictionary size and number of labels from pickle files | ||||
| word2index = load_pickle(pickle_path, "word2id.pkl") | word2index = load_pickle(pickle_path, "word2id.pkl") | ||||
| @@ -33,11 +45,11 @@ def infer(): | |||||
| model = SeqLabeling(test_args) | model = SeqLabeling(test_args) | ||||
| # Dump trained parameters into the model | # Dump trained parameters into the model | ||||
| ModelLoader.load_pytorch(model, pickle_path + "saved_model.pkl") | |||||
| ModelLoader.load_pytorch(model, os.path.join(pickle_path, model_name)) | |||||
| print("model loaded!") | print("model loaded!") | ||||
| # Data Loader | # Data Loader | ||||
| raw_data_loader = BaseLoader(data_name, data_infer_path) | |||||
| raw_data_loader = BaseLoader("xxx", data_infer_path) | |||||
| infer_data = raw_data_loader.load_lines() | infer_data = raw_data_loader.load_lines() | ||||
| # Inference interface | # Inference interface | ||||
| @@ -51,49 +63,72 @@ def infer(): | |||||
| def train_and_test(): | def train_and_test(): | ||||
| # Config Loader | # Config Loader | ||||
| train_args = ConfigSection() | |||||
| ConfigLoader("config.cfg", "").load_config("./data_for_tests/config", {"POS": train_args}) | |||||
| trainer_args = ConfigSection() | |||||
| model_args = ConfigSection() | |||||
| ConfigLoader("config.cfg", "").load_config(config_dir, { | |||||
| "test_seq_label_trainer": trainer_args, "test_seq_label_model": model_args}) | |||||
| # Data Loader | # Data Loader | ||||
| pos_loader = POSDatasetLoader(data_name, data_path) | |||||
| pos_loader = POSDatasetLoader("xxx", data_path) | |||||
| train_data = pos_loader.load_lines() | train_data = pos_loader.load_lines() | ||||
| # Preprocessor | # Preprocessor | ||||
| p = SeqLabelPreprocess() | p = SeqLabelPreprocess() | ||||
| data_train, data_dev = p.run(train_data, pickle_path=pickle_path, train_dev_split=0.5) | data_train, data_dev = p.run(train_data, pickle_path=pickle_path, train_dev_split=0.5) | ||||
| train_args["vocab_size"] = p.vocab_size | |||||
| train_args["num_classes"] = p.num_classes | |||||
| # Trainer | |||||
| trainer = SeqLabelTrainer(train_args) | |||||
| model_args["vocab_size"] = p.vocab_size | |||||
| model_args["num_classes"] = p.num_classes | |||||
| # Trainer: two definition styles | |||||
| # 1 | |||||
| # trainer = SeqLabelTrainer(trainer_args.data) | |||||
| # 2 | |||||
| trainer = SeqLabelTrainer( | |||||
| epochs=trainer_args["epochs"], | |||||
| batch_size=trainer_args["batch_size"], | |||||
| validate=trainer_args["validate"], | |||||
| use_cuda=trainer_args["use_cuda"], | |||||
| pickle_path=pickle_path, | |||||
| save_best_dev=trainer_args["save_best_dev"], | |||||
| model_name=model_name, | |||||
| optimizer=Optimizer("SGD", lr=0.01, momentum=0.9), | |||||
| ) | |||||
| # Model | # Model | ||||
| model = SeqLabeling(train_args) | |||||
| model = SeqLabeling(model_args) | |||||
| # Start training | # Start training | ||||
| trainer.train(model, data_train, data_dev) | trainer.train(model, data_train, data_dev) | ||||
| print("Training finished!") | print("Training finished!") | ||||
| # Saver | # Saver | ||||
| saver = ModelSaver(pickle_path + "saved_model.pkl") | |||||
| saver = ModelSaver(os.path.join(pickle_path, model_name)) | |||||
| saver.save_pytorch(model) | saver.save_pytorch(model) | ||||
| print("Model saved!") | print("Model saved!") | ||||
| del model, trainer, pos_loader | del model, trainer, pos_loader | ||||
| # Define the same model | # Define the same model | ||||
| model = SeqLabeling(train_args) | |||||
| model = SeqLabeling(model_args) | |||||
| # Dump trained parameters into the model | # Dump trained parameters into the model | ||||
| ModelLoader.load_pytorch(model, pickle_path + "saved_model.pkl") | |||||
| ModelLoader.load_pytorch(model, os.path.join(pickle_path, model_name)) | |||||
| print("model loaded!") | print("model loaded!") | ||||
| # Load test configuration | # Load test configuration | ||||
| test_args = ConfigSection() | |||||
| ConfigLoader("config.cfg", "").load_config("./data_for_tests/config", {"POS_test": test_args}) | |||||
| tester_args = ConfigSection() | |||||
| ConfigLoader("config.cfg", "").load_config(config_dir, {"test_seq_label_tester": tester_args}) | |||||
| # Tester | # Tester | ||||
| tester = SeqLabelTester(test_args) | |||||
| tester = SeqLabelTester(save_output=False, | |||||
| save_loss=False, | |||||
| save_best_dev=False, | |||||
| batch_size=8, | |||||
| use_cuda=False, | |||||
| pickle_path=pickle_path, | |||||
| model_name="seq_label_in_test.pkl", | |||||
| print_every_step=1 | |||||
| ) | |||||
| # Start testing with validation data | # Start testing with validation data | ||||
| tester.test(model, data_dev) | tester.test(model, data_dev) | ||||
| @@ -105,4 +140,4 @@ def train_and_test(): | |||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||
| train_and_test() | train_and_test() | ||||
| # infer() | |||||
| infer() | |||||
| @@ -1,6 +1,7 @@ | |||||
| # Python: 3.5 | # Python: 3.5 | ||||
| # encoding: utf-8 | # encoding: utf-8 | ||||
| import argparse | |||||
| import os | import os | ||||
| import sys | import sys | ||||
| @@ -13,75 +14,105 @@ from fastNLP.loader.model_loader import ModelLoader | |||||
| from fastNLP.core.preprocess import ClassPreprocess | from fastNLP.core.preprocess import ClassPreprocess | ||||
| from fastNLP.models.cnn_text_classification import CNNText | from fastNLP.models.cnn_text_classification import CNNText | ||||
| from fastNLP.saver.model_saver import ModelSaver | from fastNLP.saver.model_saver import ModelSaver | ||||
| from fastNLP.core.optimizer import Optimizer | |||||
| from fastNLP.core.loss import Loss | |||||
| save_path = "./test_classification/" | |||||
| data_dir = "./data_for_tests/" | |||||
| train_file = 'text_classify.txt' | |||||
| model_name = "model_class.pkl" | |||||
| parser = argparse.ArgumentParser() | |||||
| parser.add_argument("-s", "--save", type=str, default="./test_classification/", help="path to save pickle files") | |||||
| parser.add_argument("-t", "--train", type=str, default="./data_for_tests/text_classify.txt", | |||||
| help="path to the training data") | |||||
| parser.add_argument("-c", "--config", type=str, default="./data_for_tests/config", help="path to the config file") | |||||
| parser.add_argument("-m", "--model_name", type=str, default="classify_model.pkl", help="the name of the model") | |||||
| args = parser.parse_args() | |||||
| save_dir = args.save | |||||
| train_data_dir = args.train | |||||
| model_name = args.model_name | |||||
| config_dir = args.config | |||||
| def infer(): | def infer(): | ||||
| # load dataset | # load dataset | ||||
| print("Loading data...") | print("Loading data...") | ||||
| ds_loader = ClassDatasetLoader("train", os.path.join(data_dir, train_file)) | |||||
| ds_loader = ClassDatasetLoader("train", train_data_dir) | |||||
| data = ds_loader.load() | data = ds_loader.load() | ||||
| unlabeled_data = [x[0] for x in data] | unlabeled_data = [x[0] for x in data] | ||||
| # pre-process data | # pre-process data | ||||
| pre = ClassPreprocess() | pre = ClassPreprocess() | ||||
| vocab_size, n_classes = pre.run(data, pickle_path=save_path) | |||||
| print("vocabulary size:", vocab_size) | |||||
| print("number of classes:", n_classes) | |||||
| data = pre.run(data, pickle_path=save_dir) | |||||
| print("vocabulary size:", pre.vocab_size) | |||||
| print("number of classes:", pre.num_classes) | |||||
| model_args = ConfigSection() | model_args = ConfigSection() | ||||
| ConfigLoader.load_config("data_for_tests/config", {"text_class_model": model_args}) | |||||
| # TODO: load from config file | |||||
| model_args["vocab_size"] = pre.vocab_size | |||||
| model_args["num_classes"] = pre.num_classes | |||||
| # ConfigLoader.load_config(config_dir, {"text_class_model": model_args}) | |||||
| # construct model | # construct model | ||||
| print("Building model...") | print("Building model...") | ||||
| cnn = CNNText(model_args) | cnn = CNNText(model_args) | ||||
| # Dump trained parameters into the model | # Dump trained parameters into the model | ||||
| ModelLoader.load_pytorch(cnn, "./data_for_tests/saved_model.pkl") | |||||
| ModelLoader.load_pytorch(cnn, os.path.join(save_dir, model_name)) | |||||
| print("model loaded!") | print("model loaded!") | ||||
| infer = ClassificationInfer(data_dir) | |||||
| infer = ClassificationInfer(pickle_path=save_dir) | |||||
| results = infer.predict(cnn, unlabeled_data) | results = infer.predict(cnn, unlabeled_data) | ||||
| print(results) | print(results) | ||||
| def train(): | def train(): | ||||
| train_args, model_args = ConfigSection(), ConfigSection() | train_args, model_args = ConfigSection(), ConfigSection() | ||||
| ConfigLoader.load_config("data_for_tests/config", {"text_class": train_args, "text_class_model": model_args}) | |||||
| ConfigLoader.load_config(config_dir, {"text_class": train_args}) | |||||
| # load dataset | # load dataset | ||||
| print("Loading data...") | print("Loading data...") | ||||
| ds_loader = ClassDatasetLoader("train", os.path.join(data_dir, train_file)) | |||||
| ds_loader = ClassDatasetLoader("train", train_data_dir) | |||||
| data = ds_loader.load() | data = ds_loader.load() | ||||
| print(data[0]) | print(data[0]) | ||||
| # pre-process data | # pre-process data | ||||
| pre = ClassPreprocess() | pre = ClassPreprocess() | ||||
| data_train = pre.run(data, pickle_path=save_path) | |||||
| data_train = pre.run(data, pickle_path=save_dir) | |||||
| print("vocabulary size:", pre.vocab_size) | print("vocabulary size:", pre.vocab_size) | ||||
| print("number of classes:", pre.num_classes) | print("number of classes:", pre.num_classes) | ||||
| model_args["num_classes"] = pre.num_classes | |||||
| model_args["vocab_size"] = pre.vocab_size | |||||
| # construct model | # construct model | ||||
| print("Building model...") | print("Building model...") | ||||
| model = CNNText(model_args) | model = CNNText(model_args) | ||||
| # ConfigSaver().save_config(config_dir, {"text_class_model": model_args}) | |||||
| # train | # train | ||||
| print("Training...") | print("Training...") | ||||
| trainer = ClassificationTrainer(train_args) | |||||
| # 1 | |||||
| # trainer = ClassificationTrainer(train_args) | |||||
| # 2 | |||||
| trainer = ClassificationTrainer(epochs=train_args["epochs"], | |||||
| batch_size=train_args["batch_size"], | |||||
| validate=train_args["validate"], | |||||
| use_cuda=train_args["use_cuda"], | |||||
| pickle_path=save_dir, | |||||
| save_best_dev=train_args["save_best_dev"], | |||||
| model_name=model_name, | |||||
| loss=Loss("cross_entropy"), | |||||
| optimizer=Optimizer("SGD", lr=0.001, momentum=0.9)) | |||||
| trainer.train(model, data_train) | trainer.train(model, data_train) | ||||
| print("Training finished!") | print("Training finished!") | ||||
| saver = ModelSaver("./data_for_tests/saved_model.pkl") | |||||
| saver = ModelSaver(os.path.join(save_dir, model_name)) | |||||
| saver.save_pytorch(model) | saver.save_pytorch(model) | ||||
| print("Model saved!") | print("Model saved!") | ||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||
| train() | train() | ||||
| # infer() | |||||
| infer() | |||||