|
- #!/usr/bin/env python3
- # coding: utf-8
- # Copyright 2019 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
-
- """common"""
- import akg.tvm
- from .elewise_compute import vmuls, vadds, vmax, vmin, vabs, vrec, vmul, set_is_need_save_dtype
- from .cast_compute import floor, round, cast
-
-
- def fargmax(x, y):
- """
- Build expression for the index of maximum value among input expressions x and y.
-
- Args:
- x (tvm.expr.Expr): Input expression.
- y (tvm.expr.Expr): Input expression.
-
- Returns:
- tvm.expr.Expr. The call expression.
-
- Examples:
- >>> n = akg.tvm.var('n')
- >>> m = akg.tvm.var('m')
- >>> data = akg.tvm.placeholder((n, m), name='data')
- >>> k = akg.tvm.reduce_axis((0, m), "k")
- >>> reducer = akg.tvm.comm_reducer(lambda x,y: akg.fargmax(x, y), lambda t: akg.tvm.min_value(t), name="argmax")
- >>> res = akg.tvm.compute((n,), lambda *indice: reducer(data(*indice, k), axis=k), name="res")
- """
- return akg.tvm.call_pure_intrin(x.dtype, "fargmax", x, y)
-
-
- def fargmin(x, y):
- """
- Build expression for the index of minimum value among input expressions x and y.
-
- Args:
- x (tvm.expr.Expr): Input expression.
- y (tvm.expr.Expr): Input expression.
-
- Returns:
- tvm.expr.Expr. The call expression.
- """
- return akg.tvm.call_pure_intrin(x.dtype, "fargmin", x, y)
-
-
- def mad(x, y):
- """
- Build expression for two matrices multiplication and add.
-
- Args:
- x (tvm.expr.Expr): Input expression.
- y (tvm.expr.Expr): Input expression.
-
- Returns:
- tvm.expr.Expr. The call expression.
-
- Examples:
- >>> n = akg.tvm.var('n')
- >>> m = akg.tvm.var('m')
- >>> k = akg.tvm.var('k')
- >>> A = akg.tvm.placeholder((m, k), name='A')
- >>> B = akg.tvm.placeholder((k, n), name='B')
- >>> kk = akg.tvm.reduce_axis((0, k), name='kk')
- >>> mmad = akg.tvm.comm_reducer(lambda x, y: akg.mad(x, y), lambda t: akg.tvm.const(0, dtype=t), name="mmad")
- >>> C = akg.tvm.compute((m, n), lambda i, j: mmad(A[i, kk] * B[kk, j], axis=kk), name="C")
- """
- return akg.tvm.call_pure_intrin(x.dtype, "mad", x, y)
-
-
- mmad = akg.tvm.comm_reducer(lambda x, y: mad(x, y), lambda t: akg.tvm.const(0, dtype=t), name="mmad")
-
-
- def dropout(x, y):
- """
- Build expression with dropout function.
-
- Args:
- x (tvm.expr.Expr): Input expression.
- y (tvm.expr.Expr): Input expression.
-
- Returns:
- tvm.expr.Expr. The call expression.
- """
- return akg.tvm.call_pure_intrin(y.dtype, "dropout", x, y)
-
-
- def iou(x, y):
- """
- Return the intersection over union of x, y box.
-
- Args:
- x (tvm.expr.Expr): Input expression.
- y (tvm.expr.Expr): Input expression.
-
- Returns:
- tvm.expr.Expr. The call expression.
- """
- return akg.tvm.call_pure_intrin(x.dtype, "iou", x, y)
-
-
- def nms(x, y, scalar):
- """
- return nonmaximum suppresion result x, y box.
-
- Args:
- x (tvm.expr.Expr): Input argument of reduced tensor.
- y (tvm.expr.Expr): Input argument.
- scalar (Union[tvm.expr.Expr, float]): Score threshold of nms.
-
- Returns:
- z : tvm.expr.Expr. The result is store in fp16, each fp16 is a hex number indicating suppresion.
- """
- return akg.tvm.call_pure_intrin(x.dtype, "nms", x, y, scalar)
-
-
- def topk_sort(dst, src, topk):
- """
- sort the proposal box and return topk result, used when the sort process need partition the sorting loop.
-
- Args:
- dst (tvm.expr.Expr): Input argument. The destination of sort generated by common reducer.
- src (tvm.expr.Expr): Input argument.
- Strictly required that the box number can be divisible by 16 and item number is 8.
- topk (tvm.expr.Expr): Input argument. Constant tvm.expr.Expr indicating the required topk number.
-
- Returns:
- z : tvm.expr.Expr. The result.
- """
- return akg.tvm.call_pure_intrin(src.dtype, "topk_sort", dst, src, topk)
-
-
- def proposal_sort(dst, src, topk):
- """
- sort the proposal box and return topk result.
-
- Args:
- dst (tvm.expr.Expr): Input argument. The destination of sort generated by common reducer.
- src (tvm.expr.Expr): Input argument.
- Strictly required that the box number can be divisible by 16 and item number is 8.
- topk (tvm.expr.Expr): Input argument. Constant tvm.expr.Expr indicating the required topk number.
-
- Returns:
- z : tvm.expr.Expr. The result.
- """
- return akg.tvm.call_pure_intrin(src.dtype, "proposal_sort", dst, src, topk)
-
-
- def fnot(x):
- return akg.tvm.call_pure_intrin(x.dtype, "not", x)
-
-
- def round_to(data, max_, min_):
- """
- round data to [min,max]
-
- Args:
- data (Tensor): tensors need to change dtype.
- max_ (float): the range of res.
- min_ (float): the range of res.
-
- Returns:
- tensor : akg.tvm.tensor ,elements in tensor is in range [min,max]
- """
- data_tmp = vmuls(data, 0)
- data_min = vadds(data_tmp, min_)
- data_max = vadds(data_tmp, max_)
- data1 = vmax(data, data_min)
- data1 = vmin(data1, data_max)
- return data1
-
-
- def cast_to(data, dtype, f1628_int_flag=False):
- """
- a wrapped cast operations , cast data to the type of dtype
-
- Args:
- data (Tensor): akg.tvm.tensor needs to change dtype.
- dtype (String): dst dtype need to cast to.
- f1628_int_flag (bool): before fp16->int8/uint8, the data is all interger or not. default value is False.
-
- Returns:
- tensor : akg.tvm.tensor.
- """
- if isinstance(data, akg.tvm.tensor.Tensor):
- data_dtype = getattr(data, 'dtype')
- else:
- raise RuntimeError("The cast input type must be akg.tvm.tensor")
-
- if (data_dtype == "float16") and (dtype == "int32"):
- fp16_max = akg.tvm.const(32768, dtype="float16")
- fp16_min = akg.tvm.const(2 ** (-15), dtype="float16")
-
- data1 = round_to(data, 0.5, -0.5)
-
- new_data = vmuls(data1, fp16_max)
- tmp2 = vabs(new_data)
- tmp3 = vadds(tmp2, fp16_min)
- fp16_res = vmul(new_data, vrec(tmp3))
- sign_res = round(fp16_res)
-
- floor_data = floor(vabs(data))
- res = vmul(floor_data, sign_res)
- return res
- if data_dtype == "float16" and dtype in ("int8", "uint8") and not f1628_int_flag:
- fp16_half = akg.tvm.const(-0.5, dtype="float16")
- set_is_need_save_dtype()
- data = vadds(data, fp16_half)
-
- if data_dtype == dtype:
- return data
- if data_dtype == "float16":
- tmp = data
- else:
- tmp = cast(data, dst_dtype="float16")
- return cast(tmp, dst_dtype=dtype)
-
-
- def four2five_nchw(data):
- return akg.tvm.call_pure_intrin(data.dtype, "four2five_nchw", data)
-
-
- def load_im2col_c1_buf(data, pad_h, pad_t, pad_l, pad_r,
- fm_h, fm_w, stride_h, stride_w,
- filter_h, filter_w, dilation_h, dilation_w, repeat_mode, jmp_offset):
- return akg.tvm.call_pure_intrin(data.dtype, "load_im2col_c1_buf", data, pad_h, pad_t, pad_l, pad_r,
- fm_h, fm_w, stride_h, stride_w,
- filter_h, filter_w, dilation_h, dilation_w, repeat_mode, jmp_offset)
-
-
- def sin(data):
- return akg.tvm.call_pure_intrin(data.dtype, "sin", data)
-
-
- def cos(data):
- return akg.tvm.call_pure_intrin(data.dtype, "cos", data)
-
-
- def sinh(data):
- return akg.tvm.call_pure_intrin(data.dtype, "sinh", data)
-
-
- def cosh(data):
- return akg.tvm.call_pure_intrin(data.dtype, "cosh", data)
-
-
- def divide_var(data, divisor):
- return akg.tvm.call_pure_intrin(data.dtype, "divide_var", data, divisor)
-
-
- def vmadd(x, y, z):
- """
- Call the vmadd instruction to calculate :math:`x * y + z`.
-
- Args:
- x (tvm.tensor.Tensor): input x.
- y (tvm.tensor.Tensor): input y.
- z (tvm.tensor.Tensor): input z.
-
- Returns:
- tensor : akg.tvm.tensor.
- """
- return akg.tvm.call_pure_intrin(x.dtype, "vmadd", y, z, x)
-
-
- def vmla(x, y, z):
- """
- Call the vmla instruction to calculate :math:`x + y * z`.
-
- Args:
- x (tvm.tensor.Tensor): input x.
- y (tvm.tensor.Tensor): input y.
- z (tvm.tensor.Tensor): input z.
-
- Returns:
- tensor : akg.tvm.tensor.
- """
- return akg.tvm.call_pure_intrin(x.dtype, "vmla", y, z, x)
|