You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

maxpool_ad.py 23 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492
  1. #!/usr/bin/env python3
  2. # coding: utf-8
  3. # Copyright 2019 Huawei Technologies Co., Ltd
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. """operator dsl function: maxpool_ad"""
  17. import akg.tvm
  18. import akg.topi
  19. import akg
  20. from akg.ops.nn import maxpool
  21. from akg.utils.format_transform import get_shape
  22. from akg.utils.dsl_create import cal_pad_shapes_by_strategy
  23. from akg.utils import kernel_exec as utils
  24. from akg.utils import validation_check as vc_util
  25. @vc_util.check_input_type(akg.tvm.tensor.Tensor, akg.tvm.tensor.Tensor, (list, tuple), (list, tuple),
  26. (str, list, tuple))
  27. def maxpool_ad_no_custom_diff_poly_all_max(head, data, kernel, stride, pad):
  28. """automatic differentiate of maxpool with polyhedral"""
  29. attrs = {"enable_post_poly_loop_partition": False, "enable_pre_poly_loop_partition": False}
  30. maxpool_fwd = maxpool.old_maxpool(data, kernel, stride, pad)
  31. [dl_ddata] = akg.differentiate(maxpool_fwd, [data], head, None, None)
  32. return dl_ddata, attrs
  33. @vc_util.check_input_type(akg.tvm.tensor.Tensor, akg.tvm.tensor.Tensor,
  34. (list, tuple), (list, tuple), (str, list, tuple))
  35. def maxpool_ad_no_custom_diff_manual_schedule_all_max(head, data, kernel, stride, pad):
  36. """automatic differentiate of maxpool with manual schedule."""
  37. attrs = {"enable_post_poly_loop_partition": False, "enable_pre_poly_loop_partition": False}
  38. maxpool_fwd = maxpool.old_maxpool(data, kernel, stride, pad)
  39. [dl_ddata] = akg.differentiate(maxpool_fwd, [data], head, None, None)
  40. # schedule for differetiation operation
  41. s = akg.tvm.create_schedule([dl_ddata.op])
  42. new_tensor_red = dl_ddata
  43. new_tensor = new_tensor_red.op.input_tensors[0]
  44. data = new_tensor.op.input_tensors[0]
  45. broadcast = new_tensor.op.input_tensors[1]
  46. head = new_tensor.op.input_tensors[2]
  47. forward = broadcast.op.input_tensors[0]
  48. def comp_func(s):
  49. data_ub = s.cache_read(data, "local.UB", [forward, new_tensor])
  50. head_ub = s.cache_read(head, "local.UB", [new_tensor])
  51. result_ub = s.cache_write(new_tensor_red, "local.UB")
  52. s[broadcast].set_scope("local.UB")
  53. s[forward].set_scope("local.UB")
  54. b, c1, h, w, c0 = forward.op.axis
  55. oh, ow = forward.op.reduce_axis
  56. s[forward].reorder(oh, ow, b, c1, h, w, c0)
  57. s[new_tensor].set_scope("local.UB")
  58. b, c1, h, w, c0 = result_ub.op.axis
  59. s[result_ub].reorder(*result_ub.op.reduce_axis, b, c1, h, w, c0)
  60. s[broadcast].compute_at(s[result_ub], b)
  61. s[new_tensor].compute_at(s[result_ub], b)
  62. return dl_ddata, comp_func, attrs
  63. @vc_util.check_input_type(akg.tvm.tensor.Tensor, akg.tvm.tensor.Tensor, akg.tvm.tensor.Tensor,
  64. akg.tvm.tensor.Tensor, (list, tuple), (list, tuple), (str, list, tuple))
  65. def maxpool_ad(head, data, forward, mask, kernel, stride, pad):
  66. """automatic differentiate of maxpool with manual schedule."""
  67. shape = get_shape(data)
  68. dtype = data.dtype
  69. kernel_h, kernel_w = kernel
  70. stride_h, stride_w = stride
  71. [ph_h, _, pw_h, _], [out_size_h, out_size_w] = \
  72. cal_pad_shapes_by_strategy(shape, kernel, stride, pad)
  73. batch_size, input_c1, input_h, input_w, input_c0 = shape
  74. # tile size one is proved to be the most efficient one
  75. tile_scale_h = 1
  76. tile_scale_w = 1
  77. tile_h = stride_h * tile_scale_h
  78. if kernel_h == stride_h: # non-overlapping case
  79. tile_h_pad_u = ph_h % stride_h
  80. elif kernel_h % stride_h == 0:
  81. tile_h_pad_u = kernel_h - stride_h - ph_h
  82. else:
  83. tile_h_pad_u = kernel_h - kernel_h % stride_h - ph_h
  84. tile_h_pad_l = kernel_h - stride_h + ph_h
  85. tile_input_h = tile_h + tile_h_pad_u + tile_h_pad_l
  86. tile_h_out = (input_h - 1) // tile_h + 1
  87. if ph_h % stride_h == 0:
  88. pad_output_h = ph_h // stride_h
  89. else:
  90. pad_output_h = ph_h // stride_h + 1
  91. if tile_h_pad_u % stride_h == 0:
  92. pad_output_h -= tile_h_pad_u // stride_h
  93. else:
  94. pad_output_h -= tile_h_pad_u // stride_h + 1
  95. tile_output_h = (tile_input_h - kernel_h) // stride_h + 1
  96. tile_w = stride_w * tile_scale_w
  97. if kernel_w == stride_w: # non-overlapping case
  98. tile_w_pad_u = pw_h % stride_w
  99. elif kernel_w % stride_w == 0:
  100. tile_w_pad_u = kernel_w - stride_w - pw_h
  101. else:
  102. tile_w_pad_u = kernel_w - kernel_w % stride_w - pw_h
  103. tile_w_pad_l = kernel_w - stride_w + pw_h
  104. tile_input_w = tile_w + tile_w_pad_u + tile_w_pad_l
  105. tile_w_out = (input_w - 1) // tile_w + 1
  106. if pw_h % stride_w == 0:
  107. pad_output_w = pw_h // stride_w
  108. else:
  109. pad_output_w = pw_h // stride_w + 1
  110. if tile_w_pad_u % stride_w == 0:
  111. pad_output_w -= tile_w_pad_u // stride_w
  112. else:
  113. pad_output_w -= tile_w_pad_u // stride_w + 1
  114. tile_output_w = (tile_input_w - kernel_w) // stride_w + 1
  115. def custom_maxpool_fdiff(out, inputs, head_, ad_attrs, new_pld_array):
  116. head_reshaped = akg.tvm.compute((batch_size, input_c1, tile_h_out, tile_w_out,
  117. tile_output_h, tile_output_w, input_c0),
  118. lambda b, c1, h_out, w_out, oh, ow, c0:
  119. akg.tvm.expr.Select(
  120. akg.tvm.any(h_out * tile_scale_h + pad_output_h + oh < 0,
  121. h_out * tile_scale_h + pad_output_h + oh > out_size_h - 1,
  122. w_out * tile_scale_w + pad_output_w + ow < 0,
  123. w_out * tile_scale_w + pad_output_w + ow > out_size_w - 1),
  124. akg.tvm.const(0.0, dtype=dtype),
  125. head_(b, c1,
  126. h_out * tile_scale_h + pad_output_h + oh,
  127. w_out * tile_scale_w + pad_output_w + ow,
  128. c0)),
  129. name="head_reshaped")
  130. mask_reshaped = akg.tvm.compute((batch_size, input_c1, tile_h_out, tile_w_out,
  131. tile_output_h, tile_output_w, kernel_h, kernel_w, input_c0),
  132. lambda b, c1, h_out, w_out, oh, ow, kh, kw, c0:
  133. akg.tvm.expr.Select(
  134. akg.tvm.any(h_out * tile_scale_h + pad_output_h + oh < 0,
  135. h_out * tile_scale_h + pad_output_h + oh > out_size_h - 1,
  136. w_out * tile_scale_w + pad_output_w + ow < 0,
  137. w_out * tile_scale_w + pad_output_w + ow > out_size_w - 1),
  138. akg.tvm.const(0.0, dtype=dtype),
  139. mask(b, c1, kh, kw,
  140. h_out * tile_scale_h + pad_output_h + oh,
  141. w_out * tile_scale_w + pad_output_w + ow,
  142. c0)),
  143. name="mask_reshaped")
  144. d_data = akg.tvm.compute((batch_size, input_c1, tile_h_out, tile_w_out,
  145. tile_output_h, tile_output_w, kernel_h, kernel_w, input_c0),
  146. lambda b, c1, h_out, w_out, oh, ow, kh, kw, c0:
  147. mask_reshaped(b, c1, h_out, w_out, oh, ow, kh, kw, c0)
  148. * head_reshaped(b, c1, h_out, w_out, oh, ow, c0),
  149. name="d_data")
  150. data_reorg = akg.tvm.compute((batch_size, input_c1, tile_h_out, tile_w_out,
  151. tile_output_h, tile_output_w, tile_h, tile_w, input_c0),
  152. lambda b, c1, h_out, w_out, oh, ow, h, w, c0:
  153. akg.tvm.expr.Select(
  154. akg.tvm.any(h + tile_h_pad_u < oh * stride_h,
  155. h + tile_h_pad_u > oh * stride_h + kernel_h - 1,
  156. w + tile_w_pad_u < ow * stride_w,
  157. w + tile_w_pad_u > ow * stride_w + kernel_w - 1),
  158. akg.tvm.const(0, dtype=dtype),
  159. d_data(b, c1, h_out, w_out, oh, ow,
  160. h + tile_h_pad_u - oh * stride_h,
  161. w + tile_w_pad_u - ow * stride_w,
  162. c0)),
  163. name="data_reorg")
  164. result_tile = akg.topi.sum(data_reorg, [4, 5])
  165. result = akg.tvm.compute(shape,
  166. lambda b, c1, h, w, c0:
  167. result_tile(b, c1, h // tile_h, w // tile_w, h % tile_h, w % tile_w, c0),
  168. name="result")
  169. return [result]
  170. # override differentiation computation with custom function
  171. [dl_ddata] = akg.differentiate(forward, [data], head, None, None,
  172. override={forward: ([data], custom_maxpool_fdiff)})
  173. # schedule for differetiation operation
  174. s = akg.tvm.create_schedule([dl_ddata.op])
  175. # get computations
  176. result = dl_ddata
  177. result_tile = result.op.input_tensors[0]
  178. data_reorg = result_tile.op.input_tensors[0]
  179. d_data = data_reorg.op.input_tensors[0]
  180. mask_reshaped = d_data.op.input_tensors[0]
  181. head_reshaped = d_data.op.input_tensors[1]
  182. def comp_func(s):
  183. data_ub = s.cache_read(mask, "local.UB", [mask_reshaped])
  184. head_ub = s.cache_read(head, "local.UB", [head_reshaped])
  185. result_ub = s.cache_write(result, "local.UB")
  186. s[d_data].set_scope("local.UB")
  187. s[data_reorg].set_scope("local.UB")
  188. s[mask_reshaped].set_scope("local.UB")
  189. s[head_reshaped].set_scope("local.UB")
  190. s[result_tile].set_scope("local.UB")
  191. s[result_ub].compute_inline()
  192. # inline inputs
  193. s[head_ub].compute_inline()
  194. s[data_ub].compute_inline()
  195. # result_tile dependencies
  196. s[data_reorg].compute_inline()
  197. b, c1, h_out, w_out, h, w, c0 = result_tile.op.axis
  198. oh, ow = result_tile.op.reduce_axis
  199. s[result_tile].reorder(b, c1, h_out, w_out, h, w, oh, ow, c0)
  200. s[d_data].compute_at(s[result_tile], w_out)
  201. s[mask_reshaped].compute_at(s[result_tile], w_out)
  202. s[head_reshaped].compute_at(s[result_tile], w_out)
  203. # tile result
  204. b, c1, h, w, c0 = result.op.axis
  205. h_out, h_in = s[result].split(h, tile_h)
  206. w_out, w_in = s[result].split(w, tile_w)
  207. s[result].reorder(b, c1, h_out, w_out, h_in, w_in, c0)
  208. s[result_tile].compute_at(s[result], w_out)
  209. return dl_ddata, comp_func
  210. @vc_util.check_input_type((list, tuple), (list, tuple), (list, tuple), (str, list, tuple),
  211. str, (bool, type(None)), (dict, type(None)))
  212. def maxpool_ad_manual_schedule_all_max(shape, kernel, stride, pad, dtype, polyhedral=True, attrs=None):
  213. """automatic differentiate of maxpool with manual schedule for all maximum."""
  214. kernel_h, kernel_w = kernel
  215. stride_h, stride_w = stride
  216. pad_h, pad_w, _, _ = pad
  217. batch_size, input_c1, input_h, input_w, input_c0 = shape
  218. pad_shape = (batch_size, input_c1, input_h + 2 * pad_h, input_w + 2 * pad_w, input_c0)
  219. out_size_h = (input_h + 2 * pad_h - kernel_h) // stride_h + 1
  220. out_size_w = (input_w + 2 * pad_w - kernel_w) // stride_w + 1
  221. out_shape = (batch_size, input_c1, out_size_h, out_size_w, input_c0)
  222. def custom_maxpool_fdiff(out, inputs, head_, ad_attrs, new_pld_array):
  223. in_data = inputs[0]
  224. data_separated_by_windows = (kernel_h, kernel_w, batch_size, input_c1, out_size_h, out_size_w, input_c0)
  225. pad_data = akg.tvm.compute(pad_shape,
  226. lambda b, c1, h, w, c0:
  227. akg.tvm.expr.Select(
  228. akg.tvm.all(h >= pad_h,
  229. h < input_h + pad_h,
  230. w >= pad_w,
  231. w < input_w + pad_w),
  232. in_data(b, c1, h - pad_h, w - pad_w, c0),
  233. akg.tvm.const(0.0, dtype=dtype)),
  234. name="pad_data")
  235. data_reshaped = akg.tvm.compute(data_separated_by_windows,
  236. lambda wh, ww, b, c1, oh, ow, c0:
  237. pad_data(b, c1, oh * stride_h + wh, ow * stride_w + ww, c0),
  238. name="data_reshaped")
  239. max_broadcast = akg.tvm.compute(data_separated_by_windows,
  240. lambda wh, ww, b, c1, oh, ow, c0:
  241. out(b, c1, oh, ow, c0),
  242. name="max_broadcast")
  243. equal = akg.tvm.compute(data_separated_by_windows,
  244. lambda wh, ww, b, c1, oh, ow, c0:
  245. akg.tvm.expr.Select(
  246. max_broadcast(wh, ww, b, c1, oh, ow, c0) ==
  247. data_reshaped(wh, ww, b, c1, oh, ow, c0),
  248. head_(b, c1, oh, ow, c0),
  249. akg.tvm.const(0.0, dtype=dtype)),
  250. name="equal")
  251. data_reorg = akg.tvm.compute((out_size_h, out_size_w, batch_size, input_c1, input_h + 2 * pad_h,
  252. input_w + 2 * pad_w, input_c0),
  253. lambda oh, ow, b, c1, h, w, c0:
  254. akg.tvm.expr.Select(
  255. akg.tvm.any(h < oh * stride_h,
  256. h > oh * stride_h + kernel_h - 1,
  257. w < ow * stride_w,
  258. w > ow * stride_w + kernel_w - 1),
  259. akg.tvm.const(0, dtype=dtype),
  260. equal(h - oh * stride_h, w - ow * stride_w, b, c1, oh, ow, c0)),
  261. name="data_reorg")
  262. result_pad = akg.topi.sum(data_reorg, [0, 1])
  263. result = akg.tvm.compute(shape,
  264. lambda b, c1, h, w, c0:
  265. result_pad(b, c1, h + pad_h, w + pad_w, c0),
  266. name="result")
  267. return [result]
  268. # tensor for the input data
  269. data = akg.tvm.placeholder(shape, dtype, name="input_data")
  270. # maxpool output
  271. forward = akg.tvm.placeholder(out_shape, name="forward", dtype=dtype)
  272. # adjoint tensor for the differentiation
  273. head = akg.tvm.placeholder(out_shape, name="head", dtype=dtype)
  274. # override differentiation computation with custom function
  275. [dl_ddata] = akg.differentiate(forward, [data], head, None, None,
  276. override={forward: ([data], custom_maxpool_fdiff)})
  277. # schedule for differetiation operation
  278. s = akg.tvm.create_schedule([dl_ddata.op])
  279. # get computations
  280. result = dl_ddata
  281. result_pad = result.op.input_tensors[0]
  282. data_reorg = result_pad.op.input_tensors[0]
  283. equal = data_reorg.op.input_tensors[0]
  284. max_broadcast = equal.op.input_tensors[0]
  285. data_reshaped = equal.op.input_tensors[1]
  286. pad_data = data_reshaped.op.input_tensors[0]
  287. data_ub = s.cache_read(data, "local.UB", [pad_data])
  288. head_ub = s.cache_read(head, "local.UB", [equal])
  289. forward_ub = s.cache_read(forward, "local.UB", [max_broadcast])
  290. result_ub = s.cache_write(result, "local.UB")
  291. s[max_broadcast].set_scope("local.UB")
  292. s[data_reshaped].set_scope("local.UB")
  293. s[pad_data].set_scope("local.UB")
  294. s[equal].set_scope("local.UB")
  295. s[data_reorg].set_scope("local.UB")
  296. s[result_pad].set_scope("local.UB")
  297. s[data_ub].compute_inline()
  298. s[result_ub].compute_inline()
  299. s[pad_data].compute_inline()
  300. # equal dependencies
  301. s[forward_ub].compute_at(s[equal], equal.op.axis[0])
  302. s[max_broadcast].compute_at(s[equal], equal.op.axis[0])
  303. s[data_reshaped].compute_at(s[equal], equal.op.axis[0])
  304. s[head_ub].compute_at(s[equal], equal.op.axis[0])
  305. s[equal].compute_at(s[result_pad], result_pad.op.axis[0])
  306. # result dependencies
  307. s[data_reorg].compute_inline()
  308. b, c1, h, w, c0 = result_pad.op.axis
  309. oh, ow = result_pad.op.reduce_axis
  310. s[result_pad].reorder(oh, ow, b, c1, h, w, c0)
  311. # s[result_pad].compute_at(s[result], result.op.axis[1])
  312. b, c1, h, w, c0 = result.op.axis
  313. h_out, _ = s[result].split(h, stride_h)
  314. s[result_pad].compute_at(s[result], h_out)
  315. with akg.build_config(add_lower_pass=utils.debug_mode(0), dump_pass_ir=True):
  316. mod = akg.build(s, [head, data, forward, dl_ddata], "cce", name="maxpool_ad_manual_schedule_all_max",
  317. attrs=attrs, polyhedral=polyhedral)
  318. source_code = mod.imported_modules[0].get_source()
  319. kernel_name = "maxpool_ad_manual_schedule_all_max"
  320. utils.create_code(kernel_name, './', source_code)
  321. return mod
  322. def maxpool_ad_manual_schedule_no_overlap_all_max(shape, kernel, stride, pad, dtype, attrs=None, polyhedral=False):
  323. """automatic differentiate of maxpool with manual schedule for no overlap case."""
  324. kernel_h, kernel_w = kernel
  325. stride_h, stride_w = stride
  326. pad_h, pad_w, _, _ = pad
  327. batch_size, input_c1, input_h, input_w, input_c0 = shape
  328. pad_shape = (batch_size, input_c1, input_h + 2 * pad_h, input_w + 2 * pad_w, input_c0)
  329. def custom_maxpool_fdiff(out, inputs, head_, ad_attrs, new_pld_array):
  330. in_data = inputs[0]
  331. if stride_w != kernel_w:
  332. raise RuntimeError("Only supports kernels with same dimensions as stride size!")
  333. if stride_h != kernel_h:
  334. raise RuntimeError("Only supports kernels with same dimensions as stride size!")
  335. out_broadcast = akg.tvm.compute(pad_shape,
  336. lambda b, c1, h, w, c0:
  337. out(b, c1, akg.tvm.floordiv(h, stride_h), akg.tvm.floordiv(w, stride_w), c0),
  338. name="out_broadcast")
  339. # copy output to the shape of the padded input, copying the same value for the entire kernel size
  340. out_broadcast = akg.tvm.compute(pad_shape,
  341. lambda b, c1, h, w, c0:
  342. out(b, c1, akg.tvm.floordiv(h, stride_h), akg.tvm.floordiv(w, stride_w), c0),
  343. name="out_broadcast")
  344. # copy head to the shape of the padded input, copying the same value for the entire kernel size
  345. head_broadcast = akg.tvm.compute(pad_shape,
  346. lambda b, c1, h, w, c0:
  347. head_(b, c1, akg.tvm.floordiv(h, stride_h), akg.tvm.floordiv(w, stride_w), c0),
  348. name="head_broadcast")
  349. # check if value was a maximum and assign head of that position if it was
  350. # this is done for all the maximum values within one kernel
  351. result = akg.tvm.compute(in_data.shape,
  352. lambda b, c1, h, w, c0:
  353. akg.tvm.expr.Select(
  354. in_data(b, c1, h, w, c0) == out_broadcast(b, c1, h + pad_h, w + pad_w, c0),
  355. head_broadcast(b, c1, h + pad_h, w + pad_w, c0),
  356. akg.tvm.const(0, dtype=in_data.dtype)),
  357. name="result")
  358. return [result]
  359. out_size_h = (input_h + 2 * pad_h - kernel_h) // stride_h + 1
  360. out_size_w = (input_w + 2 * pad_w - kernel_w) // stride_w + 1
  361. out_shape = (batch_size, input_c1, out_size_h, out_size_w, input_c0)
  362. # tensor for the input data
  363. data = akg.tvm.placeholder(shape, dtype, name="input_data")
  364. # maxpool output
  365. forward = akg.tvm.placeholder(out_shape, name="forward", dtype=dtype)
  366. # adjoint tensor for the differentiation
  367. head = akg.tvm.placeholder(out_shape, name="head", dtype=dtype)
  368. # override differentiation computation with custom function
  369. [dl_ddata] = akg.differentiate(forward, [data], head, None, None,
  370. override={forward: ([data], custom_maxpool_fdiff)})
  371. # schedule for differetiation operation
  372. s = akg.tvm.create_schedule([dl_ddata.op])
  373. # get computations
  374. result = dl_ddata
  375. forward_broadcast = result.op.input_tensors[1]
  376. head_broadcast = result.op.input_tensors[2]
  377. # cache reads and writes
  378. result_ub = s.cache_write(result, "local.UB")
  379. data_ub = s.cache_read(data, "local.UB", [result_ub])
  380. head_ub = s.cache_read(head, "local.UB", [head_broadcast])
  381. forward_ub = s.cache_read(forward, "local.UB", [forward_broadcast])
  382. s[head_broadcast].set_scope("local.UB")
  383. s[forward_broadcast].set_scope("local.UB")
  384. s[head_ub].compute_at(s[head_broadcast], head_broadcast.op.axis[0])
  385. s[forward_ub].compute_at(s[forward_broadcast], forward_broadcast.op.axis[0])
  386. s[data_ub].compute_at(s[result_ub], result_ub.op.axis[0])
  387. s[forward_broadcast].compute_at(s[result_ub], result_ub.op.axis[0])
  388. s[head_broadcast].compute_at(s[result_ub], result_ub.op.axis[0])
  389. _, c1, h, _, _ = result.op.axis
  390. if input_h + 2 * pad_h > 32 or input_w + 2 * pad_w > 32:
  391. h_outer, _ = s[result].split(h, 4)
  392. s[result_ub].compute_at(s[result], h_outer)
  393. else:
  394. s[result_ub].compute_at(s[result], c1)
  395. with akg.build_config(add_lower_pass=utils.debug_mode(0), dump_pass_ir=True):
  396. mod = akg.build(s, [head, data, forward, dl_ddata], "cce",
  397. name="maxpool_ad_manual_schedule_no_overlap_all_max", attrs=attrs, polyhedral=polyhedral)
  398. source_code = mod.imported_modules[0].get_source()
  399. kernel_name = "maxpool_ad_manual_schedule_no_overlap_all_max"
  400. utils.create_code(kernel_name, './', source_code)
  401. return mod