#!/usr/bin/env python3 # coding: utf-8 # Copyright 2019 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """format transform function""" import akg supported_bits = { "8": 1, "16": 2, "32": 4, "64": 8, "bool": 1 } def to_tvm_const(x): """Convert integer to TVM expression""" if isinstance(x, int): return akg.tvm.const(x) return x def get_const(expr): """ get const value from TVM expression. Args: expr (tvm.expr.Expr): tvm expression. Returns: value (int): expr value. """ if isinstance(expr, int): return expr if not isinstance(expr, (akg.tvm.expr.IntImm, akg.tvm.expr.UIntImm)): expr = akg.tvm.ir_pass.Simplify(expr) if not isinstance(expr, (akg.tvm.expr.IntImm, akg.tvm.expr.UIntImm)): raise TypeError("Expr is not a const. Get const fail, please use get shape.") return expr.value def get_bytes(dtype, allow_none=False): """get number of bytes for supported dtype.""" dtype = str(dtype) for bits in supported_bits: if bits in dtype: return supported_bits[bits] if allow_none: return None raise RuntimeError("Invalid dtype, supported bits are {0}".format(supported_bits.keys())) def refine_shape(shape, reduce_axis=None): """ Refine shape to drop 1 in shape according to reduce axis. Note: if input is just shape, result is shape, and if inputs are shape and axis, result is a tuple of (shape, axis). Args: shape : shape of data reduce_axis : list, tuple or int axis want to reduce keepdims: if keepdims = True, we should not refine the shape Returns: shape (list): refined shape. reduce_axis (list): if input parameters send reduce axis, this will be the output. if all the reduce axis is illegal like the length of reduce axis is 1, a empty list([]) will be returned. """ def _refine_shape_no_reduce(): refined = [shp for _, shp in enumerate(shape) if shp > 1] if not refined: refined = [1] return refined if reduce_axis is not None: res_reduce_axis = sorted(refine_reduce_axis(shape, reduce_axis)) if not res_reduce_axis: return _refine_shape_no_reduce(), [] res_shape = shape[:] refined_shape = [] count = 0 for i in res_shape: if i > 1: refined_shape.append(i) count += 1 else: for j, axs in enumerate(res_reduce_axis): if axs > count: res_reduce_axis[j] -= 1 return refined_shape, res_reduce_axis return _refine_shape_no_reduce() def refine_reduce_axis(input, axis): """make reduce axis legal.""" shape = get_shape(input) if axis is None: axis = [i for i in range(len(shape))] elif isinstance(axis, int): axis = [axis] elif not isinstance(axis, (tuple, list)): raise TypeError("axis must be one of the type int,tuple,list or None") if len(axis) > len(shape): raise ValueError("axis size must not larger than shape size") axis = list(axis) for i, _ in enumerate(axis): if axis[i] < 0: axis[i] += len(shape) if axis[i] >= len(shape): raise ValueError(("axis value-{} exceeds len(axis) which is invalid".format(axis[i]))) axis.sort(reverse=True) return axis def get_shape_from_tensor(data): """translate akg.tvm.shape to list type in python.""" tvm_shape = data.shape py_shape = [] for i in tvm_shape: if isinstance(i, akg.tvm.expr.IntImm): py_shape.append(i.value) else: py_shape.append(i) return py_shape def tvm_shape_to_list(tvm_shape): """translate akg.tvm.shape to list type in python.""" py_shape = [] for i in tvm_shape: if isinstance(i, akg.tvm.expr.Var): py_shape.append(i) else: py_shape.append(i.value) return py_shape def tvm_array_to_list(tvm_array): """translate akg.tvm.array to list type in python.""" tensor_list = [] for i in tvm_array: if isinstance(i, akg.tvm.tensor.Tensor): tensor_list.append(i) else: raise ValueError("Only surpport akg.tvm.tensor.Tensor.") return tensor_list def get_shape(data): """get shape and save it as list.""" if isinstance(data, akg.tvm.tensor.Tensor): shape = get_shape_from_tensor(data) elif isinstance(data, akg.tvm.container.Array): shape = tvm_shape_to_list(data) elif isinstance(data, int): shape = [data] elif isinstance(data, (tuple, list)): shape = list(data) elif isinstance(data, akg.tvm.expr.Var): shape = [data] else: raise TypeError("Refine axis does not support type {} for now.".format(type(data))) return shape def convert_to_list(something, convert_all=True): """convert other types to string.""" out = [] if isinstance(something, (list, tuple)): for x in something: out.append(convert_to_list(x, convert_all=False)) else: if convert_all: out.append(something) else: out = something return out def to_tvm_nd_array(data, ctx=None): """convert other types to tvm nd array with specified context""" if ctx is None: ctx = akg.tvm.context("cuda", 0) if isinstance(data, list): return [akg.tvm.nd.array(d, ctx) for d in data] if isinstance(data, tuple): return (akg.tvm.nd.array(d, ctx) for d in data) return akg.tvm.nd.array(data, ctx)