# Copyright 2019 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """dsl create helping function""" import _akg from _akg.utils import format_transform as ft_util class TensorUtils: """Class for creating tensor.""" CREATE_SCH_ONLY = 'create_sch_only' @classmethod def get_tensor_attrs(cls, tensor): """get tensor attrs.""" tensor_attrs = dict() if "attrs" in dir(tensor.op): tensor_attrs = dict(tensor.op.attrs.items()) return tensor_attrs @classmethod def update_tensor_attrs(cls, tensor, attrs): """update tensor attrs.""" tensor_attrs = cls.get_tensor_attrs(tensor) tensor_attrs.update(attrs) tensor = _akg.tvm.compute(tensor.shape, lambda *indice: tensor[indice], name=tensor.op.name, tag=tensor.op.tag, attrs=tensor_attrs) return tensor @classmethod def is_create_sch_only(cls, tensor): tensor_attrs = cls.get_tensor_attrs(tensor) if cls.CREATE_SCH_ONLY in tensor_attrs.keys(): return True return False @classmethod def is_output_value(cls, tensor): """check output value.""" return not cls.is_create_sch_only(tensor) @classmethod def inplace_set(cls, input_tensor, output_tensor, buffer_name="data_buf"): """inplace set.""" input_tensor_shape = ft_util.get_shape(input_tensor) output_tensor_shape = ft_util.get_shape(output_tensor) if not input_tensor_shape == output_tensor_shape: raise RuntimeError("Shape of the input_tensor and the output_tensor should be equal, " "but got %s and %s"%(input_tensor_shape, output_tensor_shape)) output_tensor = cls.update_tensor_attrs(output_tensor, {cls.CREATE_SCH_ONLY: 1}) data_buf = _akg.tvm.decl_buffer(input_tensor.shape, input_tensor.dtype, name=buffer_name) binds_info = {input_tensor: data_buf, output_tensor: data_buf} return output_tensor, binds_info @classmethod def inplace_set_tensors(cls, input_tensors, output_tensors, buffer_names=None): """ inplace set for tensors Args: in_tensors (Union[list, tuple]): Origin input tensors. out_tensors (Union[list, tuple]): Origin output tensors. buffer_names (Union[list, tuple] or None): Buffer names used to bind. Return: inplace_tensors (list): Output tensors with the inplace info. binds_infos (dict): Dictionary that maps the input tensor and the output tensor to buffer. """ if not buffer_names: buffer_names = ["data_buf_%s" % i for i in range(len(input_tensors))] for arg in (input_tensors, output_tensors, buffer_names): if not isinstance(arg, (tuple, list)): raise RuntimeError("arg must be tuple or list!") if len(input_tensors) != len(output_tensors) or len(input_tensors) != len(buffer_names): raise RuntimeError("length of the input_tensors, output_tensors and buffer_names must be equal!") inplace_tensors = [] binds_infos = dict() for input_tensor, output_tensor, buffer_name in zip(input_tensors, output_tensors, buffer_names): inplace_tensor, binds_info = cls.inplace_set(input_tensor, output_tensor, buffer_name) inplace_tensors.append(inplace_tensor) binds_infos.update(binds_info) return inplace_tensors, binds_infos def produce_shapes(shape1, shape2): """two input shapes produce three output shape.""" shape1 = list(shape1) shape2 = list(shape2) flag = 0 if len(shape1) < len(shape2): shape1, shape2 = shape2, shape1 flag = 1 output_shape_len = len(shape1) dec = output_shape_len - len(shape2) for i in range(dec): shape2 = [1] + shape2 out_shape = [] for i in range(output_shape_len): if (shape1[i] != shape2[i]) and (shape1[i] != 1) and (shape2[i] != 1): raise RuntimeError("input shapes not match!") out_shape.append(shape1[i] if shape1[i] > shape2[i] else shape2[i]) if flag == 1: shape1, shape2 = shape2, shape1 return shape1, shape2, out_shape