|
|
|
@@ -101,13 +101,20 @@ const std::map<std::pair<TypeId, TypeId>, DataTypeTransMode> mode_map{ |
|
|
|
{std::pair<TypeId, TypeId>(kNumberTypeInt64, kNumberTypeInt32), FROM_INT64_TO_INT32}, |
|
|
|
{std::pair<TypeId, TypeId>(kNumberTypeUInt16, kNumberTypeInt32), FROM_UINT16_TO_INT32}}; |
|
|
|
|
|
|
|
template <typename SrcT, typename DstT> |
|
|
|
void TransDataSrc2Dst(const TypeIdArgs &args, void *dst, const size_t data_size) { |
|
|
|
auto src_id = TypeIdSize(args.src_type); |
|
|
|
auto dst_id = TypeIdSize(args.dst_type); |
|
|
|
if (args.src_size / src_id != args.src_shape_size || args.dst_size / dst_id != args.dst_shape_size) { |
|
|
|
void CheckMemSize(const TypeIdArgs &args) { |
|
|
|
auto src_type_size = TypeIdSize(args.host_data_type); |
|
|
|
auto dst_type_size = TypeIdSize(args.device_data_type); |
|
|
|
if (src_type_size < 1 || dst_type_size < 1) { |
|
|
|
MS_LOG(EXCEPTION) << "Invalid src or dst data type."; |
|
|
|
} |
|
|
|
if (args.data_size / src_type_size != args.host_shape_size) { |
|
|
|
MS_LOG(EXCEPTION) << "Invalid src or dst data size."; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
template <typename SrcT, typename DstT> |
|
|
|
void TransDataSrc2Dst(const TypeIdArgs &args, void *dst, const size_t data_size) { |
|
|
|
CheckMemSize(args); |
|
|
|
for (size_t idx = 0; idx != data_size; idx++) { |
|
|
|
SrcT src_data = static_cast<const SrcT *>(args.data)[idx]; |
|
|
|
static_cast<DstT *>(dst)[idx] = static_cast<DstT>(src_data); |
|
|
|
@@ -116,11 +123,7 @@ void TransDataSrc2Dst(const TypeIdArgs &args, void *dst, const size_t data_size) |
|
|
|
|
|
|
|
template <typename SrcT> |
|
|
|
void TransDataSrc2Fp16(const TypeIdArgs &args, void *dst, const size_t data_size) { |
|
|
|
auto src_id = TypeIdSize(args.src_type); |
|
|
|
auto dst_id = TypeIdSize(args.dst_type); |
|
|
|
if (args.src_size / src_id != args.src_shape_size || args.dst_size / dst_id != args.dst_shape_size) { |
|
|
|
MS_LOG(EXCEPTION) << "Invalid src or dst data size."; |
|
|
|
} |
|
|
|
CheckMemSize(args); |
|
|
|
auto src_data = static_cast<const SrcT *>(args.data); |
|
|
|
auto half_data = static_cast<Eigen::half *>(dst); |
|
|
|
for (size_t i = 0; i < data_size; i++) { |
|
|
|
@@ -394,27 +397,18 @@ bool CheckArgs(const FormatArgs &args, size_t *size, size_t *total_size) { |
|
|
|
} |
|
|
|
|
|
|
|
bool TransDataType(const TypeIdArgs &args, void *result) { |
|
|
|
MS_LOG(DEBUG) << "Begin trans datatype from " << TypeIdLabel(args.src_type) << " to " << TypeIdLabel(args.dst_type); |
|
|
|
MS_LOG(DEBUG) << "Begin trans datatype from " << TypeIdLabel(args.host_data_type) << " to " |
|
|
|
<< TypeIdLabel(args.device_data_type); |
|
|
|
MS_EXCEPTION_IF_NULL(result); |
|
|
|
std::pair<TypeId, TypeId> type_info(args.src_type, args.dst_type); |
|
|
|
std::pair<TypeId, TypeId> type_info(args.host_data_type, args.device_data_type); |
|
|
|
auto iter = mode_map.find(type_info); |
|
|
|
if (iter == mode_map.end()) { |
|
|
|
MS_LOG(ERROR) << "Unsupported datatype trans. src_type :" << TypeIdLabel(args.src_type) |
|
|
|
<< ", dst_type:" << TypeIdLabel(args.dst_type); |
|
|
|
MS_LOG(ERROR) << "Unsupported datatype trans. src_type :" << TypeIdLabel(args.host_data_type) |
|
|
|
<< ", dst_type:" << TypeIdLabel(args.device_data_type); |
|
|
|
return false; |
|
|
|
} |
|
|
|
auto trans_mode = iter->second; |
|
|
|
auto src_id = TypeIdSize(args.src_type); |
|
|
|
auto dst_id = TypeIdSize(args.dst_type); |
|
|
|
if (src_id < 1 || dst_id < 1) { |
|
|
|
MS_LOG(ERROR) << "Invalid src or dst data type."; |
|
|
|
return false; |
|
|
|
} |
|
|
|
if (args.src_size / src_id != args.src_shape_size || args.dst_size / dst_id != args.dst_shape_size) { |
|
|
|
MS_LOG(ERROR) << "Invalid src or dst data size."; |
|
|
|
return false; |
|
|
|
} |
|
|
|
if (!CastKernel(args, result, args.dst_shape_size, trans_mode)) { |
|
|
|
if (!CastKernel(args, result, args.host_shape_size, trans_mode)) { |
|
|
|
MS_LOG(ERROR) << "Failed to trans datatype.."; |
|
|
|
return false; |
|
|
|
} |
|
|
|
|