| @@ -154,8 +154,8 @@ PYBIND_REGISTER( | |||
| PYBIND_REGISTER(LookupOperation, 1, ([](const py::module *m) { | |||
| (void)py::class_<text::LookupOperation, TensorOperation, std::shared_ptr<text::LookupOperation>>( | |||
| *m, "LookupOperation") | |||
| .def(py::init([](const std::shared_ptr<Vocab> &vocab, const std::string &unknown_token, | |||
| const std::string &data_type) { | |||
| .def(py::init([](const std::shared_ptr<Vocab> &vocab, | |||
| const std::optional<std::string> &unknown_token, const std::string &data_type) { | |||
| auto lookup = std::make_shared<text::LookupOperation>(vocab, unknown_token, data_type); | |||
| THROW_IF_ERROR(lookup->ValidateParams()); | |||
| return lookup; | |||
| @@ -87,8 +87,8 @@ std::shared_ptr<JiebaTokenizerOperation> JiebaTokenizer(const std::string &hmm_p | |||
| return op->ValidateParams() ? op : nullptr; | |||
| } | |||
| std::shared_ptr<LookupOperation> Lookup(const std::shared_ptr<Vocab> &vocab, const std::string &unknown_token, | |||
| const std::string &data_type) { | |||
| std::shared_ptr<LookupOperation> Lookup(const std::shared_ptr<Vocab> &vocab, | |||
| const std::optional<std::string> &unknown_token, const std::string &data_type) { | |||
| auto op = std::make_shared<LookupOperation>(vocab, unknown_token, data_type); | |||
| return op->ValidateParams() ? op : nullptr; | |||
| @@ -340,7 +340,7 @@ Status JiebaTokenizerOperation::AddWord(const std::string &word, int64_t freq) { | |||
| } | |||
| // LookupOperation | |||
| LookupOperation::LookupOperation(const std::shared_ptr<Vocab> &vocab, const std::string &unknown_token, | |||
| LookupOperation::LookupOperation(const std::shared_ptr<Vocab> &vocab, const std::optional<std::string> &unknown_token, | |||
| const std::string &data_type) | |||
| : vocab_(vocab), unknown_token_(unknown_token), default_id_(Vocab::kNoTokenExists), data_type_(data_type) {} | |||
| @@ -352,10 +352,10 @@ Status LookupOperation::ValidateParams() { | |||
| MS_LOG(ERROR) << err_msg; | |||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||
| } | |||
| if (!unknown_token_.empty()) { | |||
| default_id_ = vocab_->Lookup(unknown_token_); | |||
| if (unknown_token_ != std::nullopt) { | |||
| default_id_ = vocab_->Lookup(*unknown_token_); | |||
| if (default_id_ == Vocab::kNoTokenExists) { | |||
| std::string err_msg = "Lookup: \"" + unknown_token_ + "\" doesn't exist in vocab."; | |||
| std::string err_msg = "Lookup: \"" + *unknown_token_ + "\" doesn't exist in vocab."; | |||
| MS_LOG(ERROR) << err_msg; | |||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||
| } | |||
| @@ -18,6 +18,7 @@ | |||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_TEXT_H_ | |||
| #include <memory> | |||
| #include <optional> | |||
| #include <string> | |||
| #include <utility> | |||
| #include <vector> | |||
| @@ -143,11 +144,13 @@ std::shared_ptr<JiebaTokenizerOperation> JiebaTokenizer(const std::string &hmm_p | |||
| /// \brief Lookup operator that looks up a word to an id. | |||
| /// \param[in] vocab a Vocab object. | |||
| /// \param[in] unknown_token word to use for lookup if the word being looked up is out of Vocabulary (oov). | |||
| /// If unknown_token is oov, runtime error will be thrown. | |||
| /// If unknown_token is oov, runtime error will be thrown. If unknown_token is {}, which means that not to | |||
| // specify unknown_token when word being out of Vocabulary (default={}). | |||
| /// \param[in] data_type type of the tensor after lookup, typically int32. | |||
| /// \return Shared pointer to the current TensorOperation. | |||
| std::shared_ptr<LookupOperation> Lookup(const std::shared_ptr<Vocab> &vocab, const std::string &unknown_token, | |||
| std::shared_ptr<LookupOperation> Lookup(const std::shared_ptr<Vocab> &vocab, | |||
| const std::optional<std::string> &unknown_token = {}, | |||
| const std::string &data_type = "int32"); | |||
| /// \brief TensorOp to generate n-gram from a 1-D string Tensor. | |||
| @@ -343,7 +346,7 @@ class JiebaTokenizerOperation : public TensorOperation { | |||
| class LookupOperation : public TensorOperation { | |||
| public: | |||
| explicit LookupOperation(const std::shared_ptr<Vocab> &vocab, const std::string &unknown_token, | |||
| explicit LookupOperation(const std::shared_ptr<Vocab> &vocab, const std::optional<std::string> &unknown_token, | |||
| const std::string &data_type); | |||
| ~LookupOperation(); | |||
| @@ -356,7 +359,7 @@ class LookupOperation : public TensorOperation { | |||
| private: | |||
| std::shared_ptr<Vocab> vocab_; | |||
| std::string unknown_token_; | |||
| std::optional<std::string> unknown_token_; | |||
| int32_t default_id_; | |||
| std::string data_type_; | |||
| }; | |||
| @@ -295,7 +295,7 @@ class Lookup(TextTensorOperation): | |||
| @check_lookup | |||
| def __init__(self, vocab, unknown_token=None, data_type=mstype.int32): | |||
| self.vocab = vocab | |||
| self.unknown_token = replace_none(unknown_token, '') | |||
| self.unknown_token = unknown_token | |||
| self.data_type = data_type | |||
| def parse(self): | |||
| @@ -119,6 +119,31 @@ def test_from_list(): | |||
| assert "is not of type" in test_config("w1", ["w1", "w2"], ["s1"], True, 123) | |||
| def test_from_list_lookup_empty_string(): | |||
| # "" is a valid word in vocab, which can be looked up by LookupOp | |||
| vocab = text.Vocab.from_list("home IS behind the world ahead !".split(" "), ["<pad>", ""], True) | |||
| lookup = text.Lookup(vocab, "") | |||
| data = ds.TextFileDataset(DATA_FILE, shuffle=False) | |||
| data = data.map(operations=lookup, input_columns=["text"]) | |||
| ind = 0 | |||
| res = [2, 1, 4, 5, 6, 7] | |||
| for d in data.create_dict_iterator(num_epochs=1, output_numpy=True): | |||
| assert d["text"] == res[ind], ind | |||
| ind += 1 | |||
| # unknown_token of LookUp is None, it will convert to std::nullopt in C++, | |||
| # so it has nothing to do with "" in vocab and C++ will skip looking up unknown_token | |||
| vocab = text.Vocab.from_list("home IS behind the world ahead !".split(" "), ["<pad>", ""], True) | |||
| lookup = text.Lookup(vocab) | |||
| data = ds.TextFileDataset(DATA_FILE, shuffle=False) | |||
| data = data.map(operations=lookup, input_columns=["text"]) | |||
| try: | |||
| for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True): | |||
| pass | |||
| except RuntimeError as e: | |||
| assert "token: \"is\" doesn't exist in vocab and no unknown token is specified" in str(e) | |||
| def test_from_file(): | |||
| def gen(texts): | |||
| for word in texts.split(" "): | |||