|
|
|
@@ -24,6 +24,7 @@ from mindspore.ops import operations as P |
|
|
|
from mindspore.ops import composite as C |
|
|
|
from mindspore.common.tensor import Tensor |
|
|
|
from mindspore.common.parameter import Parameter |
|
|
|
from mindspore import context |
|
|
|
from .fused_layer_norm import FusedLayerNorm |
|
|
|
|
|
|
|
|
|
|
|
@@ -250,11 +251,16 @@ class BertOutput(nn.Cell): |
|
|
|
weight_init=TruncatedNormal(initializer_range)).to_float(compute_type) |
|
|
|
self.dropout = nn.Dropout(1 - dropout_prob) |
|
|
|
self.add = P.TensorAdd() |
|
|
|
if compute_type == mstype.float16: |
|
|
|
self.layernorm = FusedLayerNorm((out_channels,), |
|
|
|
use_batch_norm=enable_fused_layernorm).to_float(compute_type) |
|
|
|
self.is_gpu = context.get_context('device_target') == "GPU" |
|
|
|
if self.is_gpu: |
|
|
|
self.layernorm = nn.LayerNorm((out_channels,)).to_float(mstype.float32) |
|
|
|
self.compute_type = compute_type |
|
|
|
else: |
|
|
|
self.layernorm = nn.LayerNorm((out_channels,)).to_float(compute_type) |
|
|
|
if compute_type == mstype.float16: |
|
|
|
self.layernorm = FusedLayerNorm((out_channels,), |
|
|
|
use_batch_norm=enable_fused_layernorm).to_float(compute_type) |
|
|
|
else: |
|
|
|
self.layernorm = nn.LayerNorm((out_channels,)).to_float(compute_type) |
|
|
|
|
|
|
|
self.cast = P.Cast() |
|
|
|
|
|
|
|
@@ -264,6 +270,8 @@ class BertOutput(nn.Cell): |
|
|
|
output = self.dropout(output) |
|
|
|
output = self.add(input_tensor, output) |
|
|
|
output = self.layernorm(output) |
|
|
|
if self.is_gpu: |
|
|
|
output = self.cast(output, self.compute_type) |
|
|
|
return output |
|
|
|
|
|
|
|
|
|
|
|
|