2. metric增加SpanFMetric,可以用于计算sequence labelling的performance 3. 分词复现任务根据新版接口做了部分调整。tags/v0.3.0^2
| @@ -9,8 +9,8 @@ from fastNLP.core.dataset import DataSet | |||
| from fastNLP.api.model_zoo import load_url | |||
| from fastNLP.api.processor import ModelProcessor | |||
| from reproduction.chinese_word_segment.cws_io.cws_reader import ConlluCWSReader | |||
| from reproduction.pos_tag_model.pos_io.pos_reader import ConlluPOSReader | |||
| from reproduction.chinese_word_segment.cws_io.cws_reader import ConllCWSReader | |||
| from reproduction.pos_tag_model.pos_io.pos_reader import ConllPOSReader | |||
| from reproduction.Biaffine_parser.util import ConllxDataLoader, add_seg_tag | |||
| from fastNLP.core.instance import Instance | |||
| from fastNLP.core.sampler import SequentialSampler | |||
| @@ -95,7 +95,7 @@ class POS(API): | |||
| pipeline.append(tag_proc) | |||
| pp = Pipeline(pipeline) | |||
| reader = ConlluPOSReader() | |||
| reader = ConllPOSReader() | |||
| te_dataset = reader.load(filepath) | |||
| evaluator = SeqLabelEvaluator2('word_seq_origin_len') | |||
| @@ -168,7 +168,7 @@ class CWS(API): | |||
| pipeline.insert(1, tag_proc) | |||
| pp = Pipeline(pipeline) | |||
| reader = ConlluCWSReader() | |||
| reader = ConllCWSReader() | |||
| # te_filename = '/home/hyan/ctb3/test.conllx' | |||
| te_dataset = reader.load(filepath) | |||
| @@ -11,6 +11,11 @@ from fastNLP.core.vocabulary import Vocabulary | |||
| class Processor(object): | |||
| def __init__(self, field_name, new_added_field_name): | |||
| """ | |||
| :param field_name: 处理哪个field | |||
| :param new_added_field_name: 如果为None,则认为是field_name,即覆盖原有的field | |||
| """ | |||
| self.field_name = field_name | |||
| if new_added_field_name is None: | |||
| self.new_added_field_name = field_name | |||
| @@ -92,6 +97,11 @@ class FullSpaceToHalfSpaceProcessor(Processor): | |||
| class PreAppendProcessor(Processor): | |||
| """ | |||
| 向某个field的起始增加data(应该为str类型)。该field需要为list类型。即新增的field为 | |||
| [data] + instance[field_name] | |||
| """ | |||
| def __init__(self, data, field_name, new_added_field_name=None): | |||
| super(PreAppendProcessor, self).__init__(field_name, new_added_field_name) | |||
| self.data = data | |||
| @@ -102,6 +112,10 @@ class PreAppendProcessor(Processor): | |||
| class SliceProcessor(Processor): | |||
| """ | |||
| 从某个field中只取部分内容。等价于instance[field_name][start:end:step] | |||
| """ | |||
| def __init__(self, start, end, step, field_name, new_added_field_name=None): | |||
| super(SliceProcessor, self).__init__(field_name, new_added_field_name) | |||
| for o in (start, end, step): | |||
| @@ -114,7 +128,17 @@ class SliceProcessor(Processor): | |||
| class Num2TagProcessor(Processor): | |||
| """ | |||
| 将一句话中的数字转换为某个tag。 | |||
| """ | |||
| def __init__(self, tag, field_name, new_added_field_name=None): | |||
| """ | |||
| :param tag: str, 将数字转换为该tag | |||
| :param field_name: | |||
| :param new_added_field_name: | |||
| """ | |||
| super(Num2TagProcessor, self).__init__(field_name, new_added_field_name) | |||
| self.tag = tag | |||
| self.pattern = r'[-+]?([0-9]+[.]?[0-9]*)+[/eE]?[-+]?([0-9]+[.]?[0-9]*)' | |||
| @@ -135,6 +159,10 @@ class Num2TagProcessor(Processor): | |||
| class IndexerProcessor(Processor): | |||
| """ | |||
| 给定一个vocabulary , 将指定field转换为index形式。指定field应该是一维的list,比如 | |||
| ['我', '是', xxx] | |||
| """ | |||
| def __init__(self, vocab, field_name, new_added_field_name, delete_old_field=False, is_input=True): | |||
| assert isinstance(vocab, Vocabulary), "Only Vocabulary class is allowed, not {}.".format(type(vocab)) | |||
| @@ -163,19 +191,19 @@ class IndexerProcessor(Processor): | |||
| class VocabProcessor(Processor): | |||
| """Build vocabulary with a field in the data set. | |||
| """ | |||
| 传入若干个DataSet以建立vocabulary。 | |||
| """ | |||
| def __init__(self, field_name): | |||
| def __init__(self, field_name, min_freq=1, max_size=None): | |||
| super(VocabProcessor, self).__init__(field_name, None) | |||
| self.vocab = Vocabulary() | |||
| self.vocab = Vocabulary(min_freq=min_freq, max_size=max_size) | |||
| def process(self, *datasets): | |||
| for dataset in datasets: | |||
| assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | |||
| for ins in dataset: | |||
| self.vocab.update(ins[self.field_name]) | |||
| dataset.apply(lambda ins: self.vocab.update(ins[self.field_name])) | |||
| def get_vocab(self): | |||
| self.vocab.build_vocab() | |||
| @@ -183,6 +211,10 @@ class VocabProcessor(Processor): | |||
| class SeqLenProcessor(Processor): | |||
| """ | |||
| 根据某个field新增一个sequence length的field。取该field的第一维 | |||
| """ | |||
| def __init__(self, field_name, new_added_field_name='seq_lens', is_input=True): | |||
| super(SeqLenProcessor, self).__init__(field_name, new_added_field_name) | |||
| self.is_input = is_input | |||
| @@ -195,10 +227,15 @@ class SeqLenProcessor(Processor): | |||
| return dataset | |||
| from fastNLP.core.utils import _build_args | |||
| class ModelProcessor(Processor): | |||
| def __init__(self, model, seq_len_field_name='seq_lens', batch_size=32): | |||
| """ | |||
| 迭代模型并将结果的padding drop掉 | |||
| 传入一个model,在process()时传入一个dataset,该processor会通过Batch将DataSet的内容输出给model.predict或者model.forward. | |||
| model输出的内容会被增加到dataset中,field_name由model输出决定。如果生成的内容维度不是(Batch_size, )与 | |||
| (Batch_size, 1),则使用seqence length这个field进行unpad | |||
| TODO 这个类需要删除对seq_lens的依赖。 | |||
| :param seq_len_field_name: | |||
| :param batch_size: | |||
| @@ -211,13 +248,18 @@ class ModelProcessor(Processor): | |||
| def process(self, dataset): | |||
| self.model.eval() | |||
| assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | |||
| data_iterator = Batch(dataset, batch_size=self.batch_size, sampler=SequentialSampler(), use_cuda=False) | |||
| data_iterator = Batch(dataset, batch_size=self.batch_size, sampler=SequentialSampler()) | |||
| batch_output = defaultdict(list) | |||
| if hasattr(self.model, "predict"): | |||
| predict_func = self.model.predict | |||
| else: | |||
| predict_func = self.model.forward | |||
| with torch.no_grad(): | |||
| for batch_x, _ in data_iterator: | |||
| prediction = self.model.predict(**batch_x) | |||
| seq_lens = batch_x[self.seq_len_field_name].cpu().numpy().tolist() | |||
| refined_batch_x = _build_args(predict_func, **batch_x) | |||
| prediction = predict_func(**refined_batch_x) | |||
| seq_lens = batch_x[self.seq_len_field_name].tolist() | |||
| for key, value in prediction.items(): | |||
| tmp_batch = [] | |||
| @@ -246,6 +288,10 @@ class ModelProcessor(Processor): | |||
| class Index2WordProcessor(Processor): | |||
| """ | |||
| 将DataSet中某个为index的field根据vocab转换为str | |||
| """ | |||
| def __init__(self, vocab, field_name, new_added_field_name): | |||
| super(Index2WordProcessor, self).__init__(field_name, new_added_field_name) | |||
| self.vocab = vocab | |||
| @@ -266,5 +312,5 @@ class SetIsTargetProcessor(Processor): | |||
| def process(self, dataset): | |||
| set_dict = {name: self.default for name in dataset.get_all_fields().keys()} | |||
| set_dict.update(self.field_dict) | |||
| dataset.set_target(**set_dict) | |||
| dataset.set_target(*set_dict.keys()) | |||
| return dataset | |||
| @@ -254,7 +254,7 @@ class DataSet(object): | |||
| :return results: if new_field_name is not passed, returned values of the function over all instances. | |||
| """ | |||
| results = [func(ins) for ins in self._inner_iter()] | |||
| if len(list(filter(lambda x: x is not None, results))) == 0: # all None | |||
| if len(list(filter(lambda x: x is not None, results))) == 0 and not (new_field_name is None): # all None | |||
| raise ValueError("{} always return None.".format(get_func_signature(func=func))) | |||
| extra_param = {} | |||
| @@ -10,6 +10,7 @@ from fastNLP.core.utils import _build_args | |||
| from fastNLP.core.utils import _check_arg_dict_list | |||
| from fastNLP.core.utils import get_func_signature | |||
| from fastNLP.core.utils import seq_lens_to_masks | |||
| from fastNLP.core.vocabulary import Vocabulary | |||
| class MetricBase(object): | |||
| @@ -62,11 +63,6 @@ class MetricBase(object): | |||
| f"Parameter `{func_param}` is not in {get_func_signature(self.evaluate)}. Please check the " | |||
| f"initialization parameters, or change its signature.") | |||
| # evaluate should not have varargs. | |||
| # if func_spect.varargs: | |||
| # raise NameError(f"Delete `*{func_spect.varargs}` in {get_func_signature(self.evaluate)}(Do not use " | |||
| # f"positional argument.).") | |||
| def get_metric(self, reset=True): | |||
| raise NotImplemented | |||
| @@ -91,10 +87,9 @@ class MetricBase(object): | |||
| This method will call self.evaluate method. | |||
| Before calling self.evaluate, it will first check the validity of output_dict, target_dict | |||
| (1) whether self.evaluate has varargs, which is not supported. | |||
| (2) whether params needed by self.evaluate is not included in output_dict,target_dict. | |||
| (3) whether params needed by self.evaluate duplicate in pred_dict, target_dict | |||
| (4) whether params in output_dict, target_dict are not used by evaluate.(Might cause warning) | |||
| (1) whether params needed by self.evaluate is not included in output_dict,target_dict. | |||
| (2) whether params needed by self.evaluate duplicate in pred_dict, target_dict | |||
| (3) whether params in output_dict, target_dict are not used by evaluate.(Might cause warning) | |||
| Besides, before passing params into self.evaluate, this function will filter out params from output_dict and | |||
| target_dict which are not used in self.evaluate. (but if **kwargs presented in self.evaluate, no filtering | |||
| will be conducted.) | |||
| @@ -275,6 +270,369 @@ class AccuracyMetric(MetricBase): | |||
| self.total = 0 | |||
| return evaluate_result | |||
| def bmes_tag_to_spans(tags, ignore_labels=None): | |||
| """ | |||
| :param tags: List[str], | |||
| :param ignore_labels: List[str], 在该list中的label将被忽略 | |||
| :return: List[Tuple[str, List[int, int]]]. [(label,[start, end])] | |||
| """ | |||
| ignore_labels = set(ignore_labels) if ignore_labels else set() | |||
| spans = [] | |||
| prev_bmes_tag = None | |||
| for idx, tag in enumerate(tags): | |||
| tag = tag.lower() | |||
| bmes_tag, label = tag[:1], tag[2:] | |||
| if bmes_tag in ('b', 's'): | |||
| spans.append((label, [idx, idx])) | |||
| elif bmes_tag in ('m', 'e') and prev_bmes_tag in ('b', 'm') and label==spans[-1][0]: | |||
| spans[-1][1][1] = idx | |||
| else: | |||
| spans.append((label, [idx, idx])) | |||
| prev_bmes_tag = bmes_tag | |||
| return [(span[0], (span[1][0], span[1][1])) | |||
| for span in spans | |||
| if span[0] not in ignore_labels | |||
| ] | |||
| def bio_tag_to_spans(tags, ignore_labels=None): | |||
| """ | |||
| :param tags: List[str], | |||
| :param ignore_labels: List[str], 在该list中的label将被忽略 | |||
| :return: List[Tuple[str, List[int, int]]]. [(label,[start, end])] | |||
| """ | |||
| ignore_labels = set(ignore_labels) if ignore_labels else set() | |||
| spans = [] | |||
| prev_bio_tag = None | |||
| for idx, tag in enumerate(tags): | |||
| tag = tag.lower() | |||
| bio_tag, label = tag[:1], tag[2:] | |||
| if bio_tag == 'b': | |||
| spans.append((label, [idx, idx])) | |||
| elif bio_tag == 'i' and prev_bio_tag in ('b', 'i') and label==spans[-1][0]: | |||
| spans[-1][1][1] = idx | |||
| elif bio_tag == 'o': # o tag does not count | |||
| pass | |||
| else: | |||
| spans.append((label, [idx, idx])) | |||
| prev_bio_tag = bio_tag | |||
| return [(span[0], (span[1][0], span[1][1])) | |||
| for span in spans | |||
| if span[0] not in ignore_labels | |||
| ] | |||
| class SpanFPreRecMetric(MetricBase): | |||
| """ | |||
| 在序列标注问题中,以span的方式计算F, pre, rec. | |||
| 最后得到的metric结果为 | |||
| { | |||
| 'f': xxx, # 这里使用f考虑以后可以计算f_beta值 | |||
| 'pre': xxx, | |||
| 'rec':xxx | |||
| } | |||
| 若only_gross=False, 即还会返回各个label的metric统计值 | |||
| { | |||
| 'f': xxx, | |||
| 'pre': xxx, | |||
| 'rec':xxx, | |||
| 'f-label': xxx, | |||
| 'pre-label': xxx, | |||
| 'rec-label':xxx, | |||
| ... | |||
| } | |||
| """ | |||
| def __init__(self, tag_vocab, pred=None, target=None, seq_lens=None, encoding_type='bio', ignore_labels=None, | |||
| only_gross=True, f_type='micro', beta=1): | |||
| """ | |||
| :param tag_vocab: Vocabulary, 标签的vocabulary。支持的标签为"B"(没有label);或"B-xxx"(xxx为某种label,比如POS中的NN), | |||
| 在解码时,会将相同xxx的认为是同一个label,比如['B-NN', 'E-NN']会被合并为一个'NN'. | |||
| :param pred: str, 用该key在evaluate()时从传入dict中取出prediction数据。 为None,则使用'pred'取数据 | |||
| :param target: str, 用该key在evaluate()时从传入dict中取出target数据。 为None,则使用'target'取数据 | |||
| :param seq_lens: str, 用该key在evaluate()时从传入dict中取出sequence length数据。为None,则使用'seq_lens'取数据。 | |||
| :param encoding_type: str, 目前支持bio, bmes | |||
| :param ignore_labels, List[str]. 这个list中的class不会被用于计算。例如在POS tagging时传入['NN'],则不会计算'NN'这 | |||
| 个label | |||
| :param only_gross, bool. 是否只计算总的f1, precision, recall的值;如果为False,不仅返回总的f1, pre, rec, 还会返回每个 | |||
| label的f1, pre, rec | |||
| :param f_type, str. 'micro'或'macro'. 'micro':通过先计算总体的TP,FN和FP的数量,再计算f, precision, recall; 'macro': | |||
| 分布计算每个类别的f, precision, recall,然后做平均(各类别f的权重相同) | |||
| :param beta, float. f_beta分数,f_beta = (1 + beta^2)*(pre*rec)/(beta^2*pre + rec). 常用为beta=0.5, 1, 2. 若为0.5 | |||
| 则精确率的权重高于召回率;若为1,则两者平等;若为2,则召回率权重高于精确率。 | |||
| """ | |||
| encoding_type = encoding_type.lower() | |||
| if encoding_type not in ('bio', 'bmes'): | |||
| raise ValueError("Only support 'bio' or 'bmes' type.") | |||
| if not isinstance(tag_vocab, Vocabulary): | |||
| raise TypeError("tag_vocab can only be fastNLP.Vocabulary, not {}.".format(type(tag_vocab))) | |||
| if f_type not in ('micro', 'macro'): | |||
| raise ValueError("f_type only supports `micro` or `macro`', got {}.".format(f_type)) | |||
| self.encoding_type = encoding_type | |||
| if self.encoding_type == 'bmes': | |||
| self.tag_to_span_func = bmes_tag_to_spans | |||
| elif self.encoding_type == 'bio': | |||
| self.tag_to_span_func = bio_tag_to_spans | |||
| self.ignore_labels = ignore_labels | |||
| self.f_type = f_type | |||
| self.beta = beta | |||
| self.beta_square = self.beta**2 | |||
| self.only_gross = only_gross | |||
| super().__init__() | |||
| self._init_param_map(pred=pred, target=target, seq_lens=seq_lens) | |||
| self.tag_vocab = tag_vocab | |||
| self._true_positives = defaultdict(int) | |||
| self._false_positives = defaultdict(int) | |||
| self._false_negatives = defaultdict(int) | |||
| def evaluate(self, pred, target, seq_lens): | |||
| """ | |||
| A lot of design idea comes from allennlp's measure | |||
| :param pred: | |||
| :param target: | |||
| :param seq_lens: | |||
| :return: | |||
| """ | |||
| if not isinstance(pred, torch.Tensor): | |||
| raise TypeError(f"`pred` in {get_func_signature(self.evaluate)} must be torch.Tensor," | |||
| f"got {type(pred)}.") | |||
| if not isinstance(target, torch.Tensor): | |||
| raise TypeError(f"`target` in {get_func_signature(self.evaluate)} must be torch.Tensor," | |||
| f"got {type(target)}.") | |||
| if not isinstance(seq_lens, torch.Tensor): | |||
| raise TypeError(f"`seq_lens` in {get_func_signature(self.evaluate)} must be torch.Tensor," | |||
| f"got {type(seq_lens)}.") | |||
| num_classes = pred.size(-1) | |||
| if (target >= num_classes).any(): | |||
| raise ValueError("A gold label passed to SpanBasedF1Metric contains an " | |||
| "id >= {}, the number of classes.".format(num_classes)) | |||
| if pred.size() == target.size() and len(target.size()) == 2: | |||
| pass | |||
| elif len(pred.size()) == len(target.size()) + 1 and len(target.size()) == 2: | |||
| pred = pred.argmax(dim=-1) | |||
| else: | |||
| raise RuntimeError(f"In {get_func_signature(self.evaluate)}, when pred have " | |||
| f"size:{pred.size()}, target should have size: {pred.size()} or " | |||
| f"{pred.size()[:-1]}, got {target.size()}.") | |||
| batch_size = pred.size(0) | |||
| for i in range(batch_size): | |||
| pred_tags = pred[i, :seq_lens[i]].tolist() | |||
| gold_tags = target[i, :seq_lens[i]].tolist() | |||
| pred_str_tags = [self.tag_vocab.to_word(tag) for tag in pred_tags] | |||
| gold_str_tags = [self.tag_vocab.to_word(tag) for tag in gold_tags] | |||
| pred_spans = self.tag_to_span_func(pred_str_tags, ignore_labels=self.ignore_labels) | |||
| gold_spans = self.tag_to_span_func(gold_str_tags, ignore_labels=self.ignore_labels) | |||
| for span in pred_spans: | |||
| if span in gold_spans: | |||
| self._true_positives[span[0]] += 1 | |||
| gold_spans.remove(span) | |||
| else: | |||
| self._false_positives[span[0]] += 1 | |||
| for span in gold_spans: | |||
| self._false_negatives[span[0]] += 1 | |||
| def get_metric(self, reset=True): | |||
| evaluate_result = {} | |||
| if not self.only_gross or self.f_type=='macro': | |||
| tags = set(self._false_negatives.keys()) | |||
| tags.update(set(self._false_positives.keys())) | |||
| tags.update(set(self._true_positives.keys())) | |||
| f_sum = 0 | |||
| pre_sum = 0 | |||
| rec_sum = 0 | |||
| for tag in tags: | |||
| tp = self._true_positives[tag] | |||
| fn = self._false_negatives[tag] | |||
| fp = self._false_positives[tag] | |||
| f, pre, rec = self._compute_f_pre_rec(tp, fn, fp) | |||
| f_sum += f | |||
| pre_sum += pre | |||
| rec_sum + rec | |||
| if not self.only_gross and tag!='': # tag!=''防止无tag的情况 | |||
| f_key = 'f-{}'.format(tag) | |||
| pre_key = 'pre-{}'.format(tag) | |||
| rec_key = 'rec-{}'.format(tag) | |||
| evaluate_result[f_key] = f | |||
| evaluate_result[pre_key] = pre | |||
| evaluate_result[rec_key] = rec | |||
| if self.f_type == 'macro': | |||
| evaluate_result['f'] = f_sum/len(tags) | |||
| evaluate_result['pre'] = pre_sum/len(tags) | |||
| evaluate_result['rec'] = rec_sum/len(tags) | |||
| if self.f_type == 'micro': | |||
| f, pre, rec = self._compute_f_pre_rec(sum(self._true_positives.values()), | |||
| sum(self._false_negatives.values()), | |||
| sum(self._false_positives.values())) | |||
| evaluate_result['f'] = f | |||
| evaluate_result['pre'] = pre | |||
| evaluate_result['rec'] = rec | |||
| if reset: | |||
| self._true_positives = defaultdict(int) | |||
| self._false_positives = defaultdict(int) | |||
| self._false_negatives = defaultdict(int) | |||
| return evaluate_result | |||
| def _compute_f_pre_rec(self, tp, fn, fp): | |||
| """ | |||
| :param tp: int, true positive | |||
| :param fn: int, false negative | |||
| :param fp: int, false positive | |||
| :return: (f, pre, rec) | |||
| """ | |||
| pre = tp / (fp + tp + 1e-13) | |||
| rec = tp / (fn + tp + 1e-13) | |||
| f = (1 + self.beta_square) * pre * rec / (self.beta_square * pre + rec + 1e-13) | |||
| return f, pre, rec | |||
| class BMESF1PreRecMetric(MetricBase): | |||
| """ | |||
| 按照BMES标注方式计算f1, precision, recall。由于可能存在非法tag,比如"BS",所以需要用以下的表格做转换,cur_B意思是当前tag是B, | |||
| next_B意思是后一个tag是B。则cur_B=S,即将当前被predict是B的tag标为S;next_M=B, 即将后一个被predict是M的tag标为B | |||
| | | next_B | next_M | next_E | next_S | end | | |||
| |:-----:|:-------:|:--------:|:--------:|:-------:|:-------:| | |||
| | start | 合法 | next_M=B | next_E=S | 合法 | - | | |||
| | cur_B | cur_B=S | 合法 | 合法 | cur_B=S | cur_B=S | | |||
| | cur_M | cur_M=E | 合法 | 合法 | cur_M=E | cur_M=E | | |||
| | cur_E | 合法 | next_M=B | next_E=S | 合法 | 合法 | | |||
| | cur_S | 合法 | next_M=B | next_E=S | 合法 | 合法 | | |||
| 举例: | |||
| prediction为BSEMS,会被认为是SSSSS. | |||
| 本Metric不检验target的合法性,请务必保证target的合法性。 | |||
| pred的形状应该为(batch_size, max_len) 或 (batch_size, max_len, 4)。 | |||
| target形状为 (batch_size, max_len) | |||
| seq_lens形状为 (batch_size, ) | |||
| """ | |||
| def __init__(self, b_idx=0, m_idx=1, e_idx=2, s_idx=3, pred=None, target=None, seq_lens=None): | |||
| """ | |||
| 需要申明BMES这四种tag中,各种tag对应的idx。所有不为b_idx, m_idx, e_idx, s_idx的数字都认为是s_idx。 | |||
| :param b_idx: int, Begin标签所对应的tag idx. | |||
| :param m_idx: int, Middle标签所对应的tag idx. | |||
| :param e_idx: int, End标签所对应的tag idx. | |||
| :param s_idx: int, Single标签所对应的tag idx | |||
| :param pred: str, 用该key在evaluate()时从传入dict中取出prediction数据。 为None,则使用'pred'取数据 | |||
| :param target: str, 用该key在evaluate()时从传入dict中取出target数据。 为None,则使用'target'取数据 | |||
| :param seq_lens: str, 用该key在evaluate()时从传入dict中取出seqence length数据。为None,则使用'seq_lens'取数据。 | |||
| """ | |||
| super().__init__() | |||
| self._init_param_map(pred=pred, target=target, seq_lens=seq_lens) | |||
| self.yt_wordnum = 0 | |||
| self.yp_wordnum = 0 | |||
| self.corr_num = 0 | |||
| self.b_idx = b_idx | |||
| self.m_idx = m_idx | |||
| self.e_idx = e_idx | |||
| self.s_idx = s_idx | |||
| # 还原init处介绍的矩阵 | |||
| self._valida_matrix = { | |||
| -1: [(-1, -1), (1, self.b_idx), (1, self.s_idx), (-1, -1)], # magic start idx | |||
| self.b_idx:[(0, self.s_idx), (-1, -1), (-1, -1), (0, self.s_idx), (0, self.s_idx)], | |||
| self.m_idx:[(0, self.e_idx), (-1, -1), (-1, -1), (0, self.e_idx), (0, self.e_idx)], | |||
| self.e_idx:[(-1, -1), (1, self.b_idx), (1, self.s_idx), (-1, -1), (-1, -1)], | |||
| self.s_idx:[(-1, -1), (1, self.b_idx), (1, self.s_idx), (-1, -1), (-1, -1)], | |||
| } | |||
| def _validate_tags(self, tags): | |||
| """ | |||
| 给定一个tag的Tensor,返回合法tag | |||
| :param tags: Tensor, shape: (seq_len, ) | |||
| :return: 返回修改为合法tag的list | |||
| """ | |||
| assert len(tags)!=0 | |||
| assert isinstance(tags, torch.Tensor) and len(tags.size())==1 | |||
| padded_tags = [-1, *tags.tolist(), -1] | |||
| for idx in range(len(padded_tags)-1): | |||
| cur_tag = padded_tags[idx] | |||
| if cur_tag not in self._valida_matrix: | |||
| cur_tag = self.s_idx | |||
| if padded_tags[idx+1] not in self._valida_matrix: | |||
| padded_tags[idx+1] = self.s_idx | |||
| next_tag = padded_tags[idx+1] | |||
| shift_tag = self._valida_matrix[cur_tag][next_tag] | |||
| if shift_tag[0]!=-1: | |||
| padded_tags[idx+shift_tag[0]] = shift_tag[1] | |||
| return padded_tags[1:-1] | |||
| def evaluate(self, pred, target, seq_lens): | |||
| if not isinstance(pred, torch.Tensor): | |||
| raise TypeError(f"`pred` in {get_func_signature(self.evaluate)} must be torch.Tensor," | |||
| f"got {type(pred)}.") | |||
| if not isinstance(target, torch.Tensor): | |||
| raise TypeError(f"`target` in {get_func_signature(self.evaluate)} must be torch.Tensor," | |||
| f"got {type(target)}.") | |||
| if not isinstance(seq_lens, torch.Tensor): | |||
| raise TypeError(f"`seq_lens` in {get_func_signature(self.evaluate)} must be torch.Tensor," | |||
| f"got {type(seq_lens)}.") | |||
| if pred.size() == target.size() and len(target.size()) == 2: | |||
| pass | |||
| elif len(pred.size()) == len(target.size()) + 1 and len(target.size()) == 2: | |||
| pred = pred.argmax(dim=-1) | |||
| else: | |||
| raise RuntimeError(f"In {get_func_signature(self.evaluate)}, when pred have " | |||
| f"size:{pred.size()}, target should have size: {pred.size()} or " | |||
| f"{pred.size()[:-1]}, got {target.size()}.") | |||
| for idx in range(len(pred)): | |||
| seq_len = seq_lens[idx] | |||
| target_tags = target[idx][:seq_len].tolist() | |||
| pred_tags = pred[idx][:seq_len] | |||
| pred_tags = self._validate_tags(pred_tags) | |||
| start_idx = 0 | |||
| for t_idx, (t_tag, p_tag) in enumerate(zip(target_tags, pred_tags)): | |||
| if t_tag in (self.s_idx, self.e_idx): | |||
| self.yt_wordnum += 1 | |||
| corr_flag = True | |||
| for i in range(start_idx, t_idx+1): | |||
| if target_tags[i]!=pred_tags[i]: | |||
| corr_flag = False | |||
| if corr_flag: | |||
| self.corr_num += 1 | |||
| start_idx = t_idx + 1 | |||
| if p_tag in (self.s_idx, self.e_idx): | |||
| self.yp_wordnum += 1 | |||
| def get_metric(self, reset=True): | |||
| P = self.corr_num / (self.yp_wordnum + 1e-12) | |||
| R = self.corr_num / (self.yt_wordnum + 1e-12) | |||
| F = 2 * P * R / (P + R + 1e-12) | |||
| evaluate_result = {'f': round(F, 6), 'pre':round(P, 6), 'rec': round(R, 6)} | |||
| if reset: | |||
| self.yp_wordnum = 0 | |||
| self.yt_wordnum = 0 | |||
| self.corr_num = 0 | |||
| return evaluate_result | |||
| def _prepare_metrics(metrics): | |||
| """ | |||
| @@ -31,9 +31,8 @@ class Trainer(object): | |||
| """ | |||
| def __init__(self, train_data, model, loss=None, metrics=None, n_epochs=3, batch_size=32, print_every=50, | |||
| validate_every=-1, dev_data=None, use_cuda=False, save_path=None, | |||
| optimizer=Adam(lr=0.01, weight_decay=0), check_code_level=0, | |||
| metric_key=None, sampler=RandomSampler(), use_tqdm=True): | |||
| validate_every=-1, dev_data=None, save_path=None, optimizer=Adam(lr=0.01, weight_decay=0), | |||
| check_code_level=0, metric_key=None, sampler=RandomSampler(), use_tqdm=True, use_cuda=False): | |||
| """ | |||
| :param DataSet train_data: the training data | |||
| @@ -19,26 +19,149 @@ def seq_len_to_byte_mask(seq_lens): | |||
| mask = broadcast_arange.lt(seq_lens.float().view(-1, 1)) | |||
| return mask | |||
| def allowed_transitions(id2label, encoding_type='bio'): | |||
| """ | |||
| :param id2label: dict, key是label的indices,value是str类型的tag或tag-label。value可以是只有tag的, 比如"B", "M"; 也可以是 | |||
| "B-NN", "M-NN", tag和label之间一定要用"-"隔开。一般可以通过Vocabulary.get_id2word()id2label。 | |||
| :param encoding_type: str, 支持"bio", "bmes"。 | |||
| :return:List[Tuple(int, int)]], 内部的Tuple是(from_tag_id, to_tag_id)。 返回的结果考虑了start和end,比如"BIO"中,B、O可以 | |||
| 位于序列的开端,而I不行。所以返回的结果中会包含(start_idx, B_idx), (start_idx, O_idx), 但是不包含(start_idx, I_idx). | |||
| start_idx=len(id2label), end_idx=len(id2label)+1。 | |||
| """ | |||
| num_tags = len(id2label) | |||
| start_idx = num_tags | |||
| end_idx = num_tags + 1 | |||
| encoding_type = encoding_type.lower() | |||
| allowed_trans = [] | |||
| id_label_lst = list(id2label.items()) + [(start_idx, 'start'), (end_idx, 'end')] | |||
| def split_tag_label(from_label): | |||
| from_label = from_label.lower() | |||
| if from_label in ['start', 'end']: | |||
| from_tag = from_label | |||
| from_label = '' | |||
| else: | |||
| from_tag = from_label[:1] | |||
| from_label = from_label[2:] | |||
| return from_tag, from_label | |||
| for from_id, from_label in id_label_lst: | |||
| if from_label in ['<pad>', '<unk>']: | |||
| continue | |||
| from_tag, from_label = split_tag_label(from_label) | |||
| for to_id, to_label in id_label_lst: | |||
| if to_label in ['<pad>', '<unk>']: | |||
| continue | |||
| to_tag, to_label = split_tag_label(to_label) | |||
| if is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label): | |||
| allowed_trans.append((from_id, to_id)) | |||
| return allowed_trans | |||
| def is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label): | |||
| """ | |||
| :param encoding_type: str, 支持"BIO", "BMES"。 | |||
| :param from_tag: str, 比如"B", "M"之类的标注tag. 还包括start, end等两种特殊tag | |||
| :param from_label: str, 比如"PER", "LOC"等label | |||
| :param to_tag: str, 比如"B", "M"之类的标注tag. 还包括start, end等两种特殊tag | |||
| :param to_label: str, 比如"PER", "LOC"等label | |||
| :return: bool,能否跃迁 | |||
| """ | |||
| if to_tag=='start' or from_tag=='end': | |||
| return False | |||
| encoding_type = encoding_type.lower() | |||
| if encoding_type == 'bio': | |||
| """ | |||
| 第一行是to_tag, 第一列是from_tag. y任意条件下可转,-只有在label相同时可转,n不可转 | |||
| +-------+---+---+---+-------+-----+ | |||
| | | B | I | O | start | end | | |||
| +-------+---+---+---+-------+-----+ | |||
| | B | y | - | y | n | y | | |||
| +-------+---+---+---+-------+-----+ | |||
| | I | y | - | y | n | y | | |||
| +-------+---+---+---+-------+-----+ | |||
| | O | y | n | y | n | y | | |||
| +-------+---+---+---+-------+-----+ | |||
| | start | y | n | y | n | n | | |||
| +-------+---+---+---+-------+-----+ | |||
| | end | n | n | n | n | n | | |||
| +-------+---+---+---+-------+-----+ | |||
| """ | |||
| if from_tag == 'start': | |||
| return to_tag in ('b', 'o') | |||
| elif from_tag in ['b', 'i']: | |||
| return any([to_tag in ['end', 'b', 'o'], to_tag=='i' and from_label==to_label]) | |||
| elif from_tag == 'o': | |||
| return to_tag in ['end', 'b', 'o'] | |||
| else: | |||
| raise ValueError("Unexpect tag {}. Expect only 'B', 'I', 'O'.".format(from_tag)) | |||
| elif encoding_type == 'bmes': | |||
| """ | |||
| 第一行是to_tag, 第一列是from_tag,y任意条件下可转,-只有在label相同时可转,n不可转 | |||
| +-------+---+---+---+---+-------+-----+ | |||
| | | B | M | E | S | start | end | | |||
| +-------+---+---+---+---+-------+-----+ | |||
| | B | n | - | - | n | n | n | | |||
| +-------+---+---+---+---+-------+-----+ | |||
| | M | n | - | - | n | n | n | | |||
| +-------+---+---+---+---+-------+-----+ | |||
| | E | y | n | n | y | n | y | | |||
| +-------+---+---+---+---+-------+-----+ | |||
| | S | y | n | n | y | n | y | | |||
| +-------+---+---+---+---+-------+-----+ | |||
| | start | y | n | n | y | n | n | | |||
| +-------+---+---+---+---+-------+-----+ | |||
| | end | n | n | n | n | n | n | | |||
| +-------+---+---+---+---+-------+-----+ | |||
| """ | |||
| if from_tag == 'start': | |||
| return to_tag in ['b', 's'] | |||
| elif from_tag == 'b': | |||
| return to_tag in ['m', 'e'] and from_label==to_label | |||
| elif from_tag == 'm': | |||
| return to_tag in ['m', 'e'] and from_label==to_label | |||
| elif from_tag in ['e', 's']: | |||
| return to_tag in ['b', 's', 'end'] | |||
| else: | |||
| raise ValueError("Unexpect tag type {}. Expect only 'B', 'M', 'E', 'S'.".format(from_tag)) | |||
| else: | |||
| raise ValueError("Only support BIO, BMES encoding type, got {}.".format(encoding_type)) | |||
| class ConditionalRandomField(nn.Module): | |||
| def __init__(self, tag_size, include_start_end_trans=False ,initial_method = None): | |||
| def __init__(self, num_tags, include_start_end_trans=False, allowed_transitions=None, initial_method=None): | |||
| """ | |||
| :param tag_size: int, num of tags | |||
| :param include_start_end_trans: bool, whether to include start/end tag | |||
| :param num_tags: int, 标签的数量。 | |||
| :param include_start_end_trans: bool, 是否包含起始tag | |||
| :param allowed_transitions: List[Tuple[from_tag_id(int), to_tag_id(int)]]. 允许的跃迁,可以通过allowed_transitions()得到。 | |||
| 如果为None,则所有跃迁均为合法 | |||
| :param initial_method: | |||
| """ | |||
| super(ConditionalRandomField, self).__init__() | |||
| self.include_start_end_trans = include_start_end_trans | |||
| self.tag_size = tag_size | |||
| self.num_tags = num_tags | |||
| # the meaning of entry in this matrix is (from_tag_id, to_tag_id) score | |||
| self.trans_m = nn.Parameter(torch.randn(tag_size, tag_size)) | |||
| self.trans_m = nn.Parameter(torch.randn(num_tags, num_tags)) | |||
| if self.include_start_end_trans: | |||
| self.start_scores = nn.Parameter(torch.randn(tag_size)) | |||
| self.end_scores = nn.Parameter(torch.randn(tag_size)) | |||
| self.start_scores = nn.Parameter(torch.randn(num_tags)) | |||
| self.end_scores = nn.Parameter(torch.randn(num_tags)) | |||
| if allowed_transitions is None: | |||
| constrain = torch.zeros(num_tags + 2, num_tags + 2) | |||
| else: | |||
| constrain = torch.ones(num_tags + 2, num_tags + 2) * -1000 | |||
| for from_tag_id, to_tag_id in allowed_transitions: | |||
| constrain[from_tag_id, to_tag_id] = 0 | |||
| self._constrain = nn.Parameter(constrain, requires_grad=False) | |||
| # self.reset_parameter() | |||
| initial_parameter(self, initial_method) | |||
| def reset_parameter(self): | |||
| nn.init.xavier_normal_(self.trans_m) | |||
| if self.include_start_end_trans: | |||
| @@ -49,7 +172,7 @@ class ConditionalRandomField(nn.Module): | |||
| """ | |||
| Computes the (batch_size,) denominator term for the log-likelihood, which is the | |||
| sum of the likelihoods across all possible state sequences. | |||
| :param logits:FloatTensor, max_len x batch_size x tag_size | |||
| :param logits:FloatTensor, max_len x batch_size x num_tags | |||
| :param mask:ByteTensor, max_len x batch_size | |||
| :return:FloatTensor, batch_size | |||
| """ | |||
| @@ -72,7 +195,7 @@ class ConditionalRandomField(nn.Module): | |||
| def _glod_score(self, logits, tags, mask): | |||
| """ | |||
| Compute the score for the gold path. | |||
| :param logits: FloatTensor, max_len x batch_size x tag_size | |||
| :param logits: FloatTensor, max_len x batch_size x num_tags | |||
| :param tags: LongTensor, max_len x batch_size | |||
| :param mask: ByteTensor, max_len x batch_size | |||
| :return:FloatTensor, batch_size | |||
| @@ -99,7 +222,7 @@ class ConditionalRandomField(nn.Module): | |||
| def forward(self, feats, tags, mask): | |||
| """ | |||
| Calculate the neg log likelihood | |||
| :param feats:FloatTensor, batch_size x max_len x tag_size | |||
| :param feats:FloatTensor, batch_size x max_len x num_tags | |||
| :param tags:LongTensor, batch_size x max_len | |||
| :param mask:ByteTensor batch_size x max_len | |||
| :return:FloatTensor, batch_size | |||
| @@ -112,13 +235,20 @@ class ConditionalRandomField(nn.Module): | |||
| return all_path_score - gold_path_score | |||
| def viterbi_decode(self, data, mask, get_score=False): | |||
| def viterbi_decode(self, data, mask, get_score=False, unpad=False): | |||
| """ | |||
| Given a feats matrix, return best decode path and best score. | |||
| :param data:FloatTensor, batch_size x max_len x tag_size | |||
| :param data:FloatTensor, batch_size x max_len x num_tags | |||
| :param mask:ByteTensor batch_size x max_len | |||
| :param get_score: bool, whether to output the decode score. | |||
| :return: scores, paths | |||
| :param unpad: bool, 是否将结果unpad, | |||
| 如果False, 返回的是batch_size x max_len的tensor, | |||
| 如果True,返回的是List[List[int]], List[int]为每个sequence的label,已经unpadding了,即每个 | |||
| List[int]的长度是这个sample的有效长度 | |||
| :return: 如果get_score为False,返回结果根据unpadding变动 | |||
| 如果get_score为True, 返回 (paths, List[float], )。第一个仍然是解码后的路径(根据unpad变化),第二个List[Float] | |||
| 为每个seqence的解码分数。 | |||
| """ | |||
| batch_size, seq_len, n_tags = data.size() | |||
| data = data.transpose(0, 1).data # L, B, H | |||
| @@ -127,19 +257,23 @@ class ConditionalRandomField(nn.Module): | |||
| # dp | |||
| vpath = data.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long) | |||
| vscore = data[0] | |||
| transitions = self._constrain.data.clone() | |||
| transitions[:n_tags, :n_tags] += self.trans_m.data | |||
| if self.include_start_end_trans: | |||
| vscore += self.start_scores.view(1, -1) | |||
| transitions[n_tags, :n_tags] += self.start_scores.data | |||
| transitions[:n_tags, n_tags+1] += self.end_scores.data | |||
| vscore += transitions[n_tags, :n_tags] | |||
| trans_score = transitions[:n_tags, :n_tags].view(1, n_tags, n_tags).data | |||
| for i in range(1, seq_len): | |||
| prev_score = vscore.view(batch_size, n_tags, 1) | |||
| cur_score = data[i].view(batch_size, 1, n_tags) | |||
| trans_score = self.trans_m.view(1, n_tags, n_tags).data | |||
| score = prev_score + trans_score + cur_score | |||
| best_score, best_dst = score.max(1) | |||
| vpath[i] = best_dst | |||
| vscore = best_score * mask[i].view(batch_size, 1) + vscore * (1 - mask[i]).view(batch_size, 1) | |||
| if self.include_start_end_trans: | |||
| vscore += self.end_scores.view(1, -1) | |||
| vscore += transitions[:n_tags, n_tags+1].view(1, -1) | |||
| # backtrace | |||
| batch_idx = torch.arange(batch_size, dtype=torch.long, device=data.device) | |||
| @@ -154,7 +288,13 @@ class ConditionalRandomField(nn.Module): | |||
| for i in range(seq_len - 1): | |||
| last_tags = vpath[idxes[i], batch_idx, last_tags] | |||
| ans[idxes[i+1], batch_idx] = last_tags | |||
| ans = ans.transpose(0, 1) | |||
| if unpad: | |||
| paths = [] | |||
| for idx, seq_len in enumerate(lens): | |||
| paths.append(ans[idx, :seq_len+1].tolist()) | |||
| else: | |||
| paths = ans | |||
| if get_score: | |||
| return ans_score, ans.transpose(0, 1) | |||
| return ans.transpose(0, 1) | |||
| return paths, ans_score.tolist() | |||
| return paths | |||
| @@ -6,6 +6,13 @@ from fastNLP.io.dataset_loader import DataSetLoader | |||
| def cut_long_sentence(sent, max_sample_length=200): | |||
| """ | |||
| 将长于max_sample_length的sentence截成多段,只会在有空格的地方发生截断。所以截取的句子可能长于或者短于max_sample_length | |||
| :param sent: str. | |||
| :param max_sample_length: int. | |||
| :return: list of str. | |||
| """ | |||
| sent_no_space = sent.replace(' ', '') | |||
| cutted_sentence = [] | |||
| if len(sent_no_space) > max_sample_length: | |||
| @@ -127,12 +134,26 @@ class POSCWSReader(DataSetLoader): | |||
| return dataset | |||
| class ConlluCWSReader(object): | |||
| # 返回的Dataset包含words(list of list, 里层的list是character), tag两个field(list of str, str是标有BMES的tag)。 | |||
| class ConllCWSReader(object): | |||
| def __init__(self): | |||
| pass | |||
| def load(self, path, cut_long_sent=False): | |||
| """ | |||
| 返回的DataSet只包含raw_sentence这个field,内容为str。 | |||
| 假定了输入为conll的格式,以空行隔开两个句子,每行共7列,即 | |||
| 1 编者按 编者按 NN O 11 nmod:topic | |||
| 2 : : PU O 11 punct | |||
| 3 7月 7月 NT DATE 4 compound:nn | |||
| 4 12日 12日 NT DATE 11 nmod:tmod | |||
| 5 , , PU O 11 punct | |||
| 1 这 这 DT O 3 det | |||
| 2 款 款 M O 1 mark:clf | |||
| 3 飞行 飞行 NN O 8 nsubj | |||
| 4 从 从 P O 5 case | |||
| 5 外型 外型 NN O 8 nmod:prep | |||
| """ | |||
| datalist = [] | |||
| with open(path, 'r', encoding='utf-8') as f: | |||
| sample = [] | |||
| @@ -150,10 +171,10 @@ class ConlluCWSReader(object): | |||
| ds = DataSet() | |||
| for sample in datalist: | |||
| # print(sample) | |||
| res = self.get_one(sample) | |||
| res = self.get_char_lst(sample) | |||
| if res is None: | |||
| continue | |||
| line = ' '.join(res) | |||
| line = ' '.join(res) | |||
| if cut_long_sent: | |||
| sents = cut_long_sentence(line) | |||
| else: | |||
| @@ -163,7 +184,7 @@ class ConlluCWSReader(object): | |||
| return ds | |||
| def get_one(self, sample): | |||
| def get_char_lst(self, sample): | |||
| if len(sample)==0: | |||
| return None | |||
| text = [] | |||
| @@ -9,7 +9,7 @@ from reproduction.chinese_word_segment.utils import seq_lens_to_mask | |||
| class CWSBiLSTMEncoder(BaseModel): | |||
| def __init__(self, vocab_num, embed_dim=100, bigram_vocab_num=None, bigram_embed_dim=100, num_bigram_per_char=None, | |||
| hidden_size=200, bidirectional=True, embed_drop_p=None, num_layers=1): | |||
| hidden_size=200, bidirectional=True, embed_drop_p=0.2, num_layers=1): | |||
| super().__init__() | |||
| self.input_size = 0 | |||
| @@ -68,6 +68,7 @@ class CWSBiLSTMEncoder(BaseModel): | |||
| if not bigrams is None: | |||
| bigram_tensor = self.bigram_embedding(bigrams).view(batch_size, max_len, -1) | |||
| x_tensor = torch.cat([x_tensor, bigram_tensor], dim=2) | |||
| x_tensor = self.embedding_drop(x_tensor) | |||
| sorted_lens, sorted_indices = torch.sort(seq_lens, descending=True) | |||
| packed_x = nn.utils.rnn.pack_padded_sequence(x_tensor[sorted_indices], sorted_lens, batch_first=True) | |||
| @@ -120,10 +121,24 @@ class CWSBiLSTMSegApp(BaseModel): | |||
| from fastNLP.modules.decoder.CRF import ConditionalRandomField | |||
| from fastNLP.modules.decoder.CRF import allowed_transitions | |||
| class CWSBiLSTMCRF(BaseModel): | |||
| def __init__(self, vocab_num, embed_dim=100, bigram_vocab_num=None, bigram_embed_dim=100, num_bigram_per_char=None, | |||
| hidden_size=200, bidirectional=True, embed_drop_p=None, num_layers=1, tag_size=4): | |||
| hidden_size=200, bidirectional=True, embed_drop_p=0.2, num_layers=1, tag_size=4): | |||
| """ | |||
| 默认使用BMES的标注方式 | |||
| :param vocab_num: | |||
| :param embed_dim: | |||
| :param bigram_vocab_num: | |||
| :param bigram_embed_dim: | |||
| :param num_bigram_per_char: | |||
| :param hidden_size: | |||
| :param bidirectional: | |||
| :param embed_drop_p: | |||
| :param num_layers: | |||
| :param tag_size: | |||
| """ | |||
| super(CWSBiLSTMCRF, self).__init__() | |||
| self.tag_size = tag_size | |||
| @@ -133,10 +148,12 @@ class CWSBiLSTMCRF(BaseModel): | |||
| size_layer = [hidden_size, 200, tag_size] | |||
| self.decoder_model = MLP(size_layer) | |||
| self.crf = ConditionalRandomField(tag_size=tag_size, include_start_end_trans=False) | |||
| allowed_trans = allowed_transitions({0:'b', 1:'m', 2:'e', 3:'s'}, encoding_type='bmes') | |||
| self.crf = ConditionalRandomField(num_tags=tag_size, include_start_end_trans=False, | |||
| allowed_transitions=allowed_trans) | |||
| def forward(self, chars, tags, seq_lens, bigrams=None): | |||
| def forward(self, chars, target, seq_lens, bigrams=None): | |||
| device = self.parameters().__next__().device | |||
| chars = chars.to(device).long() | |||
| if not bigrams is None: | |||
| @@ -147,7 +164,7 @@ class CWSBiLSTMCRF(BaseModel): | |||
| masks = seq_lens_to_mask(seq_lens) | |||
| feats = self.encoder_model(chars, bigrams, seq_lens) | |||
| feats = self.decoder_model(feats) | |||
| losses = self.crf(feats, tags, masks) | |||
| losses = self.crf(feats, target, masks) | |||
| pred_dict = {} | |||
| pred_dict['seq_lens'] = seq_lens | |||
| @@ -168,5 +185,5 @@ class CWSBiLSTMCRF(BaseModel): | |||
| feats = self.decoder_model(feats) | |||
| probs = self.crf.viterbi_decode(feats, masks, get_score=False) | |||
| return {'pred_tags': probs} | |||
| return {'pred': probs} | |||
| @@ -2,7 +2,6 @@ | |||
| import re | |||
| from fastNLP.core.field import SeqLabelField | |||
| from fastNLP.core.vocabulary import Vocabulary | |||
| from fastNLP.core.dataset import DataSet | |||
| from fastNLP.api.processor import Processor | |||
| @@ -11,7 +10,10 @@ from reproduction.chinese_word_segment.process.span_converter import SpanConvert | |||
| _SPECIAL_TAG_PATTERN = '<[a-zA-Z]+>' | |||
| class SpeicalSpanProcessor(Processor): | |||
| # 这个类会将句子中的special span转换为对应的内容。 | |||
| """ | |||
| 将DataSet中field_name使用span_converter替换掉。 | |||
| """ | |||
| def __init__(self, field_name, new_added_field_name=None): | |||
| super(SpeicalSpanProcessor, self).__init__(field_name, new_added_field_name) | |||
| @@ -20,11 +22,12 @@ class SpeicalSpanProcessor(Processor): | |||
| def process(self, dataset): | |||
| assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | |||
| for ins in dataset: | |||
| def inner_proc(ins): | |||
| sentence = ins[self.field_name] | |||
| for span_converter in self.span_converters: | |||
| sentence = span_converter.find_certain_span_and_replace(sentence) | |||
| ins[self.new_added_field_name] = sentence | |||
| return sentence | |||
| dataset.apply(func=inner_proc, new_field_name=self.new_added_field_name) | |||
| return dataset | |||
| @@ -34,17 +37,22 @@ class SpeicalSpanProcessor(Processor): | |||
| self.span_converters.append(converter) | |||
| class CWSCharSegProcessor(Processor): | |||
| """ | |||
| 将DataSet中field_name这个field分成一个个的汉字,即原来可能为"复旦大学 fudan", 分成['复', '旦', '大', '学', | |||
| ' ', 'f', 'u', ...] | |||
| """ | |||
| def __init__(self, field_name, new_added_field_name): | |||
| super(CWSCharSegProcessor, self).__init__(field_name, new_added_field_name) | |||
| def process(self, dataset): | |||
| assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | |||
| for ins in dataset: | |||
| def inner_proc(ins): | |||
| sentence = ins[self.field_name] | |||
| chars = self._split_sent_into_chars(sentence) | |||
| ins[self.new_added_field_name] = chars | |||
| return chars | |||
| dataset.apply(func=inner_proc, new_field_name=self.new_added_field_name) | |||
| return dataset | |||
| @@ -73,6 +81,10 @@ class CWSCharSegProcessor(Processor): | |||
| class CWSTagProcessor(Processor): | |||
| """ | |||
| 为分词生成tag。该class为Base class。 | |||
| """ | |||
| def __init__(self, field_name, new_added_field_name=None): | |||
| super(CWSTagProcessor, self).__init__(field_name, new_added_field_name) | |||
| @@ -107,18 +119,22 @@ class CWSTagProcessor(Processor): | |||
| def process(self, dataset): | |||
| assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | |||
| for ins in dataset: | |||
| def inner_proc(ins): | |||
| sentence = ins[self.field_name] | |||
| tag_list = self._generate_tag(sentence) | |||
| ins[self.new_added_field_name] = tag_list | |||
| dataset.set_target(**{self.new_added_field_name:True}) | |||
| dataset._set_need_tensor(**{self.new_added_field_name:True}) | |||
| return tag_list | |||
| dataset.apply(func=inner_proc, new_field_name=self.new_added_field_name) | |||
| dataset.set_target(self.new_added_field_name) | |||
| return dataset | |||
| def _tags_from_word_len(self, word_len): | |||
| raise NotImplementedError | |||
| class CWSBMESTagProcessor(CWSTagProcessor): | |||
| """ | |||
| 通过DataSet中的field_name这个field生成相应的BMES的tag。 | |||
| """ | |||
| def __init__(self, field_name, new_added_field_name=None): | |||
| super(CWSBMESTagProcessor, self).__init__(field_name, new_added_field_name) | |||
| @@ -137,6 +153,10 @@ class CWSBMESTagProcessor(CWSTagProcessor): | |||
| return tag_list | |||
| class CWSSegAppTagProcessor(CWSTagProcessor): | |||
| """ | |||
| 通过DataSet中的field_name这个field生成相应的SegApp的tag。 | |||
| """ | |||
| def __init__(self, field_name, new_added_field_name=None): | |||
| super(CWSSegAppTagProcessor, self).__init__(field_name, new_added_field_name) | |||
| @@ -151,6 +171,10 @@ class CWSSegAppTagProcessor(CWSTagProcessor): | |||
| class BigramProcessor(Processor): | |||
| """ | |||
| 这是生成bigram的基类。 | |||
| """ | |||
| def __init__(self, field_name, new_added_fielf_name=None): | |||
| super(BigramProcessor, self).__init__(field_name, new_added_fielf_name) | |||
| @@ -158,22 +182,31 @@ class BigramProcessor(Processor): | |||
| def process(self, dataset): | |||
| assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | |||
| for ins in dataset: | |||
| def inner_proc(ins): | |||
| characters = ins[self.field_name] | |||
| bigrams = self._generate_bigram(characters) | |||
| ins[self.new_added_field_name] = bigrams | |||
| return bigrams | |||
| dataset.apply(func=inner_proc, new_field_name=self.new_added_field_name) | |||
| return dataset | |||
| def _generate_bigram(self, characters): | |||
| pass | |||
| class Pre2Post2BigramProcessor(BigramProcessor): | |||
| def __init__(self, field_name, new_added_fielf_name=None): | |||
| """ | |||
| 该bigram processor生成bigram的方式如下 | |||
| 原汉字list为l = ['a', 'b', 'c'],会被padding为L=['SOS', 'SOS', 'a', 'b', 'c', 'EOS', 'EOS'],生成bigram list为 | |||
| [L[idx-2], L[idx-1], L[idx+1], L[idx+2], L[idx-2]L[idx], L[idx-1]L[idx], L[idx]L[idx+1], L[idx]L[idx+2], ....] | |||
| 即每个汉字,会有八个bigram, 对于上例中'a'的bigram为 | |||
| ['SOS', 'SOS', 'b', 'c', 'SOSa', 'SOSa', 'ab', 'ac'] | |||
| 返回的bigram是一个list,但其实每8个元素是一个汉字的bigram信息。 | |||
| """ | |||
| def __init__(self, field_name, new_added_field_name=None): | |||
| super(BigramProcessor, self).__init__(field_name, new_added_fielf_name) | |||
| super(BigramProcessor, self).__init__(field_name, new_added_field_name) | |||
| def _generate_bigram(self, characters): | |||
| bigrams = [] | |||
| @@ -197,20 +230,102 @@ class Pre2Post2BigramProcessor(BigramProcessor): | |||
| # 这里需要建立vocabulary了,但是遇到了以下的问题 | |||
| # (1) 如果使用Processor的方式的话,但是在这种情况返回的不是dataset。所以建立vocabulary的工作用另外的方式实现,不借用 | |||
| # Processor了 | |||
| # TODO 如何将建立vocab和index这两步统一了? | |||
| class VocabIndexerProcessor(Processor): | |||
| """ | |||
| 根据DataSet创建Vocabulary,并将其用数字index。新生成的index的field会被放在new_added_filed_name, 如果没有提供 | |||
| new_added_field_name, 则覆盖原有的field_name. | |||
| """ | |||
| def __init__(self, field_name, new_added_filed_name=None, min_freq=1, max_size=None, | |||
| verbose=1): | |||
| """ | |||
| :param field_name: 从哪个field_name创建词表,以及对哪个field_name进行index操作 | |||
| :param new_added_filed_name: index时,生成的index field的名称,如果不传入,则覆盖field_name. | |||
| :param min_freq: 创建的Vocabulary允许的单词最少出现次数. | |||
| :param max_size: 创建的Vocabulary允许的最大的单词数量 | |||
| :param verbose: 0, 不输出任何信息;1,输出信息 | |||
| """ | |||
| super(VocabIndexerProcessor, self).__init__(field_name, new_added_filed_name) | |||
| self.min_freq = min_freq | |||
| self.max_size = max_size | |||
| self.verbose =verbose | |||
| def construct_vocab(self, *datasets): | |||
| """ | |||
| 使用传入的DataSet创建vocabulary | |||
| :param datasets: DataSet类型的数据,用于构建vocabulary | |||
| :return: | |||
| """ | |||
| self.vocab = Vocabulary(min_freq=self.min_freq, max_size=self.max_size) | |||
| for dataset in datasets: | |||
| assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | |||
| dataset.apply(lambda ins: self.vocab.update(ins[self.field_name])) | |||
| self.vocab.build_vocab() | |||
| if self.verbose: | |||
| print("Vocabulary Constructed, has {} items.".format(len(self.vocab))) | |||
| def process(self, *datasets, only_index_dataset=None): | |||
| """ | |||
| 若还未建立Vocabulary,则使用dataset中的DataSet建立vocabulary;若已经有了vocabulary则使用已有的vocabulary。得到vocabulary | |||
| 后,则会index datasets与only_index_dataset。 | |||
| :param datasets: DataSet类型的数据 | |||
| :param only_index_dataset: DataSet, or list of DataSet. 该参数中的内容只会被用于index,不会被用于生成vocabulary。 | |||
| :return: | |||
| """ | |||
| if len(datasets)==0 and not hasattr(self,'vocab'): | |||
| raise RuntimeError("You have to construct vocabulary first. Or you have to pass datasets to construct it.") | |||
| if not hasattr(self, 'vocab'): | |||
| self.construct_vocab(*datasets) | |||
| else: | |||
| if self.verbose: | |||
| print("Using constructed vocabulary with {} items.".format(len(self.vocab))) | |||
| to_index_datasets = [] | |||
| if len(datasets)!=0: | |||
| for dataset in datasets: | |||
| assert isinstance(dataset, DataSet), "Only DataSet class is allowed, not {}.".format(type(dataset)) | |||
| to_index_datasets.append(dataset) | |||
| if not (only_index_dataset is None): | |||
| if isinstance(only_index_dataset, list): | |||
| for dataset in only_index_dataset: | |||
| assert isinstance(dataset, DataSet), "Only DataSet class is allowed, not {}.".format(type(dataset)) | |||
| to_index_datasets.append(dataset) | |||
| elif isinstance(only_index_dataset, DataSet): | |||
| to_index_datasets.append(only_index_dataset) | |||
| else: | |||
| raise TypeError('Only DataSet or list of DataSet is allowed, not {}.'.format(type(only_index_dataset))) | |||
| for dataset in to_index_datasets: | |||
| assert isinstance(dataset, DataSet), "Only DataSet class is allowed, not {}.".format(type(dataset)) | |||
| dataset.apply(lambda ins: [self.vocab.to_index(token) for token in ins[self.field_name]], | |||
| new_field_name=self.new_added_field_name) | |||
| def set_vocab(self, vocab): | |||
| assert isinstance(vocab, Vocabulary), "Only fastNLP.core.Vocabulary is allowed, not {}.".format(type(vocab)) | |||
| self.vocab = vocab | |||
| def delete_vocab(self): | |||
| del self.vocab | |||
| def get_vocab_size(self): | |||
| return len(self.vocab) | |||
| class VocabProcessor(Processor): | |||
| def __init__(self, field_name, min_count=1, max_vocab_size=None): | |||
| def __init__(self, field_name, min_freq=1, max_size=None): | |||
| super(VocabProcessor, self).__init__(field_name, None) | |||
| self.vocab = Vocabulary(min_freq=min_count, max_size=max_vocab_size) | |||
| self.vocab = Vocabulary(min_freq=min_freq, max_size=max_size) | |||
| def process(self, *datasets): | |||
| for dataset in datasets: | |||
| assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | |||
| for ins in dataset: | |||
| tokens = ins[self.field_name] | |||
| self.vocab.update(tokens) | |||
| dataset.apply(lambda ins: self.vocab.update(ins[self.field_name])) | |||
| def get_vocab(self): | |||
| self.vocab.build_vocab() | |||
| @@ -220,19 +335,6 @@ class VocabProcessor(Processor): | |||
| return len(self.vocab) | |||
| class SeqLenProcessor(Processor): | |||
| def __init__(self, field_name, new_added_field_name='seq_lens'): | |||
| super(SeqLenProcessor, self).__init__(field_name, new_added_field_name) | |||
| def process(self, dataset): | |||
| assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | |||
| for ins in dataset: | |||
| length = len(ins[self.field_name]) | |||
| ins[self.new_added_field_name] = length | |||
| dataset._set_need_tensor(**{self.new_added_field_name:True}) | |||
| return dataset | |||
| class SegApp2OutputProcessor(Processor): | |||
| def __init__(self, chars_field_name='chars_list', tag_field_name='pred_tags', new_added_field_name='output'): | |||
| super(SegApp2OutputProcessor, self).__init__(None, None) | |||
| @@ -258,7 +360,32 @@ class SegApp2OutputProcessor(Processor): | |||
| class BMES2OutputProcessor(Processor): | |||
| def __init__(self, chars_field_name='chars_list', tag_field_name='pred_tags', new_added_field_name='output'): | |||
| """ | |||
| 按照BMES标注方式推测生成的tag。由于可能存在非法tag,比如"BS",所以需要用以下的表格做转换,cur_B意思是当前tag是B, | |||
| next_B意思是后一个tag是B。则cur_B=S,即将当前被predict是B的tag标为S;next_M=B, 即将后一个被predict是M的tag标为B | |||
| | | next_B | next_M | next_E | next_S | end | | |||
| |:-----:|:-------:|:--------:|:--------:|:-------:|:-------:| | |||
| | start | 合法 | next_M=B | next_E=S | 合法 | - | | |||
| | cur_B | cur_B=S | 合法 | 合法 | cur_B=S | cur_B=S | | |||
| | cur_M | cur_M=E | 合法 | 合法 | cur_M=E | cur_M=E | | |||
| | cur_E | 合法 | next_M=B | next_E=S | 合法 | 合法 | | |||
| | cur_S | 合法 | next_M=B | next_E=S | 合法 | 合法 | | |||
| 举例: | |||
| prediction为BSEMS,会被认为是SSSSS. | |||
| """ | |||
| def __init__(self, chars_field_name='chars_list', tag_field_name='pred_tags', new_added_field_name='output', | |||
| b_idx = 0, m_idx = 1, e_idx = 2, s_idx = 3): | |||
| """ | |||
| :param chars_field_name: character所对应的field | |||
| :param tag_field_name: 预测对应的field | |||
| :param new_added_field_name: 转换后的内容所在field | |||
| :param b_idx: int, Begin标签所对应的tag idx. | |||
| :param m_idx: int, Middle标签所对应的tag idx. | |||
| :param e_idx: int, End标签所对应的tag idx. | |||
| :param s_idx: int, Single标签所对应的tag idx | |||
| """ | |||
| super(BMES2OutputProcessor, self).__init__(None, None) | |||
| self.chars_field_name = chars_field_name | |||
| @@ -266,19 +393,55 @@ class BMES2OutputProcessor(Processor): | |||
| self.new_added_field_name = new_added_field_name | |||
| self.b_idx = b_idx | |||
| self.m_idx = m_idx | |||
| self.e_idx = e_idx | |||
| self.s_idx = s_idx | |||
| # 还原init处介绍的矩阵 | |||
| self._valida_matrix = { | |||
| -1: [(-1, -1), (1, self.b_idx), (1, self.s_idx), (-1, -1)], # magic start idx | |||
| self.b_idx:[(0, self.s_idx), (-1, -1), (-1, -1), (0, self.s_idx), (0, self.s_idx)], | |||
| self.m_idx:[(0, self.e_idx), (-1, -1), (-1, -1), (0, self.e_idx), (0, self.e_idx)], | |||
| self.e_idx:[(-1, -1), (1, self.b_idx), (1, self.s_idx), (-1, -1), (-1, -1)], | |||
| self.s_idx:[(-1, -1), (1, self.b_idx), (1, self.s_idx), (-1, -1), (-1, -1)], | |||
| } | |||
| def _validate_tags(self, tags): | |||
| """ | |||
| 给定一个tag的List,返回合法tag | |||
| :param tags: Tensor, shape: (seq_len, ) | |||
| :return: 返回修改为合法tag的list | |||
| """ | |||
| assert len(tags)!=0 | |||
| padded_tags = [-1, *tags, -1] | |||
| for idx in range(len(padded_tags)-1): | |||
| cur_tag = padded_tags[idx] | |||
| if cur_tag not in self._valida_matrix: | |||
| cur_tag = self.s_idx | |||
| if padded_tags[idx+1] not in self._valida_matrix: | |||
| padded_tags[idx+1] = self.s_idx | |||
| next_tag = padded_tags[idx+1] | |||
| shift_tag = self._valida_matrix[cur_tag][next_tag] | |||
| if shift_tag[0]!=-1: | |||
| padded_tags[idx+shift_tag[0]] = shift_tag[1] | |||
| return padded_tags[1:-1] | |||
| def process(self, dataset): | |||
| assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | |||
| for ins in dataset: | |||
| def inner_proc(ins): | |||
| pred_tags = ins[self.tag_field_name] | |||
| pred_tags = self._validate_tags(pred_tags) | |||
| chars = ins[self.chars_field_name] | |||
| words = [] | |||
| start_idx = 0 | |||
| for idx, tag in enumerate(pred_tags): | |||
| if tag==3: | |||
| # 当前没有考虑将原文替换回去 | |||
| if tag==self.s_idx: | |||
| words.extend(chars[start_idx:idx+1]) | |||
| start_idx = idx + 1 | |||
| elif tag==2: | |||
| elif tag==self.e_idx: | |||
| words.append(''.join(chars[start_idx:idx+1])) | |||
| start_idx = idx + 1 | |||
| ins[self.new_added_field_name] = ' '.join(words) | |||
| return ' '.join(words) | |||
| dataset.apply(func=inner_proc, new_field_name=self.new_added_field_name) | |||
| @@ -24,8 +24,8 @@ def cut_long_sentence(sent, max_sample_length=200): | |||
| return cutted_sentence | |||
| class ConlluPOSReader(object): | |||
| # 返回的Dataset包含words(list of list, 里层的list是character), tag两个field(list of str, str是标有BMES的tag)。 | |||
| class ConllPOSReader(object): | |||
| # 返回的Dataset包含words(list of list, 里层的list是character), tag两个field(list of str, str是标有BIO的tag)。 | |||
| def __init__(self): | |||
| pass | |||
| @@ -70,6 +70,70 @@ class ConlluPOSReader(object): | |||
| return ds | |||
| class ZhConllPOSReader(object): | |||
| # 中文colln格式reader | |||
| def __init__(self): | |||
| pass | |||
| def load(self, path): | |||
| """ | |||
| 返回的DataSet, 包含以下的field | |||
| words:list of str, | |||
| tag: list of str, 被加入了BMES tag, 比如原来的序列为['VP', 'NN', 'NN', ..],会被认为是["S-VP", "B-NN", "M-NN",..] | |||
| 假定了输入为conll的格式,以空行隔开两个句子,每行共7列,即 | |||
| 1 编者按 编者按 NN O 11 nmod:topic | |||
| 2 : : PU O 11 punct | |||
| 3 7月 7月 NT DATE 4 compound:nn | |||
| 4 12日 12日 NT DATE 11 nmod:tmod | |||
| 5 , , PU O 11 punct | |||
| 1 这 这 DT O 3 det | |||
| 2 款 款 M O 1 mark:clf | |||
| 3 飞行 飞行 NN O 8 nsubj | |||
| 4 从 从 P O 5 case | |||
| 5 外型 外型 NN O 8 nmod:prep | |||
| """ | |||
| datalist = [] | |||
| with open(path, 'r', encoding='utf-8') as f: | |||
| sample = [] | |||
| for line in f: | |||
| if line.startswith('\n'): | |||
| datalist.append(sample) | |||
| sample = [] | |||
| elif line.startswith('#'): | |||
| continue | |||
| else: | |||
| sample.append(line.split('\t')) | |||
| if len(sample) > 0: | |||
| datalist.append(sample) | |||
| ds = DataSet() | |||
| for sample in datalist: | |||
| # print(sample) | |||
| res = self.get_one(sample) | |||
| if res is None: | |||
| continue | |||
| char_seq = [] | |||
| pos_seq = [] | |||
| for word, tag in zip(res[0], res[1]): | |||
| char_seq.extend(list(word)) | |||
| if len(word)==1: | |||
| pos_seq.append('S-{}'.format(tag)) | |||
| elif len(word)>1: | |||
| pos_seq.append('B-{}'.format(tag)) | |||
| for _ in range(len(word)-2): | |||
| pos_seq.append('M-{}'.format(tag)) | |||
| pos_seq.append('E-{}'.format(tag)) | |||
| else: | |||
| raise ValueError("Zero length of word detected.") | |||
| ds.append(Instance(words=char_seq, | |||
| tag=pos_seq)) | |||
| return ds | |||
| def get_one(self, sample): | |||
| if len(sample)==0: | |||
| return None | |||
| @@ -84,6 +148,6 @@ class ConlluPOSReader(object): | |||
| return text, pos_tags | |||
| if __name__ == '__main__': | |||
| reader = ConlluPOSReader() | |||
| reader = ZhConllPOSReader() | |||
| d = reader.load('/home/hyan/train.conllx') | |||
| print('reader') | |||
| print(d) | |||
| @@ -4,6 +4,7 @@ import numpy as np | |||
| import torch | |||
| from fastNLP.core.metrics import AccuracyMetric | |||
| from fastNLP.core.metrics import BMESF1PreRecMetric | |||
| from fastNLP.core.metrics import pred_topk, accuracy_topk | |||
| @@ -132,6 +133,234 @@ class TestAccuracyMetric(unittest.TestCase): | |||
| return | |||
| self.assertTrue(True, False), "No exception catches." | |||
| class SpanF1PreRecMetric(unittest.TestCase): | |||
| def test_case1(self): | |||
| from fastNLP.core.metrics import bmes_tag_to_spans | |||
| from fastNLP.core.metrics import bio_tag_to_spans | |||
| bmes_lst = ['M-8', 'S-2', 'S-0', 'B-9', 'B-6', 'E-5', 'B-7', 'S-2', 'E-7', 'S-8'] | |||
| bio_lst = ['O-8', 'O-2', 'B-0', 'O-9', 'I-6', 'I-5', 'I-7', 'I-2', 'I-7', 'O-8'] | |||
| expect_bmes_res = set() | |||
| expect_bmes_res.update([('8', (0, 0)), ('2', (1, 1)), ('0', (2, 2)), ('9', (3, 3)), ('6', (4, 4)), | |||
| ('5', (5, 5)), ('7', (6, 6)), ('2', (7, 7)), ('7', (8, 8)), ('8', (9, 9))]) | |||
| expect_bio_res = set() | |||
| expect_bio_res.update([('7', (8, 8)), ('0', (2, 2)), ('2', (7, 7)), ('5', (5, 5)), | |||
| ('6', (4, 4)), ('7', (6, 6))]) | |||
| self.assertSetEqual(expect_bmes_res,set(bmes_tag_to_spans(bmes_lst))) | |||
| self.assertSetEqual(expect_bio_res, set(bio_tag_to_spans(bio_lst))) | |||
| # 已与allennlp对应函数做过验证,但由于测试不能依赖allennlp,所以这里只是截取上面的例子做固定测试 | |||
| # from allennlp.data.dataset_readers.dataset_utils import bio_tags_to_spans as allen_bio_tags_to_spans | |||
| # from allennlp.data.dataset_readers.dataset_utils import bmes_tags_to_spans as allen_bmes_tags_to_spans | |||
| # for i in range(1000): | |||
| # strs = list(map(str, np.random.randint(100, size=1000))) | |||
| # bmes = list('bmes'.upper()) | |||
| # bmes_strs = [str_ + '-' + tag for tag, str_ in zip(strs, np.random.choice(bmes, size=len(strs)))] | |||
| # bio = list('bio'.upper()) | |||
| # bio_strs = [str_ + '-' + tag for tag, str_ in zip(strs, np.random.choice(bio, size=len(strs)))] | |||
| # self.assertSetEqual(set(allen_bmes_tags_to_spans(bmes_strs)),set(bmes_tag_to_spans(bmes_strs))) | |||
| # self.assertSetEqual(set(allen_bio_tags_to_spans(bio_strs)), set(bio_tag_to_spans(bio_strs))) | |||
| def test_case2(self): | |||
| # 测试不带label的 | |||
| from fastNLP.core.metrics import bmes_tag_to_spans | |||
| from fastNLP.core.metrics import bio_tag_to_spans | |||
| bmes_lst = ['B', 'E', 'B', 'S', 'B', 'M', 'E', 'M', 'B', 'E'] | |||
| bio_lst = ['I', 'B', 'O', 'O', 'I', 'O', 'I', 'B', 'O', 'O'] | |||
| expect_bmes_res = set() | |||
| expect_bmes_res.update([('', (0, 1)), ('', (2, 2)), ('', (3, 3)), ('', (4, 6)), ('', (7, 7)), ('', (8, 9))]) | |||
| expect_bio_res = set() | |||
| expect_bio_res.update([('', (7, 7)), ('', (6, 6)), ('', (4, 4)), ('', (0, 0)), ('', (1, 1))]) | |||
| self.assertSetEqual(expect_bmes_res,set(bmes_tag_to_spans(bmes_lst))) | |||
| self.assertSetEqual(expect_bio_res, set(bio_tag_to_spans(bio_lst))) | |||
| # 已与allennlp对应函数做过验证,但由于测试不能依赖allennlp,所以这里只是截取上面的例子做固定测试 | |||
| # from allennlp.data.dataset_readers.dataset_utils import bio_tags_to_spans as allen_bio_tags_to_spans | |||
| # from allennlp.data.dataset_readers.dataset_utils import bmes_tags_to_spans as allen_bmes_tags_to_spans | |||
| # for i in range(1000): | |||
| # bmes = list('bmes'.upper()) | |||
| # bmes_strs = np.random.choice(bmes, size=1000) | |||
| # bio = list('bio'.upper()) | |||
| # bio_strs = np.random.choice(bio, size=100) | |||
| # self.assertSetEqual(set(allen_bmes_tags_to_spans(bmes_strs)),set(bmes_tag_to_spans(bmes_strs))) | |||
| # self.assertSetEqual(set(allen_bio_tags_to_spans(bio_strs)), set(bio_tag_to_spans(bio_strs))) | |||
| def tese_case3(self): | |||
| from fastNLP.core.vocabulary import Vocabulary | |||
| from collections import Counter | |||
| from fastNLP.core.metrics import SpanFPreRecMetric | |||
| # 与allennlp测试能否正确计算f metric | |||
| # | |||
| def generate_allen_tags(encoding_type, number_labels=4): | |||
| vocab = {} | |||
| for i in range(number_labels): | |||
| label = str(i) | |||
| for tag in encoding_type: | |||
| if tag == 'O': | |||
| if tag not in vocab: | |||
| vocab['O'] = len(vocab) + 1 | |||
| continue | |||
| vocab['{}-{}'.format(tag, label)] = len(vocab) + 1 # 其实表达的是这个的count | |||
| return vocab | |||
| number_labels = 4 | |||
| # bio tag | |||
| fastnlp_bio_vocab = Vocabulary(unknown=None, padding=None) | |||
| fastnlp_bio_vocab.word_count = Counter(generate_allen_tags('BIO', number_labels)) | |||
| fastnlp_bio_metric = SpanFPreRecMetric(tag_vocab=fastnlp_bio_vocab, only_gross=False) | |||
| bio_sequence = torch.FloatTensor( | |||
| [[[-0.9543, -1.4357, -0.2365, 0.2438, 1.0312, -1.4302, 0.3011, | |||
| 0.0470, 0.0971], | |||
| [-0.6638, -0.7116, -1.9804, 0.2787, -0.2732, -0.9501, -1.4523, | |||
| 0.7987, -0.3970], | |||
| [0.2939, 0.8132, -0.0903, -2.8296, 0.2080, -0.9823, -0.1898, | |||
| 0.6880, 1.4348], | |||
| [-0.1886, 0.0067, -0.6862, -0.4635, 2.2776, 0.0710, -1.6793, | |||
| -1.6876, -0.8917], | |||
| [-0.7663, 0.6377, 0.8669, 0.1237, 1.7628, 0.0313, -1.0824, | |||
| 1.4217, 0.2622]], | |||
| [[0.1529, 0.7474, -0.9037, 1.5287, 0.2771, 0.2223, 0.8136, | |||
| 1.3592, -0.8973], | |||
| [0.4515, -0.5235, 0.3265, -1.1947, 0.8308, 1.8754, -0.4887, | |||
| -0.4025, -0.3417], | |||
| [-0.7855, 0.1615, -0.1272, -1.9289, -0.5181, 1.9742, -0.9698, | |||
| 0.2861, -0.3966], | |||
| [-0.8291, -0.8823, -1.1496, 0.2164, 1.3390, -0.3964, -0.5275, | |||
| 0.0213, 1.4777], | |||
| [-1.1299, 0.0627, -0.1358, -1.5951, 0.4484, -0.6081, -1.9566, | |||
| 1.3024, 0.2001]]] | |||
| ) | |||
| bio_target = torch.LongTensor([[5., 0., 3., 3., 3.], | |||
| [5., 6., 8., 6., 0.]]) | |||
| fastnlp_bio_metric({'pred': bio_sequence, 'seq_lens': torch.LongTensor([5, 5])}, {'target': bio_target}) | |||
| expect_bio_res = {'pre-1': 0.24999999999999373, 'rec-1': 0.499999999999975, 'f-1': 0.33333333333327775, | |||
| 'pre-2': 0.0, 'rec-2': 0.0, 'f-2': 0.0, 'pre-3': 0.0, 'rec-3': 0.0, 'f-3': 0.0, 'pre-0': 0.0, | |||
| 'rec-0': 0.0, 'f-0': 0.0, 'pre': 0.12499999999999845, 'rec': 0.12499999999999845, | |||
| 'f': 0.12499999999994846} | |||
| self.assertDictEqual(expect_bio_res, fastnlp_bio_metric.get_metric()) | |||
| #bmes tag | |||
| bmes_sequence = torch.FloatTensor( | |||
| [[[0.6536, -0.7179, 0.6579, 1.2503, 0.4176, 0.6696, 0.2352, | |||
| -0.4085, 0.4084, -0.4185, 1.4172, -0.9162, -0.2679, 0.3332, | |||
| -0.3505, -0.6002], | |||
| [0.3238, -1.2378, -1.3304, -0.4903, 1.4518, -0.1868, -0.7641, | |||
| 1.6199, -0.8877, 0.1449, 0.8995, -0.5810, 0.1041, 0.1002, | |||
| 0.4439, 0.2514], | |||
| [-0.8362, 2.9526, 0.8008, 0.1193, 1.0488, 0.6670, 1.1696, | |||
| -1.1006, -0.8540, -0.1600, -0.9519, -0.2749, -0.4948, -1.4753, | |||
| 0.5802, -0.0516], | |||
| [-0.8383, -1.7292, -1.4079, -1.5023, 0.5383, 0.6653, 0.3121, | |||
| 4.1249, -0.4173, -0.2043, 1.7755, 1.1110, -1.7069, -0.0390, | |||
| -0.9242, -0.0333], | |||
| [0.9088, -0.4955, -0.5076, 0.3732, 0.0283, -0.0263, -1.0393, | |||
| 0.7734, 1.0968, 0.4132, -1.3647, -0.5762, 0.6678, 0.8809, | |||
| -0.3779, -0.3195]], | |||
| [[-0.4638, -0.5939, -0.1052, -0.5573, 0.4600, -1.3484, 0.1753, | |||
| 0.0685, 0.3663, -0.6789, 0.0097, 1.0327, -0.0212, -0.9957, | |||
| -0.1103, 0.4417], | |||
| [-0.2903, 0.9205, -1.5758, -1.0421, 0.2921, -0.2142, -0.3049, | |||
| -0.0879, -0.4412, -1.3195, -0.0657, -0.2986, 0.7214, 0.0631, | |||
| -0.6386, 0.2797], | |||
| [0.6440, -0.3748, 1.2912, -0.0170, 0.7447, 1.4075, -0.4947, | |||
| 0.4123, -0.8447, -0.5502, 0.3520, -0.2832, 0.5019, -0.1522, | |||
| 1.1237, -1.5385], | |||
| [0.2839, -0.7649, 0.9067, -0.1163, -1.3789, 0.2571, -1.3977, | |||
| -0.3680, -0.8902, -0.6983, -1.1583, 1.2779, 0.2197, 0.1376, | |||
| -0.0591, -0.2461], | |||
| [-0.2977, -1.8564, -0.5347, 1.0011, -1.1260, 0.4252, -2.0097, | |||
| 2.6973, -0.8308, -1.4939, 0.9865, -0.3935, 0.2743, 0.1142, | |||
| -0.7344, -1.2046]]] | |||
| ) | |||
| bmes_target = torch.LongTensor([[ 9., 6., 1., 9., 15.], | |||
| [ 6., 15., 6., 15., 5.]]) | |||
| fastnlp_bmes_vocab = Vocabulary(unknown=None, padding=None) | |||
| fastnlp_bmes_vocab.word_count = Counter(generate_allen_tags('BMES', number_labels)) | |||
| fastnlp_bmes_metric = SpanFPreRecMetric(tag_vocab=fastnlp_bmes_vocab, only_gross=False, encoding_type='bmes') | |||
| fastnlp_bmes_metric({'pred': bmes_sequence, 'seq_lens': torch.LongTensor([20, 20])}, {'target': bmes_target}) | |||
| expect_bmes_res = {'f-3': 0.6666666666665778, 'pre-3': 0.499999999999975, 'rec-3': 0.9999999999999001, | |||
| 'f-0': 0.0, 'pre-0': 0.0, 'rec-0': 0.0, 'f-1': 0.33333333333327775, | |||
| 'pre-1': 0.24999999999999373, 'rec-1': 0.499999999999975, 'f-2': 0.7499999999999314, | |||
| 'pre-2': 0.7499999999999812, 'rec-2': 0.7499999999999812, 'f': 0.49999999999994504, | |||
| 'pre': 0.499999999999995, 'rec': 0.499999999999995} | |||
| self.assertDictEqual(fastnlp_bmes_metric.get_metric(), expect_bmes_res) | |||
| # 已经和allennlp做过验证,但由于不能依赖allennlp,所以注释了以下代码 | |||
| # from allennlp.data.vocabulary import Vocabulary as allen_Vocabulary | |||
| # from allennlp.training.metrics import SpanBasedF1Measure | |||
| # allen_bio_vocab = allen_Vocabulary({"tags": generate_allen_tags('BIO', number_labels)}, | |||
| # non_padded_namespaces=['tags']) | |||
| # allen_bio_metric = SpanBasedF1Measure(allen_bio_vocab, 'tags') | |||
| # bio_sequence = torch.randn(size=(2, 20, 2 * number_labels + 1)) | |||
| # bio_target = torch.randint(2 * number_labels + 1, size=(2, 20)) | |||
| # allen_bio_metric(bio_sequence, bio_target, torch.ones(2, 20)) | |||
| # fastnlp_bio_vocab = Vocabulary(unknown=None, padding=None) | |||
| # fastnlp_bio_vocab.word_count = Counter(generate_allen_tags('BIO', number_labels)) | |||
| # fastnlp_bio_metric = SpanFPreRecMetric(tag_vocab=fastnlp_bio_vocab, only_gross=False) | |||
| # | |||
| # def convert_allen_res_to_fastnlp_res(metric_result): | |||
| # allen_result = {} | |||
| # key_map = {'f1-measure-overall': "f", "recall-overall": "rec", "precision-overall": "pre"} | |||
| # for key, value in metric_result.items(): | |||
| # if key in key_map: | |||
| # key = key_map[key] | |||
| # else: | |||
| # label = key.split('-')[-1] | |||
| # if key.startswith('f1'): | |||
| # key = 'f-{}'.format(label) | |||
| # else: | |||
| # key = '{}-{}'.format(key[:3], label) | |||
| # allen_result[key] = value | |||
| # return allen_result | |||
| # | |||
| # # print(convert_allen_res_to_fastnlp_res(allen_bio_metric.get_metric())) | |||
| # # print(fastnlp_bio_metric.get_metric()) | |||
| # self.assertDictEqual(convert_allen_res_to_fastnlp_res(allen_bio_metric.get_metric()), | |||
| # fastnlp_bio_metric.get_metric()) | |||
| # | |||
| # allen_bmes_vocab = allen_Vocabulary({"tags": generate_allen_tags('BMES', number_labels)}) | |||
| # allen_bmes_metric = SpanBasedF1Measure(allen_bmes_vocab, 'tags', label_encoding='BMES') | |||
| # fastnlp_bmes_vocab = Vocabulary(unknown=None, padding=None) | |||
| # fastnlp_bmes_vocab.word_count = Counter(generate_allen_tags('BMES', number_labels)) | |||
| # fastnlp_bmes_metric = SpanFPreRecMetric(tag_vocab=fastnlp_bmes_vocab, only_gross=False, encoding_type='bmes') | |||
| # bmes_sequence = torch.randn(size=(2, 20, 4 * number_labels)) | |||
| # bmes_target = torch.randint(4 * number_labels, size=(2, 20)) | |||
| # allen_bmes_metric(bmes_sequence, bmes_target, torch.ones(2, 20)) | |||
| # fastnlp_bmes_metric({'pred': bmes_sequence, 'seq_lens': torch.LongTensor([20, 20])}, {'target': bmes_target}) | |||
| # | |||
| # # print(convert_allen_res_to_fastnlp_res(allen_bmes_metric.get_metric())) | |||
| # # print(fastnlp_bmes_metric.get_metric()) | |||
| # self.assertDictEqual(convert_allen_res_to_fastnlp_res(allen_bmes_metric.get_metric()), | |||
| # fastnlp_bmes_metric.get_metric()) | |||
| class TestBMESF1PreRecMetric(unittest.TestCase): | |||
| def test_case1(self): | |||
| seq_lens = torch.LongTensor([4, 2]) | |||
| pred = torch.randn(2, 4, 4) | |||
| target = torch.LongTensor([[0, 1, 2, 3], | |||
| [3, 3, 0, 0]]) | |||
| pred_dict = {'pred': pred} | |||
| target_dict = {'target': target, 'seq_lens': seq_lens} | |||
| metric = BMESF1PreRecMetric() | |||
| metric(pred_dict, target_dict) | |||
| metric.get_metric() | |||
| def test_case2(self): | |||
| # 测试相同两个seqence,应该给出{f1: 1, precision:1, recall:1} | |||
| seq_lens = torch.LongTensor([4, 2]) | |||
| target = torch.LongTensor([[0, 1, 2, 3], | |||
| [3, 3, 0, 0]]) | |||
| pred_dict = {'pred': target} | |||
| target_dict = {'target': target, 'seq_lens': seq_lens} | |||
| metric = BMESF1PreRecMetric() | |||
| metric(pred_dict, target_dict) | |||
| self.assertDictEqual(metric.get_metric(), {'f1': 0.999999, 'precision': 1.0, 'recall': 1.0}) | |||
| class TestUsefulFunctions(unittest.TestCase): | |||
| # 测试metrics.py中一些看上去挺有用的函数 | |||
| @@ -0,0 +1,104 @@ | |||
| import unittest | |||
| class TestCRF(unittest.TestCase): | |||
| def test_case1(self): | |||
| # 检查allowed_transitions()能否正确使用 | |||
| from fastNLP.modules.decoder.CRF import allowed_transitions | |||
| id2label = {0: 'B', 1: 'I', 2:'O'} | |||
| expected_res = {(0, 0), (0, 1), (0, 2), (0, 4), (1, 0), (1, 1), (1, 2), (1, 4), (2, 0), (2, 2), | |||
| (2, 4), (3, 0), (3, 2)} | |||
| self.assertSetEqual(expected_res, set(allowed_transitions(id2label))) | |||
| id2label = {0: 'B', 1:'M', 2:'E', 3:'S'} | |||
| expected_res = {(0, 1), (0, 2), (1, 1), (1, 2), (2, 0), (2, 3), (2, 5), (3, 0), (3, 3), (3, 5), (4, 0), (4, 3)} | |||
| self.assertSetEqual(expected_res, set(allowed_transitions(id2label, encoding_type='BMES'))) | |||
| id2label = {0: 'B', 1: 'I', 2:'O', 3: '<pad>', 4:"<unk>"} | |||
| allowed_transitions(id2label) | |||
| labels = ['O'] | |||
| for label in ['X', 'Y']: | |||
| for tag in 'BI': | |||
| labels.append('{}-{}'.format(tag, label)) | |||
| id2label = {idx:label for idx, label in enumerate(labels)} | |||
| expected_res = {(0, 0), (0, 1), (0, 3), (0, 6), (1, 0), (1, 1), (1, 2), (1, 3), (1, 6), (2, 0), (2, 1), | |||
| (2, 2), (2, 3), (2, 6), (3, 0), (3, 1), (3, 3), (3, 4), (3, 6), (4, 0), (4, 1), (4, 3), | |||
| (4, 4), (4, 6), (5, 0), (5, 1), (5, 3)} | |||
| self.assertSetEqual(expected_res, set(allowed_transitions(id2label))) | |||
| labels = [] | |||
| for label in ['X', 'Y']: | |||
| for tag in 'BMES': | |||
| labels.append('{}-{}'.format(tag, label)) | |||
| id2label = {idx: label for idx, label in enumerate(labels)} | |||
| expected_res = {(0, 1), (0, 2), (1, 1), (1, 2), (2, 0), (2, 3), (2, 4), (2, 7), (2, 9), (3, 0), (3, 3), (3, 4), | |||
| (3, 7), (3, 9), (4, 5), (4, 6), (5, 5), (5, 6), (6, 0), (6, 3), (6, 4), (6, 7), (6, 9), (7, 0), | |||
| (7, 3), (7, 4), (7, 7), (7, 9), (8, 0), (8, 3), (8, 4), (8, 7)} | |||
| self.assertSetEqual(expected_res, set(allowed_transitions(id2label, encoding_type='BMES'))) | |||
| def test_case2(self): | |||
| # 测试CRF能否避免解码出非法跃迁, 使用allennlp做了验证。 | |||
| pass | |||
| # import torch | |||
| # from fastNLP.modules.decoder.CRF import seq_len_to_byte_mask | |||
| # | |||
| # labels = ['O'] | |||
| # for label in ['X', 'Y']: | |||
| # for tag in 'BI': | |||
| # labels.append('{}-{}'.format(tag, label)) | |||
| # id2label = {idx: label for idx, label in enumerate(labels)} | |||
| # num_tags = len(id2label) | |||
| # | |||
| # from allennlp.modules.conditional_random_field import ConditionalRandomField, allowed_transitions | |||
| # allen_CRF = ConditionalRandomField(num_tags=num_tags, constraints=allowed_transitions('BIO', id2label), | |||
| # include_start_end_transitions=False) | |||
| # batch_size = 3 | |||
| # logits = torch.nn.functional.softmax(torch.rand(size=(batch_size, 20, num_tags))).log() | |||
| # trans_m = allen_CRF.transitions | |||
| # seq_lens = torch.randint(1, 20, size=(batch_size,)) | |||
| # seq_lens[-1] = 20 | |||
| # mask = seq_len_to_byte_mask(seq_lens) | |||
| # allen_res = allen_CRF.viterbi_tags(logits, mask) | |||
| # | |||
| # from fastNLP.modules.decoder.CRF import ConditionalRandomField, allowed_transitions | |||
| # fast_CRF = ConditionalRandomField(num_tags=num_tags, allowed_transitions=allowed_transitions(id2label)) | |||
| # fast_CRF.trans_m = trans_m | |||
| # fast_res = fast_CRF.viterbi_decode(logits, mask, get_score=True) | |||
| # # score equal | |||
| # self.assertListEqual([score for _, score in allen_res], fast_res[1]) | |||
| # # seq equal | |||
| # self.assertListEqual([_ for _, score in allen_res], fast_res[0]) | |||
| # | |||
| # | |||
| # labels = [] | |||
| # for label in ['X', 'Y']: | |||
| # for tag in 'BMES': | |||
| # labels.append('{}-{}'.format(tag, label)) | |||
| # id2label = {idx: label for idx, label in enumerate(labels)} | |||
| # num_tags = len(id2label) | |||
| # | |||
| # from allennlp.modules.conditional_random_field import ConditionalRandomField, allowed_transitions | |||
| # allen_CRF = ConditionalRandomField(num_tags=num_tags, constraints=allowed_transitions('BMES', id2label), | |||
| # include_start_end_transitions=False) | |||
| # batch_size = 3 | |||
| # logits = torch.nn.functional.softmax(torch.rand(size=(batch_size, 20, num_tags))).log() | |||
| # trans_m = allen_CRF.transitions | |||
| # seq_lens = torch.randint(1, 20, size=(batch_size,)) | |||
| # seq_lens[-1] = 20 | |||
| # mask = seq_len_to_byte_mask(seq_lens) | |||
| # allen_res = allen_CRF.viterbi_tags(logits, mask) | |||
| # | |||
| # from fastNLP.modules.decoder.CRF import ConditionalRandomField, allowed_transitions | |||
| # fast_CRF = ConditionalRandomField(num_tags=num_tags, allowed_transitions=allowed_transitions(id2label, | |||
| # encoding_type='BMES')) | |||
| # fast_CRF.trans_m = trans_m | |||
| # fast_res = fast_CRF.viterbi_decode(logits, mask, get_score=True) | |||
| # # score equal | |||
| # self.assertListEqual([score for _, score in allen_res], fast_res[1]) | |||
| # # seq equal | |||
| # self.assertListEqual([_ for _, score in allen_res], fast_res[0]) | |||