You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

five2four.py 8.2 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194
  1. #!/usr/bin/env python3
  2. # coding: utf-8
  3. # Copyright 2019 Huawei Technologies Co., Ltd
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. """operator dsl function: five2four"""
  17. import akg.topi
  18. from akg.tvm.hybrid import script
  19. from akg.utils import custom_tiling as ct_util
  20. from akg.utils import validation_check as vc_util
  21. from akg.utils.format_transform import get_shape, get_bytes, to_tvm_const
  22. from akg.utils.dynamic_shape import shape_is_dynamic
  23. C_LIMIT_FOR_CAST = 3600
  24. def get_attrs():
  25. """get attrs."""
  26. attrs = {
  27. "pragma_sink_last_axis": False
  28. }
  29. return attrs
  30. def five2four_tiling_strategy(tensor, c_value=None, expansion=None):
  31. """Custom tiling strategy for five2four op."""
  32. strategy = list()
  33. if c_value is None:
  34. strategy = ct_util.create_template(tensor=tensor,
  35. template=ct_util.TileTemplate.NC1HWC0)
  36. elif not shape_is_dynamic(tensor):
  37. c_value = 16 if c_value < 16 else c_value
  38. node_n = ct_util.create_constraint_on_tensor(tensor=tensor,
  39. values=1,
  40. constraints=ct_util.TileConstraint.FACTOR,
  41. tensor_pos=0)
  42. node_c1 = ct_util.create_constraint_on_tensor(tensor=tensor,
  43. values="FULL",
  44. constraints=ct_util.TileConstraint.MAX,
  45. tensor_pos=1)
  46. node_c0 = ct_util.create_constraint_on_tensor(tensor=tensor,
  47. values=c_value,
  48. constraints=ct_util.TileConstraint.FACTOR,
  49. tensor_pos=4)
  50. strategy = node_n + node_c1 + node_c0
  51. if expansion:
  52. strategy.append(ct_util.create_constraint_on_tensor(tensor=tensor,
  53. values=expansion,
  54. constraints=ct_util.TileConstraint.SET_EXPANSION)[0])
  55. if shape_is_dynamic(tensor):
  56. # axis should be full tiled due to cast operator
  57. strategy.append(ct_util.modify_common_constraints(
  58. value=0.85, constraint=ct_util.TileConstraint.SET_MEM_RATIO))
  59. return strategy
  60. @vc_util.check_input_type(akg.tvm.tensor.Tensor, (list, tuple), str, str)
  61. def five2four(data, shape4d, dst_type, format_):
  62. """
  63. Convert 5-dims "data" to 4-dims,the format of "data" is defined in "format_"
  64. Args:
  65. data (tvm.tensor.Tensor): 5-dims tensor of type float16, float32
  66. shape4d (Union[list, tuple]): a list has 4 nums, shape of output Tensor
  67. dst_type (str): data type of output Tensor
  68. format_ (str): a str defined the format of returns, support NCHW and NHWC
  69. Returns:
  70. 4-dims tvm.tensor.Tensor.
  71. """
  72. vc_util.ops_dtype_check([data.dtype, dst_type], vc_util.DtypeForDavinci.ALL_FLOAT)
  73. shape5d = get_shape(data)
  74. if not shape_is_dynamic(data):
  75. if len(shape5d) != 5 or shape5d[-1] != 16:
  76. raise ValueError("five2four_cce only support 5-dim data and last dim should be 16")
  77. bs, c1, h, w, c0 = shape5d
  78. if not shape_is_dynamic(data):
  79. vc_util.davinci_format_check(shape5d, "NC1HWC0", dim=5)
  80. # Check format
  81. if format_ not in ['NCHW', 'NHWC']:
  82. raise ValueError("{} format is not support, five2four only support NCHW and NHWC format input"
  83. .format(format_))
  84. if format_ == "NCHW":
  85. if shape_is_dynamic(data):
  86. shape4d = [bs, c1 * c0, h, w]
  87. _, c, h_4d, w_4d = shape4d
  88. else:
  89. if shape_is_dynamic(data):
  90. shape4d = [bs, h, w, c1 * c0]
  91. _, h_4d, w_4d, c = shape4d
  92. vc_util.davinci_format_check(shape4d, format_, dim=4)
  93. # Check is shape4d and shape5d match
  94. if False not in [isinstance(s, (int, akg.tvm.expr.IntImm)) for s in shape5d]:
  95. if h_4d != h or w_4d != w:
  96. raise ValueError("five2four_cce's shape4d h and w should equal to data shape's h and w")
  97. if c > c1 * c0 or c <= (c1 - 1) * c0:
  98. raise ValueError("five2four_cce's shape4d c should in set ((c1 - 1) * c0, c1 * c0]")
  99. # Check size c when casting happens
  100. if not shape_is_dynamic(data):
  101. if data.dtype != dst_type and c >= C_LIMIT_FOR_CAST:
  102. raise ValueError("When input and output data type is not matched, shape of 'c' axis should not exceed {}, "
  103. "while currently set is {}".format(C_LIMIT_FOR_CAST, c))
  104. @script(capture=locals())
  105. def nc1hwc0_to_nhwc(inputs, bs, h, w, c, c1, c0):
  106. output = allocate((bs, h, w, c), inputs.dtype, "local")
  107. for n_i in range(bs):
  108. for h_i in range(h):
  109. for w_i in range(w):
  110. for c_i in range(c1):
  111. for c_i0 in range(c0):
  112. output[n_i, h_i, w_i, c_i * c0 + c_i0] = inputs[n_i, c_i, h_i, w_i, c_i0]
  113. return output
  114. @script(capture=locals())
  115. def nc1hwc0_to_nchw(inputs, bs, h, w, c, c1, c0):
  116. output = allocate((bs, c, h, w), inputs.dtype, "local")
  117. for n_i in range(bs):
  118. for c_i in range(c1):
  119. for h_i in range(h):
  120. for w_i in range(w):
  121. for c_i0 in range(c0):
  122. output[n_i, c_i * c0 + c_i0, h_i, w_i] = inputs[n_i, c_i, h_i, w_i, c_i0]
  123. return output
  124. # if c % 16 == 0, h and w == 1, five2four is a reshape operation
  125. if shape_is_dynamic(data):
  126. call_reshape = isinstance(h, int) and isinstance(w, int) and h == 1 and w == 1
  127. else:
  128. call_reshape = h == 1 and w == 1 and c % 16 == 0
  129. c_value = None
  130. expansion = None
  131. if format_ == "NHWC":
  132. if call_reshape:
  133. output = akg.topi.reshape(data, (bs, h, w, c))
  134. if shape_is_dynamic(data):
  135. output = akg.tvm.compute((bs, h, w, c), lambda *indice: output(*indice), name="reshape")
  136. elif c < c0:
  137. reshape_output = akg.topi.reshape(data, (bs, h, w, c0))
  138. output = akg.tvm.compute((bs, h, w, c), lambda *i: reshape_output(*i), name='slice_output')
  139. else:
  140. output = nc1hwc0_to_nhwc(
  141. data,
  142. to_tvm_const(bs),
  143. to_tvm_const(h),
  144. to_tvm_const(w),
  145. to_tvm_const(c),
  146. to_tvm_const(c1),
  147. to_tvm_const(c0))
  148. else:
  149. if call_reshape:
  150. output = akg.topi.reshape(data, (bs, c, h, w))
  151. if shape_is_dynamic(data):
  152. output = akg.tvm.compute((bs, c, h, w), lambda *indice: output(*indice), name="reshape")
  153. else:
  154. output = nc1hwc0_to_nchw(
  155. data,
  156. to_tvm_const(bs),
  157. to_tvm_const(h),
  158. to_tvm_const(w),
  159. to_tvm_const(c),
  160. to_tvm_const(c1),
  161. to_tvm_const(c0))
  162. # two special cases for tiling strategy
  163. if not shape_is_dynamic(data):
  164. if c < c0 or output.dtype != dst_type:
  165. c_value = c
  166. if c % c0 != 0 and output.dtype != dst_type:
  167. expansion = int(ct_util.BLOCK_SIZE / get_bytes(data.dtype))
  168. attrs = get_attrs()
  169. if not call_reshape:
  170. attrs["custom_tiling"] = five2four_tiling_strategy(data, c_value, expansion)
  171. if output.dtype != dst_type:
  172. output = akg.topi.cast(output, dst_type)
  173. return output, attrs