You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

gelu.py 2.6 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ===========================================================================
  15. """generate json desc for gelu"""
  16. from mindspore._extends.graph_kernel.model import model_builder as builder
  17. CSVALUE = 0.044715
  18. CSVALUE_SQRT_TWO_DIV_PI = 0.7978845608028564 # np.sqrt(2/np.pi)
  19. ONE = 1.0
  20. HALF = 0.5
  21. def expand_gelu(expand_info):
  22. """Gelu expander"""
  23. # cal formula are:
  24. # gelu(x) = 0.5 * x * (1.0 + tanh(y))
  25. # y = sqrt(2.0 / pi) * (x + 0.044715 * x * x * x)
  26. # get op info.
  27. input_desc = expand_info['input_desc'][0]
  28. graph_builder = builder.GraphBuilder()
  29. # generate a graph.
  30. with graph_builder.graph_scope('main') as graph_scope:
  31. # create tensor input.
  32. input_x = graph_builder.tensor(input_desc['shape'], input_desc['data_type'], input_desc['format'])
  33. graph_scope.set_input(input_x)
  34. # cal y
  35. mul_0 = graph_builder.emit('Mul', [input_x, input_x])
  36. pow_0 = graph_builder.emit('Mul', [mul_0, input_x])
  37. const_csvalue = graph_builder.value(pow_0.dtype, CSVALUE, input_desc['format'])
  38. mul_1 = graph_builder.emit('Mul', [pow_0, const_csvalue])
  39. tanh_res = graph_builder.emit('TensorAdd', [input_x, mul_1])
  40. const_csvalue_sqrt_two_div_pi = graph_builder.value(
  41. tanh_res.dtype, CSVALUE_SQRT_TWO_DIV_PI, input_desc['format'])
  42. y = graph_builder.emit('Mul', [tanh_res, const_csvalue_sqrt_two_div_pi])
  43. # cal gelu(x)
  44. tanh_y = graph_builder.emit('Tanh', [y])
  45. const_one = graph_builder.value(tanh_y.dtype, ONE, input_desc['format'])
  46. const_half = graph_builder.value(tanh_y.dtype, HALF, input_desc['format'])
  47. tanh_y_add_one = graph_builder.emit('TensorAdd', [tanh_y, const_one])
  48. mul_x = graph_builder.emit('Mul', [input_x, tanh_y_add_one])
  49. result = graph_builder.emit('Mul', [const_half, mul_x])
  50. # set graph output.
  51. graph_scope.set_output(result)
  52. graph = graph_builder.get()[0]
  53. return graph