Browse Source

Changes from python-api branch:

Fix bug in batch getDatasetSize
Add getDatasetSize support for generator_op.cc
Fix header guards of some files
tags/v1.1.0
hesham 5 years ago
parent
commit
9e868d7633
6 changed files with 18 additions and 9 deletions
  1. +3
    -3
      mindspore/ccsrc/minddata/dataset/api/python/de_pipeline.h
  2. +3
    -3
      mindspore/ccsrc/minddata/dataset/api/python/pybind_register.h
  3. +2
    -2
      mindspore/ccsrc/minddata/dataset/engine/datasetops/batch_op.cc
  4. +2
    -1
      mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.cc
  5. +6
    -0
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.cc
  6. +2
    -0
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.h

+ 3
- 3
mindspore/ccsrc/minddata/dataset/api/python/de_pipeline.h View File

@@ -13,8 +13,8 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_API_DE_PIPELINE_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_API_DE_PIPELINE_H_
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_API_PYTHON_DE_PIPELINE_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_API_PYTHON_DE_PIPELINE_H_

#include <iostream>
#include <map>
@@ -260,4 +260,4 @@ class DEPipeline {
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_API_DE_PIPELINE_H_
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_API_PYTHON_DE_PIPELINE_H_

+ 3
- 3
mindspore/ccsrc/minddata/dataset/api/python/pybind_register.h View File

@@ -14,8 +14,8 @@
* limitations under the License.
*/

#ifndef API_PYBIND_API_H_
#define API_PYBIND_API_H_
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_API_PYTHON_PYBIND_REGISTER_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_API_PYTHON_PYBIND_REGISTER_H_

#include <map>
#include <string>
@@ -78,4 +78,4 @@ class PybindDefineRegisterer {
#endif
} // namespace dataset
} // namespace mindspore
#endif // API_PYBIND_API_H_
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_API_PYTHON_PYBIND_REGISTER_H_

+ 2
- 2
mindspore/ccsrc/minddata/dataset/engine/datasetops/batch_op.cc View File

@@ -546,9 +546,9 @@ Status BatchOp::GetDatasetSize(int64_t *dataset_size) {
RETURN_IF_NOT_OK(child_[0]->GetDatasetSize(&num_rows));
if (num_rows > 0 && start_batch_size_ > 0) {
if (drop_) {
num_rows = floor(num_rows / start_batch_size_);
num_rows = floor(num_rows / (1.0 * start_batch_size_));
} else {
num_rows = ceil(num_rows / start_batch_size_);
num_rows = ceil(num_rows / (1.0 * start_batch_size_));
}
}
*dataset_size = num_rows;


+ 2
- 1
mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.cc View File

@@ -312,7 +312,8 @@ Status DatasetOp::GetNumClasses(int64_t *num_classes) {
if (!child_.empty()) {
return child_[0]->GetNumClasses(num_classes);
} else {
RETURN_STATUS_UNEXPECTED("Can't get the dataset size for the current tree.");
*num_classes = -1;
RETURN_STATUS_UNEXPECTED("Can't get the number of classes for the current tree.");
}
}



+ 6
- 0
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.cc View File

@@ -274,5 +274,11 @@ Status GeneratorOp::ComputeColMap() {
}
return Status::OK();
}
Status GeneratorOp::GetDatasetSize(int64_t *dataset_size) { // Get Dataset size
// We are returning -1 because we can't easily calculate GetDatasetSize. Returning -1 will make TreeGetters to
// iterate over the dataset and count the size
*dataset_size = dataset_size_;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

+ 2
- 0
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.h View File

@@ -136,6 +136,8 @@ class GeneratorOp : public PipelineOp {

Status Init();

Status GetDatasetSize(int64_t *dataset_size) override;

private:
py::function generator_function_;
std::vector<std::string> column_names_;


Loading…
Cancel
Save