You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

format_transform.py 6.2 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209
  1. #!/usr/bin/env python3
  2. # coding: utf-8
  3. # Copyright 2019 Huawei Technologies Co., Ltd
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. """format transform function"""
  17. import akg
  18. supported_bits = {
  19. "8": 1, "16": 2, "32": 4, "64": 8, "bool": 1
  20. }
  21. def to_tvm_const(x):
  22. """Convert integer to TVM expression"""
  23. if isinstance(x, int):
  24. return akg.tvm.const(x)
  25. return x
  26. def get_const(expr):
  27. """
  28. get const value from TVM expression.
  29. Args:
  30. expr (tvm.expr.Expr): tvm expression.
  31. Returns:
  32. value (int): expr value.
  33. """
  34. if isinstance(expr, int):
  35. return expr
  36. if not isinstance(expr, (akg.tvm.expr.IntImm, akg.tvm.expr.UIntImm)):
  37. expr = akg.tvm.ir_pass.Simplify(expr)
  38. if not isinstance(expr, (akg.tvm.expr.IntImm, akg.tvm.expr.UIntImm)):
  39. raise TypeError("Expr is not a const. Get const fail, please use get shape.")
  40. return expr.value
  41. def get_bytes(dtype, allow_none=False):
  42. """get number of bytes for supported dtype."""
  43. dtype = str(dtype)
  44. for bits in supported_bits:
  45. if bits in dtype:
  46. return supported_bits[bits]
  47. if allow_none:
  48. return None
  49. raise RuntimeError("Invalid dtype, supported bits are {0}".format(supported_bits.keys()))
  50. def refine_shape(shape, reduce_axis=None):
  51. """
  52. Refine shape to drop 1 in shape according to reduce axis.
  53. Note:
  54. if input is just shape, result is shape, and if inputs are shape and axis, result is a tuple of (shape, axis).
  55. Args:
  56. shape : shape of data
  57. reduce_axis : list, tuple or int
  58. axis want to reduce
  59. keepdims: if keepdims = True, we should not refine the shape
  60. Returns:
  61. shape (list): refined shape.
  62. reduce_axis (list): if input parameters send reduce axis, this will be the output.
  63. if all the reduce axis is illegal like the length of reduce axis is 1, a empty list([]) will be returned.
  64. """
  65. def _refine_shape_no_reduce():
  66. refined = [shp for _, shp in enumerate(shape) if shp > 1]
  67. if not refined:
  68. refined = [1]
  69. return refined
  70. if reduce_axis is not None:
  71. res_reduce_axis = sorted(refine_reduce_axis(shape, reduce_axis))
  72. if not res_reduce_axis:
  73. return _refine_shape_no_reduce(), []
  74. res_shape = shape[:]
  75. refined_shape = []
  76. count = 0
  77. for i in res_shape:
  78. if i > 1:
  79. refined_shape.append(i)
  80. count += 1
  81. else:
  82. for j, axs in enumerate(res_reduce_axis):
  83. if axs > count:
  84. res_reduce_axis[j] -= 1
  85. return refined_shape, res_reduce_axis
  86. return _refine_shape_no_reduce()
  87. def refine_reduce_axis(input, axis):
  88. """make reduce axis legal."""
  89. shape = get_shape(input)
  90. if axis is None:
  91. axis = [i for i in range(len(shape))]
  92. elif isinstance(axis, int):
  93. axis = [axis]
  94. elif not isinstance(axis, (tuple, list)):
  95. raise TypeError("axis must be one of the type int,tuple,list or None")
  96. if len(axis) > len(shape):
  97. raise ValueError("axis size must not larger than shape size")
  98. axis = list(axis)
  99. for i, _ in enumerate(axis):
  100. if axis[i] < 0:
  101. axis[i] += len(shape)
  102. if axis[i] >= len(shape):
  103. raise ValueError(("axis value-{} exceeds len(axis) which is invalid".format(axis[i])))
  104. axis.sort(reverse=True)
  105. return axis
  106. def get_shape_from_tensor(data):
  107. """translate akg.tvm.shape to list type in python."""
  108. tvm_shape = data.shape
  109. py_shape = []
  110. for i in tvm_shape:
  111. if isinstance(i, akg.tvm.expr.IntImm):
  112. py_shape.append(i.value)
  113. else:
  114. py_shape.append(i)
  115. return py_shape
  116. def tvm_shape_to_list(tvm_shape):
  117. """translate akg.tvm.shape to list type in python."""
  118. py_shape = []
  119. for i in tvm_shape:
  120. if isinstance(i, akg.tvm.expr.Var):
  121. py_shape.append(i)
  122. else:
  123. py_shape.append(i.value)
  124. return py_shape
  125. def tvm_array_to_list(tvm_array):
  126. """translate akg.tvm.array to list type in python."""
  127. tensor_list = []
  128. for i in tvm_array:
  129. if isinstance(i, akg.tvm.tensor.Tensor):
  130. tensor_list.append(i)
  131. else:
  132. raise ValueError("Only surpport akg.tvm.tensor.Tensor.")
  133. return tensor_list
  134. def get_shape(data):
  135. """get shape and save it as list."""
  136. if isinstance(data, akg.tvm.tensor.Tensor):
  137. shape = get_shape_from_tensor(data)
  138. elif isinstance(data, akg.tvm.container.Array):
  139. shape = tvm_shape_to_list(data)
  140. elif isinstance(data, int):
  141. shape = [data]
  142. elif isinstance(data, (tuple, list)):
  143. shape = list(data)
  144. elif isinstance(data, akg.tvm.expr.Var):
  145. shape = [data]
  146. else:
  147. raise TypeError("Refine axis does not support type {} for now.".format(type(data)))
  148. return shape
  149. def convert_to_list(something, convert_all=True):
  150. """convert other types to string."""
  151. out = []
  152. if isinstance(something, (list, tuple)):
  153. for x in something:
  154. out.append(convert_to_list(x, convert_all=False))
  155. else:
  156. if convert_all:
  157. out.append(something)
  158. else:
  159. out = something
  160. return out
  161. def to_tvm_nd_array(data, ctx=None):
  162. """convert other types to tvm nd array with specified context"""
  163. if ctx is None:
  164. ctx = akg.tvm.context("cuda", 0)
  165. if isinstance(data, list):
  166. return [akg.tvm.nd.array(d, ctx) for d in data]
  167. if isinstance(data, tuple):
  168. return (akg.tvm.nd.array(d, ctx) for d in data)
  169. return akg.tvm.nd.array(data, ctx)

AKG(Auto Kernel Generator)对深度神经网络中的算子进行优化,并提供特定模式下的算子自动融合功能。AKG与MindSpore的图算融合功能协同工作,可提升在不同硬件后端上运行网络的性能。