|
|
|
@@ -167,7 +167,7 @@ class BertAttentionMask(nn.Cell): |
|
|
|
|
|
|
|
super(BertAttentionMask, self).__init__() |
|
|
|
self.has_attention_mask = has_attention_mask |
|
|
|
self.multiply_data = Tensor([-1000.0, ], dtype=dtype) |
|
|
|
self.multiply_data = Tensor([-1000.0,], dtype=dtype) |
|
|
|
self.multiply = P.Mul() |
|
|
|
|
|
|
|
if self.has_attention_mask: |
|
|
|
@@ -198,7 +198,7 @@ class BertAttentionMaskBackward(nn.Cell): |
|
|
|
dtype=mstype.float32): |
|
|
|
super(BertAttentionMaskBackward, self).__init__() |
|
|
|
self.has_attention_mask = has_attention_mask |
|
|
|
self.multiply_data = Tensor([-1000.0, ], dtype=dtype) |
|
|
|
self.multiply_data = Tensor([-1000.0,], dtype=dtype) |
|
|
|
self.multiply = P.Mul() |
|
|
|
self.attention_mask = Tensor(np.ones(shape=attention_mask_shape).astype(np.float32)) |
|
|
|
if self.has_attention_mask: |
|
|
|
|