* add code comments * merge *_saver.py & *_loader.py in io/ * (ancient codes) rename Loss into LossFromTorchtags/v0.2.0^2
| @@ -1,5 +1,3 @@ | |||||
| import torch | |||||
| import hashlib | import hashlib | ||||
| import os | import os | ||||
| import re | import re | ||||
| @@ -7,6 +5,8 @@ import shutil | |||||
| import sys | import sys | ||||
| import tempfile | import tempfile | ||||
| import torch | |||||
| try: | try: | ||||
| from requests.utils import urlparse | from requests.utils import urlparse | ||||
| from requests import get as urlopen | from requests import get as urlopen | ||||
| @@ -132,7 +132,3 @@ if tqdm is None: | |||||
| sys.stderr.write('\n') | sys.stderr.write('\n') | ||||
| if __name__ == '__main__': | |||||
| pipeline = load_url('http://10.141.208.102:5000/file/download/infer_context-4e86fd93.pkl', model_dir='.') | |||||
| print(type(pipeline)) | |||||
| @@ -1,14 +1,15 @@ | |||||
| import torch | |||||
| from collections import defaultdict | |||||
| import re | import re | ||||
| from collections import defaultdict | |||||
| import torch | |||||
| from fastNLP.core.dataset import DataSet | |||||
| from fastNLP.core.vocabulary import Vocabulary | |||||
| from fastNLP.core.batch import Batch | from fastNLP.core.batch import Batch | ||||
| from fastNLP.core.dataset import DataSet | |||||
| from fastNLP.core.sampler import SequentialSampler | from fastNLP.core.sampler import SequentialSampler | ||||
| from fastNLP.core.vocabulary import Vocabulary | |||||
| class Processor: | |||||
| class Processor(object): | |||||
| def __init__(self, field_name, new_added_field_name): | def __init__(self, field_name, new_added_field_name): | ||||
| self.field_name = field_name | self.field_name = field_name | ||||
| if new_added_field_name is None: | if new_added_field_name is None: | ||||
| @@ -17,7 +18,7 @@ class Processor: | |||||
| self.new_added_field_name = new_added_field_name | self.new_added_field_name = new_added_field_name | ||||
| def process(self, *args, **kwargs): | def process(self, *args, **kwargs): | ||||
| pass | |||||
| raise NotImplementedError | |||||
| def __call__(self, *args, **kwargs): | def __call__(self, *args, **kwargs): | ||||
| return self.process(*args, **kwargs) | return self.process(*args, **kwargs) | ||||
| @@ -132,13 +133,14 @@ class Num2TagProcessor(Processor): | |||||
| class IndexerProcessor(Processor): | class IndexerProcessor(Processor): | ||||
| def __init__(self, vocab, field_name, new_added_field_name, delete_old_field=False): | |||||
| def __init__(self, vocab, field_name, new_added_field_name, delete_old_field=False, is_input=True): | |||||
| assert isinstance(vocab, Vocabulary), "Only Vocabulary class is allowed, not {}.".format(type(vocab)) | assert isinstance(vocab, Vocabulary), "Only Vocabulary class is allowed, not {}.".format(type(vocab)) | ||||
| super(IndexerProcessor, self).__init__(field_name, new_added_field_name) | super(IndexerProcessor, self).__init__(field_name, new_added_field_name) | ||||
| self.vocab = vocab | self.vocab = vocab | ||||
| self.delete_old_field = delete_old_field | self.delete_old_field = delete_old_field | ||||
| self.is_input = is_input | |||||
| def set_vocab(self, vocab): | def set_vocab(self, vocab): | ||||
| assert isinstance(vocab, Vocabulary), "Only Vocabulary class is allowed, not {}.".format(type(vocab)) | assert isinstance(vocab, Vocabulary), "Only Vocabulary class is allowed, not {}.".format(type(vocab)) | ||||
| @@ -146,13 +148,14 @@ class IndexerProcessor(Processor): | |||||
| self.vocab = vocab | self.vocab = vocab | ||||
| def process(self, dataset): | def process(self, dataset): | ||||
| assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | |||||
| assert isinstance(dataset, DataSet), "Only DataSet class is allowed, not {}.".format(type(dataset)) | |||||
| for ins in dataset: | for ins in dataset: | ||||
| tokens = ins[self.field_name] | tokens = ins[self.field_name] | ||||
| index = [self.vocab.to_index(token) for token in tokens] | index = [self.vocab.to_index(token) for token in tokens] | ||||
| ins[self.new_added_field_name] = index | ins[self.new_added_field_name] = index | ||||
| dataset._set_need_tensor(**{self.new_added_field_name: True}) | |||||
| if self.is_input: | |||||
| dataset.set_input(self.new_added_field_name) | |||||
| if self.delete_old_field: | if self.delete_old_field: | ||||
| dataset.delete_field(self.field_name) | dataset.delete_field(self.field_name) | ||||
| @@ -161,6 +164,9 @@ class IndexerProcessor(Processor): | |||||
| class VocabProcessor(Processor): | class VocabProcessor(Processor): | ||||
| """Build vocabulary with a field in the data set. | |||||
| """ | |||||
| def __init__(self, field_name): | def __init__(self, field_name): | ||||
| super(VocabProcessor, self).__init__(field_name, None) | super(VocabProcessor, self).__init__(field_name, None) | ||||
| self.vocab = Vocabulary() | self.vocab = Vocabulary() | ||||
| @@ -178,17 +184,20 @@ class VocabProcessor(Processor): | |||||
| class SeqLenProcessor(Processor): | class SeqLenProcessor(Processor): | ||||
| def __init__(self, field_name, new_added_field_name='seq_lens'): | |||||
| def __init__(self, field_name, new_added_field_name='seq_lens', is_input=True): | |||||
| super(SeqLenProcessor, self).__init__(field_name, new_added_field_name) | super(SeqLenProcessor, self).__init__(field_name, new_added_field_name) | ||||
| self.is_input = is_input | |||||
| def process(self, dataset): | def process(self, dataset): | ||||
| assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | ||||
| for ins in dataset: | for ins in dataset: | ||||
| length = len(ins[self.field_name]) | length = len(ins[self.field_name]) | ||||
| ins[self.new_added_field_name] = length | ins[self.new_added_field_name] = length | ||||
| dataset._set_need_tensor(**{self.new_added_field_name: True}) | |||||
| if self.is_input: | |||||
| dataset.set_input(self.new_added_field_name) | |||||
| return dataset | return dataset | ||||
| class ModelProcessor(Processor): | class ModelProcessor(Processor): | ||||
| def __init__(self, model, seq_len_field_name='seq_lens', batch_size=32): | def __init__(self, model, seq_len_field_name='seq_lens', batch_size=32): | ||||
| """ | """ | ||||
| @@ -238,6 +247,7 @@ class ModelProcessor(Processor): | |||||
| device = torch.device(device) | device = torch.device(device) | ||||
| self.model.to(device) | self.model.to(device) | ||||
| class Index2WordProcessor(Processor): | class Index2WordProcessor(Processor): | ||||
| def __init__(self, vocab, field_name, new_added_field_name): | def __init__(self, vocab, field_name, new_added_field_name): | ||||
| super(Index2WordProcessor, self).__init__(field_name, new_added_field_name) | super(Index2WordProcessor, self).__init__(field_name, new_added_field_name) | ||||
| @@ -251,6 +261,7 @@ class Index2WordProcessor(Processor): | |||||
| class SetTensorProcessor(Processor): | class SetTensorProcessor(Processor): | ||||
| # TODO: remove it. It is strange. | |||||
| def __init__(self, field_dict, default=False): | def __init__(self, field_dict, default=False): | ||||
| super(SetTensorProcessor, self).__init__(None, None) | super(SetTensorProcessor, self).__init__(None, None) | ||||
| self.field_dict = field_dict | self.field_dict = field_dict | ||||
| @@ -264,6 +275,7 @@ class SetTensorProcessor(Processor): | |||||
| class SetIsTargetProcessor(Processor): | class SetIsTargetProcessor(Processor): | ||||
| # TODO; remove it. | |||||
| def __init__(self, field_dict, default=False): | def __init__(self, field_dict, default=False): | ||||
| super(SetIsTargetProcessor, self).__init__(None, None) | super(SetIsTargetProcessor, self).__init__(None, None) | ||||
| self.field_dict = field_dict | self.field_dict = field_dict | ||||
| @@ -2,7 +2,7 @@ from .batch import Batch | |||||
| from .dataset import DataSet | from .dataset import DataSet | ||||
| from .fieldarray import FieldArray | from .fieldarray import FieldArray | ||||
| from .instance import Instance | from .instance import Instance | ||||
| from .losses import Loss | |||||
| from .losses import LossFromTorch | |||||
| from .optimizer import Optimizer | from .optimizer import Optimizer | ||||
| from .sampler import SequentialSampler, BucketSampler, RandomSampler, BaseSampler | from .sampler import SequentialSampler, BucketSampler, RandomSampler, BaseSampler | ||||
| from .tester import Tester | from .tester import Tester | ||||
| @@ -9,32 +9,20 @@ from fastNLP.core.utils import get_func_signature | |||||
| _READERS = {} | _READERS = {} | ||||
| def construct_dataset(sentences): | |||||
| """Construct a data set from a list of sentences. | |||||
| :param sentences: list of list of str | |||||
| :return dataset: a DataSet object | |||||
| """ | |||||
| dataset = DataSet() | |||||
| for sentence in sentences: | |||||
| instance = Instance() | |||||
| instance['raw_sentence'] = sentence | |||||
| dataset.append(instance) | |||||
| return dataset | |||||
| class DataSet(object): | class DataSet(object): | ||||
| """DataSet is the collection of examples. | """DataSet is the collection of examples. | ||||
| DataSet provides instance-level interface. You can append and access an instance of the DataSet. | DataSet provides instance-level interface. You can append and access an instance of the DataSet. | ||||
| However, it stores data in a different way: Field-first, Instance-second. | However, it stores data in a different way: Field-first, Instance-second. | ||||
| """ | """ | ||||
| def __init__(self, data=None): | def __init__(self, data=None): | ||||
| """ | """ | ||||
| :param data: a dict or a list. If it is a dict, the key is the name of a field and the value is the field. | |||||
| All values must be of the same length. | |||||
| If it is a list, it must be a list of Instance objects. | |||||
| :param data: a dict or a list. | |||||
| If `data` is a dict, the key is the name of a FieldArray and the value is the FieldArray. All values | |||||
| must be of the same length. | |||||
| If `data` is a list, it must be a list of Instance objects. | |||||
| """ | """ | ||||
| self.field_arrays = {} | self.field_arrays = {} | ||||
| if data is not None: | if data is not None: | ||||
| @@ -60,6 +48,7 @@ class DataSet(object): | |||||
| def iter_func(): | def iter_func(): | ||||
| for idx in range(len(self)): | for idx in range(len(self)): | ||||
| yield self[idx] | yield self[idx] | ||||
| return iter_func() | return iter_func() | ||||
| def _inner_iter(self): | def _inner_iter(self): | ||||
| @@ -69,7 +58,8 @@ class DataSet(object): | |||||
| self.idx = idx | self.idx = idx | ||||
| def __getitem__(self, item): | def __getitem__(self, item): | ||||
| assert item in self.dataset.field_arrays, "no such field:{} in Instance {}".format(item, self.dataset[self.idx]) | |||||
| assert item in self.dataset.field_arrays, "no such field:{} in Instance {}".format(item, self.dataset[ | |||||
| self.idx]) | |||||
| assert self.idx < len(self.dataset.field_arrays[item]), "index:{} out of range".format(self.idx) | assert self.idx < len(self.dataset.field_arrays[item]), "index:{} out of range".format(self.idx) | ||||
| return self.dataset.field_arrays[item][self.idx] | return self.dataset.field_arrays[item][self.idx] | ||||
| @@ -79,6 +69,7 @@ class DataSet(object): | |||||
| def inner_iter_func(): | def inner_iter_func(): | ||||
| for idx in range(len(self)): | for idx in range(len(self)): | ||||
| yield Iter_ptr(self, idx) | yield Iter_ptr(self, idx) | ||||
| return inner_iter_func() | return inner_iter_func() | ||||
| def __getitem__(self, idx): | def __getitem__(self, idx): | ||||
| @@ -217,9 +208,17 @@ class DataSet(object): | |||||
| raise KeyError("{} is not a valid field name.".format(name)) | raise KeyError("{} is not a valid field name.".format(name)) | ||||
| def get_input_name(self): | def get_input_name(self): | ||||
| """Get all field names with `is_input` as True. | |||||
| :return list field_names: a list of str | |||||
| """ | |||||
| return [name for name, field in self.field_arrays.items() if field.is_input] | return [name for name, field in self.field_arrays.items() if field.is_input] | ||||
| def get_target_name(self): | def get_target_name(self): | ||||
| """Get all field names with `is_target` as True. | |||||
| :return list field_names: a list of str | |||||
| """ | |||||
| return [name for name, field in self.field_arrays.items() if field.is_target] | return [name for name, field in self.field_arrays.items() if field.is_target] | ||||
| @classmethod | @classmethod | ||||
| @@ -243,7 +242,7 @@ class DataSet(object): | |||||
| :return results: if new_field_name is not passed, returned values of the function over all instances. | :return results: if new_field_name is not passed, returned values of the function over all instances. | ||||
| """ | """ | ||||
| results = [func(ins) for ins in self._inner_iter()] | results = [func(ins) for ins in self._inner_iter()] | ||||
| if len(list(filter(lambda x: x is not None, results)))==0: # all None | |||||
| if len(list(filter(lambda x: x is not None, results))) == 0: # all None | |||||
| raise ValueError("{} always return None.".format(get_func_signature(func=func))) | raise ValueError("{} always return None.".format(get_func_signature(func=func))) | ||||
| extra_param = {} | extra_param = {} | ||||
| @@ -269,6 +268,12 @@ class DataSet(object): | |||||
| return results | return results | ||||
| def drop(self, func): | def drop(self, func): | ||||
| """Drop instances if a condition holds. | |||||
| :param func: a function that takes an Instance object as input, and returns bool. | |||||
| The instance will be dropped if the function returns True. | |||||
| """ | |||||
| results = [ins for ins in self._inner_iter() if not func(ins)] | results = [ins for ins in self._inner_iter() if not func(ins)] | ||||
| for name, old_field in self.field_arrays.items(): | for name, old_field in self.field_arrays.items(): | ||||
| self.field_arrays[name].content = [ins[name] for ins in results] | self.field_arrays[name].content = [ins[name] for ins in results] | ||||
| @@ -338,10 +343,33 @@ class DataSet(object): | |||||
| return cls(_dict) | return cls(_dict) | ||||
| def save(self, path): | def save(self, path): | ||||
| """Save the DataSet object as pickle. | |||||
| :param str path: the path to the pickle | |||||
| """ | |||||
| with open(path, 'wb') as f: | with open(path, 'wb') as f: | ||||
| pickle.dump(self, f) | pickle.dump(self, f) | ||||
| @staticmethod | @staticmethod | ||||
| def load(path): | def load(path): | ||||
| """Load a DataSet object from pickle. | |||||
| :param str path: the path to the pickle | |||||
| :return DataSet data_set: | |||||
| """ | |||||
| with open(path, 'rb') as f: | with open(path, 'rb') as f: | ||||
| return pickle.load(f) | return pickle.load(f) | ||||
| def construct_dataset(sentences): | |||||
| """Construct a data set from a list of sentences. | |||||
| :param sentences: list of list of str | |||||
| :return dataset: a DataSet object | |||||
| """ | |||||
| dataset = DataSet() | |||||
| for sentence in sentences: | |||||
| instance = Instance() | |||||
| instance['raw_sentence'] = sentence | |||||
| dataset.append(instance) | |||||
| return dataset | |||||
| @@ -7,14 +7,13 @@ import torch.nn.functional as F | |||||
| from fastNLP.core.utils import CheckError | from fastNLP.core.utils import CheckError | ||||
| from fastNLP.core.utils import CheckRes | from fastNLP.core.utils import CheckRes | ||||
| from fastNLP.core.utils import _build_args | from fastNLP.core.utils import _build_args | ||||
| from fastNLP.core.utils import _check_function_or_method | |||||
| from fastNLP.core.utils import _check_arg_dict_list | from fastNLP.core.utils import _check_arg_dict_list | ||||
| from fastNLP.core.utils import _check_function_or_method | |||||
| from fastNLP.core.utils import get_func_signature | from fastNLP.core.utils import get_func_signature | ||||
| class LossBase(object): | class LossBase(object): | ||||
| def __init__(self): | def __init__(self): | ||||
| # key: name in target function; value: name in output function | |||||
| self.param_map = {} | self.param_map = {} | ||||
| self._checked = False | self._checked = False | ||||
| @@ -159,8 +158,18 @@ class LossBase(object): | |||||
| return loss | return loss | ||||
| class LossFunc(LossBase): | class LossFunc(LossBase): | ||||
| """A wrapper of user-provided loss function. | |||||
| """ | |||||
| def __init__(self, func, key_map=None, **kwargs): | def __init__(self, func, key_map=None, **kwargs): | ||||
| """ | |||||
| :param func: a callable object, such as a function. | |||||
| :param dict key_map: | |||||
| :param kwargs: | |||||
| """ | |||||
| super(LossFunc, self).__init__() | super(LossFunc, self).__init__() | ||||
| _check_function_or_method(func) | _check_function_or_method(func) | ||||
| if key_map is not None: | if key_map is not None: | ||||
| @@ -254,19 +263,19 @@ def _prepare_losser(losser): | |||||
| def squash(predict, truth, **kwargs): | def squash(predict, truth, **kwargs): | ||||
| '''To reshape tensors in order to fit Loss functions in pytorch | |||||
| """To reshape tensors in order to fit loss functions in pytorch | |||||
| :param predict : Tensor, model output | :param predict : Tensor, model output | ||||
| :param truth : Tensor, truth from dataset | :param truth : Tensor, truth from dataset | ||||
| :param **kwargs : extra arguments | :param **kwargs : extra arguments | ||||
| :return predict , truth: predict & truth after processing | :return predict , truth: predict & truth after processing | ||||
| ''' | |||||
| """ | |||||
| return predict.view(-1, predict.size()[-1]), truth.view(-1, ) | return predict.view(-1, predict.size()[-1]), truth.view(-1, ) | ||||
| def unpad(predict, truth, **kwargs): | def unpad(predict, truth, **kwargs): | ||||
| '''To process padded sequence output to get true loss | |||||
| """To process padded sequence output to get true loss | |||||
| Using pack_padded_sequence() method | Using pack_padded_sequence() method | ||||
| This method contains squash() | This method contains squash() | ||||
| @@ -277,7 +286,7 @@ def unpad(predict, truth, **kwargs): | |||||
| the i-th element is true lengths of i-th sequence | the i-th element is true lengths of i-th sequence | ||||
| :return predict , truth: predict & truth after processing | :return predict , truth: predict & truth after processing | ||||
| ''' | |||||
| """ | |||||
| if kwargs.get("lens") is None: | if kwargs.get("lens") is None: | ||||
| return predict, truth | return predict, truth | ||||
| lens = torch.LongTensor(kwargs["lens"]) | lens = torch.LongTensor(kwargs["lens"]) | ||||
| @@ -288,7 +297,7 @@ def unpad(predict, truth, **kwargs): | |||||
| def unpad_mask(predict, truth, **kwargs): | def unpad_mask(predict, truth, **kwargs): | ||||
| '''To process padded sequence output to get true loss | |||||
| """To process padded sequence output to get true loss | |||||
| Using mask() method | Using mask() method | ||||
| This method contains squash() | This method contains squash() | ||||
| @@ -299,7 +308,7 @@ def unpad_mask(predict, truth, **kwargs): | |||||
| the i-th element is true lengths of i-th sequence | the i-th element is true lengths of i-th sequence | ||||
| :return predict , truth: predict & truth after processing | :return predict , truth: predict & truth after processing | ||||
| ''' | |||||
| """ | |||||
| if kwargs.get("lens") is None: | if kwargs.get("lens") is None: | ||||
| return predict, truth | return predict, truth | ||||
| mas = make_mask(kwargs["lens"], truth.size()[1]) | mas = make_mask(kwargs["lens"], truth.size()[1]) | ||||
| @@ -307,7 +316,7 @@ def unpad_mask(predict, truth, **kwargs): | |||||
| def mask(predict, truth, **kwargs): | def mask(predict, truth, **kwargs): | ||||
| '''To select specific elements from Tensor | |||||
| """To select specific elements from Tensor | |||||
| This method contains squash() | This method contains squash() | ||||
| :param predict : Tensor, [batch_size , max_len , tag_size] | :param predict : Tensor, [batch_size , max_len , tag_size] | ||||
| @@ -317,7 +326,7 @@ def mask(predict, truth, **kwargs): | |||||
| the mask Tensor , the position that is 1 will be selected | the mask Tensor , the position that is 1 will be selected | ||||
| :return predict , truth: predict & truth after processing | :return predict , truth: predict & truth after processing | ||||
| ''' | |||||
| """ | |||||
| if kwargs.get("mask") is None: | if kwargs.get("mask") is None: | ||||
| return predict, truth | return predict, truth | ||||
| mask = kwargs["mask"] | mask = kwargs["mask"] | ||||
| @@ -332,14 +341,14 @@ def mask(predict, truth, **kwargs): | |||||
| def make_mask(lens, tar_len): | def make_mask(lens, tar_len): | ||||
| '''to generate a mask that select [:lens[i]] for i-th element | |||||
| """to generate a mask that select [:lens[i]] for i-th element | |||||
| embezzle from fastNLP.models.sequence_modeling.seq_mask | embezzle from fastNLP.models.sequence_modeling.seq_mask | ||||
| :param lens : list or LongTensor, [batch_size] | :param lens : list or LongTensor, [batch_size] | ||||
| :param tar_len : int | :param tar_len : int | ||||
| :return mask : ByteTensor | :return mask : ByteTensor | ||||
| ''' | |||||
| """ | |||||
| lens = torch.LongTensor(lens) | lens = torch.LongTensor(lens) | ||||
| mask = [torch.ge(lens, i + 1) for i in range(tar_len)] | mask = [torch.ge(lens, i + 1) for i in range(tar_len)] | ||||
| mask = torch.stack(mask, 1) | mask = torch.stack(mask, 1) | ||||
| @@ -376,9 +385,11 @@ loss_function_name = { | |||||
| } | } | ||||
| class Loss(object): | |||||
| """a Loss object is a callable object represents loss functions | |||||
| class LossFromTorch(object): | |||||
| """a LossFromTorch object is a callable object represents loss functions | |||||
| This class only helps you with loss functions from PyTorch. | |||||
| It has nothing to do with Trainer. | |||||
| """ | """ | ||||
| def __init__(self, loss_name, pre_pro=[squash], **kwargs): | def __init__(self, loss_name, pre_pro=[squash], **kwargs): | ||||
| @@ -408,11 +419,11 @@ class Loss(object): | |||||
| self.pre_pro = [f if callable(f) else method_dict.get(f) for f in pre_pro] | self.pre_pro = [f if callable(f) else method_dict.get(f) for f in pre_pro] | ||||
| def add_pre_pro(self, func): | def add_pre_pro(self, func): | ||||
| '''add a pre_pro function | |||||
| """add a pre_pro function | |||||
| :param func: a function or str, methods to reform parameters before calculating loss | :param func: a function or str, methods to reform parameters before calculating loss | ||||
| the strings will be auto translated to pre-defined functions | the strings will be auto translated to pre-defined functions | ||||
| ''' | |||||
| """ | |||||
| if not callable(func): | if not callable(func): | ||||
| func = method_dict.get(func) | func = method_dict.get(func) | ||||
| if func is None: | if func is None: | ||||
| @@ -421,12 +432,12 @@ class Loss(object): | |||||
| @staticmethod | @staticmethod | ||||
| def _get_loss(loss_name, **kwargs): | def _get_loss(loss_name, **kwargs): | ||||
| '''Get loss function from torch | |||||
| """Get loss function from torch | |||||
| :param loss_name: str, the name of loss function | :param loss_name: str, the name of loss function | ||||
| :param **kwargs: kwargs for torch loss function | :param **kwargs: kwargs for torch loss function | ||||
| :return: A callable loss function object | :return: A callable loss function object | ||||
| ''' | |||||
| """ | |||||
| loss_name = loss_name.strip().lower() | loss_name = loss_name.strip().lower() | ||||
| loss_name = "".join(loss_name.split("_")) | loss_name = "".join(loss_name.split("_")) | ||||
| @@ -435,19 +446,19 @@ class Loss(object): | |||||
| return loss_function_name[loss_name](**kwargs) | return loss_function_name[loss_name](**kwargs) | ||||
| def get(self): | def get(self): | ||||
| '''This method exists just for make some existing codes run error-freely | |||||
| ''' | |||||
| """This method exists just for make some existing codes run error-freely | |||||
| """ | |||||
| return self | return self | ||||
| def __call__(self, predict, truth, **kwargs): | def __call__(self, predict, truth, **kwargs): | ||||
| '''call a loss function | |||||
| """Call a loss function | |||||
| predict and truth will be processed by pre_pro methods in order of addition | predict and truth will be processed by pre_pro methods in order of addition | ||||
| :param predict : Tensor, model output | :param predict : Tensor, model output | ||||
| :param truth : Tensor, truth from dataset | :param truth : Tensor, truth from dataset | ||||
| :param **kwargs : extra arguments, pass to pre_pro functions | :param **kwargs : extra arguments, pass to pre_pro functions | ||||
| for example, if used unpad_mask() in pre_pro, there should be a kwarg named lens | for example, if used unpad_mask() in pre_pro, there should be a kwarg named lens | ||||
| ''' | |||||
| """ | |||||
| for f in self.pre_pro: | for f in self.pre_pro: | ||||
| if f is None: | if f is None: | ||||
| continue | continue | ||||
| @@ -308,6 +308,13 @@ def _prepare_metrics(metrics): | |||||
| return _metrics | return _metrics | ||||
| """ | |||||
| Attention: Codes below are not used in current FastNLP. | |||||
| However, it is useful. | |||||
| """ | |||||
| def _conver_numpy(x): | def _conver_numpy(x): | ||||
| """convert input data to numpy array | """convert input data to numpy array | ||||
| @@ -11,6 +11,12 @@ class Optimizer(object): | |||||
| class SGD(Optimizer): | class SGD(Optimizer): | ||||
| def __init__(self, model_params=None, lr=0.01, momentum=0): | def __init__(self, model_params=None, lr=0.01, momentum=0): | ||||
| """ | |||||
| :param model_params: a generator. E.g. model.parameters() for PyTorch models. | |||||
| :param float lr: learning rate. Default: 0.01 | |||||
| :param float momentum: momentum. Default: 0 | |||||
| """ | |||||
| super(SGD, self).__init__(model_params, lr=lr, momentum=momentum) | super(SGD, self).__init__(model_params, lr=lr, momentum=momentum) | ||||
| def construct_from_pytorch(self, model_params): | def construct_from_pytorch(self, model_params): | ||||
| @@ -23,6 +29,12 @@ class SGD(Optimizer): | |||||
| class Adam(Optimizer): | class Adam(Optimizer): | ||||
| def __init__(self, model_params=None, lr=0.01, weight_decay=0): | def __init__(self, model_params=None, lr=0.01, weight_decay=0): | ||||
| """ | |||||
| :param model_params: a generator. E.g. model.parameters() for PyTorch models. | |||||
| :param float lr: learning rate | |||||
| :param float weight_decay: | |||||
| """ | |||||
| super(Adam, self).__init__(model_params, lr=lr, weight_decay=weight_decay) | super(Adam, self).__init__(model_params, lr=lr, weight_decay=weight_decay) | ||||
| def construct_from_pytorch(self, model_params): | def construct_from_pytorch(self, model_params): | ||||
| @@ -140,7 +140,6 @@ class Trainer(object): | |||||
| def train(self): | def train(self): | ||||
| """Start Training. | """Start Training. | ||||
| :return: | |||||
| """ | """ | ||||
| try: | try: | ||||
| if torch.cuda.is_available() and self.use_cuda: | if torch.cuda.is_available() and self.use_cuda: | ||||
| @@ -216,14 +215,6 @@ class Trainer(object): | |||||
| pbar.close() | pbar.close() | ||||
| def _print_train(self): | def _print_train(self): | ||||
| """ | |||||
| :param data_iterator: | |||||
| :param model: | |||||
| :param epoch: | |||||
| :param start: | |||||
| :return: | |||||
| """ | |||||
| epoch = 1 | epoch = 1 | ||||
| start = time.time() | start = time.time() | ||||
| while epoch <= self.n_epochs: | while epoch <= self.n_epochs: | ||||
| @@ -29,19 +29,3 @@ class BaseLoader(object): | |||||
| with open(cache_path, 'wb') as f: | with open(cache_path, 'wb') as f: | ||||
| pickle.dump(obj, f) | pickle.dump(obj, f) | ||||
| return obj | return obj | ||||
| class ToyLoader0(BaseLoader): | |||||
| """ | |||||
| For CharLM | |||||
| """ | |||||
| def __init__(self, data_path): | |||||
| super(ToyLoader0, self).__init__(data_path) | |||||
| def load(self): | |||||
| with open(self.data_path, 'r') as f: | |||||
| corpus = f.read().lower() | |||||
| import re | |||||
| corpus = re.sub(r"<unk>", "unk", corpus) | |||||
| return corpus.split() | |||||
| @@ -1,6 +1,152 @@ | |||||
| import configparser | |||||
| import json | |||||
| import os | import os | ||||
| from fastNLP.io.config_loader import ConfigSection, ConfigLoader | |||||
| from fastNLP.io.base_loader import BaseLoader | |||||
| class ConfigLoader(BaseLoader): | |||||
| """loader for configuration files""" | |||||
| def __init__(self, data_path=None): | |||||
| super(ConfigLoader, self).__init__() | |||||
| if data_path is not None: | |||||
| self.config = self.parse(super(ConfigLoader, self).load(data_path)) | |||||
| @staticmethod | |||||
| def parse(string): | |||||
| raise NotImplementedError | |||||
| @staticmethod | |||||
| def load_config(file_path, sections): | |||||
| """ | |||||
| :param file_path: the path of config file | |||||
| :param sections: the dict of {section_name(string): Section instance} | |||||
| Example: | |||||
| test_args = ConfigSection() | |||||
| ConfigLoader("config.cfg", "").load_config("./data_for_tests/config", {"POS_test": test_args}) | |||||
| :return: return nothing, but the value of attributes are saved in sessions | |||||
| """ | |||||
| assert isinstance(sections, dict) | |||||
| cfg = configparser.ConfigParser() | |||||
| if not os.path.exists(file_path): | |||||
| raise FileNotFoundError("config file {} not found. ".format(file_path)) | |||||
| cfg.read(file_path) | |||||
| for s in sections: | |||||
| attr_list = [i for i in sections[s].__dict__.keys() if | |||||
| not callable(getattr(sections[s], i)) and not i.startswith("__")] | |||||
| if s not in cfg: | |||||
| print('section %s not found in config file' % (s)) | |||||
| continue | |||||
| gen_sec = cfg[s] | |||||
| for attr in gen_sec.keys(): | |||||
| try: | |||||
| val = json.loads(gen_sec[attr]) | |||||
| # print(s, attr, val, type(val)) | |||||
| if attr in attr_list: | |||||
| assert type(val) == type(getattr(sections[s], attr)), \ | |||||
| 'type not match, except %s but got %s' % \ | |||||
| (type(getattr(sections[s], attr)), type(val)) | |||||
| """ | |||||
| if attr in attr_list then check its type and | |||||
| update its value. | |||||
| else add a new attr in sections[s] | |||||
| """ | |||||
| setattr(sections[s], attr, val) | |||||
| except Exception as e: | |||||
| print("cannot load attribute %s in section %s" | |||||
| % (attr, s)) | |||||
| pass | |||||
| class ConfigSection(object): | |||||
| def __init__(self): | |||||
| pass | |||||
| def __getitem__(self, key): | |||||
| """ | |||||
| :param key: str, the name of the attribute | |||||
| :return attr: the value of this attribute | |||||
| if key not in self.__dict__.keys(): | |||||
| return self[key] | |||||
| else: | |||||
| raise AttributeError | |||||
| """ | |||||
| if key in self.__dict__.keys(): | |||||
| return getattr(self, key) | |||||
| raise AttributeError("do NOT have attribute %s" % key) | |||||
| def __setitem__(self, key, value): | |||||
| """ | |||||
| :param key: str, the name of the attribute | |||||
| :param value: the value of this attribute | |||||
| if key not in self.__dict__.keys(): | |||||
| self[key] will be added | |||||
| else: | |||||
| self[key] will be updated | |||||
| """ | |||||
| if key in self.__dict__.keys(): | |||||
| if not isinstance(value, type(getattr(self, key))): | |||||
| raise AttributeError("attr %s except %s but got %s" % | |||||
| (key, str(type(getattr(self, key))), str(type(value)))) | |||||
| setattr(self, key, value) | |||||
| def __contains__(self, item): | |||||
| """ | |||||
| :param item: The key of item. | |||||
| :return: True if the key in self.__dict__.keys() else False. | |||||
| """ | |||||
| return item in self.__dict__.keys() | |||||
| def __eq__(self, other): | |||||
| """Overwrite the == operator | |||||
| :param other: Another ConfigSection() object which to be compared. | |||||
| :return: True if value of each key in each ConfigSection() object are equal to the other, else False. | |||||
| """ | |||||
| for k in self.__dict__.keys(): | |||||
| if k not in other.__dict__.keys(): | |||||
| return False | |||||
| if getattr(self, k) != getattr(self, k): | |||||
| return False | |||||
| for k in other.__dict__.keys(): | |||||
| if k not in self.__dict__.keys(): | |||||
| return False | |||||
| if getattr(self, k) != getattr(self, k): | |||||
| return False | |||||
| return True | |||||
| def __ne__(self, other): | |||||
| """Overwrite the != operator | |||||
| :param other: | |||||
| :return: | |||||
| """ | |||||
| return not self.__eq__(other) | |||||
| @property | |||||
| def data(self): | |||||
| return self.__dict__ | |||||
| if __name__ == "__main__": | |||||
| config = ConfigLoader('there is no data') | |||||
| section = {'General': ConfigSection(), 'My': ConfigSection(), 'A': ConfigSection()} | |||||
| """ | |||||
| General and My can be found in config file, so the attr and | |||||
| value will be updated | |||||
| A cannot be found in config file, so nothing will be done | |||||
| """ | |||||
| config.load_config("../../test/data_for_tests/config", section) | |||||
| for s in section: | |||||
| print(s) | |||||
| for attr in section[s].__dict__.keys(): | |||||
| print(s, attr, getattr(section[s], attr), type(getattr(section[s], attr))) | |||||
| class ConfigSaver(object): | class ConfigSaver(object): | ||||
| @@ -125,7 +271,7 @@ class ConfigSaver(object): | |||||
| # logger = create_logger(__name__, "./config_loader.log") | # logger = create_logger(__name__, "./config_loader.log") | ||||
| # logger.warning("section [%s] in config file [%s] has been changed" % ( | # logger.warning("section [%s] in config file [%s] has been changed" % ( | ||||
| # section_name, self.file_path | # section_name, self.file_path | ||||
| #)) | |||||
| # )) | |||||
| change_file = True | change_file = True | ||||
| break | break | ||||
| if not change_file: | if not change_file: | ||||
| @@ -1,149 +0,0 @@ | |||||
| import configparser | |||||
| import json | |||||
| import os | |||||
| from fastNLP.io.base_loader import BaseLoader | |||||
| class ConfigLoader(BaseLoader): | |||||
| """loader for configuration files""" | |||||
| def __init__(self, data_path=None): | |||||
| super(ConfigLoader, self).__init__() | |||||
| if data_path is not None: | |||||
| self.config = self.parse(super(ConfigLoader, self).load(data_path)) | |||||
| @staticmethod | |||||
| def parse(string): | |||||
| raise NotImplementedError | |||||
| @staticmethod | |||||
| def load_config(file_path, sections): | |||||
| """ | |||||
| :param file_path: the path of config file | |||||
| :param sections: the dict of {section_name(string): Section instance} | |||||
| Example: | |||||
| test_args = ConfigSection() | |||||
| ConfigLoader("config.cfg", "").load_config("./data_for_tests/config", {"POS_test": test_args}) | |||||
| :return: return nothing, but the value of attributes are saved in sessions | |||||
| """ | |||||
| assert isinstance(sections, dict) | |||||
| cfg = configparser.ConfigParser() | |||||
| if not os.path.exists(file_path): | |||||
| raise FileNotFoundError("config file {} not found. ".format(file_path)) | |||||
| cfg.read(file_path) | |||||
| for s in sections: | |||||
| attr_list = [i for i in sections[s].__dict__.keys() if | |||||
| not callable(getattr(sections[s], i)) and not i.startswith("__")] | |||||
| if s not in cfg: | |||||
| print('section %s not found in config file' % (s)) | |||||
| continue | |||||
| gen_sec = cfg[s] | |||||
| for attr in gen_sec.keys(): | |||||
| try: | |||||
| val = json.loads(gen_sec[attr]) | |||||
| # print(s, attr, val, type(val)) | |||||
| if attr in attr_list: | |||||
| assert type(val) == type(getattr(sections[s], attr)), \ | |||||
| 'type not match, except %s but got %s' % \ | |||||
| (type(getattr(sections[s], attr)), type(val)) | |||||
| """ | |||||
| if attr in attr_list then check its type and | |||||
| update its value. | |||||
| else add a new attr in sections[s] | |||||
| """ | |||||
| setattr(sections[s], attr, val) | |||||
| except Exception as e: | |||||
| print("cannot load attribute %s in section %s" | |||||
| % (attr, s)) | |||||
| pass | |||||
| class ConfigSection(object): | |||||
| def __init__(self): | |||||
| pass | |||||
| def __getitem__(self, key): | |||||
| """ | |||||
| :param key: str, the name of the attribute | |||||
| :return attr: the value of this attribute | |||||
| if key not in self.__dict__.keys(): | |||||
| return self[key] | |||||
| else: | |||||
| raise AttributeError | |||||
| """ | |||||
| if key in self.__dict__.keys(): | |||||
| return getattr(self, key) | |||||
| raise AttributeError("do NOT have attribute %s" % key) | |||||
| def __setitem__(self, key, value): | |||||
| """ | |||||
| :param key: str, the name of the attribute | |||||
| :param value: the value of this attribute | |||||
| if key not in self.__dict__.keys(): | |||||
| self[key] will be added | |||||
| else: | |||||
| self[key] will be updated | |||||
| """ | |||||
| if key in self.__dict__.keys(): | |||||
| if not isinstance(value, type(getattr(self, key))): | |||||
| raise AttributeError("attr %s except %s but got %s" % | |||||
| (key, str(type(getattr(self, key))), str(type(value)))) | |||||
| setattr(self, key, value) | |||||
| def __contains__(self, item): | |||||
| """ | |||||
| :param item: The key of item. | |||||
| :return: True if the key in self.__dict__.keys() else False. | |||||
| """ | |||||
| return item in self.__dict__.keys() | |||||
| def __eq__(self, other): | |||||
| """Overwrite the == operator | |||||
| :param other: Another ConfigSection() object which to be compared. | |||||
| :return: True if value of each key in each ConfigSection() object are equal to the other, else False. | |||||
| """ | |||||
| for k in self.__dict__.keys(): | |||||
| if k not in other.__dict__.keys(): | |||||
| return False | |||||
| if getattr(self, k) != getattr(self, k): | |||||
| return False | |||||
| for k in other.__dict__.keys(): | |||||
| if k not in self.__dict__.keys(): | |||||
| return False | |||||
| if getattr(self, k) != getattr(self, k): | |||||
| return False | |||||
| return True | |||||
| def __ne__(self, other): | |||||
| """Overwrite the != operator | |||||
| :param other: | |||||
| :return: | |||||
| """ | |||||
| return not self.__eq__(other) | |||||
| @property | |||||
| def data(self): | |||||
| return self.__dict__ | |||||
| if __name__ == "__main__": | |||||
| config = ConfigLoader('there is no data') | |||||
| section = {'General': ConfigSection(), 'My': ConfigSection(), 'A': ConfigSection()} | |||||
| """ | |||||
| General and My can be found in config file, so the attr and | |||||
| value will be updated | |||||
| A cannot be found in config file, so nothing will be done | |||||
| """ | |||||
| config.load_config("../../test/data_for_tests/config", section) | |||||
| for s in section: | |||||
| print(s) | |||||
| for attr in section[s].__dict__.keys(): | |||||
| print(s, attr, getattr(section[s], attr), type(getattr(section[s], attr))) | |||||
| @@ -1,4 +1,3 @@ | |||||
| #TODO: need fix for current DataSet | |||||
| import os | import os | ||||
| from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
| @@ -20,8 +19,7 @@ def convert_seq_dataset(data): | |||||
| """ | """ | ||||
| dataset = DataSet() | dataset = DataSet() | ||||
| for word_seq in data: | for word_seq in data: | ||||
| x = TextField(word_seq, is_target=False) | |||||
| dataset.append(Instance(word_seq=x)) | |||||
| dataset.append(Instance(word_seq=word_seq)) | |||||
| return dataset | return dataset | ||||
| @@ -40,11 +38,7 @@ def convert_seq2tag_dataset(data): | |||||
| """ | """ | ||||
| dataset = DataSet() | dataset = DataSet() | ||||
| for sample in data: | for sample in data: | ||||
| word_seq, label = sample[0], sample[1] | |||||
| ins = Instance() | |||||
| ins.add_field("word_seq", TextField(word_seq, is_target=False)) \ | |||||
| .add_field("label", LabelField(label, is_target=True)) | |||||
| dataset.append(ins) | |||||
| dataset.append(Instance(word_seq=sample[0], label=sample[1])) | |||||
| return dataset | return dataset | ||||
| @@ -63,11 +57,7 @@ def convert_seq2seq_dataset(data): | |||||
| """ | """ | ||||
| dataset = DataSet() | dataset = DataSet() | ||||
| for sample in data: | for sample in data: | ||||
| word_seq, label_seq = sample[0], sample[1] | |||||
| ins = Instance() | |||||
| ins.add_field("word_seq", TextField(word_seq, is_target=False)) \ | |||||
| .add_field("label_seq", TextField(label_seq, is_target=True)) | |||||
| dataset.append(ins) | |||||
| dataset.append(Instance(word_seq=sample[0], label_seq=sample[1])) | |||||
| return dataset | return dataset | ||||
| @@ -273,85 +263,6 @@ class ClassDataSetLoader(DataSetLoader): | |||||
| return convert_seq2tag_dataset(data) | return convert_seq2tag_dataset(data) | ||||
| @DataSet.set_reader('read_conll') | |||||
| class ConllLoader(DataSetLoader): | |||||
| """loader for conll format files""" | |||||
| def __init__(self): | |||||
| """ | |||||
| :param str data_path: the path to the conll data set | |||||
| """ | |||||
| super(ConllLoader, self).__init__() | |||||
| def load(self, data_path): | |||||
| """ | |||||
| :return: list lines: all lines in a conll file | |||||
| """ | |||||
| with open(data_path, "r", encoding="utf-8") as f: | |||||
| lines = f.readlines() | |||||
| data = self.parse(lines) | |||||
| return self.convert(data) | |||||
| @staticmethod | |||||
| def parse(lines): | |||||
| """ | |||||
| :param list lines:a list containing all lines in a conll file. | |||||
| :return: a 3D list | |||||
| """ | |||||
| sentences = list() | |||||
| tokens = list() | |||||
| for line in lines: | |||||
| if line[0] == "#": | |||||
| # skip the comments | |||||
| continue | |||||
| if line == "\n": | |||||
| sentences.append(tokens) | |||||
| tokens = [] | |||||
| continue | |||||
| tokens.append(line.split()) | |||||
| return sentences | |||||
| def convert(self, data): | |||||
| pass | |||||
| @DataSet.set_reader('read_lm') | |||||
| class LMDataSetLoader(DataSetLoader): | |||||
| """Language Model Dataset Loader | |||||
| This loader produces data for language model training in a supervised way. | |||||
| That means it has X and Y. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(LMDataSetLoader, self).__init__() | |||||
| def load(self, data_path): | |||||
| if not os.path.exists(data_path): | |||||
| raise FileNotFoundError("file {} not found.".format(data_path)) | |||||
| with open(data_path, "r", encoding="utf=8") as f: | |||||
| text = " ".join(f.readlines()) | |||||
| tokens = text.strip().split() | |||||
| data = self.sentence_cut(tokens) | |||||
| return self.convert(data) | |||||
| def sentence_cut(self, tokens, sentence_length=15): | |||||
| start_idx = 0 | |||||
| data_set = [] | |||||
| for idx in range(len(tokens) // sentence_length): | |||||
| x = tokens[start_idx * idx: start_idx * idx + sentence_length] | |||||
| y = tokens[start_idx * idx + 1: start_idx * idx + sentence_length + 1] | |||||
| if start_idx * idx + sentence_length + 1 >= len(tokens): | |||||
| # ad hoc | |||||
| y.extend(["<unk>"]) | |||||
| data_set.append([x, y]) | |||||
| return data_set | |||||
| def convert(self, data): | |||||
| pass | |||||
| @DataSet.set_reader('read_people_daily') | @DataSet.set_reader('read_people_daily') | ||||
| class PeopleDailyCorpusLoader(DataSetLoader): | class PeopleDailyCorpusLoader(DataSetLoader): | ||||
| """ | """ | ||||
| @@ -403,10 +314,19 @@ class PeopleDailyCorpusLoader(DataSetLoader): | |||||
| pos_tag_examples.append([sent_words, sent_pos_tag]) | pos_tag_examples.append([sent_words, sent_pos_tag]) | ||||
| ner_examples.append([sent_words, sent_ner]) | ner_examples.append([sent_words, sent_ner]) | ||||
| # List[List[List[str], List[str]]] | # List[List[List[str], List[str]]] | ||||
| return pos_tag_examples, ner_examples | |||||
| # ner_examples not used | |||||
| return self.convert(pos_tag_examples) | |||||
| def convert(self, data): | def convert(self, data): | ||||
| pass | |||||
| data_set = DataSet() | |||||
| for item in data: | |||||
| sent_words, sent_pos_tag = item[0], item[1] | |||||
| data_set.append(Instance(words=sent_words, tags=sent_pos_tag)) | |||||
| data_set.apply(lambda ins: len(ins), new_field_name="seq_len") | |||||
| data_set.set_target("tags") | |||||
| data_set.set_input("sent_words") | |||||
| data_set.set_input("seq_len") | |||||
| return data_set | |||||
| class SNLIDataSetLoader(DataSetLoader): | class SNLIDataSetLoader(DataSetLoader): | ||||
| @@ -462,17 +382,13 @@ class SNLIDataSetLoader(DataSetLoader): | |||||
| for example in data: | for example in data: | ||||
| p, h, l = example | p, h, l = example | ||||
| # list, list, str | # list, list, str | ||||
| x1 = TextField(p, is_target=False) | |||||
| x2 = TextField(h, is_target=False) | |||||
| x1_len = TextField([1] * len(p), is_target=False) | |||||
| x2_len = TextField([1] * len(h), is_target=False) | |||||
| y = LabelField(l, is_target=True) | |||||
| instance = Instance() | instance = Instance() | ||||
| instance.add_field("premise", x1) | |||||
| instance.add_field("hypothesis", x2) | |||||
| instance.add_field("premise_len", x1_len) | |||||
| instance.add_field("hypothesis_len", x2_len) | |||||
| instance.add_field("truth", y) | |||||
| instance.add_field("premise", p) | |||||
| instance.add_field("hypothesis", h) | |||||
| instance.add_field("truth", l) | |||||
| data_set.append(instance) | data_set.append(instance) | ||||
| data_set.apply(lambda ins: len(ins["premise"]), new_field_name="premise_len") | |||||
| data_set.apply(lambda ins: len(ins["hypothesis"]), new_field_name="hypothesis_len") | |||||
| data_set.set_input("premise", "hypothesis", "premise_len", "hypothesis_len") | |||||
| data_set.set_target("truth") | |||||
| return data_set | return data_set | ||||
| @@ -1,5 +1,32 @@ | |||||
| import torch | import torch | ||||
| from fastNLP.io.base_loader import BaseLoader | |||||
| class ModelLoader(BaseLoader): | |||||
| """ | |||||
| Loader for models. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(ModelLoader, self).__init__() | |||||
| @staticmethod | |||||
| def load_pytorch(empty_model, model_path): | |||||
| """ | |||||
| Load model parameters from .pkl files into the empty PyTorch model. | |||||
| :param empty_model: a PyTorch model with initialized parameters. | |||||
| :param model_path: str, the path to the saved model. | |||||
| """ | |||||
| empty_model.load_state_dict(torch.load(model_path)) | |||||
| @staticmethod | |||||
| def load_pytorch_model(model_path): | |||||
| """Load the entire model. | |||||
| """ | |||||
| return torch.load(model_path) | |||||
| class ModelSaver(object): | class ModelSaver(object): | ||||
| """Save a model | """Save a model | ||||
| @@ -8,6 +35,7 @@ class ModelSaver(object): | |||||
| saver.save_pytorch(model) | saver.save_pytorch(model) | ||||
| """ | """ | ||||
| def __init__(self, save_path): | def __init__(self, save_path): | ||||
| """ | """ | ||||
| @@ -1,28 +0,0 @@ | |||||
| import torch | |||||
| from fastNLP.io.base_loader import BaseLoader | |||||
| class ModelLoader(BaseLoader): | |||||
| """ | |||||
| Loader for models. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(ModelLoader, self).__init__() | |||||
| @staticmethod | |||||
| def load_pytorch(empty_model, model_path): | |||||
| """ | |||||
| Load model parameters from .pkl files into the empty PyTorch model. | |||||
| :param empty_model: a PyTorch model with initialized parameters. | |||||
| :param model_path: str, the path to the saved model. | |||||
| """ | |||||
| empty_model.load_state_dict(torch.load(model_path)) | |||||
| @staticmethod | |||||
| def load_pytorch_model(model_path): | |||||
| """Load the entire model. | |||||
| """ | |||||
| return torch.load(model_path) | |||||
| @@ -5,7 +5,7 @@ sys.path.extend(['/home/yfshao/workdir/dev_fastnlp']) | |||||
| from fastNLP.api.processor import * | from fastNLP.api.processor import * | ||||
| from fastNLP.models.biaffine_parser import BiaffineParser | from fastNLP.models.biaffine_parser import BiaffineParser | ||||
| from fastNLP.io.config_loader import ConfigSection, ConfigLoader | |||||
| from fastNLP.io.config_io import ConfigSection, ConfigLoader | |||||
| import _pickle as pickle | import _pickle as pickle | ||||
| import torch | import torch | ||||
| @@ -13,11 +13,10 @@ from fastNLP.core.vocabulary import Vocabulary | |||||
| from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
| from fastNLP.core.field import TextField, SeqLabelField | from fastNLP.core.field import TextField, SeqLabelField | ||||
| from fastNLP.core.tester import Tester | from fastNLP.core.tester import Tester | ||||
| from fastNLP.io.config_loader import ConfigLoader, ConfigSection | |||||
| from fastNLP.io.model_loader import ModelLoader | |||||
| from fastNLP.io.config_io import ConfigLoader, ConfigSection | |||||
| from fastNLP.io.model_io import ModelLoader, ModelSaver | |||||
| from fastNLP.io.embed_loader import EmbedLoader | from fastNLP.io.embed_loader import EmbedLoader | ||||
| from fastNLP.models.biaffine_parser import BiaffineParser | from fastNLP.models.biaffine_parser import BiaffineParser | ||||
| from fastNLP.io.model_saver import ModelSaver | |||||
| BOS = '<BOS>' | BOS = '<BOS>' | ||||
| EOS = '<EOS>' | EOS = '<EOS>' | ||||
| @@ -2,8 +2,8 @@ import torch.nn.functional as F | |||||
| from fastNLP.core.trainer import ClassificationTrainer | from fastNLP.core.trainer import ClassificationTrainer | ||||
| from fastNLP.core.utils import ClassPreprocess as Preprocess | from fastNLP.core.utils import ClassPreprocess as Preprocess | ||||
| from fastNLP.io.config_loader import ConfigLoader | |||||
| from fastNLP.io.config_loader import ConfigSection | |||||
| from fastNLP.io.config_io import ConfigLoader | |||||
| from fastNLP.io.config_io import ConfigSection | |||||
| from fastNLP.io.dataset_loader import ClassDataSetLoader as Dataset_loader | from fastNLP.io.dataset_loader import ClassDataSetLoader as Dataset_loader | ||||
| from fastNLP.models.base_model import BaseModel | from fastNLP.models.base_model import BaseModel | ||||
| from fastNLP.modules.aggregator.self_attention import SelfAttention | from fastNLP.modules.aggregator.self_attention import SelfAttention | ||||
| @@ -3,12 +3,11 @@ import sys | |||||
| sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) | sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) | ||||
| from fastNLP.io.config_loader import ConfigLoader, ConfigSection | |||||
| from fastNLP.io.config_io import ConfigLoader, ConfigSection | |||||
| from fastNLP.core.trainer import SeqLabelTrainer | from fastNLP.core.trainer import SeqLabelTrainer | ||||
| from fastNLP.io.dataset_loader import BaseLoader, TokenizeDataSetLoader | from fastNLP.io.dataset_loader import BaseLoader, TokenizeDataSetLoader | ||||
| from fastNLP.core.utils import load_pickle | from fastNLP.core.utils import load_pickle | ||||
| from fastNLP.io.model_saver import ModelSaver | |||||
| from fastNLP.io.model_loader import ModelLoader | |||||
| from fastNLP.io.model_io import ModelLoader, ModelSaver | |||||
| from fastNLP.core.tester import SeqLabelTester | from fastNLP.core.tester import SeqLabelTester | ||||
| from fastNLP.models.sequence_modeling import AdvSeqLabel | from fastNLP.models.sequence_modeling import AdvSeqLabel | ||||
| from fastNLP.core.predictor import SeqLabelInfer | from fastNLP.core.predictor import SeqLabelInfer | ||||
| @@ -12,12 +12,12 @@ with open('requirements.txt', encoding='utf-8') as f: | |||||
| reqs = f.read() | reqs = f.read() | ||||
| setup( | setup( | ||||
| name='fastNLP', | |||||
| name='FastNLP', | |||||
| version='0.1.1', | version='0.1.1', | ||||
| description='fastNLP: Deep Learning Toolkit for NLP, developed by Fudan FastNLP Team', | description='fastNLP: Deep Learning Toolkit for NLP, developed by Fudan FastNLP Team', | ||||
| long_description=readme, | long_description=readme, | ||||
| license=license, | license=license, | ||||
| author='fudanNLP', | |||||
| author='FudanNLP', | |||||
| python_requires='>=3.5', | python_requires='>=3.5', | ||||
| packages=find_packages(), | packages=find_packages(), | ||||
| install_requires=reqs.strip().split('\n'), | install_requires=reqs.strip().split('\n'), | ||||
| @@ -0,0 +1,12 @@ | |||||
| import unittest | |||||
| from fastNLP.api.processor import FullSpaceToHalfSpaceProcessor | |||||
| from fastNLP.core.dataset import DataSet | |||||
| class TestProcessor(unittest.TestCase): | |||||
| def test_FullSpaceToHalfSpaceProcessor(self): | |||||
| ds = DataSet({"word": ["00, u1, u), (u2, u2"]}) | |||||
| proc = FullSpaceToHalfSpaceProcessor("word") | |||||
| ds = proc(ds) | |||||
| self.assertTrue(ds.field_arrays["word"].content, ["00, u1, u), (u2, u2"]) | |||||
| @@ -45,7 +45,7 @@ class TestLoss(unittest.TestCase): | |||||
| # 验证squash()的正确性 | # 验证squash()的正确性 | ||||
| log = math.log | log = math.log | ||||
| loss_func = loss.Loss("nll") | |||||
| loss_func = loss.LossFromTorch("nll") | |||||
| y = tc.Tensor( | y = tc.Tensor( | ||||
| [ | [ | ||||
| @@ -129,7 +129,7 @@ class TestLoss(unittest.TestCase): | |||||
| lens = [4, 2, 1] | lens = [4, 2, 1] | ||||
| y = tc.log(y) | y = tc.log(y) | ||||
| loss_func = loss.Loss("nll", pre_pro=["unpad"]) | |||||
| loss_func = loss.LossFromTorch("nll", pre_pro=["unpad"]) | |||||
| los = loss_func(y, gy, lens=lens) | los = loss_func(y, gy, lens=lens) | ||||
| r = -log(.1) - log(.3) - log(.5) - log(.5) - log(.3) - log(.7) - log(.1) | r = -log(.1) - log(.3) - log(.5) - log(.5) - log(.3) - log(.7) - log(.1) | ||||
| @@ -169,7 +169,7 @@ class TestLoss(unittest.TestCase): | |||||
| lens = [2, 4, 2] | lens = [2, 4, 2] | ||||
| loss_func = loss.Loss("nll", pre_pro=["mask"]) | |||||
| loss_func = loss.LossFromTorch("nll", pre_pro=["mask"]) | |||||
| los = loss_func(y, gy, mask=mask) | los = loss_func(y, gy, mask=mask) | ||||
| los2 = loss_func(y, gy, mask=loss.make_mask(lens, gy.size()[-1])) | los2 = loss_func(y, gy, mask=loss.make_mask(lens, gy.size()[-1])) | ||||
| @@ -205,7 +205,7 @@ class TestLoss(unittest.TestCase): | |||||
| y = tc.log(y) | y = tc.log(y) | ||||
| loss_func = loss.Loss("nll", pre_pro=["unpad_mask"]) | |||||
| loss_func = loss.LossFromTorch("nll", pre_pro=["unpad_mask"]) | |||||
| los = loss_func(y, gy, lens=lens) | los = loss_func(y, gy, lens=lens) | ||||
| r = -log(.1) - log(.3) - log(.5) - log(.5) - log(.3) - log(.7) - log(.1) | r = -log(.1) - log(.3) - log(.5) - log(.5) - log(.3) - log(.7) - log(.1) | ||||
| @@ -235,7 +235,7 @@ class TestLoss(unittest.TestCase): | |||||
| lens = [4, 2, 1] | lens = [4, 2, 1] | ||||
| y = tc.log(y) | y = tc.log(y) | ||||
| loss_func = loss.Loss("nll", pre_pro=[], weight=tc.Tensor([1, 1, 0])) | |||||
| loss_func = loss.LossFromTorch("nll", pre_pro=[], weight=tc.Tensor([1, 1, 0])) | |||||
| loss_func.add_pre_pro("unpad_mask") | loss_func.add_pre_pro("unpad_mask") | ||||
| los = loss_func(y, gy, lens=lens) | los = loss_func(y, gy, lens=lens) | ||||
| @@ -1,8 +1,7 @@ | |||||
| import os | import os | ||||
| import unittest | import unittest | ||||
| from fastNLP.io.config_loader import ConfigSection, ConfigLoader | |||||
| from fastNLP.io.config_saver import ConfigSaver | |||||
| from fastNLP.io.config_io import ConfigSection, ConfigLoader, ConfigSaver | |||||
| class TestConfigSaver(unittest.TestCase): | class TestConfigSaver(unittest.TestCase): | ||||