| @@ -23,6 +23,7 @@ from .utils import _get_func_signature | |||
| from .utils import seq_len_to_mask | |||
| from .vocabulary import Vocabulary | |||
| from abc import abstractmethod | |||
| import warnings | |||
| class MetricBase(object): | |||
| @@ -492,6 +493,30 @@ def _bio_tag_to_spans(tags, ignore_labels=None): | |||
| return [(span[0], (span[1][0], span[1][1] + 1)) for span in spans if span[0] not in ignore_labels] | |||
| def _check_tag_vocab_and_encoding_type(vocab:Vocabulary, encoding_type:str): | |||
| """ | |||
| 检查vocab中的tag是否与encoding_type是匹配的 | |||
| :param vocab: target的Vocabulary | |||
| :param encoding_type: bio, bmes, bioes, bmeso | |||
| :return: | |||
| """ | |||
| tag_set = set() | |||
| for tag, idx in vocab: | |||
| if idx in (vocab.unknown_idx, vocab.padding_idx): | |||
| continue | |||
| tag = tag[:1].lower() | |||
| tag_set.add(tag) | |||
| tags = encoding_type | |||
| for tag in tag_set: | |||
| assert tag in tags, f"{tag} is not a valid tag in encoding type:{encoding_type}. Please check your " \ | |||
| f"encoding_type." | |||
| tags = tags.replace(tag, '') # 删除该值 | |||
| if tags: # 如果不为空,说明出现了未使用的tag | |||
| warnings.warn(f"Tag:{tags} in encoding type:{encoding_type} is not presented in your Vocabulary. Check your " | |||
| "encoding_type.") | |||
| class SpanFPreRecMetric(MetricBase): | |||
| r""" | |||
| 别名::class:`fastNLP.SpanFPreRecMetric` :class:`fastNLP.core.metrics.SpanFPreRecMetric` | |||
| @@ -546,6 +571,7 @@ class SpanFPreRecMetric(MetricBase): | |||
| raise ValueError("f_type only supports `micro` or `macro`', got {}.".format(f_type)) | |||
| self.encoding_type = encoding_type | |||
| _check_tag_vocab_and_encoding_type(tag_vocab, encoding_type) | |||
| if self.encoding_type == 'bmes': | |||
| self.tag_to_span_func = _bmes_tag_to_spans | |||
| elif self.encoding_type == 'bio': | |||
| @@ -338,6 +338,41 @@ class SpanF1PreRecMetric(unittest.TestCase): | |||
| for key, value in expected_metric.items(): | |||
| self.assertAlmostEqual(value, metric_value[key], places=5) | |||
| def test_encoding_type(self): | |||
| # 检查传入的tag_vocab与encoding_type不符合时,是否会报错 | |||
| vocabs = {} | |||
| import random | |||
| from itertools import product | |||
| for encoding_type in ['bio', 'bioes', 'bmeso']: | |||
| vocab = Vocabulary(unknown=None, padding=None) | |||
| for i in range(random.randint(10, 100)): | |||
| label = str(random.randint(1, 10)) | |||
| for tag in encoding_type: | |||
| if tag!='o': | |||
| vocab.add_word(f'{tag}-{label}') | |||
| else: | |||
| vocab.add_word('o') | |||
| vocabs[encoding_type] = vocab | |||
| for e1, e2 in product(['bio', 'bioes', 'bmeso'], ['bio', 'bioes', 'bmeso']): | |||
| with self.subTest(e1=e1, e2=e2): | |||
| if e1==e2: | |||
| metric = SpanFPreRecMetric(vocabs[e1], encoding_type=e2) | |||
| else: | |||
| s2 = set(e2) | |||
| s2.update(set(e1)) | |||
| if s2==set(e2): | |||
| continue | |||
| with self.assertRaises(AssertionError): | |||
| metric = SpanFPreRecMetric(vocabs[e1], encoding_type=e2) | |||
| for encoding_type in ['bio', 'bioes', 'bmeso']: | |||
| with self.assertRaises(AssertionError): | |||
| metric = SpanFPreRecMetric(vocabs[encoding_type], encoding_type='bmes') | |||
| with self.assertWarns(Warning): | |||
| vocab = Vocabulary(unknown=None, padding=None).add_word_lst(list('bmes')) | |||
| metric = SpanFPreRecMetric(vocab, encoding_type='bmeso') | |||
| vocab = Vocabulary().add_word_lst(list('bmes')) | |||
| metric = SpanFPreRecMetric(vocab, encoding_type='bmeso') | |||
| class TestUsefulFunctions(unittest.TestCase): | |||
| # 测试metrics.py中一些看上去挺有用的函数 | |||