#!/usr/bin/env python3 # coding: utf-8 # Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """composite topi""" from akg import tvm from akg.utils.format_transform import get_const from akg.utils import validation_check as vc_util @tvm.register_func("ElemAny") def elem_any(inputs, attrs): def kernel_ir(dst, data): ib = tvm.ir_builder.create() with ib.for_range_n(data.shape, "ax") as i: zero = tvm.const(0, data.dtype) one = tvm.const(1, data.dtype) with ib.if_scope(ib.load(data, i) > zero): ib.store(dst, 0, one) return ib.get() in_tensor = inputs[0] return tvm.extern((1,), [in_tensor], lambda ins, outs : kernel_ir(outs[0], ins[0]), name = "elemany", dtype=in_tensor.dtype) @tvm.register_func("ElemAll") def elem_all(inputs, attrs): def kernel_ir(dst, data): ib = tvm.ir_builder.create() with ib.for_range_n(data.shape, "ax") as i: zero = tvm.const(0, data.dtype) with ib.if_scope(ib.load(data, i) == zero): ib.store(dst, 0, zero) return ib.get() in_tensor = inputs[0] return tvm.extern((1,), [in_tensor], lambda ins, outs : kernel_ir(outs[0], ins[0]), name = "elemall", dtype=in_tensor.dtype) @tvm.register_func("TransData") def trans_data(inputs, attrs): attrs = {k: v for k, v in attrs.items()} if len(inputs) != 1: raise ValueError("length of inputs shoule be 1, but got %d." % len(inputs)) if "src_format" not in attrs or "dst_format" not in attrs: raise ValueError("src_format or dst_format not be found in the attrs") input_data = inputs[0] output_name = "T_transdata_" + input_data.op.name src_format = attrs["src_format"] dst_format = attrs["dst_format"] input_dtype = input_data.dtype vc_util.ops_dtype_check(input_dtype, [vc_util.DtypeForDavinci.FLOAT16, vc_util.DtypeForDavinci.FLOAT32]) # cube size is 16 cs = 16 def _zn2default(data, original_shape): if len(data.shape) < 4: raise ValueError("length of shape of input_data should be greater than or equal to 4, but got %d" % len(data.shape)) if len(original_shape) < 2: raise ValueError("length of original_shape(output_shape) should be greater than or equal to 2, but got %d" % len(original_shape)) def kernel_ir(input_, output): ib = tvm.ir_builder.create() shape = [get_const(x) for x in input_.shape] n1, m1, m0, n0 = shape[-4:] original_shape_ = [get_const(x) for x in original_shape] m, n = original_shape_[-2:] batch_dims = shape[:-4] with ib.for_range_n(batch_dims, "bs") as i: with ib.for_range(0, n1) as i_n1: with ib.for_range(0, m1) as i_m1: with ib.for_range(0, m0) as i_m0: with ib.for_range(0, n0) as i_n0: with ib.if_scope(tvm.all((i_m1*cs + i_m0) < m, (i_n1*cs + i_n0) < n)): output_args = i + [i_m1*cs + i_m0, i_n1*cs + i_n0] input_args = i + [i_n1, i_m1, i_m0, i_n0] ib.store(output, output_args, ib.load(input_, input_args)) return ib.get() # If it is implemented with tvm.compute, # the generated stmt is difficult to process for poly in the fusion scene return tvm.extern(original_shape, [data], lambda ins, outs : kernel_ir(ins[0], outs[0]), name=output_name, dtype=data.dtype) def _default2zn(data): shape = [get_const(x) for x in data.shape] dtype = data.dtype if len(shape) < 2: raise ValueError("length of shape of input_data should be greater than or equal to 2, but got %d" % len(shape)) m, n = shape[-2:] output_shape = [] for i in range(0, len(shape) - 2): output_shape.append(shape[i]) m1 = (m + cs - 1) // cs n1 = (n + cs - 1) // cs output_shape.extend([n1, m1, cs, cs]) def fcompute(*output_indices): input_indices = [] batch_len = len(output_indices) - 4 n1_indice = output_indices[batch_len] m1_indice = output_indices[batch_len + 1] m0_indcie = output_indices[batch_len + 2] n0_indcie = output_indices[batch_len + 3] m_indice = m1_indice * cs + m0_indcie n_indice = n1_indice * cs + n0_indcie for i in range(0, batch_len): input_indices.append(output_indices[i]) input_indices.append(m_indice) input_indices.append(n_indice) res = tvm.if_then_else(tvm.any(m_indice >= m, n_indice >= n), tvm.const(0, dtype), data(*input_indices)) return res output = tvm.compute(output_shape, fcompute, name=output_name) return output # FRACTAL_NZ: zN fractal format if (src_format == "DefaultFormat" or src_format == "NCHW") and dst_format == "FRACTAL_NZ": return _default2zn(input_data) elif src_format == "FRACTAL_NZ" and (dst_format == "DefaultFormat" or dst_format == "NCHW"): if "output_shape" not in attrs: raise ValueError("output_shape(original_shape) not be found in the attrs") original_shape = attrs["output_shape"] return _zn2default(input_data, original_shape) else: raise ValueError("TransData for src_format %s and dst_format %s is not supported" % (src_format, dst_format))