You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

gelu.py 3.0 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ===========================================================================
  15. """generate json desc for gelu"""
  16. from mindspore._extends.graph_kernel.model import model_builder as builder
  17. CSVALUE = 0.044715
  18. CSVALUE_A = 1.5957691 # 2*np.sqrt(2/np.pi)
  19. def expand_gelu(expand_info):
  20. """Gelu expander"""
  21. # get op info.
  22. input_desc = expand_info['input_desc'][0]
  23. graph_builder = builder.GraphBuilder()
  24. # generate a graph.
  25. with graph_builder.graph_scope('main') as graph_scope:
  26. # create tensor input.
  27. input_x = graph_builder.tensor(input_desc['shape'], input_desc['data_type'], input_desc['format'])
  28. dtype = input_x.dtype
  29. if dtype == 'float16':
  30. input_x = graph_builder.emit('Cast', [input_x], attrs={'dst_type': 'float32'})
  31. # cal tanh.
  32. mul_0 = graph_builder.emit('Mul', [input_x, input_x])
  33. pow_0 = graph_builder.emit('Mul', [mul_0, input_x])
  34. const_csvalue = graph_builder.value(pow_0.dtype, CSVALUE, input_desc['format'])
  35. mul_1 = graph_builder.emit('Mul', [pow_0, const_csvalue])
  36. tanh_res = graph_builder.emit('TensorAdd', [input_x, mul_1])
  37. const_csvalue_a = graph_builder.value(tanh_res.dtype, CSVALUE_A, input_desc['format'])
  38. mul_0 = graph_builder.emit('Mul', [tanh_res, const_csvalue_a])
  39. const_zero = graph_builder.value(mul_0.dtype, 0.0, input_desc['format'])
  40. mul_0_min = graph_builder.emit('Minimum', [mul_0, const_zero])
  41. right_mul = graph_builder.emit('Exp', [mul_0_min])
  42. mul_0_abs = graph_builder.emit('Abs', [mul_0])
  43. const_neg_one = graph_builder.value(mul_0_abs.dtype, -1.0, input_desc['format'])
  44. mul_0_abs_neg = graph_builder.emit('Mul', [mul_0_abs, const_neg_one])
  45. mul_0_abs_neg_exp = graph_builder.emit('Exp', [mul_0_abs_neg])
  46. const_one = graph_builder.value(mul_0_abs_neg_exp.dtype, 1.0, input_desc['format'])
  47. mul_0_abs_neg_exp_add = graph_builder.emit('TensorAdd', [mul_0_abs_neg_exp, const_one])
  48. left_mul = graph_builder.emit('RealDiv', [input_x, mul_0_abs_neg_exp_add])
  49. result = graph_builder.emit('Mul', [left_mul, right_mul])
  50. if dtype == 'float16':
  51. result = graph_builder.emit('Cast', [result], attrs={'dst_type': 'float16'})
  52. # set graph output.
  53. graph_scope.set_output(result)
  54. graph = graph_builder.get()[0]
  55. return graph