| @@ -27,8 +27,8 @@ class Predictor(object): | |||
| self.batch_output = [] | |||
| self.pickle_path = pickle_path | |||
| self._task = task # one of ("seq_label", "text_classify") | |||
| self.index2label = load_pickle(self.pickle_path, "id2class.pkl") | |||
| self.word2index = load_pickle(self.pickle_path, "word2id.pkl") | |||
| self.label_vocab = load_pickle(self.pickle_path, "class2id.pkl") | |||
| self.word_vocab = load_pickle(self.pickle_path, "word2id.pkl") | |||
| def predict(self, network, data): | |||
| """Perform inference using the trained model. | |||
| @@ -82,7 +82,7 @@ class Predictor(object): | |||
| :return data_set: a DataSet instance. | |||
| """ | |||
| assert isinstance(data, list) | |||
| return create_dataset_from_lists(data, self.word2index, has_target=False) | |||
| return create_dataset_from_lists(data, self.word_vocab, has_target=False) | |||
| def prepare_output(self, data): | |||
| """Transform list of batch outputs into strings.""" | |||
| @@ -97,14 +97,14 @@ class Predictor(object): | |||
| results = [] | |||
| for batch in batch_outputs: | |||
| for example in np.array(batch): | |||
| results.append([self.index2label[int(x)] for x in example]) | |||
| results.append([self.label_vocab.to_word(int(x)) for x in example]) | |||
| return results | |||
| def _text_classify_prepare_output(self, batch_outputs): | |||
| results = [] | |||
| for batch_out in batch_outputs: | |||
| idx = np.argmax(batch_out.detach().numpy(), axis=-1) | |||
| results.extend([self.index2label[i] for i in idx]) | |||
| results.extend([self.label_vocab.to_word(i) for i in idx]) | |||
| return results | |||
| @@ -6,16 +6,7 @@ import numpy as np | |||
| from fastNLP.core.dataset import DataSet | |||
| from fastNLP.core.field import TextField, LabelField | |||
| from fastNLP.core.instance import Instance | |||
| DEFAULT_PADDING_LABEL = '<pad>' # dict index = 0 | |||
| DEFAULT_UNKNOWN_LABEL = '<unk>' # dict index = 1 | |||
| DEFAULT_RESERVED_LABEL = ['<reserved-2>', | |||
| '<reserved-3>', | |||
| '<reserved-4>'] # dict index = 2~4 | |||
| DEFAULT_WORD_TO_INDEX = {DEFAULT_PADDING_LABEL: 0, DEFAULT_UNKNOWN_LABEL: 1, | |||
| DEFAULT_RESERVED_LABEL[0]: 2, DEFAULT_RESERVED_LABEL[1]: 3, | |||
| DEFAULT_RESERVED_LABEL[2]: 4} | |||
| from fastNLP.core.vocabulary import Vocabulary | |||
| # the first vocab in dict with the index = 5 | |||
| @@ -68,24 +59,22 @@ class BasePreprocess(object): | |||
| - "word2id.pkl", a mapping from words(tokens) to indices | |||
| - "id2word.pkl", a reversed dictionary | |||
| - "label2id.pkl", a dictionary on labels | |||
| - "id2label.pkl", a reversed dictionary on labels | |||
| These four pickle files are expected to be saved in the given pickle directory once they are constructed. | |||
| Preprocessors will check if those files are already in the directory and will reuse them in future calls. | |||
| """ | |||
| def __init__(self): | |||
| self.word2index = None | |||
| self.label2index = None | |||
| self.data_vocab = Vocabulary() | |||
| self.label_vocab = Vocabulary() | |||
| @property | |||
| def vocab_size(self): | |||
| return len(self.word2index) | |||
| return len(self.data_vocab) | |||
| @property | |||
| def num_classes(self): | |||
| return len(self.label2index) | |||
| return len(self.label_vocab) | |||
| def run(self, train_dev_data, test_data=None, pickle_path="./", train_dev_split=0, cross_val=False, n_fold=10): | |||
| """Main pre-processing pipeline. | |||
| @@ -102,20 +91,14 @@ class BasePreprocess(object): | |||
| """ | |||
| if pickle_exist(pickle_path, "word2id.pkl") and pickle_exist(pickle_path, "class2id.pkl"): | |||
| self.word2index = load_pickle(pickle_path, "word2id.pkl") | |||
| self.label2index = load_pickle(pickle_path, "class2id.pkl") | |||
| self.data_vocab = load_pickle(pickle_path, "word2id.pkl") | |||
| self.label_vocab = load_pickle(pickle_path, "class2id.pkl") | |||
| else: | |||
| self.word2index, self.label2index = self.build_dict(train_dev_data) | |||
| save_pickle(self.word2index, pickle_path, "word2id.pkl") | |||
| save_pickle(self.label2index, pickle_path, "class2id.pkl") | |||
| if not pickle_exist(pickle_path, "id2word.pkl"): | |||
| index2word = self.build_reverse_dict(self.word2index) | |||
| save_pickle(index2word, pickle_path, "id2word.pkl") | |||
| self.data_vocab, self.label_vocab = self.build_dict(train_dev_data) | |||
| save_pickle(self.data_vocab, pickle_path, "word2id.pkl") | |||
| save_pickle(self.label_vocab, pickle_path, "class2id.pkl") | |||
| if not pickle_exist(pickle_path, "id2class.pkl"): | |||
| index2label = self.build_reverse_dict(self.label2index) | |||
| save_pickle(index2label, pickle_path, "id2class.pkl") | |||
| self.build_reverse_dict() | |||
| train_set = [] | |||
| dev_set = [] | |||
| @@ -125,13 +108,13 @@ class BasePreprocess(object): | |||
| split = int(len(train_dev_data) * train_dev_split) | |||
| data_dev = train_dev_data[: split] | |||
| data_train = train_dev_data[split:] | |||
| train_set = self.convert_to_dataset(data_train, self.word2index, self.label2index) | |||
| dev_set = self.convert_to_dataset(data_dev, self.word2index, self.label2index) | |||
| train_set = self.convert_to_dataset(data_train, self.data_vocab, self.label_vocab) | |||
| dev_set = self.convert_to_dataset(data_dev, self.data_vocab, self.label_vocab) | |||
| save_pickle(dev_set, pickle_path, "data_dev.pkl") | |||
| print("{} of the training data is split for validation. ".format(train_dev_split)) | |||
| else: | |||
| train_set = self.convert_to_dataset(train_dev_data, self.word2index, self.label2index) | |||
| train_set = self.convert_to_dataset(train_dev_data, self.data_vocab, self.label_vocab) | |||
| save_pickle(train_set, pickle_path, "data_train.pkl") | |||
| else: | |||
| train_set = load_pickle(pickle_path, "data_train.pkl") | |||
| @@ -143,8 +126,8 @@ class BasePreprocess(object): | |||
| # cross validation | |||
| data_cv = self.cv_split(train_dev_data, n_fold) | |||
| for i, (data_train_cv, data_dev_cv) in enumerate(data_cv): | |||
| data_train_cv = self.convert_to_dataset(data_train_cv, self.word2index, self.label2index) | |||
| data_dev_cv = self.convert_to_dataset(data_dev_cv, self.word2index, self.label2index) | |||
| data_train_cv = self.convert_to_dataset(data_train_cv, self.data_vocab, self.label_vocab) | |||
| data_dev_cv = self.convert_to_dataset(data_dev_cv, self.data_vocab, self.label_vocab) | |||
| save_pickle( | |||
| data_train_cv, pickle_path, | |||
| "data_train_{}.pkl".format(i)) | |||
| @@ -165,7 +148,7 @@ class BasePreprocess(object): | |||
| test_set = [] | |||
| if test_data is not None: | |||
| if not pickle_exist(pickle_path, "data_test.pkl"): | |||
| test_set = self.convert_to_dataset(test_data, self.word2index, self.label2index) | |||
| test_set = self.convert_to_dataset(test_data, self.data_vocab, self.label_vocab) | |||
| save_pickle(test_set, pickle_path, "data_test.pkl") | |||
| # return preprocessed results | |||
| @@ -180,28 +163,15 @@ class BasePreprocess(object): | |||
| return tuple(results) | |||
| def build_dict(self, data): | |||
| label2index = DEFAULT_WORD_TO_INDEX.copy() | |||
| word2index = DEFAULT_WORD_TO_INDEX.copy() | |||
| for example in data: | |||
| for word in example[0]: | |||
| if word not in word2index: | |||
| word2index[word] = len(word2index) | |||
| label = example[1] | |||
| if isinstance(label, str): | |||
| # label is a string | |||
| if label not in label2index: | |||
| label2index[label] = len(label2index) | |||
| elif isinstance(label, list): | |||
| # label is a list of strings | |||
| for single_label in label: | |||
| if single_label not in label2index: | |||
| label2index[single_label] = len(label2index) | |||
| return word2index, label2index | |||
| def build_reverse_dict(self, word_dict): | |||
| id2word = {word_dict[w]: w for w in word_dict} | |||
| return id2word | |||
| word, label = example | |||
| self.data_vocab.update(word) | |||
| self.label_vocab.update(label) | |||
| return self.data_vocab, self.label_vocab | |||
| def build_reverse_dict(self): | |||
| self.data_vocab.build_reverse_vocab() | |||
| self.label_vocab.build_reverse_vocab() | |||
| def data_split(self, data, train_dev_split): | |||
| """Split data into train and dev set.""" | |||
| @@ -0,0 +1,124 @@ | |||
| from copy import deepcopy | |||
| DEFAULT_PADDING_LABEL = '<pad>' # dict index = 0 | |||
| DEFAULT_UNKNOWN_LABEL = '<unk>' # dict index = 1 | |||
| DEFAULT_RESERVED_LABEL = ['<reserved-2>', | |||
| '<reserved-3>', | |||
| '<reserved-4>'] # dict index = 2~4 | |||
| DEFAULT_WORD_TO_INDEX = {DEFAULT_PADDING_LABEL: 0, DEFAULT_UNKNOWN_LABEL: 1, | |||
| DEFAULT_RESERVED_LABEL[0]: 2, DEFAULT_RESERVED_LABEL[1]: 3, | |||
| DEFAULT_RESERVED_LABEL[2]: 4} | |||
| def isiterable(p_object): | |||
| try: | |||
| it = iter(p_object) | |||
| except TypeError: | |||
| return False | |||
| return True | |||
| class Vocabulary(object): | |||
| """Use for word and index one to one mapping | |||
| Example:: | |||
| vocab = Vocabulary() | |||
| word_list = "this is a word list".split() | |||
| vocab.update(word_list) | |||
| vocab["word"] | |||
| vocab.to_word(5) | |||
| """ | |||
| def __init__(self, need_default=True): | |||
| """ | |||
| :param bool need_default: set if the Vocabulary has default labels reserved. | |||
| """ | |||
| if need_default: | |||
| self.word2idx = deepcopy(DEFAULT_WORD_TO_INDEX) | |||
| self.padding_label = DEFAULT_PADDING_LABEL | |||
| self.unknown_label = DEFAULT_UNKNOWN_LABEL | |||
| else: | |||
| self.word2idx = {} | |||
| self.padding_label = None | |||
| self.unknown_label = None | |||
| self.has_default = need_default | |||
| self.idx2word = None | |||
| def __len__(self): | |||
| return len(self.word2idx) | |||
| def update(self, word): | |||
| """add word or list of words into Vocabulary | |||
| :param word: a list of str or str | |||
| """ | |||
| if not isinstance(word, str) and isiterable(word): | |||
| # it's a nested list | |||
| for w in word: | |||
| self.update(w) | |||
| else: | |||
| # it's a word to be added | |||
| if word not in self.word2idx: | |||
| self.word2idx[word] = len(self) | |||
| if self.idx2word is not None: | |||
| self.idx2word = None | |||
| def __getitem__(self, w): | |||
| """To support usage like:: | |||
| vocab[w] | |||
| """ | |||
| if w in self.word2idx: | |||
| return self.word2idx[w] | |||
| else: | |||
| return self.word2idx[DEFAULT_UNKNOWN_LABEL] | |||
| def to_index(self, w): | |||
| """ like to_index(w) function, turn a word to the index | |||
| if w is not in Vocabulary, return the unknown label | |||
| :param str w: | |||
| """ | |||
| return self[w] | |||
| def unknown_idx(self): | |||
| if self.unknown_label is None: | |||
| return None | |||
| return self.word2idx[self.unknown_label] | |||
| def padding_idx(self): | |||
| if self.padding_label is None: | |||
| return None | |||
| return self.word2idx[self.padding_label] | |||
| def build_reverse_vocab(self): | |||
| """build 'index to word' dict based on 'word to index' dict | |||
| """ | |||
| self.idx2word = {self.word2idx[w] : w for w in self.word2idx} | |||
| def to_word(self, idx): | |||
| """given a word's index, return the word itself | |||
| :param int idx: | |||
| """ | |||
| if self.idx2word is None: | |||
| self.build_reverse_vocab() | |||
| return self.idx2word[idx] | |||
| def __getstate__(self): | |||
| """use to prepare data for pickle | |||
| """ | |||
| state = self.__dict__.copy() | |||
| # no need to pickle idx2word as it can be constructed from word2idx | |||
| del state['idx2word'] | |||
| return state | |||
| def __setstate__(self, state): | |||
| """use to restore state from pickle | |||
| """ | |||
| self.__dict__.update(state) | |||
| self.idx2word = None | |||
| @@ -69,7 +69,7 @@ class FastNLP(object): | |||
| :param model_dir: this directory should contain the following files: | |||
| 1. a pre-trained model | |||
| 2. a config file | |||
| 3. "id2class.pkl" | |||
| 3. "class2id.pkl" | |||
| 4. "word2id.pkl" | |||
| """ | |||
| self.model_dir = model_dir | |||
| @@ -99,10 +99,10 @@ class FastNLP(object): | |||
| print("Restore model hyper-parameters {}".format(str(model_args.data))) | |||
| # fetch dictionary size and number of labels from pickle files | |||
| word2index = load_pickle(self.model_dir, "word2id.pkl") | |||
| model_args["vocab_size"] = len(word2index) | |||
| index2label = load_pickle(self.model_dir, "id2class.pkl") | |||
| model_args["num_classes"] = len(index2label) | |||
| word_vocab = load_pickle(self.model_dir, "word2id.pkl") | |||
| model_args["vocab_size"] = len(word_vocab) | |||
| label_vocab = load_pickle(self.model_dir, "class2id.pkl") | |||
| model_args["num_classes"] = len(label_vocab) | |||
| # Construct the model | |||
| model = model_class(model_args) | |||
| @@ -32,7 +32,7 @@ def infer(): | |||
| # fetch dictionary size and number of labels from pickle files | |||
| word2index = load_pickle(pickle_path, "word2id.pkl") | |||
| test_args["vocab_size"] = len(word2index) | |||
| index2label = load_pickle(pickle_path, "id2class.pkl") | |||
| index2label = load_pickle(pickle_path, "class2id.pkl") | |||
| test_args["num_classes"] = len(index2label) | |||
| @@ -105,7 +105,7 @@ def test(): | |||
| # fetch dictionary size and number of labels from pickle files | |||
| word2index = load_pickle(pickle_path, "word2id.pkl") | |||
| test_args["vocab_size"] = len(word2index) | |||
| index2label = load_pickle(pickle_path, "id2class.pkl") | |||
| index2label = load_pickle(pickle_path, "class2id.pkl") | |||
| test_args["num_classes"] = len(index2label) | |||
| # load dev data | |||
| @@ -33,7 +33,7 @@ def infer(): | |||
| # fetch dictionary size and number of labels from pickle files | |||
| word2index = load_pickle(pickle_path, "word2id.pkl") | |||
| test_args["vocab_size"] = len(word2index) | |||
| index2label = load_pickle(pickle_path, "id2class.pkl") | |||
| index2label = load_pickle(pickle_path, "class2id.pkl") | |||
| test_args["num_classes"] = len(index2label) | |||
| # Define the same model | |||
| @@ -105,7 +105,7 @@ def test(): | |||
| # fetch dictionary size and number of labels from pickle files | |||
| word2index = load_pickle(pickle_path, "word2id.pkl") | |||
| test_args["vocab_size"] = len(word2index) | |||
| index2label = load_pickle(pickle_path, "id2class.pkl") | |||
| index2label = load_pickle(pickle_path, "class2id.pkl") | |||
| test_args["num_classes"] = len(index2label) | |||
| # load dev data | |||
| @@ -4,6 +4,7 @@ import unittest | |||
| from fastNLP.core.predictor import Predictor | |||
| from fastNLP.core.preprocess import save_pickle | |||
| from fastNLP.models.sequence_modeling import SeqLabeling | |||
| from fastNLP.core.vocabulary import Vocabulary | |||
| class TestPredictor(unittest.TestCase): | |||
| @@ -23,10 +24,14 @@ class TestPredictor(unittest.TestCase): | |||
| ['a', 'b', 'c', 'd', '$'], | |||
| ['!', 'b', 'c', 'd', 'e'] | |||
| ] | |||
| vocab = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, '!': 5, '@': 6, '#': 7, '$': 8, '?': 9} | |||
| vocab = Vocabulary() | |||
| vocab.word2idx = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, '!': 5, '@': 6, '#': 7, '$': 8, '?': 9} | |||
| class_vocab = Vocabulary() | |||
| class_vocab.word2idx = {"0":0, "1":1, "2":2, "3":3, "4":4} | |||
| os.system("mkdir save") | |||
| save_pickle({0: "0", 1: "1", 2: "2", 3: "3", 4: "4"}, "./save/", "id2class.pkl") | |||
| save_pickle(class_vocab, "./save/", "class2id.pkl") | |||
| save_pickle(vocab, "./save/", "word2id.pkl") | |||
| model = SeqLabeling(model_args) | |||
| @@ -0,0 +1,31 @@ | |||
| import unittest | |||
| from fastNLP.core.vocabulary import Vocabulary, DEFAULT_WORD_TO_INDEX | |||
| class TestVocabulary(unittest.TestCase): | |||
| def test_vocab(self): | |||
| import _pickle as pickle | |||
| import os | |||
| vocab = Vocabulary() | |||
| filename = 'vocab' | |||
| vocab.update(filename) | |||
| vocab.update([filename, ['a'], [['b']], ['c']]) | |||
| idx = vocab[filename] | |||
| before_pic = (vocab.to_word(idx), vocab[filename]) | |||
| with open(filename, 'wb') as f: | |||
| pickle.dump(vocab, f) | |||
| with open(filename, 'rb') as f: | |||
| vocab = pickle.load(f) | |||
| os.remove(filename) | |||
| vocab.build_reverse_vocab() | |||
| after_pic = (vocab.to_word(idx), vocab[filename]) | |||
| TRUE_DICT = {'vocab': 5, 'a': 6, 'b': 7, 'c': 8} | |||
| TRUE_DICT.update(DEFAULT_WORD_TO_INDEX) | |||
| TRUE_IDXDICT = {0: '<pad>', 1: '<unk>', 2: '<reserved-2>', 3: '<reserved-3>', 4: '<reserved-4>', 5: 'vocab', 6: 'a', 7: 'b', 8: 'c'} | |||
| self.assertEqual(before_pic, after_pic) | |||
| self.assertDictEqual(TRUE_DICT, vocab.word2idx) | |||
| self.assertDictEqual(TRUE_IDXDICT, vocab.idx2word) | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||
| @@ -38,7 +38,7 @@ def infer(): | |||
| # fetch dictionary size and number of labels from pickle files | |||
| word2index = load_pickle(pickle_path, "word2id.pkl") | |||
| test_args["vocab_size"] = len(word2index) | |||
| index2label = load_pickle(pickle_path, "id2class.pkl") | |||
| index2label = load_pickle(pickle_path, "class2id.pkl") | |||
| test_args["num_classes"] = len(index2label) | |||
| # Define the same model | |||
| @@ -27,7 +27,7 @@ def infer(): | |||
| # fetch dictionary size and number of labels from pickle files | |||
| word2index = load_pickle(pickle_path, "word2id.pkl") | |||
| test_args["vocab_size"] = len(word2index) | |||
| index2label = load_pickle(pickle_path, "id2class.pkl") | |||
| index2label = load_pickle(pickle_path, "class2id.pkl") | |||
| test_args["num_classes"] = len(index2label) | |||
| # Define the same model | |||