| @@ -151,16 +151,19 @@ class DataSet(object): | |||
| assert name in self.field_arrays | |||
| self.field_arrays[name].append(field) | |||
| def add_field(self, name, fields, padder=AutoPadder(pad_val=0), is_input=False, is_target=False, ignore_type=False): | |||
| def add_field(self, name, fields, padder=None, is_input=False, is_target=False, ignore_type=False): | |||
| """Add a new field to the DataSet. | |||
| :param str name: the name of the field. | |||
| :param fields: a list of int, float, or other objects. | |||
| :param int padder: PadBase对象,如何对该Field进行padding。大部分情况使用默认值即可 | |||
| :param padder: PadBase对象,如何对该Field进行padding。如果为None则使用 | |||
| :param bool is_input: whether this field is model input. | |||
| :param bool is_target: whether this field is label or target. | |||
| :param bool ignore_type: If True, do not perform type check. (Default: False) | |||
| """ | |||
| if padder is None: | |||
| padder = AutoPadder(pad_val=0) | |||
| if len(self.field_arrays) != 0: | |||
| if len(self) != len(fields): | |||
| raise RuntimeError(f"The field to append must have the same size as dataset. " | |||
| @@ -231,8 +234,8 @@ class DataSet(object): | |||
| raise KeyError("{} is not a valid field name.".format(name)) | |||
| def set_padder(self, field_name, padder): | |||
| """ | |||
| 为field_name设置padder | |||
| """为field_name设置padder | |||
| :param field_name: str, 设置field的padding方式为padder | |||
| :param padder: PadderBase类型或None. 设置为None即删除padder。即对该field不进行padding操作. | |||
| :return: | |||
| @@ -242,8 +245,7 @@ class DataSet(object): | |||
| self.field_arrays[field_name].set_padder(padder) | |||
| def set_pad_val(self, field_name, pad_val): | |||
| """ | |||
| 为某个 | |||
| """为某个field设置对应的pad_val. | |||
| :param field_name: str,修改该field的pad_val | |||
| :param pad_val: int,该field的padder会以pad_val作为padding index | |||
| @@ -254,43 +256,60 @@ class DataSet(object): | |||
| self.field_arrays[field_name].set_pad_val(pad_val) | |||
| def get_input_name(self): | |||
| """Get all field names with `is_input` as True. | |||
| """返回所有is_input被设置为True的field名称 | |||
| :return field_names: a list of str | |||
| :return list, 里面的元素为被设置为input的field名称 | |||
| """ | |||
| return [name for name, field in self.field_arrays.items() if field.is_input] | |||
| def get_target_name(self): | |||
| """Get all field names with `is_target` as True. | |||
| """返回所有is_target被设置为True的field名称 | |||
| :return field_names: a list of str | |||
| :return list, 里面的元素为被设置为target的field名称 | |||
| """ | |||
| return [name for name, field in self.field_arrays.items() if field.is_target] | |||
| def apply(self, func, new_field_name=None, **kwargs): | |||
| """Apply a function to every instance of the DataSet. | |||
| :param func: a function that takes an instance as input. | |||
| :param str new_field_name: If not None, results of the function will be stored as a new field. | |||
| :param **kwargs: Accept parameters will be | |||
| (1) is_input: boolean, will be ignored if new_field is None. If True, the new field will be as input. | |||
| (2) is_target: boolean, will be ignored if new_field is None. If True, the new field will be as target. | |||
| :return results: if new_field_name is not passed, returned values of the function over all instances. | |||
| def apply_field(self, func, field_name, new_field_name=None, **kwargs): | |||
| """将DataSet中的每个instance中的`field_name`这个field传给func,并获取它的返回值. | |||
| :param func: Callable, input是instance的`field_name`这个field. | |||
| :param field_name: str, 传入func的是哪个field. | |||
| :param new_field_name: (str, None). 如果不是None,将func的返回值放入这个名为`new_field_name`的新field中,如果名称与已有 | |||
| 的field相同,则覆盖之前的field. | |||
| :param **kwargs: 合法的参数有以下三个 | |||
| (1) is_input: bool, 如果为True则将`new_field_name`这个field设置为input | |||
| (2) is_target: bool, 如果为True则将`new_field_name`这个field设置为target | |||
| (3) ignore_type: bool, 如果为True则将`new_field_name`这个field的ignore_type设置为true, 忽略其类型 | |||
| :return: List[], 里面的元素为func的返回值,所以list长度为DataSet的长度 | |||
| """ | |||
| assert len(self)!=0, "Null dataset cannot use .apply()." | |||
| assert len(self)!=0, "Null DataSet cannot use apply()." | |||
| if field_name not in self: | |||
| raise KeyError("DataSet has no field named `{}`.".format(field_name)) | |||
| results = [] | |||
| idx = -1 | |||
| try: | |||
| for idx, ins in enumerate(self._inner_iter()): | |||
| results.append(func(ins)) | |||
| results.append(func(ins[field_name])) | |||
| except Exception as e: | |||
| if idx!=-1: | |||
| print("Exception happens at the `{}`th instance.".format(idx)) | |||
| raise e | |||
| # results = [func(ins) for ins in self._inner_iter()] | |||
| if not (new_field_name is None) and len(list(filter(lambda x: x is not None, results))) == 0: # all None | |||
| raise ValueError("{} always return None.".format(get_func_signature(func=func))) | |||
| if new_field_name is not None: | |||
| self._add_apply_field(results, new_field_name, kwargs) | |||
| return results | |||
| def _add_apply_field(self, results, new_field_name, kwargs): | |||
| """将results作为加入到新的field中,field名称为new_field_name | |||
| :param results: List[], 一般是apply*()之后的结果 | |||
| :param new_field_name: str, 新加入的field的名称 | |||
| :param kwargs: dict, 用户apply*()时传入的自定义参数 | |||
| :return: | |||
| """ | |||
| extra_param = {} | |||
| if 'is_input' in kwargs: | |||
| extra_param['is_input'] = kwargs['is_input'] | |||
| @@ -298,56 +317,84 @@ class DataSet(object): | |||
| extra_param['is_target'] = kwargs['is_target'] | |||
| if 'ignore_type' in kwargs: | |||
| extra_param['ignore_type'] = kwargs['ignore_type'] | |||
| if new_field_name is not None: | |||
| if new_field_name in self.field_arrays: | |||
| # overwrite the field, keep same attributes | |||
| old_field = self.field_arrays[new_field_name] | |||
| if 'is_input' not in extra_param: | |||
| extra_param['is_input'] = old_field.is_input | |||
| if 'is_target' not in extra_param: | |||
| extra_param['is_target'] = old_field.is_target | |||
| if 'ignore_type' not in extra_param: | |||
| extra_param['ignore_type'] = old_field.ignore_type | |||
| self.add_field(name=new_field_name, fields=results, is_input=extra_param["is_input"], | |||
| is_target=extra_param["is_target"], ignore_type=extra_param['ignore_type']) | |||
| else: | |||
| self.add_field(name=new_field_name, fields=results, is_input=extra_param.get("is_input", None), | |||
| is_target=extra_param.get("is_target", None), | |||
| ignore_type=extra_param.get("ignore_type", False)) | |||
| if new_field_name in self.field_arrays: | |||
| # overwrite the field, keep same attributes | |||
| old_field = self.field_arrays[new_field_name] | |||
| if 'is_input' not in extra_param: | |||
| extra_param['is_input'] = old_field.is_input | |||
| if 'is_target' not in extra_param: | |||
| extra_param['is_target'] = old_field.is_target | |||
| if 'ignore_type' not in extra_param: | |||
| extra_param['ignore_type'] = old_field.ignore_type | |||
| self.add_field(name=new_field_name, fields=results, is_input=extra_param["is_input"], | |||
| is_target=extra_param["is_target"], ignore_type=extra_param['ignore_type']) | |||
| else: | |||
| return results | |||
| self.add_field(name=new_field_name, fields=results, is_input=extra_param.get("is_input", None), | |||
| is_target=extra_param.get("is_target", None), | |||
| ignore_type=extra_param.get("ignore_type", False)) | |||
| def apply(self, func, new_field_name=None, **kwargs): | |||
| """将DataSet中每个instance传入到func中,并获取它的返回值. | |||
| :param func: Callable, 参数是DataSet中的instance | |||
| :param new_field_name: (None, str). (1) None, 不创建新的field; (2) str,将func的返回值放入这个名为 | |||
| `new_field_name`的新field中,如果名称与已有的field相同,则覆盖之前的field; | |||
| :param kwargs: 合法的参数有以下三个 | |||
| (1) is_input: bool, 如果为True则将`new_field_name`的field设置为input | |||
| (2) is_target: bool, 如果为True则将`new_field_name`的field设置为target | |||
| (3) ignore_type: bool, 如果为True则将`new_field_name`的field的ignore_type设置为true, 忽略其类型 | |||
| :return: List[], 里面的元素为func的返回值,所以list长度为DataSet的长度 | |||
| """ | |||
| assert len(self)!=0, "Null DataSet cannot use apply()." | |||
| idx = -1 | |||
| try: | |||
| results = [] | |||
| for idx, ins in enumerate(self._inner_iter()): | |||
| results.append(func(ins)) | |||
| except Exception as e: | |||
| if idx!=-1: | |||
| print("Exception happens at the `{}`th instance.".format(idx)) | |||
| raise e | |||
| # results = [func(ins) for ins in self._inner_iter()] | |||
| if not (new_field_name is None) and len(list(filter(lambda x: x is not None, results))) == 0: # all None | |||
| raise ValueError("{} always return None.".format(get_func_signature(func=func))) | |||
| if new_field_name is not None: | |||
| self._add_apply_field(results, new_field_name, kwargs) | |||
| return results | |||
| def drop(self, func, inplace=True): | |||
| """Drop instances if a condition holds. | |||
| """func接受一个instance,返回bool值,返回值为True时,该instance会被删除。 | |||
| :param func: a function that takes an Instance object as input, and returns bool. | |||
| The instance will be dropped if the function returns True. | |||
| :param inplace: bool, whether to drop inpalce. Otherwise a new dataset will be returned. | |||
| :param func: Callable, 接受一个instance作为参数,返回bool值。为True时删除该instance | |||
| :param inplace: bool, 是否在当前DataSet中直接删除instance。如果为False,返回值为一个删除了相应instance的新的DataSet | |||
| :return: DataSet. | |||
| """ | |||
| if inplace: | |||
| results = [ins for ins in self._inner_iter() if not func(ins)] | |||
| for name, old_field in self.field_arrays.items(): | |||
| self.field_arrays[name].content = [ins[name] for ins in results] | |||
| return self | |||
| else: | |||
| results = [ins for ins in self if not func(ins)] | |||
| data = DataSet(results) | |||
| for field_name, field in self.field_arrays.items(): | |||
| data.field_arrays[field_name].to(field) | |||
| return data | |||
| def split(self, dev_ratio): | |||
| """Split the dataset into training and development(validation) set. | |||
| def split(self, ratio): | |||
| """将DataSet按照ratio的比例拆分,返回两个DataSet | |||
| :param float dev_ratio: the ratio of test set in all data. | |||
| :return (train_set, dev_set): | |||
| train_set: the training set | |||
| dev_set: the development set | |||
| :param ratio: float, 0<ratio<1, 返回的第一个DataSet拥有ratio这么多数据,第二个DataSet拥有(1-ratio)这么多数据 | |||
| :return (DataSet, DataSet) | |||
| """ | |||
| assert isinstance(dev_ratio, float) | |||
| assert 0 < dev_ratio < 1 | |||
| assert isinstance(ratio, float) | |||
| assert 0 < ratio < 1 | |||
| all_indices = [_ for _ in range(len(self))] | |||
| np.random.shuffle(all_indices) | |||
| split = int(dev_ratio * len(self)) | |||
| split = int(ratio * len(self)) | |||
| dev_indices = all_indices[:split] | |||
| train_indices = all_indices[split:] | |||
| dev_set = DataSet() | |||
| @@ -398,26 +445,25 @@ class DataSet(object): | |||
| _dict[header].append(content) | |||
| return cls(_dict) | |||
| # def read_pos(self): | |||
| # return DataLoaderRegister.get_reader('read_pos') | |||
| def save(self, path): | |||
| """Save the DataSet object as pickle. | |||
| """保存DataSet. | |||
| :param str path: the path to the pickle | |||
| :param path: str, 将DataSet存在哪个路径 | |||
| """ | |||
| with open(path, 'wb') as f: | |||
| pickle.dump(self, f) | |||
| @staticmethod | |||
| def load(path): | |||
| """Load a DataSet object from pickle. | |||
| """从保存的DataSet pickle路径中读取DataSet | |||
| :param str path: the path to the pickle | |||
| :return data_set: | |||
| :param path: str, 读取路径 | |||
| :return DataSet: | |||
| """ | |||
| with open(path, 'rb') as f: | |||
| return pickle.load(f) | |||
| d = pickle.load(f) | |||
| assert isinstance(d, DataSet), "The object is not DataSet, but {}.".format(type(d)) | |||
| return d | |||
| def construct_dataset(sentences): | |||
| @@ -84,7 +84,7 @@ class AutoPadder(PadderBase): | |||
| for i, content in enumerate(contents): | |||
| array[i][:len(content)] = content | |||
| elif field_ele_dtype is None: | |||
| array = contents # 当ignore_type=True时,直接返回contents | |||
| array = np.array(contents) # 当ignore_type=True时,直接返回contents | |||
| else: # should only be str | |||
| array = np.array([content for content in contents]) | |||
| return array | |||
| @@ -290,9 +290,10 @@ class FieldArray(object): | |||
| return "FieldArray {}: {}".format(self.name, self.content.__repr__()) | |||
| def append(self, val): | |||
| """Add a new item to the tail of FieldArray. | |||
| """将val增加到FieldArray中,若该field的ignore_type为True则直接append到这个field中;若ignore_type为False,且当前field为 | |||
| input或者target,则会检查传入的content是否与之前的内容在dimension, 元素的类型上是匹配的。 | |||
| :param val: int, float, str, or a list of one. | |||
| :param val: Any. | |||
| """ | |||
| if self.ignore_type is False: | |||
| if isinstance(val, list): | |||
| @@ -367,13 +368,14 @@ class FieldArray(object): | |||
| self.padder = deepcopy(padder) | |||
| def set_pad_val(self, pad_val): | |||
| """ | |||
| 修改padder的pad_val. | |||
| :param pad_val: int。 | |||
| """修改padder的pad_val. | |||
| :param pad_val: int。将该field的pad值设置为该值 | |||
| :return: | |||
| """ | |||
| if self.padder is not None: | |||
| self.padder.set_pad_val(pad_val) | |||
| return self | |||
| def __len__(self): | |||
| @@ -385,8 +387,7 @@ class FieldArray(object): | |||
| def to(self, other): | |||
| """ | |||
| 将other的属性复制给本fieldarray(必须通过fieldarray类型). 包含 is_input, is_target, padder, dtype, pytype, content_dim | |||
| ignore_type | |||
| 将other的属性复制给本FieldArray(other必须为FieldArray类型). 包含 is_input, is_target, padder, ignore_type | |||
| :param other: FieldArray | |||
| :return: | |||
| @@ -396,11 +397,10 @@ class FieldArray(object): | |||
| self.is_input = other.is_input | |||
| self.is_target = other.is_target | |||
| self.padder = other.padder | |||
| self.dtype = other.dtype | |||
| self.pytype = other.pytype | |||
| self.content_dim = other.content_dim | |||
| self.ignore_type = other.ignore_type | |||
| return self | |||
| def is_iterable(content): | |||
| try: | |||
| _ = (e for e in content) | |||
| @@ -24,7 +24,7 @@ def _prepare_cache_filepath(filepath): | |||
| if not os.path.exists(cache_dir): | |||
| os.makedirs(cache_dir) | |||
| # TODO 可以保存下缓存时的参数,如果load的时候发现参数不一致,发出警告。 | |||
| def cache_results(cache_filepath, refresh=False, verbose=1): | |||
| def wrapper_(func): | |||
| signature = inspect.signature(func) | |||
| @@ -79,7 +79,7 @@ class SeqLabeling(BaseModel): | |||
| :return prediction: list of [decode path(list)] | |||
| """ | |||
| max_len = x.shape[1] | |||
| tag_seq = self.Crf.viterbi_decode(x, self.mask) | |||
| tag_seq, _ = self.Crf.viterbi_decode(x, self.mask) | |||
| # pad prediction to equal length | |||
| if pad is True: | |||
| for pred in tag_seq: | |||
| @@ -2,12 +2,7 @@ import torch | |||
| from torch import nn | |||
| from fastNLP.modules.utils import initial_parameter | |||
| def log_sum_exp(x, dim=-1): | |||
| max_value, _ = x.max(dim=dim, keepdim=True) | |||
| res = torch.log(torch.sum(torch.exp(x - max_value), dim=dim, keepdim=True)) + max_value | |||
| return res.squeeze(dim) | |||
| from fastNLP.modules.decoder.utils import log_sum_exp | |||
| def seq_len_to_byte_mask(seq_lens): | |||
| @@ -20,22 +15,27 @@ def seq_len_to_byte_mask(seq_lens): | |||
| return mask | |||
| def allowed_transitions(id2label, encoding_type='bio'): | |||
| def allowed_transitions(id2label, encoding_type='bio', include_start_end=True): | |||
| """ | |||
| 给定一个id到label的映射表,返回所有可以跳转的(from_tag_id, to_tag_id)列表。 | |||
| :param dict id2label: key是label的indices,value是str类型的tag或tag-label。value可以是只有tag的, 比如"B", "M"; 也可以是 | |||
| "B-NN", "M-NN", tag和label之间一定要用"-"隔开。一般可以通过Vocabulary.get_id2word()id2label。 | |||
| :param id2label: Dict, key是label的indices,value是str类型的tag或tag-label。value可以是只有tag的, 比如"B", "M"; 也可以是 | |||
| "B-NN", "M-NN", tag和label之间一定要用"-"隔开。一般可以通过Vocabulary.get_id2word()得到id2label。 | |||
| :param encoding_type: str, 支持"bio", "bmes", "bmeso"。 | |||
| :return: List[Tuple(int, int)]], 内部的Tuple是(from_tag_id, to_tag_id)。 返回的结果考虑了start和end,比如"BIO"中,B、O可以 | |||
| 位于序列的开端,而I不行。所以返回的结果中会包含(start_idx, B_idx), (start_idx, O_idx), 但是不包含(start_idx, I_idx). | |||
| start_idx=len(id2label), end_idx=len(id2label)+1。 | |||
| :param include_start_end: bool, 是否包含开始与结尾的转换。比如在bio中,b/o可以在开头,但是i不能在开头; | |||
| 为True,返回的结果中会包含(start_idx, b_idx), (start_idx, o_idx), 但是不包含(start_idx, i_idx); | |||
| start_idx=len(id2label), end_idx=len(id2label)+1。 | |||
| 为False, 返回的结果中不含与开始结尾相关的内容 | |||
| :return: List[Tuple(int, int)]], 内部的Tuple是可以进行跳转的(from_tag_id, to_tag_id)。 | |||
| """ | |||
| num_tags = len(id2label) | |||
| start_idx = num_tags | |||
| end_idx = num_tags + 1 | |||
| encoding_type = encoding_type.lower() | |||
| allowed_trans = [] | |||
| id_label_lst = list(id2label.items()) + [(start_idx, 'start'), (end_idx, 'end')] | |||
| id_label_lst = list(id2label.items()) | |||
| if include_start_end: | |||
| id_label_lst += [(start_idx, 'start'), (end_idx, 'end')] | |||
| def split_tag_label(from_label): | |||
| from_label = from_label.lower() | |||
| if from_label in ['start', 'end']: | |||
| @@ -54,12 +54,12 @@ def allowed_transitions(id2label, encoding_type='bio'): | |||
| if to_label in ['<pad>', '<unk>']: | |||
| continue | |||
| to_tag, to_label = split_tag_label(to_label) | |||
| if is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label): | |||
| if _is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label): | |||
| allowed_trans.append((from_id, to_id)) | |||
| return allowed_trans | |||
| def is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label): | |||
| def _is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label): | |||
| """ | |||
| :param encoding_type: str, 支持"BIO", "BMES", "BEMSO"。 | |||
| @@ -140,20 +140,22 @@ def is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label) | |||
| raise ValueError("Unexpect tag type {}. Expect only 'B', 'M', 'E', 'S', 'O'.".format(from_tag)) | |||
| else: | |||
| raise ValueError("Only support BIO, BMES encoding type, got {}.".format(encoding_type)) | |||
| raise ValueError("Only support BIO, BMES, BMESO encoding type, got {}.".format(encoding_type)) | |||
| class ConditionalRandomField(nn.Module): | |||
| """ | |||
| :param int num_tags: 标签的数量。 | |||
| :param bool include_start_end_trans: 是否包含起始tag | |||
| :param list allowed_transitions: ``List[Tuple[from_tag_id(int), to_tag_id(int)]]``. 允许的跃迁,可以通过allowed_transitions()得到。 | |||
| 如果为None,则所有跃迁均为合法 | |||
| :param str initial_method: | |||
| """ | |||
| def __init__(self, num_tags, include_start_end_trans=False, allowed_transitions=None, initial_method=None): | |||
| def __init__(self, num_tags, include_start_end_trans=False, allowed_transitions=None, | |||
| initial_method=None): | |||
| """条件随机场。 | |||
| 提供forward()以及viterbi_decode()两个方法,分别用于训练与inference。 | |||
| :param num_tags: int, 标签的数量 | |||
| :param include_start_end_trans: bool, 是否考虑各个tag作为开始以及结尾的分数。 | |||
| :param allowed_transitions: List[Tuple[from_tag_id(int), to_tag_id(int)]], 内部的Tuple[from_tag_id(int), | |||
| to_tag_id(int)]视为允许发生的跃迁,其他没有包含的跃迁认为是禁止跃迁,可以通过 | |||
| allowed_transitions()函数得到;如果为None,则所有跃迁均为合法 | |||
| :param initial_method: str, 初始化方法。见initial_parameter | |||
| """ | |||
| super(ConditionalRandomField, self).__init__() | |||
| self.include_start_end_trans = include_start_end_trans | |||
| @@ -168,18 +170,12 @@ class ConditionalRandomField(nn.Module): | |||
| if allowed_transitions is None: | |||
| constrain = torch.zeros(num_tags + 2, num_tags + 2) | |||
| else: | |||
| constrain = torch.ones(num_tags + 2, num_tags + 2) * -1000 | |||
| constrain = torch.new_full((num_tags+2, num_tags+2), fill_value=-10000.0, dtype=torch.float) | |||
| for from_tag_id, to_tag_id in allowed_transitions: | |||
| constrain[from_tag_id, to_tag_id] = 0 | |||
| self._constrain = nn.Parameter(constrain, requires_grad=False) | |||
| # self.reset_parameter() | |||
| initial_parameter(self, initial_method) | |||
| def reset_parameter(self): | |||
| nn.init.xavier_normal_(self.trans_m) | |||
| if self.include_start_end_trans: | |||
| nn.init.normal_(self.start_scores) | |||
| nn.init.normal_(self.end_scores) | |||
| def _normalizer_likelihood(self, logits, mask): | |||
| """Computes the (batch_size,) denominator term for the log-likelihood, which is the | |||
| @@ -239,10 +235,11 @@ class ConditionalRandomField(nn.Module): | |||
| def forward(self, feats, tags, mask): | |||
| """ | |||
| Calculate the neg log likelihood | |||
| :param feats:FloatTensor, batch_size x max_len x num_tags | |||
| :param tags:LongTensor, batch_size x max_len | |||
| :param mask:ByteTensor batch_size x max_len | |||
| 用于计算CRF的前向loss,返回值为一个batch_size的FloatTensor,可能需要mean()求得loss。 | |||
| :param feats:FloatTensor, batch_size x max_len x num_tags,特征矩阵。 | |||
| :param tags:LongTensor, batch_size x max_len,标签矩阵。 | |||
| :param mask:ByteTensor batch_size x max_len,为0的位置认为是padding。 | |||
| :return:FloatTensor, batch_size | |||
| """ | |||
| feats = feats.transpose(0, 1) | |||
| @@ -253,28 +250,27 @@ class ConditionalRandomField(nn.Module): | |||
| return all_path_score - gold_path_score | |||
| def viterbi_decode(self, data, mask, get_score=False, unpad=False): | |||
| """Given a feats matrix, return best decode path and best score. | |||
| def viterbi_decode(self, feats, mask, unpad=False): | |||
| """给定一个特征矩阵以及转移分数矩阵,计算出最佳的路径以及对应的分数 | |||
| :param data:FloatTensor, batch_size x max_len x num_tags | |||
| :param mask:ByteTensor batch_size x max_len | |||
| :param get_score: bool, whether to output the decode score. | |||
| :param unpad: bool, 是否将结果unpad, | |||
| 如果False, 返回的是batch_size x max_len的tensor, | |||
| 如果True,返回的是List[List[int]], List[int]为每个sequence的label,已经unpadding了,即每个 | |||
| List[int]的长度是这个sample的有效长度 | |||
| :return: 如果get_score为False,返回结果根据unpadding变动 | |||
| 如果get_score为True, 返回 (paths, List[float], )。第一个仍然是解码后的路径(根据unpad变化),第二个List[Float] | |||
| 为每个seqence的解码分数。 | |||
| :param feats: FloatTensor, batch_size x max_len x num_tags,特征矩阵。 | |||
| :param mask: ByteTensor, batch_size x max_len, 为0的位置认为是pad;如果为None,则认为没有padding。 | |||
| :param unpad: bool, 是否将结果删去padding, | |||
| False, 返回的是batch_size x max_len的tensor, | |||
| True,返回的是List[List[int]], 内部的List[int]为每个sequence的label,已经除去pad部分,即每个List[int] | |||
| 的长度是这个sample的有效长度。 | |||
| :return: 返回 (paths, scores)。 | |||
| paths: 是解码后的路径, 其值参照unpad参数. | |||
| scores: torch.FloatTensor, size为(batch_size,), 对应每个最优路径的分数。 | |||
| """ | |||
| batch_size, seq_len, n_tags = data.size() | |||
| data = data.transpose(0, 1).data # L, B, H | |||
| batch_size, seq_len, n_tags = feats.size() | |||
| feats = feats.transpose(0, 1).data # L, B, H | |||
| mask = mask.transpose(0, 1).data.byte() # L, B | |||
| # dp | |||
| vpath = data.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long) | |||
| vscore = data[0] | |||
| vpath = feats.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long) | |||
| vscore = feats[0] | |||
| transitions = self._constrain.data.clone() | |||
| transitions[:n_tags, :n_tags] += self.trans_m.data | |||
| if self.include_start_end_trans: | |||
| @@ -285,23 +281,24 @@ class ConditionalRandomField(nn.Module): | |||
| trans_score = transitions[:n_tags, :n_tags].view(1, n_tags, n_tags).data | |||
| for i in range(1, seq_len): | |||
| prev_score = vscore.view(batch_size, n_tags, 1) | |||
| cur_score = data[i].view(batch_size, 1, n_tags) | |||
| cur_score = feats[i].view(batch_size, 1, n_tags) | |||
| score = prev_score + trans_score + cur_score | |||
| best_score, best_dst = score.max(1) | |||
| vpath[i] = best_dst | |||
| vscore = best_score.masked_fill(mask[i].eq(0).view(batch_size, 1), 0) + \ | |||
| vscore.masked_fill(mask[i].view(batch_size, 1), 0) | |||
| vscore += transitions[:n_tags, n_tags+1].view(1, -1) | |||
| if self.include_start_end_trans: | |||
| vscore += transitions[:n_tags, n_tags+1].view(1, -1) | |||
| # backtrace | |||
| batch_idx = torch.arange(batch_size, dtype=torch.long, device=data.device) | |||
| seq_idx = torch.arange(seq_len, dtype=torch.long, device=data.device) | |||
| batch_idx = torch.arange(batch_size, dtype=torch.long, device=feats.device) | |||
| seq_idx = torch.arange(seq_len, dtype=torch.long, device=feats.device) | |||
| lens = (mask.long().sum(0) - 1) | |||
| # idxes [L, B], batched idx from seq_len-1 to 0 | |||
| idxes = (lens.view(1,-1) - seq_idx.view(-1,1)) % seq_len | |||
| ans = data.new_empty((seq_len, batch_size), dtype=torch.long) | |||
| ans = feats.new_empty((seq_len, batch_size), dtype=torch.long) | |||
| ans_score, last_tags = vscore.max(1) | |||
| ans[idxes[0], batch_idx] = last_tags | |||
| for i in range(seq_len - 1): | |||
| @@ -0,0 +1,70 @@ | |||
| import torch | |||
| def log_sum_exp(x, dim=-1): | |||
| max_value, _ = x.max(dim=dim, keepdim=True) | |||
| res = torch.log(torch.sum(torch.exp(x - max_value), dim=dim, keepdim=True)) + max_value | |||
| return res.squeeze(dim) | |||
| def viterbi_decode(feats, transitions, mask=None, unpad=False): | |||
| """给定一个特征矩阵以及转移分数矩阵,计算出最佳的路径以及对应的分数 | |||
| :param feats: FloatTensor, batch_size x max_len x num_tags,特征矩阵。 | |||
| :param transitions: FloatTensor, n_tags x n_tags。[i, j]位置的值认为是从tag i到tag j的转换。 | |||
| :param mask: ByteTensor, batch_size x max_len, 为0的位置认为是pad;如果为None,则认为没有padding。 | |||
| :param unpad: bool, 是否将结果删去padding, | |||
| False, 返回的是batch_size x max_len的tensor, | |||
| True,返回的是List[List[int]], 内部的List[int]为每个sequence的label,已经除去pad部分,即每个List[int]的长度是 | |||
| 这个sample的有效长度。 | |||
| :return: 返回 (paths, scores)。 | |||
| paths: 是解码后的路径, 其值参照unpad参数. | |||
| scores: torch.FloatTensor, size为(batch_size,), 对应每个最优路径的分数。 | |||
| """ | |||
| batch_size, seq_len, n_tags = feats.size() | |||
| assert n_tags==transitions.size(0) and n_tags==transitions.size(1), "The shapes of transitions and feats are not " \ | |||
| "compatible." | |||
| feats = feats.transpose(0, 1).data # L, B, H | |||
| if mask is not None: | |||
| mask = mask.transpose(0, 1).data.byte() # L, B | |||
| else: | |||
| mask = feats.new_ones((seq_len, batch_size), dtype=torch.uint8) | |||
| # dp | |||
| vpath = feats.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long) | |||
| vscore = feats[0] | |||
| vscore += transitions[n_tags, :n_tags] | |||
| trans_score = transitions[:n_tags, :n_tags].view(1, n_tags, n_tags).data | |||
| for i in range(1, seq_len): | |||
| prev_score = vscore.view(batch_size, n_tags, 1) | |||
| cur_score = feats[i].view(batch_size, 1, n_tags) | |||
| score = prev_score + trans_score + cur_score | |||
| best_score, best_dst = score.max(1) | |||
| vpath[i] = best_dst | |||
| vscore = best_score.masked_fill(mask[i].eq(0).view(batch_size, 1), 0) + \ | |||
| vscore.masked_fill(mask[i].view(batch_size, 1), 0) | |||
| # backtrace | |||
| batch_idx = torch.arange(batch_size, dtype=torch.long, device=feats.device) | |||
| seq_idx = torch.arange(seq_len, dtype=torch.long, device=feats.device) | |||
| lens = (mask.long().sum(0) - 1) | |||
| # idxes [L, B], batched idx from seq_len-1 to 0 | |||
| idxes = (lens.view(1, -1) - seq_idx.view(-1, 1)) % seq_len | |||
| ans = feats.new_empty((seq_len, batch_size), dtype=torch.long) | |||
| ans_score, last_tags = vscore.max(1) | |||
| ans[idxes[0], batch_idx] = last_tags | |||
| for i in range(seq_len - 1): | |||
| last_tags = vpath[idxes[i], batch_idx, last_tags] | |||
| ans[idxes[i + 1], batch_idx] = last_tags | |||
| ans = ans.transpose(0, 1) | |||
| if unpad: | |||
| paths = [] | |||
| for idx, seq_len in enumerate(lens): | |||
| paths.append(ans[idx, :seq_len + 1].tolist()) | |||
| else: | |||
| paths = ans | |||
| return paths, ans_score | |||
| @@ -183,7 +183,7 @@ class CWSBiLSTMCRF(BaseModel): | |||
| masks = seq_lens_to_mask(seq_lens) | |||
| feats = self.encoder_model(chars, bigrams, seq_lens) | |||
| feats = self.decoder_model(feats) | |||
| probs = self.crf.viterbi_decode(feats, masks, get_score=False) | |||
| paths, _ = self.crf.viterbi_decode(feats, masks) | |||
| return {'pred': probs, 'seq_lens':seq_lens} | |||
| return {'pred': paths, 'seq_lens':seq_lens} | |||
| @@ -72,9 +72,9 @@ class TransformerCWS(nn.Module): | |||
| feats = self.transformer(x, masks) | |||
| feats = self.fc2(feats) | |||
| probs = self.crf.viterbi_decode(feats, masks, get_score=False) | |||
| paths, _ = self.crf.viterbi_decode(feats, masks) | |||
| return {'pred': probs, 'seq_lens':seq_lens} | |||
| return {'pred': paths, 'seq_lens':seq_lens} | |||
| class NoamOpt(torch.optim.Optimizer): | |||