#!/usr/bin/env python3 # coding: utf-8 # Copyright 2019 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """operator dsl function: cast""" import akg.tvm import akg.topi from akg.utils import kernel_exec as utils from akg.utils import validation_check as vc_util from akg.utils.format_transform import get_shape @vc_util.check_input_type(akg.tvm.tensor.Tensor, str) def cast(data, dst_type): """ cast data to target type. Args: data (tvm.tensor.Tensor): Tensor to be casted, type can be float32, float16, int32, int8, uint8 and bool. dst_type (str): target cast type. Returns: tvm.tensor.Tensor, type is dst_type. """ ori_type = data.dtype shape = get_shape(data) # dtype check dst_check_list = ["int8", "float32", "float16", "uint8", "int32"] if dst_type not in dst_check_list: raise RuntimeError("cast only support cast to %s while dtype is %s" % (",".join(dst_check_list), dst_type)) if utils.product_is_mini(): # product mini has not conv insn between float32 and int32. if ori_type == "float32" and dst_type == "int32": tmp = akg.topi.cast(data, "float16") return akg.topi.cast(tmp, dst_type) if ori_type == "int32" and dst_type == "float32": tmp = akg.topi.cast(data, "float16") return akg.topi.cast(tmp, dst_type) dtype_pair = (ori_type, dst_type) support_dtype = (('float32', 'float16'), ('float16', 'float32'), ('float16', 'int8'), ('float16', 'uint8'), ('int32', 'float16'), ('int32', 'float32'), ('float16', 'int32'), ('float32', 'int32'), ('uint8', 'float16'), ('int8', 'float16')) tmp_trans_dtype = (('int8', 'float32'), ('float32', 'int8'), ('bool', 'float32'), ('uint8', 'float32'), ('uint8', 'int32'), ('bool', 'int32'), ('float32', 'uint8')) if dtype_pair not in support_dtype and dtype_pair not in tmp_trans_dtype and ori_type != dst_type: raise RuntimeError("Don't support cast from ", ori_type, " to ", dst_type) need_tmp_transfer = dtype_pair in tmp_trans_dtype if need_tmp_transfer: if data.dtype == 'float32' and dst_type == 'int8' and not utils.product_is_mini(): tmp = akg.tvm.compute(shape, lambda *indice: akg.tvm.trunc(data(*indice)).astype('int32')) tmp = akg.topi.cast(tmp, 'float16') out = akg.tvm.compute(shape, lambda *indice: akg.tvm.trunc(tmp(*indice)).astype(dst_type)) else: tmp = akg.topi.cast(data, 'float16') out = akg.topi.cast(tmp, dst_type) else: if data.dtype in ('float16', 'float32') and dst_type in ('int8, int32') and not utils.product_is_mini(): out = akg.tvm.compute(shape, lambda *indice: akg.tvm.trunc(data(*indice)).astype(dst_type)) else: out = akg.topi.cast(data, dst_type) return out