Browse Source

!9039 register FastGelu for activation

From: @shibeiji
Reviewed-by: 
Signed-off-by:
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
0aa63f21c5
16 changed files with 430 additions and 2 deletions
  1. +12
    -0
      mindspore/ccsrc/transform/graph_ir/op_declare/nonlinear_fuc_ops_declare.cc
  2. +6
    -0
      mindspore/ccsrc/transform/graph_ir/op_declare/nonlinear_fuc_ops_declare.h
  3. +4
    -0
      mindspore/core/abstract/infer_functions.h
  4. +2
    -0
      mindspore/core/base/core_ops.h
  5. +35
    -0
      mindspore/nn/layer/activation.py
  6. +12
    -0
      mindspore/ops/_grad/grad_nn_ops.py
  7. +2
    -0
      mindspore/ops/_op_impl/tbe/__init__.py
  8. +37
    -0
      mindspore/ops/_op_impl/tbe/fast_gelu.py
  9. +41
    -0
      mindspore/ops/_op_impl/tbe/fast_gelu_grad.py
  10. +6
    -0
      mindspore/ops/_selected_ops.py
  11. +2
    -1
      mindspore/ops/operations/__init__.py
  12. +18
    -0
      mindspore/ops/operations/_grad_ops.py
  13. +39
    -0
      mindspore/ops/operations/nn_ops.py
  14. +1
    -1
      model_zoo/official/nlp/bert/src/config.py
  15. +122
    -0
      tests/st/ops/ascend/test_tbe_ops/test_fast_gelu.py
  16. +91
    -0
      tests/st/ops/ascend/test_tbe_ops/test_fast_gelu_grad_sens.py

+ 12
- 0
mindspore/ccsrc/transform/graph_ir/op_declare/nonlinear_fuc_ops_declare.cc View File

@@ -112,4 +112,16 @@ INPUT_MAP(GeluGrad) = {{1, INPUT_DESC(dy)}, {2, INPUT_DESC(x)}, {3, INPUT_DESC(y
ATTR_MAP(GeluGrad) = EMPTY_ATTR_MAP; ATTR_MAP(GeluGrad) = EMPTY_ATTR_MAP;
OUTPUT_MAP(GeluGrad) = {{0, OUTPUT_DESC(z)}}; OUTPUT_MAP(GeluGrad) = {{0, OUTPUT_DESC(z)}};
REG_ADPT_DESC(GeluGrad, prim::kPrimGeluGrad->name(), ADPT_DESC(GeluGrad)) REG_ADPT_DESC(GeluGrad, prim::kPrimGeluGrad->name(), ADPT_DESC(GeluGrad))

// FastGelu
INPUT_MAP(FastGelu) = {{1, INPUT_DESC(x)}};
ATTR_MAP(FastGelu) = EMPTY_ATTR_MAP;
OUTPUT_MAP(FastGelu) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(FastGelu, prim::kPrimFastGelu->name(), ADPT_DESC(FastGelu))

// FastGeluGrad
INPUT_MAP(FastGeluGrad) = {{1, INPUT_DESC(dy)}, {2, INPUT_DESC(x)}};
ATTR_MAP(FastGeluGrad) = EMPTY_ATTR_MAP;
OUTPUT_MAP(FastGeluGrad) = {{0, OUTPUT_DESC(z)}};
REG_ADPT_DESC(FastGeluGrad, prim::kPrimFastGeluGrad->name(), ADPT_DESC(FastGeluGrad))
} // namespace mindspore::transform } // namespace mindspore::transform

+ 6
- 0
mindspore/ccsrc/transform/graph_ir/op_declare/nonlinear_fuc_ops_declare.h View File

@@ -50,6 +50,12 @@ DECLARE_OP_USE_OUTPUT(Gelu)
DECLARE_OP_ADAPTER(GeluGrad) DECLARE_OP_ADAPTER(GeluGrad)
DECLARE_OP_USE_OUTPUT(GeluGrad) DECLARE_OP_USE_OUTPUT(GeluGrad)


DECLARE_OP_ADAPTER(FastGelu)
DECLARE_OP_USE_OUTPUT(FastGelu)

DECLARE_OP_ADAPTER(FastGeluGrad)
DECLARE_OP_USE_OUTPUT(FastGeluGrad)

DECLARE_OP_ADAPTER(Relu) DECLARE_OP_ADAPTER(Relu)
DECLARE_OP_USE_OUTPUT(Relu) DECLARE_OP_USE_OUTPUT(Relu)




+ 4
- 0
mindspore/core/abstract/infer_functions.h View File

@@ -65,6 +65,10 @@ AbstractBasePtr InferImplGelu(const AnalysisEnginePtr &, const PrimitivePtr &pri
const AbstractBasePtrList &args_spec_list); const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplGeluGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, AbstractBasePtr InferImplGeluGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list); const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplFastGelu(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplFastGeluGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplRelu(const AnalysisEnginePtr &, const PrimitivePtr &primitive, AbstractBasePtr InferImplRelu(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list); const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplZerosLike(const AnalysisEnginePtr &, const PrimitivePtr &primitive, AbstractBasePtr InferImplZerosLike(const AnalysisEnginePtr &, const PrimitivePtr &primitive,


+ 2
- 0
mindspore/core/base/core_ops.h View File

@@ -171,6 +171,8 @@ inline const PrimitivePtr kPrimCudnnUniformReal = std::make_shared<Primitive>("C
inline const PrimitivePtr kPrimOneHot = std::make_shared<Primitive>("OneHot"); inline const PrimitivePtr kPrimOneHot = std::make_shared<Primitive>("OneHot");
inline const PrimitivePtr kPrimGelu = std::make_shared<Primitive>("Gelu"); inline const PrimitivePtr kPrimGelu = std::make_shared<Primitive>("Gelu");
inline const PrimitivePtr kPrimGeluGrad = std::make_shared<Primitive>("GeluGrad"); inline const PrimitivePtr kPrimGeluGrad = std::make_shared<Primitive>("GeluGrad");
inline const PrimitivePtr kPrimFastGelu = std::make_shared<Primitive>("FastGelu");
inline const PrimitivePtr kPrimFastGeluGrad = std::make_shared<Primitive>("FastGeluGrad");
inline const PrimitivePtr kPrimRelu = std::make_shared<Primitive>("ReLU"); inline const PrimitivePtr kPrimRelu = std::make_shared<Primitive>("ReLU");
inline const PrimitivePtr kPrimRelu6 = std::make_shared<Primitive>("ReLU6"); inline const PrimitivePtr kPrimRelu6 = std::make_shared<Primitive>("ReLU6");
inline const PrimitivePtr kPrimReluV2 = std::make_shared<Primitive>("ReLUV2"); inline const PrimitivePtr kPrimReluV2 = std::make_shared<Primitive>("ReLUV2");


+ 35
- 0
mindspore/nn/layer/activation.py View File

@@ -31,6 +31,7 @@ __all__ = ['Softmax',
'ReLU6', 'ReLU6',
'Tanh', 'Tanh',
'GELU', 'GELU',
'FastGelu',
'Sigmoid', 'Sigmoid',
'PReLU', 'PReLU',
'get_activation', 'get_activation',
@@ -357,6 +358,39 @@ class GELU(Cell):
return self.gelu(x) return self.gelu(x)




class FastGelu(Cell):
r"""
fast Gaussian error linear unit activation function.

Applies FastGelu function to each element of the input. The input is a Tensor with any valid shape.

FastGelu is defined as:
:math:`FastGelu(x_i) = \frac {x_i} {1 + \exp(-1.702 * \left| x_i \right|)} *
\exp(0.851 * (x_i - \left| x_i \right|))`, where :math:`x_i` is the element of the input.

Inputs:
- **input_data** (Tensor) - The input of FastGelu with data type of float16 or float32.

Outputs:
Tensor, with the same type and shape as the `input_data`.

Examples:
>>> input_x = Tensor(np.array([[-1.0, 4.0, -8.0], [2.0, -5.0, 9.0]]), mindspore.float32)
>>> fast_gelu = nn.FastGelu()
>>> output = fast_gelu(input_x)
>>> print(output)
[[-1.5420423e-01 3.9955849e+00 -9.7664278e-06]
[ 1.9356585e+00 -1.0070159e-03 8.9999981e+00]]
"""

def __init__(self):
super(FastGelu, self).__init__()
self.fast_gelu = _selected_ops.FastGelu()

def construct(self, x):
return self.fast_gelu(x)


class Sigmoid(Cell): class Sigmoid(Cell):
r""" r"""
Sigmoid activation function. Sigmoid activation function.
@@ -582,6 +616,7 @@ _activation = {
'relu6': ReLU6, 'relu6': ReLU6,
'tanh': Tanh, 'tanh': Tanh,
'gelu': GELU, 'gelu': GELU,
'fast_gelu': FastGelu,
'elu': ELU, 'elu': ELU,
'sigmoid': Sigmoid, 'sigmoid': Sigmoid,
'prelu': PReLU, 'prelu': PReLU,


+ 12
- 0
mindspore/ops/_grad/grad_nn_ops.py View File

@@ -563,6 +563,18 @@ def get_bprop_gelu(self):
return bprop return bprop




@bprop_getters.register(P.FastGelu)
def get_bprop_fast_gelu(self):
"""Grad definition for `FastGelu` operation."""
input_grad = G.FastGeluGrad()

def bprop(x, out, dout):
dx = input_grad(dout, x)
return (dx,)

return bprop


@bprop_getters.register(P.FusedBatchNorm) @bprop_getters.register(P.FusedBatchNorm)
def get_bprop_fused_batch_norm(self): def get_bprop_fused_batch_norm(self):
"""Grad definition for `FusedBatchNorm` operation.""" """Grad definition for `FusedBatchNorm` operation."""


+ 2
- 0
mindspore/ops/_op_impl/tbe/__init__.py View File

@@ -58,6 +58,8 @@ from .confusion_mul_grad import _confusion_mul_grad_tbe
from .dropout_do_mask import _dropout_do_mask_tbe from .dropout_do_mask import _dropout_do_mask_tbe
from .gelu import _gelu_tbe from .gelu import _gelu_tbe
from .gelu_grad import _gelu_grad_tbe from .gelu_grad import _gelu_grad_tbe
from .fast_gelu import _fast_gelu_tbe
from .fast_gelu_grad import _fast_gelu_grad_tbe
from .max_pool import _max_pool_tbe from .max_pool import _max_pool_tbe
from .max_pool_grad import _max_pool_grad_tbe from .max_pool_grad import _max_pool_grad_tbe
from .max_pool_grad_with_argmax import _max_pool_grad_with_argmax_tbe from .max_pool_grad_with_argmax import _max_pool_grad_with_argmax_tbe


+ 37
- 0
mindspore/ops/_op_impl/tbe/fast_gelu.py View File

@@ -0,0 +1,37 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""FastGelu op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

fast_gelu_op_info = TBERegOp("FastGelu") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("fast_gelu.so") \
.compute_cost(10) \
.kernel_name("fast_gelu") \
.partial_flag(True) \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.op_pattern("formatAgnostic") \
.dtype_format(DataType.F16_None, DataType.F16_None) \
.dtype_format(DataType.F32_None, DataType.F32_None) \
.get_op_info()


@op_info_register(fast_gelu_op_info)
def _fast_gelu_tbe():
"""FastGelu TBE register"""
return

+ 41
- 0
mindspore/ops/_op_impl/tbe/fast_gelu_grad.py View File

@@ -0,0 +1,41 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""FastGeluGrad op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

fast_gelu_grad_op_info = TBERegOp("FastGeluGrad") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("fast_gelu_grad.so") \
.compute_cost(10) \
.kernel_name("fast_gelu_grad") \
.partial_flag(True) \
.input(0, "dy", False, "required", "all") \
.input(1, "x", False, "required", "all") \
.output(0, "z", True, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
.dtype_format(DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ) \
.get_op_info()


@op_info_register(fast_gelu_grad_op_info)
def _fast_gelu_grad_tbe():
"""FastGeluGrad TBE register"""
return

+ 6
- 0
mindspore/ops/_selected_ops.py View File

@@ -84,6 +84,12 @@ class Gelu:
pass pass




@op_selector
class FastGelu:
def __call__(self, *args):
pass


@op_selector @op_selector
class LayerNorm: class LayerNorm:
def __call__(self, *args): def __call__(self, *args):


+ 2
- 1
mindspore/ops/operations/__init__.py View File

@@ -63,7 +63,7 @@ from .nn_ops import (LSTM, SGD, Adam, FusedSparseAdam, FusedSparseLazyAdam, Adam
DepthwiseConv2dNative, DepthwiseConv2dNative,
DropoutDoMask, Dropout, DropoutDoMask, Dropout,
DropoutGenMask, Flatten, FusedBatchNorm, FusedBatchNormEx, BNTrainingReduce, BNTrainingUpdate, DropoutGenMask, Flatten, FusedBatchNorm, FusedBatchNormEx, BNTrainingReduce, BNTrainingUpdate,
Gelu, Elu,
Gelu, FastGelu, Elu,
GetNext, L2Normalize, LayerNorm, L2Loss, CTCLoss, CTCGreedyDecoder, GetNext, L2Normalize, LayerNorm, L2Loss, CTCLoss, CTCGreedyDecoder,
LogSoftmax, LogSoftmax,
MaxPool, DataFormatDimMap, MaxPool, DataFormatDimMap,
@@ -165,6 +165,7 @@ __all__ = [
'Tile', 'Tile',
'BiasAdd', 'BiasAdd',
'Gelu', 'Gelu',
'FastGelu',
'Minimum', 'Minimum',
'Maximum', 'Maximum',
'StridedSlice', 'StridedSlice',


+ 18
- 0
mindspore/ops/operations/_grad_ops.py View File

@@ -637,6 +637,24 @@ class GeluGrad(PrimitiveWithInfer):
return x_dtype return x_dtype




class FastGeluGrad(PrimitiveWithInfer):
"""Gradients of FastGelu operation."""

@prim_attr_register
def __init__(self):
"""init FastGeluGrad"""

def infer_shape(self, y_backprop_shape, x_shape):
return x_shape

def infer_dtype(self, y_backprop_dtype, x_dtype):
tuple(map(partial(validator.check_tensor_dtype_valid,
valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name),
("y_backprop", "x"),
(y_backprop_dtype, x_dtype)))
return x_dtype


class _PoolGrad(PrimitiveWithInfer): class _PoolGrad(PrimitiveWithInfer):
"""Gradients of the max/avg pool operation.""" """Gradients of the max/avg pool operation."""




+ 39
- 0
mindspore/ops/operations/nn_ops.py View File

@@ -2957,6 +2957,45 @@ class Gelu(PrimitiveWithInfer):
return input_x return input_x




class FastGelu(PrimitiveWithInfer):
r"""
fast Gaussian Error Linear Units activation function.

FastGelu is defined as follows:

.. math::
\text{output} = \frac {x} {1 + \exp(-1.702 * \left| x \right|)} * \exp(0.851 * (x - \left| x \right|))`,

where :math:`x` is the element of the input.

Inputs:
- **input_x** (Tensor) - Input to compute the FastGelu with data type of float16 or float32.

Outputs:
Tensor, with the same type and shape as input.

Examples:
>>> tensor = Tensor(np.array([[-1.0, 4.0, -8.0], [2.0, -5.0, 9.0]]), mindspore.float32)
>>> fast_gelu = P.FastGelu()
>>> output = fast_gelu(tensor)
>>> print(output)
[[-1.5420423e-01 3.9955849e+00 -9.7664278e-06]
[ 1.9356585e+00 -1.0070159e-03 8.9999981e+00]]
"""

@prim_attr_register
def __init__(self):
"""init FastGeLU"""
self.init_prim_io_names(inputs=['x'], outputs=['output'])

def infer_shape(self, input_x):
return input_x

def infer_dtype(self, input_x):
validator.check_tensor_dtype_valid("input_x", input_x, (mstype.float16, mstype.float32), self.name)
return input_x


class GetNext(PrimitiveWithInfer): class GetNext(PrimitiveWithInfer):
""" """
Returns the next element in the dataset queue. Returns the next element in the dataset queue.


+ 1
- 1
model_zoo/official/nlp/bert/src/config.py View File

@@ -103,7 +103,7 @@ if cfg.bert_network == 'large':
num_hidden_layers=24, num_hidden_layers=24,
num_attention_heads=16, num_attention_heads=16,
intermediate_size=4096, intermediate_size=4096,
hidden_act="gelu",
hidden_act="fast_gelu",
hidden_dropout_prob=0.1, hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1, attention_probs_dropout_prob=0.1,
max_position_embeddings=512, max_position_embeddings=512,


+ 122
- 0
tests/st/ops/ascend/test_tbe_ops/test_fast_gelu.py View File

@@ -0,0 +1,122 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest

from mindspore import context
from mindspore.common.tensor import Tensor
from mindspore.nn import FastGelu
from mindspore.train.model import Model

context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")


def fast_gelu_forward_me_impl(input_):
n = FastGelu()
n.set_train()
m = Model(n)
out = m.predict(input_)
return out.asnumpy()


def fast_gelu_forward_cmp(input_shape, data_type=np.float32):
input_np = np.random.randn(*input_shape).astype(data_type)
input_me = Tensor(input_np)
fast_gelu_forward_me_impl(input_me)


def test_fast_gelu_input_dim_0():
input_shape = [0]
with pytest.raises(ValueError):
fast_gelu_forward_cmp(input_shape)


def test_fast_gelu_input_dim_10240_1024():
input_shape = [10240, 1024]
fast_gelu_forward_cmp(input_shape)


def test_fast_gelu_input_dim_10240_768():
input_shape = [10240, 768]
fast_gelu_forward_cmp(input_shape)


@pytest.mark.lower_bs
def test_fast_gelu_input_dim_1024_3072():
input_shape = [1024, 3072]
fast_gelu_forward_cmp(input_shape)


@pytest.mark.lower_bs
def test_fast_gelu_input_dim_1024_4096():
input_shape = [1024, 4096]
fast_gelu_forward_cmp(input_shape)


def test_fast_gelu_input_dim_1280_1024():
input_shape = [1280, 1024]
fast_gelu_forward_cmp(input_shape)


@pytest.mark.lower_bs
def test_fast_gelu_input_dim_1280_768():
input_shape = [1280, 768]
fast_gelu_forward_cmp(input_shape)


@pytest.mark.lower_bs
def test_fast_gelu_input_dim_128_3072():
input_shape = [128, 3072]
fast_gelu_forward_cmp(input_shape)


@pytest.mark.lower_bs
def test_fast_gelu_input_dim_128_4096():
input_shape = [128, 4096]
fast_gelu_forward_cmp(input_shape)


@pytest.mark.lower_bs
def test_fast_gelu_input_dim_160_1024():
input_shape = [160, 1024]
fast_gelu_forward_cmp(input_shape)


@pytest.mark.lower_bs
def test_fast_gelu_input_dim_160_768():
input_shape = [160, 768]
fast_gelu_forward_cmp(input_shape)


@pytest.mark.lower_bs
def test_fast_gelu_input_dim_16384_3072():
input_shape = [16384, 3072]
fast_gelu_forward_cmp(input_shape)


def test_fast_gelu_input_dim_16384_4096():
input_shape = [16384, 4096]
fast_gelu_forward_cmp(input_shape)


@pytest.mark.lower_bs
def test_fast_gelu_input_dim_20_1024():
input_shape = [20, 1024]
fast_gelu_forward_cmp(input_shape)


def test_fast_gelu_input_dim_20480_1024():
input_shape = [20480, 1024]
fast_gelu_forward_cmp(input_shape)

+ 91
- 0
tests/st/ops/ascend/test_tbe_ops/test_fast_gelu_grad_sens.py View File

@@ -0,0 +1,91 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np

from mindspore import context
from mindspore import log as logger
from mindspore.common.tensor import Tensor
from mindspore.nn import Cell, FastGelu
from mindspore.ops import operations as P
from mindspore.ops.composite import GradOperation

context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")


class Grad(Cell):
def __init__(self, network):
super(Grad, self).__init__()
self.grad = GradOperation(get_all=True, sens_param=True)
self.network = network

def construct(self, input_, output_grad):
return self.grad(self.network)(input_, output_grad)


def fast_gelu_backward_me_impl(input_, output_grad):
n = FastGelu()
grad_with_sense = Grad(n)
grad_with_sense.set_train()
input_grad = grad_with_sense(input_, output_grad)
return input_grad


def fast_gelu_backward_cmp(input_shape):
input_np = np.random.randn(*input_shape).astype(np.float32)
input_me = Tensor(input_np)

output_grad_shape = input_shape
output_grad_np = np.random.randn(*output_grad_shape).astype(np.float32)
output_grad_me = Tensor(output_grad_np)

output_grad_me = fast_gelu_backward_me_impl(input_me, output_grad_me)
logger.info("---------me--------")
logger.info(output_grad_me)


# ---------- LARGE INPUT ---------------

class MEGeluLargeIn(Cell):
def __init__(self):
super(MEGeluLargeIn, self).__init__()
self.matmul = P.MatMul()
self.fast_gelu = P.Gelu()

def construct(self, x1, x2):
x = self.matmul(x1, x2)
return self.fast_gelu(x)


class GradLargeIn(Cell):
def __init__(self, network):
super(GradLargeIn, self).__init__()
self.grad = GradOperation(get_all=True, sens_param=True)
self.network = network

def construct(self, x1, x2, output_grad):
return self.grad(self.network)(x1, x2, output_grad)


def fast_gelu_backward_me_large_in_impl(x1, x2, output_grad):
n = FastGelu()
grad_with_sense = GradLargeIn(n)
grad_with_sense.set_train()
input_grad = grad_with_sense(x1, x2, output_grad)
return input_grad[0].asnumpy(), input_grad[1].asnumpy()


def test_grad_fast_gelu_input_10240_1024():
input_shape = [10240, 1024]
fast_gelu_backward_cmp(input_shape)

Loading…
Cancel
Save