|
|
@@ -27,17 +27,34 @@ namespace dataset { |
|
|
SentencePieceTokenizerOp::SentencePieceTokenizerOp(const std::shared_ptr<SentencePieceVocab> vocab, |
|
|
SentencePieceTokenizerOp::SentencePieceTokenizerOp(const std::shared_ptr<SentencePieceVocab> vocab, |
|
|
const SPieceTokenizerLoadType load_type, |
|
|
const SPieceTokenizerLoadType load_type, |
|
|
const SPieceTokenizerOutType out_type) |
|
|
const SPieceTokenizerOutType out_type) |
|
|
: vocab_(vocab), load_type_(load_type), out_type_(out_type) {} |
|
|
|
|
|
|
|
|
: vocab_(vocab), load_type_(load_type), out_type_(out_type) { |
|
|
|
|
|
auto status = processor_.LoadFromSerializedProto(vocab_.get()->model_proto()); |
|
|
|
|
|
if (!status.ok()) { |
|
|
|
|
|
model_status_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "parser vocab model filed."); |
|
|
|
|
|
} else { |
|
|
|
|
|
model_status_ = Status::OK(); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
SentencePieceTokenizerOp::SentencePieceTokenizerOp(const std::string &model_path, const std::string &model_filename, |
|
|
SentencePieceTokenizerOp::SentencePieceTokenizerOp(const std::string &model_path, const std::string &model_filename, |
|
|
const SPieceTokenizerLoadType load_type, |
|
|
const SPieceTokenizerLoadType load_type, |
|
|
const SPieceTokenizerOutType out_type) |
|
|
const SPieceTokenizerOutType out_type) |
|
|
: load_type_(load_type), out_type_(out_type) { |
|
|
: load_type_(load_type), out_type_(out_type) { |
|
|
(void)GetModelRealPath(model_path, model_filename); |
|
|
(void)GetModelRealPath(model_path, model_filename); |
|
|
|
|
|
auto status = processor_.Load(file_path_); |
|
|
|
|
|
if (!status.ok()) { |
|
|
|
|
|
model_status_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "load vocab model filed."); |
|
|
|
|
|
} else { |
|
|
|
|
|
model_status_ = Status::OK(); |
|
|
|
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
Status SentencePieceTokenizerOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) { |
|
|
Status SentencePieceTokenizerOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) { |
|
|
IO_CHECK(input, output); |
|
|
IO_CHECK(input, output); |
|
|
|
|
|
if (!model_status_.IsOk()) { |
|
|
|
|
|
return model_status_; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
if (input->Rank() != 0 || input->type() != DataType::DE_STRING) { |
|
|
if (input->Rank() != 0 || input->type() != DataType::DE_STRING) { |
|
|
RETURN_STATUS_UNEXPECTED("the input tensor should be scalar string tensor"); |
|
|
RETURN_STATUS_UNEXPECTED("the input tensor should be scalar string tensor"); |
|
|
} |
|
|
} |
|
|
@@ -45,18 +62,6 @@ Status SentencePieceTokenizerOp::Compute(const std::shared_ptr<Tensor> &input, s |
|
|
std::string_view sentence_v; |
|
|
std::string_view sentence_v; |
|
|
RETURN_IF_NOT_OK(input->GetItemAt(&sentence_v, {})); |
|
|
RETURN_IF_NOT_OK(input->GetItemAt(&sentence_v, {})); |
|
|
std::string sentence{sentence_v}; |
|
|
std::string sentence{sentence_v}; |
|
|
if (load_type_ == SPieceTokenizerLoadType::kFile) { |
|
|
|
|
|
auto status = processor_.Load(file_path_); |
|
|
|
|
|
if (!status.ok()) { |
|
|
|
|
|
RETURN_STATUS_UNEXPECTED("load sentence piece model failed."); |
|
|
|
|
|
} |
|
|
|
|
|
} else { |
|
|
|
|
|
RETURN_UNEXPECTED_IF_NULL(vocab_); |
|
|
|
|
|
auto status = processor_.LoadFromSerializedProto(vocab_.get()->model_proto()); |
|
|
|
|
|
if (!status.ok()) { |
|
|
|
|
|
RETURN_STATUS_UNEXPECTED("sentence piece load model failed."); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
if (out_type_ == SPieceTokenizerOutType::kString) { |
|
|
if (out_type_ == SPieceTokenizerOutType::kString) { |
|
|
std::vector<std::string> pieces; |
|
|
std::vector<std::string> pieces; |
|
|
|