You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

add3_impl.py 2.4 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. from __future__ import absolute_import
  16. import te.lang.cce
  17. from te import tvm
  18. from te.platform.fusion_manager import fusion_manager
  19. from topi import generic
  20. from topi.cce import util
  21. from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
  22. @fusion_manager.register("add3")
  23. def add3_compute(input1, input2, const_bias):
  24. sum2 = te.lang.cce.vadd(input1, input2)
  25. sum3 = te.lang.cce.vadds(sum2, tvm.const(const_bias, dtype=input1.dtype))
  26. return sum3
  27. cus_add3_op_info = TBERegOp("CusAdd3") \
  28. .fusion_type("OPAQUE") \
  29. .async_flag(False) \
  30. .binfile_name("add3.so") \
  31. .compute_cost(10) \
  32. .kernel_name("CusAdd3Impl") \
  33. .partial_flag(True) \
  34. .attr("const_bias", "required", "float", "all") \
  35. .input(0, "input1", False, "required", "all") \
  36. .input(1, "input2", False, "required", "all") \
  37. .output(0, "sum", False, "required", "all") \
  38. .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
  39. .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
  40. .get_op_info()
  41. @op_info_register(cus_add3_op_info)
  42. def CusAdd3Impl(input1, inptu2, sum1, const_bias, kernel_name="CusAdd3Impl"):
  43. shape = input1.get("shape")
  44. shape = util.shape_refine(shape)
  45. dtype = input1.get("dtype").lower()
  46. input1 = tvm.placeholder(shape, name="input1", dtype=dtype.lower())
  47. input2 = tvm.placeholder(shape, name="input2", dtype=dtype.lower())
  48. with tvm.target.cce():
  49. res = add3_compute(input1, input2, const_bias)
  50. sch = generic.auto_schedule(res)
  51. config = {"print_ir": False,
  52. "name": kernel_name,
  53. "tensor_list": [input1, input2, res]}
  54. te.lang.cce.cce_build_code(sch, config)