From 9e868d7633bf9d1e673e69d52480d9b4e3efe047 Mon Sep 17 00:00:00 2001 From: hesham Date: Mon, 16 Nov 2020 15:55:11 -0500 Subject: [PATCH] Changes from python-api branch: Fix bug in batch getDatasetSize Add getDatasetSize support for generator_op.cc Fix header guards of some files --- mindspore/ccsrc/minddata/dataset/api/python/de_pipeline.h | 6 +++--- .../ccsrc/minddata/dataset/api/python/pybind_register.h | 6 +++--- .../ccsrc/minddata/dataset/engine/datasetops/batch_op.cc | 4 ++-- .../ccsrc/minddata/dataset/engine/datasetops/dataset_op.cc | 3 ++- .../dataset/engine/datasetops/source/generator_op.cc | 6 ++++++ .../dataset/engine/datasetops/source/generator_op.h | 2 ++ 6 files changed, 18 insertions(+), 9 deletions(-) diff --git a/mindspore/ccsrc/minddata/dataset/api/python/de_pipeline.h b/mindspore/ccsrc/minddata/dataset/api/python/de_pipeline.h index 7002237e58..afe93b1a99 100644 --- a/mindspore/ccsrc/minddata/dataset/api/python/de_pipeline.h +++ b/mindspore/ccsrc/minddata/dataset/api/python/de_pipeline.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_API_DE_PIPELINE_H_ -#define MINDSPORE_CCSRC_MINDDATA_DATASET_API_DE_PIPELINE_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_API_PYTHON_DE_PIPELINE_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_API_PYTHON_DE_PIPELINE_H_ #include #include @@ -260,4 +260,4 @@ class DEPipeline { }; } // namespace dataset } // namespace mindspore -#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_API_DE_PIPELINE_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_API_PYTHON_DE_PIPELINE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/api/python/pybind_register.h b/mindspore/ccsrc/minddata/dataset/api/python/pybind_register.h index 8717a2844c..35f79fd166 100644 --- a/mindspore/ccsrc/minddata/dataset/api/python/pybind_register.h +++ b/mindspore/ccsrc/minddata/dataset/api/python/pybind_register.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef API_PYBIND_API_H_ -#define API_PYBIND_API_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_API_PYTHON_PYBIND_REGISTER_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_API_PYTHON_PYBIND_REGISTER_H_ #include #include @@ -78,4 +78,4 @@ class PybindDefineRegisterer { #endif } // namespace dataset } // namespace mindspore -#endif // API_PYBIND_API_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_API_PYTHON_PYBIND_REGISTER_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/batch_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/batch_op.cc index 7b10f9bfc0..9689f75ee7 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/batch_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/batch_op.cc @@ -546,9 +546,9 @@ Status BatchOp::GetDatasetSize(int64_t *dataset_size) { RETURN_IF_NOT_OK(child_[0]->GetDatasetSize(&num_rows)); if (num_rows > 0 && start_batch_size_ > 0) { if (drop_) { - num_rows = floor(num_rows / start_batch_size_); + num_rows = floor(num_rows / (1.0 * start_batch_size_)); } else { - num_rows = ceil(num_rows / start_batch_size_); + num_rows = ceil(num_rows / (1.0 * start_batch_size_)); } } *dataset_size = num_rows; diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.cc index 905f9ead34..a0d21dae20 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.cc @@ -312,7 +312,8 @@ Status DatasetOp::GetNumClasses(int64_t *num_classes) { if (!child_.empty()) { return child_[0]->GetNumClasses(num_classes); } else { - RETURN_STATUS_UNEXPECTED("Can't get the dataset size for the current tree."); + *num_classes = -1; + RETURN_STATUS_UNEXPECTED("Can't get the number of classes for the current tree."); } } diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.cc index 4bf2205744..18734c333c 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.cc @@ -274,5 +274,11 @@ Status GeneratorOp::ComputeColMap() { } return Status::OK(); } +Status GeneratorOp::GetDatasetSize(int64_t *dataset_size) { // Get Dataset size + // We are returning -1 because we can't easily calculate GetDatasetSize. Returning -1 will make TreeGetters to + // iterate over the dataset and count the size + *dataset_size = dataset_size_; + return Status::OK(); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.h index ff451b0929..35d313733a 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.h @@ -136,6 +136,8 @@ class GeneratorOp : public PipelineOp { Status Init(); + Status GetDatasetSize(int64_t *dataset_size) override; + private: py::function generator_function_; std::vector column_names_;