* add more code comments * fix tester * refresh code stylestags/v0.2.0
| @@ -10,7 +10,7 @@ class Batch(object): | |||
| """ | |||
| def __init__(self, dataset, batch_size, sampler, as_numpy=False,): | |||
| def __init__(self, dataset, batch_size, sampler, as_numpy=False): | |||
| """ | |||
| :param dataset: a DataSet object | |||
| @@ -1,6 +1,7 @@ | |||
| import numpy as np | |||
| from fastNLP.core.fieldarray import FieldArray | |||
| from fastNLP.core.instance import Instance | |||
| _READERS = {} | |||
| @@ -27,10 +28,10 @@ class DataSet(object): | |||
| """ | |||
| class Instance(object): | |||
| def __init__(self, dataset, idx=-1): | |||
| def __init__(self, dataset, idx=-1, **fields): | |||
| self.dataset = dataset | |||
| self.idx = idx | |||
| self.fields = None | |||
| self.fields = fields | |||
| def __next__(self): | |||
| self.idx += 1 | |||
| @@ -38,6 +39,14 @@ class DataSet(object): | |||
| raise StopIteration | |||
| return self | |||
| def add_field(self, field_name, field): | |||
| """Add a new field to the instance. | |||
| :param field_name: str, the name of the field. | |||
| :param field: | |||
| """ | |||
| self.fields[field_name] = field | |||
| def __getitem__(self, name): | |||
| return self.dataset[name][self.idx] | |||
| @@ -47,13 +56,6 @@ class DataSet(object): | |||
| self.dataset.add_field(name, new_fields) | |||
| self.dataset[name][self.idx] = val | |||
| def __getattr__(self, item): | |||
| if item == 'fields': | |||
| self.fields = {name: field[self.idx] for name, field in self.dataset.get_fields().items()} | |||
| return self.fields | |||
| else: | |||
| raise AttributeError('{} does not exist.'.format(item)) | |||
| def __repr__(self): | |||
| return "\n".join(['{}: {}'.format(name, repr(self.dataset[name][self.idx])) for name | |||
| in self.dataset.get_fields().keys()]) | |||
| @@ -112,14 +114,13 @@ class DataSet(object): | |||
| self.field_arrays[name].append(field) | |||
| def add_field(self, name, fields, padding_val=0, is_input=False, is_target=False): | |||
| """ | |||
| """Add a new field to the DataSet. | |||
| :param str name: | |||
| :param fields: | |||
| :param int padding_val: | |||
| :param bool is_input: | |||
| :param bool is_target: | |||
| :return: | |||
| :param str name: the name of the field. | |||
| :param fields: a list of int, float, or other objects. | |||
| :param int padding_val: integer for padding. | |||
| :param bool is_input: whether this field is model input. | |||
| :param bool is_target: whether this field is label or target. | |||
| """ | |||
| if len(self.field_arrays) != 0: | |||
| assert len(self) == len(fields) | |||
| @@ -127,28 +128,43 @@ class DataSet(object): | |||
| is_input=is_input) | |||
| def delete_field(self, name): | |||
| """Delete a field based on the field name. | |||
| :param str name: the name of the field to be deleted. | |||
| """ | |||
| self.field_arrays.pop(name) | |||
| def get_fields(self): | |||
| """Return all the fields with their names. | |||
| :return dict field_arrays: the internal data structure of DataSet. | |||
| """ | |||
| return self.field_arrays | |||
| def __getitem__(self, name): | |||
| if isinstance(name, int): | |||
| return self.Instance(self, idx=name) | |||
| elif isinstance(name, slice): | |||
| ds = DataSet() | |||
| def __getitem__(self, idx): | |||
| """ | |||
| :param idx: can be int, slice, or str. | |||
| :return: If `idx` is int, return an Instance object. | |||
| If `idx` is slice, return a DataSet object. | |||
| If `idx` is str, it must be a field name, return the field. | |||
| """ | |||
| if isinstance(idx, int): | |||
| return self.Instance(self, idx, **{name: self.field_arrays[name][idx] for name in self.field_arrays}) | |||
| elif isinstance(idx, slice): | |||
| data_set = DataSet() | |||
| for field in self.field_arrays.values(): | |||
| ds.add_field(name=field.name, | |||
| fields=field.content[name], | |||
| padding_val=field.padding_val, | |||
| need_tensor=field.need_tensor, | |||
| is_target=field.is_target) | |||
| return ds | |||
| elif isinstance(name, str): | |||
| return self.field_arrays[name] | |||
| data_set.add_field(name=field.name, | |||
| fields=field.content[idx], | |||
| padding_val=field.padding_val, | |||
| is_input=field.is_input, | |||
| is_target=field.is_target) | |||
| return data_set | |||
| elif isinstance(idx, str): | |||
| return self.field_arrays[idx] | |||
| else: | |||
| raise KeyError | |||
| raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx))) | |||
| def __len__(self): | |||
| if len(self.field_arrays) == 0: | |||
| @@ -208,6 +224,7 @@ class DataSet(object): | |||
| pass | |||
| try: | |||
| reader_cls = _READERS[item] | |||
| # add read_*data() support | |||
| def _read(*args, **kwargs): | |||
| data = reader_cls().load(*args, **kwargs) | |||
| @@ -231,6 +248,12 @@ class DataSet(object): | |||
| return wrapper | |||
| def apply(self, func, new_field_name=None): | |||
| """Apply a function to every instance of the DataSet. | |||
| :param func: a function that takes an instance as input. | |||
| :param str new_field_name: If not None, results of the function will be stored as a new field. | |||
| :return results: returned values of the function over all instances. | |||
| """ | |||
| results = [] | |||
| for ins in self: | |||
| results.append(func(ins)) | |||
| @@ -247,28 +270,24 @@ class DataSet(object): | |||
| else: | |||
| return results | |||
| def split(self, test_ratio): | |||
| assert isinstance(test_ratio, float) | |||
| def split(self, dev_ratio): | |||
| """Split the dataset into training and development(validation) set. | |||
| :param float dev_ratio: the ratio of test set in all data. | |||
| :return DataSet train_set: the training set | |||
| DataSet dev_set: the development set | |||
| """ | |||
| assert isinstance(dev_ratio, float) | |||
| assert 0 < dev_ratio < 1 | |||
| all_indices = [_ for _ in range(len(self))] | |||
| np.random.shuffle(all_indices) | |||
| test_indices = all_indices[:int(test_ratio)] | |||
| train_indices = all_indices[int(test_ratio):] | |||
| test_set = DataSet() | |||
| split = int(dev_ratio * len(self)) | |||
| dev_indices = all_indices[:split] | |||
| train_indices = all_indices[split:] | |||
| dev_set = DataSet() | |||
| train_set = DataSet() | |||
| for idx in test_indices: | |||
| test_set.append(self[idx]) | |||
| for idx in dev_indices: | |||
| dev_set.append(self[idx]) | |||
| for idx in train_indices: | |||
| train_set.append(self[idx]) | |||
| return train_set, test_set | |||
| if __name__ == '__main__': | |||
| from fastNLP.core.instance import Instance | |||
| d = DataSet({'a': list('abc')}) | |||
| _ = d.a | |||
| d.apply(lambda x: x['a']) | |||
| print(d[1]) | |||
| import copy | |||
| dd = copy.deepcopy(d) | |||
| print(dd.a) | |||
| return train_set, dev_set | |||
| @@ -3,61 +3,19 @@ from collections import defaultdict | |||
| import torch | |||
| from fastNLP.core.batch import Batch | |||
| from fastNLP.core.metrics import Evaluator | |||
| from fastNLP.core.sampler import RandomSampler | |||
| # logger = create_logger(__name__, "./train_test.log") | |||
| class Tester(object): | |||
| """An collection of model inference and evaluation of performance, used over validation/dev set and test set. """ | |||
| def __init__(self, **kwargs): | |||
| """ | |||
| :param kwargs: a dict-like object that has __getitem__ method, can be accessed by "test_args["key_str"]" | |||
| """ | |||
| def __init__(self, batch_size, evaluator, use_cuda, save_path="./save/", **kwargs): | |||
| super(Tester, self).__init__() | |||
| """ | |||
| "default_args" provides default value for important settings. | |||
| The initialization arguments "kwargs" with the same key (name) will override the default value. | |||
| "kwargs" must have the same type as "default_args" on corresponding keys. | |||
| Otherwise, error will raise. | |||
| """ | |||
| default_args = {"batch_size": 8, | |||
| "use_cuda": False, | |||
| "pickle_path": "./save/", | |||
| "model_name": "dev_best_model.pkl", | |||
| "evaluator": Evaluator() | |||
| } | |||
| """ | |||
| "required_args" is the collection of arguments that users must pass to Trainer explicitly. | |||
| This is used to warn users of essential settings in the training. | |||
| Specially, "required_args" does not have default value, so they have nothing to do with "default_args". | |||
| """ | |||
| required_args = {} | |||
| for req_key in required_args: | |||
| if req_key not in kwargs: | |||
| raise ValueError("Tester lacks argument {}".format(req_key)) | |||
| for key in default_args: | |||
| if key in kwargs: | |||
| if isinstance(kwargs[key], type(default_args[key])): | |||
| default_args[key] = kwargs[key] | |||
| else: | |||
| msg = "Argument %s type mismatch: expected %s while get %s" % ( | |||
| key, type(default_args[key]), type(kwargs[key])) | |||
| raise ValueError(msg) | |||
| else: | |||
| # Tester doesn't care about extra arguments | |||
| pass | |||
| # print(default_args) | |||
| self.batch_size = default_args["batch_size"] | |||
| self.pickle_path = default_args["pickle_path"] | |||
| self.use_cuda = default_args["use_cuda"] | |||
| self._evaluator = default_args["evaluator"] | |||
| self.batch_size = batch_size | |||
| self.pickle_path = save_path | |||
| self.use_cuda = use_cuda | |||
| self._evaluator = evaluator | |||
| self._model = None | |||
| self.eval_history = [] # evaluation results of all batches | |||
| @@ -72,7 +30,7 @@ class Tester(object): | |||
| self.mode(network, is_test=True) | |||
| self.eval_history.clear() | |||
| output, truths = defaultdict(list), defaultdict(list) | |||
| data_iterator = Batch(dev_data, self.batch_size, sampler=RandomSampler(), use_cuda=self.use_cuda) | |||
| data_iterator = Batch(dev_data, self.batch_size, sampler=RandomSampler(), as_numpy=False) | |||
| with torch.no_grad(): | |||
| for batch_x, batch_y in data_iterator: | |||
| @@ -15,6 +15,8 @@ from fastNLP.core.optimizer import Optimizer | |||
| from fastNLP.core.sampler import RandomSampler | |||
| from fastNLP.core.sampler import SequentialSampler | |||
| from fastNLP.core.tester import Tester | |||
| from fastNLP.core.utils import _build_args | |||
| from fastNLP.core.utils import _check_arg_dict_list | |||
| from fastNLP.core.utils import _check_arg_dict_list | |||
| from fastNLP.core.utils import _build_args | |||
| @@ -78,7 +80,7 @@ class Trainer(object): | |||
| epoch = 1 | |||
| while epoch <= self.n_epochs: | |||
| data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=RandomSampler()) | |||
| data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=RandomSampler(), as_numpy=False) | |||
| self._train_epoch(data_iterator, self.model, epoch, self.dev_data, start) | |||
| @@ -207,9 +209,9 @@ def best_eval_result(self, metrics): | |||
| DEFAULT_CHECK_BATCH_SIZE = 2 | |||
| DEFAULT_CHECK_NUM_BATCH = 2 | |||
| IGNORE_CHECK_LEVEL=0 | |||
| WARNING_CHECK_LEVEL=1 | |||
| STRICT_CHECK_LEVEL=2 | |||
| IGNORE_CHECK_LEVEL = 0 | |||
| WARNING_CHECK_LEVEL = 1 | |||
| STRICT_CHECK_LEVEL = 2 | |||
| def _check_code(dataset, model, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=None, check_level=1): | |||
| # check get_loss 方法 | |||
| @@ -220,11 +222,20 @@ def _check_code(dataset, model, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=No | |||
| batch_size = min(DEFAULT_CHECK_BATCH_SIZE, batch_size) | |||
| batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler()) | |||
| for batch_count, (batch_x, batch_y) in enumerate(batch): | |||
| _syn_model_data(model, batch_x, batch_y) | |||
| # forward check | |||
| if batch_count==0: | |||
| _check_forward_error(model=model, model_func=model.forward, check_level=check_level, | |||
| batch_x=batch_x) | |||
| if batch_count == 0: | |||
| check_res = _check_arg_dict_list(model.forward, batch_x) | |||
| _info_str = '' | |||
| if len(check_res.missing) > 0: | |||
| if check_level == WARNING_CHECK_LEVEL: | |||
| for field_name in check_res.missing: | |||
| if hasattr(dataset, field_name): | |||
| _info_str += "{} " | |||
| _info_str += "Missing argument: [{}] needed by '{}.forward' is not presented in the input.\n" | |||
| _info_str += "" | |||
| print("") | |||
| if len(check_res.unused) > 0: | |||
| if check_level == WARNING_CHECK_LEVEL: | |||
| _info_str += "" | |||
| refined_batch_x = _build_args(model.forward, **batch_x) | |||
| output = model(**refined_batch_x) | |||
| @@ -233,10 +244,14 @@ def _check_code(dataset, model, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=No | |||
| # loss check | |||
| if batch_count == 0: | |||
| _check_loss_evaluate(model=model, model_func=model.get_loss, check_level=check_level, | |||
| output=output, batch_y=batch_y) | |||
| loss_input = _build_args(model.get_loss, **output, **batch_y) | |||
| loss = model.get_loss(**loss_input) | |||
| _dict = _check_arg_dict_list(model.loss, [output, batch_y]) | |||
| if len(_dict) != 0: | |||
| pass | |||
| loss_input = _build_args(model.loss, **output, **batch_y) | |||
| loss = model.loss(**loss_input) | |||
| if batch_count == 0: | |||
| if isinstance(loss, torch.Tensor): | |||
| pass | |||
| # check loss output | |||
| if batch_count == 0: | |||
| @@ -248,8 +263,7 @@ def _check_code(dataset, model, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=No | |||
| model_name, loss.size() | |||
| )) | |||
| loss.backward() | |||
| model.zero_grad() | |||
| if batch_count+1>=DEFAULT_CHECK_NUM_BATCH: | |||
| if batch_count + 1 >= DEFAULT_CHECK_BATCH_SIZE: | |||
| break | |||
| if check_level > IGNORE_CHECK_LEVEL: | |||
| print('Finish checking training process.', flush=True) | |||
| @@ -407,14 +421,7 @@ if __name__ == '__main__': | |||
| # trainer = Trainer(dataset, model) | |||
| _check_code(dataset=dataset, model=model, dev_data=dataset, check_level=2) | |||
| # _check_forward_error(model=model, model_func=model.forward, check_level=1, | |||
| # batch_x=fake_data_dict) | |||
| # import inspect | |||
| # print(inspect.getfullargspec(model.forward)) | |||
| if len(_dict) != 0: | |||
| pass | |||
| refined_batch_x = _build_args(model.forward, **batch_x) | |||
| output = model(**refined_batch_x) | |||
| @@ -1,8 +1,8 @@ | |||
| import _pickle | |||
| import os | |||
| import inspect | |||
| from collections import namedtuple | |||
| import os | |||
| from collections import Counter | |||
| from collections import namedtuple | |||
| CheckRes = namedtuple('CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed'], verbose=False) | |||