|
|
|
@@ -28,6 +28,7 @@ Status LookupOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<T |
|
|
|
IO_CHECK(input, output); |
|
|
|
RETURN_UNEXPECTED_IF_NULL(vocab_); |
|
|
|
CHECK_FAIL_RETURN_UNEXPECTED(input->type() == DataType::DE_STRING, "None string tensor received."); |
|
|
|
|
|
|
|
std::vector<WordIdType> word_ids; |
|
|
|
word_ids.reserve(input->Size()); |
|
|
|
for (auto itr = input->begin<std::string_view>(); itr != input->end<std::string_view>(); itr++) { |
|
|
|
@@ -41,6 +42,8 @@ Status LookupOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<T |
|
|
|
|
|
|
|
// type cast to user's requirements if what user wants isn't int32_t |
|
|
|
if ((*output)->type() != type_) { |
|
|
|
CHECK_FAIL_RETURN_UNEXPECTED(type_.IsNumeric(), |
|
|
|
"Lookup doesn't support string to string lookup. data_type needs to be numeric"); |
|
|
|
std::shared_ptr<Tensor> cast_to; |
|
|
|
RETURN_IF_NOT_OK(TypeCast(*output, &cast_to, type_)); |
|
|
|
*output = cast_to; |
|
|
|
|