Browse Source

!11732 bert scripts modify to extract embedding tables in construct

From: @shibeiji
Reviewed-by: @linqingke,@c_34
Signed-off-by: @c_34
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
316e78fe6d
1 changed files with 1 additions and 2 deletions
  1. +1
    -2
      model_zoo/official/nlp/bert/src/bert_model.py

+ 1
- 2
model_zoo/official/nlp/bert/src/bert_model.py View File

@@ -805,7 +805,6 @@ class BertModel(nn.Cell):
vocab_size=config.vocab_size,
embedding_size=self.embedding_size,
use_one_hot=use_one_hot_embeddings)
self.embedding_tables = self.bert_embedding_lookup.embedding_table

self.bert_embedding_postprocessor = EmbeddingPostprocessor(
embedding_size=self.embedding_size,
@@ -847,7 +846,7 @@ class BertModel(nn.Cell):
def construct(self, input_ids, token_type_ids, input_mask):
"""Bidirectional Encoder Representations from Transformers."""
# embedding
embedding_tables = self.embedding_tables
embedding_tables = self.bert_embedding_lookup.embedding_table
word_embeddings = self.bert_embedding_lookup(input_ids)
embedding_output = self.bert_embedding_postprocessor(token_type_ids,
word_embeddings)


Loading…
Cancel
Save