|
|
|
@@ -415,7 +415,7 @@ class GELU(Cell): |
|
|
|
|
|
|
|
def __init__(self): |
|
|
|
super(GELU, self).__init__() |
|
|
|
self.gelu = _selected_ops.Gelu() |
|
|
|
self.gelu = _selected_ops.GeLU() |
|
|
|
|
|
|
|
def construct(self, x): |
|
|
|
return self.gelu(x) |
|
|
|
@@ -458,7 +458,7 @@ class FastGelu(Cell): |
|
|
|
|
|
|
|
def __init__(self): |
|
|
|
super(FastGelu, self).__init__() |
|
|
|
self.fast_gelu = _selected_ops.FastGelu() |
|
|
|
self.fast_gelu = _selected_ops.FastGeLU() |
|
|
|
|
|
|
|
def construct(self, x): |
|
|
|
return self.fast_gelu(x) |
|
|
|
|