ci add test case address some of the review cmts address review cmtstags/v1.0.0
| @@ -121,12 +121,13 @@ PYBIND_REGISTER(UnicodeCharTokenizerOp, 1, ([](const py::module *m) { | |||
| PYBIND_REGISTER(LookupOp, 1, ([](const py::module *m) { | |||
| (void)py::class_<LookupOp, TensorOp, std::shared_ptr<LookupOp>>(*m, "LookupOp") | |||
| .def(py::init([](std::shared_ptr<Vocab> vocab, const py::object &py_word) { | |||
| .def(py::init([](std::shared_ptr<Vocab> vocab, const py::object &py_word, | |||
| const DataType &data_type) { | |||
| if (vocab == nullptr) { | |||
| THROW_IF_ERROR(Status(StatusCode::kUnexpectedError, "vocab object type is incorrect or null.")); | |||
| } | |||
| if (py_word.is_none()) { | |||
| return std::make_shared<LookupOp>(vocab, Vocab::kNoTokenExists); | |||
| return std::make_shared<LookupOp>(vocab, Vocab::kNoTokenExists, data_type); | |||
| } | |||
| std::string word = py::reinterpret_borrow<py::str>(py_word); | |||
| WordIdType default_id = vocab->Lookup(word); | |||
| @@ -134,7 +135,7 @@ PYBIND_REGISTER(LookupOp, 1, ([](const py::module *m) { | |||
| THROW_IF_ERROR(Status(StatusCode::kUnexpectedError, | |||
| "default unknown token: " + word + " doesn't exist in vocab.")); | |||
| } | |||
| return std::make_shared<LookupOp>(vocab, default_id); | |||
| return std::make_shared<LookupOp>(vocab, default_id, data_type); | |||
| })); | |||
| })); | |||
| @@ -22,8 +22,9 @@ namespace dataset { | |||
| namespace api { | |||
| namespace text { | |||
| std::shared_ptr<LookupOperation> Lookup(const std::shared_ptr<Vocab> &vocab, const std::string &unknown_token) { | |||
| auto op = std::make_shared<LookupOperation>(vocab, unknown_token); | |||
| std::shared_ptr<LookupOperation> Lookup(const std::shared_ptr<Vocab> &vocab, const std::string &unknown_token, | |||
| const DataType &data_type) { | |||
| auto op = std::make_shared<LookupOperation>(vocab, unknown_token, data_type); | |||
| if (!op->ValidateParams()) { | |||
| return nullptr; | |||
| @@ -32,8 +33,9 @@ std::shared_ptr<LookupOperation> Lookup(const std::shared_ptr<Vocab> &vocab, con | |||
| } | |||
| // LookupOperation | |||
| LookupOperation::LookupOperation(const std::shared_ptr<Vocab> &vocab, const std::string &unknown_token) | |||
| : vocab_(vocab), unknown_token_(unknown_token), default_id_(Vocab::kNoTokenExists) {} | |||
| LookupOperation::LookupOperation(const std::shared_ptr<Vocab> &vocab, const std::string &unknown_token, | |||
| const DataType &data_type) | |||
| : vocab_(vocab), unknown_token_(unknown_token), default_id_(Vocab::kNoTokenExists), data_type_(data_type) {} | |||
| bool LookupOperation::ValidateParams() { | |||
| if (vocab_ == nullptr) { | |||
| @@ -54,7 +56,7 @@ bool LookupOperation::ValidateParams() { | |||
| } | |||
| std::shared_ptr<TensorOp> LookupOperation::Build() { | |||
| std::shared_ptr<LookupOp> tensor_op = std::make_shared<LookupOp>(vocab_, default_id_); | |||
| std::shared_ptr<LookupOp> tensor_op = std::make_shared<LookupOp>(vocab_, default_id_, data_type_); | |||
| return tensor_op; | |||
| } | |||
| @@ -20,9 +20,11 @@ | |||
| #include <vector> | |||
| #include <memory> | |||
| #include <string> | |||
| #include "minddata/dataset/core/constants.h" | |||
| #include "minddata/dataset/include/transforms.h" | |||
| #include "minddata/dataset/text/vocab.h" | |||
| #include "mindspore/ccsrc/minddata/dataset/core/data_type.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| @@ -37,15 +39,18 @@ class LookupOperation; | |||
| /// \brief Lookup operator that looks up a word to an id. | |||
| /// \param[in] vocab a Vocab object. | |||
| /// \param[in] unknown_token word to use for lookup if the word being looked up is out of Vocabulary (oov). | |||
| /// If unknown_token is oov, runtime error will be thrown | |||
| /// If unknown_token is oov, runtime error will be thrown. | |||
| /// \param[in] DataType type of the tensor after lookup, typically int32. | |||
| /// \return Shared pointer to the current TensorOperation. | |||
| std::shared_ptr<LookupOperation> Lookup(const std::shared_ptr<Vocab> &vocab, const std::string &unknown_token); | |||
| std::shared_ptr<LookupOperation> Lookup(const std::shared_ptr<Vocab> &vocab, const std::string &unknown_token, | |||
| const mindspore::dataset::DataType &data_type = DataType("int32")); | |||
| /* ####################################### Derived TensorOperation classes ################################# */ | |||
| class LookupOperation : public TensorOperation { | |||
| public: | |||
| explicit LookupOperation(const std::shared_ptr<Vocab> &vocab, const std::string &unknown_token); | |||
| explicit LookupOperation(const std::shared_ptr<Vocab> &vocab, const std::string &unknown_token, | |||
| const DataType &data_type); | |||
| ~LookupOperation() = default; | |||
| @@ -57,6 +62,7 @@ class LookupOperation : public TensorOperation { | |||
| std::shared_ptr<Vocab> vocab_; | |||
| std::string unknown_token_; | |||
| int32_t default_id_; | |||
| DataType data_type_; | |||
| }; | |||
| } // namespace text | |||
| } // namespace api | |||
| @@ -13,15 +13,16 @@ | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "minddata/dataset/text/kernels/lookup_op.h" | |||
| #include <string> | |||
| #include "minddata/dataset/kernels/data/data_utils.h" | |||
| #include "minddata/dataset/text/kernels/lookup_op.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| LookupOp::LookupOp(std::shared_ptr<Vocab> vocab, WordIdType default_id) | |||
| : vocab_(vocab), default_id_(default_id), type_(DataType("int32")) {} | |||
| LookupOp::LookupOp(std::shared_ptr<Vocab> vocab, WordIdType default_id, const DataType &data_type) | |||
| : vocab_(vocab), default_id_(default_id), type_(data_type) {} | |||
| Status LookupOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) { | |||
| IO_CHECK(input, output); | |||
| @@ -37,6 +38,14 @@ Status LookupOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<T | |||
| "Lookup Error: token: " + std::string(*itr) + " doesn't exist in vocab and no unknown token is specified."); | |||
| } | |||
| RETURN_IF_NOT_OK(Tensor::CreateFromVector(word_ids, input->shape(), output)); | |||
| // type cast to user's requirements if what user wants isn't int32_t | |||
| if ((*output)->type() != type_) { | |||
| std::shared_ptr<Tensor> cast_to; | |||
| RETURN_IF_NOT_OK(TypeCast(*output, &cast_to, type_)); | |||
| *output = cast_to; | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status LookupOp::OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) { | |||
| @@ -18,9 +18,9 @@ | |||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_KERNELS_LOOKUP_OP_H_ | |||
| #include <memory> | |||
| #include <vector> | |||
| #include <utility> | |||
| #include <string> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "minddata/dataset/core/tensor.h" | |||
| #include "minddata/dataset/kernels/tensor_op.h" | |||
| @@ -31,26 +31,27 @@ namespace mindspore { | |||
| namespace dataset { | |||
| class LookupOp : public TensorOp { | |||
| public: | |||
| // constructor for lookup, takes in a vocab object | |||
| // @param std::shared_ptr<Vocab> vocab - | |||
| // @param WordIdType default_id, id to lookup if a word is not in vocab | |||
| explicit LookupOp(std::shared_ptr<Vocab> vocab, WordIdType default_id = 1); | |||
| /// \brief constructor for lookup, takes in a vocab object. | |||
| /// \param[in] std::shared_ptr<Vocab> vocab - vocab used for lookup. | |||
| /// \param[in] WordIdType default_id, id to lookup if a word is not in vocab. | |||
| /// \param[in] DataType type of the tensor after lookup, mostly int32. | |||
| explicit LookupOp(std::shared_ptr<Vocab> vocab, WordIdType default_id, const DataType &data_type); | |||
| ~LookupOp() = default; | |||
| // perform actual lookup on each tensor | |||
| // @param const std::shared_ptr<Tensor> &input | |||
| // @param std::shared_ptr<Tensor> *output | |||
| // @return error code | |||
| /// \brief perform actual lookup on each tensor. | |||
| /// \param[in] const std::shared_ptr<Tensor> &input | |||
| /// \param[in] std::shared_ptr<Tensor> *output | |||
| /// \return[out] error code. | |||
| Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override; | |||
| // print method | |||
| // @param std::ostream out | |||
| /// \brief print method. | |||
| /// \param[in] std::ostream out | |||
| void Print(std::ostream &out) const override; | |||
| // @param std::vector<DataType> &inputs - | |||
| // @param std::vector<DataType> &outputs - | |||
| // @return error code | |||
| /// \param[in] std::vector<DataType> &inputs - | |||
| /// \param[in] std::vector<DataType> &outputs - | |||
| /// \return[out] error code. | |||
| Status OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) override; | |||
| std::string Name() const override { return kLookupOp; } | |||
| @@ -49,6 +49,7 @@ import platform | |||
| import numpy as np | |||
| import mindspore._c_dataengine as cde | |||
| import mindspore.common.dtype as mstype | |||
| from .utils import JiebaMode, NormalizeForm, to_str, SPieceTokenizerOutType, SPieceTokenizerLoadType | |||
| from .validators import check_lookup, check_jieba_add_dict, \ | |||
| @@ -66,11 +67,12 @@ class Lookup(cde.LookupOp): | |||
| vocab(Vocab): a Vocab object. | |||
| unknown_token(str, optional): word to use for lookup if the word being looked up is out of Vocabulary (oov). | |||
| If unknown_token is oov, runtime error will be thrown (default=None). | |||
| data_type (mindspore.dtype, optional): mindspore.dtype lookup maps string to (default=mstype.int32) | |||
| """ | |||
| @check_lookup | |||
| def __init__(self, vocab, unknown_token=None): | |||
| super().__init__(vocab, unknown_token) | |||
| def __init__(self, vocab, unknown_token=None, data_type=mstype.int32): | |||
| super().__init__(vocab, unknown_token, mstype_to_detype(data_type)) | |||
| class SlidingWindow(cde.SlidingWindowOp): | |||
| @@ -103,7 +105,6 @@ class SlidingWindow(cde.SlidingWindowOp): | |||
| super().__init__(width, axis) | |||
| class Ngram(cde.NgramOp): | |||
| """ | |||
| TensorOp to generate n-gram from a 1-D string Tensor. | |||
| @@ -44,12 +44,13 @@ def check_lookup(method): | |||
| @wraps(method) | |||
| def new_method(self, *args, **kwargs): | |||
| [vocab, unknown_token], _ = parse_user_args(method, *args, **kwargs) | |||
| [vocab, unknown_token, data_type], _ = parse_user_args(method, *args, **kwargs) | |||
| if unknown_token is not None: | |||
| type_check(unknown_token, (str,), "unknown_token") | |||
| type_check(vocab, (cde.Vocab,), "vocab is not an instance of cde.Vocab.") | |||
| type_check(data_type, (typing.Type,), "data_type") | |||
| return method(self, *args, **kwargs) | |||
| @@ -327,6 +328,7 @@ def check_from_dataset(method): | |||
| return new_method | |||
| def check_slidingwindow(method): | |||
| """A wrapper that wraps a parameter checker to the original function(sliding window operation).""" | |||
| @@ -339,6 +341,7 @@ def check_slidingwindow(method): | |||
| return new_method | |||
| def check_ngram(method): | |||
| """A wrapper that wraps a parameter checker to the original function.""" | |||
| @@ -26,9 +26,10 @@ | |||
| #include "minddata/dataset/include/text.h" | |||
| using namespace mindspore::dataset::api; | |||
| using mindspore::dataset::DataType; | |||
| using mindspore::dataset::ShuffleMode; | |||
| using mindspore::dataset::Tensor; | |||
| using mindspore::dataset::Status; | |||
| using mindspore::dataset::Tensor; | |||
| using mindspore::dataset::Vocab; | |||
| class MindDataTestPipeline : public UT::DatasetOpTesting { | |||
| @@ -50,7 +51,7 @@ TEST_F(MindDataTestPipeline, TestVocabLookupOp) { | |||
| EXPECT_EQ(s, Status::OK()); | |||
| // Create Lookup operation on ds | |||
| std::shared_ptr<TensorOperation> lookup = text::Lookup(vocab, "<unk>"); | |||
| std::shared_ptr<TensorOperation> lookup = text::Lookup(vocab, "<unk>", DataType("int32")); | |||
| EXPECT_NE(lookup, nullptr); | |||
| // Create Map operation on ds | |||
| @@ -94,7 +95,7 @@ TEST_F(MindDataTestPipeline, TestVocabLookupOpFail1) { | |||
| // Create lookup op for ds | |||
| // Expected failure: "<unk>" is not a word of vocab | |||
| std::shared_ptr<TensorOperation> lookup = text::Lookup(vocab, "<unk>"); | |||
| std::shared_ptr<TensorOperation> lookup = text::Lookup(vocab, "<unk>", DataType("int32")); | |||
| EXPECT_EQ(lookup, nullptr); | |||
| } | |||
| @@ -105,7 +106,7 @@ TEST_F(MindDataTestPipeline, TestVocabLookupOpFail2) { | |||
| // Create lookup op | |||
| // Expected failure: vocab is null | |||
| std::shared_ptr<TensorOperation> lookup = text::Lookup(vocab, ""); | |||
| std::shared_ptr<TensorOperation> lookup = text::Lookup(vocab, "", DataType("int32")); | |||
| EXPECT_EQ(lookup, nullptr); | |||
| } | |||
| @@ -126,7 +127,7 @@ TEST_F(MindDataTestPipeline, TestVocabLookupOpWithEmptyUnknownToken) { | |||
| // Create Lookup operation on ds | |||
| // Expected failure: "" is not a word of vocab | |||
| std::shared_ptr<TensorOperation> lookup = text::Lookup(vocab, ""); | |||
| std::shared_ptr<TensorOperation> lookup = text::Lookup(vocab, "", DataType("int32")); | |||
| EXPECT_EQ(lookup, nullptr); | |||
| } | |||
| @@ -148,7 +149,7 @@ TEST_F(MindDataTestPipeline, TestVocabFromDataset) { | |||
| EXPECT_EQ(home_index, 4); | |||
| // Create Lookup operation on ds | |||
| std::shared_ptr<TensorOperation> lookup = text::Lookup(vocab, "<unk>"); | |||
| std::shared_ptr<TensorOperation> lookup = text::Lookup(vocab, "<unk>", DataType("int32")); | |||
| EXPECT_NE(lookup, nullptr); | |||
| // Create Map operation on ds | |||
| @@ -212,12 +213,15 @@ TEST_F(MindDataTestPipeline, TestVocabFromDatasetDefault) { | |||
| uint64_t i = 0; | |||
| std::vector<int32_t> expected = {2, 3, 1, 4, 5, 0}; | |||
| std::vector<int64_t> not_expected = {2, 3, 1, 4, 5, 0}; | |||
| while (row.size() != 0) { | |||
| auto ind = row["text"]; | |||
| MS_LOG(INFO) << ind->shape() << " " << *ind; | |||
| std::shared_ptr<Tensor> expected_item; | |||
| std::shared_ptr<Tensor> expected_item, not_expected_item; | |||
| Tensor::CreateScalar(expected[i], &expected_item); | |||
| Tensor::CreateScalar(not_expected[i], ¬_expected_item); | |||
| EXPECT_EQ(*ind, *expected_item); | |||
| EXPECT_NE(*ind, *not_expected_item); | |||
| iter->GetNextRow(&row); | |||
| i++; | |||
| } | |||
| @@ -233,8 +237,8 @@ TEST_F(MindDataTestPipeline, TestVocabFromDatasetFail1) { | |||
| // Create vocab from dataset | |||
| // Expected failure: top_k can not be negative | |||
| std::shared_ptr<Vocab> vocab = ds->BuildVocab({"text"}, {0, std::numeric_limits<int64_t>::max()}, | |||
| -2, {"<pad>", "<unk>"}, true); | |||
| std::shared_ptr<Vocab> vocab = | |||
| ds->BuildVocab({"text"}, {0, std::numeric_limits<int64_t>::max()}, -2, {"<pad>", "<unk>"}, true); | |||
| EXPECT_EQ(vocab, nullptr); | |||
| } | |||
| @@ -247,9 +251,9 @@ TEST_F(MindDataTestPipeline, TestVocabFromDatasetFail2) { | |||
| EXPECT_NE(ds, nullptr); | |||
| // Create vocab from dataset | |||
| // Expected failure: requency_range [a,b] should be 0 <= a <= b | |||
| std::shared_ptr<Vocab> vocab = ds->BuildVocab({"text"}, {4, 1}, | |||
| std::numeric_limits<int64_t>::max(), {"<pad>", "<unk>"}, true); | |||
| // Expected failure: frequency_range [a,b] should be 0 <= a <= b | |||
| std::shared_ptr<Vocab> vocab = | |||
| ds->BuildVocab({"text"}, {4, 1}, std::numeric_limits<int64_t>::max(), {"<pad>", "<unk>"}, true); | |||
| EXPECT_EQ(vocab, nullptr); | |||
| } | |||
| @@ -266,3 +270,52 @@ TEST_F(MindDataTestPipeline, TestVocabFromDatasetFail3) { | |||
| std::shared_ptr<Vocab> vocab = ds->BuildVocab({"ColumnNotExist"}); | |||
| EXPECT_EQ(vocab, nullptr); | |||
| } | |||
| TEST_F(MindDataTestPipeline, TestVocabFromDatasetInt64) { | |||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestVocabFromDatasetInt64."; | |||
| // Create a TextFile dataset | |||
| std::string data_file = datasets_root_path_ + "/testVocab/words.txt"; | |||
| std::shared_ptr<Dataset> ds = TextFile({data_file}, 0, ShuffleMode::kFalse); | |||
| EXPECT_NE(ds, nullptr); | |||
| // Create vocab from dataset | |||
| std::shared_ptr<Vocab> vocab = ds->BuildVocab(); | |||
| EXPECT_NE(vocab, nullptr); | |||
| // Check if vocab has words or not | |||
| int32_t home_index = vocab->Lookup("home"); | |||
| EXPECT_EQ(home_index, 2); | |||
| // Create Lookup operation on ds | |||
| std::shared_ptr<TensorOperation> lookup = text::Lookup(vocab, "home", DataType("int64")); | |||
| EXPECT_NE(lookup, nullptr); | |||
| // Create Map operation on ds | |||
| ds = ds->Map({lookup}); | |||
| EXPECT_NE(ds, nullptr); | |||
| // Create an iterator over the result of the above dataset | |||
| // This will trigger the creation of the Execution Tree and launch it. | |||
| std::shared_ptr<Iterator> iter = ds->CreateIterator(); | |||
| EXPECT_NE(iter, nullptr); | |||
| // Iterate the dataset and get each row | |||
| std::unordered_map<std::string, std::shared_ptr<Tensor>> row; | |||
| iter->GetNextRow(&row); | |||
| uint64_t i = 0; | |||
| std::vector<int64_t> expected = {2, 3, 1, 4, 5, 0}; | |||
| std::vector<int8_t> not_expected = {2, 3, 1, 4, 5, 0}; | |||
| while (row.size() != 0) { | |||
| auto ind = row["text"]; | |||
| MS_LOG(INFO) << ind->shape() << " " << *ind; | |||
| std::shared_ptr<Tensor> expected_item, not_expected_item; | |||
| Tensor::CreateScalar(expected[i], &expected_item); | |||
| Tensor::CreateScalar(not_expected[i], ¬_expected_item); | |||
| EXPECT_EQ(*ind, *expected_item); | |||
| EXPECT_NE(*ind, *not_expected_item); | |||
| iter->GetNextRow(&row); | |||
| i++; | |||
| } | |||
| } | |||
| @@ -17,6 +17,7 @@ import numpy as np | |||
| import mindspore.dataset as ds | |||
| import mindspore.dataset.text as text | |||
| import mindspore.common.dtype as mstype | |||
| # this file contains "home is behind the world head" each word is 1 line | |||
| DATA_FILE = "../data/dataset/testVocab/words.txt" | |||
| @@ -137,6 +138,36 @@ def test_from_file(): | |||
| assert "Input vocab_size must be greater than 0" in test_config("w1 w2", 0, [], True) | |||
| assert "Input vocab_size must be greater than 0" in test_config("w1 w2", -1, [], True) | |||
| def test_lookup_cast_type(): | |||
| def gen(texts): | |||
| for word in texts.split(" "): | |||
| yield (np.array(word, dtype='S'),) | |||
| def test_config(lookup_str, data_type=None): | |||
| try: | |||
| vocab = text.Vocab.from_list(["w1", "w2", "w3"], special_tokens=["<unk>"], special_first=True) | |||
| data = ds.GeneratorDataset(gen(lookup_str), column_names=["text"]) | |||
| # if data_type is None, test the default value of data_type | |||
| op = text.Lookup(vocab, "<unk>") if data_type is None else text.Lookup(vocab, "<unk>", data_type) | |||
| data = data.map(input_columns=["text"], operations=op) | |||
| res = [] | |||
| for d in data.create_dict_iterator(num_epochs=1): | |||
| res.append(d["text"]) | |||
| return res[0].dtype | |||
| except (ValueError, RuntimeError, TypeError) as e: | |||
| return str(e) | |||
| # test result is correct | |||
| assert test_config("w1", mstype.int8) == np.dtype("int8") | |||
| assert test_config("w2", mstype.int32) == np.dtype("int32") | |||
| assert test_config("w3", mstype.int64) == np.dtype("int64") | |||
| assert test_config("unk", mstype.float32) != np.dtype("int32") | |||
| assert test_config("unk") == np.dtype("int32") | |||
| # test exception, data_type isn't the correct type | |||
| assert "tldr is not of type (<class 'mindspore._c_expression.typing.Type'>,)" in test_config("unk", "tldr") | |||
| if __name__ == '__main__': | |||
| test_from_dict_exception() | |||
| test_from_list_tutorial() | |||
| @@ -144,3 +175,4 @@ if __name__ == '__main__': | |||
| test_from_dict_tutorial() | |||
| test_from_list() | |||
| test_from_file() | |||
| test_lookup_cast_type() | |||