From 91a6b2b0cafff27db64972d5e39c91cdd4ac2b57 Mon Sep 17 00:00:00 2001 From: Zirui Wu Date: Wed, 9 Dec 2020 17:07:55 -0500 Subject: [PATCH] Add checks to RandomData and LookUp --- mindspore/ccsrc/minddata/dataset/api/text.cc | 6 ++++++ mindspore/ccsrc/minddata/dataset/core/config_manager.h | 2 +- .../dataset/engine/datasetops/source/random_data_op.cc | 4 ++++ .../dataset/engine/ir/datasetops/source/random_node.cc | 5 +++++ .../minddata/dataset/engine/opt/post/auto_worker_pass.cc | 2 +- .../ccsrc/minddata/dataset/text/kernels/lookup_op.cc | 3 +++ mindspore/dataset/text/transforms.py | 1 + tests/ut/cpp/dataset/c_api_dataset_randomdata_test.cc | 8 +++++++- tests/ut/python/dataset/test_vocab.py | 1 + 9 files changed, 29 insertions(+), 3 deletions(-) diff --git a/mindspore/ccsrc/minddata/dataset/api/text.cc b/mindspore/ccsrc/minddata/dataset/api/text.cc index c3608b4c06..2e3abe4e48 100644 --- a/mindspore/ccsrc/minddata/dataset/api/text.cc +++ b/mindspore/ccsrc/minddata/dataset/api/text.cc @@ -320,6 +320,12 @@ Status LookupOperation::ValidateParams() { RETURN_STATUS_SYNTAX_ERROR(err_msg); } + if (!data_type_.IsNumeric()) { + std::string err_msg = "Lookup does not support a string to string mapping, data_type can only be numeric."; + MS_LOG(ERROR) << err_msg; + RETURN_STATUS_SYNTAX_ERROR(err_msg); + } + return Status::OK(); } diff --git a/mindspore/ccsrc/minddata/dataset/core/config_manager.h b/mindspore/ccsrc/minddata/dataset/core/config_manager.h index e3bb2f8edf..609cee8da0 100644 --- a/mindspore/ccsrc/minddata/dataset/core/config_manager.h +++ b/mindspore/ccsrc/minddata/dataset/core/config_manager.h @@ -183,7 +183,7 @@ class ConfigManager { // E.g. 0 would corresponds to a 1:1:1 ratio of num_worker among leaf batch and map. // please refer to AutoWorkerPass for detail on what each option is. // @return The experimental config used by AutoNumWorker, each 1 refers to a different setup configuration - uint8_t get_auto_worker_config_() { return auto_worker_config_; } + uint8_t get_auto_worker_config() { return auto_worker_config_; } // setter function // E.g. set the value of 0 would corresponds to a 1:1:1 ratio of num_worker among leaf batch and map. diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/random_data_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/random_data_op.cc index d7801a6242..b3d8e15c56 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/random_data_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/random_data_op.cc @@ -147,6 +147,10 @@ void RandomDataOp::GenerateSchema() { // All DatasetOps operate by launching a thread (see ExecutionTree). This class functor will // provide the master loop that drives the logic for performing the work. Status RandomDataOp::operator()() { + CHECK_FAIL_RETURN_UNEXPECTED(total_rows_ >= num_workers_, + "RandomDataOp expects total_rows < num_workers. total_row=" + + std::to_string(total_rows_) + ", num_workers=" + std::to_string(num_workers_) + " ."); + // First, compute how many buffers we'll need to satisfy the total row count. // The only reason we do this is for the purpose of throttling worker count if needed. int64_t buffers_needed = total_rows_ / rows_per_buffer_; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/random_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/random_node.cc index 89813243e2..c63f7bcc3c 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/random_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/random_node.cc @@ -52,6 +52,11 @@ Status RandomNode::ValidateParams() { RETURN_IF_NOT_OK(ValidateDatasetColumnParam("RandomNode", "columns_list", columns_list_)); } + // allow total_rows == 0 for now because RandomOp would generate a random row when it gets a 0 + CHECK_FAIL_RETURN_UNEXPECTED(total_rows_ == 0 || total_rows_ >= num_workers_, + "RandomNode needs total_rows < num_workers. total_rows=" + std::to_string(total_rows_) + + ", num_workers=" + std::to_string(num_workers_) + "."); + return Status::OK(); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/post/auto_worker_pass.cc b/mindspore/ccsrc/minddata/dataset/engine/opt/post/auto_worker_pass.cc index 36699953c3..5a97faee8f 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/opt/post/auto_worker_pass.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/post/auto_worker_pass.cc @@ -27,7 +27,7 @@ namespace dataset { // this will become the RootNode:DatasetNode when it is turned on Status AutoWorkerPass::RunOnTree(std::shared_ptr root_ir, bool *modified) { - uint8_t config = GlobalContext::config_manager()->get_auto_worker_config_(); + uint8_t config = GlobalContext::config_manager()->get_auto_worker_config(); OpWeightPass pass(kOpWeightConfigs[config < kOpWeightConfigs.size() ? config : 0]); diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/lookup_op.cc b/mindspore/ccsrc/minddata/dataset/text/kernels/lookup_op.cc index 3d19b991aa..aeaeb84e6b 100644 --- a/mindspore/ccsrc/minddata/dataset/text/kernels/lookup_op.cc +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/lookup_op.cc @@ -28,6 +28,7 @@ Status LookupOp::Compute(const std::shared_ptr &input, std::shared_ptrtype() == DataType::DE_STRING, "None string tensor received."); + std::vector word_ids; word_ids.reserve(input->Size()); for (auto itr = input->begin(); itr != input->end(); itr++) { @@ -41,6 +42,8 @@ Status LookupOp::Compute(const std::shared_ptr &input, std::shared_ptrtype() != type_) { + CHECK_FAIL_RETURN_UNEXPECTED(type_.IsNumeric(), + "Lookup doesn't support string to string lookup. data_type needs to be numeric"); std::shared_ptr cast_to; RETURN_IF_NOT_OK(TypeCast(*output, &cast_to, type_)); *output = cast_to; diff --git a/mindspore/dataset/text/transforms.py b/mindspore/dataset/text/transforms.py index df038f4ebf..f88d4d3b2f 100644 --- a/mindspore/dataset/text/transforms.py +++ b/mindspore/dataset/text/transforms.py @@ -500,6 +500,7 @@ if platform.system().lower() != 'windows': NormalizeForm.NFKD: cde.NormalizeForm.DE_NORMALIZE_NFKD } + class NormalizeUTF8(cde.NormalizeUTF8Op): """ Apply normalize operation on UTF-8 string tensor. diff --git a/tests/ut/cpp/dataset/c_api_dataset_randomdata_test.cc b/tests/ut/cpp/dataset/c_api_dataset_randomdata_test.cc index 8167968ba1..5ff421fb18 100644 --- a/tests/ut/cpp/dataset/c_api_dataset_randomdata_test.cc +++ b/tests/ut/cpp/dataset/c_api_dataset_randomdata_test.cc @@ -100,7 +100,6 @@ TEST_F(MindDataTestPipeline, TestRandomDatasetBasicWithPipeline) { ds1 = ds1->Concat({ds2}); EXPECT_NE(ds1, nullptr); - // Create an iterator over the result of the above dataset // This will trigger the creation of the Execution Tree and launch it. std::shared_ptr iter = ds1->CreateIterator(); @@ -474,3 +473,10 @@ TEST_F(MindDataTestPipeline, TestRandomDatasetDuplicateColumnName) { // Expect failure: duplicate column names EXPECT_EQ(ds->CreateIterator(), nullptr); } + +TEST_F(MindDataTestPipeline, TestRandomDatasetFail) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRandomDatasetFail."; + // this will fail because num_workers is greater than num_rows + std::shared_ptr ds = RandomData(3)->SetNumWorkers(5); + EXPECT_EQ(ds->CreateIterator(), nullptr); +} diff --git a/tests/ut/python/dataset/test_vocab.py b/tests/ut/python/dataset/test_vocab.py index 9241b759bd..c8a46db9c4 100644 --- a/tests/ut/python/dataset/test_vocab.py +++ b/tests/ut/python/dataset/test_vocab.py @@ -166,6 +166,7 @@ def test_lookup_cast_type(): assert test_config("unk") == np.dtype("int32") # test exception, data_type isn't the correct type assert "tldr is not of type (,)" in test_config("unk", "tldr") + assert "Lookup doesn't support string to string lookup" in test_config("w1", mstype.string) if __name__ == '__main__':