|
- #!/usr/bin/env python3
- # coding: utf-8
- # Copyright 2019 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
-
- """operator dsl function:one hot"""
- import akg.tvm
- from akg.tvm.hybrid import script
- from akg.utils import custom_tiling as ct_util
- from akg.utils.validation_check import ops_dtype_check, check_shape, check_input_type, DtypeForDavinci
-
-
- def onehot_tiling_strategy(tensor, axis):
- """Custom tiling strategy for onehot op."""
- tot_axis = ct_util.create_constraint_on_tensor(tensor=tensor,
- values=0,
- constraints=ct_util.TileConstraint.SET_PRIORITY,
- tensor_pos=axis)
- return tot_axis
-
-
- @check_input_type(akg.tvm.tensor.Tensor, int, str, (int, float, type(None)),
- (int, float, type(None)), (int, type(None)))
- def one_hot(indices, depth, dtype, on_value=None, off_value=None, axis=None):
- """
- generate the one-hot code for input indices
-
- Args:
- indices (tvm.tensor.Tensor): defining the input data.
- depth (int): defining the depth of the one hot dimension.
- dtype (String): "float16" or "float32" or "int" or "int32".
- on_value (Scalar): optional. defining the value to fill in the output if indices[i] == j. default 1.
- off_value (Scalar): optional. defining the value to fill in the output if indices[i] != j. default 0.
- axis (int): optional. The axis to fill. default -1, that means a new inner-most axis.
- attrs (dict): optional. Dictionary provide tiling information for poly.
- kernel_name (String): optional. the name of the kernel that will be generated.
-
- Returns:
- akg.tvm.module. A module that combines both host and device code.
- """
-
- ops_dtype_check([indices.dtype, dtype], DtypeForDavinci.INT32.value + DtypeForDavinci.ALL_FLOAT.value)
-
- shape = [x.value for x in indices.shape]
- check_shape(shape)
-
- # Tensor of tensor do not support tensor with more than 3 dimensions for now
- if len(shape) > 3:
- raise RuntimeError("one_hot do not support input shape %d dimensions which is more than 3" % len(shape))
-
- on_value_const = akg.tvm.const(1, dtype) if on_value is None else akg.tvm.const(on_value, dtype)
- off_value_const = akg.tvm.const(0, dtype) if off_value is None else akg.tvm.const(off_value, dtype)
-
- if axis is None:
- axis = -1
-
- if axis == -1:
- axis = len(shape)
-
- if axis <= -2 or axis > len(shape):
- raise RuntimeError("axis(%s) is not an valid index" % axis)
-
- in_shape = [x for x in indices.shape]
-
- in_shape.insert(axis, depth)
- out_shape = tuple(in_shape)
-
- @script
- def one_hot_hybrid_1(indices_in, on_value_const_in, off_value_const_in):
- out = output_tensor(out_shape, on_value_const_in.dtype)
-
- m, n = out_shape
-
- for i in range(m):
- for j in range(n):
- out[i, j] = off_value_const_in
-
- if axis == 0:
- for i in range(n):
- if indices_in[i] >= 0:
- out[indices_in[i], i] = on_value_const_in
- else:
- for i in range(m):
- if indices_in[i] >= 0:
- out[i, indices_in[i]] = on_value_const_in
-
- return out
-
- @script
- def one_hot_hybrid_2(indices_in, on_value_const_in, off_value_const_in):
- out = output_tensor(out_shape, on_value_const_in.dtype)
-
- m, n, k = out.shape
-
- for x in range(m):
- for y in range(n):
- for z in range(k):
- out[x, y, z] = off_value_const_in
-
- if axis == 0:
- for i in range(n):
- for j in range(k):
- if indices_in[i, j] >= 0:
- out[indices_in[i, j], i, j] = on_value_const_in
- elif axis == 1:
- for i in range(m):
- for j in range(k):
- if indices_in[i, j] >= 0:
- out[i, indices_in[i, j], j] = on_value_const_in
- else:
- for i in range(m):
- for j in range(n):
- if indices_in[i, j] >= 0:
- out[i, j, indices_in[i, j]] = on_value_const_in
-
- return out
-
- @script
- def one_hot_hybrid_3(indices_in, on_value_const_in, off_value_const_in):
- out = output_tensor(out_shape, on_value_const_in.dtype)
- m, n, k, t = out.shape
-
- for x in range(m):
- for y in range(n):
- for z in range(k):
- for u in range(t):
- out[x, y, z, u] = off_value_const_in
-
- if axis == 0:
- for i in range(n):
- for j in range(k):
- for c in range(t):
- if indices_in[i, j, c] >= 0:
- out[indices_in[i, j, c], i, j, c] = on_value_const_in
- elif axis == 1:
- for i in range(m):
- for j in range(k):
- for c in range(t):
- if indices_in[i, j, c] >= 0:
- out[i, indices_in[i, j, c], j, c] = on_value_const_in
- elif axis == 2:
- for i in range(m):
- for j in range(n):
- for c in range(t):
- if indices_in[i, j, c] >= 0:
- out[i, j, indices_in[i, j, c], c] = on_value_const_in
- else:
- for i in range(m):
- for j in range(n):
- for c in range(k):
- if indices_in[i, j, c] >= 0:
- out[i, j, c, indices_in[i, j, c]] = on_value_const_in
- return out
-
- if len(shape) == 1:
- out = one_hot_hybrid_1(indices, on_value_const, off_value_const)
- elif len(shape) == 2:
- out = one_hot_hybrid_2(indices, on_value_const, off_value_const)
- elif len(shape) == 3:
- out = one_hot_hybrid_3(indices, on_value_const, off_value_const)
- strategy = onehot_tiling_strategy(out, axis)
- attr_map = {"RewriteVarTensorIdx": True}
- if strategy:
- attr_map["custom_tiling"] = strategy
-
- return out, attr_map
-
-
- @check_input_type(akg.tvm.tensor.Tensor, akg.tvm.tensor.Tensor, akg.tvm.tensor.Tensor,
- int, (int, type(None)))
- def one_hot_v2(indices, on_value, off_value, depth, axis=None):
- """
- generate the one-hot code for input indices
-
- Args:
- indices (akg.tvm.tensor.Tensor): defining the input data.
- on_value (akg.tvm.tensor.Tensor): defining the value to fill in the output if indices[i] == j.
- off_value (akg.tvm.tensor.Tensor): defining the value to fill in the output if indices[i] != j.
- depth (int): defining the depth of the one hot dimension.
- axis (int): optional. The axis to fill. default -1, that means a new inner-most axis.
- attrs (dict): optional. Dictionary provide tiling information for poly.
- kernel_name (String): optional. the name of the kernel that will be generated.
-
- Returns:
- akg.tvm.module. A module that combines both host and device code.
- """
-
- ops_dtype_check(indices.dtype, DtypeForDavinci.INT32)
- ops_dtype_check([on_value.dtype, off_value.dtype], [DtypeForDavinci.INT32, DtypeForDavinci.ALL_FLOAT])
-
- shape = [x.value for x in indices.shape]
- check_shape(shape)
-
- # OneHot do not support tensor with more than 3 dimensions for now
- if len(shape) > 3:
- raise RuntimeError("one_hot do not support input shape %d dimensions which is more than 3" % len(shape))
-
- if axis is None:
- axis = -1
-
- if axis == -1:
- axis = len(shape)
-
- if axis <= -2 or axis > len(shape):
- raise RuntimeError("axis(%s) is not an valid index" % axis)
-
- in_shape = [x for x in indices.shape]
-
- in_shape.insert(axis, depth)
- out_shape = tuple(in_shape)
-
- @script
- def one_hot_hybrid_1(indices_in, on_value_const_in, off_value_const_in):
- out = output_tensor(out_shape, on_value_const_in.dtype)
-
- m, n = out_shape
-
- for i in range(m):
- for j in range(n):
- out[i, j] = off_value_const_in[0]
-
- if axis == 0:
- for i in range(n):
- if indices_in[i] >= 0:
- out[indices_in[i], i] = on_value_const_in[0]
- else:
- for i in range(m):
- if indices_in[i] >= 0:
- out[i, indices_in[i]] = on_value_const_in[0]
-
- return out
-
- @script
- def one_hot_hybrid_2(indices_in, on_value_const_in, off_value_const_in):
-
- out = output_tensor(out_shape, on_value_const_in.dtype)
-
- m, n, k = out.shape
-
- for x in range(m):
- for y in range(n):
- for z in range(k):
- out[x, y, z] = off_value_const_in[0]
-
- if axis == 0:
- for i in range(n):
- for j in range(k):
- if indices_in[i, j] >= 0:
- out[indices_in[i, j], i, j] = on_value_const_in[0]
- elif axis == 1:
- for i in range(m):
- for j in range(k):
- if indices_in[i, j] >= 0:
- out[i, indices_in[i, j], j] = on_value_const_in[0]
- else:
- for i in range(m):
- for j in range(n):
- if indices_in[i, j] >= 0:
- out[i, j, indices_in[i, j]] = on_value_const_in[0]
-
- return out
-
- @script
- def one_hot_hybrid_3(indices_in, on_value_const_in, off_value_const_in):
- out = output_tensor(out_shape, on_value_const_in.dtype)
- m, n, k, t = out.shape
-
- for x in range(m):
- for y in range(n):
- for z in range(k):
- for u in range(t):
- out[x, y, z, u] = off_value_const_in[0]
-
- if axis == 0:
- for i in range(n):
- for j in range(k):
- for c in range(t):
- if indices_in[i, j, c] >= 0:
- out[indices_in[i, j, c], i, j, c] = on_value_const_in[0]
- elif axis == 1:
- for i in range(m):
- for j in range(k):
- for c in range(t):
- if indices_in[i, j, c] >= 0:
- out[i, indices_in[i, j, c], j, c] = on_value_const_in[0]
- elif axis == 2:
- for i in range(m):
- for j in range(n):
- for c in range(t):
- if indices_in[i, j, c] >= 0:
- out[i, j, indices_in[i, j, c], c] = on_value_const_in[0]
- else:
- for i in range(m):
- for j in range(n):
- for c in range(k):
- if indices_in[i, j, c] >= 0:
- out[i, j, c, indices_in[i, j, c]] = on_value_const_in[0]
- return out
-
- if len(shape) == 1:
- out = one_hot_hybrid_1(indices, on_value, off_value)
- elif len(shape) == 2:
- out = one_hot_hybrid_2(indices, on_value, off_value)
- elif len(shape) == 3:
- out = one_hot_hybrid_3(indices, on_value, off_value)
- strategy = onehot_tiling_strategy(out, axis)
- attr_map = {"RewriteVarTensorIdx": True}
- if strategy:
- attr_map["custom_tiling"] = strategy
-
- return out, attr_map
|