|
|
|
@@ -26,7 +26,7 @@ LookupOp::LookupOp(std::shared_ptr<Vocab> vocab, WordIdType default_id) |
|
|
|
Status LookupOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) { |
|
|
|
IO_CHECK(input, output); |
|
|
|
RETURN_UNEXPECTED_IF_NULL(vocab_); |
|
|
|
CHECK_FAIL_RETURN_UNEXPECTED(input->type() == DataType::DE_STRING, "None String Tensor."); |
|
|
|
CHECK_FAIL_RETURN_UNEXPECTED(input->type() == DataType::DE_STRING, "None string tensor received."); |
|
|
|
std::vector<WordIdType> word_ids; |
|
|
|
word_ids.reserve(input->Size()); |
|
|
|
for (auto itr = input->begin<std::string_view>(); itr != input->end<std::string_view>(); itr++) { |
|
|
|
@@ -34,7 +34,7 @@ Status LookupOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<T |
|
|
|
word_ids.emplace_back(word_id == Vocab::kNoTokenExists ? default_id_ : word_id); |
|
|
|
CHECK_FAIL_RETURN_UNEXPECTED( |
|
|
|
word_ids.back() != Vocab::kNoTokenExists, |
|
|
|
"Lookup Error: token" + std::string(*itr) + "doesn't exist in vocab and no unknown token is specified."); |
|
|
|
"Lookup Error: token: " + std::string(*itr) + " doesn't exist in vocab and no unknown token is specified."); |
|
|
|
} |
|
|
|
|
|
|
|
RETURN_IF_NOT_OK(Tensor::CreateTensor(output, TensorImpl::kFlexible, input->shape(), type_, |
|
|
|
@@ -42,8 +42,8 @@ Status LookupOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<T |
|
|
|
return Status::OK(); |
|
|
|
} |
|
|
|
Status LookupOp::OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) { |
|
|
|
CHECK_FAIL_RETURN_UNEXPECTED(inputs.size() == NumInput() && outputs.size() == NumOutput(), "size doesn't match"); |
|
|
|
CHECK_FAIL_RETURN_UNEXPECTED(inputs[0] == DataType::DE_STRING, "None String tensor type"); |
|
|
|
CHECK_FAIL_RETURN_UNEXPECTED(inputs.size() == NumInput() && outputs.size() == NumOutput(), "size doesn't match."); |
|
|
|
CHECK_FAIL_RETURN_UNEXPECTED(inputs[0] == DataType::DE_STRING, "None String tensor type."); |
|
|
|
outputs[0] = type_; |
|
|
|
return Status::OK(); |
|
|
|
} |
|
|
|
|