You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

conv2d.py 6.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. from __future__ import absolute_import
  16. import te.lang.cce
  17. from te.platform.fusion_manager import fusion_manager
  18. from .conv_layer import conv_layer_cce
  19. from .conv_layer_fast import conv_layer_fast_cce
  20. from topi.cce import util
  21. from te import platform as cce
  22. Nonetype = type(None)
  23. # pylint: disable=unused-argument, no-value-for-parameter, too-many-branches
  24. @fusion_manager.register("conv2d")
  25. def conv2d_compute(inputs, weights, bias, outputs, strides, pad_list, dilations,
  26. kernel_name="conv2d"):
  27. """
  28. conv2d compute
  29. Notice
  30. ------
  31. only used by framework combine with IR
  32. Parameters
  33. ----------
  34. inputs: tvm placeholder
  35. input 5hd feature map tensor
  36. weights: tvm placeholder
  37. input frac_z weight tensor
  38. outputs: tvm placeholder
  39. output tensor, dtype must be assigned
  40. bias: tvm placeholder or None
  41. input 1d bias tensor
  42. strides: integers
  43. stride on H/W, format sensitive
  44. pads: tuple/list of 4 integers
  45. [pad_top, pad_bottom, pad_left, pad_right]
  46. dilations: integers
  47. dilation on H/W, format sensitive
  48. kernel_name: string
  49. kernel name, default value is "conv2d"
  50. Returns
  51. -------
  52. tvm compute
  53. """
  54. shape_w = []
  55. for i in weights.op.attrs['ori_shape']:
  56. shape_w.append(i.value)
  57. format_w = weights.op.attrs['ori_format']
  58. if format_w == "NCHW":
  59. weight_h = shape_w[2]
  60. weight_w = shape_w[3]
  61. elif format_w == "NHWC":
  62. weight_h = shape_w[1]
  63. weight_w = shape_w[2]
  64. elif format_w == "HWCN":
  65. weight_h = shape_w[0]
  66. weight_w = shape_w[1]
  67. else:
  68. raise RuntimeError("weights ori_format should be NCHW, NHWC or HWCN")
  69. format_x = inputs.op.attrs['ori_format']
  70. if format_x == "NCHW":
  71. strideh = strides[0]
  72. stridew = strides[0]
  73. dlt_h = dilations[0]
  74. dlt_w = dilations[0]
  75. elif format_x == "NHWC":
  76. strideh = strides[0]
  77. stridew = strides[0]
  78. dlt_h = dilations[0]
  79. dlt_w = dilations[0]
  80. else:
  81. raise RuntimeError("inputs ori_format should be NCHW or NHWC")
  82. if len(pad_list) == 4:
  83. padh = [pad_list[0], pad_list[1]]
  84. padw = [pad_list[2], pad_list[3]]
  85. else:
  86. raise RuntimeError("pads shape should be 4d.")
  87. para_dict = {"pad_h": padh, "pad_w": padw, "stride_h": strideh, "stride_w": stridew,
  88. "filter_h": weight_h, "filter_w": weight_w, "bias_tensor": bias}
  89. if cce.CceProductParams().cce_product == "5.10":
  90. para_dict["mad_dtype"] = "float16"
  91. res = te.lang.cce.conv(inputs, weights, para_dict)
  92. else:
  93. res = te.lang.cce.conv(inputs, weights, para_dict)
  94. return res
  95. @util.check_input_type(dict, dict, (dict, Nonetype), dict, (tuple, list), (tuple, list), (tuple, list),
  96. str)
  97. def conv2d(inputs, weights, bias, outputs, strides, pad_list, dilations,
  98. kernel_name="conv2d"):
  99. """
  100. algorithm: conv2d
  101. Notice
  102. ------
  103. only used by framework combine with IR
  104. Parameters
  105. ----------
  106. inputs: dict with keys(shape and dtype)
  107. input 4d feature map tensor
  108. weights: dict with keys(shape and dtype)
  109. input 4d weight tensor
  110. outputs: dict with keys(shape and dtype)
  111. output tensor, dtype must be assigned
  112. bias: dict with keys(shape and dtype) or None
  113. input bias tensor
  114. strides: integers
  115. stride on H/W, format sensitive
  116. pads: integers
  117. [pad_top, pad_bottom, pad_left, pad_right]
  118. dilations: tuple/list of 4 integers
  119. dilation on H/W, format sensitive
  120. kernel_name: str
  121. kernel name, default value is "conv2d"
  122. Returns
  123. -------
  124. None
  125. """
  126. shape_x = inputs.get("ori_shape")
  127. in_dtype = inputs.get("dtype")
  128. shape_w = weights.get("ori_shape")
  129. w_dtype = weights.get("dtype")
  130. res_dtype = outputs.get("dtype")
  131. if len(pad_list) == 4:
  132. padh = [pad_list[0], pad_list[1]]
  133. padw = [pad_list[2], pad_list[3]]
  134. else:
  135. raise RuntimeError("pads shape should be 4d.")
  136. if (not isinstance(shape_x, (tuple, list))) or len(shape_x) != 4:
  137. raise RuntimeError("inputs should be 4d list.")
  138. if (not isinstance(shape_w, (tuple, list))) or len(shape_w) != 4:
  139. raise RuntimeError("weights should be 4d list.")
  140. format_x = inputs.get("ori_format")
  141. if format_x == "NCHW":
  142. shape_fm = shape_x
  143. strideh = strides[0]
  144. stridew = strides[0]
  145. dlt_h = dilations[0]
  146. dlt_w = dilations[0]
  147. elif format_x == "NHWC":
  148. shape_fm = [shape_x[0], shape_x[3], shape_x[1], shape_x[2]]
  149. strideh = strides[0]
  150. stridew = strides[0]
  151. dlt_h = dilations[0]
  152. dlt_w = dilations[0]
  153. else:
  154. raise RuntimeError("inputs ori_format should be NCHW or NHWC.")
  155. format_w = weights.get("ori_format")
  156. if format_w == "NCHW":
  157. shape_filter = shape_w
  158. elif format_w == "NHWC":
  159. shape_filter = [shape_w[0], shape_w[3], shape_w[1], shape_w[2]]
  160. elif format_w == "HWCN":
  161. shape_filter = [shape_w[3], shape_w[2], shape_w[0], shape_w[1]]
  162. else:
  163. raise RuntimeError("weights ori_format should be NCHW, NHWC or HWCN.")
  164. if bias is None:
  165. use_bias = False
  166. else:
  167. use_bias = True
  168. if cce.CceProductParams().cce_product == "5.10":
  169. conv_layer_fast_cce(shape_fm, shape_filter, in_dtype, w_dtype, res_dtype,
  170. padh, padw, strideh, stridew, bias=use_bias,
  171. kernel_name=kernel_name, need_build=True, need_print=False)
  172. else:
  173. conv_layer_cce(shape_fm, shape_filter, in_dtype, w_dtype, res_dtype,
  174. padh, padw, strideh, stridew,
  175. quantize_config=[0, 0, 0], scale_sqrt=[0, 0, 0],
  176. bias=use_bias, kernel_name=kernel_name,
  177. need_build=True, need_print=False)