#!/usr/bin/env python3 # coding: utf-8 # Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """dynamic shape function""" import akg import akg.tvm from akg.utils.format_transform import get_shape NODE_TYPE = "DynamicShapeNode" def to_expanded_list(data): data_list = [] if isinstance(data, (list, tuple)): for i in data: tmp_list = to_expanded_list(i) for ii in tmp_list: data_list.append(ii) else: data_list.append(data) return data_list def shape_is_dynamic(data): data_list = to_expanded_list(data) for i in data_list: shape = get_shape(i) if False in [isinstance(s, (int, akg.tvm.expr.IntImm)) for s in shape]: return True return False def preprocess_position(position): """check position's value is valid and turn integer position into list""" if isinstance(position, (list, tuple)): for p in position: if not isinstance(p, int): raise TypeError("Position of tensor should be a integer") elif isinstance(position, int): position = [position] else: raise TypeError( "Position of tensor should be a integer, list or a tuple") return position def preprocess_value_with_position(values, position): """check value is valid and compatible with position, and turn integer into list""" if isinstance(values, (list, tuple)): if len(values) != len(position): raise ValueError( "Length of values is not compatible with position.") for l in values: if not isinstance(l, int): raise TypeError( "Dynamic shape values of tensor should be a integer or a list/tuple of integer") elif isinstance(values, int): values = [values] else: raise TypeError( "Dynamic shape values of tensor should be a integer or a list/tuple of integer") return values def set_poly_upper_bound_for_tensor(tensor, upper_bound, position=None): """api for dsl to set poly upper bound for certain tensor.""" if not isinstance(tensor, akg.tvm.tensor.Tensor): raise TypeError("Tensor should be tvm.tensor.Tensor") if position is None: position = [i for i, _ in enumerate(tensor.shape)] position = preprocess_position(position) upper_bound = preprocess_value_with_position(upper_bound, position) tensor_shape = get_shape(tensor) ret = list() for i, p in enumerate(position): # create limit for var will help poly to determine the upper bound if isinstance(tensor_shape[p], akg.tvm.expr.Var): ret.append(create_dynamic_shape_node( tensor_name=tensor_shape[p].name, pos=p, poly_upper_bound=upper_bound[i])) return ret def set_dynamic_shape_limit_for_tensor(tensor, limit, position=None): """api for dsl to set dynamic shape limit for certain tensor.""" if not isinstance(tensor, akg.tvm.tensor.Tensor): raise TypeError("Tensor should be tvm.tensor.Tensor") if position is None: position = [i for i, _ in enumerate(tensor.shape)] position = preprocess_position(position) limit = preprocess_value_with_position(limit, position) tensor_name = tensor.op.name ret = list() for i, p in enumerate(position): # create limit for tensor in position p will help inferbound to determine the max bound ret.append(create_dynamic_shape_node( tensor_name=tensor_name, pos=p, dyn_shape_limit=limit[i])) return ret def create_dynamic_shape_node(tensor_name, pos, dyn_shape_limit=-1, poly_upper_bound=-1): return akg.tvm.make.node(NODE_TYPE, tensor_name=tensor_name, pos=pos, dyn_shape_limit=dyn_shape_limit, poly_upper_bound=poly_upper_bound)