#!/usr/bin/env python3 # coding: utf-8 # Copyright 2019 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """operator dsl function: five2four""" import akg.topi from akg.tvm.hybrid import script from akg.utils import custom_tiling as ct_util from akg.utils import validation_check as vc_util from akg.utils.format_transform import get_shape, get_bytes, to_tvm_const from akg.utils.dynamic_shape import shape_is_dynamic C_LIMIT_FOR_CAST = 3600 def get_attrs(): """get attrs.""" attrs = { "pragma_sink_last_axis": False } return attrs def five2four_tiling_strategy(tensor, c_value=None, expansion=None): """Custom tiling strategy for five2four op.""" strategy = list() if c_value is None: strategy = ct_util.create_template(tensor=tensor, template=ct_util.TileTemplate.NC1HWC0) elif not shape_is_dynamic(tensor): c_value = 16 if c_value < 16 else c_value node_n = ct_util.create_constraint_on_tensor(tensor=tensor, values=1, constraints=ct_util.TileConstraint.FACTOR, tensor_pos=0) node_c1 = ct_util.create_constraint_on_tensor(tensor=tensor, values="FULL", constraints=ct_util.TileConstraint.MAX, tensor_pos=1) node_c0 = ct_util.create_constraint_on_tensor(tensor=tensor, values=c_value, constraints=ct_util.TileConstraint.FACTOR, tensor_pos=4) strategy = node_n + node_c1 + node_c0 if expansion: strategy.append(ct_util.create_constraint_on_tensor(tensor=tensor, values=expansion, constraints=ct_util.TileConstraint.SET_EXPANSION)[0]) if shape_is_dynamic(tensor): # axis should be full tiled due to cast operator strategy.append(ct_util.modify_common_constraints( value=0.85, constraint=ct_util.TileConstraint.SET_MEM_RATIO)) return strategy @vc_util.check_input_type(akg.tvm.tensor.Tensor, (list, tuple), str, str) def five2four(data, shape4d, dst_type, format_): """ Convert 5-dims "data" to 4-dims,the format of "data" is defined in "format_" Args: data (tvm.tensor.Tensor): 5-dims tensor of type float16, float32 shape4d (Union[list, tuple]): a list has 4 nums, shape of output Tensor dst_type (str): data type of output Tensor format_ (str): a str defined the format of returns, support NCHW and NHWC Returns: 4-dims tvm.tensor.Tensor. """ vc_util.ops_dtype_check([data.dtype, dst_type], vc_util.DtypeForDavinci.ALL_FLOAT) shape5d = get_shape(data) if not shape_is_dynamic(data): if len(shape5d) != 5 or shape5d[-1] != 16: raise ValueError("five2four_cce only support 5-dim data and last dim should be 16") bs, c1, h, w, c0 = shape5d if not shape_is_dynamic(data): vc_util.davinci_format_check(shape5d, "NC1HWC0", dim=5) # Check format if format_ not in ['NCHW', 'NHWC']: raise ValueError("{} format is not support, five2four only support NCHW and NHWC format input" .format(format_)) if format_ == "NCHW": if shape_is_dynamic(data): shape4d = [bs, c1 * c0, h, w] _, c, h_4d, w_4d = shape4d else: if shape_is_dynamic(data): shape4d = [bs, h, w, c1 * c0] _, h_4d, w_4d, c = shape4d vc_util.davinci_format_check(shape4d, format_, dim=4) # Check is shape4d and shape5d match if False not in [isinstance(s, (int, akg.tvm.expr.IntImm)) for s in shape5d]: if h_4d != h or w_4d != w: raise ValueError("five2four_cce's shape4d h and w should equal to data shape's h and w") if c > c1 * c0 or c <= (c1 - 1) * c0: raise ValueError("five2four_cce's shape4d c should in set ((c1 - 1) * c0, c1 * c0]") # Check size c when casting happens if not shape_is_dynamic(data): if data.dtype != dst_type and c >= C_LIMIT_FOR_CAST: raise ValueError("When input and output data type is not matched, shape of 'c' axis should not exceed {}, " "while currently set is {}".format(C_LIMIT_FOR_CAST, c)) @script(capture=locals()) def nc1hwc0_to_nhwc(inputs, bs, h, w, c, c1, c0): output = allocate((bs, h, w, c), inputs.dtype, "local") for n_i in range(bs): for h_i in range(h): for w_i in range(w): for c_i in range(c1): for c_i0 in range(c0): output[n_i, h_i, w_i, c_i * c0 + c_i0] = inputs[n_i, c_i, h_i, w_i, c_i0] return output @script(capture=locals()) def nc1hwc0_to_nchw(inputs, bs, h, w, c, c1, c0): output = allocate((bs, c, h, w), inputs.dtype, "local") for n_i in range(bs): for c_i in range(c1): for h_i in range(h): for w_i in range(w): for c_i0 in range(c0): output[n_i, c_i * c0 + c_i0, h_i, w_i] = inputs[n_i, c_i, h_i, w_i, c_i0] return output # if c % 16 == 0, h and w == 1, five2four is a reshape operation if shape_is_dynamic(data): call_reshape = isinstance(h, int) and isinstance(w, int) and h == 1 and w == 1 else: call_reshape = h == 1 and w == 1 and c % 16 == 0 c_value = None expansion = None if format_ == "NHWC": if call_reshape: output = akg.topi.reshape(data, (bs, h, w, c)) if shape_is_dynamic(data): output = akg.tvm.compute((bs, h, w, c), lambda *indice: output(*indice), name="reshape") elif c < c0: reshape_output = akg.topi.reshape(data, (bs, h, w, c0)) output = akg.tvm.compute((bs, h, w, c), lambda *i: reshape_output(*i), name='slice_output') else: output = nc1hwc0_to_nhwc( data, to_tvm_const(bs), to_tvm_const(h), to_tvm_const(w), to_tvm_const(c), to_tvm_const(c1), to_tvm_const(c0)) else: if call_reshape: output = akg.topi.reshape(data, (bs, c, h, w)) if shape_is_dynamic(data): output = akg.tvm.compute((bs, c, h, w), lambda *indice: output(*indice), name="reshape") else: output = nc1hwc0_to_nchw( data, to_tvm_const(bs), to_tvm_const(h), to_tvm_const(w), to_tvm_const(c), to_tvm_const(c1), to_tvm_const(c0)) # two special cases for tiling strategy if not shape_is_dynamic(data): if c < c0 or output.dtype != dst_type: c_value = c if c % c0 != 0 and output.dtype != dst_type: expansion = int(ct_util.BLOCK_SIZE / get_bytes(data.dtype)) attrs = get_attrs() if not call_reshape: attrs["custom_tiling"] = five2four_tiling_strategy(data, c_value, expansion) if output.dtype != dst_type: output = akg.topi.cast(output, dst_type) return output, attrs