| @@ -134,7 +134,10 @@ class BasePreprocess(object): | |||
| results.append(data_dev) | |||
| if test_data: | |||
| results.append(data_test) | |||
| return tuple(results) | |||
| if len(results) == 1: | |||
| return results[0] | |||
| else: | |||
| return tuple(results) | |||
| def build_dict(self, data): | |||
| raise NotImplementedError | |||
| @@ -282,7 +285,8 @@ class ClassPreprocess(BasePreprocess): | |||
| data_index = [] | |||
| for example in data: | |||
| word_list = [] | |||
| for word, label in zip(example[0]): | |||
| # example[0] is the word list, example[1] is the single label | |||
| for word in example[0]: | |||
| word_list.append(self.word2index.get(word, DEFAULT_WORD_TO_INDEX[DEFAULT_UNKNOWN_LABEL])) | |||
| label_index = self.label2index.get(example[1], DEFAULT_WORD_TO_INDEX[DEFAULT_UNKNOWN_LABEL]) | |||
| data_index.append([word_list, label_index]) | |||
| @@ -95,10 +95,10 @@ num_classes = 27 | |||
| [text_class] | |||
| epochs = 1 | |||
| batch_size = 10 | |||
| pickle_path = "./data_for_tests/" | |||
| pickle_path = "./save_path/" | |||
| validate = false | |||
| save_best_dev = false | |||
| model_saved_path = "./data_for_tests/" | |||
| model_saved_path = "./save_path/" | |||
| use_cuda = true | |||
| learn_rate = 1e-3 | |||
| momentum = 0.9 | |||
| @@ -14,7 +14,7 @@ from fastNLP.core.predictor import SeqLabelInfer | |||
| data_name = "people.txt" | |||
| data_path = "data_for_tests/people.txt" | |||
| pickle_path = "data_for_tests" | |||
| pickle_path = "seq_label/" | |||
| data_infer_path = "data_for_tests/people_infer.txt" | |||
| @@ -33,21 +33,12 @@ def infer(): | |||
| model = SeqLabeling(test_args) | |||
| # Dump trained parameters into the model | |||
| ModelLoader.load_pytorch(model, "./data_for_tests/saved_model.pkl") | |||
| ModelLoader.load_pytorch(model, pickle_path + "saved_model.pkl") | |||
| print("model loaded!") | |||
| # Data Loader | |||
| raw_data_loader = BaseLoader(data_name, data_infer_path) | |||
| infer_data = raw_data_loader.load_lines() | |||
| """ | |||
| Transform strings into list of list of strings. | |||
| [ | |||
| [word_11, word_12, ...], | |||
| [word_21, word_22, ...], | |||
| ... | |||
| ] | |||
| In this case, each line in "people_infer.txt" is already a sentence. So load_lines() just splits them. | |||
| """ | |||
| # Inference interface | |||
| infer = SeqLabelInfer(pickle_path) | |||
| @@ -69,7 +60,7 @@ def train_and_test(): | |||
| # Preprocessor | |||
| p = SeqLabelPreprocess() | |||
| data_train, data_dev = p.run(train_data, pickle_path, train_dev_split=0.5) | |||
| data_train, data_dev = p.run(train_data, pickle_path=pickle_path, train_dev_split=0.5) | |||
| train_args["vocab_size"] = p.vocab_size | |||
| train_args["num_classes"] = p.num_classes | |||
| @@ -84,7 +75,7 @@ def train_and_test(): | |||
| print("Training finished!") | |||
| # Saver | |||
| saver = ModelSaver("./data_for_tests/saved_model.pkl") | |||
| saver = ModelSaver(pickle_path + "saved_model.pkl") | |||
| saver.save_pytorch(model) | |||
| print("Model saved!") | |||
| @@ -94,7 +85,7 @@ def train_and_test(): | |||
| model = SeqLabeling(train_args) | |||
| # Dump trained parameters into the model | |||
| ModelLoader.load_pytorch(model, "./data_for_tests/saved_model.pkl") | |||
| ModelLoader.load_pytorch(model, pickle_path + "saved_model.pkl") | |||
| print("model loaded!") | |||
| # Load test configuration | |||
| @@ -14,6 +14,7 @@ from fastNLP.core.preprocess import ClassPreprocess | |||
| from fastNLP.models.cnn_text_classification import CNNText | |||
| from fastNLP.saver.model_saver import ModelSaver | |||
| save_path = "./test_classification/" | |||
| data_dir = "./data_for_tests/" | |||
| train_file = 'text_classify.txt' | |||
| model_name = "model_class.pkl" | |||
| @@ -27,8 +28,8 @@ def infer(): | |||
| unlabeled_data = [x[0] for x in data] | |||
| # pre-process data | |||
| pre = ClassPreprocess(data_dir) | |||
| vocab_size, n_classes = pre.process(data, "data_train.pkl") | |||
| pre = ClassPreprocess() | |||
| vocab_size, n_classes = pre.run(data, pickle_path=save_path) | |||
| print("vocabulary size:", vocab_size) | |||
| print("number of classes:", n_classes) | |||
| @@ -60,7 +61,7 @@ def train(): | |||
| # pre-process data | |||
| pre = ClassPreprocess() | |||
| data_train = pre.run(data, pickle_path=data_dir) | |||
| data_train = pre.run(data, pickle_path=save_path) | |||
| print("vocabulary size:", pre.vocab_size) | |||
| print("number of classes:", pre.num_classes) | |||