| @@ -154,8 +154,8 @@ PYBIND_REGISTER( | |||||
| PYBIND_REGISTER(LookupOperation, 1, ([](const py::module *m) { | PYBIND_REGISTER(LookupOperation, 1, ([](const py::module *m) { | ||||
| (void)py::class_<text::LookupOperation, TensorOperation, std::shared_ptr<text::LookupOperation>>( | (void)py::class_<text::LookupOperation, TensorOperation, std::shared_ptr<text::LookupOperation>>( | ||||
| *m, "LookupOperation") | *m, "LookupOperation") | ||||
| .def(py::init([](const std::shared_ptr<Vocab> &vocab, const std::string &unknown_token, | |||||
| const std::string &data_type) { | |||||
| .def(py::init([](const std::shared_ptr<Vocab> &vocab, | |||||
| const std::optional<std::string> &unknown_token, const std::string &data_type) { | |||||
| auto lookup = std::make_shared<text::LookupOperation>(vocab, unknown_token, data_type); | auto lookup = std::make_shared<text::LookupOperation>(vocab, unknown_token, data_type); | ||||
| THROW_IF_ERROR(lookup->ValidateParams()); | THROW_IF_ERROR(lookup->ValidateParams()); | ||||
| return lookup; | return lookup; | ||||
| @@ -87,8 +87,8 @@ std::shared_ptr<JiebaTokenizerOperation> JiebaTokenizer(const std::string &hmm_p | |||||
| return op->ValidateParams() ? op : nullptr; | return op->ValidateParams() ? op : nullptr; | ||||
| } | } | ||||
| std::shared_ptr<LookupOperation> Lookup(const std::shared_ptr<Vocab> &vocab, const std::string &unknown_token, | |||||
| const std::string &data_type) { | |||||
| std::shared_ptr<LookupOperation> Lookup(const std::shared_ptr<Vocab> &vocab, | |||||
| const std::optional<std::string> &unknown_token, const std::string &data_type) { | |||||
| auto op = std::make_shared<LookupOperation>(vocab, unknown_token, data_type); | auto op = std::make_shared<LookupOperation>(vocab, unknown_token, data_type); | ||||
| return op->ValidateParams() ? op : nullptr; | return op->ValidateParams() ? op : nullptr; | ||||
| @@ -340,7 +340,7 @@ Status JiebaTokenizerOperation::AddWord(const std::string &word, int64_t freq) { | |||||
| } | } | ||||
| // LookupOperation | // LookupOperation | ||||
| LookupOperation::LookupOperation(const std::shared_ptr<Vocab> &vocab, const std::string &unknown_token, | |||||
| LookupOperation::LookupOperation(const std::shared_ptr<Vocab> &vocab, const std::optional<std::string> &unknown_token, | |||||
| const std::string &data_type) | const std::string &data_type) | ||||
| : vocab_(vocab), unknown_token_(unknown_token), default_id_(Vocab::kNoTokenExists), data_type_(data_type) {} | : vocab_(vocab), unknown_token_(unknown_token), default_id_(Vocab::kNoTokenExists), data_type_(data_type) {} | ||||
| @@ -352,10 +352,10 @@ Status LookupOperation::ValidateParams() { | |||||
| MS_LOG(ERROR) << err_msg; | MS_LOG(ERROR) << err_msg; | ||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | RETURN_STATUS_SYNTAX_ERROR(err_msg); | ||||
| } | } | ||||
| if (!unknown_token_.empty()) { | |||||
| default_id_ = vocab_->Lookup(unknown_token_); | |||||
| if (unknown_token_ != std::nullopt) { | |||||
| default_id_ = vocab_->Lookup(*unknown_token_); | |||||
| if (default_id_ == Vocab::kNoTokenExists) { | if (default_id_ == Vocab::kNoTokenExists) { | ||||
| std::string err_msg = "Lookup: \"" + unknown_token_ + "\" doesn't exist in vocab."; | |||||
| std::string err_msg = "Lookup: \"" + *unknown_token_ + "\" doesn't exist in vocab."; | |||||
| MS_LOG(ERROR) << err_msg; | MS_LOG(ERROR) << err_msg; | ||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | RETURN_STATUS_SYNTAX_ERROR(err_msg); | ||||
| } | } | ||||
| @@ -18,6 +18,7 @@ | |||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_TEXT_H_ | #define MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_TEXT_H_ | ||||
| #include <memory> | #include <memory> | ||||
| #include <optional> | |||||
| #include <string> | #include <string> | ||||
| #include <utility> | #include <utility> | ||||
| #include <vector> | #include <vector> | ||||
| @@ -143,11 +144,13 @@ std::shared_ptr<JiebaTokenizerOperation> JiebaTokenizer(const std::string &hmm_p | |||||
| /// \brief Lookup operator that looks up a word to an id. | /// \brief Lookup operator that looks up a word to an id. | ||||
| /// \param[in] vocab a Vocab object. | /// \param[in] vocab a Vocab object. | ||||
| /// \param[in] unknown_token word to use for lookup if the word being looked up is out of Vocabulary (oov). | /// \param[in] unknown_token word to use for lookup if the word being looked up is out of Vocabulary (oov). | ||||
| /// If unknown_token is oov, runtime error will be thrown. | |||||
| /// If unknown_token is oov, runtime error will be thrown. If unknown_token is {}, which means that not to | |||||
| // specify unknown_token when word being out of Vocabulary (default={}). | |||||
| /// \param[in] data_type type of the tensor after lookup, typically int32. | /// \param[in] data_type type of the tensor after lookup, typically int32. | ||||
| /// \return Shared pointer to the current TensorOperation. | /// \return Shared pointer to the current TensorOperation. | ||||
| std::shared_ptr<LookupOperation> Lookup(const std::shared_ptr<Vocab> &vocab, const std::string &unknown_token, | |||||
| std::shared_ptr<LookupOperation> Lookup(const std::shared_ptr<Vocab> &vocab, | |||||
| const std::optional<std::string> &unknown_token = {}, | |||||
| const std::string &data_type = "int32"); | const std::string &data_type = "int32"); | ||||
| /// \brief TensorOp to generate n-gram from a 1-D string Tensor. | /// \brief TensorOp to generate n-gram from a 1-D string Tensor. | ||||
| @@ -343,7 +346,7 @@ class JiebaTokenizerOperation : public TensorOperation { | |||||
| class LookupOperation : public TensorOperation { | class LookupOperation : public TensorOperation { | ||||
| public: | public: | ||||
| explicit LookupOperation(const std::shared_ptr<Vocab> &vocab, const std::string &unknown_token, | |||||
| explicit LookupOperation(const std::shared_ptr<Vocab> &vocab, const std::optional<std::string> &unknown_token, | |||||
| const std::string &data_type); | const std::string &data_type); | ||||
| ~LookupOperation(); | ~LookupOperation(); | ||||
| @@ -356,7 +359,7 @@ class LookupOperation : public TensorOperation { | |||||
| private: | private: | ||||
| std::shared_ptr<Vocab> vocab_; | std::shared_ptr<Vocab> vocab_; | ||||
| std::string unknown_token_; | |||||
| std::optional<std::string> unknown_token_; | |||||
| int32_t default_id_; | int32_t default_id_; | ||||
| std::string data_type_; | std::string data_type_; | ||||
| }; | }; | ||||
| @@ -295,7 +295,7 @@ class Lookup(TextTensorOperation): | |||||
| @check_lookup | @check_lookup | ||||
| def __init__(self, vocab, unknown_token=None, data_type=mstype.int32): | def __init__(self, vocab, unknown_token=None, data_type=mstype.int32): | ||||
| self.vocab = vocab | self.vocab = vocab | ||||
| self.unknown_token = replace_none(unknown_token, '') | |||||
| self.unknown_token = unknown_token | |||||
| self.data_type = data_type | self.data_type = data_type | ||||
| def parse(self): | def parse(self): | ||||
| @@ -119,6 +119,31 @@ def test_from_list(): | |||||
| assert "is not of type" in test_config("w1", ["w1", "w2"], ["s1"], True, 123) | assert "is not of type" in test_config("w1", ["w1", "w2"], ["s1"], True, 123) | ||||
| def test_from_list_lookup_empty_string(): | |||||
| # "" is a valid word in vocab, which can be looked up by LookupOp | |||||
| vocab = text.Vocab.from_list("home IS behind the world ahead !".split(" "), ["<pad>", ""], True) | |||||
| lookup = text.Lookup(vocab, "") | |||||
| data = ds.TextFileDataset(DATA_FILE, shuffle=False) | |||||
| data = data.map(operations=lookup, input_columns=["text"]) | |||||
| ind = 0 | |||||
| res = [2, 1, 4, 5, 6, 7] | |||||
| for d in data.create_dict_iterator(num_epochs=1, output_numpy=True): | |||||
| assert d["text"] == res[ind], ind | |||||
| ind += 1 | |||||
| # unknown_token of LookUp is None, it will convert to std::nullopt in C++, | |||||
| # so it has nothing to do with "" in vocab and C++ will skip looking up unknown_token | |||||
| vocab = text.Vocab.from_list("home IS behind the world ahead !".split(" "), ["<pad>", ""], True) | |||||
| lookup = text.Lookup(vocab) | |||||
| data = ds.TextFileDataset(DATA_FILE, shuffle=False) | |||||
| data = data.map(operations=lookup, input_columns=["text"]) | |||||
| try: | |||||
| for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True): | |||||
| pass | |||||
| except RuntimeError as e: | |||||
| assert "token: \"is\" doesn't exist in vocab and no unknown token is specified" in str(e) | |||||
| def test_from_file(): | def test_from_file(): | ||||
| def gen(texts): | def gen(texts): | ||||
| for word in texts.split(" "): | for word in texts.split(" "): | ||||