diff --git a/mindspore/ccsrc/transform/graph_ir/op_declare/nonlinear_fuc_ops_declare.cc b/mindspore/ccsrc/transform/graph_ir/op_declare/nonlinear_fuc_ops_declare.cc index 01904536d3..51c8aff5dd 100644 --- a/mindspore/ccsrc/transform/graph_ir/op_declare/nonlinear_fuc_ops_declare.cc +++ b/mindspore/ccsrc/transform/graph_ir/op_declare/nonlinear_fuc_ops_declare.cc @@ -112,4 +112,16 @@ INPUT_MAP(GeluGrad) = {{1, INPUT_DESC(dy)}, {2, INPUT_DESC(x)}, {3, INPUT_DESC(y ATTR_MAP(GeluGrad) = EMPTY_ATTR_MAP; OUTPUT_MAP(GeluGrad) = {{0, OUTPUT_DESC(z)}}; REG_ADPT_DESC(GeluGrad, prim::kPrimGeluGrad->name(), ADPT_DESC(GeluGrad)) + +// FastGelu +INPUT_MAP(FastGelu) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(FastGelu) = EMPTY_ATTR_MAP; +OUTPUT_MAP(FastGelu) = {{0, OUTPUT_DESC(y)}}; +REG_ADPT_DESC(FastGelu, prim::kPrimFastGelu->name(), ADPT_DESC(FastGelu)) + +// FastGeluGrad +INPUT_MAP(FastGeluGrad) = {{1, INPUT_DESC(dy)}, {2, INPUT_DESC(x)}}; +ATTR_MAP(FastGeluGrad) = EMPTY_ATTR_MAP; +OUTPUT_MAP(FastGeluGrad) = {{0, OUTPUT_DESC(z)}}; +REG_ADPT_DESC(FastGeluGrad, prim::kPrimFastGeluGrad->name(), ADPT_DESC(FastGeluGrad)) } // namespace mindspore::transform diff --git a/mindspore/ccsrc/transform/graph_ir/op_declare/nonlinear_fuc_ops_declare.h b/mindspore/ccsrc/transform/graph_ir/op_declare/nonlinear_fuc_ops_declare.h index bcd2d12c77..deda0fa58c 100644 --- a/mindspore/ccsrc/transform/graph_ir/op_declare/nonlinear_fuc_ops_declare.h +++ b/mindspore/ccsrc/transform/graph_ir/op_declare/nonlinear_fuc_ops_declare.h @@ -50,6 +50,12 @@ DECLARE_OP_USE_OUTPUT(Gelu) DECLARE_OP_ADAPTER(GeluGrad) DECLARE_OP_USE_OUTPUT(GeluGrad) +DECLARE_OP_ADAPTER(FastGelu) +DECLARE_OP_USE_OUTPUT(FastGelu) + +DECLARE_OP_ADAPTER(FastGeluGrad) +DECLARE_OP_USE_OUTPUT(FastGeluGrad) + DECLARE_OP_ADAPTER(Relu) DECLARE_OP_USE_OUTPUT(Relu) diff --git a/mindspore/core/abstract/infer_functions.h b/mindspore/core/abstract/infer_functions.h index 57d20c1675..4ac2b659ba 100644 --- a/mindspore/core/abstract/infer_functions.h +++ b/mindspore/core/abstract/infer_functions.h @@ -65,6 +65,10 @@ AbstractBasePtr InferImplGelu(const AnalysisEnginePtr &, const PrimitivePtr &pri const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplGeluGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplFastGelu(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplFastGeluGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplRelu(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplZerosLike(const AnalysisEnginePtr &, const PrimitivePtr &primitive, diff --git a/mindspore/core/base/core_ops.h b/mindspore/core/base/core_ops.h index 2de21c4a8d..32b699d785 100644 --- a/mindspore/core/base/core_ops.h +++ b/mindspore/core/base/core_ops.h @@ -171,6 +171,8 @@ inline const PrimitivePtr kPrimCudnnUniformReal = std::make_shared("C inline const PrimitivePtr kPrimOneHot = std::make_shared("OneHot"); inline const PrimitivePtr kPrimGelu = std::make_shared("Gelu"); inline const PrimitivePtr kPrimGeluGrad = std::make_shared("GeluGrad"); +inline const PrimitivePtr kPrimFastGelu = std::make_shared("FastGelu"); +inline const PrimitivePtr kPrimFastGeluGrad = std::make_shared("FastGeluGrad"); inline const PrimitivePtr kPrimRelu = std::make_shared("ReLU"); inline const PrimitivePtr kPrimRelu6 = std::make_shared("ReLU6"); inline const PrimitivePtr kPrimReluV2 = std::make_shared("ReLUV2"); diff --git a/mindspore/nn/layer/activation.py b/mindspore/nn/layer/activation.py index 53693e9fb9..f72a44d95b 100644 --- a/mindspore/nn/layer/activation.py +++ b/mindspore/nn/layer/activation.py @@ -31,6 +31,7 @@ __all__ = ['Softmax', 'ReLU6', 'Tanh', 'GELU', + 'FastGelu', 'Sigmoid', 'PReLU', 'get_activation', @@ -357,6 +358,39 @@ class GELU(Cell): return self.gelu(x) +class FastGelu(Cell): + r""" + fast Gaussian error linear unit activation function. + + Applies FastGelu function to each element of the input. The input is a Tensor with any valid shape. + + FastGelu is defined as: + :math:`FastGelu(x_i) = \frac {x_i} {1 + \exp(-1.702 * \left| x_i \right|)} * + \exp(0.851 * (x_i - \left| x_i \right|))`, where :math:`x_i` is the element of the input. + + Inputs: + - **input_data** (Tensor) - The input of FastGelu with data type of float16 or float32. + + Outputs: + Tensor, with the same type and shape as the `input_data`. + + Examples: + >>> input_x = Tensor(np.array([[-1.0, 4.0, -8.0], [2.0, -5.0, 9.0]]), mindspore.float32) + >>> fast_gelu = nn.FastGelu() + >>> output = fast_gelu(input_x) + >>> print(output) + [[-1.5420423e-01 3.9955849e+00 -9.7664278e-06] + [ 1.9356585e+00 -1.0070159e-03 8.9999981e+00]] + """ + + def __init__(self): + super(FastGelu, self).__init__() + self.fast_gelu = _selected_ops.FastGelu() + + def construct(self, x): + return self.fast_gelu(x) + + class Sigmoid(Cell): r""" Sigmoid activation function. @@ -582,6 +616,7 @@ _activation = { 'relu6': ReLU6, 'tanh': Tanh, 'gelu': GELU, + 'fast_gelu': FastGelu, 'elu': ELU, 'sigmoid': Sigmoid, 'prelu': PReLU, diff --git a/mindspore/ops/_grad/grad_nn_ops.py b/mindspore/ops/_grad/grad_nn_ops.py index 1f87191de5..5aec089dfc 100755 --- a/mindspore/ops/_grad/grad_nn_ops.py +++ b/mindspore/ops/_grad/grad_nn_ops.py @@ -563,6 +563,18 @@ def get_bprop_gelu(self): return bprop +@bprop_getters.register(P.FastGelu) +def get_bprop_fast_gelu(self): + """Grad definition for `FastGelu` operation.""" + input_grad = G.FastGeluGrad() + + def bprop(x, out, dout): + dx = input_grad(dout, x) + return (dx,) + + return bprop + + @bprop_getters.register(P.FusedBatchNorm) def get_bprop_fused_batch_norm(self): """Grad definition for `FusedBatchNorm` operation.""" diff --git a/mindspore/ops/_op_impl/tbe/__init__.py b/mindspore/ops/_op_impl/tbe/__init__.py index b85c192bb2..990a020334 100644 --- a/mindspore/ops/_op_impl/tbe/__init__.py +++ b/mindspore/ops/_op_impl/tbe/__init__.py @@ -58,6 +58,8 @@ from .confusion_mul_grad import _confusion_mul_grad_tbe from .dropout_do_mask import _dropout_do_mask_tbe from .gelu import _gelu_tbe from .gelu_grad import _gelu_grad_tbe +from .fast_gelu import _fast_gelu_tbe +from .fast_gelu_grad import _fast_gelu_grad_tbe from .max_pool import _max_pool_tbe from .max_pool_grad import _max_pool_grad_tbe from .max_pool_grad_with_argmax import _max_pool_grad_with_argmax_tbe diff --git a/mindspore/ops/_op_impl/tbe/fast_gelu.py b/mindspore/ops/_op_impl/tbe/fast_gelu.py new file mode 100644 index 0000000000..f108f7af7f --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/fast_gelu.py @@ -0,0 +1,37 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""FastGelu op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +fast_gelu_op_info = TBERegOp("FastGelu") \ + .fusion_type("ELEMWISE") \ + .async_flag(False) \ + .binfile_name("fast_gelu.so") \ + .compute_cost(10) \ + .kernel_name("fast_gelu") \ + .partial_flag(True) \ + .input(0, "x", False, "required", "all") \ + .output(0, "y", False, "required", "all") \ + .op_pattern("formatAgnostic") \ + .dtype_format(DataType.F16_None, DataType.F16_None) \ + .dtype_format(DataType.F32_None, DataType.F32_None) \ + .get_op_info() + + +@op_info_register(fast_gelu_op_info) +def _fast_gelu_tbe(): + """FastGelu TBE register""" + return diff --git a/mindspore/ops/_op_impl/tbe/fast_gelu_grad.py b/mindspore/ops/_op_impl/tbe/fast_gelu_grad.py new file mode 100644 index 0000000000..18baa6ec48 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/fast_gelu_grad.py @@ -0,0 +1,41 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""FastGeluGrad op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +fast_gelu_grad_op_info = TBERegOp("FastGeluGrad") \ + .fusion_type("ELEMWISE") \ + .async_flag(False) \ + .binfile_name("fast_gelu_grad.so") \ + .compute_cost(10) \ + .kernel_name("fast_gelu_grad") \ + .partial_flag(True) \ + .input(0, "dy", False, "required", "all") \ + .input(1, "x", False, "required", "all") \ + .output(0, "z", True, "required", "all") \ + .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \ + .dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \ + .dtype_format(DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ) \ + .get_op_info() + + +@op_info_register(fast_gelu_grad_op_info) +def _fast_gelu_grad_tbe(): + """FastGeluGrad TBE register""" + return diff --git a/mindspore/ops/_selected_ops.py b/mindspore/ops/_selected_ops.py index 5e125025c9..2d816f58f8 100644 --- a/mindspore/ops/_selected_ops.py +++ b/mindspore/ops/_selected_ops.py @@ -84,6 +84,12 @@ class Gelu: pass +@op_selector +class FastGelu: + def __call__(self, *args): + pass + + @op_selector class LayerNorm: def __call__(self, *args): diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index 8f59cb1bed..85c0e6c1dd 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -63,7 +63,7 @@ from .nn_ops import (LSTM, SGD, Adam, FusedSparseAdam, FusedSparseLazyAdam, Adam DepthwiseConv2dNative, DropoutDoMask, Dropout, DropoutGenMask, Flatten, FusedBatchNorm, FusedBatchNormEx, BNTrainingReduce, BNTrainingUpdate, - Gelu, Elu, + Gelu, FastGelu, Elu, GetNext, L2Normalize, LayerNorm, L2Loss, CTCLoss, CTCGreedyDecoder, LogSoftmax, MaxPool, DataFormatDimMap, @@ -165,6 +165,7 @@ __all__ = [ 'Tile', 'BiasAdd', 'Gelu', + 'FastGelu', 'Minimum', 'Maximum', 'StridedSlice', diff --git a/mindspore/ops/operations/_grad_ops.py b/mindspore/ops/operations/_grad_ops.py index 5395eefa74..e9147f9288 100644 --- a/mindspore/ops/operations/_grad_ops.py +++ b/mindspore/ops/operations/_grad_ops.py @@ -637,6 +637,24 @@ class GeluGrad(PrimitiveWithInfer): return x_dtype +class FastGeluGrad(PrimitiveWithInfer): + """Gradients of FastGelu operation.""" + + @prim_attr_register + def __init__(self): + """init FastGeluGrad""" + + def infer_shape(self, y_backprop_shape, x_shape): + return x_shape + + def infer_dtype(self, y_backprop_dtype, x_dtype): + tuple(map(partial(validator.check_tensor_dtype_valid, + valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name), + ("y_backprop", "x"), + (y_backprop_dtype, x_dtype))) + return x_dtype + + class _PoolGrad(PrimitiveWithInfer): """Gradients of the max/avg pool operation.""" diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index f62e5192a1..d76d2815c0 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -2957,6 +2957,45 @@ class Gelu(PrimitiveWithInfer): return input_x +class FastGelu(PrimitiveWithInfer): + r""" + fast Gaussian Error Linear Units activation function. + + FastGelu is defined as follows: + + .. math:: + \text{output} = \frac {x} {1 + \exp(-1.702 * \left| x \right|)} * \exp(0.851 * (x - \left| x \right|))`, + + where :math:`x` is the element of the input. + + Inputs: + - **input_x** (Tensor) - Input to compute the FastGelu with data type of float16 or float32. + + Outputs: + Tensor, with the same type and shape as input. + + Examples: + >>> tensor = Tensor(np.array([[-1.0, 4.0, -8.0], [2.0, -5.0, 9.0]]), mindspore.float32) + >>> fast_gelu = P.FastGelu() + >>> output = fast_gelu(tensor) + >>> print(output) + [[-1.5420423e-01 3.9955849e+00 -9.7664278e-06] + [ 1.9356585e+00 -1.0070159e-03 8.9999981e+00]] + """ + + @prim_attr_register + def __init__(self): + """init FastGeLU""" + self.init_prim_io_names(inputs=['x'], outputs=['output']) + + def infer_shape(self, input_x): + return input_x + + def infer_dtype(self, input_x): + validator.check_tensor_dtype_valid("input_x", input_x, (mstype.float16, mstype.float32), self.name) + return input_x + + class GetNext(PrimitiveWithInfer): """ Returns the next element in the dataset queue. diff --git a/model_zoo/official/nlp/bert/src/config.py b/model_zoo/official/nlp/bert/src/config.py index e5e8c8b49f..39b29a4540 100644 --- a/model_zoo/official/nlp/bert/src/config.py +++ b/model_zoo/official/nlp/bert/src/config.py @@ -103,7 +103,7 @@ if cfg.bert_network == 'large': num_hidden_layers=24, num_attention_heads=16, intermediate_size=4096, - hidden_act="gelu", + hidden_act="fast_gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, diff --git a/tests/st/ops/ascend/test_tbe_ops/test_fast_gelu.py b/tests/st/ops/ascend/test_tbe_ops/test_fast_gelu.py new file mode 100644 index 0000000000..21b2116b29 --- /dev/null +++ b/tests/st/ops/ascend/test_tbe_ops/test_fast_gelu.py @@ -0,0 +1,122 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +import pytest + +from mindspore import context +from mindspore.common.tensor import Tensor +from mindspore.nn import FastGelu +from mindspore.train.model import Model + +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + + +def fast_gelu_forward_me_impl(input_): + n = FastGelu() + n.set_train() + m = Model(n) + out = m.predict(input_) + return out.asnumpy() + + +def fast_gelu_forward_cmp(input_shape, data_type=np.float32): + input_np = np.random.randn(*input_shape).astype(data_type) + input_me = Tensor(input_np) + fast_gelu_forward_me_impl(input_me) + + +def test_fast_gelu_input_dim_0(): + input_shape = [0] + with pytest.raises(ValueError): + fast_gelu_forward_cmp(input_shape) + + +def test_fast_gelu_input_dim_10240_1024(): + input_shape = [10240, 1024] + fast_gelu_forward_cmp(input_shape) + + +def test_fast_gelu_input_dim_10240_768(): + input_shape = [10240, 768] + fast_gelu_forward_cmp(input_shape) + + +@pytest.mark.lower_bs +def test_fast_gelu_input_dim_1024_3072(): + input_shape = [1024, 3072] + fast_gelu_forward_cmp(input_shape) + + +@pytest.mark.lower_bs +def test_fast_gelu_input_dim_1024_4096(): + input_shape = [1024, 4096] + fast_gelu_forward_cmp(input_shape) + + +def test_fast_gelu_input_dim_1280_1024(): + input_shape = [1280, 1024] + fast_gelu_forward_cmp(input_shape) + + +@pytest.mark.lower_bs +def test_fast_gelu_input_dim_1280_768(): + input_shape = [1280, 768] + fast_gelu_forward_cmp(input_shape) + + +@pytest.mark.lower_bs +def test_fast_gelu_input_dim_128_3072(): + input_shape = [128, 3072] + fast_gelu_forward_cmp(input_shape) + + +@pytest.mark.lower_bs +def test_fast_gelu_input_dim_128_4096(): + input_shape = [128, 4096] + fast_gelu_forward_cmp(input_shape) + + +@pytest.mark.lower_bs +def test_fast_gelu_input_dim_160_1024(): + input_shape = [160, 1024] + fast_gelu_forward_cmp(input_shape) + + +@pytest.mark.lower_bs +def test_fast_gelu_input_dim_160_768(): + input_shape = [160, 768] + fast_gelu_forward_cmp(input_shape) + + +@pytest.mark.lower_bs +def test_fast_gelu_input_dim_16384_3072(): + input_shape = [16384, 3072] + fast_gelu_forward_cmp(input_shape) + + +def test_fast_gelu_input_dim_16384_4096(): + input_shape = [16384, 4096] + fast_gelu_forward_cmp(input_shape) + + +@pytest.mark.lower_bs +def test_fast_gelu_input_dim_20_1024(): + input_shape = [20, 1024] + fast_gelu_forward_cmp(input_shape) + + +def test_fast_gelu_input_dim_20480_1024(): + input_shape = [20480, 1024] + fast_gelu_forward_cmp(input_shape) diff --git a/tests/st/ops/ascend/test_tbe_ops/test_fast_gelu_grad_sens.py b/tests/st/ops/ascend/test_tbe_ops/test_fast_gelu_grad_sens.py new file mode 100755 index 0000000000..639bbe93d9 --- /dev/null +++ b/tests/st/ops/ascend/test_tbe_ops/test_fast_gelu_grad_sens.py @@ -0,0 +1,91 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np + +from mindspore import context +from mindspore import log as logger +from mindspore.common.tensor import Tensor +from mindspore.nn import Cell, FastGelu +from mindspore.ops import operations as P +from mindspore.ops.composite import GradOperation + +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + + +class Grad(Cell): + def __init__(self, network): + super(Grad, self).__init__() + self.grad = GradOperation(get_all=True, sens_param=True) + self.network = network + + def construct(self, input_, output_grad): + return self.grad(self.network)(input_, output_grad) + + +def fast_gelu_backward_me_impl(input_, output_grad): + n = FastGelu() + grad_with_sense = Grad(n) + grad_with_sense.set_train() + input_grad = grad_with_sense(input_, output_grad) + return input_grad + + +def fast_gelu_backward_cmp(input_shape): + input_np = np.random.randn(*input_shape).astype(np.float32) + input_me = Tensor(input_np) + + output_grad_shape = input_shape + output_grad_np = np.random.randn(*output_grad_shape).astype(np.float32) + output_grad_me = Tensor(output_grad_np) + + output_grad_me = fast_gelu_backward_me_impl(input_me, output_grad_me) + logger.info("---------me--------") + logger.info(output_grad_me) + + +# ---------- LARGE INPUT --------------- + +class MEGeluLargeIn(Cell): + def __init__(self): + super(MEGeluLargeIn, self).__init__() + self.matmul = P.MatMul() + self.fast_gelu = P.Gelu() + + def construct(self, x1, x2): + x = self.matmul(x1, x2) + return self.fast_gelu(x) + + +class GradLargeIn(Cell): + def __init__(self, network): + super(GradLargeIn, self).__init__() + self.grad = GradOperation(get_all=True, sens_param=True) + self.network = network + + def construct(self, x1, x2, output_grad): + return self.grad(self.network)(x1, x2, output_grad) + + +def fast_gelu_backward_me_large_in_impl(x1, x2, output_grad): + n = FastGelu() + grad_with_sense = GradLargeIn(n) + grad_with_sense.set_train() + input_grad = grad_with_sense(x1, x2, output_grad) + return input_grad[0].asnumpy(), input_grad[1].asnumpy() + + +def test_grad_fast_gelu_input_10240_1024(): + input_shape = [10240, 1024] + fast_gelu_backward_cmp(input_shape)