|
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970 |
- #!/usr/bin/env python3
- # coding: utf-8
- # Copyright 2019 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
-
- """operator dsl function: reshape"""
-
- from functools import reduce
- import akg
- import akg.topi
- from akg.utils.validation_check import ops_dtype_check, check_shape, DtypeForDavinci, check_input_type
- from akg.utils.format_transform import get_shape
- from akg.utils import dynamic_shape as ds
-
-
- @check_input_type(akg.tvm.tensor.Tensor, (list, tuple))
- def reshape(data, out_shape):
- """
- Rearranges input tensor data to new shape out_shape.
-
- Args:
- data (tvm.tensor.Tensor): The tensor to be reshaped.
- out_shape (list, tuple): The new shape applied on the input tensor data,
- should be compatible with the original shape of data.
-
- Returns:
- The reshaped akg.tvm.tensor of same type as input tensor data.
- """
- ops_dtype_check(data.dtype, DtypeForDavinci.INT32.value + DtypeForDavinci.ALL_FLOAT.value)
-
- data_shape = data.shape
- check_shape(data_shape)
-
- in_shape = get_shape(data)
- out_shape = list(out_shape)
- is_dynamic = ds.shape_is_dynamic(data)
-
- if -1 in out_shape:
- access_size = 1
- for i, o_shape in enumerate(out_shape):
- if -1 != o_shape:
- access_size *= o_shape
- else:
- hit_idx = i
- ori_size = reduce(lambda x, y: x * y, in_shape)
- if ori_size % access_size != 0:
- raise ValueError(("Invalid out_shape ({})".format(out_shape)))
-
- out_shape[hit_idx] = int(ori_size / access_size)
- else:
- if not is_dynamic:
- if reduce(lambda x, y: x * y, in_shape) != reduce(lambda x, y: x * y, out_shape):
- raise ValueError("the total length of out_shape is not equal to the in_shape")
-
- inputs = akg.tvm.compute(in_shape, lambda *indice: data(*indice), name="inputs")
- res = akg.topi.reshape(inputs, out_shape)
- output = akg.tvm.compute(out_shape, lambda *indice: res(*indice), name="reshape")
- attr_map = {}
- return output, attr_map
|