#!/usr/bin/env python3 # coding: utf-8 # Copyright 2019 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """operator dsl function: maxpool_ad""" import akg.tvm import akg.topi import akg from akg.ops.nn import maxpool from akg.utils.format_transform import get_shape from akg.utils.dsl_create import cal_pad_shapes_by_strategy from akg.utils import kernel_exec as utils from akg.utils import validation_check as vc_util @vc_util.check_input_type(akg.tvm.tensor.Tensor, akg.tvm.tensor.Tensor, (list, tuple), (list, tuple), (str, list, tuple)) def maxpool_ad_no_custom_diff_poly_all_max(head, data, kernel, stride, pad): """automatic differentiate of maxpool with polyhedral""" attrs = {"enable_post_poly_loop_partition": False, "enable_pre_poly_loop_partition": False} maxpool_fwd = maxpool.old_maxpool(data, kernel, stride, pad) [dl_ddata] = akg.differentiate(maxpool_fwd, [data], head, None, None) return dl_ddata, attrs @vc_util.check_input_type(akg.tvm.tensor.Tensor, akg.tvm.tensor.Tensor, (list, tuple), (list, tuple), (str, list, tuple)) def maxpool_ad_no_custom_diff_manual_schedule_all_max(head, data, kernel, stride, pad): """automatic differentiate of maxpool with manual schedule.""" attrs = {"enable_post_poly_loop_partition": False, "enable_pre_poly_loop_partition": False} maxpool_fwd = maxpool.old_maxpool(data, kernel, stride, pad) [dl_ddata] = akg.differentiate(maxpool_fwd, [data], head, None, None) # schedule for differetiation operation s = akg.tvm.create_schedule([dl_ddata.op]) new_tensor_red = dl_ddata new_tensor = new_tensor_red.op.input_tensors[0] data = new_tensor.op.input_tensors[0] broadcast = new_tensor.op.input_tensors[1] head = new_tensor.op.input_tensors[2] forward = broadcast.op.input_tensors[0] def comp_func(s): data_ub = s.cache_read(data, "local.UB", [forward, new_tensor]) head_ub = s.cache_read(head, "local.UB", [new_tensor]) result_ub = s.cache_write(new_tensor_red, "local.UB") s[broadcast].set_scope("local.UB") s[forward].set_scope("local.UB") b, c1, h, w, c0 = forward.op.axis oh, ow = forward.op.reduce_axis s[forward].reorder(oh, ow, b, c1, h, w, c0) s[new_tensor].set_scope("local.UB") b, c1, h, w, c0 = result_ub.op.axis s[result_ub].reorder(*result_ub.op.reduce_axis, b, c1, h, w, c0) s[broadcast].compute_at(s[result_ub], b) s[new_tensor].compute_at(s[result_ub], b) return dl_ddata, comp_func, attrs @vc_util.check_input_type(akg.tvm.tensor.Tensor, akg.tvm.tensor.Tensor, akg.tvm.tensor.Tensor, akg.tvm.tensor.Tensor, (list, tuple), (list, tuple), (str, list, tuple)) def maxpool_ad(head, data, forward, mask, kernel, stride, pad): """automatic differentiate of maxpool with manual schedule.""" shape = get_shape(data) dtype = data.dtype kernel_h, kernel_w = kernel stride_h, stride_w = stride [ph_h, _, pw_h, _], [out_size_h, out_size_w] = \ cal_pad_shapes_by_strategy(shape, kernel, stride, pad) batch_size, input_c1, input_h, input_w, input_c0 = shape # tile size one is proved to be the most efficient one tile_scale_h = 1 tile_scale_w = 1 tile_h = stride_h * tile_scale_h if kernel_h == stride_h: # non-overlapping case tile_h_pad_u = ph_h % stride_h elif kernel_h % stride_h == 0: tile_h_pad_u = kernel_h - stride_h - ph_h else: tile_h_pad_u = kernel_h - kernel_h % stride_h - ph_h tile_h_pad_l = kernel_h - stride_h + ph_h tile_input_h = tile_h + tile_h_pad_u + tile_h_pad_l tile_h_out = (input_h - 1) // tile_h + 1 if ph_h % stride_h == 0: pad_output_h = ph_h // stride_h else: pad_output_h = ph_h // stride_h + 1 if tile_h_pad_u % stride_h == 0: pad_output_h -= tile_h_pad_u // stride_h else: pad_output_h -= tile_h_pad_u // stride_h + 1 tile_output_h = (tile_input_h - kernel_h) // stride_h + 1 tile_w = stride_w * tile_scale_w if kernel_w == stride_w: # non-overlapping case tile_w_pad_u = pw_h % stride_w elif kernel_w % stride_w == 0: tile_w_pad_u = kernel_w - stride_w - pw_h else: tile_w_pad_u = kernel_w - kernel_w % stride_w - pw_h tile_w_pad_l = kernel_w - stride_w + pw_h tile_input_w = tile_w + tile_w_pad_u + tile_w_pad_l tile_w_out = (input_w - 1) // tile_w + 1 if pw_h % stride_w == 0: pad_output_w = pw_h // stride_w else: pad_output_w = pw_h // stride_w + 1 if tile_w_pad_u % stride_w == 0: pad_output_w -= tile_w_pad_u // stride_w else: pad_output_w -= tile_w_pad_u // stride_w + 1 tile_output_w = (tile_input_w - kernel_w) // stride_w + 1 def custom_maxpool_fdiff(out, inputs, head_, ad_attrs, new_pld_array): head_reshaped = akg.tvm.compute((batch_size, input_c1, tile_h_out, tile_w_out, tile_output_h, tile_output_w, input_c0), lambda b, c1, h_out, w_out, oh, ow, c0: akg.tvm.expr.Select( akg.tvm.any(h_out * tile_scale_h + pad_output_h + oh < 0, h_out * tile_scale_h + pad_output_h + oh > out_size_h - 1, w_out * tile_scale_w + pad_output_w + ow < 0, w_out * tile_scale_w + pad_output_w + ow > out_size_w - 1), akg.tvm.const(0.0, dtype=dtype), head_(b, c1, h_out * tile_scale_h + pad_output_h + oh, w_out * tile_scale_w + pad_output_w + ow, c0)), name="head_reshaped") mask_reshaped = akg.tvm.compute((batch_size, input_c1, tile_h_out, tile_w_out, tile_output_h, tile_output_w, kernel_h, kernel_w, input_c0), lambda b, c1, h_out, w_out, oh, ow, kh, kw, c0: akg.tvm.expr.Select( akg.tvm.any(h_out * tile_scale_h + pad_output_h + oh < 0, h_out * tile_scale_h + pad_output_h + oh > out_size_h - 1, w_out * tile_scale_w + pad_output_w + ow < 0, w_out * tile_scale_w + pad_output_w + ow > out_size_w - 1), akg.tvm.const(0.0, dtype=dtype), mask(b, c1, kh, kw, h_out * tile_scale_h + pad_output_h + oh, w_out * tile_scale_w + pad_output_w + ow, c0)), name="mask_reshaped") d_data = akg.tvm.compute((batch_size, input_c1, tile_h_out, tile_w_out, tile_output_h, tile_output_w, kernel_h, kernel_w, input_c0), lambda b, c1, h_out, w_out, oh, ow, kh, kw, c0: mask_reshaped(b, c1, h_out, w_out, oh, ow, kh, kw, c0) * head_reshaped(b, c1, h_out, w_out, oh, ow, c0), name="d_data") data_reorg = akg.tvm.compute((batch_size, input_c1, tile_h_out, tile_w_out, tile_output_h, tile_output_w, tile_h, tile_w, input_c0), lambda b, c1, h_out, w_out, oh, ow, h, w, c0: akg.tvm.expr.Select( akg.tvm.any(h + tile_h_pad_u < oh * stride_h, h + tile_h_pad_u > oh * stride_h + kernel_h - 1, w + tile_w_pad_u < ow * stride_w, w + tile_w_pad_u > ow * stride_w + kernel_w - 1), akg.tvm.const(0, dtype=dtype), d_data(b, c1, h_out, w_out, oh, ow, h + tile_h_pad_u - oh * stride_h, w + tile_w_pad_u - ow * stride_w, c0)), name="data_reorg") result_tile = akg.topi.sum(data_reorg, [4, 5]) result = akg.tvm.compute(shape, lambda b, c1, h, w, c0: result_tile(b, c1, h // tile_h, w // tile_w, h % tile_h, w % tile_w, c0), name="result") return [result] # override differentiation computation with custom function [dl_ddata] = akg.differentiate(forward, [data], head, None, None, override={forward: ([data], custom_maxpool_fdiff)}) # schedule for differetiation operation s = akg.tvm.create_schedule([dl_ddata.op]) # get computations result = dl_ddata result_tile = result.op.input_tensors[0] data_reorg = result_tile.op.input_tensors[0] d_data = data_reorg.op.input_tensors[0] mask_reshaped = d_data.op.input_tensors[0] head_reshaped = d_data.op.input_tensors[1] def comp_func(s): data_ub = s.cache_read(mask, "local.UB", [mask_reshaped]) head_ub = s.cache_read(head, "local.UB", [head_reshaped]) result_ub = s.cache_write(result, "local.UB") s[d_data].set_scope("local.UB") s[data_reorg].set_scope("local.UB") s[mask_reshaped].set_scope("local.UB") s[head_reshaped].set_scope("local.UB") s[result_tile].set_scope("local.UB") s[result_ub].compute_inline() # inline inputs s[head_ub].compute_inline() s[data_ub].compute_inline() # result_tile dependencies s[data_reorg].compute_inline() b, c1, h_out, w_out, h, w, c0 = result_tile.op.axis oh, ow = result_tile.op.reduce_axis s[result_tile].reorder(b, c1, h_out, w_out, h, w, oh, ow, c0) s[d_data].compute_at(s[result_tile], w_out) s[mask_reshaped].compute_at(s[result_tile], w_out) s[head_reshaped].compute_at(s[result_tile], w_out) # tile result b, c1, h, w, c0 = result.op.axis h_out, h_in = s[result].split(h, tile_h) w_out, w_in = s[result].split(w, tile_w) s[result].reorder(b, c1, h_out, w_out, h_in, w_in, c0) s[result_tile].compute_at(s[result], w_out) return dl_ddata, comp_func @vc_util.check_input_type((list, tuple), (list, tuple), (list, tuple), (str, list, tuple), str, (bool, type(None)), (dict, type(None))) def maxpool_ad_manual_schedule_all_max(shape, kernel, stride, pad, dtype, polyhedral=True, attrs=None): """automatic differentiate of maxpool with manual schedule for all maximum.""" kernel_h, kernel_w = kernel stride_h, stride_w = stride pad_h, pad_w, _, _ = pad batch_size, input_c1, input_h, input_w, input_c0 = shape pad_shape = (batch_size, input_c1, input_h + 2 * pad_h, input_w + 2 * pad_w, input_c0) out_size_h = (input_h + 2 * pad_h - kernel_h) // stride_h + 1 out_size_w = (input_w + 2 * pad_w - kernel_w) // stride_w + 1 out_shape = (batch_size, input_c1, out_size_h, out_size_w, input_c0) def custom_maxpool_fdiff(out, inputs, head_, ad_attrs, new_pld_array): in_data = inputs[0] data_separated_by_windows = (kernel_h, kernel_w, batch_size, input_c1, out_size_h, out_size_w, input_c0) pad_data = akg.tvm.compute(pad_shape, lambda b, c1, h, w, c0: akg.tvm.expr.Select( akg.tvm.all(h >= pad_h, h < input_h + pad_h, w >= pad_w, w < input_w + pad_w), in_data(b, c1, h - pad_h, w - pad_w, c0), akg.tvm.const(0.0, dtype=dtype)), name="pad_data") data_reshaped = akg.tvm.compute(data_separated_by_windows, lambda wh, ww, b, c1, oh, ow, c0: pad_data(b, c1, oh * stride_h + wh, ow * stride_w + ww, c0), name="data_reshaped") max_broadcast = akg.tvm.compute(data_separated_by_windows, lambda wh, ww, b, c1, oh, ow, c0: out(b, c1, oh, ow, c0), name="max_broadcast") equal = akg.tvm.compute(data_separated_by_windows, lambda wh, ww, b, c1, oh, ow, c0: akg.tvm.expr.Select( max_broadcast(wh, ww, b, c1, oh, ow, c0) == data_reshaped(wh, ww, b, c1, oh, ow, c0), head_(b, c1, oh, ow, c0), akg.tvm.const(0.0, dtype=dtype)), name="equal") data_reorg = akg.tvm.compute((out_size_h, out_size_w, batch_size, input_c1, input_h + 2 * pad_h, input_w + 2 * pad_w, input_c0), lambda oh, ow, b, c1, h, w, c0: akg.tvm.expr.Select( akg.tvm.any(h < oh * stride_h, h > oh * stride_h + kernel_h - 1, w < ow * stride_w, w > ow * stride_w + kernel_w - 1), akg.tvm.const(0, dtype=dtype), equal(h - oh * stride_h, w - ow * stride_w, b, c1, oh, ow, c0)), name="data_reorg") result_pad = akg.topi.sum(data_reorg, [0, 1]) result = akg.tvm.compute(shape, lambda b, c1, h, w, c0: result_pad(b, c1, h + pad_h, w + pad_w, c0), name="result") return [result] # tensor for the input data data = akg.tvm.placeholder(shape, dtype, name="input_data") # maxpool output forward = akg.tvm.placeholder(out_shape, name="forward", dtype=dtype) # adjoint tensor for the differentiation head = akg.tvm.placeholder(out_shape, name="head", dtype=dtype) # override differentiation computation with custom function [dl_ddata] = akg.differentiate(forward, [data], head, None, None, override={forward: ([data], custom_maxpool_fdiff)}) # schedule for differetiation operation s = akg.tvm.create_schedule([dl_ddata.op]) # get computations result = dl_ddata result_pad = result.op.input_tensors[0] data_reorg = result_pad.op.input_tensors[0] equal = data_reorg.op.input_tensors[0] max_broadcast = equal.op.input_tensors[0] data_reshaped = equal.op.input_tensors[1] pad_data = data_reshaped.op.input_tensors[0] data_ub = s.cache_read(data, "local.UB", [pad_data]) head_ub = s.cache_read(head, "local.UB", [equal]) forward_ub = s.cache_read(forward, "local.UB", [max_broadcast]) result_ub = s.cache_write(result, "local.UB") s[max_broadcast].set_scope("local.UB") s[data_reshaped].set_scope("local.UB") s[pad_data].set_scope("local.UB") s[equal].set_scope("local.UB") s[data_reorg].set_scope("local.UB") s[result_pad].set_scope("local.UB") s[data_ub].compute_inline() s[result_ub].compute_inline() s[pad_data].compute_inline() # equal dependencies s[forward_ub].compute_at(s[equal], equal.op.axis[0]) s[max_broadcast].compute_at(s[equal], equal.op.axis[0]) s[data_reshaped].compute_at(s[equal], equal.op.axis[0]) s[head_ub].compute_at(s[equal], equal.op.axis[0]) s[equal].compute_at(s[result_pad], result_pad.op.axis[0]) # result dependencies s[data_reorg].compute_inline() b, c1, h, w, c0 = result_pad.op.axis oh, ow = result_pad.op.reduce_axis s[result_pad].reorder(oh, ow, b, c1, h, w, c0) # s[result_pad].compute_at(s[result], result.op.axis[1]) b, c1, h, w, c0 = result.op.axis h_out, _ = s[result].split(h, stride_h) s[result_pad].compute_at(s[result], h_out) with akg.build_config(add_lower_pass=utils.debug_mode(0), dump_pass_ir=True): mod = akg.build(s, [head, data, forward, dl_ddata], "cce", name="maxpool_ad_manual_schedule_all_max", attrs=attrs, polyhedral=polyhedral) source_code = mod.imported_modules[0].get_source() kernel_name = "maxpool_ad_manual_schedule_all_max" utils.create_code(kernel_name, './', source_code) return mod def maxpool_ad_manual_schedule_no_overlap_all_max(shape, kernel, stride, pad, dtype, attrs=None, polyhedral=False): """automatic differentiate of maxpool with manual schedule for no overlap case.""" kernel_h, kernel_w = kernel stride_h, stride_w = stride pad_h, pad_w, _, _ = pad batch_size, input_c1, input_h, input_w, input_c0 = shape pad_shape = (batch_size, input_c1, input_h + 2 * pad_h, input_w + 2 * pad_w, input_c0) def custom_maxpool_fdiff(out, inputs, head_, ad_attrs, new_pld_array): in_data = inputs[0] if stride_w != kernel_w: raise RuntimeError("Only supports kernels with same dimensions as stride size!") if stride_h != kernel_h: raise RuntimeError("Only supports kernels with same dimensions as stride size!") out_broadcast = akg.tvm.compute(pad_shape, lambda b, c1, h, w, c0: out(b, c1, akg.tvm.floordiv(h, stride_h), akg.tvm.floordiv(w, stride_w), c0), name="out_broadcast") # copy output to the shape of the padded input, copying the same value for the entire kernel size out_broadcast = akg.tvm.compute(pad_shape, lambda b, c1, h, w, c0: out(b, c1, akg.tvm.floordiv(h, stride_h), akg.tvm.floordiv(w, stride_w), c0), name="out_broadcast") # copy head to the shape of the padded input, copying the same value for the entire kernel size head_broadcast = akg.tvm.compute(pad_shape, lambda b, c1, h, w, c0: head_(b, c1, akg.tvm.floordiv(h, stride_h), akg.tvm.floordiv(w, stride_w), c0), name="head_broadcast") # check if value was a maximum and assign head of that position if it was # this is done for all the maximum values within one kernel result = akg.tvm.compute(in_data.shape, lambda b, c1, h, w, c0: akg.tvm.expr.Select( in_data(b, c1, h, w, c0) == out_broadcast(b, c1, h + pad_h, w + pad_w, c0), head_broadcast(b, c1, h + pad_h, w + pad_w, c0), akg.tvm.const(0, dtype=in_data.dtype)), name="result") return [result] out_size_h = (input_h + 2 * pad_h - kernel_h) // stride_h + 1 out_size_w = (input_w + 2 * pad_w - kernel_w) // stride_w + 1 out_shape = (batch_size, input_c1, out_size_h, out_size_w, input_c0) # tensor for the input data data = akg.tvm.placeholder(shape, dtype, name="input_data") # maxpool output forward = akg.tvm.placeholder(out_shape, name="forward", dtype=dtype) # adjoint tensor for the differentiation head = akg.tvm.placeholder(out_shape, name="head", dtype=dtype) # override differentiation computation with custom function [dl_ddata] = akg.differentiate(forward, [data], head, None, None, override={forward: ([data], custom_maxpool_fdiff)}) # schedule for differetiation operation s = akg.tvm.create_schedule([dl_ddata.op]) # get computations result = dl_ddata forward_broadcast = result.op.input_tensors[1] head_broadcast = result.op.input_tensors[2] # cache reads and writes result_ub = s.cache_write(result, "local.UB") data_ub = s.cache_read(data, "local.UB", [result_ub]) head_ub = s.cache_read(head, "local.UB", [head_broadcast]) forward_ub = s.cache_read(forward, "local.UB", [forward_broadcast]) s[head_broadcast].set_scope("local.UB") s[forward_broadcast].set_scope("local.UB") s[head_ub].compute_at(s[head_broadcast], head_broadcast.op.axis[0]) s[forward_ub].compute_at(s[forward_broadcast], forward_broadcast.op.axis[0]) s[data_ub].compute_at(s[result_ub], result_ub.op.axis[0]) s[forward_broadcast].compute_at(s[result_ub], result_ub.op.axis[0]) s[head_broadcast].compute_at(s[result_ub], result_ub.op.axis[0]) _, c1, h, _, _ = result.op.axis if input_h + 2 * pad_h > 32 or input_w + 2 * pad_w > 32: h_outer, _ = s[result].split(h, 4) s[result_ub].compute_at(s[result], h_outer) else: s[result_ub].compute_at(s[result], c1) with akg.build_config(add_lower_pass=utils.debug_mode(0), dump_pass_ir=True): mod = akg.build(s, [head, data, forward, dl_ddata], "cce", name="maxpool_ad_manual_schedule_no_overlap_all_max", attrs=attrs, polyhedral=polyhedral) source_code = mod.imported_modules[0].get_source() kernel_name = "maxpool_ad_manual_schedule_no_overlap_all_max" utils.create_code(kernel_name, './', source_code) return mod