diff --git a/mindspore/ccsrc/minddata/dataset/api/datasets.cc b/mindspore/ccsrc/minddata/dataset/api/datasets.cc index f8f74c86c0..ed36277d93 100644 --- a/mindspore/ccsrc/minddata/dataset/api/datasets.cc +++ b/mindspore/ccsrc/minddata/dataset/api/datasets.cc @@ -179,6 +179,26 @@ Dataset::Dataset() { rows_per_buffer_ = cfg->rows_per_buffer(); connector_que_size_ = cfg->op_connector_size(); worker_connector_size_ = cfg->worker_connector_size(); + tree_getters_ = std::make_shared(); +} + +int64_t Dataset::GetDatasetSize() { + int64_t dataset_size; + auto ds = shared_from_this(); + Status rc; + std::unique_ptr runtime_context = std::make_unique(); + rc = runtime_context->Init(); + if (rc.IsError()) { + MS_LOG(ERROR) << "GetDatasetSize: Initializing RuntimeContext failed."; + return -1; + } + rc = tree_getters_->Init(ds); + if (rc.IsError()) { + MS_LOG(ERROR) << "GetDatasetSize: Initializing TreeGetters failed."; + return -1; + } + rc = tree_getters_->GetDatasetSize(&dataset_size); + return rc.IsError() ? -1 : dataset_size; } // Constructor to initialize the cache diff --git a/mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.cc b/mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.cc index 21b212e973..f27b4d5e41 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.cc @@ -351,4 +351,32 @@ Status SaveToDisk::TransfromTensor(const unsigned char *src, const TensorShape & } #endif +TreeGetters::TreeGetters() { + tree_adapter_ = std::make_unique(); + dataset_size_ = -1; +} + +Status TreeGetters::Init(std::shared_ptr d) { return tree_adapter_->BuildAndPrepare(std::move(d), 1); } + +Status TreeGetters::GetDatasetSize(int64_t *dataset_size) { + if (dataset_size_ == -1) { + std::shared_ptr root = std::shared_ptr(tree_adapter_->GetRoot()); + CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr."); + RETURN_IF_NOT_OK(root->GetDatasetSize(dataset_size)); + dataset_size_ = *dataset_size; + TensorRow row; + if (*dataset_size == -1) { + int64_t num_rows = 0; + RETURN_IF_NOT_OK(tree_adapter_->GetNext(&row)); + while (row.size() != 0) { + num_rows++; + RETURN_IF_NOT_OK(tree_adapter_->GetNext(&row)); + } + dataset_size_ = num_rows; + } + } + + *dataset_size = dataset_size_; + return Status::OK(); +} } // namespace mindspore::dataset diff --git a/mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.h b/mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.h index b1ec235a6a..9dd66f9bcb 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.h +++ b/mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.h @@ -152,9 +152,10 @@ class ToDevice : public TreeConsumer { /// Consumer that is used to get some pipeline information class TreeGetters : public TreeConsumer { - Status GetDatasetSize(int32_t *size) { - return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet."); - } + public: + TreeGetters(); + Status Init(std::shared_ptr d) override; + Status GetDatasetSize(int64_t *size); Status GetBatchSize(int32_t *batch_size) { return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet."); } @@ -173,6 +174,11 @@ class TreeGetters : public TreeConsumer { Status GetOutputNames(std::vector *names) { return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet."); } + + std::string Name() override { return "TreeGetters"; } + + private: + int64_t dataset_size_; }; } // namespace mindspore::dataset diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/batch_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/batch_op.cc index 93ef16c525..28b2e88932 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/batch_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/batch_op.cc @@ -531,5 +531,30 @@ Status BatchOp::ComputeColMap() { return Status::OK(); } +Status BatchOp::GetDatasetSize(int64_t *dataset_size) { + if (dataset_size_ > 0) { + *dataset_size = dataset_size_; + return Status::OK(); + } +#ifdef ENABLE_PYTHON + if (batch_size_func_) { + *dataset_size = -1; + return Status::OK(); + } +#endif + int64_t num_rows; + RETURN_IF_NOT_OK(child_[0]->GetDatasetSize(&num_rows)); + if (num_rows > 0 && start_batch_size_ > 0) { + if (drop_) { + num_rows = floor(num_rows / start_batch_size_); + } else { + num_rows = ceil(num_rows / start_batch_size_); + } + } + *dataset_size = num_rows; + dataset_size_ = num_rows; + return Status::OK(); +} + } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/batch_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/batch_op.h index 1b77856250..1e6a66efa8 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/batch_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/batch_op.h @@ -219,6 +219,11 @@ class BatchOp : public ParallelOp { static Status PadColumns(std::unique_ptr *table, const PadInfo &pad_info, const std::unordered_map &column_name_id_map); + /// \brief Base-class override for GetDatasetSize + /// \param[out] dataset_size the size of the dataset + /// \return Status of the function + Status GetDatasetSize(int64_t *dataset_size) override; + protected: Status ComputeColMap() override; diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/bucket_batch_by_length_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/bucket_batch_by_length_op.cc index 5963619de9..a8e6793c0f 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/bucket_batch_by_length_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/bucket_batch_by_length_op.cc @@ -231,5 +231,13 @@ Status BucketBatchByLengthOp::ComputeColMap() { } return Status::OK(); } + +// Get Dataset size +Status BucketBatchByLengthOp::GetDatasetSize(int64_t *dataset_size) { + // We are returning -1 because we can't easily calculate GetDatasetSize. Returning -1 will make TreeGetters to + // iterate over the dataset and count the size + *dataset_size = dataset_size_; + return Status::OK(); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/bucket_batch_by_length_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/bucket_batch_by_length_op.h index 3fd446322b..6ac3700bae 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/bucket_batch_by_length_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/bucket_batch_by_length_op.h @@ -112,6 +112,11 @@ class BucketBatchByLengthOp : public PipelineOp { std::string Name() const override { return kBucketBatchByLengthOp; } + /// \brief Base-class override for GetDatasetSize + /// \param[out] dataset_size the size of the dataset + /// \return Status of the function + Status GetDatasetSize(int64_t *dataset_size) override; + // << Stream output operator overload // @notes This allows you to write the debug print info using stream operators // @param out - reference to the output stream being overloaded diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/concat_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/concat_op.cc index 5233af631a..1e934e7259 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/concat_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/concat_op.cc @@ -195,5 +195,13 @@ Status ConcatOp::PreAccept(NodePass *p, bool *modified) { // Downcast shared pointer then call visitor return p->PreRunOnNode(shared_from_base(), modified); } + +// Get Dataset size +Status ConcatOp::GetDatasetSize(int64_t *dataset_size) { + // We are returning -1 because we can't easily calculate GetDatasetSize. Returning -1 will make TreeGetters to + // iterate over the dataset and count the size + *dataset_size = dataset_size_; + return Status::OK(); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/concat_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/concat_op.h index 7df5e0ae6f..7639aa18d5 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/concat_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/concat_op.h @@ -111,6 +111,11 @@ class ConcatOp : public PipelineOp { /// \return Status of the node visit Status PreAccept(NodePass *p, bool *modified) override; + /// \brief Base-class override for GetDatasetSize + /// \param[out] dataset_size the size of the dataset + /// \return Status of the function + Status GetDatasetSize(int64_t *dataset_size) override; + private: Status Verify(int32_t id, const std::unique_ptr &buf); diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.cc index 9e59236d01..e656f1724f 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.cc @@ -50,7 +50,8 @@ DatasetOp::DatasetOp(int32_t op_connector_size, std::shared_ptr sampler op_num_repeats_per_epoch_(kInfiniteRepeat), op_current_repeats_(0), op_current_epochs_(0), - out_connector_(nullptr) { + out_connector_(nullptr), + dataset_size_(-1) { // The operator starts out with an invalid operator id. The only way to // get it out of invalid state is to assign the operator to an execution tree. } @@ -290,6 +291,17 @@ Status DatasetOp::GetNextInput(std::unique_ptr *p_buffer, int32_t wo return Status::OK(); } +// Gets the dataset size +Status DatasetOp::GetDatasetSize(int64_t *dataset_size) { + if (dataset_size_ > 0) { + *dataset_size = dataset_size_; + return Status::OK(); + } + CHECK_FAIL_RETURN_UNEXPECTED(child_.size() == 1, "Can't get the dataset size for the current tree."); + + return child_[0]->GetDatasetSize(dataset_size); +} + // Performs handling for when an eoe message is received. // The base class implementation simply flows the eoe message to output. Derived classes // may override if they need to perform special eoe handling. diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.h index 00e2200f4c..75b50c3ba6 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.h @@ -179,6 +179,10 @@ class DatasetOp : public std::enable_shared_from_this { /// \return Status - The error code return Status GetNextInput(std::unique_ptr *p_buffer, int32_t worker_id = 0, int32_t child_index = 0); + /// \brief Gets the dataset size + /// \return Status - The status code return + virtual Status GetDatasetSize(int64_t *dataset_size); + /// \brief Performs handling for when an eoe message is received. /// The base class implementation simply flows the eoe message to output. Derived classes /// may override if they need to perform special eoe handling. @@ -406,6 +410,7 @@ class DatasetOp : public std::enable_shared_from_this { std::unordered_map column_name_id_map_; // Mapping between col index and col name std::mutex column_name_map_mutex_; // For protecting shared access to the column map CallbackManager callback_manager_; // Manages callbacks associated with a DatasetOp + int64_t dataset_size_; // Size of the dataset private: /// Sets the operator id. diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/filter_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/filter_op.cc index d9eb064041..d4c762f693 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/filter_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/filter_op.cc @@ -278,5 +278,14 @@ Status FilterOp::PreAccept(NodePass *p, bool *modified) { // Downcast shared pointer then call visitor return p->PreRunOnNode(shared_from_base(), modified); } + +// Get Dataset size +Status FilterOp::GetDatasetSize(int64_t *dataset_size) { + // We are returning -1 because we can't easily calculate GetDatasetSize. Returning -1 will make TreeGetters to + // iterate over the dataset and count the size + *dataset_size = dataset_size_; + return Status::OK(); +} + } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/filter_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/filter_op.h index e41011bab9..f7805bc85e 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/filter_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/filter_op.h @@ -137,6 +137,11 @@ class FilterOp : public ParallelOp { // @return Name of the current Op std::string Name() const override { return kFilterOp; } + /// \brief Base-class override for GetDatasetSize + /// \param[out] dataset_size the size of the dataset + /// \return Status of the function + Status GetDatasetSize(int64_t *dataset_size) override; + private: // predicate_func python callable which returns a boolean value. py::function predicate_func_; diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/repeat_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/repeat_op.cc index 557698a396..e8fdf1d9a0 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/repeat_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/repeat_op.cc @@ -191,5 +191,21 @@ Status RepeatOp::Accept(NodePass *p, bool *modified) { // Downcast shared pointer then call visitor return p->RunOnNode(shared_from_base(), modified); } + +// Get Dataset size +Status RepeatOp::GetDatasetSize(int64_t *dataset_size) { + if (dataset_size_ > 0 || num_repeats_ == -1) { + *dataset_size = dataset_size_; + return Status::OK(); + } + int64_t num_rows; + RETURN_IF_NOT_OK(child_[0]->GetDatasetSize(&num_rows)); + if (num_rows > 0 && num_repeats_ > 0) { + num_rows = num_rows * num_repeats_; + } + *dataset_size = num_rows; + dataset_size_ = num_rows; + return Status::OK(); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/repeat_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/repeat_op.h index 35c3bfeea7..50dab4bac0 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/repeat_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/repeat_op.h @@ -133,6 +133,11 @@ class RepeatOp : public PipelineOp { /// \@return Status - The error code return Status Reset() override; + /// \brief Base-class override for GetDatasetSize + /// \param[out] dataset_size the size of the dataset + /// \return Status of the function + Status GetDatasetSize(int64_t *dataset_size) override; + // \brief Adds an operator to the repeat ops list of tracked leaf/eoe nodes // \param[in] eoe_op The input leaf/eoe operator to add to the list void AddToEoeList(std::shared_ptr eoe_op) { eoe_ops_.push_back(std::move(eoe_op)); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/skip_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/skip_op.cc index c85af73e2c..63fc04776b 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/skip_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/skip_op.cc @@ -134,5 +134,21 @@ Status SkipOp::Accept(NodePass *p, bool *modified) { // Downcast shared pointer then call visitor return p->RunOnNode(shared_from_base(), modified); } + +// Get Dataset size +Status SkipOp::GetDatasetSize(int64_t *dataset_size) { + if (dataset_size_ > 0) { + *dataset_size = dataset_size_; + return Status::OK(); + } + int64_t num_rows; + RETURN_IF_NOT_OK(child_[0]->GetDatasetSize(&num_rows)); + *dataset_size = 0; + if (max_skips_ >= 0 && max_skips_ < num_rows) { + *dataset_size = num_rows - max_skips_; + } + dataset_size_ = *dataset_size; + return Status::OK(); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/skip_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/skip_op.h index 657da1fe84..84663d60eb 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/skip_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/skip_op.h @@ -80,6 +80,11 @@ class SkipOp : public PipelineOp { // @return - Status of the node visit. Status Accept(NodePass *p, bool *modified) override; + /// \brief Base-class override for GetDatasetSize + /// \param[out] dataset_size the size of the dataset + /// \return Status of the function + Status GetDatasetSize(int64_t *dataset_size) override; + // Op name getter // @return Name of the current Op std::string Name() const override { return kSkipOp; } diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/celeba_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/celeba_op.cc index f703473c36..7b7d7f94ea 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/celeba_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/celeba_op.cc @@ -15,6 +15,7 @@ */ #include "minddata/dataset/engine/datasetops/source/celeba_op.h" +#include #include #include #include "minddata/dataset/core/config_manager.h" @@ -445,5 +446,64 @@ Status CelebAOp::ComputeColMap() { } return Status::OK(); } + +// Get Dataset size +Status CelebAOp::GetDatasetSize(int64_t *dataset_size) { + int64_t num_rows, sample_size; + std::string line; + Path folder_path(folder_path_); + std::ifstream attr_file((folder_path / "list_attr_celeba.txt").toString()); + if (!attr_file.is_open()) { + std::string attr_file_name = (folder_path / "list_attr_celeba.txt").toString(); + RETURN_STATUS_UNEXPECTED("Invalid file, failed to open Celeba attr file: " + attr_file_name); + } + + std::string rows_num; + (void)getline(attr_file, rows_num); + try { + num_rows = static_cast(std::stoul(rows_num)); // First line is rows number in attr file + } catch (std::invalid_argument &e) { + RETURN_STATUS_UNEXPECTED( + "Invalid data, failed to convert rows_num from attr_file to unsigned long, invalid argument: " + rows_num); + } catch (std::out_of_range &e) { + RETURN_STATUS_UNEXPECTED( + "Invalid data, failed to convert rows_num from attr_file to unsigned long, out of range: " + rows_num); + } + if (usage_ != "all") { + int64_t partition_num = 0; + char usage_type; + if (usage_ == "train") { + usage_type = '0'; + } else { + if (usage_ == "valid") { + usage_type = '1'; + } else { + if (usage_ == "test") + usage_type = '2'; + else + RETURN_STATUS_UNEXPECTED("Invalid usage."); + } + } + if (!partition_file_.is_open()) { + partition_file_.open((folder_path / "list_eval_partition.txt").toString()); + } + if (partition_file_.is_open()) { + while (getline(partition_file_, line)) { + int start = line.find(' '); + if (line.at(start + 1) == usage_type) { + partition_num++; + } + } + } else { + std::string partition_file_name = "list_eval_partition.txt"; + RETURN_STATUS_UNEXPECTED("Invalid file, failed to open Celeba partition file: " + partition_file_name); + } + num_rows = std::min(num_rows, partition_num); + } + + sample_size = sampler_->GetNumSamples(); + *dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows; + return Status::OK(); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/celeba_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/celeba_op.h index 4d50bb16fd..fdd7781132 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/celeba_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/celeba_op.h @@ -179,6 +179,11 @@ class CelebAOp : public ParallelOp, RandomAccessOp { // @return Name of the current Op std::string Name() const override { return "CelebAOp"; } + /// \brief Base-class override for GetDatasetSize + /// \param[out] dataset_size the size of the dataset + /// \return Status of the function + Status GetDatasetSize(int64_t *dataset_size) override; + private: // Called first when function is called // @return diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/cifar_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/cifar_op.cc index 9d4568041a..93324dccfa 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/cifar_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/cifar_op.cc @@ -507,5 +507,21 @@ Status CifarOp::ComputeColMap() { } return Status::OK(); } + +// Get Dataset size +Status CifarOp::GetDatasetSize(int64_t *dataset_size) { + if (dataset_size_ > 0) { + *dataset_size = dataset_size_; + return Status::OK(); + } + int64_t num_rows, sample_size; + num_rows = num_rows_; + if (num_rows_ <= 0) + RETURN_IF_NOT_OK(CountTotalRows(folder_path_, usage_, cifar_type_ == CifarType::kCifar10, &num_rows)); + sample_size = sampler_->GetNumSamples(); + *dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows; + dataset_size_ = *dataset_size; + return Status::OK(); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/cifar_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/cifar_op.h index 882b688fd5..a2e15dbfbd 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/cifar_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/cifar_op.h @@ -175,6 +175,11 @@ class CifarOp : public ParallelOp, public RandomAccessOp { // @return Name of the current Op std::string Name() const override { return "CifarOp"; } + /// \brief Base-class override for GetDatasetSize + /// \param[out] dataset_size the size of the dataset + /// \return Status of the function + Status GetDatasetSize(int64_t *dataset_size) override; + private: // Initialize Sampler, calls sampler->Init() within // @return Status - The error code return diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/clue_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/clue_op.cc index 02f58a8a92..9d3ad80c7a 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/clue_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/clue_op.cc @@ -15,6 +15,7 @@ */ #include "minddata/dataset/engine/datasetops/source/clue_op.h" +#include #include #include #include @@ -563,5 +564,20 @@ Status ClueOp::Accept(NodePass *p, bool *modified) { // Downcast shared pointer then call visitor return p->RunOnNode(shared_from_base(), modified); } + +// Get Dataset size +Status ClueOp::GetDatasetSize(int64_t *dataset_size) { + if (dataset_size_ > 0) { + *dataset_size = dataset_size_; + return Status::OK(); + } + int64_t num_rows, sample_size; + if (num_rows_per_shard_ <= 0) RETURN_IF_NOT_OK(CalculateNumRowsPerShard()); + sample_size = num_samples_; + num_rows = num_rows_per_shard_; + *dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows; + dataset_size_ = *dataset_size; + return Status::OK(); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/clue_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/clue_op.h index b1edc7264f..d2825dc1b6 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/clue_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/clue_op.h @@ -193,6 +193,11 @@ class ClueOp : public ParallelOp { // @return - Status of the node visit. Status Accept(NodePass *p, bool *modified) override; + /// \brief Base-class override for GetDatasetSize + /// \param[out] dataset_size the size of the dataset + /// \return Status of the function + Status GetDatasetSize(int64_t *dataset_size) override; + private: // The entry point for when workers are launched. // @param worker_id - the id of the worker that is executing this function. diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/coco_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/coco_op.cc index e1b875165f..69ac8a28d3 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/coco_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/coco_op.cc @@ -679,5 +679,36 @@ Status CocoOp::ComputeColMap() { } return Status::OK(); } + +// Get Dataset size +Status CocoOp::GetDatasetSize(int64_t *dataset_size) { + if (dataset_size_ > 0) { + *dataset_size = dataset_size_; + return Status::OK(); + } + int64_t num_rows = 0, sample_size; + std::string task_type; + switch (task_type_) { + case TaskType::Detection: + task_type = "Detection"; + break; + case TaskType::Keypoint: + task_type = "Keypoint"; + break; + case TaskType::Panoptic: + task_type = "Panoptic"; + break; + case TaskType::Stuff: + task_type = "Stuff"; + break; + } + if (image_ids_.size() == 0) { + RETURN_IF_NOT_OK(CountTotalRows(image_folder_path_, annotation_path_, task_type, &num_rows)); + } + sample_size = sampler_->GetNumSamples(); + *dataset_size = sample_size != 0 ? std::min(num_rows, sample_size) : num_rows; + dataset_size_ = *dataset_size; + return Status::OK(); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/coco_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/coco_op.h index 7026a382d0..534dab9649 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/coco_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/coco_op.h @@ -209,6 +209,11 @@ class CocoOp : public ParallelOp, public RandomAccessOp { // @return Name of the current Op std::string Name() const override { return "CocoOp"; } + /// \brief Base-class override for GetDatasetSize + /// \param[out] dataset_size the size of the dataset + /// \return Status of the function + Status GetDatasetSize(int64_t *dataset_size) override; + private: // Initialize Sampler, calls sampler->Init() within // @return Status - The error code return diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.cc index 82ff744693..2e7b59b7ff 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.cc @@ -15,6 +15,7 @@ */ #include "minddata/dataset/engine/datasetops/source/csv_op.h" +#include #include #include #include @@ -914,5 +915,20 @@ Status CsvOp::Accept(NodePass *p, bool *modified) { // Downcast shared pointer then call visitor return p->RunOnNode(shared_from_base(), modified); } + +// Get Dataset size +Status CsvOp::GetDatasetSize(int64_t *dataset_size) { + if (dataset_size_ > 0) { + *dataset_size = dataset_size_; + return Status::OK(); + } + int64_t num_rows, sample_size; + if (num_rows_per_shard_ <= 0) RETURN_IF_NOT_OK(CalculateNumRowsPerShard()); + sample_size = num_samples_; + num_rows = num_rows_per_shard_; + *dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows; + dataset_size_ = *dataset_size; + return Status::OK(); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.h index 9092ede899..35d95daf3b 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.h @@ -318,6 +318,11 @@ class CsvOp : public ParallelOp { // @return - Status of the node visit. Status Accept(NodePass *p, bool *modified) override; + /// \brief Base-class override for GetDatasetSize + /// \param[out] dataset_size the size of the dataset + /// \return Status of the function + Status GetDatasetSize(int64_t *dataset_size) override; + private: // The entry point for when workers are launched. // @param worker_id - the id of the worker that is executing this function. diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/image_folder_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/image_folder_op.cc index 8ffabb7260..c08966d939 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/image_folder_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/image_folder_op.cc @@ -453,5 +453,20 @@ Status ImageFolderOp::ComputeColMap() { } return Status::OK(); } + +// Get Dataset size +Status ImageFolderOp::GetDatasetSize(int64_t *dataset_size) { + if (dataset_size_ > 0) { + *dataset_size = dataset_size_; + return Status::OK(); + } + int64_t sample_size, num_rows, num_classes; + num_rows = num_rows_; + if (num_rows_ <= 0) RETURN_IF_NOT_OK(CountRowsAndClasses(folder_path_, extensions_, &num_rows, &num_classes)); + sample_size = sampler_->GetNumSamples(); + *dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows; + dataset_size_ = *dataset_size; + return Status::OK(); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/image_folder_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/image_folder_op.h index 53e7357e1a..90b813f60b 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/image_folder_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/image_folder_op.h @@ -213,6 +213,11 @@ class ImageFolderOp : public ParallelOp, public RandomAccessOp { // @return Name of the current Op std::string Name() const override { return "ImageFolderOp"; } + /// \brief Base-class override for GetDatasetSize + /// \param[out] dataset_size the size of the dataset + /// \return Status of the function + Status GetDatasetSize(int64_t *dataset_size) override; + private: // Initialize Sampler, calls sampler->Init() within // @return Status - The error code return diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/manifest_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/manifest_op.cc index 0c32ed96b6..bfff2d4704 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/manifest_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/manifest_op.cc @@ -453,5 +453,23 @@ Status ManifestOp::ComputeColMap() { } return Status::OK(); } + +// Get Dataset size +Status ManifestOp::GetDatasetSize(int64_t *dataset_size) { + if (dataset_size_ > 0) { + *dataset_size = dataset_size_; + return Status::OK(); + } + int64_t num_rows, sample_size; + std::shared_ptr op; + RETURN_IF_NOT_OK(Builder().SetManifestFile(file_).SetClassIndex(class_index_).SetUsage(usage_).Build(&op)); + RETURN_IF_NOT_OK(op->ParseManifestFile()); + num_rows = static_cast(op->image_labelname_.size()); + sample_size = sampler_->GetNumSamples(); + *dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows; + dataset_size_ = *dataset_size; + return Status::OK(); +} + } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/manifest_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/manifest_op.h index 2a022868d2..ce19c06465 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/manifest_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/manifest_op.h @@ -183,6 +183,11 @@ class ManifestOp : public ParallelOp, public RandomAccessOp { // @return Name of the current Op std::string Name() const override { return "ManifestOp"; } + /// \brief Base-class override for GetDatasetSize + /// \param[out] dataset_size the size of the dataset + /// \return Status of the function + Status GetDatasetSize(int64_t *dataset_size) override; + private: // Initialize Sampler, calls sampler->Init() within // @return Status - The error code return diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mindrecord_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mindrecord_op.cc index f9ddfb68de..9f8b63be1d 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mindrecord_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mindrecord_op.cc @@ -38,6 +38,7 @@ namespace mindspore { namespace dataset { + using mindrecord::kInt64Len; using mindrecord::MSRStatus; using mindrecord::Schema; @@ -476,5 +477,23 @@ Status MindRecordOp::ComputeColMap() { } return Status::OK(); } + +// Get Dataset size +Status MindRecordOp::GetDatasetSize(int64_t *dataset_size) { + if (dataset_size_ > 0) { + *dataset_size = dataset_size_; + return Status::OK(); + } + int64_t num_rows = num_rows_, sample_size; + if (num_rows_ <= 0) { + std::shared_ptr op; + RETURN_IF_NOT_OK(CountTotalRows(dataset_file_, load_dataset_, op, &num_rows, num_padded_)); + } + sample_size = operators_[0]->GetNumSamples(num_rows, 0); + *dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows; + dataset_size_ = *dataset_size; + return Status::OK(); +} + } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mindrecord_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mindrecord_op.h index dae29f5541..cba0d9f625 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mindrecord_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mindrecord_op.h @@ -212,6 +212,11 @@ class MindRecordOp : public ParallelOp { // @return Name of the current Op std::string Name() const override { return "MindRecordOp"; } + /// \brief Base-class override for GetDatasetSize + /// \param[out] dataset_size the size of the dataset + /// \return Status of the function + Status GetDatasetSize(int64_t *dataset_size) override; + private: Status GetBufferFromReader(std::unique_ptr *fetched_buffer, int64_t buffer_id, int32_t worker_id); diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mnist_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mnist_op.cc index c3f8c62e8a..7d5923c475 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mnist_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mnist_op.cc @@ -470,5 +470,20 @@ Status MnistOp::ComputeColMap() { } return Status::OK(); } + +// Get Dataset size +Status MnistOp::GetDatasetSize(int64_t *dataset_size) { + if (dataset_size_ > 0) { + *dataset_size = dataset_size_; + return Status::OK(); + } + int64_t num_rows, sample_size; + num_rows = num_rows_; + if (num_rows_ <= 0) RETURN_IF_NOT_OK(CountTotalRows(folder_path_, usage_, &num_rows)); + sample_size = sampler_->GetNumSamples(); + *dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows; + dataset_size_ = *dataset_size; + return Status::OK(); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mnist_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mnist_op.h index db9a9587b0..c845ad1217 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mnist_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mnist_op.h @@ -168,6 +168,11 @@ class MnistOp : public ParallelOp, public RandomAccessOp { // @return Name of the current Op std::string Name() const override { return "MnistOp"; } + /// \brief Base-class override for GetDatasetSize + /// \param[out] dataset_size the size of the dataset + /// \return Status of the function + Status GetDatasetSize(int64_t *dataset_size) override; + private: // Initialize Sampler, calls sampler->Init() within // @return Status - The error code return diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/random_data_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/random_data_op.cc index 381e4ed9bb..2f5d3b0e39 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/random_data_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/random_data_op.cc @@ -15,6 +15,8 @@ */ #include "minddata/dataset/engine/datasetops/source/random_data_op.h" + +#include #include #include #include "minddata/dataset/engine/execution_tree.h" @@ -418,5 +420,19 @@ Status RandomDataOp::ComputeColMap() { } return Status::OK(); } + +// Get Dataset size +Status RandomDataOp::GetDatasetSize(int64_t *dataset_size) { + if (dataset_size_ > 0) { + *dataset_size = dataset_size_; + return Status::OK(); + } + int64_t num_rows, sample_size = 0; + num_rows = total_rows_ != 0 ? total_rows_ : data_schema_->num_rows(); + if (sampler_ != nullptr) sample_size = sampler_->GetNumSamples(); + *dataset_size = sample_size != 0 ? std::min(num_rows, sample_size) : num_rows; + dataset_size_ = *dataset_size; + return Status::OK(); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/random_data_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/random_data_op.h index 14421efac9..e558aded55 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/random_data_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/random_data_op.h @@ -203,6 +203,11 @@ class RandomDataOp : public ParallelOp { // @return Name of the current Op std::string Name() const override { return "RandomDataOp"; } + /// \brief Base-class override for GetDatasetSize + /// \param[out] dataset_size the size of the dataset + /// \return Status of the function + Status GetDatasetSize(int64_t *dataset_size) override; + private: /** * The entry point code for when workers are launched diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/CMakeLists.txt index 1335d987e8..86f2ffffd2 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/CMakeLists.txt +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/CMakeLists.txt @@ -2,20 +2,20 @@ file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc" set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) set(DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SRC_FILES - distributed_sampler.cc - pk_sampler.cc - random_sampler.cc - sampler.cc - sequential_sampler.cc - subset_random_sampler.cc - weighted_random_sampler.cc - ) + distributed_sampler.cc + pk_sampler.cc + random_sampler.cc + sampler.cc + sequential_sampler.cc + subset_random_sampler.cc + weighted_random_sampler.cc + ) if (ENABLE_PYTHON) set(DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SRC_FILES - ${DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SRC_FILES} - python_sampler.cc - ) -endif() + ${DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SRC_FILES} + python_sampler.cc + ) +endif () add_library(engine-datasetops-source-sampler OBJECT ${DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SRC_FILES}) diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sampler.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sampler.cc index e0af135a03..ebf00413b4 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sampler.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sampler.cc @@ -129,6 +129,8 @@ Status Sampler::SetNumSamples(int64_t num_samples) { return Status::OK(); } +int64_t Sampler::GetNumSamples() { return num_samples_; } + Status Sampler::SetNumRowsInDataset(int64_t num_rows) { CHECK_FAIL_RETURN_UNEXPECTED(num_rows > 0, "Invalid parameter, num_rows must be greater than 0."); num_rows_ = num_rows; diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sampler.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sampler.h index 268eb1256d..1aa061558c 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sampler.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sampler.h @@ -98,6 +98,11 @@ class Sampler { // @return status error code Status SetNumSamples(int64_t num_samples); + // getter for num samples + // @param num_samples - the number of samples to return. + // @return status error code + int64_t GetNumSamples(); + // setter for num or records in the dataset // @param num_rows - the number of records // @return status error code diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.cc index 98841cf2c8..52d7e7745a 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.cc @@ -519,5 +519,20 @@ Status TextFileOp::Accept(NodePass *p, bool *modified) { // Downcast shared pointer then call visitor return p->RunOnNode(shared_from_base(), modified); } + +// Get Dataset size +Status TextFileOp::GetDatasetSize(int64_t *dataset_size) { + if (dataset_size_ > 0) { + *dataset_size = dataset_size_; + return Status::OK(); + } + int64_t num_rows, sample_size; + sample_size = total_rows_; + if (num_rows_per_shard_ <= 0) RETURN_IF_NOT_OK(CalculateNumRowsPerShard()); + num_rows = total_rows_; + *dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows; + dataset_size_ = *dataset_size; + return Status::OK(); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.h index 740a24aecd..131d3accb9 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.h @@ -198,6 +198,11 @@ class TextFileOp : public ParallelOp { // @return - Status of the node visit. Status Accept(NodePass *p, bool *modified) override; + /// \brief Base-class override for GetDatasetSize + /// \param[out] dataset_size the size of the dataset + /// \return Status of the function + Status GetDatasetSize(int64_t *dataset_size) override; + private: // The entry point for when workers are launched. // @param worker_id - the id of the worker that is executing this function. diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/tf_reader_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/tf_reader_op.cc index 79d39a0fd0..385aa42ebe 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/tf_reader_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/tf_reader_op.cc @@ -1062,5 +1062,27 @@ Status TFReaderOp::PrepareNodePostAction() { return Status::OK(); } + +// Get Dataset size +Status TFReaderOp::GetDatasetSize(int64_t *dataset_size) { + if (dataset_size_ > 0) { + *dataset_size = dataset_size_; + return Status::OK(); + } + int64_t num_rows, sample_size; + num_rows = num_rows_; + if (num_rows_ <= 0) { + if (equal_rows_per_shard_) { + RETURN_IF_NOT_OK(CalculateNumRowsPerShard()); + num_rows = num_rows_per_shard_; + } else { + RETURN_IF_NOT_OK(CountTotalRows(&num_rows, dataset_files_list_)); + } + } + sample_size = total_rows_ != 0 ? total_rows_ : data_schema_->num_rows(); + *dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows; + dataset_size_ = *dataset_size; + return Status::OK(); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/tf_reader_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/tf_reader_op.h index 8a295291b4..748979f50e 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/tf_reader_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/tf_reader_op.h @@ -257,6 +257,11 @@ class TFReaderOp : public ParallelOp { // before providing their own implementations. Status PrepareNodePostAction() override; + /// \brief Base-class override for GetDatasetSize + /// \param[out] dataset_size the size of the dataset + /// \return Status of the function + Status GetDatasetSize(int64_t *dataset_size) override; + private: // The entry point for when workers are launched. // @param worker_id - the id of the worker that is executing this function. diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/voc_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/voc_op.cc index 8e09510dd3..2acd467862 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/voc_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/voc_op.cc @@ -513,5 +513,33 @@ Status VOCOp::ComputeColMap() { } return Status::OK(); } + +// Get Dataset size +Status VOCOp::GetDatasetSize(int64_t *dataset_size) { + if (dataset_size_ > 0) { + *dataset_size = dataset_size_; + return Status::OK(); + } + int64_t num_rows = 0, sample_size; + if (image_ids_.size() == 0) { + if (task_type_ == TaskType::Detection) { + std::shared_ptr op; + RETURN_IF_NOT_OK( + Builder().SetDir(folder_path_).SetTask("Detection").SetUsage(usage_).SetClassIndex(class_index_).Build(&op)); + RETURN_IF_NOT_OK(op->ParseImageIds()); + RETURN_IF_NOT_OK(op->ParseAnnotationIds()); + num_rows = static_cast(op->image_ids_.size()); + } else if (task_type_ == TaskType::Segmentation) { + std::shared_ptr op; + RETURN_IF_NOT_OK(Builder().SetDir(folder_path_).SetTask("Segmentation").SetUsage(usage_).Build(&op)); + RETURN_IF_NOT_OK(op->ParseImageIds()); + num_rows = static_cast(op->image_ids_.size()); + } + } + sample_size = sampler_->GetNumSamples(); + *dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows; + dataset_size_ = *dataset_size; + return Status::OK(); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/voc_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/voc_op.h index dca7a8e793..b648068a1e 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/voc_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/voc_op.h @@ -216,6 +216,11 @@ class VOCOp : public ParallelOp, public RandomAccessOp { // @return Name of the current Op std::string Name() const override { return "VOCOp"; } + /// \brief Base-class override for GetDatasetSize + /// \param[out] dataset_size the size of the dataset + /// \return Status of the function + Status GetDatasetSize(int64_t *dataset_size) override; + private: // Initialize Sampler, calls sampler->Init() within // @return Status - The error code return diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/take_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/take_op.cc index 615dcaea41..b61613534d 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/take_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/take_op.cc @@ -16,6 +16,7 @@ #include #include +#include #include "utils/ms_utils.h" #include "minddata/dataset/core/config_manager.h" #include "minddata/dataset/engine/data_buffer.h" @@ -131,5 +132,18 @@ Status TakeOp::Accept(NodePass *p, bool *modified) { // Downcast shared pointer then call visitor return p->RunOnNode(shared_from_base(), modified); } + +// Get Dataset size +Status TakeOp::GetDatasetSize(int64_t *dataset_size) { + if (dataset_size_ > 0) { + *dataset_size = dataset_size_; + return Status::OK(); + } + int64_t num_rows; + RETURN_IF_NOT_OK(child_[0]->GetDatasetSize(&num_rows)); + *dataset_size = std::min(static_cast(max_takes_), num_rows); + dataset_size_ = *dataset_size; + return Status::OK(); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/take_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/take_op.h index d055207520..740639f645 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/take_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/take_op.h @@ -88,6 +88,11 @@ class TakeOp : public PipelineOp { // @return Name of the current Op std::string Name() const override { return kTakeOp; } + /// \brief Base-class override for GetDatasetSize + /// \param[out] dataset_size the size of the dataset + /// \return Status of the function + Status GetDatasetSize(int64_t *dataset_size) override; + private: int32_t max_takes_; // The number of takes that the user requested int32_t take_count_; // A counter for the current number of executed takes diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/zip_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/zip_op.cc index 0ab8406153..dd9edecb59 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/zip_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/zip_op.cc @@ -14,6 +14,7 @@ * limitations under the License. */ #include "minddata/dataset/engine/datasetops/zip_op.h" +#include #include #include #include "minddata/dataset/core/constants.h" @@ -251,6 +252,24 @@ Status ZipOp::Accept(NodePass *p, bool *modified) { return p->RunOnNode(shared_from_base(), modified); } +// Get Dataset size +Status ZipOp::GetDatasetSize(int64_t *dataset_size) { + if (dataset_size_ > 0) { + *dataset_size = dataset_size_; + return Status::OK(); + } + std::vector dataset_sizes; + int64_t child_dataset_size; + for (auto child : child_) { + RETURN_IF_NOT_OK(child->GetDatasetSize(&child_dataset_size)); + dataset_sizes.push_back(child_dataset_size); + } + + *dataset_size = *std::min_element(dataset_sizes.begin(), dataset_sizes.end()); + dataset_size_ = *dataset_size; + return Status::OK(); +} + Status ZipOp::ComputeColMap() { if (column_name_id_map_.empty()) { column_name_id_map_ = {}; diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/zip_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/zip_op.h index f2cc282399..2aace463ac 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/zip_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/zip_op.h @@ -120,6 +120,11 @@ class ZipOp : public PipelineOp { // @return Name of the current Op std::string Name() const override { return kZipOp; } + /// \brief Base-class override for GetDatasetSize + /// \param[out] dataset_size the size of the dataset + /// \return Status of the function + Status GetDatasetSize(int64_t *dataset_size) override; + private: // Handles preprocessing of the main loop, used when starting new epoch Status prepare(TensorQTable *const table); diff --git a/mindspore/ccsrc/minddata/dataset/engine/tree_adapter.h b/mindspore/ccsrc/minddata/dataset/engine/tree_adapter.h index 182e9f999e..c95b72b44e 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/tree_adapter.h +++ b/mindspore/ccsrc/minddata/dataset/engine/tree_adapter.h @@ -47,6 +47,9 @@ class TreeAdapter { // 2. GetNext will return empty row when eoe/eof is obtained Status GetNext(TensorRow *); + // This function will return the root of the execution tree. + std::weak_ptr GetRoot() { return tree_ != nullptr ? tree_->root() : nullptr; } + // This function will return the column_name_map once BuildAndPrepare() is called std::unordered_map GetColumnNameMap() const { return column_name_map_; } diff --git a/mindspore/ccsrc/minddata/dataset/include/datasets.h b/mindspore/ccsrc/minddata/dataset/include/datasets.h index 85b08cc1a1..5a5eb9305d 100644 --- a/mindspore/ccsrc/minddata/dataset/include/datasets.h +++ b/mindspore/ccsrc/minddata/dataset/include/datasets.h @@ -28,6 +28,7 @@ #include "mindspore/ccsrc/minddata/dataset/engine/ir/cache/dataset_cache.h" #include "minddata/dataset/core/constants.h" +#include "minddata/dataset/engine/consumers/tree_consumer.h" #include "minddata/dataset/engine/data_schema.h" #include "minddata/dataset/include/iterator.h" #include "minddata/dataset/include/samplers.h" @@ -49,6 +50,7 @@ class DataSchema; class Tensor; class TensorShape; class TreeAdapter; +class TreeGetters; #ifndef ENABLE_ANDROID class Vocab; #endif @@ -570,6 +572,10 @@ class Dataset : public std::enable_shared_from_this { /// \return Status Status::OK() if all the parameters are valid virtual Status ValidateParams() = 0; + /// \brief Gets the dataset size + /// \return status code + int64_t GetDatasetSize(); + /// \brief Setter function for runtime number of workers /// \param[in] num_workers The number of threads in this operator /// \return Shared pointer to the original object @@ -750,6 +756,7 @@ class Dataset : public std::enable_shared_from_this { protected: std::vector> children; std::shared_ptr parent; + std::shared_ptr tree_getters_; int32_t num_workers_; int32_t rows_per_buffer_; diff --git a/tests/ut/cpp/dataset/c_api_dataset_cifar_test.cc b/tests/ut/cpp/dataset/c_api_dataset_cifar_test.cc index cd50e8633c..b7a33cea5f 100644 --- a/tests/ut/cpp/dataset/c_api_dataset_cifar_test.cc +++ b/tests/ut/cpp/dataset/c_api_dataset_cifar_test.cc @@ -73,6 +73,17 @@ TEST_F(MindDataTestPipeline, TestCifar10Dataset) { iter->Stop(); } +TEST_F(MindDataTestPipeline, TestCifar10GetDatasetSize) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCifar10GetDatasetSize."; + + // Create a Cifar10 Dataset + std::string folder_path = datasets_root_path_ + "/testCifar10Data/"; + std::shared_ptr ds = Cifar10(folder_path, "all"); + EXPECT_NE(ds, nullptr); + + EXPECT_EQ(ds->GetDatasetSize(), 10000); +} + TEST_F(MindDataTestPipeline, TestCifar100Dataset) { MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCifar100Dataset."; @@ -108,6 +119,17 @@ TEST_F(MindDataTestPipeline, TestCifar100Dataset) { iter->Stop(); } +TEST_F(MindDataTestPipeline, TestCifar100GetDatasetSize) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCifar100GetDatasetSize."; + + // Create a Cifar100 Dataset + std::string folder_path = datasets_root_path_ + "/testCifar100Data/"; + std::shared_ptr ds = Cifar100(folder_path, "all", RandomSampler(false, 10)); + EXPECT_NE(ds, nullptr); + + EXPECT_EQ(ds->GetDatasetSize(), 10); +} + TEST_F(MindDataTestPipeline, TestCifar100DatasetFail1) { MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCifar100DatasetFail1."; diff --git a/tests/ut/cpp/dataset/c_api_dataset_clue_test.cc b/tests/ut/cpp/dataset/c_api_dataset_clue_test.cc index 35320dd0c6..ed0dd8fe0c 100644 --- a/tests/ut/cpp/dataset/c_api_dataset_clue_test.cc +++ b/tests/ut/cpp/dataset/c_api_dataset_clue_test.cc @@ -162,6 +162,19 @@ TEST_F(MindDataTestPipeline, TestCLUEDatasetBasic) { iter->Stop(); } +TEST_F(MindDataTestPipeline, TestCLUEGetDatasetSize) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCLUEGetDatasetSize."; + + // Create a CLUEFile Dataset, with single CLUE file + std::string clue_file = datasets_root_path_ + "/testCLUE/afqmc/train.json"; + std::string task = "AFQMC"; + std::string usage = "train"; + std::shared_ptr ds = CLUE({clue_file}, task, usage, 2); + EXPECT_NE(ds, nullptr); + + EXPECT_EQ(ds->GetDatasetSize(), 2); +} + TEST_F(MindDataTestPipeline, TestCLUEDatasetCMNLI) { MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCLUEDatasetCMNLI."; diff --git a/tests/ut/cpp/dataset/c_api_dataset_coco_test.cc b/tests/ut/cpp/dataset/c_api_dataset_coco_test.cc index 71cea0af52..f63575eb92 100644 --- a/tests/ut/cpp/dataset/c_api_dataset_coco_test.cc +++ b/tests/ut/cpp/dataset/c_api_dataset_coco_test.cc @@ -91,6 +91,18 @@ TEST_F(MindDataTestPipeline, TestCocoDefault) { iter->Stop(); } +TEST_F(MindDataTestPipeline, TestCocoGetDatasetSize) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCocoGetDatasetSize."; + // Create a Coco Dataset + std::string folder_path = datasets_root_path_ + "/testCOCO/train"; + std::string annotation_file = datasets_root_path_ + "/testCOCO/annotations/train.json"; + + std::shared_ptr ds = Coco(folder_path, annotation_file); + EXPECT_NE(ds, nullptr); + + EXPECT_EQ(ds->GetDatasetSize(), 6); +} + TEST_F(MindDataTestPipeline, TestCocoDetection) { MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCocoDetection."; // Create a Coco Dataset diff --git a/tests/ut/cpp/dataset/c_api_dataset_csv_test.cc b/tests/ut/cpp/dataset/c_api_dataset_csv_test.cc index 825ce7099d..25593c517c 100644 --- a/tests/ut/cpp/dataset/c_api_dataset_csv_test.cc +++ b/tests/ut/cpp/dataset/c_api_dataset_csv_test.cc @@ -101,6 +101,18 @@ TEST_F(MindDataTestPipeline, TestCSVDatasetBasic) { iter->Stop(); } +TEST_F(MindDataTestPipeline, TestCSVGetDatasetSize) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCSVGetDatasetSize."; + + // Create a CSVDataset, with single CSV file + std::string train_file = datasets_root_path_ + "/testCSV/1.csv"; + std::vector column_names = {"col1", "col2", "col3", "col4"}; + std::shared_ptr ds = CSV({train_file}, ',', {}, column_names, 0, ShuffleMode::kFalse); + EXPECT_NE(ds, nullptr); + + EXPECT_EQ(ds->GetDatasetSize(), 3); +} + TEST_F(MindDataTestPipeline, TestCSVDatasetMultiFiles) { MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCSVDatasetMultiFiles."; diff --git a/tests/ut/cpp/dataset/c_api_dataset_manifest_test.cc b/tests/ut/cpp/dataset/c_api_dataset_manifest_test.cc index d94a233aca..f36a8848ea 100644 --- a/tests/ut/cpp/dataset/c_api_dataset_manifest_test.cc +++ b/tests/ut/cpp/dataset/c_api_dataset_manifest_test.cc @@ -67,6 +67,17 @@ TEST_F(MindDataTestPipeline, TestManifestBasic) { iter->Stop(); } +TEST_F(MindDataTestPipeline, TestManifestGetDatasetSize) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestManifestGetDatasetSize."; + + std::string file_path = datasets_root_path_ + "/testManifestData/cpp.json"; + // Create a Manifest Dataset + std::shared_ptr ds = Manifest(file_path); + EXPECT_NE(ds, nullptr); + + EXPECT_EQ(ds->GetDatasetSize(), 2); +} + TEST_F(MindDataTestPipeline, TestManifestDecode) { MS_LOG(INFO) << "Doing MindDataTestPipeline-TestManifestDecode."; @@ -91,7 +102,7 @@ TEST_F(MindDataTestPipeline, TestManifestDecode) { auto shape = image->shape(); MS_LOG(INFO) << "Tensor image shape size: " << shape.Size(); MS_LOG(INFO) << "Tensor image shape: " << image->shape(); - EXPECT_GT(shape.Size(), 1); // Verify decode=true took effect + EXPECT_GT(shape.Size(), 1); // Verify decode=true took effect iter->GetNextRow(&row); } diff --git a/tests/ut/cpp/dataset/c_api_dataset_mindrecord.cc b/tests/ut/cpp/dataset/c_api_dataset_mindrecord.cc index d8c110e066..c119d77b9b 100644 --- a/tests/ut/cpp/dataset/c_api_dataset_mindrecord.cc +++ b/tests/ut/cpp/dataset/c_api_dataset_mindrecord.cc @@ -71,6 +71,19 @@ TEST_F(MindDataTestPipeline, TestMindDataSuccess1) { iter->Stop(); } +TEST_F(MindDataTestPipeline, TestMindDataGetDatasetSize) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestMindDataGetDatasetSize with string file pattern."; + + // Create a MindData Dataset + // Pass one mindrecord shard file to parse dataset info, and search for other mindrecord files with same dataset info, + // thus all records in imagenet.mindrecord0 ~ imagenet.mindrecord3 will be read + std::string file_path = datasets_root_path_ + "/../mindrecord/testMindDataSet/testImageNetData/imagenet.mindrecord0"; + std::shared_ptr ds = MindData(file_path); + EXPECT_NE(ds, nullptr); + + EXPECT_EQ(ds->GetDatasetSize(), 20); +} + TEST_F(MindDataTestPipeline, TestMindDataSuccess2) { MS_LOG(INFO) << "Doing MindDataTestPipeline-TestMindDataSuccess2 with a vector of single mindrecord file."; diff --git a/tests/ut/cpp/dataset/c_api_dataset_ops_test.cc b/tests/ut/cpp/dataset/c_api_dataset_ops_test.cc index 5bb855a247..a999e25b16 100644 --- a/tests/ut/cpp/dataset/c_api_dataset_ops_test.cc +++ b/tests/ut/cpp/dataset/c_api_dataset_ops_test.cc @@ -368,6 +368,34 @@ TEST_F(MindDataTestPipeline, TestConcatSuccess) { iter->Stop(); } +TEST_F(MindDataTestPipeline, TestConcatGetDatasetSize) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestConcatGetDatasetSize."; + + // Create an ImageFolder Dataset + // Column names: {"image", "label"} + std::string folder_path = datasets_root_path_ + "/testPK/data/"; + std::shared_ptr ds = ImageFolder(folder_path, true, RandomSampler(false, 10)); + EXPECT_NE(ds, nullptr); + + // Create a Cifar10 Dataset + // Column names: {"image", "label"} + folder_path = datasets_root_path_ + "/testCifar10Data/"; + std::shared_ptr ds2 = Cifar10(folder_path, "all", RandomSampler(false, 9)); + EXPECT_NE(ds2, nullptr); + + // Create a Project operation on ds + ds = ds->Project({"image"}); + EXPECT_NE(ds, nullptr); + ds2 = ds2->Project({"image"}); + EXPECT_NE(ds, nullptr); + + // Create a Concat operation on the ds + ds = ds->Concat({ds2}); + EXPECT_NE(ds, nullptr); + + EXPECT_EQ(ds->GetDatasetSize(), 19); +} + TEST_F(MindDataTestPipeline, TestConcatSuccess2) { // Test "+" operator to concat two datasets MS_LOG(INFO) << "Doing MindDataTestPipeline-TestConcatSuccess2."; @@ -461,6 +489,27 @@ TEST_F(MindDataTestPipeline, TestImageFolderBatchAndRepeat) { iter->Stop(); } +TEST_F(MindDataTestPipeline, TestPipelineGetDatasetSize) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestPipelineGetDatasetSize."; + + // Create an ImageFolder Dataset + std::string folder_path = datasets_root_path_ + "/testPK/data/"; + std::shared_ptr ds = ImageFolder(folder_path, true, RandomSampler(false, 10)); + EXPECT_NE(ds, nullptr); + + // Create a Repeat operation on ds + int32_t repeat_num = 2; + ds = ds->Repeat(repeat_num); + EXPECT_NE(ds, nullptr); + + // Create a Batch operation on ds + int32_t batch_size = 2; + ds = ds->Batch(batch_size); + EXPECT_NE(ds, nullptr); + + EXPECT_EQ(ds->GetDatasetSize(), 10); +} + TEST_F(MindDataTestPipeline, TestProjectMap) { MS_LOG(INFO) << "Doing MindDataTestPipeline-TestProjectMap."; @@ -914,6 +963,22 @@ TEST_F(MindDataTestPipeline, TestSkipDataset) { iter->Stop(); } +TEST_F(MindDataTestPipeline, TestSkipGetDatasetSize) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSkipGetDatasetSize."; + + // Create an ImageFolder Dataset + std::string folder_path = datasets_root_path_ + "/testPK/data/"; + std::shared_ptr ds = ImageFolder(folder_path, true, RandomSampler(false, 10)); + EXPECT_NE(ds, nullptr); + + // Create a Skip operation on ds + int32_t count = 3; + ds = ds->Skip(count); + EXPECT_NE(ds, nullptr); + + EXPECT_EQ(ds->GetDatasetSize(), 7); +} + TEST_F(MindDataTestPipeline, TestSkipDatasetError1) { MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSkipDatasetError1."; @@ -966,6 +1031,21 @@ TEST_F(MindDataTestPipeline, TestTakeDatasetDefault) { iter->Stop(); } +TEST_F(MindDataTestPipeline, TestTakeGetDatasetSize) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTakeGetDatasetSize."; + + // Create an ImageFolder Dataset + std::string folder_path = datasets_root_path_ + "/testPK/data/"; + std::shared_ptr ds = ImageFolder(folder_path, true, RandomSampler(false, 7)); + EXPECT_NE(ds, nullptr); + + // Create a Take operation on ds, dafault count = -1 + ds = ds->Take(2); + EXPECT_NE(ds, nullptr); + + EXPECT_EQ(ds->GetDatasetSize(), 2); +} + TEST_F(MindDataTestPipeline, TestTakeDatasetError1) { MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTakeDatasetError1."; @@ -1190,6 +1270,44 @@ TEST_F(MindDataTestPipeline, TestZipSuccess) { iter->Stop(); } +TEST_F(MindDataTestPipeline, TestZipGetDatasetSize) { + // Testing the member zip() function + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestZipGetDatasetSize."; + + // Create an ImageFolder Dataset + std::string folder_path = datasets_root_path_ + "/testPK/data/"; + std::shared_ptr ds = ImageFolder(folder_path, true, RandomSampler(false, 2)); + EXPECT_NE(ds, nullptr); + + // Create a Project operation on ds + std::vector column_project = {"image"}; + ds = ds->Project(column_project); + EXPECT_NE(ds, nullptr); + + // Create an ImageFolder Dataset + std::shared_ptr ds1 = ImageFolder(folder_path, true, RandomSampler(false, 3)); + EXPECT_NE(ds1, nullptr); + + // Create a Rename operation on ds (so that the 3 datasets we are going to zip have distinct column names) + ds1 = ds1->Rename({"image", "label"}, {"col1", "col2"}); + EXPECT_NE(ds1, nullptr); + + folder_path = datasets_root_path_ + "/testCifar10Data/"; + std::shared_ptr ds2 = Cifar10(folder_path, "all", RandomSampler(false, 5)); + EXPECT_NE(ds2, nullptr); + + // Create a Project operation on ds + column_project = {"label"}; + ds2 = ds2->Project(column_project); + EXPECT_NE(ds2, nullptr); + + // Create a Zip operation on the datasets + ds = ds->Zip({ds1, ds2}); + EXPECT_NE(ds, nullptr); + + EXPECT_EQ(ds->GetDatasetSize(), 2); +} + TEST_F(MindDataTestPipeline, TestZipSuccess2) { // Testing the static zip() function MS_LOG(INFO) << "Doing MindDataTestPipeline-TestZipSuccess2."; diff --git a/tests/ut/cpp/dataset/c_api_dataset_randomdata_test.cc b/tests/ut/cpp/dataset/c_api_dataset_randomdata_test.cc index 9ec6d744ec..232fdbe25d 100644 --- a/tests/ut/cpp/dataset/c_api_dataset_randomdata_test.cc +++ b/tests/ut/cpp/dataset/c_api_dataset_randomdata_test.cc @@ -87,6 +87,19 @@ TEST_F(MindDataTestPipeline, TestRandomDatasetBasic1) { iter->Stop(); } +TEST_F(MindDataTestPipeline, TestRandomDatasetGetDatasetSize) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRandomDatasetGetDatasetSize."; + + // Create a RandomDataset + std::shared_ptr schema = Schema(); + schema->add_column("image", mindspore::TypeId::kNumberTypeUInt8, {2}); + schema->add_column("label", mindspore::TypeId::kNumberTypeUInt8, {1}); + std::shared_ptr ds = RandomData(50, schema); + EXPECT_NE(ds, nullptr); + + EXPECT_EQ(ds->GetDatasetSize(), 50); +} + TEST_F(MindDataTestPipeline, TestRandomDatasetBasic2) { MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRandomDatasetBasic2."; diff --git a/tests/ut/cpp/dataset/c_api_dataset_textfile_test.cc b/tests/ut/cpp/dataset/c_api_dataset_textfile_test.cc index 1ee83dc799..0b7b07a4b2 100644 --- a/tests/ut/cpp/dataset/c_api_dataset_textfile_test.cc +++ b/tests/ut/cpp/dataset/c_api_dataset_textfile_test.cc @@ -96,6 +96,32 @@ TEST_F(MindDataTestPipeline, TestTextFileDatasetBasic) { GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers); } +TEST_F(MindDataTestPipeline, TestTextFileGetDatasetSize) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTextFileGetDatasetSize."; + // Test TextFile Dataset with single text file and many default inputs + + // Set configuration + uint32_t original_seed = GlobalContext::config_manager()->seed(); + uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers(); + MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers; + GlobalContext::config_manager()->set_seed(987); + GlobalContext::config_manager()->set_num_parallel_workers(4); + + // Create a TextFile Dataset, with single text file + // Note: 1.txt has 3 rows + // Use 2 samples + // Use defaults for other input parameters + std::string tf_file1 = datasets_root_path_ + "/testTextFileDataset/1.txt"; + std::shared_ptr ds = TextFile({tf_file1}, 2); + EXPECT_NE(ds, nullptr); + + EXPECT_EQ(ds->GetDatasetSize(), 2); + + // Restore configuration + GlobalContext::config_manager()->set_seed(original_seed); + GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers); +} + TEST_F(MindDataTestPipeline, TestTextFileDatasetFail1) { MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTextFileDatasetFail1."; diff --git a/tests/ut/cpp/dataset/c_api_dataset_tfrecord_test.cc b/tests/ut/cpp/dataset/c_api_dataset_tfrecord_test.cc index 27c7785a49..339dc5d60a 100644 --- a/tests/ut/cpp/dataset/c_api_dataset_tfrecord_test.cc +++ b/tests/ut/cpp/dataset/c_api_dataset_tfrecord_test.cc @@ -98,6 +98,36 @@ TEST_F(MindDataTestPipeline, TestTFRecordDatasetBasic) { iter->Stop(); } +TEST_F(MindDataTestPipeline, TestTFRecordDatasetBasicGetDatasetSize) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTFRecordDatasetBasicGetDatasetSize."; + + // Create a TFRecord Dataset + std::string file_path = datasets_root_path_ + "/test_tf_file_3_images2/train-0000-of-0001.data"; + std::string schema_path = datasets_root_path_ + "/test_tf_file_3_images2/datasetSchema.json"; + std::shared_ptr ds = TFRecord({file_path}, schema_path, {"image"}, 0); + EXPECT_NE(ds, nullptr); + + // Create a Repeat operation on ds + int32_t repeat_num = 2; + ds = ds->Repeat(repeat_num); + EXPECT_NE(ds, nullptr); + + // Create objects for the tensor ops + std::shared_ptr random_horizontal_flip_op = vision::RandomHorizontalFlip(0.5); + EXPECT_NE(random_horizontal_flip_op, nullptr); + + // Create a Map operation on ds + ds = ds->Map({random_horizontal_flip_op}, {}, {}, {"image"}); + EXPECT_NE(ds, nullptr); + + // Create a Batch operation on ds + int32_t batch_size = 1; + ds = ds->Batch(batch_size); + EXPECT_NE(ds, nullptr); + + EXPECT_EQ(ds->GetDatasetSize(), 6); +} + TEST_F(MindDataTestPipeline, TestTFRecordDatasetShuffle) { MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTFRecordDatasetShuffle."; // This case is to verify if the list of datafiles are sorted in lexicographical order. diff --git a/tests/ut/cpp/dataset/c_api_dataset_voc_test.cc b/tests/ut/cpp/dataset/c_api_dataset_voc_test.cc index f09363f306..48f0dd557b 100644 --- a/tests/ut/cpp/dataset/c_api_dataset_voc_test.cc +++ b/tests/ut/cpp/dataset/c_api_dataset_voc_test.cc @@ -86,6 +86,22 @@ TEST_F(MindDataTestPipeline, TestVOCClassIndex) { iter->Stop(); } +TEST_F(MindDataTestPipeline, TestVOCGetDatasetSize) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestVOCGetDatasetSize."; + + // Create a VOC Dataset + std::string folder_path = datasets_root_path_ + "/testVOC2012_2"; + std::map class_index; + class_index["car"] = 0; + class_index["cat"] = 1; + class_index["train"] = 9; + + std::shared_ptr ds = VOC(folder_path, "Detection", "train", class_index, false, SequentialSampler(0, 6)); + EXPECT_NE(ds, nullptr); + + EXPECT_EQ(ds->GetDatasetSize(), 6); +} + TEST_F(MindDataTestPipeline, TestVOCDetection) { MS_LOG(INFO) << "Doing MindDataTestPipeline-TestVOCDetection."; diff --git a/tests/ut/cpp/dataset/c_api_datasets_test.cc b/tests/ut/cpp/dataset/c_api_datasets_test.cc index 5e6e588852..4d7af7c8f0 100644 --- a/tests/ut/cpp/dataset/c_api_datasets_test.cc +++ b/tests/ut/cpp/dataset/c_api_datasets_test.cc @@ -125,6 +125,17 @@ TEST_F(MindDataTestPipeline, TestCelebADefault) { iter->Stop(); } +TEST_F(MindDataTestPipeline, TestCelebAGetDatasetSize) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCelebAGetDatasetSize."; + + // Create a CelebA Dataset + std::string folder_path = datasets_root_path_ + "/testCelebAData/"; + std::shared_ptr ds = CelebA(folder_path, "valid"); + EXPECT_NE(ds, nullptr); + + EXPECT_EQ(ds->GetDatasetSize(), 1); +} + TEST_F(MindDataTestPipeline, TestCelebAException) { MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCelebAException."; @@ -179,6 +190,17 @@ TEST_F(MindDataTestPipeline, TestImageFolderFailWithWrongExtension) { iter->Stop(); } +TEST_F(MindDataTestPipeline, TestImageFolderGetDatasetSize) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestImageFolderGetDatasetSize."; + + // Create an ImageFolder Dataset + std::string folder_path = datasets_root_path_ + "/testPK/data/"; + std::shared_ptr ds = ImageFolder(folder_path, true); + EXPECT_NE(ds, nullptr); + + EXPECT_EQ(ds->GetDatasetSize(), 44); +} + TEST_F(MindDataTestPipeline, TestImageFolderFailWithNullSampler) { MS_LOG(INFO) << "Doing MindDataTestPipeline-TestImageFolderFailWithNullSampler."; @@ -199,6 +221,16 @@ TEST_F(MindDataTestPipeline, TestImageFolderFailWithWrongSampler) { EXPECT_EQ(ds, nullptr); } +TEST_F(MindDataTestPipeline, TestMnistGetDatasetSize) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestMnistGetDatasetSize."; + + // Create a Mnist Dataset + std::string folder_path = datasets_root_path_ + "/testMnistData/"; + std::shared_ptr ds = Mnist(folder_path, "all", RandomSampler(false, 20)); + EXPECT_NE(ds, nullptr); + EXPECT_EQ(ds->GetDatasetSize(), 20); +} + TEST_F(MindDataTestPipeline, TestMnistFailWithWrongDatasetDir) { MS_LOG(INFO) << "Doing MindDataTestPipeline-TestMnistFailWithWrongDatasetDir.";