You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dsl_create.py 13 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348
  1. #!/usr/bin/env python3
  2. # coding: utf-8
  3. # Copyright 2019 Huawei Technologies Co., Ltd
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. """dsl create helping function"""
  17. import logging
  18. import math
  19. import akg
  20. from akg.utils import format_transform as ft_util
  21. from akg.utils import validation_check as vc_util
  22. class TensorUtils:
  23. """Class for creating tensor."""
  24. CREATE_SCH_ONLY = 'create_sch_only'
  25. @classmethod
  26. def get_tensor_attrs(cls, tensor):
  27. """get tensor attrs."""
  28. tensor_attrs = dict()
  29. if "attrs" in dir(tensor.op):
  30. tensor_attrs = dict(tensor.op.attrs.items())
  31. return tensor_attrs
  32. @classmethod
  33. def update_tensor_attrs(cls, tensor, attrs):
  34. """update tensor attrs."""
  35. tensor_attrs = cls.get_tensor_attrs(tensor)
  36. tensor_attrs.update(attrs)
  37. tensor = akg.tvm.compute(tensor.shape,
  38. lambda *indice: tensor[indice],
  39. name=tensor.op.name,
  40. tag=tensor.op.tag,
  41. attrs=tensor_attrs)
  42. return tensor
  43. @classmethod
  44. def is_create_sch_only(cls, tensor):
  45. tensor_attrs = cls.get_tensor_attrs(tensor)
  46. if cls.CREATE_SCH_ONLY in tensor_attrs.keys():
  47. return True
  48. return False
  49. @classmethod
  50. def is_output_value(cls, tensor):
  51. """check output value."""
  52. return not cls.is_create_sch_only(tensor)
  53. @classmethod
  54. def inplace_set(cls, input_tensor, output_tensor, buffer_name="data_buf"):
  55. """inplace set."""
  56. input_tensor_shape = ft_util.get_shape(input_tensor)
  57. output_tensor_shape = ft_util.get_shape(output_tensor)
  58. if not input_tensor_shape == output_tensor_shape:
  59. raise RuntimeError("Shape of the input_tensor and the output_tensor should be equal, "
  60. "but got %s and %s" % (input_tensor_shape, output_tensor_shape))
  61. output_tensor = cls.update_tensor_attrs(output_tensor, {cls.CREATE_SCH_ONLY: 1})
  62. data_buf = akg.tvm.decl_buffer(input_tensor.shape, input_tensor.dtype, name=buffer_name)
  63. binds_info = {input_tensor: data_buf, output_tensor: data_buf}
  64. return output_tensor, binds_info
  65. @classmethod
  66. def inplace_set_tensors(cls, input_tensors, output_tensors, buffer_names=None):
  67. """
  68. inplace set for tensors
  69. Args:
  70. in_tensors (Union[list, tuple]): Origin input tensors.
  71. out_tensors (Union[list, tuple]): Origin output tensors.
  72. buffer_names (Union[list, tuple] or None): Buffer names used to bind.
  73. Return:
  74. inplace_tensors (list): Output tensors with the inplace info.
  75. binds_infos (dict): Dictionary that maps the input tensor and the output
  76. tensor to buffer.
  77. """
  78. if not buffer_names:
  79. buffer_names = ["data_buf_%s" % i for i in range(len(input_tensors))]
  80. for arg in (input_tensors, output_tensors, buffer_names):
  81. if not isinstance(arg, (tuple, list)):
  82. raise RuntimeError("arg must be tuple or list!")
  83. if len(input_tensors) != len(output_tensors) or len(input_tensors) != len(buffer_names):
  84. raise RuntimeError("length of the input_tensors, output_tensors and buffer_names must be equal!")
  85. inplace_tensors = []
  86. binds_infos = dict()
  87. for input_tensor, output_tensor, buffer_name in zip(input_tensors, output_tensors, buffer_names):
  88. inplace_tensor, binds_info = cls.inplace_set(input_tensor, output_tensor, buffer_name)
  89. inplace_tensors.append(inplace_tensor)
  90. binds_infos.update(binds_info)
  91. return inplace_tensors, binds_infos
  92. def produce_shapes(shape1, shape2):
  93. """two input shapes produce three output shape."""
  94. shape1 = list(shape1)
  95. shape2 = list(shape2)
  96. flag = 0
  97. if len(shape1) < len(shape2):
  98. shape1, shape2 = shape2, shape1
  99. flag = 1
  100. output_shape_len = len(shape1)
  101. dec = output_shape_len - len(shape2)
  102. for i in range(dec):
  103. shape2 = [1] + shape2
  104. out_shape = []
  105. for i in range(output_shape_len):
  106. if (shape1[i] != shape2[i]) and (shape1[i] != 1) and (shape2[i] != 1):
  107. raise RuntimeError("input shapes not match!")
  108. if isinstance(shape1[i], int) and isinstance(shape2[i], int) and shape1[i] > shape2[i]:
  109. out_shape.append(shape1[i])
  110. else:
  111. out_shape.append(shape2[i])
  112. if flag == 1:
  113. shape1, shape2 = shape2, shape1
  114. return shape1, shape2, out_shape
  115. def get_reduce_out_shape(in_shape, axis=None, keepdims=False):
  116. """
  117. Computes ouput shape in reduction operators.
  118. Args:
  119. in_shape : input shape
  120. axis (Union[int, list, tuple]): The reduction axis. Default value is None, in this case,
  121. all dimensions will be reduced.
  122. keepdims (bool): If True, retains reduced dimensions with length 1, default value is False.
  123. Returns:
  124. output shape.
  125. """
  126. dims = len(in_shape)
  127. if axis is None:
  128. axis = list(range(dims))
  129. if not isinstance(axis, (int, list, tuple)):
  130. raise ValueError("axis must be of the following type: int, list, tuple.")
  131. if isinstance(axis, int):
  132. axis = [axis]
  133. axis = list(axis)
  134. for i, axis_val in enumerate(axis):
  135. if axis_val < 0:
  136. axis[i] = axis_val + dims
  137. if axis_val >= dims:
  138. raise ValueError("axis[{}] is {}, which exceeds max dimension {}".format(i, axis[i], dims))
  139. remaining_axis = []
  140. for i in range(dims):
  141. if i not in axis:
  142. remaining_axis.append(i)
  143. out_shape = []
  144. for i in range(dims):
  145. if i in remaining_axis:
  146. out_shape.append(in_shape[i])
  147. else:
  148. if keepdims:
  149. out_shape.append(1)
  150. if not out_shape:
  151. out_shape.append(1)
  152. return out_shape
  153. def get_input_pad_shape(shape, dtype):
  154. """Function for getting input pad shape."""
  155. pad_unit = ft_util.get_bytes(dtype, allow_none=True)
  156. if pad_unit is None:
  157. logging.warning("%s is not support in TensorAddPad, the result is not undefined.", dtype)
  158. return shape
  159. lastdim = int(math.ceil(shape[-1] / pad_unit) * pad_unit)
  160. pad_shape = [*shape[:-1], '{},{}'.format(shape[-1], lastdim)] if lastdim != shape[-1] else shape
  161. return pad_shape
  162. def mul_axis_sum(data, axes, keepdims, name=None, attrs=None):
  163. """calculate sum one by one."""
  164. if name is None and attrs is None:
  165. for axis in axes:
  166. data = akg.topi.sum(data, axis=axis, keepdims=keepdims)
  167. else:
  168. shape = [x.value for x in data.shape]
  169. for axis in axes[:-1]:
  170. data = akg.topi.sum(data, axis=axis, keepdims=keepdims)
  171. l_axis = shape[axes[-1]]
  172. k = akg.tvm.reduce_axis((0, l_axis), name="k")
  173. res_shape = [1 if i in axes else shape[i] for i in range(len(shape))]
  174. def sumfunc(*i):
  175. new_i = list(i)
  176. new_i[axes[-1]] = k
  177. return akg.tvm.sum(data(*tuple(new_i)), axis=k)
  178. if name is None:
  179. data = akg.tvm.compute(res_shape, sumfunc, attrs=attrs)
  180. elif attrs is None:
  181. data = akg.tvm.compute(res_shape, sumfunc, name=name)
  182. else:
  183. data = akg.tvm.compute(res_shape, sumfunc, name=name, attrs=attrs)
  184. return data
  185. def update_by_moving_average(hat_z, z, momentum):
  186. r"""
  187. Update value with moving average.
  188. Note:
  189. :math:`\hat{z_{new}} = momentum * \hat{z} + (1-momentum) * z`
  190. where \f$ \hat{z} \f$ is the estimated statistic and \f$ z \f$ is the new observed value.
  191. Args:
  192. hat_z (tvm.tensor.Tensor): Tensor of type float16, float32.
  193. z (tvm.tensor.Tensor): Tensor of type float16, float32.
  194. momentum (float): must meet '0.0 < momentum < 1.0'.
  195. Returns:
  196. tvm.tensor.Tensor, updated value.
  197. """
  198. run = akg.lang.cce.vmuls(hat_z, momentum)
  199. now = akg.lang.cce.vmuls(z, (1 - momentum))
  200. return akg.lang.cce.vadd(run, now)
  201. def cal_pad_shapes_by_strategy(shape, kernel, stride, strategy):
  202. """
  203. Calculate the pad size and output shape by padding strategy.
  204. Args:
  205. shape (Union[list, tuple]): Input shape, a list or tuple of 5 int numbers.
  206. kernel (Union[list, tuple]): List or tuple of two int numbers for pooling window's size.
  207. stride (Union[list, tuple]): List or tuple of two int numbers for window's stride.
  208. strategy (Union[str, list]): A string or list for padding strategy, should be 'VALID',
  209. 'SAME' or instance of list(including four int numbers, as 'CONSTANTS' strategy).
  210. Returns:
  211. pad_sizes: Padding sizes(a list of four int numbers: [H_head_pad, H_tail_pad, W_head_pad, W_tail_pad]).
  212. out_shape: Output tensor's shape(a list of two int numbers: [output_H, output_W]).
  213. """
  214. pool_shapes = [shape[2], shape[3]]
  215. out_shape = []
  216. pad_sizes = []
  217. contrain_var = False
  218. for sh in [shape, kernel, stride]:
  219. for s in sh:
  220. if not isinstance(s, (int, akg.tvm.expr.IntImm)):
  221. contrain_var = True
  222. if isinstance(strategy, str) and strategy.upper() == "VALID":
  223. for i in range(2):
  224. out_shape.append(math.ceil((pool_shapes[i] - (kernel[i] - 1)) / stride[i]))
  225. if out_shape[i] <= 0:
  226. raise ValueError("With pad mode {0}, the value of the kernel "
  227. "(or window) size should be less than or "
  228. "equal to that of the corresponding input "
  229. "shape!".format(strategy))
  230. pad_sizes += [0, 0] # for h
  231. pad_sizes += [0, 0] # for w
  232. elif isinstance(strategy, str) and strategy.upper() == "SAME":
  233. for i in range(2):
  234. out_shape.append(math.ceil(pool_shapes[i] / stride[i]))
  235. diff_shape = ((out_shape[i] - 1) * stride[i] + kernel[i]) - pool_shapes[i]
  236. diff_shape = diff_shape if diff_shape > 0 else 0
  237. pad_shape = [math.floor(diff_shape / 2), math.ceil(diff_shape / 2)]
  238. pad_sizes += pad_shape
  239. elif isinstance(strategy, (list, tuple)):
  240. if len(strategy) != 4:
  241. raise RuntimeError(
  242. "When with strateg 'CONSTANTS', strategy should be list or tuple of 4 int numbers but get {}".
  243. format(strategy))
  244. vc_util.check_pad('pad', strategy, 4)
  245. for i in range(2):
  246. pad_shape = [strategy[i * 2], strategy[i * 2 + 1]]
  247. if contrain_var:
  248. out_shape.append(akg.tvm.floordiv((pool_shapes[i] +
  249. (pad_shape[0] + pad_shape[1]) - kernel[i]), (stride[i])) + 1)
  250. else:
  251. out_shape.append(math.floor((pool_shapes[i] +
  252. (pad_shape[0] + pad_shape[1]) - kernel[i]) / float(stride[i])) + 1)
  253. pad_sizes += pad_shape
  254. height, width = out_shape
  255. if (isinstance(height, int) and height <= 0) or (isinstance(width, int) and width <= 0):
  256. raise ValueError("The height and witdth of calculated output"
  257. " shape [{}, {}] are invalid. Please check the "
  258. "input parameters!".format(height, width))
  259. else:
  260. raise RuntimeError("Padding strategies only support 'VALID', 'CONSTANTS' or 'SAME', but get {}".
  261. format(strategy))
  262. return pad_sizes, out_shape
  263. def broadcast_gradient_args(x, y):
  264. """
  265. Return the reduction indices for computing gradients of x op y with broadcast.
  266. Args:
  267. x (Union[list, tuple]): the shape of data input
  268. y (Union[list, tuple]): the shape of data input
  269. Returns:
  270. rx (list): the reduction indices for computing gradients of x
  271. ry (list): the reduction indices for computing gradients of y
  272. """
  273. rx = []
  274. ry = []
  275. for i, item in enumerate(x):
  276. if item < y[i]:
  277. rx.append(i)
  278. elif item > y[i]:
  279. ry.append(i)
  280. return rx, ry
  281. def zero_const(dtype):
  282. return akg.tvm.const(0, dtype)
  283. def one_const(dtype):
  284. return akg.tvm.const(1, dtype)
  285. def neg_one_const(dtype):
  286. return akg.tvm.const(-1, dtype)
  287. def half_const(dtype):
  288. return akg.tvm.const(0.5, dtype)
  289. def pi_const(dtype):
  290. return akg.tvm.const(3.1415926535897932384626433832795, dtype)
  291. def get_value(val, type):
  292. if isinstance(val, type) and type in [akg.tvm.expr.IntImm, akg.tvm.expr.FloatImm]:
  293. return val.value
  294. return val