# Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """operator dsl function: reshape""" import akg import akg.topi from akg.utils import validation_check as vc_util from akg.utils.format_transform import get_shape from functools import reduce @vc_util.check_input_type(akg.tvm.tensor.Tensor, (list, tuple)) def reshape(data, out_shape): """ Rearranges input tensor data to new shape out_shape. Args: data (tvm.tensor.Tensor): The tensor to be reshaped. out_shape (list, tuple): The new shape applied on the input tensor data, should be compatible with the original shape of data. Returns: The reshaped akg.tvm.tensor of same type as input tensor data. """ data_shape = data.shape vc_util.check_shape(data_shape) in_shape = get_shape(data) out_shape = list(out_shape) if -1 in out_shape: access_size = 1 for i, o_shape in enumerate(out_shape): if -1 != o_shape: access_size *= o_shape else: hit_idx = i ori_size = reduce(lambda x, y: x * y, in_shape) if ori_size % access_size != 0: raise ValueError(("Invalid out_shape ({})".format(out_shape))) out_shape[hit_idx] = int(ori_size / access_size) res = akg.topi.reshape(data, out_shape) return res