diff --git a/model_zoo/official/nlp/bert/README.md b/model_zoo/official/nlp/bert/README.md index 4108f19611..d8f9cd60f9 100644 --- a/model_zoo/official/nlp/bert/README.md +++ b/model_zoo/official/nlp/bert/README.md @@ -50,8 +50,15 @@ The backbone structure of BERT is transformer. For BERT_base, the transformer co # [Dataset](#contents) -- Download the zhwiki or enwiki dataset for pre-training. Extract and refine texts in the dataset with [WikiExtractor](https://github.com/attardi/wikiextractor). Convert the dataset to TFRecord format. Please refer to create_pretraining_data.py file in [BERT](https://github.com/google-research/bert) repository. -- Download dataset for fine-tuning and evaluation such as CLUENER, TNEWS, SQuAD v1.1, etc. Convert dataset files from JSON format to TFRECORD format, please refer to run_classifier.py file in [BERT](https://github.com/google-research/bert) repository. +- Create pre-training dataset + - Download the [zhwiki](https://dumps.wikimedia.org/zhwiki/) or [enwiki](https://dumps.wikimedia.org/enwiki/) dataset for pre-training. + - Extract and refine texts in the dataset with [WikiExtractor](https://github.com/attardi/wikiextractor). The commands are as follows: + - pip install wikiextractor + - python -m wikiextractor.WikiExtractor -o -b + - Convert the dataset to TFRecord format. Please refer to create_pretraining_data.py file in [BERT](https://github.com/google-research/bert) repository and download vocab.txt here, if AttributeError: module 'tokenization' has no attribute 'FullTokenizer' occur, please install bert-tensorflow. +- Create fine-tune dataset + - Download dataset for fine-tuning and evaluation such as [CLUENER](https://github.com/CLUEbenchmark/CLUENER2020), [TNEWS](https://github.com/CLUEbenchmark/CLUE), [SQuAD v1.1 train dataset](https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json), [SQuAD v1.1 eval dataset](https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json), etc. + - Convert dataset files from JSON format to TFRECORD format, please refer to run_classifier.py file in [BERT](https://github.com/google-research/bert) repository. # [Environment Requirements](#contents) diff --git a/model_zoo/official/nlp/bert/README_CN.md b/model_zoo/official/nlp/bert/README_CN.md index 98f0fb5e77..749eeaa30c 100644 --- a/model_zoo/official/nlp/bert/README_CN.md +++ b/model_zoo/official/nlp/bert/README_CN.md @@ -53,8 +53,15 @@ BERT的主干结构为Transformer。对于BERT_base,Transformer包含12个编 # 数据集 -- 下载zhwiki或enwiki数据集进行预训练,使用[WikiExtractor](https://github.com/attardi/wikiextractor)提取和整理数据集中的文本,并将数据集转换为TFRecord格式。详见[BERT](https://github.com/google-research/bert)代码仓中的create_pretraining_data.py文件。 -- 下载数据集进行微调和评估,如CLUENER、TNEWS、SQuAD v1.1等。将数据集文件从JSON格式转换为TFRecord格式。详见[BERT](https://github.com/google-research/bert)代码仓中的run_classifier.py文件。 +- 生成预训练数据集 + - 下载[zhwiki](https://dumps.wikimedia.org/zhwiki/)或[enwiki](https://dumps.wikimedia.org/enwiki/)数据集进行预训练, + - 使用[WikiExtractor](https://github.com/attardi/wikiextractor)提取和整理数据集中的文本,使用步骤如下: + - pip install wikiextractor + - python -m wikiextractor.WikiExtractor -o -b + - 将数据集转换为TFRecord格式。详见[BERT](https://github.com/google-research/bert)代码仓中的create_pretraining_data.py文件,同时下载对应的vocab.txt文件, 如果出现AttributeError: module 'tokenization' has no attribute 'FullTokenizer’,请安装bert-tensorflow。 +- 生成下游任务数据集 + - 下载数据集进行微调和评估,如[CLUENER](https://github.com/CLUEbenchmark/CLUENER2020)、[TNEWS](https://github.com/CLUEbenchmark/CLUE)、[SQuAD v1.1训练集](https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json)、[SQuAD v1.1验证集](https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json)等。 + - 将数据集文件从JSON格式转换为TFRecord格式。详见[BERT](https://github.com/google-research/bert)代码仓中的run_classifier.py文件。 # 环境要求 diff --git a/model_zoo/official/nlp/tinybert/README.md b/model_zoo/official/nlp/tinybert/README.md index 539533459f..829c359c63 100644 --- a/model_zoo/official/nlp/tinybert/README.md +++ b/model_zoo/official/nlp/tinybert/README.md @@ -45,8 +45,15 @@ The backbone structure of TinyBERT is transformer, the transformer contains four # [Dataset](#contents) -- Download the zhwiki or enwiki dataset for general distillation. Extract and clean text in the dataset with [WikiExtractor](https://github.com/attardi/wikiextractor). Convert the dataset to TFRecord format, please refer to create_pretraining_data.py which in [BERT](https://github.com/google-research/bert) repository. -- Download glue dataset for task distillation. Convert dataset files from json format to tfrecord format, please refer to run_classifier.py which in [BERT](https://github.com/google-research/bert) repository. +- Create dataset for general distill phase + - Download the [zhwiki](https://dumps.wikimedia.org/zhwiki/) or [enwiki](https://dumps.wikimedia.org/enwiki/) dataset for pre-training. + - Extract and refine texts in the dataset with [WikiExtractor](https://github.com/attardi/wikiextractor). The commands are as follows: + - pip install wikiextractor + - python -m wikiextractor.WikiExtractor -o -b + - Convert the dataset to TFRecord format. Please refer to create_pretraining_data.py file in [BERT](https://github.com/google-research/bert) repository and download vocab.txt here, if AttributeError: module 'tokenization' has no attribute 'FullTokenizer' occur, please install bert-tensorflow. +- Create dataset for task distill phase + - Download [GLUE](https://github.com/nyu-mll/GLUE-baselines) dataset for task distill phase + - Convert dataset files from JSON format to TFRECORD format, please refer to run_classifier.py file in [BERT](https://github.com/google-research/bert) repository. # [Environment Requirements](#contents) diff --git a/model_zoo/official/nlp/tinybert/README_CN.md b/model_zoo/official/nlp/tinybert/README_CN.md index 3aa3684ee3..0da04460e2 100644 --- a/model_zoo/official/nlp/tinybert/README_CN.md +++ b/model_zoo/official/nlp/tinybert/README_CN.md @@ -50,8 +50,15 @@ TinyBERT模型的主干结构是转换器,转换器包含四个编码器模块 # 数据集 -- 下载zhwiki或enwiki数据集进行一般蒸馏。使用[WikiExtractor](https://github.com/attardi/wikiextractor)提取和整理数据集中的文本。如需将数据集转化为TFRecord格式。详见[BERT](https://github.com/google-research/bert)代码库中的create_pretraining_data.py文件。 -- 下载GLUE数据集进行任务蒸馏。将数据集由JSON格式转化为TFRecord格式。详见[BERT](https://github.com/google-research/bert)代码库中的run_classifier.py文件。 +- 生成通用蒸馏阶段数据集 + - 下载[zhwiki](https://dumps.wikimedia.org/zhwiki/)或[enwiki](https://dumps.wikimedia.org/enwiki/)数据集进行预训练, + - 使用[WikiExtractor](https://github.com/attardi/wikiextractor)提取和整理数据集中的文本,使用步骤如下: + - pip install wikiextractor + - python -m wikiextractor.WikiExtractor -o -b + - 将数据集转换为TFRecord格式。详见[BERT](https://github.com/google-research/bert)代码仓中的create_pretraining_data.py文件,同时下载对应的vocab.txt文件, 如果出现AttributeError: module 'tokenization' has no attribute 'FullTokenizer’,请安装bert-tensorflow。 +- 生成下游任务蒸馏阶段数据集 + - 下载数据集进行微调和评估,如[GLUE](https://github.com/nyu-mll/GLUE-baselines) + - 将数据集文件从JSON格式转换为TFRecord格式。详见[BERT](https://github.com/google-research/bert)代码仓中的run_classifier.py文件。 # 环境要求 diff --git a/model_zoo/official/nlp/tinybert/src/dataset.py b/model_zoo/official/nlp/tinybert/src/dataset.py index 2d0c1861b4..2b023f6990 100644 --- a/model_zoo/official/nlp/tinybert/src/dataset.py +++ b/model_zoo/official/nlp/tinybert/src/dataset.py @@ -52,7 +52,7 @@ def create_tinybert_dataset(task='td', batch_size=32, device_num=1, rank=0, data_set = ds.MindDataset(data_files, columns_list=columns_list, shuffle=(do_shuffle == "true"), num_shards=device_num, shard_id=rank) else: - data_set = ds.TFRecordDataset(data_files, schema_dir, columns_list=columns_list, + data_set = ds.TFRecordDataset(data_files, schema_dir if schema_dir != "" else None, columns_list=columns_list, shuffle=shuffle, num_shards=device_num, shard_id=rank, shard_equal_rows=shard_equal_rows) if device_num == 1 and shuffle is True: diff --git a/model_zoo/official/nlp/tinybert/src/tinybert_model.py b/model_zoo/official/nlp/tinybert/src/tinybert_model.py index 3a3ef8f952..0c0e2cbabf 100644 --- a/model_zoo/official/nlp/tinybert/src/tinybert_model.py +++ b/model_zoo/official/nlp/tinybert/src/tinybert_model.py @@ -86,55 +86,6 @@ class BertConfig: self.dtype = dtype self.compute_type = compute_type - -class EmbeddingLookup(nn.Cell): - """ - A embeddings lookup table with a fixed dictionary and size. - - Args: - vocab_size (int): Size of the dictionary of embeddings. - embedding_size (int): The size of each embedding vector. - embedding_shape (list): [batch_size, seq_length, embedding_size], the shape of - each embedding vector. - use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. - initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. - """ - def __init__(self, - vocab_size, - embedding_size, - embedding_shape, - use_one_hot_embeddings=False, - initializer_range=0.02): - super(EmbeddingLookup, self).__init__() - self.vocab_size = vocab_size - self.use_one_hot_embeddings = use_one_hot_embeddings - self.embedding_table = Parameter(initializer - (TruncatedNormal(initializer_range), - [vocab_size, embedding_size])) - self.expand = P.ExpandDims() - self.shape_flat = (-1,) - self.gather = P.Gather() - self.one_hot = P.OneHot() - self.on_value = Tensor(1.0, mstype.float32) - self.off_value = Tensor(0.0, mstype.float32) - self.array_mul = P.MatMul() - self.reshape = P.Reshape() - self.shape = tuple(embedding_shape) - - def construct(self, input_ids): - """embedding lookup""" - extended_ids = self.expand(input_ids, -1) - flat_ids = self.reshape(extended_ids, self.shape_flat) - if self.use_one_hot_embeddings: - one_hot_ids = self.one_hot(flat_ids, self.vocab_size, self.on_value, self.off_value) - output_for_reshape = self.array_mul( - one_hot_ids, self.embedding_table) - else: - output_for_reshape = self.gather(self.embedding_table, flat_ids, 0) - output = self.reshape(output_for_reshape, self.shape) - return output, self.embedding_table - - class EmbeddingPostprocessor(nn.Cell): """ Postprocessors apply positional and token type embeddings to word embeddings. @@ -166,10 +117,10 @@ class EmbeddingPostprocessor(nn.Cell): self.token_type_vocab_size = token_type_vocab_size self.use_one_hot_embeddings = use_one_hot_embeddings self.max_position_embeddings = max_position_embeddings - self.embedding_table = Parameter(initializer - (TruncatedNormal(initializer_range), - [token_type_vocab_size, - embedding_size])) + self.token_type_embedding = nn.Embedding( + vocab_size=token_type_vocab_size, + embedding_size=embedding_size, + use_one_hot=use_one_hot_embeddings) self.shape_flat = (-1,) self.one_hot = P.OneHot() self.on_value = Tensor(1.0, mstype.float32) @@ -177,35 +128,28 @@ class EmbeddingPostprocessor(nn.Cell): self.array_mul = P.MatMul() self.reshape = P.Reshape() self.shape = tuple(embedding_shape) - self.layernorm = nn.LayerNorm((embedding_size,)) self.dropout = nn.Dropout(1 - dropout_prob) self.gather = P.Gather() self.use_relative_positions = use_relative_positions self.slice = P.StridedSlice() - self.full_position_embeddings = Parameter(initializer - (TruncatedNormal(initializer_range), - [max_position_embeddings, - embedding_size])) + _, seq, _ = self.shape + self.full_position_embedding = nn.Embedding( + vocab_size=max_position_embeddings, + embedding_size=embedding_size, + use_one_hot=False) + self.layernorm = nn.LayerNorm((embedding_size,)) + self.position_ids = Tensor(np.arange(seq).reshape(-1, seq).astype(np.int32)) + self.add = P.Add() def construct(self, token_type_ids, word_embeddings): - """embedding postprocessor""" + """Postprocessors apply positional and token type embeddings to word embeddings.""" output = word_embeddings if self.use_token_type: - flat_ids = self.reshape(token_type_ids, self.shape_flat) - if self.use_one_hot_embeddings: - one_hot_ids = self.one_hot(flat_ids, - self.token_type_vocab_size, self.on_value, self.off_value) - token_type_embeddings = self.array_mul(one_hot_ids, - self.embedding_table) - else: - token_type_embeddings = self.gather(self.embedding_table, flat_ids, 0) - token_type_embeddings = self.reshape(token_type_embeddings, self.shape) - output += token_type_embeddings + token_type_embeddings = self.token_type_embedding(token_type_ids) + output = self.add(output, token_type_embeddings) if not self.use_relative_positions: - _, seq, width = self.shape - position_embeddings = self.slice(self.full_position_embeddings, (0, 0), (seq, width), (1, 1)) - position_embeddings = self.reshape(position_embeddings, (1, seq, width)) - output += position_embeddings + position_embeddings = self.full_position_embedding(self.position_ids) + output = self.add(output, position_embeddings) output = self.layernorm(output) output = self.dropout(output) return output @@ -788,12 +732,10 @@ class BertModel(nn.Cell): self.last_idx = self.num_hidden_layers - 1 output_embedding_shape = [-1, self.seq_length, self.embedding_size] - self.bert_embedding_lookup = EmbeddingLookup( + self.bert_embedding_lookup = nn.Embedding( vocab_size=config.vocab_size, embedding_size=self.embedding_size, - embedding_shape=output_embedding_shape, - use_one_hot_embeddings=use_one_hot_embeddings, - initializer_range=config.initializer_range) + use_one_hot=use_one_hot_embeddings) self.bert_embedding_postprocessor = EmbeddingPostprocessor( use_relative_positions=config.use_relative_positions, embedding_size=self.embedding_size, @@ -831,7 +773,8 @@ class BertModel(nn.Cell): def construct(self, input_ids, token_type_ids, input_mask): """bert model""" # embedding - word_embeddings, embedding_tables = self.bert_embedding_lookup(input_ids) + embedding_tables = self.bert_embedding_lookup.embedding_table + word_embeddings = self.bert_embedding_lookup(input_ids) embedding_output = self.bert_embedding_postprocessor(token_type_ids, word_embeddings) # attention mask [batch_size, seq_length, seq_length] attention_mask = self._create_attention_mask_from_input_mask(input_mask) @@ -883,12 +826,10 @@ class TinyBertModel(nn.Cell): self.last_idx = self.num_hidden_layers - 1 output_embedding_shape = [-1, self.seq_length, self.embedding_size] - self.tinybert_embedding_lookup = EmbeddingLookup( + self.tinybert_embedding_lookup = nn.Embedding( vocab_size=config.vocab_size, embedding_size=self.embedding_size, - embedding_shape=output_embedding_shape, - use_one_hot_embeddings=use_one_hot_embeddings, - initializer_range=config.initializer_range) + use_one_hot=use_one_hot_embeddings) self.tinybert_embedding_postprocessor = EmbeddingPostprocessor( use_relative_positions=config.use_relative_positions, embedding_size=self.embedding_size, @@ -926,7 +867,8 @@ class TinyBertModel(nn.Cell): def construct(self, input_ids, token_type_ids, input_mask): """tiny bert model""" # embedding - word_embeddings, embedding_tables = self.tinybert_embedding_lookup(input_ids) + embedding_tables = self.tinybert_embedding_lookup.embedding_table + word_embeddings = self.tinybert_embedding_lookup(input_ids) embedding_output = self.tinybert_embedding_postprocessor(token_type_ids, word_embeddings) # attention mask [batch_size, seq_length, seq_length] @@ -969,12 +911,8 @@ class BertModelCLS(nn.Cell): self.dtype = config.dtype self.num_labels = num_labels self.phase_type = phase_type - if self.phase_type == "teacher": - self.dense = nn.Dense(config.hidden_size, self.num_labels, weight_init=self.weight_init, - has_bias=True).to_float(config.compute_type) - else: - self.dense_1 = nn.Dense(config.hidden_size, self.num_labels, weight_init=self.weight_init, - has_bias=True).to_float(config.compute_type) + self.dense_1 = nn.Dense(config.hidden_size, self.num_labels, weight_init=self.weight_init, + has_bias=True).to_float(config.compute_type) self.dropout = nn.ReLU() def construct(self, input_ids, token_type_id, input_mask): @@ -982,10 +920,7 @@ class BertModelCLS(nn.Cell): _, pooled_output, _, seq_output, att_output = self.bert(input_ids, token_type_id, input_mask) cls = self.cast(pooled_output, self.dtype) cls = self.dropout(cls) - if self.phase_type == "teacher": - logits = self.dense(cls) - else: - logits = self.dense_1(cls) + logits = self.dense_1(cls) logits = self.cast(logits, self.dtype) log_probs = self.log_softmax(logits) if self._phase == 'train' or self.phase_type == "teacher":