| @@ -134,7 +134,10 @@ class BasePreprocess(object): | |||||
| results.append(data_dev) | results.append(data_dev) | ||||
| if test_data: | if test_data: | ||||
| results.append(data_test) | results.append(data_test) | ||||
| return tuple(results) | |||||
| if len(results) == 1: | |||||
| return results[0] | |||||
| else: | |||||
| return tuple(results) | |||||
| def build_dict(self, data): | def build_dict(self, data): | ||||
| raise NotImplementedError | raise NotImplementedError | ||||
| @@ -282,7 +285,8 @@ class ClassPreprocess(BasePreprocess): | |||||
| data_index = [] | data_index = [] | ||||
| for example in data: | for example in data: | ||||
| word_list = [] | word_list = [] | ||||
| for word, label in zip(example[0]): | |||||
| # example[0] is the word list, example[1] is the single label | |||||
| for word in example[0]: | |||||
| word_list.append(self.word2index.get(word, DEFAULT_WORD_TO_INDEX[DEFAULT_UNKNOWN_LABEL])) | word_list.append(self.word2index.get(word, DEFAULT_WORD_TO_INDEX[DEFAULT_UNKNOWN_LABEL])) | ||||
| label_index = self.label2index.get(example[1], DEFAULT_WORD_TO_INDEX[DEFAULT_UNKNOWN_LABEL]) | label_index = self.label2index.get(example[1], DEFAULT_WORD_TO_INDEX[DEFAULT_UNKNOWN_LABEL]) | ||||
| data_index.append([word_list, label_index]) | data_index.append([word_list, label_index]) | ||||
| @@ -95,10 +95,10 @@ num_classes = 27 | |||||
| [text_class] | [text_class] | ||||
| epochs = 1 | epochs = 1 | ||||
| batch_size = 10 | batch_size = 10 | ||||
| pickle_path = "./data_for_tests/" | |||||
| pickle_path = "./save_path/" | |||||
| validate = false | validate = false | ||||
| save_best_dev = false | save_best_dev = false | ||||
| model_saved_path = "./data_for_tests/" | |||||
| model_saved_path = "./save_path/" | |||||
| use_cuda = true | use_cuda = true | ||||
| learn_rate = 1e-3 | learn_rate = 1e-3 | ||||
| momentum = 0.9 | momentum = 0.9 | ||||
| @@ -14,7 +14,7 @@ from fastNLP.core.predictor import SeqLabelInfer | |||||
| data_name = "people.txt" | data_name = "people.txt" | ||||
| data_path = "data_for_tests/people.txt" | data_path = "data_for_tests/people.txt" | ||||
| pickle_path = "data_for_tests" | |||||
| pickle_path = "seq_label/" | |||||
| data_infer_path = "data_for_tests/people_infer.txt" | data_infer_path = "data_for_tests/people_infer.txt" | ||||
| @@ -33,21 +33,12 @@ def infer(): | |||||
| model = SeqLabeling(test_args) | model = SeqLabeling(test_args) | ||||
| # Dump trained parameters into the model | # Dump trained parameters into the model | ||||
| ModelLoader.load_pytorch(model, "./data_for_tests/saved_model.pkl") | |||||
| ModelLoader.load_pytorch(model, pickle_path + "saved_model.pkl") | |||||
| print("model loaded!") | print("model loaded!") | ||||
| # Data Loader | # Data Loader | ||||
| raw_data_loader = BaseLoader(data_name, data_infer_path) | raw_data_loader = BaseLoader(data_name, data_infer_path) | ||||
| infer_data = raw_data_loader.load_lines() | infer_data = raw_data_loader.load_lines() | ||||
| """ | |||||
| Transform strings into list of list of strings. | |||||
| [ | |||||
| [word_11, word_12, ...], | |||||
| [word_21, word_22, ...], | |||||
| ... | |||||
| ] | |||||
| In this case, each line in "people_infer.txt" is already a sentence. So load_lines() just splits them. | |||||
| """ | |||||
| # Inference interface | # Inference interface | ||||
| infer = SeqLabelInfer(pickle_path) | infer = SeqLabelInfer(pickle_path) | ||||
| @@ -69,7 +60,7 @@ def train_and_test(): | |||||
| # Preprocessor | # Preprocessor | ||||
| p = SeqLabelPreprocess() | p = SeqLabelPreprocess() | ||||
| data_train, data_dev = p.run(train_data, pickle_path, train_dev_split=0.5) | |||||
| data_train, data_dev = p.run(train_data, pickle_path=pickle_path, train_dev_split=0.5) | |||||
| train_args["vocab_size"] = p.vocab_size | train_args["vocab_size"] = p.vocab_size | ||||
| train_args["num_classes"] = p.num_classes | train_args["num_classes"] = p.num_classes | ||||
| @@ -84,7 +75,7 @@ def train_and_test(): | |||||
| print("Training finished!") | print("Training finished!") | ||||
| # Saver | # Saver | ||||
| saver = ModelSaver("./data_for_tests/saved_model.pkl") | |||||
| saver = ModelSaver(pickle_path + "saved_model.pkl") | |||||
| saver.save_pytorch(model) | saver.save_pytorch(model) | ||||
| print("Model saved!") | print("Model saved!") | ||||
| @@ -94,7 +85,7 @@ def train_and_test(): | |||||
| model = SeqLabeling(train_args) | model = SeqLabeling(train_args) | ||||
| # Dump trained parameters into the model | # Dump trained parameters into the model | ||||
| ModelLoader.load_pytorch(model, "./data_for_tests/saved_model.pkl") | |||||
| ModelLoader.load_pytorch(model, pickle_path + "saved_model.pkl") | |||||
| print("model loaded!") | print("model loaded!") | ||||
| # Load test configuration | # Load test configuration | ||||
| @@ -14,6 +14,7 @@ from fastNLP.core.preprocess import ClassPreprocess | |||||
| from fastNLP.models.cnn_text_classification import CNNText | from fastNLP.models.cnn_text_classification import CNNText | ||||
| from fastNLP.saver.model_saver import ModelSaver | from fastNLP.saver.model_saver import ModelSaver | ||||
| save_path = "./test_classification/" | |||||
| data_dir = "./data_for_tests/" | data_dir = "./data_for_tests/" | ||||
| train_file = 'text_classify.txt' | train_file = 'text_classify.txt' | ||||
| model_name = "model_class.pkl" | model_name = "model_class.pkl" | ||||
| @@ -27,8 +28,8 @@ def infer(): | |||||
| unlabeled_data = [x[0] for x in data] | unlabeled_data = [x[0] for x in data] | ||||
| # pre-process data | # pre-process data | ||||
| pre = ClassPreprocess(data_dir) | |||||
| vocab_size, n_classes = pre.process(data, "data_train.pkl") | |||||
| pre = ClassPreprocess() | |||||
| vocab_size, n_classes = pre.run(data, pickle_path=save_path) | |||||
| print("vocabulary size:", vocab_size) | print("vocabulary size:", vocab_size) | ||||
| print("number of classes:", n_classes) | print("number of classes:", n_classes) | ||||
| @@ -60,7 +61,7 @@ def train(): | |||||
| # pre-process data | # pre-process data | ||||
| pre = ClassPreprocess() | pre = ClassPreprocess() | ||||
| data_train = pre.run(data, pickle_path=data_dir) | |||||
| data_train = pre.run(data, pickle_path=save_path) | |||||
| print("vocabulary size:", pre.vocab_size) | print("vocabulary size:", pre.vocab_size) | ||||
| print("number of classes:", pre.num_classes) | print("number of classes:", pre.num_classes) | ||||