diff --git a/fastNLP/core/predictor.py b/fastNLP/core/predictor.py index 802661ef..c83b2069 100644 --- a/fastNLP/core/predictor.py +++ b/fastNLP/core/predictor.py @@ -27,8 +27,8 @@ class Predictor(object): self.batch_output = [] self.pickle_path = pickle_path self._task = task # one of ("seq_label", "text_classify") - self.index2label = load_pickle(self.pickle_path, "id2class.pkl") - self.word2index = load_pickle(self.pickle_path, "word2id.pkl") + self.label_vocab = load_pickle(self.pickle_path, "class2id.pkl") + self.word_vocab = load_pickle(self.pickle_path, "word2id.pkl") def predict(self, network, data): """Perform inference using the trained model. @@ -82,7 +82,7 @@ class Predictor(object): :return data_set: a DataSet instance. """ assert isinstance(data, list) - return create_dataset_from_lists(data, self.word2index, has_target=False) + return create_dataset_from_lists(data, self.word_vocab, has_target=False) def prepare_output(self, data): """Transform list of batch outputs into strings.""" @@ -97,14 +97,14 @@ class Predictor(object): results = [] for batch in batch_outputs: for example in np.array(batch): - results.append([self.index2label[int(x)] for x in example]) + results.append([self.label_vocab.to_word(int(x)) for x in example]) return results def _text_classify_prepare_output(self, batch_outputs): results = [] for batch_out in batch_outputs: idx = np.argmax(batch_out.detach().numpy(), axis=-1) - results.extend([self.index2label[i] for i in idx]) + results.extend([self.label_vocab.to_word(i) for i in idx]) return results diff --git a/fastNLP/fastnlp.py b/fastNLP/fastnlp.py index c76e6681..e683950d 100644 --- a/fastNLP/fastnlp.py +++ b/fastNLP/fastnlp.py @@ -69,7 +69,7 @@ class FastNLP(object): :param model_dir: this directory should contain the following files: 1. a pre-trained model 2. a config file - 3. "id2class.pkl" + 3. "class2id.pkl" 4. "word2id.pkl" """ self.model_dir = model_dir @@ -99,10 +99,10 @@ class FastNLP(object): print("Restore model hyper-parameters {}".format(str(model_args.data))) # fetch dictionary size and number of labels from pickle files - word2index = load_pickle(self.model_dir, "word2id.pkl") - model_args["vocab_size"] = len(word2index) - index2label = load_pickle(self.model_dir, "id2class.pkl") - model_args["num_classes"] = len(index2label) + word_vocab = load_pickle(self.model_dir, "word2id.pkl") + model_args["vocab_size"] = len(word_vocab) + label_vocab = load_pickle(self.model_dir, "class2id.pkl") + model_args["num_classes"] = len(label_vocab) # Construct the model model = model_class(model_args) diff --git a/reproduction/chinese_word_segment/run.py b/reproduction/chinese_word_segment/run.py index d0a22e84..0d5ae8c1 100644 --- a/reproduction/chinese_word_segment/run.py +++ b/reproduction/chinese_word_segment/run.py @@ -32,7 +32,7 @@ def infer(): # fetch dictionary size and number of labels from pickle files word2index = load_pickle(pickle_path, "word2id.pkl") test_args["vocab_size"] = len(word2index) - index2label = load_pickle(pickle_path, "id2class.pkl") + index2label = load_pickle(pickle_path, "class2id.pkl") test_args["num_classes"] = len(index2label) @@ -105,7 +105,7 @@ def test(): # fetch dictionary size and number of labels from pickle files word2index = load_pickle(pickle_path, "word2id.pkl") test_args["vocab_size"] = len(word2index) - index2label = load_pickle(pickle_path, "id2class.pkl") + index2label = load_pickle(pickle_path, "class2id.pkl") test_args["num_classes"] = len(index2label) # load dev data diff --git a/reproduction/pos_tag_model/train_pos_tag.py b/reproduction/pos_tag_model/train_pos_tag.py index 87a9f7e8..15164130 100644 --- a/reproduction/pos_tag_model/train_pos_tag.py +++ b/reproduction/pos_tag_model/train_pos_tag.py @@ -33,7 +33,7 @@ def infer(): # fetch dictionary size and number of labels from pickle files word2index = load_pickle(pickle_path, "word2id.pkl") test_args["vocab_size"] = len(word2index) - index2label = load_pickle(pickle_path, "id2class.pkl") + index2label = load_pickle(pickle_path, "class2id.pkl") test_args["num_classes"] = len(index2label) # Define the same model @@ -105,7 +105,7 @@ def test(): # fetch dictionary size and number of labels from pickle files word2index = load_pickle(pickle_path, "word2id.pkl") test_args["vocab_size"] = len(word2index) - index2label = load_pickle(pickle_path, "id2class.pkl") + index2label = load_pickle(pickle_path, "class2id.pkl") test_args["num_classes"] = len(index2label) # load dev data diff --git a/test/core/test_predictor.py b/test/core/test_predictor.py index c7ad65d7..411f636e 100644 --- a/test/core/test_predictor.py +++ b/test/core/test_predictor.py @@ -4,6 +4,7 @@ import unittest from fastNLP.core.predictor import Predictor from fastNLP.core.preprocess import save_pickle from fastNLP.models.sequence_modeling import SeqLabeling +from fastNLP.core.vocabulary import Vocabulary class TestPredictor(unittest.TestCase): @@ -23,10 +24,14 @@ class TestPredictor(unittest.TestCase): ['a', 'b', 'c', 'd', '$'], ['!', 'b', 'c', 'd', 'e'] ] - vocab = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, '!': 5, '@': 6, '#': 7, '$': 8, '?': 9} + + vocab = Vocabulary() + vocab.word2idx = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, '!': 5, '@': 6, '#': 7, '$': 8, '?': 9} + class_vocab = Vocabulary() + class_vocab.word2idx = {"0":0, "1":1, "2":2, "3":3, "4":4} os.system("mkdir save") - save_pickle({0: "0", 1: "1", 2: "2", 3: "3", 4: "4"}, "./save/", "id2class.pkl") + save_pickle(class_vocab, "./save/", "class2id.pkl") save_pickle(vocab, "./save/", "word2id.pkl") model = SeqLabeling(model_args) diff --git a/test/model/seq_labeling.py b/test/model/seq_labeling.py index d7750b17..cd011c0d 100644 --- a/test/model/seq_labeling.py +++ b/test/model/seq_labeling.py @@ -38,7 +38,7 @@ def infer(): # fetch dictionary size and number of labels from pickle files word2index = load_pickle(pickle_path, "word2id.pkl") test_args["vocab_size"] = len(word2index) - index2label = load_pickle(pickle_path, "id2class.pkl") + index2label = load_pickle(pickle_path, "class2id.pkl") test_args["num_classes"] = len(index2label) # Define the same model diff --git a/test/model/test_cws.py b/test/model/test_cws.py index 802d97ba..70716c3a 100644 --- a/test/model/test_cws.py +++ b/test/model/test_cws.py @@ -27,7 +27,7 @@ def infer(): # fetch dictionary size and number of labels from pickle files word2index = load_pickle(pickle_path, "word2id.pkl") test_args["vocab_size"] = len(word2index) - index2label = load_pickle(pickle_path, "id2class.pkl") + index2label = load_pickle(pickle_path, "class2id.pkl") test_args["num_classes"] = len(index2label) # Define the same model