Browse Source

!14617 Fix Gelu in select ops

From: @liangzhibo
Reviewed-by: @ginfung,@zh_qh
Signed-off-by: @zh_qh
pull/14617/MERGE
mindspore-ci-bot Gitee 4 years ago
parent
commit
fc1c0e0952
3 changed files with 5 additions and 5 deletions
  1. +2
    -2
      mindspore/nn/layer/activation.py
  2. +2
    -2
      mindspore/ops/_selected_ops.py
  3. +1
    -1
      model_zoo/research/nlp/gpt2/src/GPT2_model.py

+ 2
- 2
mindspore/nn/layer/activation.py View File

@@ -415,7 +415,7 @@ class GELU(Cell):

def __init__(self):
super(GELU, self).__init__()
self.gelu = _selected_ops.Gelu()
self.gelu = _selected_ops.GeLU()

def construct(self, x):
return self.gelu(x)
@@ -458,7 +458,7 @@ class FastGelu(Cell):

def __init__(self):
super(FastGelu, self).__init__()
self.fast_gelu = _selected_ops.FastGelu()
self.fast_gelu = _selected_ops.FastGeLU()

def construct(self, x):
return self.fast_gelu(x)


+ 2
- 2
mindspore/ops/_selected_ops.py View File

@@ -73,13 +73,13 @@ class Tanh:


@op_selector
class Gelu:
class GeLU:
def __call__(self, *args):
pass


@op_selector
class FastGelu:
class FastGeLU:
def __call__(self, *args):
pass



+ 1
- 1
model_zoo/research/nlp/gpt2/src/GPT2_model.py View File

@@ -499,7 +499,7 @@ class FeedForward(nn.Cell):

self.layernorm = LayerNorm(in_channels=in_channels)
self.residual_connect = ResidualConnection(dropout_prob=hidden_dropout)
self.gelu_act = P.Gelu()
self.gelu_act = P.GeLU()
self.dropout = nn.Dropout(1 - hidden_dropout)
self.use_dropout = hidden_dropout > 0
self.reshape = P.Reshape()


Loading…
Cancel
Save