import os import unittest from fastNLP.core.dataset import DataSet from fastNLP.core.field import TextField, LabelField from fastNLP.core.instance import Instance from fastNLP.core.loss import Loss from fastNLP.core.metrics import SeqLabelEvaluator from fastNLP.core.optimizer import Optimizer from fastNLP.core.trainer import Trainer from fastNLP.models.sequence_modeling import SeqLabeling class TestTrainer(unittest.TestCase): def test_case_1(self): args = {"epochs": 3, "batch_size": 2, "validate": False, "use_cuda": False, "pickle_path": "./save/", "save_best_dev": True, "model_name": "default_model_name.pkl", "loss": Loss("cross_entropy"), "optimizer": Optimizer("Adam", lr=0.001, weight_decay=0), "vocab_size": 10, "word_emb_dim": 100, "rnn_hidden_units": 100, "num_classes": 5, "evaluator": SeqLabelEvaluator() } trainer = Trainer(**args) train_data = [ [['a', 'b', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']], [['a', '@', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']], [['a', 'b', '#', 'd', 'e'], ['a', '@', 'c', 'd', 'e']], [['a', 'b', 'c', '?', 'e'], ['a', '@', 'c', 'd', 'e']], [['a', 'b', 'c', 'd', '$'], ['a', '@', 'c', 'd', 'e']], [['!', 'b', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']], ] vocab = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, '!': 5, '@': 6, '#': 7, '$': 8, '?': 9} label_vocab = {'a': 0, '@': 1, 'c': 2, 'd': 3, 'e': 4} data_set = DataSet() for example in train_data: text, label = example[0], example[1] x = TextField(text, False) x_len = LabelField(len(text), is_target=False) y = TextField(label, is_target=False) ins = Instance(word_seq=x, truth=y, word_seq_origin_len=x_len) data_set.append(ins) data_set.index_field("word_seq", vocab) data_set.index_field("truth", label_vocab) model = SeqLabeling(args) trainer.train(network=model, train_data=data_set, dev_data=data_set) # If this can run, everything is OK. os.system("rm -rf save") print("pickle path deleted")