You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

util.py 7.2 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241
  1. #!/usr/bin/env python3
  2. # coding: utf-8
  3. # Copyright 2019 Huawei Technologies Co., Ltd
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. """util"""
  17. from decorator import decorator
  18. import akg.tvm
  19. from akg.utils import kernel_exec
  20. # Save op's output dtype, when first call the template api,we will save the dtype.
  21. # Before auto scheduling,get the dtype and convert the res tensor to this dtype,
  22. # and set the dtype to None.
  23. op_output_dtype = None
  24. dtype_map = {
  25. "float32": "f32",
  26. "float16": "f16",
  27. "int8": "s8",
  28. "uint8": "u8",
  29. "int32": "s32",
  30. }
  31. def save_op_output_dtype(func, *args):
  32. """
  33. Save op's output dtype, when first call the template api,
  34. Note:
  35. we will save the dtype.
  36. Before auto scheduling,get the dtype and convert the res tensor
  37. to this dtype, and set the dtype to None.
  38. """
  39. global op_output_dtype
  40. if op_output_dtype is None:
  41. if func.__name__ == "broadcast":
  42. if isinstance(args[0], int):
  43. output_dtype = "int32"
  44. elif isinstance(args[0], float):
  45. output_dtype = "float16"
  46. else:
  47. output_dtype = args[0].dtype
  48. elif func.__name__ == "concat":
  49. output_dtype = args[0][0].dtype
  50. else:
  51. output_dtype = args[0].dtype
  52. op_output_dtype = output_dtype
  53. def get_op_output_dtype():
  54. """get saved op's output dtype and set saved dtype to None."""
  55. global op_output_dtype
  56. res = op_output_dtype
  57. op_output_dtype = None
  58. return res
  59. @decorator
  60. def dtype_check_decorator(func, *args, **kwargs):
  61. """check type decorator"""
  62. intput_check_list = {
  63. "broadcast": ["int8", "uint8", "float16", "float32", "int32", "bool"],
  64. "concat": ["int8", "uint8", "float16", "float32", "int32"],
  65. "unsorted_segment_sum": ["float16", "float32", "int32", "uint8", "int8"],
  66. "unsorted_segment_mean": ["float16", "float32", "uint8", "int8"],
  67. "unsorted_segment_prod": ["float16", "float32", "int32", "uint8", "int8"],
  68. "unsorted_segment_min": ["float16", "float32", "int32", "uint8", "int8"],
  69. "unsorted_segment_max": ["float16", "float32", "int32", "uint8", "int8"],
  70. }
  71. if func.__name__ == "cast":
  72. input_dtype = args[0].dtype
  73. output_dtype = args[1]
  74. judge_dtype = (input_dtype + "2" + output_dtype) if (input_dtype != output_dtype) else ""
  75. elif func.__name__ == "broadcast":
  76. if isinstance(args[0], int):
  77. judge_dtype = "int32"
  78. elif isinstance(args[0], float):
  79. judge_dtype = "float16"
  80. else:
  81. judge_dtype = args[0].dtype
  82. elif func.__name__ == "concat":
  83. judge_dtype = args[0][0].dtype
  84. else:
  85. judge_dtype = args[0].dtype
  86. if isinstance(intput_check_list[func.__name__], list):
  87. check_bool(judge_dtype in intput_check_list[func.__name__],
  88. "%s input_dtype just support %s, while input dtype is %s" % (
  89. func.__name__, str(intput_check_list[func.__name__]), judge_dtype))
  90. else:
  91. intri_name = "Intrinsic_" + func.__name__
  92. s_dtypes = intput_check_list[func.__name__](intri_name)
  93. check_bool(judge_dtype in s_dtypes,
  94. "%s input_dtype just support %s, while input dtype is %s" % (
  95. func.__name__, str(s_dtypes), judge_dtype))
  96. return func(*args, **kwargs)
  97. def get_value(key):
  98. """
  99. call global func to get product value.
  100. Args:
  101. key (str): key.
  102. """
  103. mode = kernel_exec.get_runtime_mode()
  104. if "cloud" in mode:
  105. product = "1.6"
  106. else:
  107. product = "1.1"
  108. if "Buffer" in key:
  109. f = akg.tvm.get_global_func("cce.product_conf_buffer")
  110. value = f(product, key)
  111. if value == 0:
  112. raise RuntimeError("Get the cce product value is 0")
  113. return value
  114. if "Compiler" in key:
  115. f = akg.tvm.get_global_func("cce.product_conf_compiler")
  116. value = f(product, key)
  117. if value == "":
  118. raise RuntimeError("Get the cce product value is None")
  119. return value
  120. if "Intrinsic" in key:
  121. f = akg.tvm.get_global_func("cce.product_conf_intrinsic")
  122. value = f(product, key)
  123. if value == "":
  124. raise RuntimeError("Get the cce product value is None")
  125. return value
  126. if "Core" in key:
  127. f = akg.tvm.get_global_func("cce.product_conf_core")
  128. value = f(product, key)
  129. if value == 0:
  130. raise RuntimeError("Get the cce product value is None")
  131. return value
  132. return None
  133. def get_intr_types(intr):
  134. """get intrinsic types"""
  135. return str_to_tuple(get_value(intr))
  136. def str_to_tuple(string_):
  137. """string to tuple"""
  138. if string_:
  139. return string_.split(",")
  140. return []
  141. def is_cast_support(src_type, dst_type):
  142. """check cast support"""
  143. if src_type not in dtype_map:
  144. raise RuntimeError("%s is unsupported dtype!" % src_type)
  145. if dst_type not in dtype_map:
  146. raise RuntimeError("%s is unsupported dtype!" % dst_type)
  147. if src_type == dst_type:
  148. return True
  149. cast_type = dtype_map[src_type] + "2" + dtype_map[dst_type]
  150. if cast_type == "s322f16":
  151. cast_type = "deq"
  152. conv_list = get_intr_types("Intrinsic_vconv")
  153. if cast_type in conv_list:
  154. return True
  155. return False
  156. def judge_var(num):
  157. """judge var if a akg.tvm.var, akg.tvm.const or python data type"""
  158. var_dict = {"python_const": [int, float],
  159. "tvm_const": [akg.tvm.expr.IntImm, akg.tvm.expr.UIntImm, akg.tvm.expr.FloatImm],
  160. "tvm_var": [akg.tvm.expr.Var]}
  161. num_type = type(num)
  162. for i in var_dict:
  163. if num_type in var_dict[i]:
  164. return i
  165. raise RuntimeError("Input var Error")
  166. def shape_to_list(shape):
  167. """translate akg.tvm.shape to list type in python"""
  168. tmp = []
  169. for i in shape:
  170. if isinstance(i, akg.tvm.expr.Var):
  171. tmp.append(i)
  172. else:
  173. tmp.append(i.value)
  174. return tmp
  175. def refine_axis(axis, shape):
  176. """refine axis"""
  177. if isinstance(axis, (tuple, list)):
  178. local_axis = axis
  179. else:
  180. local_axis = [axis]
  181. res_axis = []
  182. shape_len = len(shape)
  183. for i in local_axis:
  184. if i < 0:
  185. laxis = shape_len + i
  186. else:
  187. laxis = i
  188. if (laxis >= shape_len) or (laxis < 0):
  189. raise RuntimeError("wrong axis.")
  190. res_axis.append(laxis)
  191. return sorted(res_axis)
  192. def check_bool(bool_res, append_str):
  193. """check boolean"""
  194. if not bool_res:
  195. raise RuntimeError(append_str)

AKG(Auto Kernel Generator)对深度神经网络中的算子进行优化,并提供特定模式下的算子自动融合功能。AKG与MindSpore的图算融合功能协同工作,可提升在不同硬件后端上运行网络的性能。