#!/usr/bin/env python3 # coding: utf-8 # Copyright 2019 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """operator dsl function: gather_v2""" import akg.tvm from akg.utils import validation_check as vc_util from akg.utils.format_transform import get_shape from akg.utils import custom_tiling as ct_util attrs = { "RewriteVarTensorIdx": True, "enable_double_buffer": False, } gather_v2_set_dim_map = { } def gather_v2_set_dim_func(params, indices, axis): """set dim info for attr""" key = [] key.append(tuple(params.shape)) key.append(tuple(indices.shape)) key.append(axis) key.append(params.dtype) key.append(indices.dtype) hash_key = str(tuple(key)) if hash_key in gather_v2_set_dim_map.keys(): return ct_util.set_dims(gather_v2_set_dim_map[hash_key]), hash_key return "", hash_key def gather_tiling_strategy(data, axis): """Custom tiling strategy for gather op""" strategy = list() base = 0 for priority_value, pos in enumerate(range(len(data.shape) - 1, axis, -1)): priority_value = priority_value + base strategy.append(ct_util.create_constraint_on_tensor(tensor=data, values=priority_value, constraints=ct_util.TileConstraint.SET_PRIORITY, tensor_pos=pos)[0]) return strategy @vc_util.check_input_type(akg.tvm.tensor.Tensor, akg.tvm.tensor.Tensor, (int, type(None))) def gather_v2(params, indices, axis=0): """ Select tensor of related dimensions. Note: Each entry in indices must be an index in [0, params.shape[axis]). Args: params (tvm.tensor.Tensor): Data to be gathered. Types: int8, int32, float16, float32. indices (tvm.tensor.Tensor): A 1-D tensor for index. Types: int32. Each entry in indices must be an index in [0, params.shape[axis]). axis (int): Axis along which index of params be applied. Default: 0. Returns: tvm.tensor.Tensor, which indexes the input tensor along dimension dim (axis) using the entries in index. """ input_shape = get_shape(params) indices_shape = get_shape(indices) vc_util.check_shape(params.shape, tensor_name="params") vc_util.check_shape(indices.shape, length=1, tensor_name="indices") vc_util.ops_dtype_check(params.dtype, [vc_util.DtypeForDavinci.ALL_FLOAT, vc_util.DtypeForDavinci.ALL_INT]) vc_util.ops_dtype_check(indices.dtype, vc_util.DtypeForDavinci.INT32) axis_num = len(input_shape) vc_util.check_value_on_integer("axis", axis, -axis_num, axis_num) if axis < 0: axis = axis_num + axis def _get_output_shape(): out_shape = [] for i, in_shape in enumerate(input_shape): if i != axis: out_shape.append(in_shape) else: for indice_shape in indices_shape: out_shape.append(indice_shape) return out_shape def _get_input_index(output_index): input_index = [] indices_len = len(indices_shape) axis_input_index = indices[output_index[axis:axis + indices_len]] for i in range(axis): input_index.append(output_index[i]) input_index.append(axis_input_index) for i in range(axis + 1, len(output_index)): input_index.append(output_index[i]) return input_index output_shape = _get_output_shape() output = akg.tvm.compute( output_shape, lambda *indices_output: params(*_get_input_index( indices_output)), name="gather_output") dim_info = gather_v2_set_dim_func(params, indices, axis)[0] if dim_info != "": attrs['dim'] = dim_info attrs["custom_tiling"] = gather_tiling_strategy(params, axis) attrs["enable_feature_library"] = True return output, attrs