| @@ -161,7 +161,6 @@ For example, the schema file of cn-wiki-128 dataset for pretraining shows as fol | |||
| ├─dataset.py # data preprocessing | |||
| ├─finetune_eval_config.py # parameter configuration for finetuning | |||
| ├─finetune_eval_model.py # backbone code of network | |||
| ├─fused_layer_norm.py # Layernormal is optimized for Ascend | |||
| ├─sample_process.py # sample processing | |||
| ├─utils.py # util function | |||
| ├─pretrain_eval.py # train and eval net | |||
| @@ -25,7 +25,6 @@ from mindspore.ops import operations as P | |||
| from mindspore.ops import composite as C | |||
| from mindspore.common.tensor import Tensor | |||
| from mindspore.common.parameter import Parameter | |||
| from .fused_layer_norm import FusedLayerNorm | |||
| class BertConfig: | |||
| @@ -78,8 +77,7 @@ class BertConfig: | |||
| input_mask_from_dataset=True, | |||
| token_type_ids_from_dataset=True, | |||
| dtype=mstype.float32, | |||
| compute_type=mstype.float32, | |||
| enable_fused_layernorm=False): | |||
| compute_type=mstype.float32): | |||
| self.batch_size = batch_size | |||
| self.seq_length = seq_length | |||
| self.vocab_size = vocab_size | |||
| @@ -98,7 +96,6 @@ class BertConfig: | |||
| self.use_relative_positions = use_relative_positions | |||
| self.dtype = dtype | |||
| self.compute_type = compute_type | |||
| self.enable_fused_layernorm = enable_fused_layernorm | |||
| class EmbeddingLookup(nn.Cell): | |||
| @@ -245,19 +242,14 @@ class BertOutput(nn.Cell): | |||
| out_channels, | |||
| initializer_range=0.02, | |||
| dropout_prob=0.1, | |||
| compute_type=mstype.float32, | |||
| enable_fused_layernorm=False): | |||
| compute_type=mstype.float32): | |||
| super(BertOutput, self).__init__() | |||
| self.dense = nn.Dense(in_channels, out_channels, | |||
| weight_init=TruncatedNormal(initializer_range)).to_float(compute_type) | |||
| self.dropout = nn.Dropout(1 - dropout_prob) | |||
| self.dropout_prob = dropout_prob | |||
| self.add = P.TensorAdd() | |||
| if compute_type == mstype.float16: | |||
| self.layernorm = FusedLayerNorm((out_channels,), | |||
| use_batch_norm=enable_fused_layernorm).to_float(compute_type) | |||
| else: | |||
| self.layernorm = nn.LayerNorm((out_channels,)).to_float(compute_type) | |||
| self.layernorm = nn.LayerNorm((out_channels,)).to_float(compute_type) | |||
| self.cast = P.Cast() | |||
| def construct(self, hidden_status, input_tensor): | |||
| @@ -615,8 +607,7 @@ class BertSelfAttention(nn.Cell): | |||
| initializer_range=0.02, | |||
| hidden_dropout_prob=0.1, | |||
| use_relative_positions=False, | |||
| compute_type=mstype.float32, | |||
| enable_fused_layernorm=False): | |||
| compute_type=mstype.float32): | |||
| super(BertSelfAttention, self).__init__() | |||
| if hidden_size % num_attention_heads != 0: | |||
| raise ValueError("The hidden size (%d) is not a multiple of the number " | |||
| @@ -644,8 +635,7 @@ class BertSelfAttention(nn.Cell): | |||
| out_channels=hidden_size, | |||
| initializer_range=initializer_range, | |||
| dropout_prob=hidden_dropout_prob, | |||
| compute_type=compute_type, | |||
| enable_fused_layernorm=enable_fused_layernorm) | |||
| compute_type=compute_type) | |||
| self.reshape = P.Reshape() | |||
| self.shape = (-1, hidden_size) | |||
| @@ -687,8 +677,7 @@ class BertEncoderCell(nn.Cell): | |||
| hidden_dropout_prob=0.1, | |||
| use_relative_positions=False, | |||
| hidden_act="gelu", | |||
| compute_type=mstype.float32, | |||
| enable_fused_layernorm=False): | |||
| compute_type=mstype.float32): | |||
| super(BertEncoderCell, self).__init__() | |||
| self.attention = BertSelfAttention( | |||
| batch_size=batch_size, | |||
| @@ -700,8 +689,7 @@ class BertEncoderCell(nn.Cell): | |||
| initializer_range=initializer_range, | |||
| hidden_dropout_prob=hidden_dropout_prob, | |||
| use_relative_positions=use_relative_positions, | |||
| compute_type=compute_type, | |||
| enable_fused_layernorm=enable_fused_layernorm) | |||
| compute_type=compute_type) | |||
| self.intermediate = nn.Dense(in_channels=hidden_size, | |||
| out_channels=intermediate_size, | |||
| activation=hidden_act, | |||
| @@ -710,8 +698,7 @@ class BertEncoderCell(nn.Cell): | |||
| out_channels=hidden_size, | |||
| initializer_range=initializer_range, | |||
| dropout_prob=hidden_dropout_prob, | |||
| compute_type=compute_type, | |||
| enable_fused_layernorm=enable_fused_layernorm) | |||
| compute_type=compute_type) | |||
| def construct(self, hidden_states, attention_mask): | |||
| # self-attention | |||
| @@ -758,8 +745,7 @@ class BertTransformer(nn.Cell): | |||
| use_relative_positions=False, | |||
| hidden_act="gelu", | |||
| compute_type=mstype.float32, | |||
| return_all_encoders=False, | |||
| enable_fused_layernorm=False): | |||
| return_all_encoders=False): | |||
| super(BertTransformer, self).__init__() | |||
| self.return_all_encoders = return_all_encoders | |||
| @@ -776,8 +762,7 @@ class BertTransformer(nn.Cell): | |||
| hidden_dropout_prob=hidden_dropout_prob, | |||
| use_relative_positions=use_relative_positions, | |||
| hidden_act=hidden_act, | |||
| compute_type=compute_type, | |||
| enable_fused_layernorm=enable_fused_layernorm) | |||
| compute_type=compute_type) | |||
| layers.append(layer) | |||
| self.layers = nn.CellList(layers) | |||
| @@ -904,8 +889,7 @@ class BertModel(nn.Cell): | |||
| use_relative_positions=config.use_relative_positions, | |||
| hidden_act=config.hidden_act, | |||
| compute_type=config.compute_type, | |||
| return_all_encoders=True, | |||
| enable_fused_layernorm=config.enable_fused_layernorm) | |||
| return_all_encoders=True) | |||
| self.cast = P.Cast() | |||
| self.dtype = config.dtype | |||
| @@ -1,122 +0,0 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """fused layernorm""" | |||
| import numpy as np | |||
| from mindspore.ops import operations as P | |||
| from mindspore.ops import functional as F | |||
| from mindspore.common.parameter import Parameter | |||
| from mindspore.common.initializer import initializer | |||
| from mindspore.ops.primitive import constexpr | |||
| import mindspore.common.dtype as mstype | |||
| from mindspore.nn.cell import Cell | |||
| __all__ = ['FusedLayerNorm'] | |||
| @constexpr | |||
| def get_shape_for_norm(x_shape, begin_norm_axis): | |||
| print("input_shape: ", x_shape) | |||
| norm_shape = x_shape[begin_norm_axis:] | |||
| output_shape = (1, -1, 1, int(np.prod(norm_shape))) | |||
| print("output_shape: ", output_shape) | |||
| return output_shape | |||
| class FusedLayerNorm(Cell): | |||
| r""" | |||
| Applies Layer Normalization over a mini-batch of inputs. | |||
| Layer normalization is widely used in recurrent neural networks. It applies | |||
| normalization over a mini-batch of inputs for each single training case as described | |||
| in the paper `Layer Normalization <https://arxiv.org/pdf/1607.06450.pdf>`_. Unlike batch | |||
| normalization, layer normalization performs exactly the same computation at training and | |||
| testing times. It can be described using the following formula. It is applied across all channels | |||
| and pixel but only one batch size. | |||
| .. math:: | |||
| y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta | |||
| Args: | |||
| normalized_shape (Union(tuple[int], list[int]): The normalization is performed over axis | |||
| `begin_norm_axis ... R - 1`. | |||
| begin_norm_axis (int): It first normalization dimension: normalization will be performed along dimensions | |||
| `begin_norm_axis: rank(inputs)`, the value should be in [-1, rank(input)). Default: -1. | |||
| begin_params_axis (int): The first parameter(beta, gamma)dimension: scale and centering parameters | |||
| will have dimensions `begin_params_axis: rank(inputs)` and will be broadcast with | |||
| the normalized inputs accordingly, the value should be in [-1, rank(input)). Default: -1. | |||
| gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight. | |||
| The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform', | |||
| 'he_uniform', etc. Default: 'ones'. | |||
| beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight. | |||
| The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform', | |||
| 'he_uniform', etc. Default: 'zeros'. | |||
| use_batch_nrom (bool): Whether use batchnorm to preocess. | |||
| Inputs: | |||
| - **input_x** (Tensor) - The shape of 'input_x' is :math:`(x_1, x_2, ..., x_R)`, | |||
| and `input_shape[begin_norm_axis:]` is equal to `normalized_shape`. | |||
| Outputs: | |||
| Tensor, the normalized and scaled offset tensor, has the same shape and data type as the `input_x`. | |||
| Examples: | |||
| >>> x = Tensor(np.ones([20, 5, 10, 10]), mindspore.float32) | |||
| >>> shape1 = x.shape[1:] | |||
| >>> m = nn.LayerNorm(shape1, begin_norm_axis=1, begin_params_axis=1) | |||
| >>> m(x) | |||
| """ | |||
| def __init__(self, | |||
| normalized_shape, | |||
| begin_norm_axis=-1, | |||
| begin_params_axis=-1, | |||
| gamma_init='ones', | |||
| beta_init='zeros', | |||
| use_batch_norm=False): | |||
| super(FusedLayerNorm, self).__init__() | |||
| if not isinstance(normalized_shape, (tuple, list)): | |||
| raise TypeError("The type of 'normalized_shape' should be tuple[int] or list[int], but '{}' type is {}." | |||
| .format(normalized_shape, type(normalized_shape))) | |||
| self.normalized_shape = normalized_shape | |||
| self.begin_norm_axis = begin_norm_axis | |||
| self.begin_params_axis = begin_params_axis | |||
| self.gamma = Parameter(initializer( | |||
| gamma_init, normalized_shape), name="gamma") | |||
| self.beta = Parameter(initializer( | |||
| beta_init, normalized_shape), name="beta") | |||
| self.layer_norm = P.LayerNorm(begin_norm_axis=self.begin_norm_axis, begin_params_axis=self.begin_params_axis) | |||
| self.batch_norm = P.BatchNorm(is_training=True, epsilon=1e-5) | |||
| self.use_batch_norm = use_batch_norm | |||
| def construct(self, input_x): | |||
| """Applies Layer Normalization over a mini-batch of inputs""" | |||
| if self.use_batch_norm and self.training: | |||
| ones = P.Fill()(mstype.float32, F.shape(input_x)[:self.begin_norm_axis], 1.0) | |||
| zeros = P.Fill()(mstype.float32, F.shape(input_x)[:self.begin_norm_axis], 0.0) | |||
| shape_x = F.shape(input_x) | |||
| norm_shape = get_shape_for_norm(shape_x, self.begin_norm_axis) | |||
| input_x = F.reshape(input_x, norm_shape) | |||
| output, _, _, _, _, _ = self.batch_norm(input_x, ones, zeros, None, None) | |||
| output = F.reshape(output, shape_x) | |||
| y = output * self.gamma + self.beta | |||
| else: | |||
| y, _, _ = self.layer_norm(input_x, self.gamma, self.beta) | |||
| return y | |||
| def extend_repr(self): | |||
| """Display instance object as string.""" | |||
| s = 'normalized_shape={}, begin_norm_axis={}, begin_params_axis={}, gamma{}, beta={}'.format( | |||
| self.normalized_shape, self.begin_norm_axis, self.begin_params_axis, self.gamma, self.beta) | |||
| return s | |||
| @@ -113,7 +113,6 @@ For example, the dataset is cn-wiki-128, the schema file for general distill pha | |||
| ├─__init__.py | |||
| ├─assessment_method.py # assessment method for evaluation | |||
| ├─dataset.py # data processing | |||
| ├─fused_layer_norm.py # Layernormal is optimized for Ascend | |||
| ├─gd_config.py # parameter configuration for general distill phase | |||
| ├─td_config.py # parameter configuration for task distill phase | |||
| ├─tinybert_for_gd_td.py # backbone code of network | |||
| @@ -229,7 +228,6 @@ Parameters for bert network: | |||
| token_type_ids_from_dataset use the token type ids loaded from dataset or not: True | False, default is True | |||
| dtype data type of input: mstype.float16 | mstype.float32, default is mstype.float32 | |||
| compute_type compute type in BertTransformer: mstype.float16 | mstype.float32, default is mstype.float16 | |||
| enable_fused_layernorm use batchnorm instead of layernorm to improve performance, default is False | |||
| ``` | |||
| ## [Training Process](#contents) | |||
| ### Training | |||
| @@ -1,122 +0,0 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """fused layernorm""" | |||
| import numpy as np | |||
| from mindspore.ops import operations as P | |||
| from mindspore.ops import functional as F | |||
| from mindspore.common.parameter import Parameter | |||
| from mindspore.common.initializer import initializer | |||
| from mindspore.ops.primitive import constexpr | |||
| import mindspore.common.dtype as mstype | |||
| from mindspore.nn.cell import Cell | |||
| __all__ = ['FusedLayerNorm'] | |||
| @constexpr | |||
| def get_shape_for_norm(x_shape, begin_norm_axis): | |||
| print("input_shape: ", x_shape) | |||
| norm_shape = x_shape[begin_norm_axis:] | |||
| output_shape = (1, -1, 1, int(np.prod(norm_shape))) | |||
| print("output_shape: ", output_shape) | |||
| return output_shape | |||
| class FusedLayerNorm(Cell): | |||
| r""" | |||
| Applies Layer Normalization over a mini-batch of inputs. | |||
| Layer normalization is widely used in recurrent neural networks. It applies | |||
| normalization over a mini-batch of inputs for each single training case as described | |||
| in the paper `Layer Normalization <https://arxiv.org/pdf/1607.06450.pdf>`_. Unlike batch | |||
| normalization, layer normalization performs exactly the same computation at training and | |||
| testing times. It can be described using the following formula. It is applied across all channels | |||
| and pixel but only one batch size. | |||
| .. math:: | |||
| y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta | |||
| Args: | |||
| normalized_shape (Union(tuple[int], list[int]): The normalization is performed over axis | |||
| `begin_norm_axis ... R - 1`. | |||
| begin_norm_axis (int): It first normalization dimension: normalization will be performed along dimensions | |||
| `begin_norm_axis: rank(inputs)`, the value should be in [-1, rank(input)). Default: -1. | |||
| begin_params_axis (int): The first parameter(beta, gamma)dimension: scale and centering parameters | |||
| will have dimensions `begin_params_axis: rank(inputs)` and will be broadcast with | |||
| the normalized inputs accordingly, the value should be in [-1, rank(input)). Default: -1. | |||
| gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight. | |||
| The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform', | |||
| 'he_uniform', etc. Default: 'ones'. | |||
| beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight. | |||
| The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform', | |||
| 'he_uniform', etc. Default: 'zeros'. | |||
| use_batch_nrom (bool): Whether use batchnorm to preocess. | |||
| Inputs: | |||
| - **input_x** (Tensor) - The shape of 'input_x' is :math:`(x_1, x_2, ..., x_R)`, | |||
| and `input_shape[begin_norm_axis:]` is equal to `normalized_shape`. | |||
| Outputs: | |||
| Tensor, the normalized and scaled offset tensor, has the same shape and data type as the `input_x`. | |||
| Examples: | |||
| >>> x = Tensor(np.ones([20, 5, 10, 10]), mindspore.float32) | |||
| >>> shape1 = x.shape[1:] | |||
| >>> m = nn.LayerNorm(shape1, begin_norm_axis=1, begin_params_axis=1) | |||
| >>> m(x) | |||
| """ | |||
| def __init__(self, | |||
| normalized_shape, | |||
| begin_norm_axis=-1, | |||
| begin_params_axis=-1, | |||
| gamma_init='ones', | |||
| beta_init='zeros', | |||
| use_batch_norm=False): | |||
| super(FusedLayerNorm, self).__init__() | |||
| if not isinstance(normalized_shape, (tuple, list)): | |||
| raise TypeError("The type of 'normalized_shape' should be tuple[int] or list[int], but '{}' type is {}." | |||
| .format(normalized_shape, type(normalized_shape))) | |||
| self.normalized_shape = normalized_shape | |||
| self.begin_norm_axis = begin_norm_axis | |||
| self.begin_params_axis = begin_params_axis | |||
| self.gamma = Parameter(initializer( | |||
| gamma_init, normalized_shape), name="gamma") | |||
| self.beta = Parameter(initializer( | |||
| beta_init, normalized_shape), name="beta") | |||
| self.layer_norm = P.LayerNorm(begin_norm_axis=self.begin_norm_axis, begin_params_axis=self.begin_params_axis) | |||
| self.batch_norm = P.BatchNorm(is_training=True, epsilon=1e-5) | |||
| self.use_batch_norm = use_batch_norm | |||
| def construct(self, input_x): | |||
| """fusedlayernorm""" | |||
| if self.use_batch_norm and self.training: | |||
| ones = P.Fill()(mstype.float32, F.shape(input_x)[:self.begin_norm_axis], 1.0) | |||
| zeros = P.Fill()(mstype.float32, F.shape(input_x)[:self.begin_norm_axis], 0.0) | |||
| shape_x = F.shape(input_x) | |||
| norm_shape = get_shape_for_norm(shape_x, self.begin_norm_axis) | |||
| input_x = F.reshape(input_x, norm_shape) | |||
| output, _, _, _, _, _ = self.batch_norm(input_x, ones, zeros, None, None) | |||
| output = F.reshape(output, shape_x) | |||
| y = output * self.gamma + self.beta | |||
| else: | |||
| y, _, _ = self.layer_norm(input_x, self.gamma, self.beta) | |||
| return y | |||
| def extend_repr(self): | |||
| """Display instance object as string.""" | |||
| s = 'normalized_shape={}, begin_norm_axis={}, begin_params_axis={}, gamma{}, beta={}'.format( | |||
| self.normalized_shape, self.begin_norm_axis, self.begin_params_axis, self.gamma, self.beta) | |||
| return s | |||
| @@ -55,8 +55,7 @@ bert_teacher_net_cfg = BertConfig( | |||
| input_mask_from_dataset=True, | |||
| token_type_ids_from_dataset=True, | |||
| dtype=mstype.float32, | |||
| compute_type=mstype.float16, | |||
| enable_fused_layernorm=False | |||
| compute_type=mstype.float16 | |||
| ) | |||
| bert_student_net_cfg = BertConfig( | |||
| batch_size=32, | |||
| @@ -76,6 +75,5 @@ bert_student_net_cfg = BertConfig( | |||
| input_mask_from_dataset=True, | |||
| token_type_ids_from_dataset=True, | |||
| dtype=mstype.float32, | |||
| compute_type=mstype.float16, | |||
| enable_fused_layernorm=False | |||
| compute_type=mstype.float16 | |||
| ) | |||
| @@ -74,8 +74,7 @@ td_teacher_net_cfg = BertConfig( | |||
| input_mask_from_dataset=True, | |||
| token_type_ids_from_dataset=True, | |||
| dtype=mstype.float32, | |||
| compute_type=mstype.float16, | |||
| enable_fused_layernorm=False | |||
| compute_type=mstype.float16 | |||
| ) | |||
| td_student_net_cfg = BertConfig( | |||
| batch_size=32, | |||
| @@ -95,6 +94,5 @@ td_student_net_cfg = BertConfig( | |||
| input_mask_from_dataset=True, | |||
| token_type_ids_from_dataset=True, | |||
| dtype=mstype.float32, | |||
| compute_type=mstype.float16, | |||
| enable_fused_layernorm=False | |||
| compute_type=mstype.float16 | |||
| ) | |||
| @@ -25,7 +25,6 @@ from mindspore.ops import composite as C | |||
| from mindspore.common.tensor import Tensor | |||
| from mindspore.common.parameter import Parameter | |||
| from mindspore import context | |||
| from .fused_layer_norm import FusedLayerNorm | |||
| class BertConfig: | |||
| @@ -78,8 +77,7 @@ class BertConfig: | |||
| input_mask_from_dataset=True, | |||
| token_type_ids_from_dataset=True, | |||
| dtype=mstype.float32, | |||
| compute_type=mstype.float32, | |||
| enable_fused_layernorm=False): | |||
| compute_type=mstype.float32): | |||
| self.batch_size = batch_size | |||
| self.seq_length = seq_length | |||
| self.vocab_size = vocab_size | |||
| @@ -98,7 +96,6 @@ class BertConfig: | |||
| self.use_relative_positions = use_relative_positions | |||
| self.dtype = dtype | |||
| self.compute_type = compute_type | |||
| self.enable_fused_layernorm = enable_fused_layernorm | |||
| class EmbeddingLookup(nn.Cell): | |||
| @@ -244,8 +241,7 @@ class BertOutput(nn.Cell): | |||
| out_channels, | |||
| initializer_range=0.02, | |||
| dropout_prob=0.1, | |||
| compute_type=mstype.float32, | |||
| enable_fused_layernorm=False): | |||
| compute_type=mstype.float32): | |||
| super(BertOutput, self).__init__() | |||
| self.dense = nn.Dense(in_channels, out_channels, | |||
| weight_init=TruncatedNormal(initializer_range)).to_float(compute_type) | |||
| @@ -256,11 +252,7 @@ class BertOutput(nn.Cell): | |||
| self.layernorm = nn.LayerNorm((out_channels,)).to_float(mstype.float32) | |||
| self.compute_type = compute_type | |||
| else: | |||
| if compute_type == mstype.float16: | |||
| self.layernorm = FusedLayerNorm((out_channels,), | |||
| use_batch_norm=enable_fused_layernorm).to_float(compute_type) | |||
| else: | |||
| self.layernorm = nn.LayerNorm((out_channels,)).to_float(compute_type) | |||
| self.layernorm = nn.LayerNorm((out_channels,)).to_float(compute_type) | |||
| self.cast = P.Cast() | |||
| @@ -602,8 +594,7 @@ class BertSelfAttention(nn.Cell): | |||
| initializer_range=0.02, | |||
| hidden_dropout_prob=0.1, | |||
| use_relative_positions=False, | |||
| compute_type=mstype.float32, | |||
| enable_fused_layernorm=False): | |||
| compute_type=mstype.float32): | |||
| super(BertSelfAttention, self).__init__() | |||
| if hidden_size % num_attention_heads != 0: | |||
| raise ValueError("The hidden size (%d) is not a multiple of the number " | |||
| @@ -628,8 +619,7 @@ class BertSelfAttention(nn.Cell): | |||
| out_channels=hidden_size, | |||
| initializer_range=initializer_range, | |||
| dropout_prob=hidden_dropout_prob, | |||
| compute_type=compute_type, | |||
| enable_fused_layernorm=enable_fused_layernorm) | |||
| compute_type=compute_type) | |||
| self.reshape = P.Reshape() | |||
| self.shape = (-1, hidden_size) | |||
| @@ -672,8 +662,7 @@ class BertEncoderCell(nn.Cell): | |||
| hidden_dropout_prob=0.1, | |||
| use_relative_positions=False, | |||
| hidden_act="gelu", | |||
| compute_type=mstype.float32, | |||
| enable_fused_layernorm=False): | |||
| compute_type=mstype.float32): | |||
| super(BertEncoderCell, self).__init__() | |||
| self.attention = BertSelfAttention( | |||
| batch_size=batch_size, | |||
| @@ -685,8 +674,7 @@ class BertEncoderCell(nn.Cell): | |||
| initializer_range=initializer_range, | |||
| hidden_dropout_prob=hidden_dropout_prob, | |||
| use_relative_positions=use_relative_positions, | |||
| compute_type=compute_type, | |||
| enable_fused_layernorm=enable_fused_layernorm) | |||
| compute_type=compute_type) | |||
| self.intermediate = nn.Dense(in_channels=hidden_size, | |||
| out_channels=intermediate_size, | |||
| activation=hidden_act, | |||
| @@ -695,8 +683,7 @@ class BertEncoderCell(nn.Cell): | |||
| out_channels=hidden_size, | |||
| initializer_range=initializer_range, | |||
| dropout_prob=hidden_dropout_prob, | |||
| compute_type=compute_type, | |||
| enable_fused_layernorm=enable_fused_layernorm) | |||
| compute_type=compute_type) | |||
| def construct(self, hidden_states, attention_mask): | |||
| """bert encoder cell""" | |||
| # self-attention | |||
| @@ -743,8 +730,7 @@ class BertTransformer(nn.Cell): | |||
| use_relative_positions=False, | |||
| hidden_act="gelu", | |||
| compute_type=mstype.float32, | |||
| return_all_encoders=False, | |||
| enable_fused_layernorm=False): | |||
| return_all_encoders=False): | |||
| super(BertTransformer, self).__init__() | |||
| self.return_all_encoders = return_all_encoders | |||
| layers = [] | |||
| @@ -760,8 +746,7 @@ class BertTransformer(nn.Cell): | |||
| hidden_dropout_prob=hidden_dropout_prob, | |||
| use_relative_positions=use_relative_positions, | |||
| hidden_act=hidden_act, | |||
| compute_type=compute_type, | |||
| enable_fused_layernorm=enable_fused_layernorm) | |||
| compute_type=compute_type) | |||
| layers.append(layer) | |||
| self.layers = nn.CellList(layers) | |||
| self.reshape = P.Reshape() | |||
| @@ -877,8 +862,7 @@ class BertModel(nn.Cell): | |||
| use_relative_positions=config.use_relative_positions, | |||
| hidden_act=config.hidden_act, | |||
| compute_type=config.compute_type, | |||
| return_all_encoders=True, | |||
| enable_fused_layernorm=config.enable_fused_layernorm) | |||
| return_all_encoders=True) | |||
| self.cast = P.Cast() | |||
| self.dtype = config.dtype | |||
| self.cast_compute_type = SaturateCast(dst_type=config.compute_type) | |||
| @@ -981,8 +965,7 @@ class TinyBertModel(nn.Cell): | |||
| use_relative_positions=config.use_relative_positions, | |||
| hidden_act=config.hidden_act, | |||
| compute_type=config.compute_type, | |||
| return_all_encoders=True, | |||
| enable_fused_layernorm=config.enable_fused_layernorm) | |||
| return_all_encoders=True) | |||
| self.cast = P.Cast() | |||
| self.dtype = config.dtype | |||
| self.cast_compute_type = SaturateCast(dst_type=config.compute_type) | |||
| @@ -82,8 +82,7 @@ def get_config(version='base', batch_size=1): | |||
| input_mask_from_dataset=True, | |||
| token_type_ids_from_dataset=True, | |||
| dtype=mstype.float32, | |||
| compute_type=mstype.float16, | |||
| enable_fused_layernorm=False) | |||
| compute_type=mstype.float16) | |||
| else: | |||
| bert_config = BertConfig(batch_size=batch_size) | |||
| return bert_config | |||
| @@ -82,8 +82,7 @@ def get_config(version='base', batch_size=1): | |||
| input_mask_from_dataset=True, | |||
| token_type_ids_from_dataset=True, | |||
| dtype=mstype.float32, | |||
| compute_type=mstype.float16, | |||
| enable_fused_layernorm=False) | |||
| compute_type=mstype.float16) | |||
| else: | |||
| bert_config = BertConfig(batch_size=batch_size) | |||
| return bert_config | |||