#!/usr/bin/env python3 # coding: utf-8 # Copyright 2019 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """operator dsl function: strided_slice""" import copy import numpy as np import akg.topi import akg.tvm from akg.utils import validation_check as vc_util def check_args(begin, end, strides, begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask): """check args.""" if len(begin) != len(end): raise Exception("len(begin) is {}, len(end) is {}. They must be identical!".format(len(begin), len(end))) if strides is not None: if len(begin) != len(strides): raise Exception("len(begin) is {}, len(strides) is {}. They must be identical!". format(len(begin), len(strides))) for s in strides: if s == 0: raise Exception("Value in strides[{}] must not be 0!".format(strides)) if begin_mask < 0 or begin_mask >= (2 ** len(begin)): raise Exception("Illegal begin_mask[{}]".format(begin_mask)) if end_mask < 0 or end_mask >= (2 ** len(begin)): raise Exception("Illegal end_mask[{}]".format(end_mask)) if ellipsis_mask < 0 or ellipsis_mask >= (2 ** len(begin)): raise Exception("Illegal ellipsis_mask[{}]".format(ellipsis_mask)) if ellipsis_mask != 0: # ellipsis_mask must be a power of two (only one ellipsis) if ellipsis_mask & (ellipsis_mask - 1) != 0: raise Exception("ellipsis_mask[{}] is not power of two (only one ellipsis).".format(ellipsis_mask)) if new_axis_mask < 0 or new_axis_mask >= (2 ** len(begin)): raise Exception("Illegal new_axis_mask[{}]".format(new_axis_mask)) if shrink_axis_mask < 0 or shrink_axis_mask >= (2 ** len(begin)): raise Exception("Illegal shrink_axis_mask[{}]".format(shrink_axis_mask)) def args_to_slices(begin, end, strides, begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask): """args to slice.""" check_args(begin, end, strides, begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask) slices = [] for dim, bgn in enumerate(begin): if (ellipsis_mask >> dim) & 1: slices.append(Ellipsis) elif (new_axis_mask >> dim) & 1: slices.append(np.newaxis) elif (shrink_axis_mask >> dim) & 1: slices.append(bgn) else: start = None if (begin_mask >> dim) & 1 else bgn stop = None if (end_mask >> dim) & 1 else end[dim] step = strides[dim] slices.append(slice(start, stop, step)) return slices def slices_to_args(slices=()): """slice to args.""" begin = [] end = [] strides = [] begin_mask = 0 end_mask = 0 ellipsis_mask = 0 new_axis_mask = 0 shrink_axis_mask = 0 for i, arg in enumerate(slices): if isinstance(arg, slice): begin.append(0 if arg.start is None else arg.start) if arg.start is None: begin_mask |= 1 << i end.append(0 if arg.stop is None else arg.stop) if arg.stop is None: end_mask |= 1 << i strides.append(1 if arg.step is None else arg.step) elif arg is np.newaxis: begin.append(0) end.append(0) strides.append(1) new_axis_mask |= 1 << i elif arg is Ellipsis: begin.append(0) end.append(0) strides.append(1) ellipsis_mask |= 1 << i elif isinstance(arg, int): begin.append(arg) end.append(arg + 1) strides.append(1) shrink_axis_mask |= 1 << i else: raise Exception("arg ", arg, ' is invalid') return begin, end, strides, begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask def complete_args(inputs_shape, begin, end, strides, begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask): """complete args.""" check_args(begin, end, strides, begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask) # step0: deep copy begin, end, strides begin = copy.copy(begin) end = copy.copy(end) strides = copy.copy(strides) # step1: store all bits and calculate new_axis_count check_args(begin, end, strides, begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask) begin_list = [(begin_mask >> dim) & 1 for dim in range(len(begin))] end_list = [(end_mask >> dim) & 1 for dim in range(len(begin))] ellipsis_list = [(ellipsis_mask >> dim) & 1 for dim in range(len(begin))] new_axis_list = [(new_axis_mask >> dim) & 1 for dim in range(len(begin))] new_axis_count = len([dim for dim in range(len(begin)) if (new_axis_mask >> dim) & 1]) shrink_list = [(shrink_axis_mask >> dim) & 1 for dim in range(len(begin))] # step2: fill the ellipsis using ellipsis_list ellipsis_idx = None for idx, x in enumerate(ellipsis_list): if x: ellipsis_idx = idx break if ellipsis_idx is not None: ellipsis_length = len(inputs_shape) - (len(begin) - 1 - new_axis_count) idx = ellipsis_idx begin.pop(idx) end.pop(idx) strides.pop(idx) begin_list.pop(idx) end_list.pop(idx) ellipsis_list.pop(idx) new_axis_list.pop(idx) shrink_list.pop(idx) for _ in range(ellipsis_length): begin.insert(idx, None) end.insert(idx, None) strides.insert(idx, 1) begin_list.insert(idx, 1) end_list.insert(idx, 1) ellipsis_list.insert(idx, 0) new_axis_list.insert(idx, 0) shrink_list.insert(idx, 0) # step3: remove new_axis using new_axis_list new_axis_index = [idx for idx, x in enumerate(new_axis_list) if x] for idx in new_axis_index[::-1]: begin.pop(idx) end.pop(idx) strides.pop(idx) begin_list.pop(idx) end_list.pop(idx) ellipsis_list.pop(idx) shrink_list.pop(idx) new_axis_list.pop(idx) # step4: update (begin, end, strides) using (shrink_list, begin_list, end_list) for dim, bgn in enumerate(begin): if shrink_list[dim]: end[dim] = bgn + 1 strides[dim] = 1 continue if begin_list[dim]: begin[dim] = 0 if end_list[dim]: end[dim] = inputs_shape[dim] return begin, end, strides, new_axis_index, shrink_list @vc_util.check_input_type(akg.tvm.tensor.Tensor, ((list, tuple), int), ((list, tuple), int), ((list, tuple), int), int, int, int, int, int) def strided_slice(inputs, begin, end, strides, begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask): """ Generate an array by slicing input tensor Args: inputs (tvm.tensor.Tensor): Tensor of type float16, float32. begin (Union[list, tuple, int]): The start indexes for slicing. end (Union[list, tuple, int]): The end indexes for slicing. strides (Union[list, tuple, int]): The strides for slicing. begin_mask (int): int32 mask for begin indexes. end_mask (int): int32 mask for end indexes. ellipsis_mask (int): int32 mask for inserting unspecified dimensions. new_axis_mask (int): int32 mask for new dim with length 1. shrink_axis_mask (int): int32 mask for shrinking the dims. Returns: tvm.tensor.Tensor, with the same dtype as inputs. """ shape = [x.value for x in inputs.shape] # step0~4: complete begin, end, strides begin, end, strides, new_axis_index, shrink_list = complete_args(shape, begin, end, strides, begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask) # step5: use topi to do strided_slice using begin, end, strides if (shape == [1] and begin == end): return akg.tvm.compute(shape, lambda *i: inputs(*i), name="out") if inputs.dtype == "uint8": inputs_cast = akg.topi.cast(inputs, "int8") else: inputs_cast = inputs out1 = akg.topi.strided_slice(inputs_cast, begin, end, strides) # step6: increase out_tensor's dim using new_axis_index new_shape = list(out1.shape) for idx in new_axis_index[::-1]: new_shape.insert(idx, 1) # step7: decrease out_tensor's dim using shrink_list for idx in new_axis_index[::-1]: shrink_list.insert(idx, 0) shrink_axis_index = [idx for idx, x in enumerate(shrink_list) if x] for idx in shrink_axis_index[::-1]: new_shape.pop(idx) # step8: reshape out_tensor out2 = akg.topi.reshape(out1, tuple(new_shape)) if inputs.dtype == "uint8": out2 = akg.topi.cast(out2, "uint8") return out2