- DataSet's __init__ takes a function as argument, rather than class object - Preprocessor is about to remove. Don't use anymore. - Remove cross_validate in trainer, because it is rarely used and wired - Loader.load is expected to be a static method - Delete sth. in other_modules.py - Add more tests - Delete extra sample datatags/v0.1.0^2
| @@ -70,18 +70,18 @@ class DataSet(list): | |||
| """ | |||
| def __init__(self, name="", instances=None, loader=None): | |||
| def __init__(self, name="", instances=None, load_func=None): | |||
| """ | |||
| :param name: str, the name of the dataset. (default: "") | |||
| :param instances: list of Instance objects. (default: None) | |||
| :param load_func: a function that takes the dataset path (string) as input and returns multi-level lists. | |||
| """ | |||
| list.__init__([]) | |||
| self.name = name | |||
| if instances is not None: | |||
| self.extend(instances) | |||
| self.dataset_loader = loader | |||
| self.data_set_load_func = load_func | |||
| def index_all(self, vocab): | |||
| for ins in self: | |||
| @@ -117,15 +117,15 @@ class DataSet(list): | |||
| return lengths | |||
| def convert(self, data): | |||
| """Convert lists of strings into Instances with Fields""" | |||
| """Convert lists of strings into Instances with Fields, creating Vocabulary for labeled data. Used in Training.""" | |||
| raise NotImplementedError | |||
| def convert_with_vocabs(self, data, vocabs): | |||
| """Convert lists of strings into Instances with Fields, using existing Vocabulary. Useful in predicting.""" | |||
| """Convert lists of strings into Instances with Fields, using existing Vocabulary, with labels. Used in Testing.""" | |||
| raise NotImplementedError | |||
| def convert_for_infer(self, data, vocabs): | |||
| """Convert lists of strings into Instances with Fields.""" | |||
| """Convert lists of strings into Instances with Fields, using existing Vocabulary, without labels. Used in predicting.""" | |||
| def load(self, data_path, vocabs=None, infer=False): | |||
| """Load data from the given files. | |||
| @@ -135,7 +135,7 @@ class DataSet(list): | |||
| :param vocabs: dict of (name: Vocabulary object), used to index data. If not provided, a new vocabulary will be constructed. | |||
| """ | |||
| raw_data = self.dataset_loader.load(data_path) | |||
| raw_data = self.data_set_load_func(data_path) | |||
| if infer is True: | |||
| self.convert_for_infer(raw_data, vocabs) | |||
| else: | |||
| @@ -145,7 +145,7 @@ class DataSet(list): | |||
| self.convert(raw_data) | |||
| def load_raw(self, raw_data, vocabs): | |||
| """ | |||
| """Load raw data without loader. Used in FastNLP class. | |||
| :param raw_data: | |||
| :param vocabs: | |||
| @@ -174,8 +174,8 @@ class DataSet(list): | |||
| class SeqLabelDataSet(DataSet): | |||
| def __init__(self, instances=None, loader=POSDataSetLoader()): | |||
| super(SeqLabelDataSet, self).__init__(name="", instances=instances, loader=loader) | |||
| def __init__(self, instances=None, load_func=POSDataSetLoader().load): | |||
| super(SeqLabelDataSet, self).__init__(name="", instances=instances, load_func=load_func) | |||
| self.word_vocab = Vocabulary() | |||
| self.label_vocab = Vocabulary() | |||
| @@ -231,8 +231,8 @@ class SeqLabelDataSet(DataSet): | |||
| class TextClassifyDataSet(DataSet): | |||
| def __init__(self, instances=None, loader=ClassDataSetLoader()): | |||
| super(TextClassifyDataSet, self).__init__(name="", instances=instances, loader=loader) | |||
| def __init__(self, instances=None, load_func=ClassDataSetLoader().load): | |||
| super(TextClassifyDataSet, self).__init__(name="", instances=instances, load_func=load_func) | |||
| self.word_vocab = Vocabulary() | |||
| self.label_vocab = Vocabulary(need_default=False) | |||
| @@ -285,10 +285,3 @@ def change_field_is_target(data_set, field_name, new_target): | |||
| for inst in data_set: | |||
| inst.fields[field_name].is_target = new_target | |||
| if __name__ == "__main__": | |||
| data_set = SeqLabelDataSet() | |||
| data_set.load("../../test/data_for_tests/people.txt") | |||
| a, b = data_set.split(0.3) | |||
| print(type(data_set), type(a), type(b)) | |||
| print(len(data_set), len(a), len(b)) | |||
| @@ -78,6 +78,7 @@ class Preprocessor(object): | |||
| is only available when label_is_seq is True. Default: False. | |||
| :param add_char_field: bool, whether to add character representations to all TextFields. Default: False. | |||
| """ | |||
| print("Preprocessor is about to deprecate. Please use DataSet class.") | |||
| self.data_vocab = Vocabulary() | |||
| if label_is_seq is True: | |||
| if share_vocab is True: | |||
| @@ -307,11 +308,3 @@ class ClassPreprocess(Preprocessor): | |||
| print("[FastNLP warning] ClassPreprocess is about to deprecate. Please use Preprocess directly.") | |||
| super(ClassPreprocess, self).__init__() | |||
| if __name__ == "__main__": | |||
| p = Preprocessor() | |||
| train_dev_data = [[["I", "am", "a", "good", "student", "."], "0"], | |||
| [["You", "are", "pretty", "."], "1"] | |||
| ] | |||
| training_set = p.run(train_dev_data) | |||
| print(training_set) | |||
| @@ -1,4 +1,3 @@ | |||
| import copy | |||
| import os | |||
| import time | |||
| from datetime import timedelta | |||
| @@ -178,31 +177,6 @@ class Trainer(object): | |||
| logger.info(print_output) | |||
| step += 1 | |||
| def cross_validate(self, network, train_data_cv, dev_data_cv): | |||
| """Training with cross validation. | |||
| :param network: the model | |||
| :param train_data_cv: four-level list, of shape [num_folds, num_examples, 2, ?] | |||
| :param dev_data_cv: four-level list, of shape [num_folds, num_examples, 2, ?] | |||
| """ | |||
| if len(train_data_cv) != len(dev_data_cv): | |||
| logger.error("the number of folds in train and dev data unequals {}!={}".format(len(train_data_cv), | |||
| len(dev_data_cv))) | |||
| raise RuntimeError("the number of folds in train and dev data unequals") | |||
| if self.validate is False: | |||
| logger.warn("Cross validation requires self.validate to be True. Please turn it on. ") | |||
| print("[warning] Cross validation requires self.validate to be True. Please turn it on. ") | |||
| self.validate = True | |||
| n_fold = len(train_data_cv) | |||
| logger.info("perform {} folds cross validation.".format(n_fold)) | |||
| for i in range(n_fold): | |||
| print("CV:", i) | |||
| logger.info("running the {} of {} folds cross validation".format(i + 1, n_fold)) | |||
| network_copy = copy.deepcopy(network) | |||
| self.train(network_copy, train_data_cv[i], dev_data_cv[i]) | |||
| def mode(self, model, is_test=False): | |||
| """Train mode or Test mode. This is for PyTorch currently. | |||
| @@ -1,11 +1,10 @@ | |||
| import os | |||
| from fastNLP.core.dataset import SeqLabelDataSet, TextClassifyDataSet | |||
| from fastNLP.core.predictor import SeqLabelInfer, ClassificationInfer | |||
| from fastNLP.core.preprocess import load_pickle | |||
| from fastNLP.loader.config_loader import ConfigLoader, ConfigSection | |||
| from fastNLP.loader.model_loader import ModelLoader | |||
| from fastNLP.core.dataset import SeqLabelDataSet, TextClassifyDataSet | |||
| """ | |||
| mapping from model name to [URL, file_name.class_name, model_pickle_name] | |||
| @@ -73,7 +72,7 @@ class FastNLP(object): | |||
| :param model_dir: this directory should contain the following files: | |||
| 1. a trained model | |||
| 2. a config file, which is a fastNLP's configuration. | |||
| 3. a Vocab file, which is a pickle object of a Vocab instance. | |||
| 3. two Vocab files, which are pickle objects of Vocab instances, representing feature and label vocabs. | |||
| """ | |||
| self.model_dir = model_dir | |||
| self.model = None | |||
| @@ -192,7 +191,7 @@ class FastNLP(object): | |||
| def _load(self, model_dir, model_name): | |||
| # To do | |||
| return 0 | |||
| def _download(self, model_name, url): | |||
| @@ -202,7 +201,7 @@ class FastNLP(object): | |||
| :param url: | |||
| """ | |||
| print("Downloading {} from {}".format(model_name, url)) | |||
| # To do | |||
| # TODO: download model via url | |||
| def model_exist(self, model_dir): | |||
| """ | |||
| @@ -3,12 +3,14 @@ class BaseLoader(object): | |||
| def __init__(self): | |||
| super(BaseLoader, self).__init__() | |||
| def load_lines(self, data_path): | |||
| @staticmethod | |||
| def load_lines(data_path): | |||
| with open(data_path, "r", encoding="utf=8") as f: | |||
| text = f.readlines() | |||
| return [line.strip() for line in text] | |||
| def load(self, data_path): | |||
| @staticmethod | |||
| def load(data_path): | |||
| with open(data_path, "r", encoding="utf-8") as f: | |||
| text = f.readlines() | |||
| return [[word for word in sent.strip()] for sent in text] | |||
| @@ -84,7 +84,8 @@ class TokenizeDataSetLoader(DataSetLoader): | |||
| def __init__(self): | |||
| super(TokenizeDataSetLoader, self).__init__() | |||
| def load(self, data_path, max_seq_len=32): | |||
| @staticmethod | |||
| def load(data_path, max_seq_len=32): | |||
| """ | |||
| load pku dataset for Chinese word segmentation | |||
| CWS (Chinese Word Segmentation) pku training dataset format: | |||
| @@ -196,30 +196,3 @@ class BiAffine(nn.Module): | |||
| output = output * mask_d.unsqueeze(1).unsqueeze(3) * mask_e.unsqueeze(1).unsqueeze(2) | |||
| return output | |||
| class Transpose(nn.Module): | |||
| def __init__(self, x, y): | |||
| super(Transpose, self).__init__() | |||
| self.x = x | |||
| self.y = y | |||
| def forward(self, x): | |||
| return x.transpose(self.x, self.y) | |||
| class WordDropout(nn.Module): | |||
| def __init__(self, dropout_rate, drop_to_token): | |||
| super(WordDropout, self).__init__() | |||
| self.dropout_rate = dropout_rate | |||
| self.drop_to_token = drop_to_token | |||
| def forward(self, word_idx): | |||
| if not self.training: | |||
| return word_idx | |||
| drop_mask = torch.rand(word_idx.shape) < self.dropout_rate | |||
| if word_idx.device.type == 'cuda': | |||
| drop_mask = drop_mask.cuda() | |||
| drop_mask = drop_mask.long() | |||
| output = drop_mask * self.drop_to_token + (1 - drop_mask) * word_idx | |||
| return output | |||
| @@ -104,7 +104,8 @@ class ConfigSaver(object): | |||
| :return: | |||
| """ | |||
| section_file = self._get_section(section_name) | |||
| if len(section_file.__dict__.keys()) == 0:#the section not in file before | |||
| if len(section_file.__dict__.keys()) == 0: # the section not in the file before | |||
| # append this section to config file | |||
| with open(self.file_path, 'a') as f: | |||
| f.write('[' + section_name + ']\n') | |||
| for k in section.__dict__.keys(): | |||
| @@ -114,9 +115,11 @@ class ConfigSaver(object): | |||
| else: | |||
| f.write(str(section[k]) + '\n\n') | |||
| else: | |||
| # the section exists | |||
| change_file = False | |||
| for k in section.__dict__.keys(): | |||
| if k not in section_file: | |||
| # find a new key in this section | |||
| change_file = True | |||
| break | |||
| if section_file[k] != section[k]: | |||
| @@ -0,0 +1,243 @@ | |||
| import unittest | |||
| from fastNLP.core.dataset import SeqLabelDataSet, TextClassifyDataSet | |||
| from fastNLP.core.dataset import create_dataset_from_lists | |||
| class TestDataSet(unittest.TestCase): | |||
| labeled_data_list = [ | |||
| [["a", "b", "e", "d"], ["1", "2", "3", "4"]], | |||
| [["a", "b", "e", "d"], ["1", "2", "3", "4"]], | |||
| [["a", "b", "e", "d"], ["1", "2", "3", "4"]], | |||
| ] | |||
| unlabeled_data_list = [ | |||
| ["a", "b", "e", "d"], | |||
| ["a", "b", "e", "d"], | |||
| ["a", "b", "e", "d"] | |||
| ] | |||
| word_vocab = {"a": 0, "b": 1, "e": 2, "d": 3} | |||
| label_vocab = {"1": 1, "2": 2, "3": 3, "4": 4} | |||
| def test_case_1(self): | |||
| data_set = create_dataset_from_lists(self.labeled_data_list, self.word_vocab, has_target=True, | |||
| label_vocab=self.label_vocab) | |||
| self.assertEqual(len(data_set), len(self.labeled_data_list)) | |||
| self.assertTrue(len(data_set) > 0) | |||
| self.assertTrue(hasattr(data_set[0], "fields")) | |||
| self.assertTrue("word_seq" in data_set[0].fields) | |||
| self.assertTrue(hasattr(data_set[0].fields["word_seq"], "text")) | |||
| self.assertTrue(hasattr(data_set[0].fields["word_seq"], "_index")) | |||
| self.assertEqual(data_set[0].fields["word_seq"].text, self.labeled_data_list[0][0]) | |||
| self.assertEqual(data_set[0].fields["word_seq"]._index, | |||
| [self.word_vocab[c] for c in self.labeled_data_list[0][0]]) | |||
| self.assertTrue("label_seq" in data_set[0].fields) | |||
| self.assertTrue(hasattr(data_set[0].fields["label_seq"], "text")) | |||
| self.assertTrue(hasattr(data_set[0].fields["label_seq"], "_index")) | |||
| self.assertEqual(data_set[0].fields["label_seq"].text, self.labeled_data_list[0][1]) | |||
| self.assertEqual(data_set[0].fields["label_seq"]._index, | |||
| [self.label_vocab[c] for c in self.labeled_data_list[0][1]]) | |||
| def test_case_2(self): | |||
| data_set = create_dataset_from_lists(self.unlabeled_data_list, self.word_vocab, has_target=False) | |||
| self.assertEqual(len(data_set), len(self.unlabeled_data_list)) | |||
| self.assertTrue(len(data_set) > 0) | |||
| self.assertTrue(hasattr(data_set[0], "fields")) | |||
| self.assertTrue("word_seq" in data_set[0].fields) | |||
| self.assertTrue(hasattr(data_set[0].fields["word_seq"], "text")) | |||
| self.assertTrue(hasattr(data_set[0].fields["word_seq"], "_index")) | |||
| self.assertEqual(data_set[0].fields["word_seq"].text, self.unlabeled_data_list[0]) | |||
| self.assertEqual(data_set[0].fields["word_seq"]._index, | |||
| [self.word_vocab[c] for c in self.unlabeled_data_list[0]]) | |||
| class TestDataSetConvertion(unittest.TestCase): | |||
| labeled_data_list = [ | |||
| [["a", "b", "e", "d"], ["1", "2", "3", "4"]], | |||
| [["a", "b", "e", "d"], ["1", "2", "3", "4"]], | |||
| [["a", "b", "e", "d"], ["1", "2", "3", "4"]], | |||
| ] | |||
| unlabeled_data_list = [ | |||
| ["a", "b", "e", "d"], | |||
| ["a", "b", "e", "d"], | |||
| ["a", "b", "e", "d"] | |||
| ] | |||
| word_vocab = {"a": 0, "b": 1, "e": 2, "d": 3} | |||
| label_vocab = {"1": 1, "2": 2, "3": 3, "4": 4} | |||
| def test_case_1(self): | |||
| def loader(path): | |||
| labeled_data_list = [ | |||
| [["a", "b", "e", "d"], ["1", "2", "3", "4"]], | |||
| [["a", "b", "e", "d"], ["1", "2", "3", "4"]], | |||
| [["a", "b", "e", "d"], ["1", "2", "3", "4"]], | |||
| ] | |||
| return labeled_data_list | |||
| data_set = SeqLabelDataSet(load_func=loader) | |||
| data_set.load("any_path") | |||
| self.assertEqual(len(data_set), len(self.labeled_data_list)) | |||
| self.assertTrue(len(data_set) > 0) | |||
| self.assertTrue(hasattr(data_set[0], "fields")) | |||
| self.assertTrue("word_seq" in data_set[0].fields) | |||
| self.assertTrue(hasattr(data_set[0].fields["word_seq"], "text")) | |||
| self.assertTrue(hasattr(data_set[0].fields["word_seq"], "_index")) | |||
| self.assertEqual(data_set[0].fields["word_seq"].text, self.labeled_data_list[0][0]) | |||
| self.assertTrue("truth" in data_set[0].fields) | |||
| self.assertTrue(hasattr(data_set[0].fields["truth"], "text")) | |||
| self.assertTrue(hasattr(data_set[0].fields["truth"], "_index")) | |||
| self.assertEqual(data_set[0].fields["truth"].text, self.labeled_data_list[0][1]) | |||
| self.assertTrue("word_seq_origin_len" in data_set[0].fields) | |||
| def test_case_2(self): | |||
| def loader(path): | |||
| unlabeled_data_list = [ | |||
| ["a", "b", "e", "d"], | |||
| ["a", "b", "e", "d"], | |||
| ["a", "b", "e", "d"] | |||
| ] | |||
| return unlabeled_data_list | |||
| data_set = SeqLabelDataSet(load_func=loader) | |||
| data_set.load("any_path", vocabs={"word_vocab": self.word_vocab}, infer=True) | |||
| self.assertEqual(len(data_set), len(self.labeled_data_list)) | |||
| self.assertTrue(len(data_set) > 0) | |||
| self.assertTrue(hasattr(data_set[0], "fields")) | |||
| self.assertTrue("word_seq" in data_set[0].fields) | |||
| self.assertTrue(hasattr(data_set[0].fields["word_seq"], "text")) | |||
| self.assertTrue(hasattr(data_set[0].fields["word_seq"], "_index")) | |||
| self.assertEqual(data_set[0].fields["word_seq"].text, self.labeled_data_list[0][0]) | |||
| self.assertEqual(data_set[0].fields["word_seq"]._index, | |||
| [self.word_vocab[c] for c in self.labeled_data_list[0][0]]) | |||
| self.assertTrue("word_seq_origin_len" in data_set[0].fields) | |||
| def test_case_3(self): | |||
| def loader(path): | |||
| labeled_data_list = [ | |||
| [["a", "b", "e", "d"], ["1", "2", "3", "4"]], | |||
| [["a", "b", "e", "d"], ["1", "2", "3", "4"]], | |||
| [["a", "b", "e", "d"], ["1", "2", "3", "4"]], | |||
| ] | |||
| return labeled_data_list | |||
| data_set = SeqLabelDataSet(load_func=loader) | |||
| data_set.load("any_path", vocabs={"word_vocab": self.word_vocab, "label_vocab": self.label_vocab}) | |||
| self.assertEqual(len(data_set), len(self.labeled_data_list)) | |||
| self.assertTrue(len(data_set) > 0) | |||
| self.assertTrue(hasattr(data_set[0], "fields")) | |||
| self.assertTrue("word_seq" in data_set[0].fields) | |||
| self.assertTrue(hasattr(data_set[0].fields["word_seq"], "text")) | |||
| self.assertTrue(hasattr(data_set[0].fields["word_seq"], "_index")) | |||
| self.assertEqual(data_set[0].fields["word_seq"].text, self.labeled_data_list[0][0]) | |||
| self.assertEqual(data_set[0].fields["word_seq"]._index, | |||
| [self.word_vocab[c] for c in self.labeled_data_list[0][0]]) | |||
| self.assertTrue("truth" in data_set[0].fields) | |||
| self.assertTrue(hasattr(data_set[0].fields["truth"], "text")) | |||
| self.assertTrue(hasattr(data_set[0].fields["truth"], "_index")) | |||
| self.assertEqual(data_set[0].fields["truth"].text, self.labeled_data_list[0][1]) | |||
| self.assertEqual(data_set[0].fields["truth"]._index, | |||
| [self.label_vocab[c] for c in self.labeled_data_list[0][1]]) | |||
| self.assertTrue("word_seq_origin_len" in data_set[0].fields) | |||
| class TestDataSetConvertionHHH(unittest.TestCase): | |||
| labeled_data_list = [ | |||
| [["a", "b", "e", "d"], "A"], | |||
| [["a", "b", "e", "d"], "C"], | |||
| [["a", "b", "e", "d"], "B"], | |||
| ] | |||
| unlabeled_data_list = [ | |||
| ["a", "b", "e", "d"], | |||
| ["a", "b", "e", "d"], | |||
| ["a", "b", "e", "d"] | |||
| ] | |||
| word_vocab = {"a": 0, "b": 1, "e": 2, "d": 3} | |||
| label_vocab = {"A": 1, "B": 2, "C": 3} | |||
| def test_case_1(self): | |||
| def loader(path): | |||
| labeled_data_list = [ | |||
| [["a", "b", "e", "d"], "A"], | |||
| [["a", "b", "e", "d"], "C"], | |||
| [["a", "b", "e", "d"], "B"], | |||
| ] | |||
| return labeled_data_list | |||
| data_set = TextClassifyDataSet(load_func=loader) | |||
| data_set.load("xxx") | |||
| self.assertEqual(len(data_set), len(self.labeled_data_list)) | |||
| self.assertTrue(len(data_set) > 0) | |||
| self.assertTrue(hasattr(data_set[0], "fields")) | |||
| self.assertTrue("word_seq" in data_set[0].fields) | |||
| self.assertTrue(hasattr(data_set[0].fields["word_seq"], "text")) | |||
| self.assertTrue(hasattr(data_set[0].fields["word_seq"], "_index")) | |||
| self.assertEqual(data_set[0].fields["word_seq"].text, self.labeled_data_list[0][0]) | |||
| self.assertTrue("label" in data_set[0].fields) | |||
| self.assertTrue(hasattr(data_set[0].fields["label"], "label")) | |||
| self.assertTrue(hasattr(data_set[0].fields["label"], "_index")) | |||
| self.assertEqual(data_set[0].fields["label"].label, self.labeled_data_list[0][1]) | |||
| def test_case_2(self): | |||
| def loader(path): | |||
| labeled_data_list = [ | |||
| [["a", "b", "e", "d"], "A"], | |||
| [["a", "b", "e", "d"], "C"], | |||
| [["a", "b", "e", "d"], "B"], | |||
| ] | |||
| return labeled_data_list | |||
| data_set = TextClassifyDataSet(load_func=loader) | |||
| data_set.load("xxx", vocabs={"word_vocab": self.word_vocab, "label_vocab": self.label_vocab}) | |||
| self.assertEqual(len(data_set), len(self.labeled_data_list)) | |||
| self.assertTrue(len(data_set) > 0) | |||
| self.assertTrue(hasattr(data_set[0], "fields")) | |||
| self.assertTrue("word_seq" in data_set[0].fields) | |||
| self.assertTrue(hasattr(data_set[0].fields["word_seq"], "text")) | |||
| self.assertTrue(hasattr(data_set[0].fields["word_seq"], "_index")) | |||
| self.assertEqual(data_set[0].fields["word_seq"].text, self.labeled_data_list[0][0]) | |||
| self.assertEqual(data_set[0].fields["word_seq"]._index, | |||
| [self.word_vocab[c] for c in self.labeled_data_list[0][0]]) | |||
| self.assertTrue("label" in data_set[0].fields) | |||
| self.assertTrue(hasattr(data_set[0].fields["label"], "label")) | |||
| self.assertTrue(hasattr(data_set[0].fields["label"], "_index")) | |||
| self.assertEqual(data_set[0].fields["label"].label, self.labeled_data_list[0][1]) | |||
| self.assertEqual(data_set[0].fields["label"]._index, self.label_vocab[self.labeled_data_list[0][1]]) | |||
| def test_case_3(self): | |||
| def loader(path): | |||
| unlabeled_data_list = [ | |||
| ["a", "b", "e", "d"], | |||
| ["a", "b", "e", "d"], | |||
| ["a", "b", "e", "d"] | |||
| ] | |||
| return unlabeled_data_list | |||
| data_set = TextClassifyDataSet(load_func=loader) | |||
| data_set.load("xxx", vocabs={"word_vocab": self.word_vocab}, infer=True) | |||
| self.assertEqual(len(data_set), len(self.labeled_data_list)) | |||
| self.assertTrue(len(data_set) > 0) | |||
| self.assertTrue(hasattr(data_set[0], "fields")) | |||
| self.assertTrue("word_seq" in data_set[0].fields) | |||
| self.assertTrue(hasattr(data_set[0].fields["word_seq"], "text")) | |||
| self.assertTrue(hasattr(data_set[0].fields["word_seq"], "_index")) | |||
| self.assertEqual(data_set[0].fields["word_seq"].text, self.labeled_data_list[0][0]) | |||
| self.assertEqual(data_set[0].fields["word_seq"]._index, | |||
| [self.word_vocab[c] for c in self.labeled_data_list[0][0]]) | |||
| @@ -1,13 +1,13 @@ | |||
| import os | |||
| import unittest | |||
| from fastNLP.core.predictor import Predictor | |||
| from fastNLP.core.dataset import TextClassifyDataSet, SeqLabelDataSet | |||
| from fastNLP.core.predictor import Predictor | |||
| from fastNLP.core.preprocess import save_pickle | |||
| from fastNLP.core.vocabulary import Vocabulary | |||
| from fastNLP.loader.base_loader import BaseLoader | |||
| from fastNLP.models.sequence_modeling import SeqLabeling | |||
| from fastNLP.models.cnn_text_classification import CNNText | |||
| from fastNLP.models.sequence_modeling import SeqLabeling | |||
| class TestPredictor(unittest.TestCase): | |||
| @@ -42,7 +42,7 @@ class TestPredictor(unittest.TestCase): | |||
| predictor = Predictor("./save/", pre.text_classify_post_processor) | |||
| # Load infer data | |||
| infer_data_set = TextClassifyDataSet(loader=BaseLoader()) | |||
| infer_data_set = TextClassifyDataSet(load_func=BaseLoader.load) | |||
| infer_data_set.convert_for_infer(infer_data, vocabs={"word_vocab": vocab.word2idx}) | |||
| results = predictor.predict(network=model, data=infer_data_set) | |||
| @@ -59,7 +59,7 @@ class TestPredictor(unittest.TestCase): | |||
| model = SeqLabeling(model_args) | |||
| predictor = Predictor("./save/", pre.seq_label_post_processor) | |||
| infer_data_set = SeqLabelDataSet(loader=BaseLoader()) | |||
| infer_data_set = SeqLabelDataSet(load_func=BaseLoader.load) | |||
| infer_data_set.convert_for_infer(infer_data, vocabs={"word_vocab": vocab.word2idx}) | |||
| results = predictor.predict(network=model, data=infer_data_set) | |||
| @@ -53,7 +53,7 @@ def infer(): | |||
| print("model loaded!") | |||
| # Data Loader | |||
| infer_data = SeqLabelDataSet(loader=BaseLoader()) | |||
| infer_data = SeqLabelDataSet(load_func=BaseLoader.load) | |||
| infer_data.load(data_infer_path, vocabs={"word_vocab": word_vocab, "label_vocab": label_vocab}, infer=True) | |||
| print("data set prepared") | |||
| @@ -37,7 +37,7 @@ def infer(): | |||
| print("model loaded!") | |||
| # Load infer data | |||
| infer_data = SeqLabelDataSet(loader=BaseLoader()) | |||
| infer_data = SeqLabelDataSet(load_func=BaseLoader.load) | |||
| infer_data.load(data_infer_path, vocabs={"word_vocab": word2index}, infer=True) | |||
| # inference | |||
| @@ -52,7 +52,7 @@ def train_test(): | |||
| ConfigLoader().load_config(config_path, {"POS_infer": train_args}) | |||
| # define dataset | |||
| data_train = SeqLabelDataSet(loader=TokenizeDataSetLoader()) | |||
| data_train = SeqLabelDataSet(load_func=TokenizeDataSetLoader.load) | |||
| data_train.load(cws_data_path) | |||
| train_args["vocab_size"] = len(data_train.word_vocab) | |||
| train_args["num_classes"] = len(data_train.label_vocab) | |||
| @@ -40,7 +40,7 @@ def infer(): | |||
| print("vocabulary size:", len(word_vocab)) | |||
| print("number of classes:", len(label_vocab)) | |||
| infer_data = TextClassifyDataSet(loader=ClassDataSetLoader()) | |||
| infer_data = TextClassifyDataSet(load_func=ClassDataSetLoader.load) | |||
| infer_data.load(train_data_dir, vocabs={"word_vocab": word_vocab, "label_vocab": label_vocab}) | |||
| model_args = ConfigSection() | |||
| @@ -67,7 +67,7 @@ def train(): | |||
| # load dataset | |||
| print("Loading data...") | |||
| data = TextClassifyDataSet(loader=ClassDataSetLoader()) | |||
| data = TextClassifyDataSet(load_func=ClassDataSetLoader.load) | |||
| data.load(train_data_dir) | |||
| print("vocabulary size:", len(data.word_vocab)) | |||
| @@ -2,7 +2,7 @@ import unittest | |||
| import torch | |||
| from fastNLP.modules.other_modules import GroupNorm, LayerNormalization, BiLinear | |||
| from fastNLP.modules.other_modules import GroupNorm, LayerNormalization, BiLinear, BiAffine | |||
| class TestGroupNorm(unittest.TestCase): | |||
| @@ -27,3 +27,25 @@ class TestBiLinear(unittest.TestCase): | |||
| y = bl(x_left, x_right) | |||
| print(bl) | |||
| bl2 = BiLinear(n_left=15, n_right=15, n_out=10, bias=True) | |||
| class TestBiAffine(unittest.TestCase): | |||
| def test_case_1(self): | |||
| batch_size = 16 | |||
| encoder_length = 21 | |||
| decoder_length = 32 | |||
| layer = BiAffine(10, 10, 25, biaffine=True) | |||
| decoder_input = torch.randn((batch_size, encoder_length, 10)) | |||
| encoder_input = torch.randn((batch_size, decoder_length, 10)) | |||
| y = layer(decoder_input, encoder_input) | |||
| self.assertEqual(tuple(y.shape), (batch_size, 25, encoder_length, decoder_length)) | |||
| def test_case_2(self): | |||
| batch_size = 16 | |||
| encoder_length = 21 | |||
| decoder_length = 32 | |||
| layer = BiAffine(10, 10, 25, biaffine=False) | |||
| decoder_input = torch.randn((batch_size, encoder_length, 10)) | |||
| encoder_input = torch.randn((batch_size, decoder_length, 10)) | |||
| y = layer(decoder_input, encoder_input) | |||
| self.assertEqual(tuple(y.shape), (batch_size, 25, encoder_length, 1)) | |||
| @@ -1,8 +1,5 @@ | |||
| import os | |||
| import unittest | |||
| import configparser | |||
| import json | |||
| from fastNLP.loader.config_loader import ConfigSection, ConfigLoader | |||
| from fastNLP.saver.config_saver import ConfigSaver | |||
| @@ -10,7 +7,7 @@ from fastNLP.saver.config_saver import ConfigSaver | |||
| class TestConfigSaver(unittest.TestCase): | |||
| def test_case_1(self): | |||
| config_file_dir = "./test/loader/" | |||
| config_file_dir = "test/loader/" | |||
| config_file_name = "config" | |||
| config_file_path = os.path.join(config_file_dir, config_file_name) | |||
| @@ -80,3 +77,37 @@ class TestConfigSaver(unittest.TestCase): | |||
| tmp_config_saver = ConfigSaver("file-NOT-exist") | |||
| except Exception as e: | |||
| pass | |||
| def test_case_2(self): | |||
| config = "[section_A]\n[section_B]\n" | |||
| with open("./test.cfg", "w", encoding="utf-8") as f: | |||
| f.write(config) | |||
| saver = ConfigSaver("./test.cfg") | |||
| section = ConfigSection() | |||
| section["doubles"] = 0.8 | |||
| section["tt"] = [1, 2, 3] | |||
| section["test"] = 105 | |||
| section["str"] = "this is a str" | |||
| saver.save_config_file("section_A", section) | |||
| os.system("rm ./test.cfg") | |||
| def test_case_3(self): | |||
| config = "[section_A]\ndoubles = 0.9\ntt = [1, 2, 3]\n[section_B]\n" | |||
| with open("./test.cfg", "w", encoding="utf-8") as f: | |||
| f.write(config) | |||
| saver = ConfigSaver("./test.cfg") | |||
| section = ConfigSection() | |||
| section["doubles"] = 0.8 | |||
| section["tt"] = [1, 2, 3] | |||
| section["test"] = 105 | |||
| section["str"] = "this is a str" | |||
| saver.save_config_file("section_A", section) | |||
| os.system("rm ./test.cfg") | |||