|
|
|
@@ -1,4 +1,3 @@ |
|
|
|
|
|
|
|
/** |
|
|
|
* Copyright 2020 Huawei Technologies Co., Ltd |
|
|
|
* |
|
|
|
@@ -95,33 +94,53 @@ enum DataTypeTransMode { |
|
|
|
FROM_FLOAT_TO_INT32, |
|
|
|
FROM_FLOAT16_TO_FLOAT, |
|
|
|
FROM_FLOAT16_TO_INT32, |
|
|
|
FROM_FLOAT16_TO_UINT8, |
|
|
|
FROM_INT32_TO_FLOAT, |
|
|
|
FROM_INT32_TO_FLOAT16, |
|
|
|
FROM_INT32_TO_UINT8, |
|
|
|
FROM_INT32_TO_INT8, |
|
|
|
FROM_INT32_TO_BOOL, |
|
|
|
FROM_UINT8_TO_FLOAT, |
|
|
|
FROM_UINT8_TO_INT32, |
|
|
|
FROM_UINT8_TO_FLOAT16, |
|
|
|
FROM_INT8_TO_FLOAT, |
|
|
|
FROM_INT8_TO_FLOAT16, |
|
|
|
FROM_INT8_TO_INT32, |
|
|
|
FROM_INT64_TO_INT32, |
|
|
|
FROM_UINT16_TO_INT32, |
|
|
|
FROM_BOOL_TO_FLOAT, |
|
|
|
FROM_BOOL_TO_INT32, |
|
|
|
FROM_BOOL_TO_UINT8, |
|
|
|
FROM_BOOL_TO_FLOAT16, |
|
|
|
FROM_FLOAT64_TO_FLOAT32, |
|
|
|
FROM_FLOAT32_TO_FLOAT64 |
|
|
|
}; |
|
|
|
|
|
|
|
const std::map<std::pair<TypeId, TypeId>, DataTypeTransMode> mode_map{ |
|
|
|
{std::pair<TypeId, TypeId>(kNumberTypeFloat64, kNumberTypeFloat32), FROM_FLOAT64_TO_FLOAT32}, |
|
|
|
{std::pair<TypeId, TypeId>(kNumberTypeFloat32, kNumberTypeFloat64), FROM_FLOAT32_TO_FLOAT64}, |
|
|
|
{std::pair<TypeId, TypeId>(kNumberTypeFloat32, kNumberTypeFloat16), FROM_FLOAT_TO_FLOAT16}, |
|
|
|
{std::pair<TypeId, TypeId>(kNumberTypeFloat32, kNumberTypeInt32), FROM_FLOAT_TO_INT32}, |
|
|
|
{std::pair<TypeId, TypeId>(kNumberTypeFloat16, kNumberTypeFloat32), FROM_FLOAT16_TO_FLOAT}, |
|
|
|
{std::pair<TypeId, TypeId>(kNumberTypeFloat16, kNumberTypeInt32), FROM_FLOAT16_TO_INT32}, |
|
|
|
{std::pair<TypeId, TypeId>(kNumberTypeFloat16, kNumberTypeUInt8), FROM_FLOAT16_TO_UINT8}, |
|
|
|
{std::pair<TypeId, TypeId>(kNumberTypeInt32, kNumberTypeFloat32), FROM_INT32_TO_FLOAT}, |
|
|
|
{std::pair<TypeId, TypeId>(kNumberTypeInt32, kNumberTypeFloat16), FROM_INT32_TO_FLOAT16}, |
|
|
|
{std::pair<TypeId, TypeId>(kNumberTypeInt32, kNumberTypeUInt8), FROM_INT32_TO_UINT8}, |
|
|
|
{std::pair<TypeId, TypeId>(kNumberTypeInt32, kNumberTypeInt8), FROM_INT32_TO_INT8}, |
|
|
|
{std::pair<TypeId, TypeId>(kNumberTypeInt32, kNumberTypeBool), FROM_INT32_TO_BOOL}, |
|
|
|
{std::pair<TypeId, TypeId>(kNumberTypeUInt8, kNumberTypeFloat32), FROM_UINT8_TO_FLOAT}, |
|
|
|
{std::pair<TypeId, TypeId>(kNumberTypeUInt8, kNumberTypeInt32), FROM_UINT8_TO_INT32}, |
|
|
|
{std::pair<TypeId, TypeId>(kNumberTypeUInt8, kNumberTypeFloat16), FROM_UINT8_TO_FLOAT16}, |
|
|
|
{std::pair<TypeId, TypeId>(kNumberTypeInt8, kNumberTypeFloat32), FROM_INT8_TO_FLOAT}, |
|
|
|
{std::pair<TypeId, TypeId>(kNumberTypeInt8, kNumberTypeFloat16), FROM_INT8_TO_FLOAT16}, |
|
|
|
{std::pair<TypeId, TypeId>(kNumberTypeInt8, kNumberTypeInt32), FROM_INT8_TO_INT32}, |
|
|
|
{std::pair<TypeId, TypeId>(kNumberTypeInt64, kNumberTypeInt32), FROM_INT64_TO_INT32}, |
|
|
|
{std::pair<TypeId, TypeId>(kNumberTypeUInt16, kNumberTypeInt32), FROM_UINT16_TO_INT32}}; |
|
|
|
{std::pair<TypeId, TypeId>(kNumberTypeUInt16, kNumberTypeInt32), FROM_UINT16_TO_INT32}, |
|
|
|
{std::pair<TypeId, TypeId>(kNumberTypeBool, kNumberTypeInt32), FROM_BOOL_TO_INT32}, |
|
|
|
{std::pair<TypeId, TypeId>(kNumberTypeBool, kNumberTypeFloat), FROM_BOOL_TO_FLOAT}, |
|
|
|
{std::pair<TypeId, TypeId>(kNumberTypeBool, kNumberTypeUInt8), FROM_BOOL_TO_UINT8}, |
|
|
|
{std::pair<TypeId, TypeId>(kNumberTypeBool, kNumberTypeFloat16), FROM_BOOL_TO_FLOAT16}}; |
|
|
|
|
|
|
|
void CheckMemSize(const TypeIdArgs &args) { |
|
|
|
auto src_type_size = TypeIdSize(args.host_data_type); |
|
|
|
@@ -154,54 +173,46 @@ void TransDataSrc2Fp16(const TypeIdArgs &args, void *dst, const size_t data_size |
|
|
|
} |
|
|
|
|
|
|
|
bool CastKernel(const TypeIdArgs &args, void *dst, const size_t data_size, const DataTypeTransMode mode) { |
|
|
|
switch (mode) { |
|
|
|
case FROM_FLOAT_TO_FLOAT16: |
|
|
|
device::FloatToHalf(dst, args.data, data_size); |
|
|
|
break; |
|
|
|
case FROM_INT32_TO_FLOAT16: |
|
|
|
TransDataSrc2Fp16<int32_t>(args, dst, data_size); |
|
|
|
break; |
|
|
|
case FROM_FLOAT16_TO_FLOAT: |
|
|
|
device::HalfToFloat(dst, args.data, data_size); |
|
|
|
break; |
|
|
|
case FROM_FLOAT_TO_INT32: |
|
|
|
TransDataSrc2Dst<float, int32_t>(args, dst, data_size); |
|
|
|
break; |
|
|
|
case FROM_FLOAT16_TO_INT32: |
|
|
|
TransDataSrc2Dst<float16, int32_t>(args, dst, data_size); |
|
|
|
break; |
|
|
|
case FROM_INT32_TO_FLOAT: |
|
|
|
TransDataSrc2Dst<int32_t, float>(args, dst, data_size); |
|
|
|
break; |
|
|
|
case FROM_INT32_TO_INT8: |
|
|
|
TransDataSrc2Dst<int32_t, int8_t>(args, dst, data_size); |
|
|
|
break; |
|
|
|
case FROM_INT32_TO_UINT8: |
|
|
|
TransDataSrc2Dst<int32_t, uint8_t>(args, dst, data_size); |
|
|
|
break; |
|
|
|
case FROM_UINT8_TO_INT32: |
|
|
|
TransDataSrc2Dst<uint8_t, int32_t>(args, dst, data_size); |
|
|
|
break; |
|
|
|
case FROM_UINT8_TO_FLOAT: |
|
|
|
TransDataSrc2Dst<uint8_t, float>(args, dst, data_size); |
|
|
|
break; |
|
|
|
case FROM_INT8_TO_FLOAT: |
|
|
|
TransDataSrc2Dst<int8_t, float>(args, dst, data_size); |
|
|
|
break; |
|
|
|
case FROM_INT8_TO_INT32: |
|
|
|
TransDataSrc2Dst<int8_t, int32_t>(args, dst, data_size); |
|
|
|
break; |
|
|
|
case FROM_INT64_TO_INT32: |
|
|
|
TransDataSrc2Dst<int64_t, int32_t>(args, dst, data_size); |
|
|
|
break; |
|
|
|
case FROM_UINT16_TO_INT32: |
|
|
|
TransDataSrc2Dst<uint16_t, int32_t>(args, dst, data_size); |
|
|
|
break; |
|
|
|
default: |
|
|
|
MS_LOG(ERROR) << "Unsupported datatype trans"; |
|
|
|
return false; |
|
|
|
using DtypeKernel = std::function<void(const TypeIdArgs &, void *, const size_t)>; |
|
|
|
const std::map<DataTypeTransMode, DtypeKernel> cast_kernel_map{ |
|
|
|
{FROM_FLOAT_TO_INT32, TransDataSrc2Dst<float, int32_t>}, |
|
|
|
{FROM_FLOAT64_TO_FLOAT32, TransDataSrc2Dst<double, float>}, |
|
|
|
{FROM_FLOAT32_TO_FLOAT64, TransDataSrc2Dst<float, double>}, |
|
|
|
{FROM_FLOAT16_TO_INT32, TransDataSrc2Dst<float16, int32_t>}, |
|
|
|
{FROM_FLOAT16_TO_UINT8, TransDataSrc2Dst<float16, uint8_t>}, |
|
|
|
{FROM_INT32_TO_FLOAT, TransDataSrc2Dst<int32_t, float>}, |
|
|
|
{FROM_INT32_TO_INT8, TransDataSrc2Dst<int32_t, int8_t>}, |
|
|
|
{FROM_INT32_TO_UINT8, TransDataSrc2Dst<int32_t, uint8_t>}, |
|
|
|
{FROM_INT32_TO_BOOL, TransDataSrc2Dst<int32_t, int8_t>}, |
|
|
|
{FROM_INT32_TO_FLOAT16, TransDataSrc2Fp16<int32_t>}, |
|
|
|
{FROM_UINT8_TO_FLOAT, TransDataSrc2Dst<uint8_t, float>}, |
|
|
|
{FROM_UINT8_TO_INT32, TransDataSrc2Dst<uint8_t, int32_t>}, |
|
|
|
{FROM_UINT8_TO_FLOAT16, TransDataSrc2Fp16<uint8_t>}, |
|
|
|
{FROM_INT8_TO_FLOAT, TransDataSrc2Dst<int8_t, float>}, |
|
|
|
{FROM_INT8_TO_FLOAT16, TransDataSrc2Fp16<int8_t>}, |
|
|
|
{FROM_INT8_TO_INT32, TransDataSrc2Dst<int8_t, int32_t>}, |
|
|
|
{FROM_INT64_TO_INT32, TransDataSrc2Dst<int64_t, int32_t>}, |
|
|
|
{FROM_UINT16_TO_INT32, TransDataSrc2Dst<uint16_t, int32_t>}, |
|
|
|
{FROM_BOOL_TO_INT32, TransDataSrc2Dst<int8_t, int32_t>}, |
|
|
|
{FROM_BOOL_TO_FLOAT, TransDataSrc2Dst<int8_t, float>}, |
|
|
|
{FROM_BOOL_TO_UINT8, TransDataSrc2Dst<int8_t, uint8_t>}, |
|
|
|
{FROM_BOOL_TO_FLOAT16, TransDataSrc2Fp16<int8_t>}}; |
|
|
|
|
|
|
|
if (mode == FROM_FLOAT_TO_FLOAT16) { |
|
|
|
device::FloatToHalf(dst, args.data, data_size); |
|
|
|
return true; |
|
|
|
} else if (mode == FROM_FLOAT16_TO_FLOAT) { |
|
|
|
device::HalfToFloat(dst, args.data, data_size); |
|
|
|
return true; |
|
|
|
} |
|
|
|
auto iter = cast_kernel_map.find(mode); |
|
|
|
if (iter != cast_kernel_map.end()) { |
|
|
|
iter->second(args, dst, data_size); |
|
|
|
return true; |
|
|
|
} else { |
|
|
|
MS_LOG(ERROR) << "Unsupported datatype trans"; |
|
|
|
return false; |
|
|
|
} |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
size_t CubeSizeByType(const TypeId data_type) { |
|
|
|
|