Merge pull request !8084 from ZengZitao/expand_gelutags/v1.1.0
| @@ -15,6 +15,7 @@ | |||||
| """expanders init""" | """expanders init""" | ||||
| from .gelu import expand_gelu | from .gelu import expand_gelu | ||||
| from .gelu_grad import expand_gelugrad | |||||
| from .layernorm import expand_layernorm | from .layernorm import expand_layernorm | ||||
| from .softmax import expand_softmax | from .softmax import expand_softmax | ||||
| from .square import expand_square | from .square import expand_square | ||||
| @@ -16,11 +16,16 @@ | |||||
| from mindspore._extends.graph_kernel.model import model_builder as builder | from mindspore._extends.graph_kernel.model import model_builder as builder | ||||
| CSVALUE = 0.044715 | CSVALUE = 0.044715 | ||||
| CSVALUE_A = 1.5957691 # 2*np.sqrt(2/np.pi) | |||||
| CSVALUE_SQRT_TWO_DIV_PI = 0.7978845608028564 # np.sqrt(2/np.pi) | |||||
| ONE = 1.0 | |||||
| HALF = 0.5 | |||||
| def expand_gelu(expand_info): | def expand_gelu(expand_info): | ||||
| """Gelu expander""" | """Gelu expander""" | ||||
| # cal formula are: | |||||
| # gelu(x) = 0.5 * x * (1.0 + tanh(y)) | |||||
| # y = sqrt(2.0 / pi) * (x + 0.044715 * x * x * x) | |||||
| # get op info. | # get op info. | ||||
| input_desc = expand_info['input_desc'][0] | input_desc = expand_info['input_desc'][0] | ||||
| @@ -30,35 +35,29 @@ def expand_gelu(expand_info): | |||||
| with graph_builder.graph_scope('main') as graph_scope: | with graph_builder.graph_scope('main') as graph_scope: | ||||
| # create tensor input. | # create tensor input. | ||||
| input_x = graph_builder.tensor(input_desc['shape'], input_desc['data_type'], input_desc['format']) | input_x = graph_builder.tensor(input_desc['shape'], input_desc['data_type'], input_desc['format']) | ||||
| graph_scope.set_input(input_x) | |||||
| dtype = input_x.dtype | dtype = input_x.dtype | ||||
| if dtype == 'float16': | if dtype == 'float16': | ||||
| input_x = graph_builder.emit('Cast', [input_x], attrs={'dst_type': 'float32'}) | input_x = graph_builder.emit('Cast', [input_x], attrs={'dst_type': 'float32'}) | ||||
| # cal tanh. | |||||
| # cal y | |||||
| mul_0 = graph_builder.emit('Mul', [input_x, input_x]) | mul_0 = graph_builder.emit('Mul', [input_x, input_x]) | ||||
| pow_0 = graph_builder.emit('Mul', [mul_0, input_x]) | pow_0 = graph_builder.emit('Mul', [mul_0, input_x]) | ||||
| const_csvalue = graph_builder.value(pow_0.dtype, CSVALUE, input_desc['format']) | const_csvalue = graph_builder.value(pow_0.dtype, CSVALUE, input_desc['format']) | ||||
| mul_1 = graph_builder.emit('Mul', [pow_0, const_csvalue]) | mul_1 = graph_builder.emit('Mul', [pow_0, const_csvalue]) | ||||
| tanh_res = graph_builder.emit('TensorAdd', [input_x, mul_1]) | tanh_res = graph_builder.emit('TensorAdd', [input_x, mul_1]) | ||||
| const_csvalue_sqrt_two_div_pi = graph_builder.value( | |||||
| tanh_res.dtype, CSVALUE_SQRT_TWO_DIV_PI, input_desc['format']) | |||||
| y = graph_builder.emit('Mul', [tanh_res, const_csvalue_sqrt_two_div_pi]) | |||||
| const_csvalue_a = graph_builder.value(tanh_res.dtype, CSVALUE_A, input_desc['format']) | |||||
| mul_0 = graph_builder.emit('Mul', [tanh_res, const_csvalue_a]) | |||||
| # cal gelu(x) | |||||
| tanh_y = graph_builder.emit('Tanh', [y]) | |||||
| const_one = graph_builder.value(tanh_y.dtype, ONE, input_desc['format']) | |||||
| const_half = graph_builder.value(tanh_y.dtype, HALF, input_desc['format']) | |||||
| tanh_y_add_one = graph_builder.emit('TensorAdd', [tanh_y, const_one]) | |||||
| mul_x = graph_builder.emit('Mul', [input_x, tanh_y_add_one]) | |||||
| result = graph_builder.emit('Mul', [const_half, mul_x]) | |||||
| const_zero = graph_builder.value(mul_0.dtype, 0.0, input_desc['format']) | |||||
| mul_0_min = graph_builder.emit('Minimum', [mul_0, const_zero]) | |||||
| right_mul = graph_builder.emit('Exp', [mul_0_min]) | |||||
| mul_0_abs = graph_builder.emit('Abs', [mul_0]) | |||||
| const_neg_one = graph_builder.value(mul_0_abs.dtype, -1.0, input_desc['format']) | |||||
| mul_0_abs_neg = graph_builder.emit('Mul', [mul_0_abs, const_neg_one]) | |||||
| mul_0_abs_neg_exp = graph_builder.emit('Exp', [mul_0_abs_neg]) | |||||
| const_one = graph_builder.value(mul_0_abs_neg_exp.dtype, 1.0, input_desc['format']) | |||||
| mul_0_abs_neg_exp_add = graph_builder.emit('TensorAdd', [mul_0_abs_neg_exp, const_one]) | |||||
| left_mul = graph_builder.emit('RealDiv', [input_x, mul_0_abs_neg_exp_add]) | |||||
| result = graph_builder.emit('Mul', [left_mul, right_mul]) | |||||
| if dtype == 'float16': | if dtype == 'float16': | ||||
| result = graph_builder.emit('Cast', [result], attrs={'dst_type': 'float16'}) | result = graph_builder.emit('Cast', [result], attrs={'dst_type': 'float16'}) | ||||
| # set graph output. | # set graph output. | ||||
| @@ -0,0 +1,92 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # =========================================================================== | |||||
| """generate json desc for gelugrad""" | |||||
| from mindspore._extends.graph_kernel.model import model_builder as builder | |||||
| CSVALUE = 0.044715 | |||||
| CSVALUE_SQRT_TWO_DIV_PI = 0.7978845608028564 # np.sqrt(2/np.pi) | |||||
| CSVALUE_TRI = 0.134141 # CSVALUE * 3 | |||||
| ONE = 1.0 | |||||
| HALF = 0.5 | |||||
| def expand_gelugrad(expand_info): | |||||
| """GeluGrad expander""" | |||||
| # cal formula are: | |||||
| # gelu_grad(dy, x) = dy * y' | |||||
| # y' = 0.5 * (1.0 + tanh(tanh_para)) + 0.5 * x * (1.0 - tanh(tanh_para) * tanh(para)) * mul_right | |||||
| # tanh_para = sqrt(2.0 / pi) * (x + 0.044715 * x * x * x) | |||||
| # mul_right = sqrt(2.0 / pi) * (1 + 3 * 0.044715 * x * x) | |||||
| # get op info. | |||||
| input_desc_0 = expand_info['input_desc'][0] | |||||
| input_desc_1 = expand_info['input_desc'][1] | |||||
| input_desc_2 = expand_info['input_desc'][2] | |||||
| graph_builder = builder.GraphBuilder() | |||||
| # generate a graph. | |||||
| with graph_builder.graph_scope('main') as graph_scope: | |||||
| # create tensor input. | |||||
| input_dy = graph_builder.tensor(input_desc_0['shape'], input_desc_0['data_type'], input_desc_0['format']) | |||||
| input_x = graph_builder.tensor(input_desc_1['shape'], input_desc_1['data_type'], input_desc_1['format']) | |||||
| input_y = graph_builder.tensor(input_desc_2['shape'], input_desc_2['data_type'], input_desc_2['format']) | |||||
| graph_scope.set_input(input_dy, input_x, input_y) | |||||
| dtype = input_dy.dtype | |||||
| if dtype == 'float16': | |||||
| input_dy = graph_builder.emit('Cast', [input_dy], attrs={'dst_type': 'float32'}) | |||||
| # create some const var | |||||
| const_csvalue = graph_builder.value(input_dy.dtype, CSVALUE, input_desc_0['format']) | |||||
| const_csvalue_sqrt_two_div_pi = graph_builder.value( | |||||
| input_dy.dtype, CSVALUE_SQRT_TWO_DIV_PI, input_desc_0['format']) | |||||
| const_csvalue_tri = graph_builder.value(input_dy.dtype, CSVALUE_TRI, input_desc_0['format']) | |||||
| const_one = graph_builder.value(input_dy.dtype, ONE, input_desc_0['format']) | |||||
| const_half = graph_builder.value(input_dy.dtype, HALF, input_desc_0['format']) | |||||
| # cal mul_right | |||||
| mul_double = graph_builder.emit('Mul', [input_x, input_x]) | |||||
| mul_double_mul_tri = graph_builder.emit('Mul', [const_csvalue_tri, mul_double]) | |||||
| mul_add_one = graph_builder.emit('TensorAdd', [const_one, mul_double_mul_tri]) | |||||
| mul_right = graph_builder.emit('Mul', [const_csvalue_sqrt_two_div_pi, mul_add_one]) | |||||
| # cal tanh_para | |||||
| mul_triple = graph_builder.emit('Mul', [input_x, mul_double]) | |||||
| mul_triple_mul_csvalue = graph_builder.emit('Mul', [const_csvalue, mul_triple]) | |||||
| mul_add_x = graph_builder.emit('TensorAdd', [input_x, mul_triple_mul_csvalue]) | |||||
| tanh_para = graph_builder.emit('Mul', [const_csvalue_sqrt_two_div_pi, mul_add_x]) | |||||
| # cal 0.5 * (1.0 + tanh(tahn_para)) | |||||
| tanh_res = graph_builder.emit('Tanh', [tanh_para]) | |||||
| tanh_res_add_one = graph_builder.emit('TensorAdd', [const_one, tanh_res]) | |||||
| half_mul_tanh_res_add_one = graph_builder.emit('Mul', [const_half, tanh_res_add_one]) | |||||
| # cal 0.5 * x * (1.0 - tanh(tanh_para) * tanh(tanh_para)) * mul_right | |||||
| tan_res_double = graph_builder.emit('Mul', [tanh_res, tanh_res]) | |||||
| one_sub_tan_res_double = graph_builder.emit('Sub', [const_one, tan_res_double]) | |||||
| half_mul_x = graph_builder.emit('Mul', [const_half, input_x]) | |||||
| mul_tmp = graph_builder.emit('Mul', [half_mul_x, one_sub_tan_res_double]) | |||||
| mul_final = graph_builder.emit('Mul', [mul_tmp, mul_right]) | |||||
| # cal result | |||||
| result_tmp = graph_builder.emit('TensorAdd', [half_mul_tanh_res_add_one, mul_final]) | |||||
| result = graph_builder.emit('Mul', [input_dy, result_tmp]) | |||||
| if dtype == 'float16': | |||||
| result = graph_builder.emit('Cast', [result], attrs={'dst_type': 'float16'}) | |||||
| # set graph output. | |||||
| graph_scope.set_output(result) | |||||
| graph = graph_builder.get()[0] | |||||
| return graph | |||||
| @@ -153,6 +153,7 @@ class PrimLib: | |||||
| 'make_tuple': Prim(CONTROL), | 'make_tuple': Prim(CONTROL), | ||||
| 'ControlDepend': Prim(CONTROL), | 'ControlDepend': Prim(CONTROL), | ||||
| 'Assign': Prim(ELEMWISE), | 'Assign': Prim(ELEMWISE), | ||||
| 'Tanh': Prim(ELEMWISE), | |||||
| '@ReduceInit': Prim(ELEMWISE), | '@ReduceInit': Prim(ELEMWISE), | ||||
| } | } | ||||
| @@ -705,6 +705,7 @@ std::unordered_set<PrimitivePtr> GetExpandOps() { | |||||
| prim::kPrimSquare, | prim::kPrimSquare, | ||||
| prim::kPrimBiasAdd, | prim::kPrimBiasAdd, | ||||
| prim::kPrimBiasAddGrad, | prim::kPrimBiasAddGrad, | ||||
| prim::kPrimGelu, | |||||
| }; | }; | ||||
| return expand_ops; | return expand_ops; | ||||
| } | } | ||||
| @@ -731,7 +732,7 @@ std::vector<PrimitivePtr> GetFusibleOpList() { | |||||
| prim::kPrimRealDiv, prim::kPrimMul, prim::kPrimMinimum, prim::kPrimMaximum, prim::kPrimLog, | prim::kPrimRealDiv, prim::kPrimMul, prim::kPrimMinimum, prim::kPrimMaximum, prim::kPrimLog, | ||||
| prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimCast, | prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimCast, | ||||
| prim::kPrimAddN, prim::kPrimEqual, prim::kPrimReciprocal, prim::KPrimTransData, prim::kPrimSelect, | prim::kPrimAddN, prim::kPrimEqual, prim::kPrimReciprocal, prim::KPrimTransData, prim::kPrimSelect, | ||||
| prim::kPrimGreater, prim::kPrimAssign, prim::kPrimReduceSum}; | |||||
| prim::kPrimGreater, prim::kPrimAssign, prim::kPrimReduceSum, prim::kPrimTanh}; | |||||
| return fusible_basic_ops; | return fusible_basic_ops; | ||||
| } | } | ||||