|
|
|
@@ -53,6 +53,7 @@ enum DataTypeTransMode { |
|
|
|
FROM_INT8_TO_FLOAT, |
|
|
|
FROM_INT8_TO_INT32, |
|
|
|
FROM_INT64_TO_INT32, |
|
|
|
FROM_UINT16_TO_INT32, |
|
|
|
}; |
|
|
|
|
|
|
|
const std::map<std::pair<TypeId, TypeId>, DataTypeTransMode> mode_map{ |
|
|
|
@@ -68,7 +69,8 @@ const std::map<std::pair<TypeId, TypeId>, DataTypeTransMode> mode_map{ |
|
|
|
{std::pair<TypeId, TypeId>(kNumberTypeUInt8, kNumberTypeInt32), FROM_UINT8_TO_INT32}, |
|
|
|
{std::pair<TypeId, TypeId>(kNumberTypeInt8, kNumberTypeFloat32), FROM_INT8_TO_FLOAT}, |
|
|
|
{std::pair<TypeId, TypeId>(kNumberTypeInt8, kNumberTypeInt32), FROM_INT8_TO_INT32}, |
|
|
|
{std::pair<TypeId, TypeId>(kNumberTypeInt64, kNumberTypeInt32), FROM_INT64_TO_INT32}}; |
|
|
|
{std::pair<TypeId, TypeId>(kNumberTypeInt64, kNumberTypeInt32), FROM_INT64_TO_INT32}, |
|
|
|
{std::pair<TypeId, TypeId>(kNumberTypeUInt16, kNumberTypeInt32), FROM_UINT16_TO_INT32}}; |
|
|
|
|
|
|
|
template <typename SrcT, typename DstT> |
|
|
|
void TransDataSrc2Dst(const TypeIdArgs &args, void *dst, const size_t data_size) { |
|
|
|
@@ -116,6 +118,9 @@ bool CastKernel(const TypeIdArgs &args, void *dst, const size_t data_size, const |
|
|
|
case FROM_INT64_TO_INT32: |
|
|
|
TransDataSrc2Dst<int64_t, int32_t>(args, dst, data_size); |
|
|
|
break; |
|
|
|
case FROM_UINT16_TO_INT32: |
|
|
|
TransDataSrc2Dst<uint16_t, int32_t>(args, dst, data_size); |
|
|
|
break; |
|
|
|
default: |
|
|
|
MS_LOG(ERROR) << "unsupported datatype trans"; |
|
|
|
return false; |
|
|
|
|