#!/usr/bin/env python3 # coding: utf-8 # Copyright 2019 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """operator dsl function:argmin_argmax_common""" import akg.tvm import akg.topi from akg.lang import cce as dav from akg.utils import custom_tiling as ct_util, validation_check as vc_util from akg.utils.dsl_create import get_reduce_out_shape from akg.utils.format_transform import refine_reduce_axis, get_shape from akg.utils.dynamic_shape import shape_is_dynamic, set_dynamic_shape_limit_for_tensor def argminmax_tiling_strategy(out_shape, axis): """Custom tiling strategy for argminmax op.""" strategy = list() # when reduce axis is one, we do not need any strategy if out_shape[axis] == 1: return strategy # if reduce first axis, it will transpose to last axis # so here we adapt to this change if axis == 0: temp = out_shape[0] out_shape = out_shape[1:] out_shape.append(temp) axis = len(out_shape) - 1 # eliminate single axis, which will automatically disappear in halide ir # and adjust axis if it is influenced shrink = list() for i, shp in enumerate(out_shape): if shp == 1: if i < axis: axis -= 1 else: shrink.append(shp) for i, _ in enumerate(shrink): if i == axis: strategy.append(ct_util.create_constraint_on_axis( values="FULL", constraints=ct_util.TileConstraint.MAX, axis=i)[0]) else: strategy.append(ct_util.create_constraint_on_axis( values=1, constraints=ct_util.TileConstraint.FACTOR, axis=i)[0]) return strategy @vc_util.check_input_type(akg.tvm.tensor.Tensor, int, (str, type(None))) def common(data, axis, method="min"): """ Returns the index with the max or min value across axes of a tensor. Note: method can be "max" or "min" to get argmax or argmin. Args: data (tvm.tensor.Tensor): Tensor of type float16, float32, int8, int32. axis (int): Describe the axis of input tensor. method (str): Can be "max" or "min". Returns: tvm.tensor.Tensor, has type of int32. """ shape = get_shape(data) dtype = data.dtype vc_util.ops_dtype_check(data.dtype, [vc_util.DtypeForDavinci.ALL_FLOAT, vc_util.DtypeForDavinci.ALL_INT]) vc_util.reduce_axis_check(shape, axis) real_axis = refine_reduce_axis(shape, axis)[0] out_shape = get_reduce_out_shape(shape, axis=axis) attr_map = {} if shape_is_dynamic(data): attr_map["dynamic_shape"] = set_dynamic_shape_limit_for_tensor(data, 4096, real_axis) if dtype != "float16": data = akg.topi.cast(data, "float16") k = akg.tvm.reduce_axis((0, data.shape[real_axis]), "k") if axis in (len(shape) - 1, -1): if method == "min": reducer = akg.tvm.comm_reducer( lambda x, y: dav.fargmin(x, y), lambda t: akg.tvm.max_value(t)) elif method == "max": reducer = akg.tvm.comm_reducer( lambda x, y: dav.fargmax(x, y), lambda t: akg.tvm.min_value(t)) else: raise ValueError("not support " + method) if len(data.shape) == 1: res = akg.tvm.compute((1,), lambda i: reducer(data[k], axis=k)) else: res = akg.tvm.compute(out_shape, lambda *indice: reducer(data(*indice, k), axis=k)) res = akg.tvm.compute(out_shape, lambda *indice: res(*indice).astype("int32"), "argred_output") elif axis in (0, -len(shape)): tmp_idx = akg.tvm.compute(shape[1:], lambda *indice: akg.tvm.const(0.0, "float16"), name='tmp_index') local_data = akg.tvm.compute(shape[1:], lambda *indice: data(0, *indice), name="tmp_data") for idx in range(shape[axis] - 1): if method == 'min': tmp_idx = akg.tvm.compute( shape[1:], lambda *indice, ite_idx=idx: akg.tvm.expr.Select( local_data(*indice) > data(ite_idx + 1, *indice), akg.tvm.const(ite_idx + 1, "float16"), tmp_idx(*indice) )) local_data = akg.tvm.compute( shape[1:], lambda *indice, ite_idx=idx: akg.tvm.expr.Select( local_data(*indice) > data(ite_idx + 1, *indice), data(ite_idx + 1, *indice), local_data(*indice) )) elif method == "max": tmp_idx = akg.tvm.compute( shape[1:], lambda *indice, ite_idx=idx: akg.tvm.expr.Select( local_data(*indice) < data(ite_idx + 1, *indice), akg.tvm.const(ite_idx + 1, "float16"), tmp_idx(*indice) )) local_data = akg.tvm.compute( shape[1:], lambda *indice, ite_idx=idx: akg.tvm.expr.Select( local_data(*indice) < data(ite_idx + 1, *indice), data(ite_idx + 1, *indice), local_data(*indice) )) else: raise ValueError("not support " + method) res = akg.tvm.compute(out_shape, lambda *indice: tmp_idx(*indice).astype("int32"), "cast1") else: raise ValueError("Argmax only support first axis and is last axis now!") lager = out_shape if len(out_shape) > len(shape) else shape strategy = argminmax_tiling_strategy(lager, real_axis) if strategy: attr_map["custom_tiling"] = strategy return res, attr_map