Browse Source

!9249 Optimizing GetDatasetSize

From: @mahdirahmanihanzaki
Reviewed-by: 
Signed-off-by:
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
df44e1339e
100 changed files with 722 additions and 633 deletions
  1. +4
    -3
      mindspore/ccsrc/minddata/dataset/api/datasets.cc
  2. +1
    -20
      mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/engine/datasetops/source/bindings.cc
  3. +5
    -1
      mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/include/datasets_bindings.cc
  4. +13
    -6
      mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/include/iterator_bindings.cc
  5. +4
    -0
      mindspore/ccsrc/minddata/dataset/engine/consumers/python_tree_consumer.cc
  6. +4
    -0
      mindspore/ccsrc/minddata/dataset/engine/consumers/python_tree_consumer.h
  7. +41
    -23
      mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.cc
  8. +30
    -2
      mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.h
  9. +0
    -24
      mindspore/ccsrc/minddata/dataset/engine/datasetops/batch_op.cc
  10. +0
    -5
      mindspore/ccsrc/minddata/dataset/engine/datasetops/batch_op.h
  11. +0
    -7
      mindspore/ccsrc/minddata/dataset/engine/datasetops/bucket_batch_by_length_op.cc
  12. +0
    -5
      mindspore/ccsrc/minddata/dataset/engine/datasetops/bucket_batch_by_length_op.h
  13. +0
    -7
      mindspore/ccsrc/minddata/dataset/engine/datasetops/concat_op.cc
  14. +0
    -5
      mindspore/ccsrc/minddata/dataset/engine/datasetops/concat_op.h
  15. +0
    -18
      mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.cc
  16. +0
    -4
      mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.h
  17. +0
    -8
      mindspore/ccsrc/minddata/dataset/engine/datasetops/filter_op.cc
  18. +0
    -5
      mindspore/ccsrc/minddata/dataset/engine/datasetops/filter_op.h
  19. +0
    -15
      mindspore/ccsrc/minddata/dataset/engine/datasetops/repeat_op.cc
  20. +0
    -5
      mindspore/ccsrc/minddata/dataset/engine/datasetops/repeat_op.h
  21. +0
    -15
      mindspore/ccsrc/minddata/dataset/engine/datasetops/skip_op.cc
  22. +0
    -5
      mindspore/ccsrc/minddata/dataset/engine/datasetops/skip_op.h
  23. +0
    -58
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/celeba_op.cc
  24. +0
    -5
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/celeba_op.h
  25. +0
    -15
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/cifar_op.cc
  26. +0
    -5
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/cifar_op.h
  27. +0
    -14
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/clue_op.cc
  28. +0
    -5
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/clue_op.h
  29. +0
    -33
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/coco_op.cc
  30. +0
    -5
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/coco_op.h
  31. +0
    -14
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.cc
  32. +0
    -5
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.h
  33. +0
    -6
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.cc
  34. +0
    -2
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.h
  35. +0
    -18
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/image_folder_op.cc
  36. +0
    -5
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/image_folder_op.h
  37. +3
    -26
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/manifest_op.cc
  38. +10
    -8
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/manifest_op.h
  39. +0
    -17
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mindrecord_op.cc
  40. +0
    -5
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mindrecord_op.h
  41. +0
    -14
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mnist_op.cc
  42. +0
    -5
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mnist_op.h
  43. +0
    -18
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/random_data_op.cc
  44. +0
    -5
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/random_data_op.h
  45. +3
    -3
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.cc
  46. +5
    -0
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h
  47. +0
    -14
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.cc
  48. +0
    -5
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.h
  49. +0
    -36
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/tf_reader_op.cc
  50. +0
    -10
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/tf_reader_op.h
  51. +2
    -38
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/voc_op.cc
  52. +3
    -8
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/voc_op.h
  53. +0
    -12
      mindspore/ccsrc/minddata/dataset/engine/datasetops/take_op.cc
  54. +0
    -5
      mindspore/ccsrc/minddata/dataset/engine/datasetops/take_op.h
  55. +0
    -18
      mindspore/ccsrc/minddata/dataset/engine/datasetops/zip_op.cc
  56. +0
    -5
      mindspore/ccsrc/minddata/dataset/engine/datasetops/zip_op.h
  57. +28
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/batch_node.cc
  58. +9
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/batch_node.h
  59. +0
    -1
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.cc
  60. +2
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h
  61. +2
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/concat_node.h
  62. +25
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.cc
  63. +13
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.h
  64. +2
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/filter_node.h
  65. +17
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/repeat_node.cc
  66. +9
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/repeat_node.h
  67. +17
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/skip_node.cc
  68. +9
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/skip_node.h
  69. +62
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/celeba_node.cc
  70. +9
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/celeba_node.h
  71. +15
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar100_node.cc
  72. +9
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar100_node.h
  73. +15
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar10_node.cc
  74. +9
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar10_node.h
  75. +16
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/clue_node.cc
  76. +9
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/clue_node.h
  77. +15
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/coco_node.cc
  78. +9
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/coco_node.h
  79. +16
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/csv_node.cc
  80. +9
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/csv_node.h
  81. +0
    -1
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/generator_node.cc
  82. +7
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/generator_node.h
  83. +15
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/image_folder_node.cc
  84. +9
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/image_folder_node.h
  85. +16
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/manifest_node.cc
  86. +9
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/manifest_node.h
  87. +23
    -1
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/minddata_node.cc
  88. +10
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/minddata_node.h
  89. +15
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/mnist_node.cc
  90. +9
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/mnist_node.h
  91. +23
    -5
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/random_node.cc
  92. +10
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/random_node.h
  93. +15
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/text_file_node.cc
  94. +9
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/text_file_node.h
  95. +36
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/tf_record_node.cc
  96. +14
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/tf_record_node.h
  97. +15
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/voc_node.cc
  98. +9
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/voc_node.h
  99. +15
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/take_node.cc
  100. +9
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/take_node.h

+ 4
- 3
mindspore/ccsrc/minddata/dataset/api/datasets.cc View File

@@ -199,12 +199,13 @@ bool Dataset::Save(std::string dataset_path, int32_t num_files, std::string data
// Constructor
Dataset::Dataset() { tree_getters_ = std::make_shared<TreeGetters>(); }

int64_t Dataset::GetDatasetSize() {
int64_t Dataset::GetDatasetSize(bool estimate) {
int64_t dataset_size;
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
RETURN_SECOND_IF_ERROR(runtime_context->Init(), -1);
RETURN_SECOND_IF_ERROR(tree_getters_->Init(this->IRNode()), -1);
RETURN_SECOND_IF_ERROR(tree_getters_->GetDatasetSize(&dataset_size), -1);
std::shared_ptr<DatasetSizeGetter> size_getter = std::make_shared<DatasetSizeGetter>();
RETURN_SECOND_IF_ERROR(size_getter->Init(this->IRNode()), -1);
RETURN_SECOND_IF_ERROR(size_getter->GetDatasetSize(&dataset_size, estimate), -1);
return dataset_size;
}



+ 1
- 20
mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/engine/datasetops/source/bindings.cc View File

@@ -106,19 +106,7 @@ PYBIND_REGISTER(ImageFolderOp, 1, ([](const py::module *m) {
}));

PYBIND_REGISTER(ManifestOp, 1, ([](const py::module *m) {
(void)py::class_<ManifestOp, DatasetOp, std::shared_ptr<ManifestOp>>(*m, "ManifestOp")
.def_static("get_num_rows_and_classes",
[](const std::string &file, const py::dict &dict, const std::string &usage) {
int64_t count = 0, num_classes = 0;
THROW_IF_ERROR(ManifestOp::CountTotalRows(file, dict, usage, &count, &num_classes));
return py::make_tuple(count, num_classes);
})
.def_static("get_class_indexing", [](const std::string &file, const py::dict &dict,
const std::string &usage) {
std::map<std::string, int32_t> output_class_indexing;
THROW_IF_ERROR(ManifestOp::GetClassIndexing(file, dict, usage, &output_class_indexing));
return output_class_indexing;
});
(void)py::class_<ManifestOp, DatasetOp, std::shared_ptr<ManifestOp>>(*m, "ManifestOp");
}));
PYBIND_REGISTER(MindRecordOp, 1, ([](const py::module *m) {
(void)py::class_<MindRecordOp, DatasetOp, std::shared_ptr<MindRecordOp>>(*m, "MindRecordOp")
@@ -173,13 +161,6 @@ PYBIND_REGISTER(TFReaderOp, 1, ([](const py::module *m) {

PYBIND_REGISTER(VOCOp, 1, ([](const py::module *m) {
(void)py::class_<VOCOp, DatasetOp, std::shared_ptr<VOCOp>>(*m, "VOCOp")
.def_static("get_num_rows",
[](const std::string &dir, const std::string &task_type, const std::string &task_mode,
const py::dict &dict, int64_t numSamples) {
int64_t count = 0;
THROW_IF_ERROR(VOCOp::CountTotalRows(dir, task_type, task_mode, dict, &count));
return count;
})
.def_static("get_class_indexing", [](const std::string &dir, const std::string &task_type,
const std::string &task_mode, const py::dict &dict) {
std::map<std::string, int32_t> output_class_indexing;


+ 5
- 1
mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/include/datasets_bindings.cc View File

@@ -184,7 +184,11 @@ PYBIND_REGISTER(GeneratorNode, 2, ([](const py::module *m) {
auto gen = std::make_shared<GeneratorNode>(generator_function, schema);
THROW_IF_ERROR(gen->ValidateParams());
return gen;
}));
}))
.def("SetGeneratorDatasetSize", [](std::shared_ptr<GeneratorNode> self, int64_t sz) {
self->SetGeneratorDatasetSize(sz);
return self;
});
}));

PYBIND_REGISTER(ImageFolderNode, 2, ([](const py::module *m) {


+ 13
- 6
mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/include/iterator_bindings.cc View File

@@ -93,12 +93,6 @@ PYBIND_REGISTER(TreeGetters, 1, ([](const py::module *m) {
THROW_IF_ERROR(self.GetClassIndexing(&output_class_indexing));
return output_class_indexing;
})
.def("GetDatasetSize",
[](PythonTreeGetters &self) {
int64_t dataset_size;
THROW_IF_ERROR(self.GetDatasetSize(&dataset_size));
return dataset_size;
})
.def("__deepcopy__", [](py::object &tree_getter, py::dict memo) { return tree_getter; });
}));

@@ -164,5 +158,18 @@ PYBIND_REGISTER(PythonSaveToDisk, 1, ([](const py::module *m) {
.def("Save", [](PythonSaveToDisk &self) { THROW_IF_ERROR(self.Save()); });
}));

PYBIND_REGISTER(PythonDatasetSizeGetter, 1, ([](const py::module *m) {
(void)py::class_<PythonDatasetSizeGetter, TreeConsumer, std::shared_ptr<PythonDatasetSizeGetter>>(
*m, "DatasetSizeGetters")
.def(py::init<>())
.def("Init", [](PythonDatasetSizeGetter &self,
std::shared_ptr<DatasetNode> d) { THROW_IF_ERROR(self.Init(d)); })
.def("GetDatasetSize", [](PythonDatasetSizeGetter &self, bool estimate) {
int64_t size;
THROW_IF_ERROR(self.GetDatasetSize(&size, estimate));
return size;
});
}));

} // namespace dataset
} // namespace mindspore

+ 4
- 0
mindspore/ccsrc/minddata/dataset/engine/consumers/python_tree_consumer.cc View File

@@ -65,4 +65,8 @@ Status PythonTreeGetters::GetRow(TensorRow *r) {
py::gil_scoped_release gil_release;
return TreeGetters::GetRow(r);
}
Status PythonDatasetSizeGetter::GetRow(const std::shared_ptr<TreeAdapter> &tree_adapter, TensorRow *r) {
py::gil_scoped_release gil_release;
return DatasetSizeGetter::GetRow(tree_adapter, r);
}
} // namespace mindspore::dataset

+ 4
- 0
mindspore/ccsrc/minddata/dataset/engine/consumers/python_tree_consumer.h View File

@@ -60,5 +60,9 @@ class PythonTreeGetters : public TreeGetters {
public:
Status GetRow(TensorRow *r) override;
};
class PythonDatasetSizeGetter : public DatasetSizeGetter {
public:
Status GetRow(const std::shared_ptr<TreeAdapter> &tree_adapter, TensorRow *r) override;
};
} // namespace mindspore::dataset
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CONSUMERS_PYTHON_TREE_CONSUMER_H_

+ 41
- 23
mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.cc View File

@@ -451,29 +451,6 @@ Status TreeGetters::Init(std::shared_ptr<DatasetNode> d) {

Status TreeGetters::GetRow(TensorRow *row) { return tree_adapter_->GetNext(row); }

Status TreeGetters::GetDatasetSize(int64_t *dataset_size) {
if (dataset_size_ == -1) {
RETURN_IF_NOT_OK(InternalInit(static_cast<int8_t>(GetterPass::kDatasetSize)));
std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot());
RETURN_UNEXPECTED_IF_NULL(root);
RETURN_IF_NOT_OK(root->GetDatasetSize(dataset_size));
if (*dataset_size == -1) { // run through the tree and get everything
TensorRow row;
RETURN_IF_NOT_OK(GetRow(&row));
int64_t row_cnt = 0;
while (!row.empty()) {
++row_cnt;
RETURN_IF_NOT_OK(GetRow(&row));
}
*dataset_size = row_cnt;
}
dataset_size_ = *dataset_size; // save the previous result
}

*dataset_size = dataset_size_;
return Status::OK();
}

Status TreeGetters::GetOutputTypes(std::vector<DataType> *types) {
RETURN_IF_NOT_OK(GetFirstRowShapeAndType());
*types = first_row_type_;
@@ -573,5 +550,46 @@ Status BuildVocabConsumer::Start() {
CHECK_FAIL_RETURN_UNEXPECTED(row.empty(), "The fetched row from BuildVocab should be an EOE.");
return Status::OK();
}
Status DatasetSizeGetter::GetDatasetSize(int64_t *size, bool estimate) {
if (dataset_size_ == -1) {
RETURN_IF_NOT_OK(root_->GetDatasetSize(shared_from_this(), estimate, size));
dataset_size_ = *size; // save the previous result
}

*size = dataset_size_;
return Status::OK();
}
Status DatasetSizeGetter::Init(std::shared_ptr<DatasetNode> d) {
root_ = std::move(d);
return Status::OK();
}
Status DatasetSizeGetter::DryRun(std::shared_ptr<DatasetNode> ir_node, int64_t *dataset_size) {
std::shared_ptr<TreeAdapter> tree_adapter = std::make_shared<TreeAdapter>();
tree_adapters_.push_back(tree_adapter);
tree_adapter->SetPrePassOverride([](OptPass pre) {
pre.push_back(
std::make_unique<GetterPass>(static_cast<GetterPass::GetterType>(GetterPass::GetterType::kDatasetSize)));
return pre;
});
RETURN_IF_NOT_OK(tree_adapter->Compile(std::move(ir_node), 1));
TensorRow row;
RETURN_IF_NOT_OK(GetRow(tree_adapter, &row));
int64_t row_cnt = 0;
while (!row.empty()) {
++row_cnt;
RETURN_IF_NOT_OK(GetRow(tree_adapter, &row));
}
*dataset_size = row_cnt;
return Status::OK();
}
Status DatasetSizeGetter::GetRow(const std::shared_ptr<TreeAdapter> &tree_adapter, TensorRow *row) {
return tree_adapter->GetNext(row);
}
Status DatasetSizeGetter::Terminate() {
for (const auto &tree : tree_adapters_) {
RETURN_IF_NOT_OK(tree->AllTasks()->ServiceStop());
}
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

+ 30
- 2
mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.h View File

@@ -177,7 +177,6 @@ class TreeGetters : public TreeConsumer {
~TreeGetters() = default;
Status Init(std::shared_ptr<DatasetNode> d) override;

Status GetDatasetSize(int64_t *size);
Status GetOutputTypes(std::vector<DataType> *types);
Status GetOutputShapes(std::vector<TensorShape> *shapes);
Status GetBatchSize(int64_t *batch_size);
@@ -186,7 +185,7 @@ class TreeGetters : public TreeConsumer {
Status GetColumnNames(std::vector<std::string> *output);
Status GetClassIndexing(std::vector<std::pair<std::string, std::vector<int32_t>>> *output_class_indexing);
std::string Name() override { return "TreeGetters"; }
virtual Status GetRow(TensorRow *r);
virtual Status GetRow(TensorRow *row);

private:
Status GetFirstRowShapeAndType();
@@ -202,6 +201,35 @@ class TreeGetters : public TreeConsumer {
Status InternalInit();
};

/// Consumer that is used to get some pipeline information
class DatasetSizeGetter : public TreeConsumer, public std::enable_shared_from_this<DatasetSizeGetter> {
public:
DatasetSizeGetter() : dataset_size_(-1) {}
~DatasetSizeGetter() = default;
Status Init(std::shared_ptr<DatasetNode> d) override;
Status Terminate() override;

/// \brief Function to get the dataset size
/// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting
/// dataset size at the expense of accuracy.
/// \param[out] dataset_size the size of the dataset
/// \return Status of the function
Status GetDatasetSize(int64_t *size, bool estimate = false);

virtual Status GetRow(const std::shared_ptr<TreeAdapter> &tree_adapter, TensorRow *row);
std::string Name() override { return "DatasetSizeGetter"; }

/// \brief Gets the dataset size by iterating over the entire dataset on a sub tree starting from ir_node
/// param[in] ir_node The node that marks the top most of the sub tree on which we want to iterate
/// \return Status - The status code return
Status DryRun(std::shared_ptr<DatasetNode> ir_node, int64_t *dataset_size);

private:
std::shared_ptr<DatasetNode> root_;
std::vector<std::shared_ptr<TreeAdapter>> tree_adapters_;
int64_t dataset_size_;
};

class BuildVocabConsumer : public TreeConsumer {
public:
/// BuildVocabConsumer Constructor which will call the base class default constructor.


+ 0
- 24
mindspore/ccsrc/minddata/dataset/engine/datasetops/batch_op.cc View File

@@ -537,30 +537,6 @@ Status BatchOp::ComputeColMap() {
return Status::OK();
}

Status BatchOp::GetDatasetSize(int64_t *dataset_size) {
if (dataset_size_ > 0) {
*dataset_size = dataset_size_;
return Status::OK();
}
#ifdef ENABLE_PYTHON
if (batch_size_func_) {
*dataset_size = -1;
return Status::OK();
}
#endif
int64_t num_rows;
RETURN_IF_NOT_OK(child_[0]->GetDatasetSize(&num_rows));
if (num_rows > 0 && start_batch_size_ > 0) {
if (drop_) {
num_rows = static_cast<int64_t>(floor(num_rows / (1.0 * start_batch_size_)));
} else {
num_rows = static_cast<int64_t>(ceil(num_rows / (1.0 * start_batch_size_)));
}
}
*dataset_size = num_rows;
dataset_size_ = num_rows;
return Status::OK();
}
int64_t BatchOp::GetTreeBatchSize() {
#ifdef ENABLE_PYTHON
if (batch_size_func_) {


+ 0
- 5
mindspore/ccsrc/minddata/dataset/engine/datasetops/batch_op.h View File

@@ -225,11 +225,6 @@ class BatchOp : public ParallelOp {
static Status PadColumns(std::unique_ptr<TensorQTable> *table, const PadInfo &pad_info,
const std::unordered_map<std::string, int32_t> &column_name_id_map);

/// \brief Base-class override for GetDatasetSize
/// \param[out] dataset_size the size of the dataset
/// \return Status of the function
Status GetDatasetSize(int64_t *dataset_size) override;

int64_t GetTreeBatchSize() override;

protected:


+ 0
- 7
mindspore/ccsrc/minddata/dataset/engine/datasetops/bucket_batch_by_length_op.cc View File

@@ -232,12 +232,5 @@ Status BucketBatchByLengthOp::ComputeColMap() {
return Status::OK();
}

// Get Dataset size
Status BucketBatchByLengthOp::GetDatasetSize(int64_t *dataset_size) {
// We are returning -1 because we can't easily calculate GetDatasetSize. Returning -1 will make TreeGetters to
// iterate over the dataset and count the size
*dataset_size = dataset_size_;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

+ 0
- 5
mindspore/ccsrc/minddata/dataset/engine/datasetops/bucket_batch_by_length_op.h View File

@@ -112,11 +112,6 @@ class BucketBatchByLengthOp : public PipelineOp {

std::string Name() const override { return kBucketBatchByLengthOp; }

/// \brief Base-class override for GetDatasetSize
/// \param[out] dataset_size the size of the dataset
/// \return Status of the function
Status GetDatasetSize(int64_t *dataset_size) override;

// << Stream output operator overload
// @notes This allows you to write the debug print info using stream operators
// @param out - reference to the output stream being overloaded


+ 0
- 7
mindspore/ccsrc/minddata/dataset/engine/datasetops/concat_op.cc View File

@@ -196,12 +196,5 @@ Status ConcatOp::PreAccept(NodePass *p, bool *modified) {
return p->PreRunOnNode(shared_from_base<ConcatOp>(), modified);
}

// Get Dataset size
Status ConcatOp::GetDatasetSize(int64_t *dataset_size) {
// We are returning -1 because we can't easily calculate GetDatasetSize. Returning -1 will make TreeGetters to
// iterate over the dataset and count the size
*dataset_size = dataset_size_;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

+ 0
- 5
mindspore/ccsrc/minddata/dataset/engine/datasetops/concat_op.h View File

@@ -111,11 +111,6 @@ class ConcatOp : public PipelineOp {
/// \return Status of the node visit
Status PreAccept(NodePass *p, bool *modified) override;

/// \brief Base-class override for GetDatasetSize
/// \param[out] dataset_size the size of the dataset
/// \return Status of the function
Status GetDatasetSize(int64_t *dataset_size) override;

private:
Status Verify(int32_t id, const std::unique_ptr<DataBuffer> &buf);



+ 0
- 18
mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.cc View File

@@ -294,24 +294,6 @@ Status DatasetOp::GetNextInput(std::unique_ptr<DataBuffer> *p_buffer, int32_t wo
return Status::OK();
}

// Gets the dataset size
Status DatasetOp::GetDatasetSize(int64_t *dataset_size) {
if (dataset_size_ > 0) {
*dataset_size = dataset_size_;
return Status::OK();
}
if (child_.size() == 1) {
return child_[0]->GetDatasetSize(dataset_size);
} else if (child_.size() > 1) {
// It is okay for dataset to have more than 1 child, GetDatasetSize shouldn't fail in this case.
// This is done mostly for cache, which injects cache lookup/merge operators. Cache path will
// always be in front of the child_ structure, so we get the dataset size from the last child.
return child_[child_.size() - 1]->GetDatasetSize(dataset_size);
} else {
RETURN_STATUS_UNEXPECTED("Trying to get dataset size from leaf node, missing override");
}
}

// Gets the number of classes
Status DatasetOp::GetNumClasses(int64_t *num_classes) {
if (child_.size() == 1) {


+ 0
- 4
mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.h View File

@@ -180,10 +180,6 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
/// \return Status - The error code return
Status GetNextInput(std::unique_ptr<DataBuffer> *p_buffer, int32_t worker_id = 0, int32_t child_index = 0);

/// \brief Gets the dataset size
/// \return Status - The status code return
virtual Status GetDatasetSize(int64_t *dataset_size);

/// \brief Gets the batch size
/// \return Status - The status code return
virtual int64_t GetTreeBatchSize();


+ 0
- 8
mindspore/ccsrc/minddata/dataset/engine/datasetops/filter_op.cc View File

@@ -258,13 +258,5 @@ Status FilterOp::PreAccept(NodePass *p, bool *modified) {
return p->PreRunOnNode(shared_from_base<FilterOp>(), modified);
}

// Get Dataset size
Status FilterOp::GetDatasetSize(int64_t *dataset_size) {
// We are returning -1 because we can't easily calculate GetDatasetSize. Returning -1 will make TreeGetters to
// iterate over the dataset and count the size
*dataset_size = dataset_size_;
return Status::OK();
}

} // namespace dataset
} // namespace mindspore

+ 0
- 5
mindspore/ccsrc/minddata/dataset/engine/datasetops/filter_op.h View File

@@ -137,11 +137,6 @@ class FilterOp : public ParallelOp {
// @return Name of the current Op
std::string Name() const override { return kFilterOp; }

/// \brief Base-class override for GetDatasetSize
/// \param[out] dataset_size the size of the dataset
/// \return Status of the function
Status GetDatasetSize(int64_t *dataset_size) override;

private:
// predicate_func python callable which returns a boolean value.
std::shared_ptr<TensorOp> predicate_func_;


+ 0
- 15
mindspore/ccsrc/minddata/dataset/engine/datasetops/repeat_op.cc View File

@@ -187,21 +187,6 @@ Status RepeatOp::Accept(NodePass *p, bool *modified) {
return p->RunOnNode(shared_from_base<RepeatOp>(), modified);
}

// Get Dataset size
Status RepeatOp::GetDatasetSize(int64_t *dataset_size) {
if (dataset_size_ > 0) {
*dataset_size = dataset_size_;
return Status::OK();
}
int64_t num_rows;
RETURN_IF_NOT_OK(child_[0]->GetDatasetSize(&num_rows));
if (num_rows > 0 && num_repeats_ > 0) {
num_rows = num_rows * num_repeats_;
}
*dataset_size = num_rows;
dataset_size_ = num_rows;
return Status::OK();
}
int64_t RepeatOp::GetTreeRepeatCount() { return num_repeats_; }
} // namespace dataset
} // namespace mindspore

+ 0
- 5
mindspore/ccsrc/minddata/dataset/engine/datasetops/repeat_op.h View File

@@ -133,11 +133,6 @@ class RepeatOp : public PipelineOp {
/// \@return Status - The error code return
Status Reset() override;

/// \brief Base-class override for GetDatasetSize
/// \param[out] dataset_size the size of the dataset
/// \return Status of the function
Status GetDatasetSize(int64_t *dataset_size) override;

int64_t GetTreeRepeatCount() override;

// \brief Adds an operator to the repeat ops list of tracked leaf/eoe nodes


+ 0
- 15
mindspore/ccsrc/minddata/dataset/engine/datasetops/skip_op.cc View File

@@ -136,20 +136,5 @@ Status SkipOp::PreAccept(NodePass *p, bool *modified) {
return p->PreRunOnNode(shared_from_base<SkipOp>(), modified);
}

// Get Dataset size
Status SkipOp::GetDatasetSize(int64_t *dataset_size) {
if (dataset_size_ > 0) {
*dataset_size = dataset_size_;
return Status::OK();
}
int64_t num_rows;
RETURN_IF_NOT_OK(child_[0]->GetDatasetSize(&num_rows));
*dataset_size = 0;
if (max_skips_ >= 0 && max_skips_ < num_rows) {
*dataset_size = num_rows - max_skips_;
}
dataset_size_ = *dataset_size;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

+ 0
- 5
mindspore/ccsrc/minddata/dataset/engine/datasetops/skip_op.h View File

@@ -86,11 +86,6 @@ class SkipOp : public PipelineOp {
/// \return Status of the node visit
Status PreAccept(NodePass *p, bool *modified) override;

/// \brief Base-class override for GetDatasetSize
/// \param[out] dataset_size the size of the dataset
/// \return Status of the function
Status GetDatasetSize(int64_t *dataset_size) override;

// Op name getter
// @return Name of the current Op
std::string Name() const override { return kSkipOp; }


+ 0
- 58
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/celeba_op.cc View File

@@ -452,63 +452,5 @@ Status CelebAOp::ComputeColMap() {
return Status::OK();
}

// Get Dataset size
Status CelebAOp::GetDatasetSize(int64_t *dataset_size) {
int64_t num_rows, sample_size;
std::string line;
Path folder_path(folder_path_);
std::ifstream attr_file((folder_path / "list_attr_celeba.txt").toString());
if (!attr_file.is_open()) {
std::string attr_file_name = (folder_path / "list_attr_celeba.txt").toString();
RETURN_STATUS_UNEXPECTED("Invalid file, failed to open Celeba attr file: " + attr_file_name);
}

std::string rows_num;
(void)getline(attr_file, rows_num);
try {
num_rows = static_cast<int64_t>(std::stoul(rows_num)); // First line is rows number in attr file
} catch (std::invalid_argument &e) {
RETURN_STATUS_UNEXPECTED(
"Invalid data, failed to convert rows_num from attr_file to unsigned long, invalid argument: " + rows_num);
} catch (std::out_of_range &e) {
RETURN_STATUS_UNEXPECTED(
"Invalid data, failed to convert rows_num from attr_file to unsigned long, out of range: " + rows_num);
}
if (usage_ != "all") {
int64_t partition_num = 0;
char usage_type;
if (usage_ == "train") {
usage_type = '0';
} else {
if (usage_ == "valid") {
usage_type = '1';
} else {
if (usage_ == "test")
usage_type = '2';
else
RETURN_STATUS_UNEXPECTED("Invalid usage.");
}
}
if (!partition_file_.is_open()) {
partition_file_.open((folder_path / "list_eval_partition.txt").toString());
}
if (partition_file_.is_open()) {
while (getline(partition_file_, line)) {
int start = line.find(' ');
if (line.at(start + 1) == usage_type) {
partition_num++;
}
}
} else {
std::string partition_file_name = "list_eval_partition.txt";
RETURN_STATUS_UNEXPECTED("Invalid file, failed to open Celeba partition file: " + partition_file_name);
}
num_rows = std::min(num_rows, partition_num);
}

sample_size = sampler_->CalculateNumSamples(num_rows);
*dataset_size = sample_size;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

+ 0
- 5
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/celeba_op.h View File

@@ -179,11 +179,6 @@ class CelebAOp : public ParallelOp, RandomAccessOp {
// @return Name of the current Op
std::string Name() const override { return "CelebAOp"; }

/// \brief Base-class override for GetDatasetSize
/// \param[out] dataset_size the size of the dataset
/// \return Status of the function
Status GetDatasetSize(int64_t *dataset_size) override;

private:
// Called first when function is called
// @return


+ 0
- 15
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/cifar_op.cc View File

@@ -508,20 +508,5 @@ Status CifarOp::ComputeColMap() {
return Status::OK();
}

// Get Dataset size
Status CifarOp::GetDatasetSize(int64_t *dataset_size) {
if (dataset_size_ > 0) {
*dataset_size = dataset_size_;
return Status::OK();
}
int64_t num_rows, sample_size;
num_rows = num_rows_;
if (num_rows_ <= 0)
RETURN_IF_NOT_OK(CountTotalRows(folder_path_, usage_, cifar_type_ == CifarType::kCifar10, &num_rows));
sample_size = sampler_->CalculateNumSamples(num_rows);
*dataset_size = sample_size;
dataset_size_ = *dataset_size;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

+ 0
- 5
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/cifar_op.h View File

@@ -175,11 +175,6 @@ class CifarOp : public ParallelOp, public RandomAccessOp {
// @return Name of the current Op
std::string Name() const override { return "CifarOp"; }

/// \brief Base-class override for GetDatasetSize
/// \param[out] dataset_size the size of the dataset
/// \return Status of the function
Status GetDatasetSize(int64_t *dataset_size) override;

private:
// Initialize Sampler, calls sampler->Init() within
// @return Status - The error code return


+ 0
- 14
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/clue_op.cc View File

@@ -565,19 +565,5 @@ Status ClueOp::Accept(NodePass *p, bool *modified) {
return p->RunOnNode(shared_from_base<ClueOp>(), modified);
}

// Get Dataset size
Status ClueOp::GetDatasetSize(int64_t *dataset_size) {
if (dataset_size_ > 0) {
*dataset_size = dataset_size_;
return Status::OK();
}
int64_t num_rows, sample_size;
if (num_rows_per_shard_ <= 0) RETURN_IF_NOT_OK(CalculateNumRowsPerShard());
sample_size = num_samples_;
num_rows = num_rows_per_shard_;
*dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows;
dataset_size_ = *dataset_size;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

+ 0
- 5
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/clue_op.h View File

@@ -197,11 +197,6 @@ class ClueOp : public ParallelOp {
// @return - Status of the node visit.
Status Accept(NodePass *p, bool *modified) override;

/// \brief Base-class override for GetDatasetSize
/// \param[out] dataset_size the size of the dataset
/// \return Status of the function
Status GetDatasetSize(int64_t *dataset_size) override;

private:
// The entry point for when workers are launched.
// @param worker_id - the id of the worker that is executing this function.


+ 0
- 33
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/coco_op.cc View File

@@ -681,39 +681,6 @@ Status CocoOp::ComputeColMap() {
return Status::OK();
}

// Get Dataset size
Status CocoOp::GetDatasetSize(int64_t *dataset_size) {
if (dataset_size_ > 0) {
*dataset_size = dataset_size_;
return Status::OK();
}
int64_t num_rows = 0, sample_size;
std::string task_type;
switch (task_type_) {
case TaskType::Detection:
task_type = "Detection";
break;
case TaskType::Keypoint:
task_type = "Keypoint";
break;
case TaskType::Panoptic:
task_type = "Panoptic";
break;
case TaskType::Stuff:
task_type = "Stuff";
break;
}
if (image_ids_.size() == 0) {
RETURN_IF_NOT_OK(CountTotalRows(image_folder_path_, annotation_path_, task_type, &num_rows));
} else {
num_rows = image_ids_.size();
}
sample_size = sampler_->CalculateNumSamples(num_rows);
*dataset_size = sample_size;
dataset_size_ = *dataset_size;
return Status::OK();
}

Status CocoOp::GetClassIndexing(std::vector<std::pair<std::string, std::vector<int32_t>>> *output_class_indexing) {
if ((*output_class_indexing).empty()) {
if ((task_type_ != TaskType::Detection) && (task_type_ != TaskType::Panoptic)) {


+ 0
- 5
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/coco_op.h View File

@@ -213,11 +213,6 @@ class CocoOp : public ParallelOp, public RandomAccessOp {
// @return Name of the current Op
std::string Name() const override { return "CocoOp"; }

/// \brief Base-class override for GetDatasetSize
/// \param[out] dataset_size the size of the dataset
/// \return Status of the function
Status GetDatasetSize(int64_t *dataset_size) override;

/// \brief Gets the class indexing
/// \return Status - The status code return
Status GetClassIndexing(std::vector<std::pair<std::string, std::vector<int32_t>>> *output_class_indexing) override;


+ 0
- 14
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.cc View File

@@ -916,19 +916,5 @@ Status CsvOp::Accept(NodePass *p, bool *modified) {
return p->RunOnNode(shared_from_base<CsvOp>(), modified);
}

// Get Dataset size
Status CsvOp::GetDatasetSize(int64_t *dataset_size) {
if (dataset_size_ > 0) {
*dataset_size = dataset_size_;
return Status::OK();
}
int64_t num_rows, sample_size;
if (num_rows_per_shard_ <= 0) RETURN_IF_NOT_OK(CalculateNumRowsPerShard());
sample_size = num_samples_;
num_rows = num_rows_per_shard_;
*dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows;
dataset_size_ = *dataset_size;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

+ 0
- 5
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.h View File

@@ -318,11 +318,6 @@ class CsvOp : public ParallelOp {
// @return - Status of the node visit.
Status Accept(NodePass *p, bool *modified) override;

/// \brief Base-class override for GetDatasetSize
/// \param[out] dataset_size the size of the dataset
/// \return Status of the function
Status GetDatasetSize(int64_t *dataset_size) override;

private:
// The entry point for when workers are launched.
// @param worker_id - the id of the worker that is executing this function.


+ 0
- 6
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.cc View File

@@ -274,11 +274,5 @@ Status GeneratorOp::ComputeColMap() {
}
return Status::OK();
}
Status GeneratorOp::GetDatasetSize(int64_t *dataset_size) { // Get Dataset size
// We are returning -1 because we can't easily calculate GetDatasetSize. Returning -1 will make TreeGetters to
// iterate over the dataset and count the size
*dataset_size = dataset_size_;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

+ 0
- 2
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.h View File

@@ -136,8 +136,6 @@ class GeneratorOp : public PipelineOp {

Status Init();

Status GetDatasetSize(int64_t *dataset_size) override;

private:
py::function generator_function_;
std::vector<std::string> column_names_;


+ 0
- 18
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/image_folder_op.cc View File

@@ -465,24 +465,6 @@ Status ImageFolderOp::ComputeColMap() {
return Status::OK();
}

// Get Dataset size
Status ImageFolderOp::GetDatasetSize(int64_t *dataset_size) {
if (dataset_size_ > 0) {
*dataset_size = dataset_size_;
return Status::OK();
}
int64_t sample_size, num_rows;
num_rows = num_rows_;
if (num_rows_ <= 0) {
// GetDatasetSize will not be impacted by class_index_
RETURN_IF_NOT_OK(CountRowsAndClasses(folder_path_, extensions_, &num_rows, nullptr, {}));
}
sample_size = sampler_->CalculateNumSamples(num_rows);
*dataset_size = sample_size;
dataset_size_ = *dataset_size;
return Status::OK();
}

// Get number of classes
Status ImageFolderOp::GetNumClasses(int64_t *num_classes) {
if (num_classes_ > 0) {


+ 0
- 5
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/image_folder_op.h View File

@@ -217,11 +217,6 @@ class ImageFolderOp : public ParallelOp, public RandomAccessOp {
// @return Name of the current Op
std::string Name() const override { return "ImageFolderOp"; }

/// \brief Base-class override for GetDatasetSize
/// \param[out] dataset_size the size of the dataset
/// \return Status of the function
Status GetDatasetSize(int64_t *dataset_size) override;

/// \brief Base-class override for GetNumClasses
/// \param[out] num_classes the number of classes
/// \return Status of the function


+ 3
- 26
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/manifest_op.cc View File

@@ -396,16 +396,9 @@ Status ManifestOp::CountDatasetInfo() {
return Status::OK();
}

#ifdef ENABLE_PYTHON
Status ManifestOp::CountTotalRows(const std::string &file, const py::dict &dict, const std::string &usage,
int64_t *count, int64_t *numClasses) {
Status ManifestOp::CountTotalRows(const std::string &file, const std::map<std::string, int32_t> &map,
const std::string &usage, int64_t *count, int64_t *numClasses) {
// the logic of counting the number of samples is copied from ParseManifestFile()
std::map<std::string, int32_t> map;
for (auto p : dict) {
(void)map.insert(std::pair<std::string, int32_t>(py::reinterpret_borrow<py::str>(p.first),
py::reinterpret_borrow<py::int_>(p.second)));
}

std::shared_ptr<ManifestOp> op;
*count = 0;
RETURN_IF_NOT_OK(Builder().SetManifestFile(file).SetClassIndex(map).SetUsage(usage).Build(&op));
@@ -415,6 +408,7 @@ Status ManifestOp::CountTotalRows(const std::string &file, const py::dict &dict,
return Status::OK();
}

#ifdef ENABLE_PYTHON
Status ManifestOp::GetClassIndexing(const std::string &file, const py::dict &dict, const std::string &usage,
std::map<std::string, int32_t> *output_class_indexing) {
std::map<std::string, int32_t> input_class_indexing;
@@ -459,23 +453,6 @@ Status ManifestOp::ComputeColMap() {
return Status::OK();
}

// Get Dataset size
Status ManifestOp::GetDatasetSize(int64_t *dataset_size) {
if (dataset_size_ > 0) {
*dataset_size = dataset_size_;
return Status::OK();
}
int64_t num_rows, sample_size;
std::shared_ptr<ManifestOp> op;
RETURN_IF_NOT_OK(Builder().SetManifestFile(file_).SetClassIndex(class_index_).SetUsage(usage_).Build(&op));
RETURN_IF_NOT_OK(op->ParseManifestFile());
num_rows = static_cast<int64_t>(op->image_labelname_.size());
sample_size = sampler_->CalculateNumSamples(num_rows);
*dataset_size = sample_size;
dataset_size_ = *dataset_size;
return Status::OK();
}

// Get number of classes
Status ManifestOp::GetNumClasses(int64_t *num_classes) {
if (num_classes_ > 0) {


+ 10
- 8
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/manifest_op.h View File

@@ -164,10 +164,17 @@ class ManifestOp : public ParallelOp, public RandomAccessOp {
// @param show_all
void Print(std::ostream &out, bool show_all) const override;

#ifdef ENABLE_PYTHON
static Status CountTotalRows(const std::string &file, const py::dict &dict, const std::string &usage, int64_t *count,
int64_t *numClasses);
/// \brief Counts the total number of rows in Manifest
/// \param[in] file Dataset file path
/// \param[in] input_class_indexing Input map of class index
/// \param[in] usage Dataset usage
/// \param[out] count Number of rows counted
/// \param[out] numClasses Number of classes counted
/// \return Status of the function
static Status CountTotalRows(const std::string &file, const std::map<std::string, int32_t> &map,
const std::string &usage, int64_t *count, int64_t *numClasses);

#ifdef ENABLE_PYTHON
// Get str-to-int mapping from label name to index
static Status GetClassIndexing(const std::string &file, const py::dict &dict, const std::string &usage,
std::map<std::string, int32_t> *output_class_indexing);
@@ -183,11 +190,6 @@ class ManifestOp : public ParallelOp, public RandomAccessOp {
// @return Name of the current Op
std::string Name() const override { return "ManifestOp"; }

/// \brief Base-class override for GetDatasetSize
/// \param[out] dataset_size the size of the dataset
/// \return Status of the function
Status GetDatasetSize(int64_t *dataset_size) override;

/// \brief Base-class override for GetNumClasses
/// \param[out] num_classes the number of classes
/// \return Status of the function


+ 0
- 17
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mindrecord_op.cc View File

@@ -474,22 +474,5 @@ Status MindRecordOp::ComputeColMap() {
return Status::OK();
}

// Get Dataset size
Status MindRecordOp::GetDatasetSize(int64_t *dataset_size) {
if (dataset_size_ > 0) {
*dataset_size = dataset_size_;
return Status::OK();
}
int64_t num_rows = num_rows_;
if (num_rows_ <= 0) {
// The last operator is parent sampler
std::shared_ptr<ShardOperator> op = operators_.back();
RETURN_IF_NOT_OK(CountTotalRows(dataset_file_, load_dataset_, op, &num_rows, num_padded_));
}
*dataset_size = num_rows;
dataset_size_ = *dataset_size;
return Status::OK();
}

} // namespace dataset
} // namespace mindspore

+ 0
- 5
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mindrecord_op.h View File

@@ -212,11 +212,6 @@ class MindRecordOp : public ParallelOp {
// @return Name of the current Op
std::string Name() const override { return "MindRecordOp"; }

/// \brief Base-class override for GetDatasetSize
/// \param[out] dataset_size the size of the dataset
/// \return Status of the function
Status GetDatasetSize(int64_t *dataset_size) override;

private:
Status GetBufferFromReader(std::unique_ptr<DataBuffer> *fetched_buffer, int64_t buffer_id, int32_t worker_id);



+ 0
- 14
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mnist_op.cc View File

@@ -471,19 +471,5 @@ Status MnistOp::ComputeColMap() {
return Status::OK();
}

// Get Dataset size
Status MnistOp::GetDatasetSize(int64_t *dataset_size) {
if (dataset_size_ > 0) {
*dataset_size = dataset_size_;
return Status::OK();
}
int64_t num_rows, sample_size;
num_rows = num_rows_;
if (num_rows_ <= 0) RETURN_IF_NOT_OK(CountTotalRows(folder_path_, usage_, &num_rows));
sample_size = sampler_->CalculateNumSamples(num_rows);
*dataset_size = sample_size;
dataset_size_ = *dataset_size;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

+ 0
- 5
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mnist_op.h View File

@@ -168,11 +168,6 @@ class MnistOp : public ParallelOp, public RandomAccessOp {
// @return Name of the current Op
std::string Name() const override { return "MnistOp"; }

/// \brief Base-class override for GetDatasetSize
/// \param[out] dataset_size the size of the dataset
/// \return Status of the function
Status GetDatasetSize(int64_t *dataset_size) override;

private:
// Initialize Sampler, calls sampler->Init() within
// @return Status - The error code return


+ 0
- 18
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/random_data_op.cc View File

@@ -421,23 +421,5 @@ Status RandomDataOp::ComputeColMap() {
return Status::OK();
}

// Get Dataset size
Status RandomDataOp::GetDatasetSize(int64_t *dataset_size) {
if (dataset_size_ > 0) {
*dataset_size = dataset_size_;
return Status::OK();
}
int64_t num_rows;
num_rows = total_rows_ != 0 ? total_rows_ : data_schema_->num_rows();
if (sampler_ != nullptr) {
int64_t sample_size;
sample_size = sampler_->CalculateNumSamples(num_rows);
*dataset_size = sample_size;
} else {
*dataset_size = num_rows;
}
dataset_size_ = *dataset_size;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

+ 0
- 5
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/random_data_op.h View File

@@ -203,11 +203,6 @@ class RandomDataOp : public ParallelOp {
// @return Name of the current Op
std::string Name() const override { return "RandomDataOp"; }

/// \brief Base-class override for GetDatasetSize
/// \param[out] dataset_size the size of the dataset
/// \return Status of the function
Status GetDatasetSize(int64_t *dataset_size) override;

private:
/**
* The entry point code for when workers are launched


+ 3
- 3
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.cc View File

@@ -162,11 +162,11 @@ Status DistributedSamplerRT::ResetSampler() {
}

int64_t DistributedSamplerRT::CalculateNumSamples(int64_t num_rows) {
int64_t childs = num_rows;
int64_t child_num_rows = num_rows;
if (!child_.empty()) {
childs = child_[0]->CalculateNumSamples(num_rows);
child_num_rows = child_[0]->CalculateNumSamples(num_rows);
}
int64_t num_samples = (num_samples_ > 0) ? std::min(childs, num_samples_) : childs;
int64_t num_samples = (num_samples_ > 0) ? std::min(child_num_rows, num_samples_) : child_num_rows;
return std::ceil(num_samples * 1.0 / num_devices_);
}



+ 5
- 0
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h View File

@@ -63,6 +63,11 @@ class DistributedSamplerRT : public SamplerRT {

int64_t GetDeviceNum() { return num_devices_; }

/// \brief Recursively calls this function on its children to get the actual number of samples on a tree of samplers
/// \note This is not a getter for num_samples_. For example, if num_samples_ is 0 or if it's smaller than num_rows,
/// then num_samples_ is not returned at all.
/// \param[in] num_rows The total number of rows in the dataset
/// \return int64_t Calculated number of samples
int64_t CalculateNumSamples(int64_t num_rows) override;

void Print(std::ostream &out, bool show_all) const override;


+ 0
- 14
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.cc View File

@@ -520,19 +520,5 @@ Status TextFileOp::Accept(NodePass *p, bool *modified) {
return p->RunOnNode(shared_from_base<TextFileOp>(), modified);
}

// Get Dataset size
Status TextFileOp::GetDatasetSize(int64_t *dataset_size) {
if (dataset_size_ > 0) {
*dataset_size = dataset_size_;
return Status::OK();
}
int64_t num_rows, sample_size;
sample_size = total_rows_;
if (num_rows_per_shard_ <= 0) RETURN_IF_NOT_OK(CalculateNumRowsPerShard());
num_rows = num_rows_per_shard_;
*dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows;
dataset_size_ = *dataset_size;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

+ 0
- 5
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.h View File

@@ -198,11 +198,6 @@ class TextFileOp : public ParallelOp {
// @return - Status of the node visit.
Status Accept(NodePass *p, bool *modified) override;

/// \brief Base-class override for GetDatasetSize
/// \param[out] dataset_size the size of the dataset
/// \return Status of the function
Status GetDatasetSize(int64_t *dataset_size) override;

private:
// The entry point for when workers are launched.
// @param worker_id - the id of the worker that is executing this function.


+ 0
- 36
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/tf_reader_op.cc View File

@@ -1067,41 +1067,5 @@ Status TFReaderOp::PrepareNodePostAction() {
return Status::OK();
}

// Get the file list of the specific shard ID
Status TFReaderOp::GetShardFileList(std::vector<std::string> *shard_filenames) {
if (!shard_filenames->empty()) {
RETURN_STATUS_UNEXPECTED("The initial file list must be empty.\n");
}
for (int index = 0; index < dataset_files_list_.size(); index++) {
if (index % num_devices_ == device_id_) {
shard_filenames->push_back(dataset_files_list_.at(index));
}
}
return Status::OK();
}

// Get Dataset size
Status TFReaderOp::GetDatasetSize(int64_t *dataset_size) {
if (dataset_size_ > 0) {
*dataset_size = dataset_size_;
return Status::OK();
}
int64_t num_rows, sample_size;
num_rows = num_rows_;
if (num_rows_ <= 0) {
if (equal_rows_per_shard_) {
RETURN_IF_NOT_OK(CalculateNumRowsPerShard());
num_rows = num_rows_per_shard_;
} else {
std::vector<std::string> shard_file_list;
RETURN_IF_NOT_OK(GetShardFileList(&shard_file_list));
RETURN_IF_NOT_OK(CountTotalRows(&num_rows, shard_file_list));
}
}
sample_size = total_rows_ != 0 ? total_rows_ : data_schema_->num_rows();
*dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows;
dataset_size_ = *dataset_size;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

+ 0
- 10
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/tf_reader_op.h View File

@@ -257,11 +257,6 @@ class TFReaderOp : public ParallelOp {
// before providing their own implementations.
Status PrepareNodePostAction() override;

/// \brief Base-class override for GetDatasetSize
/// \param[out] dataset_size the size of the dataset
/// \return Status of the function
Status GetDatasetSize(int64_t *dataset_size) override;

static bool ValidateFirstRowCrc(const std::string &filename);

private:
@@ -400,11 +395,6 @@ class TFReaderOp : public ParallelOp {
// @return - Status
Status ComputeColMap() override;

// Private function for computing the file list of the specific shard ID. This is because in distributed scenario,
// data will be divided into shards by row when equal_rows_per_shard is true, but by file in the opposite case.
// @return - Status - the status code returned.
Status GetShardFileList(std::vector<std::string> *shard_filenames);

int32_t device_id_;
int32_t num_devices_;
int64_t rows_per_buffer_;


+ 2
- 38
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/voc_op.cc View File

@@ -447,16 +447,9 @@ Status VOCOp::ReadAnnotationToTensor(const std::string &path, TensorRow *row) {
return Status::OK();
}

#ifdef ENABLE_PYTHON
Status VOCOp::CountTotalRows(const std::string &dir, const std::string &task_type, const std::string &task_mode,
const py::dict &dict, int64_t *count) {
const std::map<std::string, int32_t> &input_class_indexing, int64_t *count) {
if (task_type == "Detection") {
std::map<std::string, int32_t> input_class_indexing;
for (auto p : dict) {
(void)input_class_indexing.insert(std::pair<std::string, int32_t>(py::reinterpret_borrow<py::str>(p.first),
py::reinterpret_borrow<py::int_>(p.second)));
}

std::shared_ptr<VOCOp> op;
RETURN_IF_NOT_OK(
Builder().SetDir(dir).SetTask(task_type).SetUsage(task_mode).SetClassIndex(input_class_indexing).Build(&op));
@@ -473,6 +466,7 @@ Status VOCOp::CountTotalRows(const std::string &dir, const std::string &task_typ
return Status::OK();
}

#ifdef ENABLE_PYTHON
Status VOCOp::GetClassIndexing(const std::string &dir, const std::string &task_type, const std::string &task_mode,
const py::dict &dict, std::map<std::string, int32_t> *output_class_indexing) {
std::map<std::string, int32_t> input_class_indexing;
@@ -516,36 +510,6 @@ Status VOCOp::ComputeColMap() {
return Status::OK();
}

// Get Dataset size
Status VOCOp::GetDatasetSize(int64_t *dataset_size) {
if (dataset_size_ > 0) {
*dataset_size = dataset_size_;
return Status::OK();
}
int64_t num_rows = 0, sample_size;
if (image_ids_.size() == 0) {
if (task_type_ == TaskType::Detection) {
std::shared_ptr<VOCOp> op;
RETURN_IF_NOT_OK(
Builder().SetDir(folder_path_).SetTask("Detection").SetUsage(usage_).SetClassIndex(class_index_).Build(&op));
RETURN_IF_NOT_OK(op->ParseImageIds());
RETURN_IF_NOT_OK(op->ParseAnnotationIds());
num_rows = static_cast<int64_t>(op->image_ids_.size());
} else if (task_type_ == TaskType::Segmentation) {
std::shared_ptr<VOCOp> op;
RETURN_IF_NOT_OK(Builder().SetDir(folder_path_).SetTask("Segmentation").SetUsage(usage_).Build(&op));
RETURN_IF_NOT_OK(op->ParseImageIds());
num_rows = static_cast<int64_t>(op->image_ids_.size());
}
} else {
num_rows = image_ids_.size();
}
sample_size = sampler_->CalculateNumSamples(num_rows);
*dataset_size = sample_size;
dataset_size_ = *dataset_size;
return Status::OK();
}

Status VOCOp::GetClassIndexing(std::vector<std::pair<std::string, std::vector<int32_t>>> *output_class_indexing) {
if ((*output_class_indexing).empty()) {
if (task_type_ != TaskType::Detection) {


+ 3
- 8
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/voc_op.h View File

@@ -187,15 +187,15 @@ class VOCOp : public ParallelOp, public RandomAccessOp {
// @param show_all
void Print(std::ostream &out, bool show_all) const override;

#ifdef ENABLE_PYTHON
// @param const std::string &dir - VOC dir path
// @param const std::string &task_type - task type of reading voc job
// @param const std::string &task_mode - task mode of reading voc job
// @param const py::dict &dict - input dict of class index
// @param const std::map<std::string, int32_t> input_class_indexing - input map of class index
// @param int64_t *count - output rows number of VOCDataset
static Status CountTotalRows(const std::string &dir, const std::string &task_type, const std::string &task_mode,
const py::dict &dict, int64_t *count);
const std::map<std::string, int32_t> &input_class_indexing, int64_t *count);

#ifdef ENABLE_PYTHON
// @param const std::string &dir - VOC dir path
// @param const std::string &task_type - task type of reading voc job
// @param const std::string &task_mode - task mode of reading voc job
@@ -216,11 +216,6 @@ class VOCOp : public ParallelOp, public RandomAccessOp {
// @return Name of the current Op
std::string Name() const override { return "VOCOp"; }

/// \brief Base-class override for GetDatasetSize
/// \param[out] dataset_size the size of the dataset
/// \return Status of the function
Status GetDatasetSize(int64_t *dataset_size) override;

// /// \brief Gets the class indexing
// /// \return Status - The status code return
Status GetClassIndexing(std::vector<std::pair<std::string, std::vector<int32_t>>> *output_class_indexing) override;


+ 0
- 12
mindspore/ccsrc/minddata/dataset/engine/datasetops/take_op.cc View File

@@ -139,17 +139,5 @@ Status TakeOp::PreAccept(NodePass *p, bool *modified) {
return p->PreRunOnNode(shared_from_base<TakeOp>(), modified);
}

// Get Dataset size
Status TakeOp::GetDatasetSize(int64_t *dataset_size) {
if (dataset_size_ > 0) {
*dataset_size = dataset_size_;
return Status::OK();
}
int64_t num_rows;
RETURN_IF_NOT_OK(child_[0]->GetDatasetSize(&num_rows));
*dataset_size = std::min(static_cast<int64_t>(max_takes_), num_rows);
dataset_size_ = *dataset_size;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

+ 0
- 5
mindspore/ccsrc/minddata/dataset/engine/datasetops/take_op.h View File

@@ -94,11 +94,6 @@ class TakeOp : public PipelineOp {
// @return Name of the current Op
std::string Name() const override { return kTakeOp; }

/// \brief Base-class override for GetDatasetSize
/// \param[out] dataset_size the size of the dataset
/// \return Status of the function
Status GetDatasetSize(int64_t *dataset_size) override;

private:
int32_t max_takes_; // The number of takes that the user requested
int32_t take_count_; // A counter for the current number of executed takes


+ 0
- 18
mindspore/ccsrc/minddata/dataset/engine/datasetops/zip_op.cc View File

@@ -248,24 +248,6 @@ Status ZipOp::Accept(NodePass *p, bool *modified) {
return p->RunOnNode(shared_from_base<ZipOp>(), modified);
}

// Get Dataset size
Status ZipOp::GetDatasetSize(int64_t *dataset_size) {
if (dataset_size_ > 0) {
*dataset_size = dataset_size_;
return Status::OK();
}
std::vector<int32_t> dataset_sizes;
int64_t child_dataset_size;
for (auto child : child_) {
RETURN_IF_NOT_OK(child->GetDatasetSize(&child_dataset_size));
dataset_sizes.push_back(child_dataset_size);
}

*dataset_size = *std::min_element(dataset_sizes.begin(), dataset_sizes.end());
dataset_size_ = *dataset_size;
return Status::OK();
}

Status ZipOp::ComputeColMap() {
if (column_name_id_map_.empty()) {
column_name_id_map_ = {};


+ 0
- 5
mindspore/ccsrc/minddata/dataset/engine/datasetops/zip_op.h View File

@@ -120,11 +120,6 @@ class ZipOp : public PipelineOp {
// @return Name of the current Op
std::string Name() const override { return kZipOp; }

/// \brief Base-class override for GetDatasetSize
/// \param[out] dataset_size the size of the dataset
/// \return Status of the function
Status GetDatasetSize(int64_t *dataset_size) override;

private:
// Handles preprocessing of the main loop, used when starting new epoch
Status prepare(TensorQTable *const table);


+ 28
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/batch_node.cc View File

@@ -114,5 +114,33 @@ std::vector<std::shared_ptr<DatasetOp>> BatchNode::Build() {
return node_ops;
}

// Get Dataset size
Status BatchNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) {
if (dataset_size_ > 0) {
*dataset_size = dataset_size_;
return Status::OK();
}
#ifdef ENABLE_PYTHON
if (batch_size_func_) {
RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), dataset_size));
dataset_size_ = *dataset_size;
return Status::OK();
}
#endif
int64_t num_rows;
RETURN_IF_NOT_OK(children_[0]->GetDatasetSize(size_getter, estimate, &num_rows));
if (num_rows > 0 && batch_size_ > 0) {
if (drop_remainder_) {
num_rows = static_cast<int64_t>(floor(num_rows / (1.0 * batch_size_)));
} else {
num_rows = static_cast<int64_t>(ceil(num_rows / (1.0 * batch_size_)));
}
}
*dataset_size = num_rows;
dataset_size_ = num_rows;
return Status::OK();
}

} // namespace dataset
} // namespace mindspore

+ 9
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/batch_node.h View File

@@ -64,6 +64,15 @@ class BatchNode : public DatasetNode {
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;

/// \brief Base-class override for GetDatasetSize
/// \param[in] size_getter Shared pointer to DatasetSizeGetter
/// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting
/// dataset size at the expense of accuracy.
/// \param[out] dataset_size the size of the dataset
/// \return Status of the function
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) override;

private:
int32_t batch_size_;
bool drop_remainder_;


+ 0
- 1
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.cc View File

@@ -127,6 +127,5 @@ Status BucketBatchByLengthNode::ValidateParams() {

return Status::OK();
}

} // namespace dataset
} // namespace mindspore

+ 2
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h View File

@@ -60,6 +60,8 @@ class BucketBatchByLengthNode : public DatasetNode {
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;

bool IsSizeDefined() override { return false; };

private:
std::vector<std::string> column_names_;
std::vector<int32_t> bucket_boundaries_;


+ 2
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/concat_node.h View File

@@ -58,6 +58,8 @@ class ConcatNode : public DatasetNode {
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;

bool IsSizeDefined() override { return false; }

private:
std::shared_ptr<SamplerObj> sampler_;
std::vector<std::pair<int, int>> children_flag_and_nums_;


+ 25
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.cc View File

@@ -342,6 +342,31 @@ Status DatasetNode::GetShardId(int32_t *shard_id) {
RETURN_STATUS_SYNTAX_ERROR("Get Shard Id failed at source node: " + Name() + "\n");
}
}

// Gets the dataset size
Status DatasetNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) {
if (dataset_size_ > 0) {
*dataset_size = dataset_size_;
return Status::OK();
}
if (!IsSizeDefined()) {
RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), dataset_size));
dataset_size_ = *dataset_size;
return Status::OK();
}
if (children_.size() == 1) {
return children_[0]->GetDatasetSize(size_getter, estimate, dataset_size);
} else if (children_.size() > 1) {
// It is okay for dataset to have more than 1 child, GetDatasetSize shouldn't fail in this case.
// This is done mostly for cache, which injects cache lookup/merge operators. Cache path will
// always be in front of the child_ structure, so we get the dataset size from the last child.
return children_[children_.size() - 1]->GetDatasetSize(size_getter, estimate, dataset_size);
} else {
RETURN_STATUS_UNEXPECTED("Trying to get dataset size from leaf node, missing override");
}
}

// Visitor accepting method for NodePass
Status SourceNode::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor


+ 13
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.h View File

@@ -25,6 +25,7 @@
#include <vector>

#include "minddata/dataset/include/datasets.h"
#include "minddata/dataset/engine/consumers/tree_consumer.h"

namespace mindspore {
namespace dataset {
@@ -32,6 +33,7 @@ namespace dataset {
class Dataset;
class SamplerObj;
class NodePass;
class DatasetSizeGetter;

#define RETURN_EMPTY_IF_ERROR(_s) \
do { \
@@ -169,6 +171,14 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> {
/// \return Status Status::OK() if get shard id successfully
virtual Status GetShardId(int32_t *shard_id);

/// \brief Gets the dataset size
/// \param[in] size_getter Shared pointer to DatasetSizeGetter
/// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting
/// dataset size at the expense of accuracy.
/// \return Status - The status code return
virtual Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size);

/// \brief Getter function for child nodes
/// \return Child nodes
const std::vector<std::shared_ptr<DatasetNode>> Children() const { return children_; }
@@ -219,10 +229,13 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> {
/// \notes Remove me after changing return val of Build()
Status BuildStatus() { return build_status; }

virtual bool IsSizeDefined() { return true; }

protected:
std::vector<std::shared_ptr<DatasetNode>> children_;
DatasetNode *parent_;
std::shared_ptr<DatasetCache> cache_;
int64_t dataset_size_ = -1;
int32_t num_workers_;
int32_t rows_per_buffer_;
int32_t connector_que_size_;


+ 2
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/filter_node.h View File

@@ -55,6 +55,8 @@ class FilterNode : public DatasetNode {
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;

bool IsSizeDefined() override { return false; };

/// \brief Base-class override for accepting NodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified


+ 17
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/repeat_node.cc View File

@@ -56,6 +56,23 @@ Status RepeatNode::ValidateParams() {
return Status::OK();
}

// Get Dataset size
Status RepeatNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) {
if (dataset_size_ > 0) {
*dataset_size = dataset_size_;
return Status::OK();
}
int64_t num_rows;
RETURN_IF_NOT_OK(children_[0]->GetDatasetSize(size_getter, estimate, &num_rows));
if (num_rows > 0 && repeat_count_ > 0) {
num_rows = num_rows * repeat_count_;
}
*dataset_size = num_rows;
dataset_size_ = num_rows;
return Status::OK();
}

// Visitor accepting method for NodePass
Status RepeatNode::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor


+ 9
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/repeat_node.h View File

@@ -56,6 +56,15 @@ class RepeatNode : public DatasetNode {
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;

/// \brief Base-class override for GetDatasetSize
/// \param[in] size_getter Shared pointer to DatasetSizeGetter
/// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting
/// dataset size at the expense of accuracy.
/// \param[out] dataset_size the size of the dataset
/// \return Status of the function
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) override;

/// \brief Base-class override for accepting NodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified


+ 17
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/skip_node.cc View File

@@ -56,5 +56,22 @@ Status SkipNode::ValidateParams() {
return Status::OK();
}

// Get Dataset size
Status SkipNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) {
if (dataset_size_ > 0) {
*dataset_size = dataset_size_;
return Status::OK();
}
int64_t num_rows;
RETURN_IF_NOT_OK(children_[0]->GetDatasetSize(size_getter, estimate, &num_rows));
*dataset_size = 0;
if (skip_count_ >= 0 && skip_count_ < num_rows) {
*dataset_size = num_rows - skip_count_;
}
dataset_size_ = *dataset_size;
return Status::OK();
}

} // namespace dataset
} // namespace mindspore

+ 9
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/skip_node.h View File

@@ -54,6 +54,15 @@ class SkipNode : public DatasetNode {
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;

/// \brief Base-class override for GetDatasetSize
/// \param[in] size_getter Shared pointer to DatasetSizeGetter
/// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting
/// dataset size at the expense of accuracy.
/// \param[out] dataset_size the size of the dataset
/// \return Status of the function
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) override;

private:
int32_t skip_count_;
};


+ 62
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/celeba_node.cc View File

@@ -21,6 +21,7 @@
#include <string>
#include <utility>
#include <vector>
#include <algorithm>

#include "minddata/dataset/engine/datasetops/source/celeba_op.h"
#include "minddata/dataset/util/status.h"
@@ -87,5 +88,66 @@ Status CelebANode::GetShardId(int32_t *shard_id) {
return Status::OK();
}

// Get Dataset size
Status CelebANode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) {
int64_t num_rows, sample_size;
std::ifstream partition_file;
std::string line;
Path folder_path(dataset_dir_);
std::ifstream attr_file((folder_path / "list_attr_celeba.txt").toString());
if (!attr_file.is_open()) {
std::string attr_file_name = (folder_path / "list_attr_celeba.txt").toString();
RETURN_STATUS_UNEXPECTED("Invalid file, failed to open Celeba attr file: " + attr_file_name);
}

std::string rows_num;
(void)getline(attr_file, rows_num);
try {
num_rows = static_cast<int64_t>(std::stoul(rows_num)); // First line is rows number in attr file
} catch (std::invalid_argument &e) {
RETURN_STATUS_UNEXPECTED(
"Invalid data, failed to convert rows_num from attr_file to unsigned long, invalid argument: " + rows_num);
} catch (std::out_of_range &e) {
RETURN_STATUS_UNEXPECTED(
"Invalid data, failed to convert rows_num from attr_file to unsigned long, out of range: " + rows_num);
}
if (usage_ != "all") {
int64_t partition_num = 0;
char usage_type;
if (usage_ == "train") {
usage_type = '0';
} else {
if (usage_ == "valid") {
usage_type = '1';
} else {
if (usage_ == "test")
usage_type = '2';
else
RETURN_STATUS_UNEXPECTED("Invalid usage.");
}
}
if (!partition_file.is_open()) {
partition_file.open((folder_path / "list_eval_partition.txt").toString());
}
if (partition_file.is_open()) {
while (getline(partition_file, line)) {
int start = line.find(' ');
if (line.at(start + 1) == usage_type) {
partition_num++;
}
}
} else {
std::string partition_file_name = "list_eval_partition.txt";
RETURN_STATUS_UNEXPECTED("Invalid file, failed to open CelebA partition file: " + partition_file_name);
}
num_rows = std::min(num_rows, partition_num);
}

sample_size = sampler_->Build()->CalculateNumSamples(num_rows);
*dataset_size = sample_size;
return Status::OK();
}

} // namespace dataset
} // namespace mindspore

+ 9
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/celeba_node.h View File

@@ -61,6 +61,15 @@ class CelebANode : public MappableSourceNode {
/// \return Status Status::OK() if get shard id successfully
Status GetShardId(int32_t *shard_id) override;

/// \brief Base-class override for GetDatasetSize
/// \param[in] size_getter Shared pointer to DatasetSizeGetter
/// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting
/// dataset size at the expense of accuracy.
/// \param[out] dataset_size the size of the dataset
/// \return Status of the function
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) override;

private:
std::string dataset_dir_;
std::string usage_;


+ 15
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar100_node.cc View File

@@ -83,5 +83,20 @@ Status Cifar100Node::GetShardId(int32_t *shard_id) {
return Status::OK();
}

// Get Dataset size
Status Cifar100Node::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) {
if (dataset_size_ > 0) {
*dataset_size = dataset_size_;
return Status::OK();
}
int64_t num_rows, sample_size;
RETURN_IF_NOT_OK(CifarOp::CountTotalRows(dataset_dir_, usage_, false, &num_rows));
sample_size = sampler_->Build()->CalculateNumSamples(num_rows);
*dataset_size = sample_size;
dataset_size_ = *dataset_size;
return Status::OK();
}

} // namespace dataset
} // namespace mindspore

+ 9
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar100_node.h View File

@@ -59,6 +59,15 @@ class Cifar100Node : public MappableSourceNode {
/// \return Status Status::OK() if get shard id successfully
Status GetShardId(int32_t *shard_id) override;

/// \brief Base-class override for GetDatasetSize
/// \param[in] size_getter Shared pointer to DatasetSizeGetter
/// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting
/// dataset size at the expense of accuracy.
/// \param[out] dataset_size the size of the dataset
/// \return Status of the function
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) override;

private:
std::string dataset_dir_;
std::string usage_;


+ 15
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar10_node.cc View File

@@ -81,5 +81,20 @@ Status Cifar10Node::GetShardId(int32_t *shard_id) {
return Status::OK();
}

// Get Dataset size
Status Cifar10Node::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) {
if (dataset_size_ > 0) {
*dataset_size = dataset_size_;
return Status::OK();
}
int64_t num_rows, sample_size;
RETURN_IF_NOT_OK(CifarOp::CountTotalRows(dataset_dir_, usage_, true, &num_rows));
sample_size = sampler_->Build()->CalculateNumSamples(num_rows);
*dataset_size = sample_size;
dataset_size_ = *dataset_size;
return Status::OK();
}

} // namespace dataset
} // namespace mindspore

+ 9
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar10_node.h View File

@@ -59,6 +59,15 @@ class Cifar10Node : public MappableSourceNode {
/// \return Status Status::OK() if get shard id successfully
Status GetShardId(int32_t *shard_id) override;

/// \brief Base-class override for GetDatasetSize
/// \param[in] size_getter Shared pointer to DatasetSizeGetter
/// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting
/// dataset size at the expense of accuracy.
/// \param[out] dataset_size the size of the dataset
/// \return Status of the function
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) override;

private:
std::string dataset_dir_;
std::string usage_;


+ 16
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/clue_node.cc View File

@@ -241,5 +241,21 @@ Status CLUENode::GetShardId(int32_t *shard_id) {
return Status::OK();
}

// Get Dataset size
Status CLUENode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) {
if (dataset_size_ > 0) {
*dataset_size = dataset_size_;
return Status::OK();
}
int64_t num_rows, sample_size;
RETURN_IF_NOT_OK(ClueOp::CountAllFileRows(dataset_files_, &num_rows));
sample_size = num_samples_;
num_rows = static_cast<int64_t>(ceil(num_rows / (1.0 * num_shards_)));
*dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows;
dataset_size_ = *dataset_size;
return Status::OK();
}

} // namespace dataset
} // namespace mindspore

+ 9
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/clue_node.h View File

@@ -61,6 +61,15 @@ class CLUENode : public NonMappableSourceNode {
/// \return Status Status::OK() if get shard id successfully
Status GetShardId(int32_t *shard_id) override;

/// \brief Base-class override for GetDatasetSize
/// \param[in] size_getter Shared pointer to DatasetSizeGetter
/// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting
/// dataset size at the expense of accuracy.
/// \param[out] dataset_size the size of the dataset
/// \return Status of the function
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) override;

private:
/// \brief Split string based on a character delimiter
/// \return A string vector


+ 15
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/coco_node.cc View File

@@ -134,5 +134,20 @@ Status CocoNode::GetShardId(int32_t *shard_id) {
return Status::OK();
}

// Get Dataset size
Status CocoNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) {
if (dataset_size_ > 0) {
*dataset_size = dataset_size_;
return Status::OK();
}
int64_t num_rows = 0, sample_size;
RETURN_IF_NOT_OK(CocoOp::CountTotalRows(dataset_dir_, annotation_file_, task_, &num_rows));
sample_size = sampler_->Build()->CalculateNumSamples(num_rows);
*dataset_size = sample_size;
dataset_size_ = *dataset_size;
return Status::OK();
}

} // namespace dataset
} // namespace mindspore

+ 9
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/coco_node.h View File

@@ -59,6 +59,15 @@ class CocoNode : public MappableSourceNode {
/// \return Status Status::OK() if get shard id successfully
Status GetShardId(int32_t *shard_id) override;

/// \brief Base-class override for GetDatasetSize
/// \param[in] size_getter Shared pointer to DatasetSizeGetter
/// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting
/// dataset size at the expense of accuracy.
/// \param[out] dataset_size the size of the dataset
/// \return Status of the function
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) override;

private:
std::string dataset_dir_;
std::string annotation_file_;


+ 16
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/csv_node.cc View File

@@ -153,5 +153,21 @@ Status CSVNode::GetShardId(int32_t *shard_id) {
return Status::OK();
}

// Get Dataset size
Status CSVNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) {
if (dataset_size_ > 0) {
*dataset_size = dataset_size_;
return Status::OK();
}
int64_t num_rows, sample_size;
RETURN_IF_NOT_OK(CsvOp::CountAllFileRows(dataset_files_, column_names_.empty(), &num_rows));
sample_size = num_samples_;
num_rows = static_cast<int64_t>(ceil(num_rows / (1.0 * num_shards_)));
*dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows;
dataset_size_ = *dataset_size;
return Status::OK();
}

} // namespace dataset
} // namespace mindspore

+ 9
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/csv_node.h View File

@@ -82,6 +82,15 @@ class CSVNode : public NonMappableSourceNode {
/// \return Status Status::OK() if get shard id successfully
Status GetShardId(int32_t *shard_id) override;

/// \brief Base-class override for GetDatasetSize
/// \param[in] size_getter Shared pointer to DatasetSizeGetter
/// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting
/// dataset size at the expense of accuracy.
/// \param[out] dataset_size the size of the dataset
/// \return Status of the function
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) override;

private:
std::vector<std::string> dataset_files_;
char field_delim_;


+ 0
- 1
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/generator_node.cc View File

@@ -93,6 +93,5 @@ Status GeneratorNode::GetShardId(int32_t *shard_id) {
*shard_id = 0;
return Status::OK();
}

} // namespace dataset
} // namespace mindspore

+ 7
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/generator_node.h View File

@@ -64,6 +64,13 @@ class GeneratorNode : public MappableSourceNode {
/// \return Status Status::OK() if get shard id successfully
Status GetShardId(int32_t *shard_id) override;

/// \brief Setter for DatasetSize in GeneratorNode
/// \param[in] sz dataset size to set
/// \return void
void SetGeneratorDatasetSize(int64_t sz) { dataset_size_ = sz; }

bool IsSizeDefined() override { return false; }

private:
py::function generator_function_;
std::vector<std::string> column_names_;


+ 15
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/image_folder_node.cc View File

@@ -89,5 +89,20 @@ Status ImageFolderNode::GetShardId(int32_t *shard_id) {
return Status::OK();
}

// Get Dataset size
Status ImageFolderNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) {
if (dataset_size_ > 0) {
*dataset_size = dataset_size_;
return Status::OK();
}
int64_t sample_size, num_rows;
RETURN_IF_NOT_OK(ImageFolderOp::CountRowsAndClasses(dataset_dir_, exts_, &num_rows, nullptr, {}));
sample_size = sampler_->Build()->CalculateNumSamples(num_rows);
*dataset_size = sample_size;
dataset_size_ = *dataset_size;
return Status::OK();
}

} // namespace dataset
} // namespace mindspore

+ 9
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/image_folder_node.h View File

@@ -65,6 +65,15 @@ class ImageFolderNode : public MappableSourceNode {
/// \return Status Status::OK() if get shard id successfully
Status GetShardId(int32_t *shard_id) override;

/// \brief Base-class override for GetDatasetSize
/// \param[in] size_getter Shared pointer to DatasetSizeGetter
/// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting
/// dataset size at the expense of accuracy.
/// \param[out] dataset_size the size of the dataset
/// \return Status of the function
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) override;

private:
std::string dataset_dir_;
bool decode_;


+ 16
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/manifest_node.cc View File

@@ -111,5 +111,21 @@ Status ManifestNode::GetShardId(int32_t *shard_id) {
return Status::OK();
}

// Get Dataset size
Status ManifestNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) {
if (dataset_size_ > 0) {
*dataset_size = dataset_size_;
return Status::OK();
}
int64_t num_rows, sample_size;
int64_t num_classes; // dummy variable
RETURN_IF_NOT_OK(ManifestOp::CountTotalRows(dataset_file_, class_index_, usage_, &num_rows, &num_classes));
sample_size = sampler_->Build()->CalculateNumSamples(num_rows);
*dataset_size = sample_size;
dataset_size_ = *dataset_size;
return Status::OK();
}

} // namespace dataset
} // namespace mindspore

+ 9
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/manifest_node.h View File

@@ -60,6 +60,15 @@ class ManifestNode : public MappableSourceNode {
/// \return Status Status::OK() if get shard id successfully
Status GetShardId(int32_t *shard_id) override;

/// \brief Base-class override for GetDatasetSize
/// \param[in] size_getter Shared pointer to DatasetSizeGetter
/// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting
/// dataset size at the expense of accuracy.
/// \param[out] dataset_size the size of the dataset
/// \return Status of the function
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) override;

private:
std::string dataset_file_;
std::string usage_;


+ 23
- 1
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/minddata_node.cc View File

@@ -152,7 +152,6 @@ std::vector<std::shared_ptr<DatasetOp>> MindDataNode::Build() {
// A vector containing shared pointer to the Dataset Ops that this object will create
std::vector<std::shared_ptr<DatasetOp>> node_ops;

std::vector<std::shared_ptr<ShardOperator>> operators_;
build_status = BuildMindDatasetSamplerChain(sampler_, &operators_, num_padded_);
RETURN_EMPTY_IF_ERROR(build_status); // remove me after changing return val of Build()

@@ -184,5 +183,28 @@ Status MindDataNode::GetShardId(int32_t *shard_id) {
return Status::OK();
}

// Get Dataset size
Status MindDataNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) {
if (dataset_size_ > 0) {
*dataset_size = dataset_size_;
return Status::OK();
}
int64_t num_rows;
std::vector<std::shared_ptr<ShardOperator>> operators;
RETURN_IF_NOT_OK(BuildMindDatasetSamplerChain(sampler_, &operators, num_padded_));

if (search_for_pattern_) {
dataset_files_ = {dataset_file_};
}

// The last operator is parent sampler
std::shared_ptr<ShardOperator> op = operators.back();
RETURN_IF_NOT_OK(MindRecordOp::CountTotalRows(dataset_files_, search_for_pattern_, op, &num_rows, num_padded_));
*dataset_size = num_rows;
dataset_size_ = *dataset_size;
return Status::OK();
}

} // namespace dataset
} // namespace mindspore

+ 10
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/minddata_node.h View File

@@ -74,6 +74,15 @@ class MindDataNode : public MappableSourceNode {
/// \note Pybind will use this function to set sample_bytes into MindDataNode
void SetSampleBytes(std::map<std::string, std::string> *sample_bytes);

/// \brief Base-class override for GetDatasetSize
/// \param[in] size_getter Shared pointer to DatasetSizeGetter
/// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting
/// dataset size at the expense of accuracy.
/// \param[out] dataset_size the size of the dataset
/// \return Status of the function
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) override;

private:
std::string dataset_file_; // search_for_pattern_ will be true in this mode
std::vector<std::string> dataset_files_; // search_for_pattern_ will be false in this mode
@@ -83,6 +92,7 @@ class MindDataNode : public MappableSourceNode {
nlohmann::json padded_sample_;
std::map<std::string, std::string> sample_bytes_; // enable in python
int64_t num_padded_;
std::vector<std::shared_ptr<ShardOperator>> operators_;
};

} // namespace dataset


+ 15
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/mnist_node.cc View File

@@ -75,5 +75,20 @@ Status MnistNode::GetShardId(int32_t *shard_id) {
return Status::OK();
}

// Get Dataset size
Status MnistNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) {
if (dataset_size_ > 0) {
*dataset_size = dataset_size_;
return Status::OK();
}
int64_t num_rows, sample_size;
RETURN_IF_NOT_OK(MnistOp::CountTotalRows(dataset_dir_, usage_, &num_rows));
sample_size = sampler_->Build()->CalculateNumSamples(num_rows);
*dataset_size = sample_size;
dataset_size_ = *dataset_size;
return Status::OK();
}

} // namespace dataset
} // namespace mindspore

+ 9
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/mnist_node.h View File

@@ -59,6 +59,15 @@ class MnistNode : public MappableSourceNode {
/// \return Status Status::OK() if get shard id successfully
Status GetShardId(int32_t *shard_id) override;

/// \brief Base-class override for GetDatasetSize
/// \param[in] size_getter Shared pointer to DatasetSizeGetter
/// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting
/// dataset size at the expense of accuracy.
/// \param[out] dataset_size the size of the dataset
/// \return Status of the function
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) override;

private:
std::string dataset_dir_;
std::string usage_;


+ 23
- 5
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/random_node.cc View File

@@ -86,17 +86,16 @@ std::vector<std::shared_ptr<DatasetOp>> RandomNode::Build() {
schema_file_path = schema_path_;
}

std::unique_ptr<DataSchema> data_schema;
std::vector<std::string> columns_to_load;
if (columns_list_.size() > 0) {
columns_to_load = columns_list_;
}
if (!schema_file_path.empty() || !schema_json_string.empty()) {
data_schema = std::make_unique<DataSchema>();
data_schema_ = std::make_unique<DataSchema>();
if (!schema_file_path.empty()) {
data_schema->LoadSchemaFile(schema_file_path, columns_to_load);
data_schema_->LoadSchemaFile(schema_file_path, columns_to_load);
} else if (!schema_json_string.empty()) {
data_schema->LoadSchemaString(schema_json_string, columns_to_load);
data_schema_->LoadSchemaString(schema_json_string, columns_to_load);
}
}

@@ -109,7 +108,7 @@ std::vector<std::shared_ptr<DatasetOp>> RandomNode::Build() {

std::shared_ptr<RandomDataOp> op;
op = std::make_shared<RandomDataOp>(num_workers_, connector_que_size_, rows_per_buffer_, total_rows_,
std::move(data_schema), std::move(sampler_->Build()));
std::move(data_schema_), std::move(sampler_->Build()));
build_status = AddCacheOp(&node_ops); // remove me after changing return val of Build()
RETURN_EMPTY_IF_ERROR(build_status);

@@ -125,5 +124,24 @@ Status RandomNode::GetShardId(int32_t *shard_id) {
return Status::OK();
}

// Get Dataset size
Status RandomNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) {
if (dataset_size_ > 0) {
*dataset_size = dataset_size_;
return Status::OK();
}
int64_t num_rows;
num_rows = total_rows_ != 0 ? total_rows_ : data_schema_->num_rows();
if (sampler_ != nullptr) {
int64_t sample_size;
sample_size = sampler_->Build()->CalculateNumSamples(num_rows);
*dataset_size = sample_size;
} else {
*dataset_size = num_rows;
}
dataset_size_ = *dataset_size;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

+ 10
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/random_node.h View File

@@ -79,6 +79,15 @@ class RandomNode : public NonMappableSourceNode {
/// \return Status Status::OK() if get shard id successfully
Status GetShardId(int32_t *shard_id) override;

/// \brief Base-class override for GetDatasetSize
/// \param[in] size_getter Shared pointer to DatasetSizeGetter
/// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting
/// dataset size at the expense of accuracy.
/// \param[out] dataset_size the size of the dataset
/// \return Status of the function
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) override;

private:
/// \brief A quick inline for producing a random number between (and including) min/max
/// \param[in] min minimum number that can be generated.
@@ -92,6 +101,7 @@ class RandomNode : public NonMappableSourceNode {
std::vector<std::string> columns_list_;
std::shared_ptr<SamplerObj> sampler_;
std::mt19937 rand_gen_;
std::unique_ptr<DataSchema> data_schema_;
};

} // namespace dataset


+ 15
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/text_file_node.cc View File

@@ -122,5 +122,20 @@ Status TextFileNode::GetShardId(int32_t *shard_id) {
return Status::OK();
}

// Get Dataset size
Status TextFileNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) {
if (dataset_size_ > 0) {
*dataset_size = dataset_size_;
return Status::OK();
}
int64_t num_rows, sample_size = num_samples_;
RETURN_IF_NOT_OK(TextFileOp::CountAllFileRows(dataset_files_, &num_rows));
num_rows = static_cast<int64_t>(ceil(num_rows / (1.0 * num_shards_)));
*dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows;
dataset_size_ = *dataset_size;
return Status::OK();
}

} // namespace dataset
} // namespace mindspore

+ 9
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/text_file_node.h View File

@@ -61,6 +61,15 @@ class TextFileNode : public NonMappableSourceNode {
/// \return Status Status::OK() if get shard id successfully
Status GetShardId(int32_t *shard_id) override;

/// \brief Base-class override for GetDatasetSize
/// \param[in] size_getter Shared pointer to DatasetSizeGetter
/// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting
/// dataset size at the expense of accuracy.
/// \param[out] dataset_size the size of the dataset
/// \return Status of the function
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) override;

private:
std::vector<std::string> dataset_files_;
int32_t num_samples_;


+ 36
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/tf_record_node.cc View File

@@ -169,5 +169,41 @@ Status TFRecordNode::GetShardId(int32_t *shard_id) {
return Status::OK();
}

// Get Dataset size
Status TFRecordNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) {
if (dataset_size_ > 0) {
*dataset_size = dataset_size_;
return Status::OK();
}
int64_t num_rows;
if (!shard_equal_rows_) {
// Data will be sharded by file
std::vector<std::string> shard_file_list;
RETURN_IF_NOT_OK(GetShardFileList(&shard_file_list));
RETURN_IF_NOT_OK(TFReaderOp::CountTotalRows(&num_rows, shard_file_list, 8, estimate));
} else {
// Data will be sharded by row
RETURN_IF_NOT_OK(TFReaderOp::CountTotalRows(&num_rows, dataset_files_, 8, estimate));
num_rows = static_cast<int64_t>(ceil(num_rows / (num_shards_ * 1.0)));
}
*dataset_size = num_samples_ > 0 ? std::min(num_rows, num_samples_) : num_rows;
dataset_size_ = *dataset_size;
return Status::OK();
}

// Get the file list of the specific shard ID
Status TFRecordNode::GetShardFileList(std::vector<std::string> *shard_filenames) {
if (!shard_filenames->empty()) {
RETURN_STATUS_UNEXPECTED("The initial file list must be empty.");
}
for (int index = 0; index < dataset_files_.size(); index++) {
if (index % num_shards_ == shard_id_) {
shard_filenames->push_back(dataset_files_.at(index));
}
}
return Status::OK();
}

} // namespace dataset
} // namespace mindspore

+ 14
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/tf_record_node.h View File

@@ -88,6 +88,20 @@ class TFRecordNode : public NonMappableSourceNode {
/// \return Status Status::OK() if get shard id successfully
Status GetShardId(int32_t *shard_id) override;

/// \brief Base-class override for GetDatasetSize
/// \param[in] size_getter Shared pointer to DatasetSizeGetter
/// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting
/// dataset size at the expense of accuracy.
/// \param[out] dataset_size the size of the dataset
/// \return Status of the function
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) override;

/// \brief Get the file list of the specific shard ID
/// \param[out] shard_filenames the list of filenames for that specific shard ID
/// \return Status of the function
Status GetShardFileList(std::vector<std::string> *shard_filenames);

private:
std::vector<std::string> dataset_files_;
std::string schema_path_; // schema_path_ path to schema file. It is set when type of schema parameter is string


+ 15
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/voc_node.cc View File

@@ -128,5 +128,20 @@ Status VOCNode::GetShardId(int32_t *shard_id) {
return Status::OK();
}

// Get Dataset size
Status VOCNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) {
if (dataset_size_ > 0) {
*dataset_size = dataset_size_;
return Status::OK();
}
int64_t num_rows = 0, sample_size;
RETURN_IF_NOT_OK(VOCOp::CountTotalRows(dataset_dir_, task_, usage_, class_index_, &num_rows));
sample_size = sampler_->Build()->CalculateNumSamples(num_rows);
*dataset_size = sample_size;
dataset_size_ = *dataset_size;
return Status::OK();
}

} // namespace dataset
} // namespace mindspore

+ 9
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/voc_node.h View File

@@ -61,6 +61,15 @@ class VOCNode : public MappableSourceNode {
/// \return Status Status::OK() if get shard id successfully
Status GetShardId(int32_t *shard_id) override;

/// \brief Base-class override for GetDatasetSize
/// \param[in] size_getter Shared pointer to DatasetSizeGetter
/// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting
/// dataset size at the expense of accuracy.
/// \param[out] dataset_size the size of the dataset
/// \return Status of the function
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) override;

private:
const std::string kColumnImage = "image";
const std::string kColumnTarget = "target";


+ 15
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/take_node.cc View File

@@ -19,6 +19,7 @@
#include <memory>
#include <string>
#include <vector>
#include <algorithm>

#include "minddata/dataset/engine/datasetops/take_op.h"
#include "minddata/dataset/util/status.h"
@@ -56,5 +57,19 @@ Status TakeNode::ValidateParams() {
return Status::OK();
}

// Get Dataset size
Status TakeNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) {
if (dataset_size_ > 0) {
*dataset_size = dataset_size_;
return Status::OK();
}
int64_t num_rows;
RETURN_IF_NOT_OK(children_[0]->GetDatasetSize(size_getter, estimate, &num_rows));
*dataset_size = std::min(static_cast<int64_t>(take_count_), num_rows);
dataset_size_ = *dataset_size;
return Status::OK();
}

} // namespace dataset
} // namespace mindspore

+ 9
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/take_node.h View File

@@ -54,6 +54,15 @@ class TakeNode : public DatasetNode {
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;

/// \brief Base-class override for GetDatasetSize
/// \param[in] size_getter Shared pointer to DatasetSizeGetter
/// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting
/// dataset size at the expense of accuracy.
/// \param[out] dataset_size the size of the dataset
/// \return Status of the function
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) override;

private:
int32_t take_count_;
};


Some files were not shown because too many files changed in this diff

Loading…
Cancel
Save