* refine code style * set up unit tests for Batch, DataSet, FieldArray * remove a lot of out-of-date unit tests, to get testing passedtags/v0.2.0
| @@ -64,6 +64,7 @@ class DataSet(object): | |||||
| """ | """ | ||||
| :param data: a dict or a list. If it is a dict, the key is the name of a field and the value is the field. | :param data: a dict or a list. If it is a dict, the key is the name of a field and the value is the field. | ||||
| All values must be of the same length. | |||||
| If it is a list, it must be a list of Instance objects. | If it is a list, it must be a list of Instance objects. | ||||
| """ | """ | ||||
| self.field_arrays = {} | self.field_arrays = {} | ||||
| @@ -23,8 +23,7 @@ class FieldArray(object): | |||||
| self.dtype = None | self.dtype = None | ||||
| def __repr__(self): | def __repr__(self): | ||||
| # TODO | |||||
| return '{}: {}'.format(self.name, self.content.__repr__()) | |||||
| return "FieldArray {}: {}".format(self.name, self.content.__repr__()) | |||||
| def append(self, val): | def append(self, val): | ||||
| self.content.append(val) | self.content.append(val) | ||||
| @@ -11,7 +11,7 @@ class Instance(object): | |||||
| def __init__(self, **fields): | def __init__(self, **fields): | ||||
| """ | """ | ||||
| :param fields: a dict of (field name: field) | |||||
| :param fields: a dict of (str: list). | |||||
| """ | """ | ||||
| self.fields = fields | self.fields = fields | ||||
| @@ -1,5 +1,6 @@ | |||||
| import os | |||||
| import _pickle as pickle | import _pickle as pickle | ||||
| import os | |||||
| class BaseLoader(object): | class BaseLoader(object): | ||||
| @@ -1,7 +1,6 @@ | |||||
| import os | import os | ||||
| from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
| from fastNLP.core.field import * | |||||
| from fastNLP.core.instance import Instance | from fastNLP.core.instance import Instance | ||||
| from fastNLP.io.base_loader import BaseLoader | from fastNLP.io.base_loader import BaseLoader | ||||
| @@ -87,6 +86,7 @@ class DataSetLoader(BaseLoader): | |||||
| """ | """ | ||||
| raise NotImplementedError | raise NotImplementedError | ||||
| @DataSet.set_reader('read_raw') | @DataSet.set_reader('read_raw') | ||||
| class RawDataSetLoader(DataSetLoader): | class RawDataSetLoader(DataSetLoader): | ||||
| def __init__(self): | def __init__(self): | ||||
| @@ -102,6 +102,7 @@ class RawDataSetLoader(DataSetLoader): | |||||
| def convert(self, data): | def convert(self, data): | ||||
| return convert_seq_dataset(data) | return convert_seq_dataset(data) | ||||
| @DataSet.set_reader('read_pos') | @DataSet.set_reader('read_pos') | ||||
| class POSDataSetLoader(DataSetLoader): | class POSDataSetLoader(DataSetLoader): | ||||
| """Dataset Loader for POS Tag datasets. | """Dataset Loader for POS Tag datasets. | ||||
| @@ -171,6 +172,7 @@ class POSDataSetLoader(DataSetLoader): | |||||
| """ | """ | ||||
| return convert_seq2seq_dataset(data) | return convert_seq2seq_dataset(data) | ||||
| @DataSet.set_reader('read_tokenize') | @DataSet.set_reader('read_tokenize') | ||||
| class TokenizeDataSetLoader(DataSetLoader): | class TokenizeDataSetLoader(DataSetLoader): | ||||
| """ | """ | ||||
| @@ -230,6 +232,7 @@ class TokenizeDataSetLoader(DataSetLoader): | |||||
| def convert(self, data): | def convert(self, data): | ||||
| return convert_seq2seq_dataset(data) | return convert_seq2seq_dataset(data) | ||||
| @DataSet.set_reader('read_class') | @DataSet.set_reader('read_class') | ||||
| class ClassDataSetLoader(DataSetLoader): | class ClassDataSetLoader(DataSetLoader): | ||||
| """Loader for classification data sets""" | """Loader for classification data sets""" | ||||
| @@ -268,6 +271,7 @@ class ClassDataSetLoader(DataSetLoader): | |||||
| def convert(self, data): | def convert(self, data): | ||||
| return convert_seq2tag_dataset(data) | return convert_seq2tag_dataset(data) | ||||
| @DataSet.set_reader('read_conll') | @DataSet.set_reader('read_conll') | ||||
| class ConllLoader(DataSetLoader): | class ConllLoader(DataSetLoader): | ||||
| """loader for conll format files""" | """loader for conll format files""" | ||||
| @@ -309,6 +313,7 @@ class ConllLoader(DataSetLoader): | |||||
| def convert(self, data): | def convert(self, data): | ||||
| pass | pass | ||||
| @DataSet.set_reader('read_lm') | @DataSet.set_reader('read_lm') | ||||
| class LMDataSetLoader(DataSetLoader): | class LMDataSetLoader(DataSetLoader): | ||||
| """Language Model Dataset Loader | """Language Model Dataset Loader | ||||
| @@ -345,6 +350,7 @@ class LMDataSetLoader(DataSetLoader): | |||||
| def convert(self, data): | def convert(self, data): | ||||
| pass | pass | ||||
| @DataSet.set_reader('read_people_daily') | @DataSet.set_reader('read_people_daily') | ||||
| class PeopleDailyCorpusLoader(DataSetLoader): | class PeopleDailyCorpusLoader(DataSetLoader): | ||||
| """ | """ | ||||
| @@ -1,6 +1,9 @@ | |||||
| import unittest | import unittest | ||||
| import numpy as np | |||||
| from fastNLP.core.batch import Batch | from fastNLP.core.batch import Batch | ||||
| from fastNLP.core.dataset import DataSet | |||||
| from fastNLP.core.dataset import construct_dataset | from fastNLP.core.dataset import construct_dataset | ||||
| from fastNLP.core.sampler import SequentialSampler | from fastNLP.core.sampler import SequentialSampler | ||||
| @@ -10,9 +13,21 @@ class TestCase1(unittest.TestCase): | |||||
| dataset = construct_dataset( | dataset = construct_dataset( | ||||
| [["FastNLP", "is", "the", "most", "beautiful", "tool", "in", "the", "world"] for _ in range(40)]) | [["FastNLP", "is", "the", "most", "beautiful", "tool", "in", "the", "world"] for _ in range(40)]) | ||||
| dataset.set_target() | dataset.set_target() | ||||
| batch = Batch(dataset, batch_size=4, sampler=SequentialSampler(), use_cuda=False) | |||||
| batch = Batch(dataset, batch_size=4, sampler=SequentialSampler(), as_numpy=True) | |||||
| cnt = 0 | cnt = 0 | ||||
| for _, _ in batch: | for _, _ in batch: | ||||
| cnt += 1 | cnt += 1 | ||||
| self.assertEqual(cnt, 10) | self.assertEqual(cnt, 10) | ||||
| def test_dataset_batching(self): | |||||
| ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) | |||||
| ds.set_input(x=True) | |||||
| ds.set_target(y=True) | |||||
| iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=True) | |||||
| for x, y in iter: | |||||
| self.assertTrue(isinstance(x["x"], np.ndarray) and isinstance(y["y"], np.ndarray)) | |||||
| self.assertEqual(len(x["x"]), 4) | |||||
| self.assertEqual(len(y["y"]), 4) | |||||
| self.assertListEqual(list(x["x"][-1]), [1, 2, 3, 4]) | |||||
| self.assertListEqual(list(y["y"][-1]), [5, 6]) | |||||
| @@ -1,20 +1,75 @@ | |||||
| import unittest | import unittest | ||||
| from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
| from fastNLP.core.instance import Instance | |||||
| class TestDataSet(unittest.TestCase): | class TestDataSet(unittest.TestCase): | ||||
| def test_case_1(self): | |||||
| ds = DataSet() | |||||
| ds.add_field(name="xx", fields=["a", "b", "e", "d"]) | |||||
| def test_init_v1(self): | |||||
| ds = DataSet([Instance(x=[1, 2, 3, 4], y=[5, 6])] * 40) | |||||
| self.assertTrue("x" in ds.field_arrays and "y" in ds.field_arrays) | |||||
| self.assertEqual(ds.field_arrays["x"].content, [[1, 2, 3, 4], ] * 40) | |||||
| self.assertEqual(ds.field_arrays["y"].content, [[5, 6], ] * 40) | |||||
| self.assertTrue("xx" in ds.field_arrays) | |||||
| self.assertEqual(len(ds.field_arrays["xx"]), 4) | |||||
| self.assertEqual(ds.get_length(), 4) | |||||
| self.assertEqual(ds.get_fields(), ds.field_arrays) | |||||
| def test_init_v2(self): | |||||
| ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) | |||||
| self.assertTrue("x" in ds.field_arrays and "y" in ds.field_arrays) | |||||
| self.assertEqual(ds.field_arrays["x"].content, [[1, 2, 3, 4], ] * 40) | |||||
| self.assertEqual(ds.field_arrays["y"].content, [[5, 6], ] * 40) | |||||
| try: | |||||
| ds.add_field(name="yy", fields=["x", "y", "z", "w", "f"]) | |||||
| except BaseException as e: | |||||
| self.assertTrue(isinstance(e, AssertionError)) | |||||
| def test_init_assert(self): | |||||
| with self.assertRaises(AssertionError): | |||||
| _ = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 100}) | |||||
| with self.assertRaises(AssertionError): | |||||
| _ = DataSet([[1, 2, 3, 4]] * 10) | |||||
| with self.assertRaises(ValueError): | |||||
| _ = DataSet(0.00001) | |||||
| def test_append(self): | |||||
| dd = DataSet() | |||||
| for _ in range(3): | |||||
| dd.append(Instance(x=[1, 2, 3, 4], y=[5, 6])) | |||||
| self.assertEqual(len(dd), 3) | |||||
| self.assertEqual(dd.field_arrays["x"].content, [[1, 2, 3, 4]] * 3) | |||||
| self.assertEqual(dd.field_arrays["y"].content, [[5, 6]] * 3) | |||||
| def test_add_append(self): | |||||
| dd = DataSet() | |||||
| dd.add_field("x", [[1, 2, 3]] * 10) | |||||
| dd.add_field("y", [[1, 2, 3, 4]] * 10) | |||||
| dd.add_field("z", [[5, 6]] * 10) | |||||
| self.assertEqual(len(dd), 10) | |||||
| self.assertEqual(dd.field_arrays["x"].content, [[1, 2, 3]] * 10) | |||||
| self.assertEqual(dd.field_arrays["y"].content, [[1, 2, 3, 4]] * 10) | |||||
| self.assertEqual(dd.field_arrays["z"].content, [[5, 6]] * 10) | |||||
| def test_delete_field(self): | |||||
| dd = DataSet() | |||||
| dd.add_field("x", [[1, 2, 3]] * 10) | |||||
| dd.add_field("y", [[1, 2, 3, 4]] * 10) | |||||
| dd.delete_field("x") | |||||
| self.assertFalse("x" in dd.field_arrays) | |||||
| self.assertTrue("y" in dd.field_arrays) | |||||
| def test_getitem(self): | |||||
| ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) | |||||
| ins_1, ins_0 = ds[0], ds[1] | |||||
| self.assertTrue(isinstance(ins_1, DataSet.Instance) and isinstance(ins_0, DataSet.Instance)) | |||||
| self.assertEqual(ins_1["x"], [1, 2, 3, 4]) | |||||
| self.assertEqual(ins_1["y"], [5, 6]) | |||||
| self.assertEqual(ins_0["x"], [1, 2, 3, 4]) | |||||
| self.assertEqual(ins_0["y"], [5, 6]) | |||||
| sub_ds = ds[:10] | |||||
| self.assertTrue(isinstance(sub_ds, DataSet)) | |||||
| self.assertEqual(len(sub_ds), 10) | |||||
| field = ds["x"] | |||||
| self.assertEqual(field, ds.field_arrays["x"]) | |||||
| def test_apply(self): | |||||
| ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) | |||||
| ds.apply(lambda ins: ins["x"][::-1], new_field_name="rx") | |||||
| self.assertTrue("rx" in ds.field_arrays) | |||||
| self.assertEqual(ds.field_arrays["rx"].content[0], [4, 3, 2, 1]) | |||||
| @@ -1,6 +1,22 @@ | |||||
| import unittest | import unittest | ||||
| import numpy as np | |||||
| from fastNLP.core.fieldarray import FieldArray | |||||
| class TestFieldArray(unittest.TestCase): | class TestFieldArray(unittest.TestCase): | ||||
| def test(self): | def test(self): | ||||
| pass | |||||
| fa = FieldArray("x", [1, 2, 3, 4, 5], is_input=True) | |||||
| self.assertEqual(len(fa), 5) | |||||
| fa.append(6) | |||||
| self.assertEqual(len(fa), 6) | |||||
| self.assertEqual(fa[-1], 6) | |||||
| self.assertEqual(fa[0], 1) | |||||
| fa[-1] = 60 | |||||
| self.assertEqual(fa[-1], 60) | |||||
| self.assertEqual(fa.get(0), 1) | |||||
| self.assertTrue(isinstance(fa.get([0, 1, 2]), np.ndarray)) | |||||
| self.assertListEqual(list(fa.get([0, 1, 2])), [1, 2, 3]) | |||||
| @@ -1,100 +0,0 @@ | |||||
| import os | |||||
| import sys | |||||
| sys.path = [os.path.join(os.path.dirname(__file__), '..')] + sys.path | |||||
| from fastNLP.core import metrics | |||||
| # from sklearn import metrics as skmetrics | |||||
| import unittest | |||||
| from numpy import random | |||||
| from fastNLP.core.metrics import SeqLabelEvaluator | |||||
| import torch | |||||
| def generate_fake_label(low, high, size): | |||||
| return random.randint(low, high, size), random.randint(low, high, size) | |||||
| class TestEvaluator(unittest.TestCase): | |||||
| def test_a(self): | |||||
| evaluator = SeqLabelEvaluator() | |||||
| pred = [[1, 2, 3, 4, 5], [1, 2, 3, 4, 5]] | |||||
| truth = [{"truth": torch.LongTensor([1, 2, 3, 3, 3])}, {"truth": torch.LongTensor([1, 2, 3, 3, 4])}] | |||||
| ans = evaluator(pred, truth) | |||||
| print(ans) | |||||
| def test_b(self): | |||||
| evaluator = SeqLabelEvaluator() | |||||
| pred = [[1, 2, 3, 4, 5, 0, 0], [1, 2, 3, 4, 5, 0, 0]] | |||||
| truth = [{"truth": torch.LongTensor([1, 2, 3, 3, 3, 0, 0])}, {"truth": torch.LongTensor([1, 2, 3, 3, 4, 0, 0])}] | |||||
| ans = evaluator(pred, truth) | |||||
| print(ans) | |||||
| class TestMetrics(unittest.TestCase): | |||||
| delta = 1e-5 | |||||
| # test for binary, multiclass, multilabel | |||||
| data_types = [((1000,), 2), ((1000,), 10), ((1000, 10), 2)] | |||||
| fake_data = [generate_fake_label(0, high, shape) for shape, high in data_types] | |||||
| def test_accuracy_score(self): | |||||
| for y_true, y_pred in self.fake_data: | |||||
| for normalize in [True, False]: | |||||
| for sample_weight in [None, random.rand(y_true.shape[0])]: | |||||
| test = metrics.accuracy_score(y_true, y_pred, normalize=normalize, sample_weight=sample_weight) | |||||
| # ans = skmetrics.accuracy_score(y_true, y_pred, normalize=normalize, sample_weight=sample_weight) | |||||
| # self.assertAlmostEqual(test, ans, delta=self.delta) | |||||
| def test_recall_score(self): | |||||
| for y_true, y_pred in self.fake_data: | |||||
| # print(y_true.shape) | |||||
| labels = list(range(y_true.shape[1])) if len(y_true.shape) >= 2 else None | |||||
| test = metrics.recall_score(y_true, y_pred, labels=labels, average=None) | |||||
| if not isinstance(test, list): | |||||
| test = list(test) | |||||
| # ans = skmetrics.recall_score(y_true, y_pred,labels=labels, average=None) | |||||
| # ans = list(ans) | |||||
| # for a, b in zip(test, ans): | |||||
| # # print('{}, {}'.format(a, b)) | |||||
| # self.assertAlmostEqual(a, b, delta=self.delta) | |||||
| # test binary | |||||
| y_true, y_pred = generate_fake_label(0, 2, 1000) | |||||
| test = metrics.recall_score(y_true, y_pred) | |||||
| # ans = skmetrics.recall_score(y_true, y_pred) | |||||
| # self.assertAlmostEqual(ans, test, delta=self.delta) | |||||
| def test_precision_score(self): | |||||
| for y_true, y_pred in self.fake_data: | |||||
| # print(y_true.shape) | |||||
| labels = list(range(y_true.shape[1])) if len(y_true.shape) >= 2 else None | |||||
| test = metrics.precision_score(y_true, y_pred, labels=labels, average=None) | |||||
| # ans = skmetrics.precision_score(y_true, y_pred,labels=labels, average=None) | |||||
| # ans, test = list(ans), list(test) | |||||
| # for a, b in zip(test, ans): | |||||
| # # print('{}, {}'.format(a, b)) | |||||
| # self.assertAlmostEqual(a, b, delta=self.delta) | |||||
| # test binary | |||||
| y_true, y_pred = generate_fake_label(0, 2, 1000) | |||||
| test = metrics.precision_score(y_true, y_pred) | |||||
| # ans = skmetrics.precision_score(y_true, y_pred) | |||||
| # self.assertAlmostEqual(ans, test, delta=self.delta) | |||||
| def test_f1_score(self): | |||||
| for y_true, y_pred in self.fake_data: | |||||
| # print(y_true.shape) | |||||
| labels = list(range(y_true.shape[1])) if len(y_true.shape) >= 2 else None | |||||
| test = metrics.f1_score(y_true, y_pred, labels=labels, average=None) | |||||
| # ans = skmetrics.f1_score(y_true, y_pred,labels=labels, average=None) | |||||
| # ans, test = list(ans), list(test) | |||||
| # for a, b in zip(test, ans): | |||||
| # # print('{}, {}'.format(a, b)) | |||||
| # self.assertAlmostEqual(a, b, delta=self.delta) | |||||
| # test binary | |||||
| y_true, y_pred = generate_fake_label(0, 2, 1000) | |||||
| test = metrics.f1_score(y_true, y_pred) | |||||
| # ans = skmetrics.f1_score(y_true, y_pred) | |||||
| # self.assertAlmostEqual(ans, test, delta=self.delta) | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||
| @@ -1,77 +1,6 @@ | |||||
| import os | |||||
| import unittest | import unittest | ||||
| from fastNLP.core.predictor import Predictor | |||||
| from fastNLP.core.utils import save_pickle | |||||
| from fastNLP.core.vocabulary import Vocabulary | |||||
| from fastNLP.io.dataset_loader import convert_seq_dataset | |||||
| from fastNLP.models.cnn_text_classification import CNNText | |||||
| from fastNLP.models.sequence_modeling import SeqLabeling | |||||
| class TestPredictor(unittest.TestCase): | class TestPredictor(unittest.TestCase): | ||||
| def test_seq_label(self): | |||||
| model_args = { | |||||
| "vocab_size": 10, | |||||
| "word_emb_dim": 100, | |||||
| "rnn_hidden_units": 100, | |||||
| "num_classes": 5 | |||||
| } | |||||
| infer_data = [ | |||||
| ['a', 'b', 'c', 'd', 'e'], | |||||
| ['a', '@', 'c', 'd', 'e'], | |||||
| ['a', 'b', '#', 'd', 'e'], | |||||
| ['a', 'b', 'c', '?', 'e'], | |||||
| ['a', 'b', 'c', 'd', '$'], | |||||
| ['!', 'b', 'c', 'd', 'e'] | |||||
| ] | |||||
| vocab = Vocabulary() | |||||
| vocab.word2idx = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, '!': 5, '@': 6, '#': 7, '$': 8, '?': 9} | |||||
| class_vocab = Vocabulary() | |||||
| class_vocab.word2idx = {"0": 0, "1": 1, "2": 2, "3": 3, "4": 4} | |||||
| os.system("mkdir save") | |||||
| save_pickle(class_vocab, "./save/", "label2id.pkl") | |||||
| save_pickle(vocab, "./save/", "word2id.pkl") | |||||
| model = CNNText(model_args) | |||||
| import fastNLP.core.predictor as pre | |||||
| predictor = Predictor("./save/", pre.text_classify_post_processor) | |||||
| # Load infer data | |||||
| infer_data_set = convert_seq_dataset(infer_data) | |||||
| infer_data_set.index_field("word_seq", vocab) | |||||
| results = predictor.predict(network=model, data=infer_data_set) | |||||
| self.assertTrue(isinstance(results, list)) | |||||
| self.assertGreater(len(results), 0) | |||||
| self.assertEqual(len(results), len(infer_data)) | |||||
| for res in results: | |||||
| self.assertTrue(isinstance(res, str)) | |||||
| self.assertTrue(res in class_vocab.word2idx) | |||||
| del model, predictor | |||||
| infer_data_set.set_origin_len("word_seq") | |||||
| model = SeqLabeling(model_args) | |||||
| predictor = Predictor("./save/", pre.seq_label_post_processor) | |||||
| results = predictor.predict(network=model, data=infer_data_set) | |||||
| self.assertTrue(isinstance(results, list)) | |||||
| self.assertEqual(len(results), len(infer_data)) | |||||
| for i in range(len(infer_data)): | |||||
| res = results[i] | |||||
| self.assertTrue(isinstance(res, list)) | |||||
| self.assertEqual(len(res), len(infer_data[i])) | |||||
| os.system("rm -rf save") | |||||
| print("pickle path deleted") | |||||
| class TestPredictor2(unittest.TestCase): | |||||
| def test_text_classify(self): | |||||
| # TODO | |||||
| def test(self): | |||||
| pass | pass | ||||
| @@ -1,57 +1,9 @@ | |||||
| import os | |||||
| import unittest | import unittest | ||||
| from fastNLP.core.dataset import DataSet | |||||
| from fastNLP.core.field import TextField, LabelField | |||||
| from fastNLP.core.instance import Instance | |||||
| from fastNLP.core.metrics import SeqLabelEvaluator | |||||
| from fastNLP.core.tester import Tester | |||||
| from fastNLP.models.sequence_modeling import SeqLabeling | |||||
| data_name = "pku_training.utf8" | data_name = "pku_training.utf8" | ||||
| pickle_path = "data_for_tests" | pickle_path = "data_for_tests" | ||||
| class TestTester(unittest.TestCase): | class TestTester(unittest.TestCase): | ||||
| def test_case_1(self): | def test_case_1(self): | ||||
| model_args = { | |||||
| "vocab_size": 10, | |||||
| "word_emb_dim": 100, | |||||
| "rnn_hidden_units": 100, | |||||
| "num_classes": 5 | |||||
| } | |||||
| valid_args = {"save_output": True, "validate_in_training": True, "save_dev_input": True, | |||||
| "save_loss": True, "batch_size": 2, "pickle_path": "./save/", | |||||
| "use_cuda": False, "print_every_step": 1, "evaluator": SeqLabelEvaluator()} | |||||
| train_data = [ | |||||
| [['a', 'b', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']], | |||||
| [['a', '@', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']], | |||||
| [['a', 'b', '#', 'd', 'e'], ['a', '@', 'c', 'd', 'e']], | |||||
| [['a', 'b', 'c', '?', 'e'], ['a', '@', 'c', 'd', 'e']], | |||||
| [['a', 'b', 'c', 'd', '$'], ['a', '@', 'c', 'd', 'e']], | |||||
| [['!', 'b', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']], | |||||
| ] | |||||
| vocab = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, '!': 5, '@': 6, '#': 7, '$': 8, '?': 9} | |||||
| label_vocab = {'a': 0, '@': 1, 'c': 2, 'd': 3, 'e': 4} | |||||
| data_set = DataSet() | |||||
| for example in train_data: | |||||
| text, label = example[0], example[1] | |||||
| x = TextField(text, False) | |||||
| x_len = LabelField(len(text), is_target=False) | |||||
| y = TextField(label, is_target=True) | |||||
| ins = Instance(word_seq=x, truth=y, word_seq_origin_len=x_len) | |||||
| data_set.append(ins) | |||||
| data_set.index_field("word_seq", vocab) | |||||
| data_set.index_field("truth", label_vocab) | |||||
| model = SeqLabeling(model_args) | |||||
| tester = Tester(**valid_args) | |||||
| tester.test(network=model, dev_data=data_set) | |||||
| # If this can run, everything is OK. | |||||
| os.system("rm -rf save") | |||||
| print("pickle path deleted") | |||||
| pass | |||||
| @@ -1,57 +1,6 @@ | |||||
| import os | |||||
| import unittest | import unittest | ||||
| from fastNLP.core.dataset import DataSet | |||||
| from fastNLP.core.field import TextField, LabelField | |||||
| from fastNLP.core.instance import Instance | |||||
| from fastNLP.core.loss import Loss | |||||
| from fastNLP.core.metrics import SeqLabelEvaluator | |||||
| from fastNLP.core.optimizer import Optimizer | |||||
| from fastNLP.core.trainer import Trainer | |||||
| from fastNLP.models.sequence_modeling import SeqLabeling | |||||
| class TestTrainer(unittest.TestCase): | class TestTrainer(unittest.TestCase): | ||||
| def test_case_1(self): | def test_case_1(self): | ||||
| args = {"epochs": 3, "batch_size": 2, "validate": False, "use_cuda": False, "pickle_path": "./save/", | |||||
| "save_best_dev": True, "model_name": "default_model_name.pkl", | |||||
| "loss": Loss("cross_entropy"), | |||||
| "optimizer": Optimizer("Adam", lr=0.001, weight_decay=0), | |||||
| "vocab_size": 10, | |||||
| "word_emb_dim": 100, | |||||
| "rnn_hidden_units": 100, | |||||
| "num_classes": 5, | |||||
| "evaluator": SeqLabelEvaluator() | |||||
| } | |||||
| trainer = Trainer(**args) | |||||
| train_data = [ | |||||
| [['a', 'b', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']], | |||||
| [['a', '@', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']], | |||||
| [['a', 'b', '#', 'd', 'e'], ['a', '@', 'c', 'd', 'e']], | |||||
| [['a', 'b', 'c', '?', 'e'], ['a', '@', 'c', 'd', 'e']], | |||||
| [['a', 'b', 'c', 'd', '$'], ['a', '@', 'c', 'd', 'e']], | |||||
| [['!', 'b', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']], | |||||
| ] | |||||
| vocab = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, '!': 5, '@': 6, '#': 7, '$': 8, '?': 9} | |||||
| label_vocab = {'a': 0, '@': 1, 'c': 2, 'd': 3, 'e': 4} | |||||
| data_set = DataSet() | |||||
| for example in train_data: | |||||
| text, label = example[0], example[1] | |||||
| x = TextField(text, False) | |||||
| x_len = LabelField(len(text), is_target=False) | |||||
| y = TextField(label, is_target=False) | |||||
| ins = Instance(word_seq=x, truth=y, word_seq_origin_len=x_len) | |||||
| data_set.append(ins) | |||||
| data_set.index_field("word_seq", vocab) | |||||
| data_set.index_field("truth", label_vocab) | |||||
| model = SeqLabeling(args) | |||||
| trainer.train(network=model, train_data=data_set, dev_data=data_set) | |||||
| # If this can run, everything is OK. | |||||
| os.system("rm -rf save") | |||||
| print("pickle path deleted") | |||||
| pass | |||||
| @@ -1,53 +0,0 @@ | |||||
| import configparser | |||||
| import json | |||||
| import os | |||||
| import unittest | |||||
| from fastNLP.io.config_loader import ConfigSection, ConfigLoader | |||||
| class TestConfigLoader(unittest.TestCase): | |||||
| def test_case_ConfigLoader(self): | |||||
| def read_section_from_config(config_path, section_name): | |||||
| dict = {} | |||||
| if not os.path.exists(config_path): | |||||
| raise FileNotFoundError("config file {} NOT found.".format(config_path)) | |||||
| cfg = configparser.ConfigParser() | |||||
| cfg.read(config_path) | |||||
| if section_name not in cfg: | |||||
| raise AttributeError("config file {} do NOT have section {}".format( | |||||
| config_path, section_name | |||||
| )) | |||||
| gen_sec = cfg[section_name] | |||||
| for s in gen_sec.keys(): | |||||
| try: | |||||
| val = json.loads(gen_sec[s]) | |||||
| dict[s] = val | |||||
| except Exception as e: | |||||
| raise AttributeError("json can NOT load {} in section {}, config file {}".format( | |||||
| s, section_name, config_path | |||||
| )) | |||||
| return dict | |||||
| test_arg = ConfigSection() | |||||
| ConfigLoader().load_config(os.path.join("./test/loader", "config"), {"test": test_arg}) | |||||
| section = read_section_from_config(os.path.join("./test/loader", "config"), "test") | |||||
| for sec in section: | |||||
| if (sec not in test_arg) or (section[sec] != test_arg[sec]): | |||||
| raise AttributeError("ERROR") | |||||
| for sec in test_arg.__dict__.keys(): | |||||
| if (sec not in section) or (section[sec] != test_arg[sec]): | |||||
| raise AttributeError("ERROR") | |||||
| try: | |||||
| not_exist = test_arg["NOT EXIST"] | |||||
| except Exception as e: | |||||
| pass | |||||
| print("pass config test!") | |||||
| @@ -7,7 +7,7 @@ from fastNLP.io.config_saver import ConfigSaver | |||||
| class TestConfigSaver(unittest.TestCase): | class TestConfigSaver(unittest.TestCase): | ||||
| def test_case_1(self): | def test_case_1(self): | ||||
| config_file_dir = "test/loader/" | |||||
| config_file_dir = "test/io/" | |||||
| config_file_name = "config" | config_file_name = "config" | ||||
| config_file_path = os.path.join(config_file_dir, config_file_name) | config_file_path = os.path.join(config_file_dir, config_file_name) | ||||
| @@ -1,53 +0,0 @@ | |||||
| import unittest | |||||
| from fastNLP.core.dataset import DataSet | |||||
| from fastNLP.io.dataset_loader import POSDataSetLoader, LMDataSetLoader, TokenizeDataSetLoader, \ | |||||
| PeopleDailyCorpusLoader, ConllLoader | |||||
| class TestDatasetLoader(unittest.TestCase): | |||||
| def test_case_1(self): | |||||
| data = """Tom\tT\nand\tF\nJerry\tT\n.\tF\n\nHello\tT\nworld\tF\n!\tF""" | |||||
| lines = data.split("\n") | |||||
| answer = POSDataSetLoader.parse(lines) | |||||
| truth = [[["Tom", "and", "Jerry", "."], ["T", "F", "T", "F"]], [["Hello", "world", "!"], ["T", "F", "F"]]] | |||||
| self.assertListEqual(answer, truth, "POS Dataset Loader") | |||||
| def test_case_TokenizeDatasetLoader(self): | |||||
| loader = TokenizeDataSetLoader() | |||||
| filepath = "./test/data_for_tests/cws_pku_utf_8" | |||||
| data = loader.load(filepath, max_seq_len=32) | |||||
| assert len(data) > 0 | |||||
| data1 = DataSet() | |||||
| data1.read_tokenize(filepath, max_seq_len=32) | |||||
| assert len(data1) > 0 | |||||
| print("pass TokenizeDataSetLoader test!") | |||||
| def test_case_POSDatasetLoader(self): | |||||
| loader = POSDataSetLoader() | |||||
| filepath = "./test/data_for_tests/people.txt" | |||||
| data = loader.load("./test/data_for_tests/people.txt") | |||||
| datas = loader.load_lines("./test/data_for_tests/people.txt") | |||||
| data1 = DataSet().read_pos(filepath) | |||||
| assert len(data1) > 0 | |||||
| print("pass POSDataSetLoader test!") | |||||
| def test_case_LMDatasetLoader(self): | |||||
| loader = LMDataSetLoader() | |||||
| data = loader.load("./test/data_for_tests/charlm.txt") | |||||
| datas = loader.load_lines("./test/data_for_tests/charlm.txt") | |||||
| print("pass TokenizeDataSetLoader test!") | |||||
| def test_PeopleDailyCorpusLoader(self): | |||||
| loader = PeopleDailyCorpusLoader() | |||||
| _, _ = loader.load("./test/data_for_tests/people_daily_raw.txt") | |||||
| def test_ConllLoader(self): | |||||
| loader = ConllLoader() | |||||
| _ = loader.load("./test/data_for_tests/conll_example.txt") | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||
| @@ -1,31 +0,0 @@ | |||||
| import os | |||||
| import unittest | |||||
| from fastNLP.core.vocabulary import Vocabulary | |||||
| from fastNLP.io.embed_loader import EmbedLoader | |||||
| class TestEmbedLoader(unittest.TestCase): | |||||
| glove_path = './test/data_for_tests/glove.6B.50d_test.txt' | |||||
| pkl_path = './save' | |||||
| raw_texts = ["i am a cat", | |||||
| "this is a test of new batch", | |||||
| "ha ha", | |||||
| "I am a good boy .", | |||||
| "This is the most beautiful girl ." | |||||
| ] | |||||
| texts = [text.strip().split() for text in raw_texts] | |||||
| vocab = Vocabulary() | |||||
| vocab.update(texts) | |||||
| def test1(self): | |||||
| emb, _ = EmbedLoader.load_embedding(50, self.glove_path, 'glove', self.vocab, self.pkl_path) | |||||
| self.assertTrue(emb.shape[0] == (len(self.vocab))) | |||||
| self.assertTrue(emb.shape[1] == 50) | |||||
| os.remove(self.pkl_path) | |||||
| def test2(self): | |||||
| try: | |||||
| _ = EmbedLoader.load_embedding(100, self.glove_path, 'glove', self.vocab, self.pkl_path) | |||||
| self.fail(msg="load dismatch embedding") | |||||
| except ValueError: | |||||
| pass | |||||
| @@ -1,150 +0,0 @@ | |||||
| import os | |||||
| import sys | |||||
| sys.path.append("..") | |||||
| import argparse | |||||
| from fastNLP.io.config_loader import ConfigLoader, ConfigSection | |||||
| from fastNLP.io.dataset_loader import BaseLoader | |||||
| from fastNLP.io.model_saver import ModelSaver | |||||
| from fastNLP.io.model_loader import ModelLoader | |||||
| from fastNLP.core.tester import SeqLabelTester | |||||
| from fastNLP.models.sequence_modeling import SeqLabeling | |||||
| from fastNLP.core.predictor import SeqLabelInfer | |||||
| from fastNLP.core.optimizer import Optimizer | |||||
| from fastNLP.core.dataset import SeqLabelDataSet, change_field_is_target | |||||
| from fastNLP.core.metrics import SeqLabelEvaluator | |||||
| from fastNLP.core.utils import save_pickle, load_pickle | |||||
| parser = argparse.ArgumentParser() | |||||
| parser.add_argument("-s", "--save", type=str, default="./seq_label/", help="path to save pickle files") | |||||
| parser.add_argument("-t", "--train", type=str, default="../data_for_tests/people.txt", | |||||
| help="path to the training data") | |||||
| parser.add_argument("-c", "--config", type=str, default="../data_for_tests/config", help="path to the config file") | |||||
| parser.add_argument("-m", "--model_name", type=str, default="seq_label_model.pkl", help="the name of the model") | |||||
| parser.add_argument("-i", "--infer", type=str, default="../data_for_tests/people_infer.txt", | |||||
| help="data used for inference") | |||||
| args = parser.parse_args() | |||||
| pickle_path = args.save | |||||
| model_name = args.model_name | |||||
| config_dir = args.config | |||||
| data_path = args.train | |||||
| data_infer_path = args.infer | |||||
| def infer(): | |||||
| # Load infer configuration, the same as test | |||||
| test_args = ConfigSection() | |||||
| ConfigLoader().load_config(config_dir, {"POS_infer": test_args}) | |||||
| # fetch dictionary size and number of labels from pickle files | |||||
| word_vocab = load_pickle(pickle_path, "word2id.pkl") | |||||
| label_vocab = load_pickle(pickle_path, "label2id.pkl") | |||||
| test_args["vocab_size"] = len(word_vocab) | |||||
| test_args["num_classes"] = len(label_vocab) | |||||
| print("vocabularies loaded") | |||||
| # Define the same model | |||||
| model = SeqLabeling(test_args) | |||||
| print("model defined") | |||||
| # Dump trained parameters into the model | |||||
| ModelLoader.load_pytorch(model, os.path.join(pickle_path, model_name)) | |||||
| print("model loaded!") | |||||
| # Data Loader | |||||
| infer_data = SeqLabelDataSet(load_func=BaseLoader.load) | |||||
| infer_data.load(data_infer_path, vocabs={"word_vocab": word_vocab, "label_vocab": label_vocab}, infer=True) | |||||
| print("data set prepared") | |||||
| # Inference interface | |||||
| infer = SeqLabelInfer(pickle_path) | |||||
| results = infer.predict(model, infer_data) | |||||
| for res in results: | |||||
| print(res) | |||||
| print("Inference finished!") | |||||
| def train_and_test(): | |||||
| # Config Loader | |||||
| trainer_args = ConfigSection() | |||||
| model_args = ConfigSection() | |||||
| ConfigLoader().load_config(config_dir, { | |||||
| "test_seq_label_trainer": trainer_args, "test_seq_label_model": model_args}) | |||||
| data_set = SeqLabelDataSet() | |||||
| data_set.load(data_path) | |||||
| train_set, dev_set = data_set.split(0.3, shuffle=True) | |||||
| model_args["vocab_size"] = len(data_set.word_vocab) | |||||
| model_args["num_classes"] = len(data_set.label_vocab) | |||||
| save_pickle(data_set.word_vocab, pickle_path, "word2id.pkl") | |||||
| save_pickle(data_set.label_vocab, pickle_path, "label2id.pkl") | |||||
| """ | |||||
| trainer = SeqLabelTrainer( | |||||
| epochs=trainer_args["epochs"], | |||||
| batch_size=trainer_args["batch_size"], | |||||
| validate=False, | |||||
| use_cuda=trainer_args["use_cuda"], | |||||
| pickle_path=pickle_path, | |||||
| save_best_dev=trainer_args["save_best_dev"], | |||||
| model_name=model_name, | |||||
| optimizer=Optimizer("SGD", lr=0.01, momentum=0.9), | |||||
| ) | |||||
| """ | |||||
| # Model | |||||
| model = SeqLabeling(model_args) | |||||
| model.fit(train_set, dev_set, | |||||
| epochs=trainer_args["epochs"], | |||||
| batch_size=trainer_args["batch_size"], | |||||
| validate=False, | |||||
| use_cuda=trainer_args["use_cuda"], | |||||
| pickle_path=pickle_path, | |||||
| save_best_dev=trainer_args["save_best_dev"], | |||||
| model_name=model_name, | |||||
| optimizer=Optimizer("SGD", lr=0.01, momentum=0.9)) | |||||
| # Start training | |||||
| # trainer.train(model, train_set, dev_set) | |||||
| print("Training finished!") | |||||
| # Saver | |||||
| saver = ModelSaver(os.path.join(pickle_path, model_name)) | |||||
| saver.save_pytorch(model) | |||||
| print("Model saved!") | |||||
| del model | |||||
| change_field_is_target(dev_set, "truth", True) | |||||
| # Define the same model | |||||
| model = SeqLabeling(model_args) | |||||
| # Dump trained parameters into the model | |||||
| ModelLoader.load_pytorch(model, os.path.join(pickle_path, model_name)) | |||||
| print("model loaded!") | |||||
| # Load test configuration | |||||
| tester_args = ConfigSection() | |||||
| ConfigLoader().load_config(config_dir, {"test_seq_label_tester": tester_args}) | |||||
| # Tester | |||||
| tester = SeqLabelTester(batch_size=4, | |||||
| use_cuda=False, | |||||
| pickle_path=pickle_path, | |||||
| model_name="seq_label_in_test.pkl", | |||||
| evaluator=SeqLabelEvaluator() | |||||
| ) | |||||
| # Start testing with validation data | |||||
| tester.test(model, dev_set) | |||||
| print("model tested!") | |||||
| if __name__ == "__main__": | |||||
| train_and_test() | |||||
| infer() | |||||
| @@ -1,25 +0,0 @@ | |||||
| import unittest | |||||
| import numpy as np | |||||
| import torch | |||||
| from fastNLP.models.char_language_model import CharLM | |||||
| class TestCharLM(unittest.TestCase): | |||||
| def test_case_1(self): | |||||
| char_emb_dim = 50 | |||||
| word_emb_dim = 50 | |||||
| vocab_size = 1000 | |||||
| num_char = 24 | |||||
| max_word_len = 21 | |||||
| num_seq = 64 | |||||
| seq_len = 32 | |||||
| model = CharLM(char_emb_dim, word_emb_dim, vocab_size, num_char) | |||||
| x = torch.from_numpy(np.random.randint(0, num_char, size=(num_seq, seq_len, max_word_len + 2))) | |||||
| self.assertEqual(tuple(x.shape), (num_seq, seq_len, max_word_len + 2)) | |||||
| y = model(x) | |||||
| self.assertEqual(tuple(y.shape), (num_seq * seq_len, vocab_size)) | |||||
| @@ -1,111 +0,0 @@ | |||||
| import os | |||||
| from fastNLP.core.metrics import SeqLabelEvaluator | |||||
| from fastNLP.core.predictor import Predictor | |||||
| from fastNLP.core.tester import Tester | |||||
| from fastNLP.core.trainer import Trainer | |||||
| from fastNLP.core.utils import save_pickle, load_pickle | |||||
| from fastNLP.core.vocabulary import Vocabulary | |||||
| from fastNLP.io.config_loader import ConfigLoader, ConfigSection | |||||
| from fastNLP.io.dataset_loader import TokenizeDataSetLoader, RawDataSetLoader | |||||
| from fastNLP.io.model_loader import ModelLoader | |||||
| from fastNLP.io.model_saver import ModelSaver | |||||
| from fastNLP.models.sequence_modeling import SeqLabeling | |||||
| data_name = "pku_training.utf8" | |||||
| cws_data_path = "./test/data_for_tests/cws_pku_utf_8" | |||||
| pickle_path = "./save/" | |||||
| data_infer_path = "./test/data_for_tests/people_infer.txt" | |||||
| config_path = "./test/data_for_tests/config" | |||||
| def infer(): | |||||
| # Load infer configuration, the same as test | |||||
| test_args = ConfigSection() | |||||
| ConfigLoader().load_config(config_path, {"POS_infer": test_args}) | |||||
| # fetch dictionary size and number of labels from pickle files | |||||
| word2index = load_pickle(pickle_path, "word2id.pkl") | |||||
| test_args["vocab_size"] = len(word2index) | |||||
| index2label = load_pickle(pickle_path, "label2id.pkl") | |||||
| test_args["num_classes"] = len(index2label) | |||||
| # Define the same model | |||||
| model = SeqLabeling(test_args) | |||||
| # Dump trained parameters into the model | |||||
| ModelLoader.load_pytorch(model, "./save/saved_model.pkl") | |||||
| print("model loaded!") | |||||
| # Load infer data | |||||
| infer_data = RawDataSetLoader().load(data_infer_path) | |||||
| infer_data.index_field("word_seq", word2index) | |||||
| infer_data.set_origin_len("word_seq") | |||||
| # inference | |||||
| infer = Predictor(pickle_path) | |||||
| results = infer.predict(model, infer_data) | |||||
| print(results) | |||||
| def train_test(): | |||||
| # Config Loader | |||||
| train_args = ConfigSection() | |||||
| ConfigLoader().load_config(config_path, {"POS_infer": train_args}) | |||||
| # define dataset | |||||
| data_train = TokenizeDataSetLoader().load(cws_data_path) | |||||
| word_vocab = Vocabulary() | |||||
| label_vocab = Vocabulary() | |||||
| data_train.update_vocab(word_seq=word_vocab, label_seq=label_vocab) | |||||
| data_train.index_field("word_seq", word_vocab).index_field("label_seq", label_vocab) | |||||
| data_train.set_origin_len("word_seq") | |||||
| data_train.rename_field("label_seq", "truth").set_target(truth=False) | |||||
| train_args["vocab_size"] = len(word_vocab) | |||||
| train_args["num_classes"] = len(label_vocab) | |||||
| save_pickle(word_vocab, pickle_path, "word2id.pkl") | |||||
| save_pickle(label_vocab, pickle_path, "label2id.pkl") | |||||
| # Trainer | |||||
| trainer = Trainer(**train_args.data) | |||||
| # Model | |||||
| model = SeqLabeling(train_args) | |||||
| # Start training | |||||
| trainer.train(model, data_train) | |||||
| # Saver | |||||
| saver = ModelSaver("./save/saved_model.pkl") | |||||
| saver.save_pytorch(model) | |||||
| del model, trainer | |||||
| # Define the same model | |||||
| model = SeqLabeling(train_args) | |||||
| # Dump trained parameters into the model | |||||
| ModelLoader.load_pytorch(model, "./save/saved_model.pkl") | |||||
| # Load test configuration | |||||
| test_args = ConfigSection() | |||||
| ConfigLoader().load_config(config_path, {"POS_infer": test_args}) | |||||
| test_args["evaluator"] = SeqLabelEvaluator() | |||||
| # Tester | |||||
| tester = Tester(**test_args.data) | |||||
| # Start testing | |||||
| data_train.set_target(truth=True) | |||||
| tester.test(model, data_train) | |||||
| def test(): | |||||
| os.makedirs("save", exist_ok=True) | |||||
| train_test() | |||||
| infer() | |||||
| os.system("rm -rf save") | |||||
| if __name__ == "__main__": | |||||
| train_test() | |||||
| infer() | |||||
| @@ -1,90 +0,0 @@ | |||||
| import os | |||||
| from fastNLP.core.metrics import SeqLabelEvaluator | |||||
| from fastNLP.core.optimizer import Optimizer | |||||
| from fastNLP.core.tester import Tester | |||||
| from fastNLP.core.trainer import Trainer | |||||
| from fastNLP.core.utils import save_pickle | |||||
| from fastNLP.core.vocabulary import Vocabulary | |||||
| from fastNLP.io.config_loader import ConfigLoader, ConfigSection | |||||
| from fastNLP.io.dataset_loader import TokenizeDataSetLoader | |||||
| from fastNLP.io.model_loader import ModelLoader | |||||
| from fastNLP.io.model_saver import ModelSaver | |||||
| from fastNLP.models.sequence_modeling import SeqLabeling | |||||
| pickle_path = "./seq_label/" | |||||
| model_name = "seq_label_model.pkl" | |||||
| config_dir = "../data_for_tests/config" | |||||
| data_path = "../data_for_tests/people.txt" | |||||
| data_infer_path = "../data_for_tests/people_infer.txt" | |||||
| def test_training(): | |||||
| # Config Loader | |||||
| trainer_args = ConfigSection() | |||||
| model_args = ConfigSection() | |||||
| ConfigLoader().load_config(config_dir, { | |||||
| "test_seq_label_trainer": trainer_args, "test_seq_label_model": model_args}) | |||||
| data_set = TokenizeDataSetLoader().load(data_path) | |||||
| word_vocab = Vocabulary() | |||||
| label_vocab = Vocabulary() | |||||
| data_set.update_vocab(word_seq=word_vocab, label_seq=label_vocab) | |||||
| data_set.index_field("word_seq", word_vocab).index_field("label_seq", label_vocab) | |||||
| data_set.set_origin_len("word_seq") | |||||
| data_set.rename_field("label_seq", "truth").set_target(truth=False) | |||||
| data_train, data_dev = data_set.split(0.3, shuffle=True) | |||||
| model_args["vocab_size"] = len(word_vocab) | |||||
| model_args["num_classes"] = len(label_vocab) | |||||
| save_pickle(word_vocab, pickle_path, "word2id.pkl") | |||||
| save_pickle(label_vocab, pickle_path, "label2id.pkl") | |||||
| trainer = Trainer( | |||||
| epochs=trainer_args["epochs"], | |||||
| batch_size=trainer_args["batch_size"], | |||||
| validate=False, | |||||
| use_cuda=False, | |||||
| pickle_path=pickle_path, | |||||
| save_best_dev=trainer_args["save_best_dev"], | |||||
| model_name=model_name, | |||||
| optimizer=Optimizer("SGD", lr=0.01, momentum=0.9), | |||||
| ) | |||||
| # Model | |||||
| model = SeqLabeling(model_args) | |||||
| # Start training | |||||
| trainer.train(model, data_train, data_dev) | |||||
| # Saver | |||||
| saver = ModelSaver(os.path.join(pickle_path, model_name)) | |||||
| saver.save_pytorch(model) | |||||
| del model, trainer | |||||
| # Define the same model | |||||
| model = SeqLabeling(model_args) | |||||
| # Dump trained parameters into the model | |||||
| ModelLoader.load_pytorch(model, os.path.join(pickle_path, model_name)) | |||||
| # Load test configuration | |||||
| tester_args = ConfigSection() | |||||
| ConfigLoader().load_config(config_dir, {"test_seq_label_tester": tester_args}) | |||||
| # Tester | |||||
| tester = Tester(batch_size=4, | |||||
| use_cuda=False, | |||||
| pickle_path=pickle_path, | |||||
| model_name="seq_label_in_test.pkl", | |||||
| evaluator=SeqLabelEvaluator() | |||||
| ) | |||||
| # Start testing with validation data | |||||
| data_dev.set_target(truth=True) | |||||
| tester.test(model, data_dev) | |||||
| if __name__ == "__main__": | |||||
| test_training() | |||||
| @@ -1,107 +0,0 @@ | |||||
| # Python: 3.5 | |||||
| # encoding: utf-8 | |||||
| import argparse | |||||
| import os | |||||
| import sys | |||||
| sys.path.append("..") | |||||
| from fastNLP.core.predictor import ClassificationInfer | |||||
| from fastNLP.core.trainer import ClassificationTrainer | |||||
| from fastNLP.io.config_loader import ConfigLoader, ConfigSection | |||||
| from fastNLP.io.dataset_loader import ClassDataSetLoader | |||||
| from fastNLP.io.model_loader import ModelLoader | |||||
| from fastNLP.models.cnn_text_classification import CNNText | |||||
| from fastNLP.io.model_saver import ModelSaver | |||||
| from fastNLP.core.optimizer import Optimizer | |||||
| from fastNLP.core.loss import Loss | |||||
| from fastNLP.core.dataset import TextClassifyDataSet | |||||
| from fastNLP.core.utils import save_pickle, load_pickle | |||||
| parser = argparse.ArgumentParser() | |||||
| parser.add_argument("-s", "--save", type=str, default="./test_classification/", help="path to save pickle files") | |||||
| parser.add_argument("-t", "--train", type=str, default="../data_for_tests/text_classify.txt", | |||||
| help="path to the training data") | |||||
| parser.add_argument("-c", "--config", type=str, default="../data_for_tests/config", help="path to the config file") | |||||
| parser.add_argument("-m", "--model_name", type=str, default="classify_model.pkl", help="the name of the model") | |||||
| args = parser.parse_args() | |||||
| save_dir = args.save | |||||
| train_data_dir = args.train | |||||
| model_name = args.model_name | |||||
| config_dir = args.config | |||||
| def infer(): | |||||
| # load dataset | |||||
| print("Loading data...") | |||||
| word_vocab = load_pickle(save_dir, "word2id.pkl") | |||||
| label_vocab = load_pickle(save_dir, "label2id.pkl") | |||||
| print("vocabulary size:", len(word_vocab)) | |||||
| print("number of classes:", len(label_vocab)) | |||||
| infer_data = TextClassifyDataSet(load_func=ClassDataSetLoader.load) | |||||
| infer_data.load(train_data_dir, vocabs={"word_vocab": word_vocab, "label_vocab": label_vocab}) | |||||
| model_args = ConfigSection() | |||||
| model_args["vocab_size"] = len(word_vocab) | |||||
| model_args["num_classes"] = len(label_vocab) | |||||
| ConfigLoader.load_config(config_dir, {"text_class_model": model_args}) | |||||
| # construct model | |||||
| print("Building model...") | |||||
| cnn = CNNText(model_args) | |||||
| # Dump trained parameters into the model | |||||
| ModelLoader.load_pytorch(cnn, os.path.join(save_dir, model_name)) | |||||
| print("model loaded!") | |||||
| infer = ClassificationInfer(pickle_path=save_dir) | |||||
| results = infer.predict(cnn, infer_data) | |||||
| print(results) | |||||
| def train(): | |||||
| train_args, model_args = ConfigSection(), ConfigSection() | |||||
| ConfigLoader.load_config(config_dir, {"text_class": train_args}) | |||||
| # load dataset | |||||
| print("Loading data...") | |||||
| data = TextClassifyDataSet(load_func=ClassDataSetLoader.load) | |||||
| data.load(train_data_dir) | |||||
| print("vocabulary size:", len(data.word_vocab)) | |||||
| print("number of classes:", len(data.label_vocab)) | |||||
| save_pickle(data.word_vocab, save_dir, "word2id.pkl") | |||||
| save_pickle(data.label_vocab, save_dir, "label2id.pkl") | |||||
| model_args["num_classes"] = len(data.label_vocab) | |||||
| model_args["vocab_size"] = len(data.word_vocab) | |||||
| # construct model | |||||
| print("Building model...") | |||||
| model = CNNText(model_args) | |||||
| # train | |||||
| print("Training...") | |||||
| trainer = ClassificationTrainer(epochs=train_args["epochs"], | |||||
| batch_size=train_args["batch_size"], | |||||
| validate=train_args["validate"], | |||||
| use_cuda=train_args["use_cuda"], | |||||
| pickle_path=save_dir, | |||||
| save_best_dev=train_args["save_best_dev"], | |||||
| model_name=model_name, | |||||
| loss=Loss("cross_entropy"), | |||||
| optimizer=Optimizer("SGD", lr=0.001, momentum=0.9)) | |||||
| trainer.train(model, data) | |||||
| print("Training finished!") | |||||
| saver = ModelSaver(os.path.join(save_dir, model_name)) | |||||
| saver.save_pytorch(model) | |||||
| print("Model saved!") | |||||
| if __name__ == "__main__": | |||||
| train() | |||||
| infer() | |||||
| @@ -14,7 +14,7 @@ class TestGroupNorm(unittest.TestCase): | |||||
| class TestLayerNormalization(unittest.TestCase): | class TestLayerNormalization(unittest.TestCase): | ||||
| def test_case_1(self): | def test_case_1(self): | ||||
| ln = LayerNormalization(d_hid=5, eps=2e-3) | |||||
| ln = LayerNormalization(layer_size=5, eps=2e-3) | |||||
| x = torch.randn((20, 50, 5)) | x = torch.randn((20, 50, 5)) | ||||
| y = ln(x) | y = ln(x) | ||||