* 重构dtype的检测代码,在FieldArray的初始化和append两处,达到更好的代码复用 * 类型检测的责任完全落在FieldArray,DataSet与之配合 测试: * 整理dtype相关的测试代码 * 给所有tutorial添加测试 其他: * 完善一个完整的Conll dataset loader * 升级POS tag model训练脚本tags/v0.3.1^2
| @@ -2,8 +2,8 @@ import _pickle as pickle | |||
| import numpy as np | |||
| from fastNLP.core.fieldarray import FieldArray | |||
| from fastNLP.core.fieldarray import AutoPadder | |||
| from fastNLP.core.fieldarray import FieldArray | |||
| from fastNLP.core.instance import Instance | |||
| from fastNLP.core.utils import get_func_signature | |||
| from fastNLP.io.base_loader import DataLoaderRegister | |||
| @@ -142,7 +142,8 @@ class DataSet(object): | |||
| if len(self.field_arrays) == 0: | |||
| # DataSet has no field yet | |||
| for name, field in ins.fields.items(): | |||
| self.field_arrays[name] = FieldArray(name, [field]) | |||
| field = field.tolist() if isinstance(field, np.ndarray) else field | |||
| self.field_arrays[name] = FieldArray(name, [field]) # 第一个样本,必须用list包装起来 | |||
| else: | |||
| if len(self.field_arrays) != len(ins.fields): | |||
| raise ValueError( | |||
| @@ -290,9 +291,11 @@ class DataSet(object): | |||
| extra_param['is_input'] = old_field.is_input | |||
| if 'is_target' not in extra_param: | |||
| extra_param['is_target'] = old_field.is_target | |||
| self.add_field(name=new_field_name, fields=results) | |||
| self.add_field(name=new_field_name, fields=results, is_input=extra_param["is_input"], | |||
| is_target=extra_param["is_target"]) | |||
| else: | |||
| self.add_field(name=new_field_name, fields=results) | |||
| self.add_field(name=new_field_name, fields=results, is_input=extra_param.get("is_input", None), | |||
| is_target=extra_param.get("is_target", None)) | |||
| else: | |||
| return results | |||
| @@ -334,13 +337,14 @@ class DataSet(object): | |||
| train_set.field_arrays[field_name].padder = self.field_arrays[field_name].padder | |||
| train_set.field_arrays[field_name].dtype = self.field_arrays[field_name].dtype | |||
| train_set.field_arrays[field_name].pytype = self.field_arrays[field_name].pytype | |||
| train_set.field_arrays[field_name].is_2d_list = self.field_arrays[field_name].is_2d_list | |||
| train_set.field_arrays[field_name].content_dim = self.field_arrays[field_name].content_dim | |||
| dev_set.field_arrays[field_name].is_input = self.field_arrays[field_name].is_input | |||
| dev_set.field_arrays[field_name].is_target = self.field_arrays[field_name].is_target | |||
| dev_set.field_arrays[field_name].padder = self.field_arrays[field_name].padder | |||
| dev_set.field_arrays[field_name].dtype = self.field_arrays[field_name].dtype | |||
| dev_set.field_arrays[field_name].pytype = self.field_arrays[field_name].pytype | |||
| dev_set.field_arrays[field_name].is_2d_list = self.field_arrays[field_name].is_2d_list | |||
| dev_set.field_arrays[field_name].content_dim = self.field_arrays[field_name].content_dim | |||
| return train_set, dev_set | |||
| @@ -100,6 +100,22 @@ class FieldArray(object): | |||
| """ | |||
| def __init__(self, name, content, is_target=None, is_input=None, padder=AutoPadder(pad_val=0)): | |||
| """DataSet在初始化时会有两类方法对FieldArray操作: | |||
| 1) 如果DataSet使用dict初始化,那么在add_field中会构造FieldArray: | |||
| 1.1) 二维list DataSet({"x": [[1, 2], [3, 4]]}) | |||
| 1.2) 二维array DataSet({"x": np.array([[1, 2], [3, 4]])}) | |||
| 1.3) 三维list DataSet({"x": [[[1, 2], [3, 4]], [[1, 2], [3, 4]]]}) | |||
| 2) 如果DataSet使用list of Instance 初始化,那么在append中会先对第一个样本初始化FieldArray; | |||
| 然后后面的样本使用FieldArray.append进行添加。 | |||
| 2.1) 一维list DataSet([Instance(x=[1, 2, 3, 4])]) | |||
| 2.2) 一维array DataSet([Instance(x=np.array([1, 2, 3, 4]))]) | |||
| 2.3) 二维list DataSet([Instance(x=[[1, 2], [3, 4]])]) | |||
| 2.4) 二维array DataSet([Instance(x=np.array([[1, 2], [3, 4]]))]) | |||
| 注意:np.array必须仅在最外层,即np.array([np.array, np.array]) 和 list of np.array不考虑 | |||
| 类型检查(dtype check)发生在当该field被设置为is_input或者is_target时。 | |||
| """ | |||
| self.name = name | |||
| if isinstance(content, list): | |||
| content = content | |||
| @@ -107,31 +123,39 @@ class FieldArray(object): | |||
| content = content.tolist() # convert np.ndarray into 2-D list | |||
| else: | |||
| raise TypeError("content in FieldArray can only be list or numpy.ndarray, got {}.".format(type(content))) | |||
| self.content = content | |||
| if len(content) == 0: | |||
| raise RuntimeError("Cannot initialize FieldArray with empty list.") | |||
| self.content = content # 1维 或 2维 或 3维 list, 形状可能不对齐 | |||
| self.content_dim = None # 表示content是多少维的list | |||
| self.set_padder(padder) | |||
| self._is_target = None | |||
| self._is_input = None | |||
| self.BASIC_TYPES = (int, float, str) # content中可接受的Python基本类型,这里没有np.array | |||
| self.BASIC_TYPES = (int, float, str, np.ndarray) | |||
| self.is_2d_list = False | |||
| self.pytype = None # int, float, str, or np.ndarray | |||
| self.dtype = None # np.int64, np.float64, np.str | |||
| self.pytype = None | |||
| self.dtype = None | |||
| self._is_input = None | |||
| self._is_target = None | |||
| if is_input is not None: | |||
| if is_input is not None or is_target is not None: | |||
| self.is_input = is_input | |||
| if is_target is not None: | |||
| self.is_target = is_target | |||
| def _set_dtype(self): | |||
| self.pytype = self._type_detection(self.content) | |||
| self.dtype = self._map_to_np_type(self.pytype) | |||
| @property | |||
| def is_input(self): | |||
| return self._is_input | |||
| @is_input.setter | |||
| def is_input(self, value): | |||
| """ | |||
| 当 field_array.is_input = True / False 时被调用 | |||
| """ | |||
| if value is True: | |||
| self.pytype = self._type_detection(self.content) | |||
| self.dtype = self._map_to_np_type(self.pytype) | |||
| self._set_dtype() | |||
| self._is_input = value | |||
| @property | |||
| @@ -140,46 +164,99 @@ class FieldArray(object): | |||
| @is_target.setter | |||
| def is_target(self, value): | |||
| """ | |||
| 当 field_array.is_target = True / False 时被调用 | |||
| """ | |||
| if value is True: | |||
| self.pytype = self._type_detection(self.content) | |||
| self.dtype = self._map_to_np_type(self.pytype) | |||
| self._set_dtype() | |||
| self._is_target = value | |||
| def _type_detection(self, content): | |||
| """ | |||
| :param content: a list of int, float, str or np.ndarray, or a list of list of one. | |||
| :return type: one of int, float, str, np.ndarray | |||
| """当该field被设置为is_input或者is_target时被调用 | |||
| """ | |||
| if isinstance(content, list) and len(content) > 0 and isinstance(content[0], list): | |||
| # content is a 2-D list | |||
| if not all(isinstance(_, list) for _ in content): # strict check 2-D list | |||
| raise TypeError("Please provide 2-D list.") | |||
| type_set = set([self._type_detection(x) for x in content]) | |||
| if len(type_set) == 2 and int in type_set and float in type_set: | |||
| type_set = {float} | |||
| elif len(type_set) > 1: | |||
| raise TypeError("Cannot create FieldArray with more than one type. Provided {}".format(type_set)) | |||
| self.is_2d_list = True | |||
| if len(content) == 0: | |||
| raise RuntimeError("Empty list in Field {}.".format(self.name)) | |||
| type_set = set([type(item) for item in content]) | |||
| if list in type_set: | |||
| if len(type_set) > 1: | |||
| # list 跟 非list 混在一起 | |||
| raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, type_set)) | |||
| # >1维list | |||
| inner_type_set = set() | |||
| for l in content: | |||
| [inner_type_set.add(type(obj)) for obj in l] | |||
| if list not in inner_type_set: | |||
| # 二维list | |||
| self.content_dim = 2 | |||
| return self._basic_type_detection(inner_type_set) | |||
| else: | |||
| if len(inner_type_set) == 1: | |||
| # >2维list | |||
| inner_inner_type_set = set() | |||
| for _2d_list in content: | |||
| for _1d_list in _2d_list: | |||
| [inner_inner_type_set.add(type(obj)) for obj in _1d_list] | |||
| if list in inner_inner_type_set: | |||
| raise RuntimeError("FieldArray cannot handle 4-D or more-D list.") | |||
| # 3维list | |||
| self.content_dim = 3 | |||
| return self._basic_type_detection(inner_inner_type_set) | |||
| else: | |||
| # list 跟 非list 混在一起 | |||
| raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, inner_type_set)) | |||
| else: | |||
| # 一维list | |||
| for content_type in type_set: | |||
| if content_type not in self.BASIC_TYPES: | |||
| raise RuntimeError("Unexpected data type in Field '{}'. Expect one of {}. Got {}.".format( | |||
| self.name, self.BASIC_TYPES, content_type)) | |||
| self.content_dim = 1 | |||
| return self._basic_type_detection(type_set) | |||
| def _basic_type_detection(self, type_set): | |||
| """ | |||
| :param type_set: a set of Python types | |||
| :return: one of self.BASIC_TYPES | |||
| """ | |||
| if len(type_set) == 1: | |||
| return type_set.pop() | |||
| elif isinstance(content, list): | |||
| # content is a 1-D list | |||
| if len(content) == 0: | |||
| # the old error is not informative enough. | |||
| raise RuntimeError("Cannot create FieldArray with an empty list. Or one element in the list is empty.") | |||
| type_set = set([type(item) for item in content]) | |||
| if len(type_set) == 1 and tuple(type_set)[0] in self.BASIC_TYPES: | |||
| return type_set.pop() | |||
| elif len(type_set) == 2 and float in type_set and int in type_set: | |||
| elif len(type_set) == 2: | |||
| # 有多个basic type; 可能需要up-cast | |||
| if float in type_set and int in type_set: | |||
| # up-cast int to float | |||
| return float | |||
| else: | |||
| raise TypeError("Cannot create FieldArray with type {}".format(*type_set)) | |||
| # str 跟 int 或者 float 混在一起 | |||
| raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, type_set)) | |||
| else: | |||
| raise TypeError("Cannot create FieldArray with type {}".format(type(content))) | |||
| # str, int, float混在一起 | |||
| raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, type_set)) | |||
| def _1d_list_check(self, val): | |||
| """如果不是1D list就报错 | |||
| """ | |||
| type_set = set((type(obj) for obj in val)) | |||
| if any(obj not in self.BASIC_TYPES for obj in type_set): | |||
| raise ValueError("Mixed data types in Field {}: {}".format(self.name, type_set)) | |||
| self._basic_type_detection(type_set) | |||
| # otherwise: _basic_type_detection will raise error | |||
| return True | |||
| def _2d_list_check(self, val): | |||
| """如果不是2D list 就报错 | |||
| """ | |||
| type_set = set(type(obj) for obj in val) | |||
| if list(type_set) != [list]: | |||
| raise ValueError("Mixed data types in Field {}: {}".format(self.name, type_set)) | |||
| inner_type_set = set() | |||
| for l in val: | |||
| for obj in l: | |||
| inner_type_set.add(type(obj)) | |||
| self._basic_type_detection(inner_type_set) | |||
| return True | |||
| @staticmethod | |||
| def _map_to_np_type(basic_type): | |||
| @@ -194,38 +271,39 @@ class FieldArray(object): | |||
| :param val: int, float, str, or a list of one. | |||
| """ | |||
| if self.is_target is True or self.is_input is True: | |||
| # only check type when used as target or input | |||
| if isinstance(val, list): | |||
| pass | |||
| elif isinstance(val, tuple): # 确保最外层是list | |||
| val = list(val) | |||
| elif isinstance(val, np.ndarray): | |||
| val = val.tolist() | |||
| elif any((isinstance(val, t) for t in self.BASIC_TYPES)): | |||
| pass | |||
| else: | |||
| raise RuntimeError( | |||
| "Unexpected data type {}. Should be list, np.array, or {}".format(type(val), self.BASIC_TYPES)) | |||
| val_type = type(val) | |||
| if val_type == list: # shape check | |||
| if self.is_2d_list is False: | |||
| raise RuntimeError("Cannot append a list into a 1-D FieldArray. Please provide an element.") | |||
| if self.is_input is True or self.is_target is True: | |||
| if type(val) == list: | |||
| if len(val) == 0: | |||
| raise RuntimeError("Cannot append an empty list.") | |||
| val_list_type = set([type(_) for _ in val]) # type check | |||
| if len(val_list_type) == 2 and int in val_list_type and float in val_list_type: | |||
| # up-cast int to float | |||
| val_type = float | |||
| elif len(val_list_type) == 1: | |||
| val_type = val_list_type.pop() | |||
| raise ValueError("Cannot append an empty list.") | |||
| if self.content_dim == 2 and self._1d_list_check(val): | |||
| # 1维list检查 | |||
| pass | |||
| elif self.content_dim == 3 and self._2d_list_check(val): | |||
| # 2维list检查 | |||
| pass | |||
| else: | |||
| raise TypeError("Cannot append a list of {}".format(val_list_type)) | |||
| else: | |||
| if self.is_2d_list is True: | |||
| raise RuntimeError("Cannot append a non-list into a 2-D list. Please provide a list.") | |||
| if val_type == float and self.pytype == int: | |||
| # up-cast | |||
| self.pytype = float | |||
| self.dtype = self._map_to_np_type(self.pytype) | |||
| elif val_type == int and self.pytype == float: | |||
| pass | |||
| elif val_type == self.pytype: | |||
| pass | |||
| raise RuntimeError( | |||
| "Dimension not matched: expect dim={}, got {}.".format(self.content_dim - 1, val)) | |||
| elif type(val) in self.BASIC_TYPES and self.content_dim == 1: | |||
| # scalar检查 | |||
| if type(val) == float and self.pytype == int: | |||
| self.pytype = float | |||
| self.dtype = self._map_to_np_type(self.pytype) | |||
| else: | |||
| raise TypeError("Cannot append type {} into type {}".format(val_type, self.pytype)) | |||
| raise RuntimeError( | |||
| "Unexpected data type {}. Should be list, np.array, or {}".format(type(val), self.BASIC_TYPES)) | |||
| self.content.append(val) | |||
| def __getitem__(self, indices): | |||
| @@ -11,6 +11,10 @@ class Instance(object): | |||
| """ | |||
| def __init__(self, **fields): | |||
| """ | |||
| :param fields: 可能是一维或者二维的 list or np.array | |||
| """ | |||
| self.fields = fields | |||
| def add_field(self, field_name, field): | |||
| @@ -32,5 +36,5 @@ class Instance(object): | |||
| def __repr__(self): | |||
| s = '\'' | |||
| return "{" + ",\n".join( | |||
| "\'" + field_name + "\': " + str(self.fields[field_name]) +\ | |||
| "\'" + field_name + "\': " + str(self.fields[field_name]) + \ | |||
| f" type={(str(type(self.fields[field_name]))).split(s)[1]}" for field_name in self.fields) + "}" | |||
| @@ -858,9 +858,22 @@ class ConllPOSReader(object): | |||
| ds.append(Instance(words=char_seq, | |||
| tag=pos_seq)) | |||
| return ds | |||
| def get_one(self, sample): | |||
| if len(sample) == 0: | |||
| return None | |||
| text = [] | |||
| pos_tags = [] | |||
| for w in sample: | |||
| t1, t2, t3, t4 = w[1], w[3], w[6], w[7] | |||
| if t3 == '_': | |||
| return None | |||
| text.append(t1) | |||
| pos_tags.append(t2) | |||
| return text, pos_tags | |||
| class ConllxDataLoader(object): | |||
| def load(self, path): | |||
| @@ -879,7 +892,12 @@ class ConllxDataLoader(object): | |||
| datalist.append(sample) | |||
| data = [self.get_one(sample) for sample in datalist] | |||
| return list(filter(lambda x: x is not None, data)) | |||
| data_list = list(filter(lambda x: x is not None, data)) | |||
| ds = DataSet() | |||
| for example in data_list: | |||
| ds.append(Instance(words=example[0], tag=example[1])) | |||
| return ds | |||
| def get_one(self, sample): | |||
| sample = list(map(list, zip(*sample))) | |||
| @@ -10,7 +10,7 @@ eval_sort_key = 'accuracy' | |||
| [model] | |||
| rnn_hidden_units = 300 | |||
| word_emb_dim = 100 | |||
| word_emb_dim = 300 | |||
| dropout = 0.5 | |||
| use_crf = true | |||
| print_every_step = 10 | |||
| @@ -8,16 +8,16 @@ import torch | |||
| # in order to run fastNLP without installation | |||
| sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) | |||
| from fastNLP.api.pipeline import Pipeline | |||
| from fastNLP.api.processor import SeqLenProcessor, VocabIndexerProcessor | |||
| from fastNLP.api.processor import SeqLenProcessor, VocabIndexerProcessor, SetInputProcessor, IndexerProcessor | |||
| from fastNLP.core.metrics import SpanFPreRecMetric | |||
| from fastNLP.core.trainer import Trainer | |||
| from fastNLP.io.config_io import ConfigLoader, ConfigSection | |||
| from fastNLP.models.sequence_modeling import AdvSeqLabel | |||
| from fastNLP.io.dataset_loader import ZhConllPOSReader | |||
| from fastNLP.io.dataset_loader import ZhConllPOSReader, ConllxDataLoader | |||
| from fastNLP.api.processor import ModelProcessor, Index2WordProcessor | |||
| cfgfile = './pos_tag.cfg' | |||
| pickle_path = "save" | |||
| @@ -35,7 +35,7 @@ def load_tencent_embed(embed_path, word2id): | |||
| return embedding_tensor | |||
| def train(checkpoint=None): | |||
| def train(train_data_path, dev_data_path, checkpoint=None): | |||
| # load config | |||
| train_param = ConfigSection() | |||
| model_param = ConfigSection() | |||
| @@ -43,24 +43,36 @@ def train(checkpoint=None): | |||
| print("config loaded") | |||
| # Data Loader | |||
| dataset = ZhConllPOSReader().load("/home/hyan/train.conllx") | |||
| print("loading training set...") | |||
| dataset = ConllxDataLoader().load(train_data_path) | |||
| print("loading dev set...") | |||
| dev_data = ConllxDataLoader().load(dev_data_path) | |||
| print(dataset) | |||
| print("dataset transformed") | |||
| print("================= dataset ready =====================") | |||
| dataset.rename_field("tag", "truth") | |||
| dev_data.rename_field("tag", "truth") | |||
| vocab_proc = VocabIndexerProcessor("words", new_added_filed_name="word_seq") | |||
| tag_proc = VocabIndexerProcessor("truth") | |||
| seq_len_proc = SeqLenProcessor(field_name="word_seq", new_added_field_name="word_seq_origin_len", is_input=True) | |||
| set_input_proc = SetInputProcessor("word_seq", "word_seq_origin_len", "truth") | |||
| vocab_proc(dataset) | |||
| tag_proc(dataset) | |||
| seq_len_proc(dataset) | |||
| # index dev set | |||
| word_vocab, tag_vocab = vocab_proc.vocab, tag_proc.vocab | |||
| dev_data.apply(lambda ins: [word_vocab.to_index(w) for w in ins["words"]], new_field_name="word_seq") | |||
| dev_data.apply(lambda ins: [tag_vocab.to_index(w) for w in ins["truth"]], new_field_name="truth") | |||
| dev_data.apply(lambda ins: len(ins["word_seq"]), new_field_name="word_seq_origin_len") | |||
| # set input & target | |||
| dataset.set_input("word_seq", "word_seq_origin_len", "truth") | |||
| dev_data.set_input("word_seq", "word_seq_origin_len", "truth") | |||
| dataset.set_target("truth", "word_seq_origin_len") | |||
| print("processors defined") | |||
| dev_data.set_target("truth", "word_seq_origin_len") | |||
| # dataset.set_is_target(tag_ids=True) | |||
| model_param["vocab_size"] = vocab_proc.get_vocab_size() | |||
| @@ -71,7 +83,7 @@ def train(checkpoint=None): | |||
| if checkpoint is None: | |||
| # pre_trained = load_tencent_embed("/home/zyfeng/data/char_tencent_embedding.pkl", vocab_proc.vocab.word2idx) | |||
| pre_trained = None | |||
| model = AdvSeqLabel(model_param, id2words=tag_proc.vocab.idx2word, emb=pre_trained) | |||
| model = AdvSeqLabel(model_param, id2words=None, emb=pre_trained) | |||
| print(model) | |||
| else: | |||
| model = torch.load(checkpoint) | |||
| @@ -80,33 +92,71 @@ def train(checkpoint=None): | |||
| trainer = Trainer(dataset, model, loss=None, metrics=SpanFPreRecMetric(tag_proc.vocab, pred="predict", | |||
| target="truth", | |||
| seq_lens="word_seq_origin_len"), | |||
| dev_data=dataset, metric_key="f", | |||
| use_tqdm=True, use_cuda=True, print_every=5, n_epochs=6, save_path="./save") | |||
| dev_data=dev_data, metric_key="f", | |||
| use_tqdm=True, use_cuda=True, print_every=5, n_epochs=6, save_path="./save_0") | |||
| trainer.train(load_best_model=True) | |||
| # save model & pipeline | |||
| model_proc = ModelProcessor(model, seq_len_field_name="word_seq_origin_len") | |||
| id2tag = Index2WordProcessor(tag_proc.vocab, "predict", "tag") | |||
| pp = Pipeline([vocab_proc, seq_len_proc, model_proc, id2tag]) | |||
| pp = Pipeline([vocab_proc, seq_len_proc, set_input_proc, model_proc, id2tag]) | |||
| save_dict = {"pipeline": pp, "model": model, "tag_vocab": tag_proc.vocab} | |||
| torch.save(save_dict, "model_pp.pkl") | |||
| print("pipeline saved") | |||
| torch.save(model, "./save/best_model.pkl") | |||
| def run_test(test_path): | |||
| test_data = ZhConllPOSReader().load(test_path) | |||
| with open("model_pp.pkl", "rb") as f: | |||
| save_dict = torch.load(f) | |||
| tag_vocab = save_dict["tag_vocab"] | |||
| pipeline = save_dict["pipeline"] | |||
| index_tag = IndexerProcessor(vocab=tag_vocab, field_name="tag", new_added_field_name="truth", is_input=False) | |||
| pipeline.pipeline = [index_tag] + pipeline.pipeline | |||
| pipeline(test_data) | |||
| test_data.set_target("truth") | |||
| prediction = test_data.field_arrays["predict"].content | |||
| truth = test_data.field_arrays["truth"].content | |||
| seq_len = test_data.field_arrays["word_seq_origin_len"].content | |||
| # padding by hand | |||
| max_length = max([len(seq) for seq in prediction]) | |||
| for idx in range(len(prediction)): | |||
| prediction[idx] = list(prediction[idx]) + ([0] * (max_length - len(prediction[idx]))) | |||
| truth[idx] = list(truth[idx]) + ([0] * (max_length - len(truth[idx]))) | |||
| evaluator = SpanFPreRecMetric(tag_vocab=tag_vocab, pred="predict", target="truth", | |||
| seq_lens="word_seq_origin_len") | |||
| evaluator({"predict": torch.Tensor(prediction), "word_seq_origin_len": torch.Tensor(seq_len)}, | |||
| {"truth": torch.Tensor(truth)}) | |||
| test_result = evaluator.get_metric() | |||
| f1 = round(test_result['f'] * 100, 2) | |||
| pre = round(test_result['pre'] * 100, 2) | |||
| rec = round(test_result['rec'] * 100, 2) | |||
| return {"F1": f1, "precision": pre, "recall": rec} | |||
| if __name__ == "__main__": | |||
| parser = argparse.ArgumentParser() | |||
| parser.add_argument("--train", type=str, help="training conll file", default="/home/zyfeng/data/sample.conllx") | |||
| parser.add_argument("--dev", type=str, help="dev conll file", default="/home/zyfeng/data/sample.conllx") | |||
| parser.add_argument("--test", type=str, help="test conll file", default=None) | |||
| parser.add_argument("-c", "--restart", action="store_true", help="whether to continue training") | |||
| parser.add_argument("-cp", "--checkpoint", type=str, help="checkpoint of the trained model") | |||
| args = parser.parse_args() | |||
| if args.restart is True: | |||
| # 继续训练 python train_pos_tag.py -c -cp ./save/best_model.pkl | |||
| if args.checkpoint is None: | |||
| raise RuntimeError("Please provide the checkpoint. -cp ") | |||
| train(args.checkpoint) | |||
| if args.test is not None: | |||
| print(run_test(args.test)) | |||
| else: | |||
| # 一次训练 python train_pos_tag.py | |||
| train() | |||
| if args.restart is True: | |||
| # 继续训练 python train_pos_tag.py -c -cp ./save/best_model.pkl | |||
| if args.checkpoint is None: | |||
| raise RuntimeError("Please provide the checkpoint. -cp ") | |||
| train(args.train, args.dev, args.checkpoint) | |||
| else: | |||
| # 一次训练 python train_pos_tag.py | |||
| train(args.train, args.dev) | |||
| @@ -89,3 +89,12 @@ class TestCase1(unittest.TestCase): | |||
| self.assertEqual(tuple(x["x"].shape), (4, 4)) | |||
| self.assertTrue(isinstance(y["y"], torch.Tensor)) | |||
| self.assertEqual(tuple(y["y"].shape), (4, 4)) | |||
| def test_list_of_numpy_to_tensor(self): | |||
| ds = DataSet([Instance(x=np.array([1, 2]), y=np.array([3, 4])) for _ in range(2)] + | |||
| [Instance(x=np.array([1, 2, 3, 4]), y=np.array([3, 4, 5, 6])) for _ in range(2)]) | |||
| ds.set_input("x") | |||
| ds.set_target("y") | |||
| iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=False) | |||
| for x, y in iter: | |||
| print(x, y) | |||
| @@ -6,15 +6,29 @@ from fastNLP.core.fieldarray import FieldArray | |||
| from fastNLP.core.instance import Instance | |||
| class TestDataSet(unittest.TestCase): | |||
| class TestDataSetInit(unittest.TestCase): | |||
| """初始化DataSet的办法有以下几种: | |||
| 1) 用dict: | |||
| 1.1) 二维list DataSet({"x": [[1, 2], [3, 4]]}) | |||
| 1.2) 二维array DataSet({"x": np.array([[1, 2], [3, 4]])}) | |||
| 1.3) 三维list DataSet({"x": [[[1, 2], [3, 4]], [[1, 2], [3, 4]]]}) | |||
| 2) 用list of Instance: | |||
| 2.1) 一维list DataSet([Instance(x=[1, 2, 3, 4])]) | |||
| 2.2) 一维array DataSet([Instance(x=np.array([1, 2, 3, 4]))]) | |||
| 2.3) 二维list DataSet([Instance(x=[[1, 2], [3, 4]])]) | |||
| 2.4) 二维array DataSet([Instance(x=np.array([[1, 2], [3, 4]]))]) | |||
| 只接受纯list或者最外层ndarray | |||
| """ | |||
| def test_init_v1(self): | |||
| # 一维list | |||
| ds = DataSet([Instance(x=[1, 2, 3, 4], y=[5, 6])] * 40) | |||
| self.assertTrue("x" in ds.field_arrays and "y" in ds.field_arrays) | |||
| self.assertEqual(ds.field_arrays["x"].content, [[1, 2, 3, 4], ] * 40) | |||
| self.assertEqual(ds.field_arrays["y"].content, [[5, 6], ] * 40) | |||
| def test_init_v2(self): | |||
| # 用dict | |||
| ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) | |||
| self.assertTrue("x" in ds.field_arrays and "y" in ds.field_arrays) | |||
| self.assertEqual(ds.field_arrays["x"].content, [[1, 2, 3, 4], ] * 40) | |||
| @@ -28,6 +42,8 @@ class TestDataSet(unittest.TestCase): | |||
| with self.assertRaises(ValueError): | |||
| _ = DataSet(0.00001) | |||
| class TestDataSetMethods(unittest.TestCase): | |||
| def test_append(self): | |||
| dd = DataSet() | |||
| for _ in range(3): | |||
| @@ -42,13 +42,13 @@ class TestFieldArray(unittest.TestCase): | |||
| self.assertEqual(fa.pytype, str) | |||
| def test_support_np_array(self): | |||
| fa = FieldArray("y", [np.array([1.1, 2.2, 3.3, 4.4, 5.5])], is_input=True) | |||
| self.assertEqual(fa.dtype, np.ndarray) | |||
| self.assertEqual(fa.pytype, np.ndarray) | |||
| fa = FieldArray("y", np.array([[1.1, 2.2, 3.3, 4.4, 5.5]]), is_input=True) | |||
| self.assertEqual(fa.dtype, np.float64) | |||
| self.assertEqual(fa.pytype, float) | |||
| fa.append(np.array([1.1, 2.2, 3.3, 4.4, 5.5])) | |||
| self.assertEqual(fa.dtype, np.ndarray) | |||
| self.assertEqual(fa.pytype, np.ndarray) | |||
| self.assertEqual(fa.dtype, np.float64) | |||
| self.assertEqual(fa.pytype, float) | |||
| fa = FieldArray("my_field", np.random.rand(3, 5), is_input=True) | |||
| # in this case, pytype is actually a float. We do not care about it. | |||
| @@ -1,8 +1,8 @@ | |||
| from fastNLP.models.biaffine_parser import BiaffineParser, ParserLoss, ParserMetric | |||
| import fastNLP | |||
| import unittest | |||
| import fastNLP | |||
| from fastNLP.models.biaffine_parser import BiaffineParser, ParserLoss, ParserMetric | |||
| data_file = """ | |||
| 1 The _ DET DT _ 3 det _ _ | |||
| 2 new _ ADJ JJ _ 3 amod _ _ | |||
| @@ -41,6 +41,7 @@ data_file = """ | |||
| """ | |||
| def init_data(): | |||
| ds = fastNLP.DataSet() | |||
| v = {'word_seq': fastNLP.Vocabulary(), | |||
| @@ -60,18 +61,19 @@ def init_data(): | |||
| data.append(line) | |||
| for name in ['word_seq', 'pos_seq', 'label_true']: | |||
| ds.apply(lambda x: ['<st>']+list(x[name]), new_field_name=name) | |||
| ds.apply(lambda x: ['<st>'] + list(x[name]), new_field_name=name) | |||
| ds.apply(lambda x: v[name].add_word_lst(x[name])) | |||
| for name in ['word_seq', 'pos_seq', 'label_true']: | |||
| ds.apply(lambda x: [v[name].to_index(w) for w in x[name]], new_field_name=name) | |||
| ds.apply(lambda x: [0]+list(map(int, x['arc_true'])), new_field_name='arc_true') | |||
| ds.apply(lambda x: [0] + list(map(int, x['arc_true'])), new_field_name='arc_true') | |||
| ds.apply(lambda x: len(x['word_seq']), new_field_name='seq_lens') | |||
| ds.set_input('word_seq', 'pos_seq', 'seq_lens', flag=True) | |||
| ds.set_target('arc_true', 'label_true', 'seq_lens', flag=True) | |||
| return ds, v['word_seq'], v['pos_seq'], v['label_true'] | |||
| class TestBiaffineParser(unittest.TestCase): | |||
| def test_train(self): | |||
| ds, v1, v2, v3 = init_data() | |||
| @@ -84,5 +86,6 @@ class TestBiaffineParser(unittest.TestCase): | |||
| n_epochs=10, use_cuda=False, use_tqdm=False) | |||
| trainer.train(load_best_model=False) | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||
| unittest.main() | |||
| @@ -1,91 +0,0 @@ | |||
| import unittest | |||
| from fastNLP import DataSet | |||
| from fastNLP import Instance | |||
| from fastNLP import Tester | |||
| from fastNLP import Vocabulary | |||
| from fastNLP.core.losses import CrossEntropyLoss | |||
| from fastNLP.core.metrics import AccuracyMetric | |||
| from fastNLP.models import CNNText | |||
| class TestTutorial(unittest.TestCase): | |||
| def test_tutorial(self): | |||
| # 从csv读取数据到DataSet | |||
| sample_path = "test/data_for_tests/tutorial_sample_dataset.csv" | |||
| dataset = DataSet.read_csv(sample_path, headers=('raw_sentence', 'label'), | |||
| sep='\t') | |||
| print(len(dataset)) | |||
| print(dataset[0]) | |||
| dataset.append(Instance(raw_sentence='fake data', label='0')) | |||
| dataset.apply(lambda x: x['raw_sentence'].lower(), new_field_name='raw_sentence') | |||
| # label转int | |||
| dataset.apply(lambda x: int(x['label']), new_field_name='label') | |||
| # 使用空格分割句子 | |||
| def split_sent(ins): | |||
| return ins['raw_sentence'].split() | |||
| dataset.apply(split_sent, new_field_name='words') | |||
| # 增加长度信息 | |||
| dataset.apply(lambda x: len(x['words']), new_field_name='seq_len') | |||
| print(len(dataset)) | |||
| print(dataset[0]) | |||
| # DataSet.drop(func)筛除数据 | |||
| dataset.drop(lambda x: x['seq_len'] <= 3) | |||
| print(len(dataset)) | |||
| # 设置DataSet中,哪些field要转为tensor | |||
| # set target,loss或evaluate中的golden,计算loss,模型评估时使用 | |||
| dataset.set_target("label") | |||
| # set input,模型forward时使用 | |||
| dataset.set_input("words") | |||
| # 分出测试集、训练集 | |||
| test_data, train_data = dataset.split(0.5) | |||
| print(len(test_data)) | |||
| print(len(train_data)) | |||
| # 构建词表, Vocabulary.add(word) | |||
| vocab = Vocabulary(min_freq=2) | |||
| train_data.apply(lambda x: [vocab.add(word) for word in x['words']]) | |||
| vocab.build_vocab() | |||
| # index句子, Vocabulary.to_index(word) | |||
| train_data.apply(lambda x: [vocab.to_index(word) for word in x['words']], new_field_name='words') | |||
| test_data.apply(lambda x: [vocab.to_index(word) for word in x['words']], new_field_name='words') | |||
| print(test_data[0]) | |||
| model = CNNText(embed_num=len(vocab), embed_dim=50, num_classes=5, padding=2, dropout=0.1) | |||
| from fastNLP import Trainer | |||
| from copy import deepcopy | |||
| # 更改DataSet中对应field的名称,要以模型的forward等参数名一致 | |||
| train_data.rename_field('words', 'word_seq') # input field 与 forward 参数一致 | |||
| train_data.rename_field('label', 'label_seq') | |||
| test_data.rename_field('words', 'word_seq') | |||
| test_data.rename_field('label', 'label_seq') | |||
| # 实例化Trainer,传入模型和数据,进行训练 | |||
| copy_model = deepcopy(model) | |||
| overfit_trainer = Trainer(train_data=test_data, model=copy_model, | |||
| loss=CrossEntropyLoss(pred="output", target="label_seq"), | |||
| metrics=AccuracyMetric(pred="predict", target="label_seq"), n_epochs=10, batch_size=4, | |||
| dev_data=test_data, save_path="./save") | |||
| overfit_trainer.train() | |||
| trainer = Trainer(train_data=train_data, model=model, | |||
| loss=CrossEntropyLoss(pred="output", target="label_seq"), | |||
| metrics=AccuracyMetric(pred="predict", target="label_seq"), n_epochs=10, batch_size=4, | |||
| dev_data=test_data, save_path="./save") | |||
| trainer.train() | |||
| print('Train finished!') | |||
| # 使用fastNLP的Tester测试脚本 | |||
| tester = Tester(data=test_data, model=model, metrics=AccuracyMetric(pred="predict", target="label_seq"), | |||
| batch_size=4) | |||
| acc = tester.test() | |||
| print(acc) | |||
| @@ -0,0 +1,432 @@ | |||
| import unittest | |||
| from fastNLP import DataSet | |||
| from fastNLP import Instance | |||
| from fastNLP import Vocabulary | |||
| from fastNLP.core.losses import CrossEntropyLoss | |||
| from fastNLP.core.metrics import AccuracyMetric | |||
| class TestTutorial(unittest.TestCase): | |||
| def test_fastnlp_10min_tutorial(self): | |||
| # 从csv读取数据到DataSet | |||
| sample_path = "tutorials/sample_data/tutorial_sample_dataset.csv" | |||
| dataset = DataSet.read_csv(sample_path, headers=('raw_sentence', 'label'), | |||
| sep='\t') | |||
| print(len(dataset)) | |||
| print(dataset[0]) | |||
| print(dataset[-3]) | |||
| dataset.append(Instance(raw_sentence='fake data', label='0')) | |||
| # 将所有数字转为小写 | |||
| dataset.apply(lambda x: x['raw_sentence'].lower(), new_field_name='raw_sentence') | |||
| # label转int | |||
| dataset.apply(lambda x: int(x['label']), new_field_name='label') | |||
| # 使用空格分割句子 | |||
| def split_sent(ins): | |||
| return ins['raw_sentence'].split() | |||
| dataset.apply(split_sent, new_field_name='words') | |||
| # 增加长度信息 | |||
| dataset.apply(lambda x: len(x['words']), new_field_name='seq_len') | |||
| print(len(dataset)) | |||
| print(dataset[0]) | |||
| # DataSet.drop(func)筛除数据 | |||
| dataset.drop(lambda x: x['seq_len'] <= 3) | |||
| print(len(dataset)) | |||
| # 设置DataSet中,哪些field要转为tensor | |||
| # set target,loss或evaluate中的golden,计算loss,模型评估时使用 | |||
| dataset.set_target("label") | |||
| # set input,模型forward时使用 | |||
| dataset.set_input("words", "seq_len") | |||
| # 分出测试集、训练集 | |||
| test_data, train_data = dataset.split(0.5) | |||
| print(len(test_data)) | |||
| print(len(train_data)) | |||
| # 构建词表, Vocabulary.add(word) | |||
| vocab = Vocabulary(min_freq=2) | |||
| train_data.apply(lambda x: [vocab.add(word) for word in x['words']]) | |||
| vocab.build_vocab() | |||
| # index句子, Vocabulary.to_index(word) | |||
| train_data.apply(lambda x: [vocab.to_index(word) for word in x['words']], new_field_name='words') | |||
| test_data.apply(lambda x: [vocab.to_index(word) for word in x['words']], new_field_name='words') | |||
| print(test_data[0]) | |||
| # 如果你们需要做强化学习或者GAN之类的项目,你们也可以使用这些数据预处理的工具 | |||
| from fastNLP.core.batch import Batch | |||
| from fastNLP.core.sampler import RandomSampler | |||
| batch_iterator = Batch(dataset=train_data, batch_size=2, sampler=RandomSampler()) | |||
| for batch_x, batch_y in batch_iterator: | |||
| print("batch_x has: ", batch_x) | |||
| print("batch_y has: ", batch_y) | |||
| break | |||
| from fastNLP.models import CNNText | |||
| model = CNNText(embed_num=len(vocab), embed_dim=50, num_classes=5, padding=2, dropout=0.1) | |||
| from fastNLP import Trainer | |||
| from copy import deepcopy | |||
| # 更改DataSet中对应field的名称,要以模型的forward等参数名一致 | |||
| train_data.rename_field('words', 'word_seq') # input field 与 forward 参数一致 | |||
| train_data.rename_field('label', 'label_seq') | |||
| test_data.rename_field('words', 'word_seq') | |||
| test_data.rename_field('label', 'label_seq') | |||
| loss = CrossEntropyLoss(pred="output", target="label_seq") | |||
| metric = AccuracyMetric(pred="predict", target="label_seq") | |||
| # 实例化Trainer,传入模型和数据,进行训练 | |||
| # 先在test_data拟合(确保模型的实现是正确的) | |||
| copy_model = deepcopy(model) | |||
| overfit_trainer = Trainer(model=copy_model, train_data=test_data, dev_data=test_data, | |||
| loss=loss, | |||
| metrics=metric, | |||
| save_path=None, | |||
| batch_size=32, | |||
| n_epochs=5) | |||
| overfit_trainer.train() | |||
| # 用train_data训练,在test_data验证 | |||
| trainer = Trainer(model=model, train_data=train_data, dev_data=test_data, | |||
| loss=CrossEntropyLoss(pred="output", target="label_seq"), | |||
| metrics=AccuracyMetric(pred="predict", target="label_seq"), | |||
| save_path=None, | |||
| batch_size=32, | |||
| n_epochs=5) | |||
| trainer.train() | |||
| print('Train finished!') | |||
| # 调用Tester在test_data上评价效果 | |||
| from fastNLP import Tester | |||
| tester = Tester(data=test_data, model=model, metrics=AccuracyMetric(pred="predict", target="label_seq"), | |||
| batch_size=4) | |||
| acc = tester.test() | |||
| print(acc) | |||
| def test_fastnlp_1min_tutorial(self): | |||
| # tutorials/fastnlp_1min_tutorial.ipynb | |||
| data_path = "tutorials/sample_data/tutorial_sample_dataset.csv" | |||
| ds = DataSet.read_csv(data_path, headers=('raw_sentence', 'label'), sep='\t') | |||
| print(ds[1]) | |||
| # 将所有数字转为小写 | |||
| ds.apply(lambda x: x['raw_sentence'].lower(), new_field_name='raw_sentence') | |||
| # label转int | |||
| ds.apply(lambda x: int(x['label']), new_field_name='label_seq', is_target=True) | |||
| def split_sent(ins): | |||
| return ins['raw_sentence'].split() | |||
| ds.apply(split_sent, new_field_name='words', is_input=True) | |||
| # 分割训练集/验证集 | |||
| train_data, dev_data = ds.split(0.3) | |||
| print("Train size: ", len(train_data)) | |||
| print("Test size: ", len(dev_data)) | |||
| from fastNLP import Vocabulary | |||
| vocab = Vocabulary(min_freq=2) | |||
| train_data.apply(lambda x: [vocab.add(word) for word in x['words']]) | |||
| # index句子, Vocabulary.to_index(word) | |||
| train_data.apply(lambda x: [vocab.to_index(word) for word in x['words']], new_field_name='word_seq', | |||
| is_input=True) | |||
| dev_data.apply(lambda x: [vocab.to_index(word) for word in x['words']], new_field_name='word_seq', | |||
| is_input=True) | |||
| from fastNLP.models import CNNText | |||
| model = CNNText(embed_num=len(vocab), embed_dim=50, num_classes=5, padding=2, dropout=0.1) | |||
| from fastNLP import Trainer, CrossEntropyLoss, AccuracyMetric | |||
| trainer = Trainer(model=model, | |||
| train_data=train_data, | |||
| dev_data=dev_data, | |||
| loss=CrossEntropyLoss(), | |||
| metrics=AccuracyMetric() | |||
| ) | |||
| trainer.train() | |||
| print('Train finished!') | |||
| def test_fastnlp_advanced_tutorial(self): | |||
| import os | |||
| os.chdir("tutorials/fastnlp_advanced_tutorial") | |||
| from fastNLP import DataSet | |||
| from fastNLP import Instance | |||
| from fastNLP import Vocabulary | |||
| from fastNLP import Trainer | |||
| from fastNLP import Tester | |||
| # ### Instance | |||
| # Instance表示一个样本,由一个或者多个field(域、属性、特征)组成,每个field具有自己的名字以及值 | |||
| # 在初始化Instance的时候可以定义它包含的field,使用"field_name=field_value"的写法 | |||
| # In[2]: | |||
| # 组织一个Instance,这个Instance由premise、hypothesis、label三个field组成 | |||
| instance = Instance(premise='an premise example .', hypothesis='an hypothesis example.', label=1) | |||
| instance | |||
| # In[3]: | |||
| data_set = DataSet([instance] * 5) | |||
| data_set.append(instance) | |||
| data_set[-2:] | |||
| # In[4]: | |||
| # 如果某一个field的类型与dataset对应的field类型不一样仍可被加入dataset中 | |||
| instance2 = Instance(premise='the second premise example .', hypothesis='the second hypothesis example.', | |||
| label='1') | |||
| try: | |||
| data_set.append(instance2) | |||
| except: | |||
| pass | |||
| data_set[-2:] | |||
| # In[5]: | |||
| # 如果某一个field的名字不对,则该instance不能被append到dataset中 | |||
| instance3 = Instance(premises='the third premise example .', hypothesis='the third hypothesis example.', | |||
| label=1) | |||
| try: | |||
| data_set.append(instance3) | |||
| except: | |||
| print('cannot append instance') | |||
| pass | |||
| data_set[-2:] | |||
| # In[6]: | |||
| # 除了文本以外,还可以将tensor作为其中一个field的value | |||
| import torch | |||
| tensor_ins = Instance(image=torch.randn(5, 5), label=0) | |||
| ds = DataSet() | |||
| ds.append(tensor_ins) | |||
| ds | |||
| from fastNLP import DataSet | |||
| from fastNLP import Instance | |||
| # 从csv读取数据到DataSet | |||
| # 类csv文件,即每一行为一个example的文件,都可以使用这种方法进行数据读取 | |||
| dataset = DataSet.read_csv('tutorial_sample_dataset.csv', headers=('raw_sentence', 'label'), sep='\t') | |||
| # 查看DataSet的大小 | |||
| len(dataset) | |||
| # In[8]: | |||
| # 使用数字索引[k],获取第k个样本 | |||
| dataset[0] | |||
| # In[9]: | |||
| # 获取的样本是一个Instance | |||
| type(dataset[0]) | |||
| # In[10]: | |||
| # 使用数字索引[a: b],获取第a到第b个样本 | |||
| dataset[0: 3] | |||
| # In[11]: | |||
| # 索引也可以是负数 | |||
| dataset[-1] | |||
| data_path = ['premise', 'hypothesis', 'label'] | |||
| # 读入文件 | |||
| with open(data_path[0]) as f: | |||
| premise = f.readlines() | |||
| with open(data_path[1]) as f: | |||
| hypothesis = f.readlines() | |||
| with open(data_path[2]) as f: | |||
| label = f.readlines() | |||
| assert len(premise) == len(hypothesis) and len(hypothesis) == len(label) | |||
| # 组织DataSet | |||
| data_set = DataSet() | |||
| for p, h, l in zip(premise, hypothesis, label): | |||
| p = p.strip() # 将行末空格去除 | |||
| h = h.strip() # 将行末空格去除 | |||
| data_set.append(Instance(premise=p, hypothesis=h, truth=l)) | |||
| data_set[0] | |||
| # ### DataSet的其他操作 | |||
| # 在构建完毕DataSet后,仍然可以对DataSet的内容进行操作,函数接口为DataSet.apply() | |||
| # In[13]: | |||
| # 将premise域的所有文本转成小写 | |||
| data_set.apply(lambda x: x['premise'].lower(), new_field_name='premise') | |||
| data_set[-2:] | |||
| # In[14]: | |||
| # label转int | |||
| data_set.apply(lambda x: int(x['truth']), new_field_name='truth') | |||
| data_set[-2:] | |||
| # In[15]: | |||
| # 使用空格分割句子 | |||
| def split_sent(ins): | |||
| return ins['premise'].split() | |||
| data_set.apply(split_sent, new_field_name='premise') | |||
| data_set.apply(lambda x: x['hypothesis'].split(), new_field_name='hypothesis') | |||
| data_set[-2:] | |||
| # In[16]: | |||
| # 筛选数据 | |||
| origin_data_set_len = len(data_set) | |||
| data_set.drop(lambda x: len(x['premise']) <= 6) | |||
| origin_data_set_len, len(data_set) | |||
| # In[17]: | |||
| # 增加长度信息 | |||
| data_set.apply(lambda x: [1] * len(x['premise']), new_field_name='premise_len') | |||
| data_set.apply(lambda x: [1] * len(x['hypothesis']), new_field_name='hypothesis_len') | |||
| data_set[-1] | |||
| # In[18]: | |||
| # 设定特征域、标签域 | |||
| data_set.set_input("premise", "premise_len", "hypothesis", "hypothesis_len") | |||
| data_set.set_target("truth") | |||
| # In[19]: | |||
| # 重命名field | |||
| data_set.rename_field('truth', 'label') | |||
| data_set[-1] | |||
| # In[20]: | |||
| # 切分训练、验证集、测试集 | |||
| train_data, vad_data = data_set.split(0.5) | |||
| dev_data, test_data = vad_data.split(0.4) | |||
| len(train_data), len(dev_data), len(test_data) | |||
| # In[21]: | |||
| # 深拷贝一个数据集 | |||
| import copy | |||
| train_data_2, dev_data_2 = copy.deepcopy(train_data), copy.deepcopy(dev_data) | |||
| del copy | |||
| # 初始化词表,该词表最大的vocab_size为10000,词表中每个词出现的最低频率为2,'<unk>'表示未知词语,'<pad>'表示padding词语 | |||
| # Vocabulary默认初始化参数为max_size=None, min_freq=None, unknown='<unk>', padding='<pad>' | |||
| vocab = Vocabulary(max_size=10000, min_freq=2, unknown='<unk>', padding='<pad>') | |||
| # 构建词表 | |||
| train_data.apply(lambda x: [vocab.add(word) for word in x['premise']]) | |||
| train_data.apply(lambda x: [vocab.add(word) for word in x['hypothesis']]) | |||
| vocab.build_vocab() | |||
| # In[23]: | |||
| # 根据词表index句子 | |||
| train_data.apply(lambda x: [vocab.to_index(word) for word in x['premise']], new_field_name='premise') | |||
| train_data.apply(lambda x: [vocab.to_index(word) for word in x['hypothesis']], new_field_name='hypothesis') | |||
| dev_data.apply(lambda x: [vocab.to_index(word) for word in x['premise']], new_field_name='premise') | |||
| dev_data.apply(lambda x: [vocab.to_index(word) for word in x['hypothesis']], new_field_name='hypothesis') | |||
| test_data.apply(lambda x: [vocab.to_index(word) for word in x['premise']], new_field_name='premise') | |||
| test_data.apply(lambda x: [vocab.to_index(word) for word in x['hypothesis']], new_field_name='hypothesis') | |||
| train_data[-1], dev_data[-1], test_data[-1] | |||
| # 读入vocab文件 | |||
| with open('vocab.txt') as f: | |||
| lines = f.readlines() | |||
| vocabs = [] | |||
| for line in lines: | |||
| vocabs.append(line.strip()) | |||
| # 实例化Vocabulary | |||
| vocab_bert = Vocabulary(unknown=None, padding=None) | |||
| # 将vocabs列表加入Vocabulary | |||
| vocab_bert.add_word_lst(vocabs) | |||
| # 构建词表 | |||
| vocab_bert.build_vocab() | |||
| # 更新unknown与padding的token文本 | |||
| vocab_bert.unknown = '[UNK]' | |||
| vocab_bert.padding = '[PAD]' | |||
| # In[25]: | |||
| # 根据词表index句子 | |||
| train_data_2.apply(lambda x: [vocab_bert.to_index(word) for word in x['premise']], new_field_name='premise') | |||
| train_data_2.apply(lambda x: [vocab_bert.to_index(word) for word in x['hypothesis']], | |||
| new_field_name='hypothesis') | |||
| dev_data_2.apply(lambda x: [vocab_bert.to_index(word) for word in x['premise']], new_field_name='premise') | |||
| dev_data_2.apply(lambda x: [vocab_bert.to_index(word) for word in x['hypothesis']], new_field_name='hypothesis') | |||
| train_data_2[-1], dev_data_2[-1] | |||
| # step 1:加载模型参数(非必选) | |||
| from fastNLP.io.config_io import ConfigSection, ConfigLoader | |||
| args = ConfigSection() | |||
| ConfigLoader().load_config("./data/config", {"esim_model": args}) | |||
| args["vocab_size"] = len(vocab) | |||
| args.data | |||
| # In[27]: | |||
| # step 2:加载ESIM模型 | |||
| from fastNLP.models import ESIM | |||
| model = ESIM(**args.data) | |||
| model | |||
| # In[28]: | |||
| # 另一个例子:加载CNN文本分类模型 | |||
| from fastNLP.models import CNNText | |||
| cnn_text_model = CNNText(embed_num=len(vocab), embed_dim=50, num_classes=5, padding=2, dropout=0.1) | |||
| cnn_text_model | |||
| from fastNLP import CrossEntropyLoss | |||
| from fastNLP import Adam | |||
| from fastNLP import AccuracyMetric | |||
| trainer = Trainer( | |||
| train_data=train_data, | |||
| model=model, | |||
| loss=CrossEntropyLoss(pred='pred', target='label'), | |||
| metrics=AccuracyMetric(), | |||
| n_epochs=5, | |||
| batch_size=16, | |||
| print_every=-1, | |||
| validate_every=-1, | |||
| dev_data=dev_data, | |||
| use_cuda=True, | |||
| optimizer=Adam(lr=1e-3, weight_decay=0), | |||
| check_code_level=-1, | |||
| metric_key='acc', | |||
| use_tqdm=False, | |||
| ) | |||
| trainer.train() | |||
| tester = Tester( | |||
| data=test_data, | |||
| model=model, | |||
| metrics=AccuracyMetric(), | |||
| batch_size=args["batch_size"], | |||
| ) | |||
| tester.test() | |||
| os.chdir("../..") | |||