From: @shibeiji Reviewed-by: Signed-off-by:tags/v1.1.0
| @@ -112,4 +112,16 @@ INPUT_MAP(GeluGrad) = {{1, INPUT_DESC(dy)}, {2, INPUT_DESC(x)}, {3, INPUT_DESC(y | |||
| ATTR_MAP(GeluGrad) = EMPTY_ATTR_MAP; | |||
| OUTPUT_MAP(GeluGrad) = {{0, OUTPUT_DESC(z)}}; | |||
| REG_ADPT_DESC(GeluGrad, prim::kPrimGeluGrad->name(), ADPT_DESC(GeluGrad)) | |||
| // FastGelu | |||
| INPUT_MAP(FastGelu) = {{1, INPUT_DESC(x)}}; | |||
| ATTR_MAP(FastGelu) = EMPTY_ATTR_MAP; | |||
| OUTPUT_MAP(FastGelu) = {{0, OUTPUT_DESC(y)}}; | |||
| REG_ADPT_DESC(FastGelu, prim::kPrimFastGelu->name(), ADPT_DESC(FastGelu)) | |||
| // FastGeluGrad | |||
| INPUT_MAP(FastGeluGrad) = {{1, INPUT_DESC(dy)}, {2, INPUT_DESC(x)}}; | |||
| ATTR_MAP(FastGeluGrad) = EMPTY_ATTR_MAP; | |||
| OUTPUT_MAP(FastGeluGrad) = {{0, OUTPUT_DESC(z)}}; | |||
| REG_ADPT_DESC(FastGeluGrad, prim::kPrimFastGeluGrad->name(), ADPT_DESC(FastGeluGrad)) | |||
| } // namespace mindspore::transform | |||
| @@ -50,6 +50,12 @@ DECLARE_OP_USE_OUTPUT(Gelu) | |||
| DECLARE_OP_ADAPTER(GeluGrad) | |||
| DECLARE_OP_USE_OUTPUT(GeluGrad) | |||
| DECLARE_OP_ADAPTER(FastGelu) | |||
| DECLARE_OP_USE_OUTPUT(FastGelu) | |||
| DECLARE_OP_ADAPTER(FastGeluGrad) | |||
| DECLARE_OP_USE_OUTPUT(FastGeluGrad) | |||
| DECLARE_OP_ADAPTER(Relu) | |||
| DECLARE_OP_USE_OUTPUT(Relu) | |||
| @@ -65,6 +65,10 @@ AbstractBasePtr InferImplGelu(const AnalysisEnginePtr &, const PrimitivePtr &pri | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplGeluGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplFastGelu(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplFastGeluGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplRelu(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplZerosLike(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| @@ -171,6 +171,8 @@ inline const PrimitivePtr kPrimCudnnUniformReal = std::make_shared<Primitive>("C | |||
| inline const PrimitivePtr kPrimOneHot = std::make_shared<Primitive>("OneHot"); | |||
| inline const PrimitivePtr kPrimGelu = std::make_shared<Primitive>("Gelu"); | |||
| inline const PrimitivePtr kPrimGeluGrad = std::make_shared<Primitive>("GeluGrad"); | |||
| inline const PrimitivePtr kPrimFastGelu = std::make_shared<Primitive>("FastGelu"); | |||
| inline const PrimitivePtr kPrimFastGeluGrad = std::make_shared<Primitive>("FastGeluGrad"); | |||
| inline const PrimitivePtr kPrimRelu = std::make_shared<Primitive>("ReLU"); | |||
| inline const PrimitivePtr kPrimRelu6 = std::make_shared<Primitive>("ReLU6"); | |||
| inline const PrimitivePtr kPrimReluV2 = std::make_shared<Primitive>("ReLUV2"); | |||
| @@ -31,6 +31,7 @@ __all__ = ['Softmax', | |||
| 'ReLU6', | |||
| 'Tanh', | |||
| 'GELU', | |||
| 'FastGelu', | |||
| 'Sigmoid', | |||
| 'PReLU', | |||
| 'get_activation', | |||
| @@ -357,6 +358,39 @@ class GELU(Cell): | |||
| return self.gelu(x) | |||
| class FastGelu(Cell): | |||
| r""" | |||
| fast Gaussian error linear unit activation function. | |||
| Applies FastGelu function to each element of the input. The input is a Tensor with any valid shape. | |||
| FastGelu is defined as: | |||
| :math:`FastGelu(x_i) = \frac {x_i} {1 + \exp(-1.702 * \left| x_i \right|)} * | |||
| \exp(0.851 * (x_i - \left| x_i \right|))`, where :math:`x_i` is the element of the input. | |||
| Inputs: | |||
| - **input_data** (Tensor) - The input of FastGelu with data type of float16 or float32. | |||
| Outputs: | |||
| Tensor, with the same type and shape as the `input_data`. | |||
| Examples: | |||
| >>> input_x = Tensor(np.array([[-1.0, 4.0, -8.0], [2.0, -5.0, 9.0]]), mindspore.float32) | |||
| >>> fast_gelu = nn.FastGelu() | |||
| >>> output = fast_gelu(input_x) | |||
| >>> print(output) | |||
| [[-1.5420423e-01 3.9955849e+00 -9.7664278e-06] | |||
| [ 1.9356585e+00 -1.0070159e-03 8.9999981e+00]] | |||
| """ | |||
| def __init__(self): | |||
| super(FastGelu, self).__init__() | |||
| self.fast_gelu = _selected_ops.FastGelu() | |||
| def construct(self, x): | |||
| return self.fast_gelu(x) | |||
| class Sigmoid(Cell): | |||
| r""" | |||
| Sigmoid activation function. | |||
| @@ -582,6 +616,7 @@ _activation = { | |||
| 'relu6': ReLU6, | |||
| 'tanh': Tanh, | |||
| 'gelu': GELU, | |||
| 'fast_gelu': FastGelu, | |||
| 'elu': ELU, | |||
| 'sigmoid': Sigmoid, | |||
| 'prelu': PReLU, | |||
| @@ -563,6 +563,18 @@ def get_bprop_gelu(self): | |||
| return bprop | |||
| @bprop_getters.register(P.FastGelu) | |||
| def get_bprop_fast_gelu(self): | |||
| """Grad definition for `FastGelu` operation.""" | |||
| input_grad = G.FastGeluGrad() | |||
| def bprop(x, out, dout): | |||
| dx = input_grad(dout, x) | |||
| return (dx,) | |||
| return bprop | |||
| @bprop_getters.register(P.FusedBatchNorm) | |||
| def get_bprop_fused_batch_norm(self): | |||
| """Grad definition for `FusedBatchNorm` operation.""" | |||
| @@ -58,6 +58,8 @@ from .confusion_mul_grad import _confusion_mul_grad_tbe | |||
| from .dropout_do_mask import _dropout_do_mask_tbe | |||
| from .gelu import _gelu_tbe | |||
| from .gelu_grad import _gelu_grad_tbe | |||
| from .fast_gelu import _fast_gelu_tbe | |||
| from .fast_gelu_grad import _fast_gelu_grad_tbe | |||
| from .max_pool import _max_pool_tbe | |||
| from .max_pool_grad import _max_pool_grad_tbe | |||
| from .max_pool_grad_with_argmax import _max_pool_grad_with_argmax_tbe | |||
| @@ -0,0 +1,37 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """FastGelu op""" | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| fast_gelu_op_info = TBERegOp("FastGelu") \ | |||
| .fusion_type("ELEMWISE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("fast_gelu.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("fast_gelu") \ | |||
| .partial_flag(True) \ | |||
| .input(0, "x", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .op_pattern("formatAgnostic") \ | |||
| .dtype_format(DataType.F16_None, DataType.F16_None) \ | |||
| .dtype_format(DataType.F32_None, DataType.F32_None) \ | |||
| .get_op_info() | |||
| @op_info_register(fast_gelu_op_info) | |||
| def _fast_gelu_tbe(): | |||
| """FastGelu TBE register""" | |||
| return | |||
| @@ -0,0 +1,41 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """FastGeluGrad op""" | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| fast_gelu_grad_op_info = TBERegOp("FastGeluGrad") \ | |||
| .fusion_type("ELEMWISE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("fast_gelu_grad.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("fast_gelu_grad") \ | |||
| .partial_flag(True) \ | |||
| .input(0, "dy", False, "required", "all") \ | |||
| .input(1, "x", False, "required", "all") \ | |||
| .output(0, "z", True, "required", "all") \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \ | |||
| .dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \ | |||
| .dtype_format(DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ) \ | |||
| .get_op_info() | |||
| @op_info_register(fast_gelu_grad_op_info) | |||
| def _fast_gelu_grad_tbe(): | |||
| """FastGeluGrad TBE register""" | |||
| return | |||
| @@ -84,6 +84,12 @@ class Gelu: | |||
| pass | |||
| @op_selector | |||
| class FastGelu: | |||
| def __call__(self, *args): | |||
| pass | |||
| @op_selector | |||
| class LayerNorm: | |||
| def __call__(self, *args): | |||
| @@ -63,7 +63,7 @@ from .nn_ops import (LSTM, SGD, Adam, FusedSparseAdam, FusedSparseLazyAdam, Adam | |||
| DepthwiseConv2dNative, | |||
| DropoutDoMask, Dropout, | |||
| DropoutGenMask, Flatten, FusedBatchNorm, FusedBatchNormEx, BNTrainingReduce, BNTrainingUpdate, | |||
| Gelu, Elu, | |||
| Gelu, FastGelu, Elu, | |||
| GetNext, L2Normalize, LayerNorm, L2Loss, CTCLoss, CTCGreedyDecoder, | |||
| LogSoftmax, | |||
| MaxPool, DataFormatDimMap, | |||
| @@ -165,6 +165,7 @@ __all__ = [ | |||
| 'Tile', | |||
| 'BiasAdd', | |||
| 'Gelu', | |||
| 'FastGelu', | |||
| 'Minimum', | |||
| 'Maximum', | |||
| 'StridedSlice', | |||
| @@ -637,6 +637,24 @@ class GeluGrad(PrimitiveWithInfer): | |||
| return x_dtype | |||
| class FastGeluGrad(PrimitiveWithInfer): | |||
| """Gradients of FastGelu operation.""" | |||
| @prim_attr_register | |||
| def __init__(self): | |||
| """init FastGeluGrad""" | |||
| def infer_shape(self, y_backprop_shape, x_shape): | |||
| return x_shape | |||
| def infer_dtype(self, y_backprop_dtype, x_dtype): | |||
| tuple(map(partial(validator.check_tensor_dtype_valid, | |||
| valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name), | |||
| ("y_backprop", "x"), | |||
| (y_backprop_dtype, x_dtype))) | |||
| return x_dtype | |||
| class _PoolGrad(PrimitiveWithInfer): | |||
| """Gradients of the max/avg pool operation.""" | |||
| @@ -2957,6 +2957,45 @@ class Gelu(PrimitiveWithInfer): | |||
| return input_x | |||
| class FastGelu(PrimitiveWithInfer): | |||
| r""" | |||
| fast Gaussian Error Linear Units activation function. | |||
| FastGelu is defined as follows: | |||
| .. math:: | |||
| \text{output} = \frac {x} {1 + \exp(-1.702 * \left| x \right|)} * \exp(0.851 * (x - \left| x \right|))`, | |||
| where :math:`x` is the element of the input. | |||
| Inputs: | |||
| - **input_x** (Tensor) - Input to compute the FastGelu with data type of float16 or float32. | |||
| Outputs: | |||
| Tensor, with the same type and shape as input. | |||
| Examples: | |||
| >>> tensor = Tensor(np.array([[-1.0, 4.0, -8.0], [2.0, -5.0, 9.0]]), mindspore.float32) | |||
| >>> fast_gelu = P.FastGelu() | |||
| >>> output = fast_gelu(tensor) | |||
| >>> print(output) | |||
| [[-1.5420423e-01 3.9955849e+00 -9.7664278e-06] | |||
| [ 1.9356585e+00 -1.0070159e-03 8.9999981e+00]] | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self): | |||
| """init FastGeLU""" | |||
| self.init_prim_io_names(inputs=['x'], outputs=['output']) | |||
| def infer_shape(self, input_x): | |||
| return input_x | |||
| def infer_dtype(self, input_x): | |||
| validator.check_tensor_dtype_valid("input_x", input_x, (mstype.float16, mstype.float32), self.name) | |||
| return input_x | |||
| class GetNext(PrimitiveWithInfer): | |||
| """ | |||
| Returns the next element in the dataset queue. | |||
| @@ -103,7 +103,7 @@ if cfg.bert_network == 'large': | |||
| num_hidden_layers=24, | |||
| num_attention_heads=16, | |||
| intermediate_size=4096, | |||
| hidden_act="gelu", | |||
| hidden_act="fast_gelu", | |||
| hidden_dropout_prob=0.1, | |||
| attention_probs_dropout_prob=0.1, | |||
| max_position_embeddings=512, | |||
| @@ -0,0 +1,122 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import pytest | |||
| from mindspore import context | |||
| from mindspore.common.tensor import Tensor | |||
| from mindspore.nn import FastGelu | |||
| from mindspore.train.model import Model | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
| def fast_gelu_forward_me_impl(input_): | |||
| n = FastGelu() | |||
| n.set_train() | |||
| m = Model(n) | |||
| out = m.predict(input_) | |||
| return out.asnumpy() | |||
| def fast_gelu_forward_cmp(input_shape, data_type=np.float32): | |||
| input_np = np.random.randn(*input_shape).astype(data_type) | |||
| input_me = Tensor(input_np) | |||
| fast_gelu_forward_me_impl(input_me) | |||
| def test_fast_gelu_input_dim_0(): | |||
| input_shape = [0] | |||
| with pytest.raises(ValueError): | |||
| fast_gelu_forward_cmp(input_shape) | |||
| def test_fast_gelu_input_dim_10240_1024(): | |||
| input_shape = [10240, 1024] | |||
| fast_gelu_forward_cmp(input_shape) | |||
| def test_fast_gelu_input_dim_10240_768(): | |||
| input_shape = [10240, 768] | |||
| fast_gelu_forward_cmp(input_shape) | |||
| @pytest.mark.lower_bs | |||
| def test_fast_gelu_input_dim_1024_3072(): | |||
| input_shape = [1024, 3072] | |||
| fast_gelu_forward_cmp(input_shape) | |||
| @pytest.mark.lower_bs | |||
| def test_fast_gelu_input_dim_1024_4096(): | |||
| input_shape = [1024, 4096] | |||
| fast_gelu_forward_cmp(input_shape) | |||
| def test_fast_gelu_input_dim_1280_1024(): | |||
| input_shape = [1280, 1024] | |||
| fast_gelu_forward_cmp(input_shape) | |||
| @pytest.mark.lower_bs | |||
| def test_fast_gelu_input_dim_1280_768(): | |||
| input_shape = [1280, 768] | |||
| fast_gelu_forward_cmp(input_shape) | |||
| @pytest.mark.lower_bs | |||
| def test_fast_gelu_input_dim_128_3072(): | |||
| input_shape = [128, 3072] | |||
| fast_gelu_forward_cmp(input_shape) | |||
| @pytest.mark.lower_bs | |||
| def test_fast_gelu_input_dim_128_4096(): | |||
| input_shape = [128, 4096] | |||
| fast_gelu_forward_cmp(input_shape) | |||
| @pytest.mark.lower_bs | |||
| def test_fast_gelu_input_dim_160_1024(): | |||
| input_shape = [160, 1024] | |||
| fast_gelu_forward_cmp(input_shape) | |||
| @pytest.mark.lower_bs | |||
| def test_fast_gelu_input_dim_160_768(): | |||
| input_shape = [160, 768] | |||
| fast_gelu_forward_cmp(input_shape) | |||
| @pytest.mark.lower_bs | |||
| def test_fast_gelu_input_dim_16384_3072(): | |||
| input_shape = [16384, 3072] | |||
| fast_gelu_forward_cmp(input_shape) | |||
| def test_fast_gelu_input_dim_16384_4096(): | |||
| input_shape = [16384, 4096] | |||
| fast_gelu_forward_cmp(input_shape) | |||
| @pytest.mark.lower_bs | |||
| def test_fast_gelu_input_dim_20_1024(): | |||
| input_shape = [20, 1024] | |||
| fast_gelu_forward_cmp(input_shape) | |||
| def test_fast_gelu_input_dim_20480_1024(): | |||
| input_shape = [20480, 1024] | |||
| fast_gelu_forward_cmp(input_shape) | |||
| @@ -0,0 +1,91 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| from mindspore import context | |||
| from mindspore import log as logger | |||
| from mindspore.common.tensor import Tensor | |||
| from mindspore.nn import Cell, FastGelu | |||
| from mindspore.ops import operations as P | |||
| from mindspore.ops.composite import GradOperation | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
| class Grad(Cell): | |||
| def __init__(self, network): | |||
| super(Grad, self).__init__() | |||
| self.grad = GradOperation(get_all=True, sens_param=True) | |||
| self.network = network | |||
| def construct(self, input_, output_grad): | |||
| return self.grad(self.network)(input_, output_grad) | |||
| def fast_gelu_backward_me_impl(input_, output_grad): | |||
| n = FastGelu() | |||
| grad_with_sense = Grad(n) | |||
| grad_with_sense.set_train() | |||
| input_grad = grad_with_sense(input_, output_grad) | |||
| return input_grad | |||
| def fast_gelu_backward_cmp(input_shape): | |||
| input_np = np.random.randn(*input_shape).astype(np.float32) | |||
| input_me = Tensor(input_np) | |||
| output_grad_shape = input_shape | |||
| output_grad_np = np.random.randn(*output_grad_shape).astype(np.float32) | |||
| output_grad_me = Tensor(output_grad_np) | |||
| output_grad_me = fast_gelu_backward_me_impl(input_me, output_grad_me) | |||
| logger.info("---------me--------") | |||
| logger.info(output_grad_me) | |||
| # ---------- LARGE INPUT --------------- | |||
| class MEGeluLargeIn(Cell): | |||
| def __init__(self): | |||
| super(MEGeluLargeIn, self).__init__() | |||
| self.matmul = P.MatMul() | |||
| self.fast_gelu = P.Gelu() | |||
| def construct(self, x1, x2): | |||
| x = self.matmul(x1, x2) | |||
| return self.fast_gelu(x) | |||
| class GradLargeIn(Cell): | |||
| def __init__(self, network): | |||
| super(GradLargeIn, self).__init__() | |||
| self.grad = GradOperation(get_all=True, sens_param=True) | |||
| self.network = network | |||
| def construct(self, x1, x2, output_grad): | |||
| return self.grad(self.network)(x1, x2, output_grad) | |||
| def fast_gelu_backward_me_large_in_impl(x1, x2, output_grad): | |||
| n = FastGelu() | |||
| grad_with_sense = GradLargeIn(n) | |||
| grad_with_sense.set_train() | |||
| input_grad = grad_with_sense(x1, x2, output_grad) | |||
| return input_grad[0].asnumpy(), input_grad[1].asnumpy() | |||
| def test_grad_fast_gelu_input_10240_1024(): | |||
| input_shape = [10240, 1024] | |||
| fast_gelu_backward_cmp(input_shape) | |||