| @@ -1006,24 +1006,28 @@ class CMRC2018Metric(MetricBase): | |||
| self.total = 0 | |||
| self.f1 = 0 | |||
| def evaluate(self, answers, raw_chars, context_len, pred_start, pred_end): | |||
| def evaluate(self, answers, raw_chars, pred_start, pred_end, context_len=None): | |||
| """ | |||
| :param list[str] answers: 如[["答案1", "答案2", "答案3"], [...], ...] | |||
| :param list[str] raw_chars: [["这", "是", ...], [...]] | |||
| :param tensor pred_start: batch_size x length 或 batch_size, | |||
| :param tensor pred_end: batch_size x length 或 batch_size(是闭区间,包含end位置), | |||
| :param tensor context_len: context长度, batch_size | |||
| :param tensor pred_start: batch_size x length | |||
| :param tensor pred_end: batch_size x length | |||
| :return: | |||
| """ | |||
| batch_size, max_len = pred_start.size() | |||
| context_mask = seq_len_to_mask(context_len, max_len=max_len).eq(False) | |||
| pred_start.masked_fill_(context_mask, float('-inf')) | |||
| pred_end.masked_fill_(context_mask, float('-inf')) | |||
| max_pred_start, pred_start_index = pred_start.max(dim=-1, keepdim=True) # batch_size, | |||
| pred_start_mask = pred_start.eq(max_pred_start).cumsum(dim=-1).eq(0) # 只能预测这之后的值 | |||
| pred_end.masked_fill_(pred_start_mask, float('-inf')) | |||
| pred_end_index = pred_end.argmax(dim=-1) + 1 | |||
| if pred_start.dim() > 1: | |||
| batch_size, max_len = pred_start.size() | |||
| context_mask = seq_len_to_mask(context_len, max_len=max_len).eq(False) | |||
| pred_start.masked_fill_(context_mask, float('-inf')) | |||
| pred_end.masked_fill_(context_mask, float('-inf')) | |||
| max_pred_start, pred_start_index = pred_start.max(dim=-1, keepdim=True) # batch_size, | |||
| pred_start_mask = pred_start.eq(max_pred_start).cumsum(dim=-1).eq(0) # 只能预测这之后的值 | |||
| pred_end.masked_fill_(pred_start_mask, float('-inf')) | |||
| pred_end_index = pred_end.argmax(dim=-1) + 1 | |||
| else: | |||
| pred_start_index = pred_start | |||
| pred_end_index = pred_end + 1 | |||
| pred_ans = [] | |||
| for index, (start, end) in enumerate(zip(pred_start_index.flatten().tolist(), pred_end_index.tolist())): | |||
| pred_ans.append(''.join(raw_chars[index][start:end])) | |||
| @@ -68,7 +68,11 @@ class StaticEmbedding(TokenEmbedding): | |||
| :param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。 | |||
| :param bool normalize: 是否对vector进行normalize,使得每个vector的norm为1。 | |||
| :param int min_freq: Vocabulary词频数小于这个数量的word将被指向unk。 | |||
| :param dict kwarngs: only_train_min_freq, 仅对train中的词语使用min_freq筛选; only_norm_found_vector是否仅对在预训练中找到的词语使用normalize。 | |||
| :param dict kwarngs: | |||
| bool only_train_min_freq: 仅对train中的词语使用min_freq筛选; | |||
| bool only_norm_found_vector: 是否仅对在预训练中找到的词语使用normalize; | |||
| bool only_use_pretrain_word: 仅使用出现在pretrain词表中的词语。如果该词没有在预训练的词表中出现则为unk。如果词表 | |||
| 不需要更新设置为True。 | |||
| """ | |||
| super(StaticEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout) | |||
| if embedding_dim > 0: | |||
| @@ -118,7 +122,8 @@ class StaticEmbedding(TokenEmbedding): | |||
| truncated_words_to_words[index] = truncated_vocab.to_index(word) | |||
| logger.info(f"{len(vocab) - len(truncated_vocab)} out of {len(vocab)} words have frequency less than {min_freq}.") | |||
| vocab = truncated_vocab | |||
| self.only_use_pretrain_word = kwargs.get('only_use_pretrain_word', False) | |||
| self.only_norm_found_vector = kwargs.get('only_norm_found_vector', False) | |||
| # 读取embedding | |||
| if lower: | |||
| @@ -249,12 +254,13 @@ class StaticEmbedding(TokenEmbedding): | |||
| logger.error("Error occurred at the {} line.".format(idx)) | |||
| raise e | |||
| logger.info("Found {} out of {} words in the pre-training embedding.".format(found_count, len(vocab))) | |||
| for word, index in vocab: | |||
| if index not in matrix and not vocab._is_word_no_create_entry(word): | |||
| if found_unknown: # 如果有unkonwn,用unknown初始化 | |||
| matrix[index] = matrix[vocab.unknown_idx] | |||
| else: | |||
| matrix[index] = None | |||
| if not self.only_use_pretrain_word: # 如果只用pretrain中的值就不要为未找到的词创建entry了 | |||
| for word, index in vocab: | |||
| if index not in matrix and not vocab._is_word_no_create_entry(word): | |||
| if found_unknown: # 如果有unkonwn,用unknown初始化 | |||
| matrix[index] = matrix[vocab.unknown_idx] | |||
| else: | |||
| matrix[index] = None | |||
| # matrix中代表是需要建立entry的词 | |||
| vectors = self._randomly_init_embed(len(matrix), dim, init_method) | |||
| @@ -16,9 +16,9 @@ from ...core import Vocabulary | |||
| __all__ = ['CMRC2018BertPipe'] | |||
| def _concat_clip(data_bundle, tokenizer, max_len, concat_field_name='raw_chars'): | |||
| def _concat_clip(data_bundle, max_len, concat_field_name='raw_chars'): | |||
| """ | |||
| 处理data_bundle中的DataSet,将context与question进行tokenize,然后使用[SEP]将两者连接起来。 | |||
| 处理data_bundle中的DataSet,将context与question按照character进行tokenize,然后使用[SEP]将两者连接起来。 | |||
| 会新增field: context_len(int), raw_words(list[str]), target_start(int), target_end(int)其中target_start | |||
| 与target_end是与raw_chars等长的。其中target_start和target_end是前闭后闭的区间。 | |||
| @@ -26,6 +26,7 @@ def _concat_clip(data_bundle, tokenizer, max_len, concat_field_name='raw_chars') | |||
| :param DataBundle data_bundle: 类似["a", "b", "[SEP]", "c", ] | |||
| :return: | |||
| """ | |||
| tokenizer = get_tokenizer('cn-char', lang='cn') | |||
| for name in list(data_bundle.datasets.keys()): | |||
| ds = data_bundle.get_dataset(name) | |||
| data_bundle.delete_dataset(name) | |||
| @@ -87,8 +88,8 @@ class CMRC2018BertPipe(Pipe): | |||
| ".", "...", "...","...", "..." | |||
| raw_words列是context与question拼起来的结果,words是转为index的值, target_start当当前位置为答案的开头时为1,target_end当当前 | |||
| 位置为答案的结尾是为1;context_len指示的是words列中context的长度。 | |||
| raw_words列是context与question拼起来的结果(连接的地方加入了[SEP]),words是转为index的值, target_start为答案start的index,target_end为答案end的index | |||
| (闭区间);context_len指示的是words列中context的长度。 | |||
| 其中各列的meta信息如下: | |||
| +-------------+-------------+-----------+--------------+------------+-------+---------+ | |||
| @@ -119,8 +120,7 @@ class CMRC2018BertPipe(Pipe): | |||
| :param data_bundle: | |||
| :return: | |||
| """ | |||
| _tokenizer = get_tokenizer('cn-char', lang='cn') | |||
| data_bundle = _concat_clip(data_bundle, tokenizer=_tokenizer, max_len=self.max_len, concat_field_name='raw_chars') | |||
| data_bundle = _concat_clip(data_bundle, max_len=self.max_len, concat_field_name='raw_chars') | |||
| src_vocab = Vocabulary() | |||
| src_vocab.from_dataset(*[ds for name, ds in data_bundle.iter_datasets() if 'train' in name], | |||
| @@ -35,6 +35,79 @@ class TestLoad(unittest.TestCase): | |||
| words = torch.randint(1, 200, (batch, length)).long() | |||
| embed(words) | |||
| def test_only_use_pretrain_word(self): | |||
| def check_word_unk(words, vocab, embed): | |||
| for word in words: | |||
| self.assertListEqual(embed(torch.LongTensor([vocab.to_index(word)])).tolist()[0], | |||
| embed(torch.LongTensor([1])).tolist()[0]) | |||
| def check_vector_equal(words, vocab, embed, embed_dict, lower=False): | |||
| for word in words: | |||
| index = vocab.to_index(word) | |||
| v1 = embed(torch.LongTensor([index])).tolist()[0] | |||
| if lower: | |||
| word = word.lower() | |||
| v2 = embed_dict[word] | |||
| for v1i, v2i in zip(v1, v2): | |||
| self.assertAlmostEqual(v1i, v2i, places=4) | |||
| embed_dict = read_static_embed('test/data_for_tests/embedding/small_static_embedding/' | |||
| 'glove.6B.50d_test.txt') | |||
| # 测试是否只使用pretrain的word | |||
| vocab = Vocabulary().add_word_lst(['the', 'a', 'notinfile']) | |||
| vocab.add_word('of', no_create_entry=True) | |||
| embed = StaticEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_static_embedding/' | |||
| 'glove.6B.50d_test.txt', | |||
| only_use_pretrain_word=True) | |||
| # notinfile应该被置为unk | |||
| check_vector_equal(['the', 'a', 'of'], vocab, embed, embed_dict) | |||
| check_word_unk(['notinfile'], vocab, embed) | |||
| # 测试在大小写情况下的使用 | |||
| vocab = Vocabulary().add_word_lst(['The', 'a', 'notinfile']) | |||
| vocab.add_word('Of', no_create_entry=True) | |||
| embed = StaticEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_static_embedding/' | |||
| 'glove.6B.50d_test.txt', | |||
| only_use_pretrain_word=True) | |||
| check_word_unk(['The', 'Of', 'notinfile'], vocab, embed) # 这些词应该找不到 | |||
| check_vector_equal(['a'], vocab, embed, embed_dict) | |||
| embed = StaticEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_static_embedding/' | |||
| 'glove.6B.50d_test.txt', | |||
| only_use_pretrain_word=True, lower=True) | |||
| check_vector_equal(['The', 'Of', 'a'], vocab, embed, embed_dict, lower=True) | |||
| check_word_unk(['notinfile'], vocab, embed) | |||
| # 测试min_freq | |||
| vocab = Vocabulary().add_word_lst(['The', 'a', 'notinfile1', 'A', 'notinfile2', 'notinfile2']) | |||
| vocab.add_word('Of', no_create_entry=True) | |||
| embed = StaticEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_static_embedding/' | |||
| 'glove.6B.50d_test.txt', | |||
| only_use_pretrain_word=True, lower=True, min_freq=2, only_train_min_freq=True) | |||
| check_vector_equal(['Of', 'a'], vocab, embed, embed_dict, lower=True) | |||
| check_word_unk(['notinfile1', 'The', 'notinfile2'], vocab, embed) | |||
| def read_static_embed(fp): | |||
| """ | |||
| :param str fp: embedding的路径 | |||
| :return: {}, key是word, value是vector | |||
| """ | |||
| embed = {} | |||
| with open(fp, 'r') as f: | |||
| for line in f: | |||
| line = line.strip() | |||
| if line: | |||
| parts = line.split() | |||
| vector = list(map(float, parts[1:])) | |||
| word = parts[0] | |||
| embed[word] = vector | |||
| return embed | |||
| class TestRandomSameEntry(unittest.TestCase): | |||
| def test_same_vector(self): | |||
| vocab = Vocabulary().add_word_lst(["The", "the", "THE", 'a', "A"]) | |||