|
|
|
@@ -281,9 +281,11 @@ void CastFrom(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *out |
|
|
|
case DataType::DE_UINT64: |
|
|
|
Cast<T, uint64_t>(input, output); |
|
|
|
break; |
|
|
|
#ifndef ENABLE_MD_LITE_X86_64 |
|
|
|
case DataType::DE_FLOAT16: |
|
|
|
Cast<T, float16>(input, output); |
|
|
|
break; |
|
|
|
#endif |
|
|
|
case DataType::DE_FLOAT32: |
|
|
|
Cast<T, float>(input, output); |
|
|
|
break; |
|
|
|
@@ -328,9 +330,11 @@ Status TypeCast(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *o |
|
|
|
case DataType::DE_UINT64: |
|
|
|
CastFrom<uint64_t>(input, output); |
|
|
|
break; |
|
|
|
#ifndef ENABLE_MD_LITE_X86_64 |
|
|
|
case DataType::DE_FLOAT16: |
|
|
|
CastFrom<float16>(input, output); |
|
|
|
break; |
|
|
|
#endif |
|
|
|
case DataType::DE_FLOAT32: |
|
|
|
CastFrom<float>(input, output); |
|
|
|
break; |
|
|
|
@@ -344,6 +348,7 @@ Status TypeCast(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *o |
|
|
|
return Status::OK(); |
|
|
|
} |
|
|
|
|
|
|
|
#ifndef ENABLE_MD_LITE_X86_64 |
|
|
|
Status ToFloat16(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) { |
|
|
|
// initiate new tensor for type cast |
|
|
|
DataType new_type = DataType("float16"); |
|
|
|
@@ -367,6 +372,9 @@ Status ToFloat16(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> * |
|
|
|
|
|
|
|
return Status::OK(); |
|
|
|
} |
|
|
|
#else |
|
|
|
Status ToFloat16(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) { return Status::OK(); } |
|
|
|
#endif |
|
|
|
|
|
|
|
Status PadEnd(const std::shared_ptr<Tensor> &src, std::shared_ptr<Tensor> *dst, const std::vector<dsize_t> &pad_shape, |
|
|
|
const std::shared_ptr<Tensor> &pad_val) { |
|
|
|
@@ -410,9 +418,13 @@ Status PadEndNumeric(const std::shared_ptr<Tensor> &src, std::shared_ptr<Tensor> |
|
|
|
RETURN_IF_NOT_OK((*dst)->Fill<uint8_t>(pad_val)); |
|
|
|
} else if (tensor_type == DataType::DE_INT16) { |
|
|
|
RETURN_IF_NOT_OK((*dst)->Fill<int16_t>(pad_val)); |
|
|
|
} else if (tensor_type == DataType::DE_FLOAT16) { |
|
|
|
} |
|
|
|
#ifndef ENABLE_MD_LITE_X86_64 |
|
|
|
else if (tensor_type == DataType::DE_FLOAT16) { // NOLINT |
|
|
|
RETURN_IF_NOT_OK((*dst)->Fill<float16>(static_cast<float16>(pad_val))); |
|
|
|
} else if (tensor_type == DataType::DE_UINT16) { |
|
|
|
} |
|
|
|
#endif |
|
|
|
else if (tensor_type == DataType::DE_UINT16) { // NOLINT |
|
|
|
RETURN_IF_NOT_OK((*dst)->Fill<uint16_t>(pad_val)); |
|
|
|
} else if (tensor_type == DataType::DE_INT32) { |
|
|
|
RETURN_IF_NOT_OK((*dst)->Fill<int32_t>(pad_val)); |
|
|
|
@@ -570,9 +582,11 @@ Status Mask(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *outpu |
|
|
|
case DataType::DE_INT64: |
|
|
|
RETURN_IF_NOT_OK(MaskHelper<int64_t>(input, *output, casted_value, op)); |
|
|
|
break; |
|
|
|
#ifndef ENABLE_MD_LITE_X86_64 |
|
|
|
case DataType::DE_FLOAT16: |
|
|
|
RETURN_IF_NOT_OK(MaskHelper<float16>(input, *output, casted_value, op)); |
|
|
|
break; |
|
|
|
#endif |
|
|
|
case DataType::DE_FLOAT32: |
|
|
|
RETURN_IF_NOT_OK(MaskHelper<float>(input, *output, casted_value, op)); |
|
|
|
break; |
|
|
|
@@ -732,6 +746,7 @@ struct UniqueOpHashMap<float16> { |
|
|
|
}; |
|
|
|
|
|
|
|
#else |
|
|
|
#ifndef ENABLE_MD_LITE_X86_64 |
|
|
|
struct gn_hash { |
|
|
|
size_t operator()(const float16 &f) const { return static_cast<std::size_t>(f); } |
|
|
|
}; |
|
|
|
@@ -740,7 +755,7 @@ template <> |
|
|
|
struct UniqueOpHashMap<float16> { |
|
|
|
using map_type = std::unordered_map<float16, int32_t, gn_hash>; |
|
|
|
}; |
|
|
|
|
|
|
|
#endif |
|
|
|
#endif |
|
|
|
|
|
|
|
template <> |
|
|
|
@@ -809,9 +824,13 @@ Status Unique(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *out |
|
|
|
RETURN_IF_NOT_OK(UniqueHelper<uint16_t>(input, output, output_idx, output_cnt)); |
|
|
|
} else if (input->type() == DataType::DE_UINT8) { |
|
|
|
RETURN_IF_NOT_OK(UniqueHelper<uint8_t>(input, output, output_idx, output_cnt)); |
|
|
|
} else if (input->type() == DataType::DE_FLOAT16) { |
|
|
|
} |
|
|
|
#ifndef ENABLE_MD_LITE_X86_64 |
|
|
|
else if (input->type() == DataType::DE_FLOAT16) { // NOLINT |
|
|
|
RETURN_IF_NOT_OK(UniqueHelper<float16>(input, output, output_idx, output_cnt)); |
|
|
|
} else if (input->type() == DataType::DE_FLOAT32) { |
|
|
|
} |
|
|
|
#endif |
|
|
|
else if (input->type() == DataType::DE_FLOAT32) { // NOLINT |
|
|
|
RETURN_IF_NOT_OK(UniqueHelper<float>(input, output, output_idx, output_cnt)); |
|
|
|
} else if (input->type() == DataType::DE_FLOAT64) { |
|
|
|
RETURN_IF_NOT_OK(UniqueHelper<double>(input, output, output_idx, output_cnt)); |
|
|
|
|