You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

conv2d.py 8.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ===========================================================================
  15. """generate json desc for Conv2D"""
  16. from mindspore._extends.graph_kernel.model.op_infer import check_format_any, check_nd, conv_had_pad
  17. from mindspore._extends.graph_kernel.model.model import DataFormat as DF
  18. from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedException as GKException
  19. from ._utils import Expander, ExpanderInfoValidator as VLD
  20. M_ALIGN = 32
  21. N_ALIGN = 32
  22. K_ALIGN = 16
  23. K_LIMIT = 800
  24. MNK_LIMIT = 3 * (10 ** 10)
  25. N0_CHANNEL_ALIGN = 32
  26. N1_CHANNEL_ALIGN = 32
  27. C_CHANNEL_ALIGN = 16
  28. OUT_NHW_ALIGN = 128
  29. @VLD.add_format(DF.DEFAULT, DF.DEFAULT)
  30. @VLD.add_format(DF.NHWC, DF.NHWC)
  31. @VLD.check_attrs('format', 'pad_list', 'pad_mode', 'groups', 'group', 'kernel_size', 'stride', 'dilation')
  32. class Conv2D(Expander):
  33. """
  34. Conv2D expander
  35. Currently, only Conv2D that meets several conditions can be expanded, other cases will be skipped.
  36. Conditions to expand:
  37. inputs are NHWC format and float16.
  38. attr groups and group are 1.
  39. attr dilation are all 1.
  40. N channel of inputs > 16.
  41. C channel of inputs > 8.
  42. output N*H*W are multiplies of 128.
  43. """
  44. def __init__(self, expand_info):
  45. super().__init__(expand_info)
  46. self.dst_type = self.outputs[0]['data_type']
  47. self.dst_format = self.outputs[0]['format']
  48. self.has_pad = False
  49. self.can_optimize_to_matmul = False
  50. self.shape_0_pad = self.inputs[0]['shape']
  51. self.shape_1_pad = self.inputs[1]['shape']
  52. self.m = 0
  53. self.n = 0
  54. self.k = 0
  55. def _optimize_to_matmul(self):
  56. stride = self.attrs['stride']
  57. dilation = self.attrs['dilation']
  58. _, h, w, _ = self.inputs[1]['shape']
  59. if h == 1 and w == 1 and stride == [1, 1, 1, 1] and dilation == [1, 1, 1, 1] and \
  60. self.m % M_ALIGN == 0 and self.n % N_ALIGN == 0 and self.k % K_ALIGN == 0:
  61. return True
  62. return False
  63. def _check(self):
  64. type_0 = self.inputs[0]['data_type']
  65. type_1 = self.inputs[1]['data_type']
  66. if type_0 != "float16" or type_1 != "float16":
  67. raise GKException(
  68. "inputs type should be float16, but got {} and {}".format(type_0, type_1))
  69. formats = [self.inputs[0]['format'], self.inputs[1]['format'], self.attrs['format']]
  70. check_format_any(formats, DF.NHWC)
  71. groups = self.attrs['groups']
  72. group = self.attrs['group']
  73. if groups != 1 or group != 1:
  74. raise GKException(
  75. "groups and group should be both 1, but got {} and {}.".format(groups, group))
  76. dilation = self.attrs['dilation']
  77. check_nd(dilation, 4)
  78. if dilation != [1, 1, 1, 1]:
  79. raise GKException(
  80. "dilation should be all 1, but got {}".format(dilation))
  81. pad_list = self.attrs['pad_list']
  82. pad_mode = self.attrs['pad_mode']
  83. check_nd(pad_list, 4)
  84. self.has_pad = conv_had_pad(pad_list, pad_mode)
  85. shape_0 = self.inputs[0]['shape']
  86. shape_1 = self.inputs[1]['shape']
  87. stride = self.attrs['stride']
  88. check_nd(shape_0, 4)
  89. check_nd(shape_1, 4)
  90. check_nd(stride, 4)
  91. n0, h0, w0, c0 = shape_0
  92. n1, h1, w1, c1 = shape_1
  93. if (n0 % N0_CHANNEL_ALIGN) != 0:
  94. raise GKException("N({}) channel of first input should be multiples of {}".format(n0, N0_CHANNEL_ALIGN))
  95. if (n1 % N1_CHANNEL_ALIGN) != 0:
  96. raise GKException("O({}) channel of second input should be multiples of {}".format(n1, N1_CHANNEL_ALIGN))
  97. if c0 != c1 or (c0 % C_CHANNEL_ALIGN) != 0:
  98. raise GKException("C channel of inputs({}, {}) should be same and also be multiples of {}".format(
  99. c0, c1, C_CHANNEL_ALIGN))
  100. # n0 pad
  101. n0 = ((n0 + N0_CHANNEL_ALIGN - 1) //
  102. N0_CHANNEL_ALIGN) * N0_CHANNEL_ALIGN
  103. # h0, w0 pad
  104. if self.has_pad:
  105. h0 = h0 + pad_list[0] + pad_list[1]
  106. w0 = w0 + pad_list[2] + pad_list[3]
  107. # c0, c1 pad
  108. c0 = ((c0 + C_CHANNEL_ALIGN - 1) // C_CHANNEL_ALIGN) * C_CHANNEL_ALIGN
  109. c1 = c0
  110. # n1 pad
  111. n1 = ((n1 + N1_CHANNEL_ALIGN - 1) //
  112. N1_CHANNEL_ALIGN) * N1_CHANNEL_ALIGN
  113. # check if can optimize to matmul
  114. self.m, self.n, self.k = n0 * h0 * w0, n1, c1
  115. self.can_optimize_to_matmul = self._optimize_to_matmul()
  116. # requirements
  117. if self.can_optimize_to_matmul:
  118. if self.k > K_LIMIT:
  119. raise GKException(
  120. "If transformed to MatMul, C0({}) should not be larger than {}".format(self.k, K_LIMIT))
  121. if self.m * self.n * self.k >= MNK_LIMIT:
  122. raise GKException("If transformed to MatMul, The total size({}) should not be larger than {}".format(
  123. self.m * self.n * self.k, MNK_LIMIT))
  124. else:
  125. out_h, out_w = (h0 - h1) // stride[-2] + 1, (w0 - w1) // stride[-1] + 1
  126. if ((n0 * out_h * out_w) % OUT_NHW_ALIGN) != 0:
  127. raise GKException("N({}) * H({}) * W({}) of output should be multiplies of {}".format(
  128. n0, out_h, out_w, OUT_NHW_ALIGN))
  129. if stride != [1, 1, 2, 2]:
  130. raise GKException("Stride H and W should be [2, 2] but got [{}, {}]".format(stride[2], stride[3]))
  131. self.shape_0_pad = [n0, h0, w0, c0]
  132. self.shape_1_pad = [n1, h1, w1, c1]
  133. def _expand(self, graph_builder):
  134. input_0 = self.inputs[0]
  135. input_1 = self.inputs[1]
  136. n0, _, _, c0 = input_0.shape
  137. n1, _, _, c1 = input_1.shape
  138. n0_p, h0_p, w0_p, c0_p = self.shape_0_pad
  139. n1_p, _, _, c1_p = self.shape_1_pad
  140. pad_value = 0
  141. # input0 pad
  142. input_0_pad_before = [0, 0, 0, 0]
  143. input_0_pad_after = [0, 0, 0, 0]
  144. if self.has_pad:
  145. pad_list = self.attrs['pad_list']
  146. input_0_pad_before = [0, pad_list[0], pad_list[2], 0]
  147. input_0_pad_after = [0, pad_list[1], pad_list[3], 0]
  148. input_0_pad_after[0] = n0_p - n0
  149. input_0_pad_after[3] = c0_p - c0
  150. if input_0_pad_before != [0, 0, 0, 0] or input_0_pad_after != [0, 0, 0, 0]:
  151. input_0 = graph_builder.emit('PadAkg', [input_0], attrs={'head': input_0_pad_before,
  152. 'tail': input_0_pad_after,
  153. 'pad_val': pad_value})
  154. # input1 pad
  155. input_1_pad_after = [n1_p - n1, 0, 0, c1_p - c1]
  156. if input_1_pad_after != [0, 0, 0, 0]:
  157. input_1 = graph_builder.emit('PadAkg', [input_1], attrs={'head': [0, 0, 0, 0],
  158. 'tail': input_1_pad_after,
  159. 'pad_val': pad_value})
  160. if self.can_optimize_to_matmul:
  161. a = graph_builder.emit('Reshape', [input_0], attrs={'shape': [self.m, self.k]})
  162. b = graph_builder.emit('Reshape', [input_1], attrs={'shape': [self.n, self.k]})
  163. c = graph_builder.emit('MatMul', [a, b], attrs={'transpose_a': False,
  164. 'transpose_b': True,
  165. 'dst_type': self.dst_type})
  166. result = graph_builder.emit('Reshape', [c], attrs={'shape': [n0_p, h0_p, w0_p, n1_p],
  167. 'format': self.dst_format})
  168. else:
  169. attrs = self.attrs
  170. attrs['pad_list'] = [0, 0, 0, 0]
  171. attrs['dst_type'] = self.dst_type
  172. result = graph_builder.emit('Conv2D', [input_0, input_1], attrs=attrs)
  173. # unpad
  174. unpad_after = [input_0_pad_after[0], 0, 0, input_1_pad_after[0]]
  175. if unpad_after != [0, 0, 0, 0]:
  176. result = graph_builder.emit('UnPadAkg', [result], attrs={'tail': unpad_after})
  177. return result