From: @mahdirahmanihanzaki Reviewed-by: Signed-off-by:tags/v1.1.0
| @@ -199,12 +199,13 @@ bool Dataset::Save(std::string dataset_path, int32_t num_files, std::string data | |||
| // Constructor | |||
| Dataset::Dataset() { tree_getters_ = std::make_shared<TreeGetters>(); } | |||
| int64_t Dataset::GetDatasetSize() { | |||
| int64_t Dataset::GetDatasetSize(bool estimate) { | |||
| int64_t dataset_size; | |||
| std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>(); | |||
| RETURN_SECOND_IF_ERROR(runtime_context->Init(), -1); | |||
| RETURN_SECOND_IF_ERROR(tree_getters_->Init(this->IRNode()), -1); | |||
| RETURN_SECOND_IF_ERROR(tree_getters_->GetDatasetSize(&dataset_size), -1); | |||
| std::shared_ptr<DatasetSizeGetter> size_getter = std::make_shared<DatasetSizeGetter>(); | |||
| RETURN_SECOND_IF_ERROR(size_getter->Init(this->IRNode()), -1); | |||
| RETURN_SECOND_IF_ERROR(size_getter->GetDatasetSize(&dataset_size, estimate), -1); | |||
| return dataset_size; | |||
| } | |||
| @@ -106,19 +106,7 @@ PYBIND_REGISTER(ImageFolderOp, 1, ([](const py::module *m) { | |||
| })); | |||
| PYBIND_REGISTER(ManifestOp, 1, ([](const py::module *m) { | |||
| (void)py::class_<ManifestOp, DatasetOp, std::shared_ptr<ManifestOp>>(*m, "ManifestOp") | |||
| .def_static("get_num_rows_and_classes", | |||
| [](const std::string &file, const py::dict &dict, const std::string &usage) { | |||
| int64_t count = 0, num_classes = 0; | |||
| THROW_IF_ERROR(ManifestOp::CountTotalRows(file, dict, usage, &count, &num_classes)); | |||
| return py::make_tuple(count, num_classes); | |||
| }) | |||
| .def_static("get_class_indexing", [](const std::string &file, const py::dict &dict, | |||
| const std::string &usage) { | |||
| std::map<std::string, int32_t> output_class_indexing; | |||
| THROW_IF_ERROR(ManifestOp::GetClassIndexing(file, dict, usage, &output_class_indexing)); | |||
| return output_class_indexing; | |||
| }); | |||
| (void)py::class_<ManifestOp, DatasetOp, std::shared_ptr<ManifestOp>>(*m, "ManifestOp"); | |||
| })); | |||
| PYBIND_REGISTER(MindRecordOp, 1, ([](const py::module *m) { | |||
| (void)py::class_<MindRecordOp, DatasetOp, std::shared_ptr<MindRecordOp>>(*m, "MindRecordOp") | |||
| @@ -173,13 +161,6 @@ PYBIND_REGISTER(TFReaderOp, 1, ([](const py::module *m) { | |||
| PYBIND_REGISTER(VOCOp, 1, ([](const py::module *m) { | |||
| (void)py::class_<VOCOp, DatasetOp, std::shared_ptr<VOCOp>>(*m, "VOCOp") | |||
| .def_static("get_num_rows", | |||
| [](const std::string &dir, const std::string &task_type, const std::string &task_mode, | |||
| const py::dict &dict, int64_t numSamples) { | |||
| int64_t count = 0; | |||
| THROW_IF_ERROR(VOCOp::CountTotalRows(dir, task_type, task_mode, dict, &count)); | |||
| return count; | |||
| }) | |||
| .def_static("get_class_indexing", [](const std::string &dir, const std::string &task_type, | |||
| const std::string &task_mode, const py::dict &dict) { | |||
| std::map<std::string, int32_t> output_class_indexing; | |||
| @@ -184,7 +184,11 @@ PYBIND_REGISTER(GeneratorNode, 2, ([](const py::module *m) { | |||
| auto gen = std::make_shared<GeneratorNode>(generator_function, schema); | |||
| THROW_IF_ERROR(gen->ValidateParams()); | |||
| return gen; | |||
| })); | |||
| })) | |||
| .def("SetGeneratorDatasetSize", [](std::shared_ptr<GeneratorNode> self, int64_t sz) { | |||
| self->SetGeneratorDatasetSize(sz); | |||
| return self; | |||
| }); | |||
| })); | |||
| PYBIND_REGISTER(ImageFolderNode, 2, ([](const py::module *m) { | |||
| @@ -93,12 +93,6 @@ PYBIND_REGISTER(TreeGetters, 1, ([](const py::module *m) { | |||
| THROW_IF_ERROR(self.GetClassIndexing(&output_class_indexing)); | |||
| return output_class_indexing; | |||
| }) | |||
| .def("GetDatasetSize", | |||
| [](PythonTreeGetters &self) { | |||
| int64_t dataset_size; | |||
| THROW_IF_ERROR(self.GetDatasetSize(&dataset_size)); | |||
| return dataset_size; | |||
| }) | |||
| .def("__deepcopy__", [](py::object &tree_getter, py::dict memo) { return tree_getter; }); | |||
| })); | |||
| @@ -164,5 +158,18 @@ PYBIND_REGISTER(PythonSaveToDisk, 1, ([](const py::module *m) { | |||
| .def("Save", [](PythonSaveToDisk &self) { THROW_IF_ERROR(self.Save()); }); | |||
| })); | |||
| PYBIND_REGISTER(PythonDatasetSizeGetter, 1, ([](const py::module *m) { | |||
| (void)py::class_<PythonDatasetSizeGetter, TreeConsumer, std::shared_ptr<PythonDatasetSizeGetter>>( | |||
| *m, "DatasetSizeGetters") | |||
| .def(py::init<>()) | |||
| .def("Init", [](PythonDatasetSizeGetter &self, | |||
| std::shared_ptr<DatasetNode> d) { THROW_IF_ERROR(self.Init(d)); }) | |||
| .def("GetDatasetSize", [](PythonDatasetSizeGetter &self, bool estimate) { | |||
| int64_t size; | |||
| THROW_IF_ERROR(self.GetDatasetSize(&size, estimate)); | |||
| return size; | |||
| }); | |||
| })); | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -65,4 +65,8 @@ Status PythonTreeGetters::GetRow(TensorRow *r) { | |||
| py::gil_scoped_release gil_release; | |||
| return TreeGetters::GetRow(r); | |||
| } | |||
| Status PythonDatasetSizeGetter::GetRow(const std::shared_ptr<TreeAdapter> &tree_adapter, TensorRow *r) { | |||
| py::gil_scoped_release gil_release; | |||
| return DatasetSizeGetter::GetRow(tree_adapter, r); | |||
| } | |||
| } // namespace mindspore::dataset | |||
| @@ -60,5 +60,9 @@ class PythonTreeGetters : public TreeGetters { | |||
| public: | |||
| Status GetRow(TensorRow *r) override; | |||
| }; | |||
| class PythonDatasetSizeGetter : public DatasetSizeGetter { | |||
| public: | |||
| Status GetRow(const std::shared_ptr<TreeAdapter> &tree_adapter, TensorRow *r) override; | |||
| }; | |||
| } // namespace mindspore::dataset | |||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CONSUMERS_PYTHON_TREE_CONSUMER_H_ | |||
| @@ -451,29 +451,6 @@ Status TreeGetters::Init(std::shared_ptr<DatasetNode> d) { | |||
| Status TreeGetters::GetRow(TensorRow *row) { return tree_adapter_->GetNext(row); } | |||
| Status TreeGetters::GetDatasetSize(int64_t *dataset_size) { | |||
| if (dataset_size_ == -1) { | |||
| RETURN_IF_NOT_OK(InternalInit(static_cast<int8_t>(GetterPass::kDatasetSize))); | |||
| std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot()); | |||
| RETURN_UNEXPECTED_IF_NULL(root); | |||
| RETURN_IF_NOT_OK(root->GetDatasetSize(dataset_size)); | |||
| if (*dataset_size == -1) { // run through the tree and get everything | |||
| TensorRow row; | |||
| RETURN_IF_NOT_OK(GetRow(&row)); | |||
| int64_t row_cnt = 0; | |||
| while (!row.empty()) { | |||
| ++row_cnt; | |||
| RETURN_IF_NOT_OK(GetRow(&row)); | |||
| } | |||
| *dataset_size = row_cnt; | |||
| } | |||
| dataset_size_ = *dataset_size; // save the previous result | |||
| } | |||
| *dataset_size = dataset_size_; | |||
| return Status::OK(); | |||
| } | |||
| Status TreeGetters::GetOutputTypes(std::vector<DataType> *types) { | |||
| RETURN_IF_NOT_OK(GetFirstRowShapeAndType()); | |||
| *types = first_row_type_; | |||
| @@ -573,5 +550,46 @@ Status BuildVocabConsumer::Start() { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(row.empty(), "The fetched row from BuildVocab should be an EOE."); | |||
| return Status::OK(); | |||
| } | |||
| Status DatasetSizeGetter::GetDatasetSize(int64_t *size, bool estimate) { | |||
| if (dataset_size_ == -1) { | |||
| RETURN_IF_NOT_OK(root_->GetDatasetSize(shared_from_this(), estimate, size)); | |||
| dataset_size_ = *size; // save the previous result | |||
| } | |||
| *size = dataset_size_; | |||
| return Status::OK(); | |||
| } | |||
| Status DatasetSizeGetter::Init(std::shared_ptr<DatasetNode> d) { | |||
| root_ = std::move(d); | |||
| return Status::OK(); | |||
| } | |||
| Status DatasetSizeGetter::DryRun(std::shared_ptr<DatasetNode> ir_node, int64_t *dataset_size) { | |||
| std::shared_ptr<TreeAdapter> tree_adapter = std::make_shared<TreeAdapter>(); | |||
| tree_adapters_.push_back(tree_adapter); | |||
| tree_adapter->SetPrePassOverride([](OptPass pre) { | |||
| pre.push_back( | |||
| std::make_unique<GetterPass>(static_cast<GetterPass::GetterType>(GetterPass::GetterType::kDatasetSize))); | |||
| return pre; | |||
| }); | |||
| RETURN_IF_NOT_OK(tree_adapter->Compile(std::move(ir_node), 1)); | |||
| TensorRow row; | |||
| RETURN_IF_NOT_OK(GetRow(tree_adapter, &row)); | |||
| int64_t row_cnt = 0; | |||
| while (!row.empty()) { | |||
| ++row_cnt; | |||
| RETURN_IF_NOT_OK(GetRow(tree_adapter, &row)); | |||
| } | |||
| *dataset_size = row_cnt; | |||
| return Status::OK(); | |||
| } | |||
| Status DatasetSizeGetter::GetRow(const std::shared_ptr<TreeAdapter> &tree_adapter, TensorRow *row) { | |||
| return tree_adapter->GetNext(row); | |||
| } | |||
| Status DatasetSizeGetter::Terminate() { | |||
| for (const auto &tree : tree_adapters_) { | |||
| RETURN_IF_NOT_OK(tree->AllTasks()->ServiceStop()); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -177,7 +177,6 @@ class TreeGetters : public TreeConsumer { | |||
| ~TreeGetters() = default; | |||
| Status Init(std::shared_ptr<DatasetNode> d) override; | |||
| Status GetDatasetSize(int64_t *size); | |||
| Status GetOutputTypes(std::vector<DataType> *types); | |||
| Status GetOutputShapes(std::vector<TensorShape> *shapes); | |||
| Status GetBatchSize(int64_t *batch_size); | |||
| @@ -186,7 +185,7 @@ class TreeGetters : public TreeConsumer { | |||
| Status GetColumnNames(std::vector<std::string> *output); | |||
| Status GetClassIndexing(std::vector<std::pair<std::string, std::vector<int32_t>>> *output_class_indexing); | |||
| std::string Name() override { return "TreeGetters"; } | |||
| virtual Status GetRow(TensorRow *r); | |||
| virtual Status GetRow(TensorRow *row); | |||
| private: | |||
| Status GetFirstRowShapeAndType(); | |||
| @@ -202,6 +201,35 @@ class TreeGetters : public TreeConsumer { | |||
| Status InternalInit(); | |||
| }; | |||
| /// Consumer that is used to get some pipeline information | |||
| class DatasetSizeGetter : public TreeConsumer, public std::enable_shared_from_this<DatasetSizeGetter> { | |||
| public: | |||
| DatasetSizeGetter() : dataset_size_(-1) {} | |||
| ~DatasetSizeGetter() = default; | |||
| Status Init(std::shared_ptr<DatasetNode> d) override; | |||
| Status Terminate() override; | |||
| /// \brief Function to get the dataset size | |||
| /// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting | |||
| /// dataset size at the expense of accuracy. | |||
| /// \param[out] dataset_size the size of the dataset | |||
| /// \return Status of the function | |||
| Status GetDatasetSize(int64_t *size, bool estimate = false); | |||
| virtual Status GetRow(const std::shared_ptr<TreeAdapter> &tree_adapter, TensorRow *row); | |||
| std::string Name() override { return "DatasetSizeGetter"; } | |||
| /// \brief Gets the dataset size by iterating over the entire dataset on a sub tree starting from ir_node | |||
| /// param[in] ir_node The node that marks the top most of the sub tree on which we want to iterate | |||
| /// \return Status - The status code return | |||
| Status DryRun(std::shared_ptr<DatasetNode> ir_node, int64_t *dataset_size); | |||
| private: | |||
| std::shared_ptr<DatasetNode> root_; | |||
| std::vector<std::shared_ptr<TreeAdapter>> tree_adapters_; | |||
| int64_t dataset_size_; | |||
| }; | |||
| class BuildVocabConsumer : public TreeConsumer { | |||
| public: | |||
| /// BuildVocabConsumer Constructor which will call the base class default constructor. | |||
| @@ -537,30 +537,6 @@ Status BatchOp::ComputeColMap() { | |||
| return Status::OK(); | |||
| } | |||
| Status BatchOp::GetDatasetSize(int64_t *dataset_size) { | |||
| if (dataset_size_ > 0) { | |||
| *dataset_size = dataset_size_; | |||
| return Status::OK(); | |||
| } | |||
| #ifdef ENABLE_PYTHON | |||
| if (batch_size_func_) { | |||
| *dataset_size = -1; | |||
| return Status::OK(); | |||
| } | |||
| #endif | |||
| int64_t num_rows; | |||
| RETURN_IF_NOT_OK(child_[0]->GetDatasetSize(&num_rows)); | |||
| if (num_rows > 0 && start_batch_size_ > 0) { | |||
| if (drop_) { | |||
| num_rows = static_cast<int64_t>(floor(num_rows / (1.0 * start_batch_size_))); | |||
| } else { | |||
| num_rows = static_cast<int64_t>(ceil(num_rows / (1.0 * start_batch_size_))); | |||
| } | |||
| } | |||
| *dataset_size = num_rows; | |||
| dataset_size_ = num_rows; | |||
| return Status::OK(); | |||
| } | |||
| int64_t BatchOp::GetTreeBatchSize() { | |||
| #ifdef ENABLE_PYTHON | |||
| if (batch_size_func_) { | |||
| @@ -225,11 +225,6 @@ class BatchOp : public ParallelOp { | |||
| static Status PadColumns(std::unique_ptr<TensorQTable> *table, const PadInfo &pad_info, | |||
| const std::unordered_map<std::string, int32_t> &column_name_id_map); | |||
| /// \brief Base-class override for GetDatasetSize | |||
| /// \param[out] dataset_size the size of the dataset | |||
| /// \return Status of the function | |||
| Status GetDatasetSize(int64_t *dataset_size) override; | |||
| int64_t GetTreeBatchSize() override; | |||
| protected: | |||
| @@ -232,12 +232,5 @@ Status BucketBatchByLengthOp::ComputeColMap() { | |||
| return Status::OK(); | |||
| } | |||
| // Get Dataset size | |||
| Status BucketBatchByLengthOp::GetDatasetSize(int64_t *dataset_size) { | |||
| // We are returning -1 because we can't easily calculate GetDatasetSize. Returning -1 will make TreeGetters to | |||
| // iterate over the dataset and count the size | |||
| *dataset_size = dataset_size_; | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -112,11 +112,6 @@ class BucketBatchByLengthOp : public PipelineOp { | |||
| std::string Name() const override { return kBucketBatchByLengthOp; } | |||
| /// \brief Base-class override for GetDatasetSize | |||
| /// \param[out] dataset_size the size of the dataset | |||
| /// \return Status of the function | |||
| Status GetDatasetSize(int64_t *dataset_size) override; | |||
| // << Stream output operator overload | |||
| // @notes This allows you to write the debug print info using stream operators | |||
| // @param out - reference to the output stream being overloaded | |||
| @@ -196,12 +196,5 @@ Status ConcatOp::PreAccept(NodePass *p, bool *modified) { | |||
| return p->PreRunOnNode(shared_from_base<ConcatOp>(), modified); | |||
| } | |||
| // Get Dataset size | |||
| Status ConcatOp::GetDatasetSize(int64_t *dataset_size) { | |||
| // We are returning -1 because we can't easily calculate GetDatasetSize. Returning -1 will make TreeGetters to | |||
| // iterate over the dataset and count the size | |||
| *dataset_size = dataset_size_; | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -111,11 +111,6 @@ class ConcatOp : public PipelineOp { | |||
| /// \return Status of the node visit | |||
| Status PreAccept(NodePass *p, bool *modified) override; | |||
| /// \brief Base-class override for GetDatasetSize | |||
| /// \param[out] dataset_size the size of the dataset | |||
| /// \return Status of the function | |||
| Status GetDatasetSize(int64_t *dataset_size) override; | |||
| private: | |||
| Status Verify(int32_t id, const std::unique_ptr<DataBuffer> &buf); | |||
| @@ -294,24 +294,6 @@ Status DatasetOp::GetNextInput(std::unique_ptr<DataBuffer> *p_buffer, int32_t wo | |||
| return Status::OK(); | |||
| } | |||
| // Gets the dataset size | |||
| Status DatasetOp::GetDatasetSize(int64_t *dataset_size) { | |||
| if (dataset_size_ > 0) { | |||
| *dataset_size = dataset_size_; | |||
| return Status::OK(); | |||
| } | |||
| if (child_.size() == 1) { | |||
| return child_[0]->GetDatasetSize(dataset_size); | |||
| } else if (child_.size() > 1) { | |||
| // It is okay for dataset to have more than 1 child, GetDatasetSize shouldn't fail in this case. | |||
| // This is done mostly for cache, which injects cache lookup/merge operators. Cache path will | |||
| // always be in front of the child_ structure, so we get the dataset size from the last child. | |||
| return child_[child_.size() - 1]->GetDatasetSize(dataset_size); | |||
| } else { | |||
| RETURN_STATUS_UNEXPECTED("Trying to get dataset size from leaf node, missing override"); | |||
| } | |||
| } | |||
| // Gets the number of classes | |||
| Status DatasetOp::GetNumClasses(int64_t *num_classes) { | |||
| if (child_.size() == 1) { | |||
| @@ -180,10 +180,6 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> { | |||
| /// \return Status - The error code return | |||
| Status GetNextInput(std::unique_ptr<DataBuffer> *p_buffer, int32_t worker_id = 0, int32_t child_index = 0); | |||
| /// \brief Gets the dataset size | |||
| /// \return Status - The status code return | |||
| virtual Status GetDatasetSize(int64_t *dataset_size); | |||
| /// \brief Gets the batch size | |||
| /// \return Status - The status code return | |||
| virtual int64_t GetTreeBatchSize(); | |||
| @@ -258,13 +258,5 @@ Status FilterOp::PreAccept(NodePass *p, bool *modified) { | |||
| return p->PreRunOnNode(shared_from_base<FilterOp>(), modified); | |||
| } | |||
| // Get Dataset size | |||
| Status FilterOp::GetDatasetSize(int64_t *dataset_size) { | |||
| // We are returning -1 because we can't easily calculate GetDatasetSize. Returning -1 will make TreeGetters to | |||
| // iterate over the dataset and count the size | |||
| *dataset_size = dataset_size_; | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -137,11 +137,6 @@ class FilterOp : public ParallelOp { | |||
| // @return Name of the current Op | |||
| std::string Name() const override { return kFilterOp; } | |||
| /// \brief Base-class override for GetDatasetSize | |||
| /// \param[out] dataset_size the size of the dataset | |||
| /// \return Status of the function | |||
| Status GetDatasetSize(int64_t *dataset_size) override; | |||
| private: | |||
| // predicate_func python callable which returns a boolean value. | |||
| std::shared_ptr<TensorOp> predicate_func_; | |||
| @@ -187,21 +187,6 @@ Status RepeatOp::Accept(NodePass *p, bool *modified) { | |||
| return p->RunOnNode(shared_from_base<RepeatOp>(), modified); | |||
| } | |||
| // Get Dataset size | |||
| Status RepeatOp::GetDatasetSize(int64_t *dataset_size) { | |||
| if (dataset_size_ > 0) { | |||
| *dataset_size = dataset_size_; | |||
| return Status::OK(); | |||
| } | |||
| int64_t num_rows; | |||
| RETURN_IF_NOT_OK(child_[0]->GetDatasetSize(&num_rows)); | |||
| if (num_rows > 0 && num_repeats_ > 0) { | |||
| num_rows = num_rows * num_repeats_; | |||
| } | |||
| *dataset_size = num_rows; | |||
| dataset_size_ = num_rows; | |||
| return Status::OK(); | |||
| } | |||
| int64_t RepeatOp::GetTreeRepeatCount() { return num_repeats_; } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -133,11 +133,6 @@ class RepeatOp : public PipelineOp { | |||
| /// \@return Status - The error code return | |||
| Status Reset() override; | |||
| /// \brief Base-class override for GetDatasetSize | |||
| /// \param[out] dataset_size the size of the dataset | |||
| /// \return Status of the function | |||
| Status GetDatasetSize(int64_t *dataset_size) override; | |||
| int64_t GetTreeRepeatCount() override; | |||
| // \brief Adds an operator to the repeat ops list of tracked leaf/eoe nodes | |||
| @@ -136,20 +136,5 @@ Status SkipOp::PreAccept(NodePass *p, bool *modified) { | |||
| return p->PreRunOnNode(shared_from_base<SkipOp>(), modified); | |||
| } | |||
| // Get Dataset size | |||
| Status SkipOp::GetDatasetSize(int64_t *dataset_size) { | |||
| if (dataset_size_ > 0) { | |||
| *dataset_size = dataset_size_; | |||
| return Status::OK(); | |||
| } | |||
| int64_t num_rows; | |||
| RETURN_IF_NOT_OK(child_[0]->GetDatasetSize(&num_rows)); | |||
| *dataset_size = 0; | |||
| if (max_skips_ >= 0 && max_skips_ < num_rows) { | |||
| *dataset_size = num_rows - max_skips_; | |||
| } | |||
| dataset_size_ = *dataset_size; | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -86,11 +86,6 @@ class SkipOp : public PipelineOp { | |||
| /// \return Status of the node visit | |||
| Status PreAccept(NodePass *p, bool *modified) override; | |||
| /// \brief Base-class override for GetDatasetSize | |||
| /// \param[out] dataset_size the size of the dataset | |||
| /// \return Status of the function | |||
| Status GetDatasetSize(int64_t *dataset_size) override; | |||
| // Op name getter | |||
| // @return Name of the current Op | |||
| std::string Name() const override { return kSkipOp; } | |||
| @@ -452,63 +452,5 @@ Status CelebAOp::ComputeColMap() { | |||
| return Status::OK(); | |||
| } | |||
| // Get Dataset size | |||
| Status CelebAOp::GetDatasetSize(int64_t *dataset_size) { | |||
| int64_t num_rows, sample_size; | |||
| std::string line; | |||
| Path folder_path(folder_path_); | |||
| std::ifstream attr_file((folder_path / "list_attr_celeba.txt").toString()); | |||
| if (!attr_file.is_open()) { | |||
| std::string attr_file_name = (folder_path / "list_attr_celeba.txt").toString(); | |||
| RETURN_STATUS_UNEXPECTED("Invalid file, failed to open Celeba attr file: " + attr_file_name); | |||
| } | |||
| std::string rows_num; | |||
| (void)getline(attr_file, rows_num); | |||
| try { | |||
| num_rows = static_cast<int64_t>(std::stoul(rows_num)); // First line is rows number in attr file | |||
| } catch (std::invalid_argument &e) { | |||
| RETURN_STATUS_UNEXPECTED( | |||
| "Invalid data, failed to convert rows_num from attr_file to unsigned long, invalid argument: " + rows_num); | |||
| } catch (std::out_of_range &e) { | |||
| RETURN_STATUS_UNEXPECTED( | |||
| "Invalid data, failed to convert rows_num from attr_file to unsigned long, out of range: " + rows_num); | |||
| } | |||
| if (usage_ != "all") { | |||
| int64_t partition_num = 0; | |||
| char usage_type; | |||
| if (usage_ == "train") { | |||
| usage_type = '0'; | |||
| } else { | |||
| if (usage_ == "valid") { | |||
| usage_type = '1'; | |||
| } else { | |||
| if (usage_ == "test") | |||
| usage_type = '2'; | |||
| else | |||
| RETURN_STATUS_UNEXPECTED("Invalid usage."); | |||
| } | |||
| } | |||
| if (!partition_file_.is_open()) { | |||
| partition_file_.open((folder_path / "list_eval_partition.txt").toString()); | |||
| } | |||
| if (partition_file_.is_open()) { | |||
| while (getline(partition_file_, line)) { | |||
| int start = line.find(' '); | |||
| if (line.at(start + 1) == usage_type) { | |||
| partition_num++; | |||
| } | |||
| } | |||
| } else { | |||
| std::string partition_file_name = "list_eval_partition.txt"; | |||
| RETURN_STATUS_UNEXPECTED("Invalid file, failed to open Celeba partition file: " + partition_file_name); | |||
| } | |||
| num_rows = std::min(num_rows, partition_num); | |||
| } | |||
| sample_size = sampler_->CalculateNumSamples(num_rows); | |||
| *dataset_size = sample_size; | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -179,11 +179,6 @@ class CelebAOp : public ParallelOp, RandomAccessOp { | |||
| // @return Name of the current Op | |||
| std::string Name() const override { return "CelebAOp"; } | |||
| /// \brief Base-class override for GetDatasetSize | |||
| /// \param[out] dataset_size the size of the dataset | |||
| /// \return Status of the function | |||
| Status GetDatasetSize(int64_t *dataset_size) override; | |||
| private: | |||
| // Called first when function is called | |||
| // @return | |||
| @@ -508,20 +508,5 @@ Status CifarOp::ComputeColMap() { | |||
| return Status::OK(); | |||
| } | |||
| // Get Dataset size | |||
| Status CifarOp::GetDatasetSize(int64_t *dataset_size) { | |||
| if (dataset_size_ > 0) { | |||
| *dataset_size = dataset_size_; | |||
| return Status::OK(); | |||
| } | |||
| int64_t num_rows, sample_size; | |||
| num_rows = num_rows_; | |||
| if (num_rows_ <= 0) | |||
| RETURN_IF_NOT_OK(CountTotalRows(folder_path_, usage_, cifar_type_ == CifarType::kCifar10, &num_rows)); | |||
| sample_size = sampler_->CalculateNumSamples(num_rows); | |||
| *dataset_size = sample_size; | |||
| dataset_size_ = *dataset_size; | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -175,11 +175,6 @@ class CifarOp : public ParallelOp, public RandomAccessOp { | |||
| // @return Name of the current Op | |||
| std::string Name() const override { return "CifarOp"; } | |||
| /// \brief Base-class override for GetDatasetSize | |||
| /// \param[out] dataset_size the size of the dataset | |||
| /// \return Status of the function | |||
| Status GetDatasetSize(int64_t *dataset_size) override; | |||
| private: | |||
| // Initialize Sampler, calls sampler->Init() within | |||
| // @return Status - The error code return | |||
| @@ -565,19 +565,5 @@ Status ClueOp::Accept(NodePass *p, bool *modified) { | |||
| return p->RunOnNode(shared_from_base<ClueOp>(), modified); | |||
| } | |||
| // Get Dataset size | |||
| Status ClueOp::GetDatasetSize(int64_t *dataset_size) { | |||
| if (dataset_size_ > 0) { | |||
| *dataset_size = dataset_size_; | |||
| return Status::OK(); | |||
| } | |||
| int64_t num_rows, sample_size; | |||
| if (num_rows_per_shard_ <= 0) RETURN_IF_NOT_OK(CalculateNumRowsPerShard()); | |||
| sample_size = num_samples_; | |||
| num_rows = num_rows_per_shard_; | |||
| *dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows; | |||
| dataset_size_ = *dataset_size; | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -197,11 +197,6 @@ class ClueOp : public ParallelOp { | |||
| // @return - Status of the node visit. | |||
| Status Accept(NodePass *p, bool *modified) override; | |||
| /// \brief Base-class override for GetDatasetSize | |||
| /// \param[out] dataset_size the size of the dataset | |||
| /// \return Status of the function | |||
| Status GetDatasetSize(int64_t *dataset_size) override; | |||
| private: | |||
| // The entry point for when workers are launched. | |||
| // @param worker_id - the id of the worker that is executing this function. | |||
| @@ -681,39 +681,6 @@ Status CocoOp::ComputeColMap() { | |||
| return Status::OK(); | |||
| } | |||
| // Get Dataset size | |||
| Status CocoOp::GetDatasetSize(int64_t *dataset_size) { | |||
| if (dataset_size_ > 0) { | |||
| *dataset_size = dataset_size_; | |||
| return Status::OK(); | |||
| } | |||
| int64_t num_rows = 0, sample_size; | |||
| std::string task_type; | |||
| switch (task_type_) { | |||
| case TaskType::Detection: | |||
| task_type = "Detection"; | |||
| break; | |||
| case TaskType::Keypoint: | |||
| task_type = "Keypoint"; | |||
| break; | |||
| case TaskType::Panoptic: | |||
| task_type = "Panoptic"; | |||
| break; | |||
| case TaskType::Stuff: | |||
| task_type = "Stuff"; | |||
| break; | |||
| } | |||
| if (image_ids_.size() == 0) { | |||
| RETURN_IF_NOT_OK(CountTotalRows(image_folder_path_, annotation_path_, task_type, &num_rows)); | |||
| } else { | |||
| num_rows = image_ids_.size(); | |||
| } | |||
| sample_size = sampler_->CalculateNumSamples(num_rows); | |||
| *dataset_size = sample_size; | |||
| dataset_size_ = *dataset_size; | |||
| return Status::OK(); | |||
| } | |||
| Status CocoOp::GetClassIndexing(std::vector<std::pair<std::string, std::vector<int32_t>>> *output_class_indexing) { | |||
| if ((*output_class_indexing).empty()) { | |||
| if ((task_type_ != TaskType::Detection) && (task_type_ != TaskType::Panoptic)) { | |||
| @@ -213,11 +213,6 @@ class CocoOp : public ParallelOp, public RandomAccessOp { | |||
| // @return Name of the current Op | |||
| std::string Name() const override { return "CocoOp"; } | |||
| /// \brief Base-class override for GetDatasetSize | |||
| /// \param[out] dataset_size the size of the dataset | |||
| /// \return Status of the function | |||
| Status GetDatasetSize(int64_t *dataset_size) override; | |||
| /// \brief Gets the class indexing | |||
| /// \return Status - The status code return | |||
| Status GetClassIndexing(std::vector<std::pair<std::string, std::vector<int32_t>>> *output_class_indexing) override; | |||
| @@ -916,19 +916,5 @@ Status CsvOp::Accept(NodePass *p, bool *modified) { | |||
| return p->RunOnNode(shared_from_base<CsvOp>(), modified); | |||
| } | |||
| // Get Dataset size | |||
| Status CsvOp::GetDatasetSize(int64_t *dataset_size) { | |||
| if (dataset_size_ > 0) { | |||
| *dataset_size = dataset_size_; | |||
| return Status::OK(); | |||
| } | |||
| int64_t num_rows, sample_size; | |||
| if (num_rows_per_shard_ <= 0) RETURN_IF_NOT_OK(CalculateNumRowsPerShard()); | |||
| sample_size = num_samples_; | |||
| num_rows = num_rows_per_shard_; | |||
| *dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows; | |||
| dataset_size_ = *dataset_size; | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -318,11 +318,6 @@ class CsvOp : public ParallelOp { | |||
| // @return - Status of the node visit. | |||
| Status Accept(NodePass *p, bool *modified) override; | |||
| /// \brief Base-class override for GetDatasetSize | |||
| /// \param[out] dataset_size the size of the dataset | |||
| /// \return Status of the function | |||
| Status GetDatasetSize(int64_t *dataset_size) override; | |||
| private: | |||
| // The entry point for when workers are launched. | |||
| // @param worker_id - the id of the worker that is executing this function. | |||
| @@ -274,11 +274,5 @@ Status GeneratorOp::ComputeColMap() { | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status GeneratorOp::GetDatasetSize(int64_t *dataset_size) { // Get Dataset size | |||
| // We are returning -1 because we can't easily calculate GetDatasetSize. Returning -1 will make TreeGetters to | |||
| // iterate over the dataset and count the size | |||
| *dataset_size = dataset_size_; | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -136,8 +136,6 @@ class GeneratorOp : public PipelineOp { | |||
| Status Init(); | |||
| Status GetDatasetSize(int64_t *dataset_size) override; | |||
| private: | |||
| py::function generator_function_; | |||
| std::vector<std::string> column_names_; | |||
| @@ -465,24 +465,6 @@ Status ImageFolderOp::ComputeColMap() { | |||
| return Status::OK(); | |||
| } | |||
| // Get Dataset size | |||
| Status ImageFolderOp::GetDatasetSize(int64_t *dataset_size) { | |||
| if (dataset_size_ > 0) { | |||
| *dataset_size = dataset_size_; | |||
| return Status::OK(); | |||
| } | |||
| int64_t sample_size, num_rows; | |||
| num_rows = num_rows_; | |||
| if (num_rows_ <= 0) { | |||
| // GetDatasetSize will not be impacted by class_index_ | |||
| RETURN_IF_NOT_OK(CountRowsAndClasses(folder_path_, extensions_, &num_rows, nullptr, {})); | |||
| } | |||
| sample_size = sampler_->CalculateNumSamples(num_rows); | |||
| *dataset_size = sample_size; | |||
| dataset_size_ = *dataset_size; | |||
| return Status::OK(); | |||
| } | |||
| // Get number of classes | |||
| Status ImageFolderOp::GetNumClasses(int64_t *num_classes) { | |||
| if (num_classes_ > 0) { | |||
| @@ -217,11 +217,6 @@ class ImageFolderOp : public ParallelOp, public RandomAccessOp { | |||
| // @return Name of the current Op | |||
| std::string Name() const override { return "ImageFolderOp"; } | |||
| /// \brief Base-class override for GetDatasetSize | |||
| /// \param[out] dataset_size the size of the dataset | |||
| /// \return Status of the function | |||
| Status GetDatasetSize(int64_t *dataset_size) override; | |||
| /// \brief Base-class override for GetNumClasses | |||
| /// \param[out] num_classes the number of classes | |||
| /// \return Status of the function | |||
| @@ -396,16 +396,9 @@ Status ManifestOp::CountDatasetInfo() { | |||
| return Status::OK(); | |||
| } | |||
| #ifdef ENABLE_PYTHON | |||
| Status ManifestOp::CountTotalRows(const std::string &file, const py::dict &dict, const std::string &usage, | |||
| int64_t *count, int64_t *numClasses) { | |||
| Status ManifestOp::CountTotalRows(const std::string &file, const std::map<std::string, int32_t> &map, | |||
| const std::string &usage, int64_t *count, int64_t *numClasses) { | |||
| // the logic of counting the number of samples is copied from ParseManifestFile() | |||
| std::map<std::string, int32_t> map; | |||
| for (auto p : dict) { | |||
| (void)map.insert(std::pair<std::string, int32_t>(py::reinterpret_borrow<py::str>(p.first), | |||
| py::reinterpret_borrow<py::int_>(p.second))); | |||
| } | |||
| std::shared_ptr<ManifestOp> op; | |||
| *count = 0; | |||
| RETURN_IF_NOT_OK(Builder().SetManifestFile(file).SetClassIndex(map).SetUsage(usage).Build(&op)); | |||
| @@ -415,6 +408,7 @@ Status ManifestOp::CountTotalRows(const std::string &file, const py::dict &dict, | |||
| return Status::OK(); | |||
| } | |||
| #ifdef ENABLE_PYTHON | |||
| Status ManifestOp::GetClassIndexing(const std::string &file, const py::dict &dict, const std::string &usage, | |||
| std::map<std::string, int32_t> *output_class_indexing) { | |||
| std::map<std::string, int32_t> input_class_indexing; | |||
| @@ -459,23 +453,6 @@ Status ManifestOp::ComputeColMap() { | |||
| return Status::OK(); | |||
| } | |||
| // Get Dataset size | |||
| Status ManifestOp::GetDatasetSize(int64_t *dataset_size) { | |||
| if (dataset_size_ > 0) { | |||
| *dataset_size = dataset_size_; | |||
| return Status::OK(); | |||
| } | |||
| int64_t num_rows, sample_size; | |||
| std::shared_ptr<ManifestOp> op; | |||
| RETURN_IF_NOT_OK(Builder().SetManifestFile(file_).SetClassIndex(class_index_).SetUsage(usage_).Build(&op)); | |||
| RETURN_IF_NOT_OK(op->ParseManifestFile()); | |||
| num_rows = static_cast<int64_t>(op->image_labelname_.size()); | |||
| sample_size = sampler_->CalculateNumSamples(num_rows); | |||
| *dataset_size = sample_size; | |||
| dataset_size_ = *dataset_size; | |||
| return Status::OK(); | |||
| } | |||
| // Get number of classes | |||
| Status ManifestOp::GetNumClasses(int64_t *num_classes) { | |||
| if (num_classes_ > 0) { | |||
| @@ -164,10 +164,17 @@ class ManifestOp : public ParallelOp, public RandomAccessOp { | |||
| // @param show_all | |||
| void Print(std::ostream &out, bool show_all) const override; | |||
| #ifdef ENABLE_PYTHON | |||
| static Status CountTotalRows(const std::string &file, const py::dict &dict, const std::string &usage, int64_t *count, | |||
| int64_t *numClasses); | |||
| /// \brief Counts the total number of rows in Manifest | |||
| /// \param[in] file Dataset file path | |||
| /// \param[in] input_class_indexing Input map of class index | |||
| /// \param[in] usage Dataset usage | |||
| /// \param[out] count Number of rows counted | |||
| /// \param[out] numClasses Number of classes counted | |||
| /// \return Status of the function | |||
| static Status CountTotalRows(const std::string &file, const std::map<std::string, int32_t> &map, | |||
| const std::string &usage, int64_t *count, int64_t *numClasses); | |||
| #ifdef ENABLE_PYTHON | |||
| // Get str-to-int mapping from label name to index | |||
| static Status GetClassIndexing(const std::string &file, const py::dict &dict, const std::string &usage, | |||
| std::map<std::string, int32_t> *output_class_indexing); | |||
| @@ -183,11 +190,6 @@ class ManifestOp : public ParallelOp, public RandomAccessOp { | |||
| // @return Name of the current Op | |||
| std::string Name() const override { return "ManifestOp"; } | |||
| /// \brief Base-class override for GetDatasetSize | |||
| /// \param[out] dataset_size the size of the dataset | |||
| /// \return Status of the function | |||
| Status GetDatasetSize(int64_t *dataset_size) override; | |||
| /// \brief Base-class override for GetNumClasses | |||
| /// \param[out] num_classes the number of classes | |||
| /// \return Status of the function | |||
| @@ -474,22 +474,5 @@ Status MindRecordOp::ComputeColMap() { | |||
| return Status::OK(); | |||
| } | |||
| // Get Dataset size | |||
| Status MindRecordOp::GetDatasetSize(int64_t *dataset_size) { | |||
| if (dataset_size_ > 0) { | |||
| *dataset_size = dataset_size_; | |||
| return Status::OK(); | |||
| } | |||
| int64_t num_rows = num_rows_; | |||
| if (num_rows_ <= 0) { | |||
| // The last operator is parent sampler | |||
| std::shared_ptr<ShardOperator> op = operators_.back(); | |||
| RETURN_IF_NOT_OK(CountTotalRows(dataset_file_, load_dataset_, op, &num_rows, num_padded_)); | |||
| } | |||
| *dataset_size = num_rows; | |||
| dataset_size_ = *dataset_size; | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -212,11 +212,6 @@ class MindRecordOp : public ParallelOp { | |||
| // @return Name of the current Op | |||
| std::string Name() const override { return "MindRecordOp"; } | |||
| /// \brief Base-class override for GetDatasetSize | |||
| /// \param[out] dataset_size the size of the dataset | |||
| /// \return Status of the function | |||
| Status GetDatasetSize(int64_t *dataset_size) override; | |||
| private: | |||
| Status GetBufferFromReader(std::unique_ptr<DataBuffer> *fetched_buffer, int64_t buffer_id, int32_t worker_id); | |||
| @@ -471,19 +471,5 @@ Status MnistOp::ComputeColMap() { | |||
| return Status::OK(); | |||
| } | |||
| // Get Dataset size | |||
| Status MnistOp::GetDatasetSize(int64_t *dataset_size) { | |||
| if (dataset_size_ > 0) { | |||
| *dataset_size = dataset_size_; | |||
| return Status::OK(); | |||
| } | |||
| int64_t num_rows, sample_size; | |||
| num_rows = num_rows_; | |||
| if (num_rows_ <= 0) RETURN_IF_NOT_OK(CountTotalRows(folder_path_, usage_, &num_rows)); | |||
| sample_size = sampler_->CalculateNumSamples(num_rows); | |||
| *dataset_size = sample_size; | |||
| dataset_size_ = *dataset_size; | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -168,11 +168,6 @@ class MnistOp : public ParallelOp, public RandomAccessOp { | |||
| // @return Name of the current Op | |||
| std::string Name() const override { return "MnistOp"; } | |||
| /// \brief Base-class override for GetDatasetSize | |||
| /// \param[out] dataset_size the size of the dataset | |||
| /// \return Status of the function | |||
| Status GetDatasetSize(int64_t *dataset_size) override; | |||
| private: | |||
| // Initialize Sampler, calls sampler->Init() within | |||
| // @return Status - The error code return | |||
| @@ -421,23 +421,5 @@ Status RandomDataOp::ComputeColMap() { | |||
| return Status::OK(); | |||
| } | |||
| // Get Dataset size | |||
| Status RandomDataOp::GetDatasetSize(int64_t *dataset_size) { | |||
| if (dataset_size_ > 0) { | |||
| *dataset_size = dataset_size_; | |||
| return Status::OK(); | |||
| } | |||
| int64_t num_rows; | |||
| num_rows = total_rows_ != 0 ? total_rows_ : data_schema_->num_rows(); | |||
| if (sampler_ != nullptr) { | |||
| int64_t sample_size; | |||
| sample_size = sampler_->CalculateNumSamples(num_rows); | |||
| *dataset_size = sample_size; | |||
| } else { | |||
| *dataset_size = num_rows; | |||
| } | |||
| dataset_size_ = *dataset_size; | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -203,11 +203,6 @@ class RandomDataOp : public ParallelOp { | |||
| // @return Name of the current Op | |||
| std::string Name() const override { return "RandomDataOp"; } | |||
| /// \brief Base-class override for GetDatasetSize | |||
| /// \param[out] dataset_size the size of the dataset | |||
| /// \return Status of the function | |||
| Status GetDatasetSize(int64_t *dataset_size) override; | |||
| private: | |||
| /** | |||
| * The entry point code for when workers are launched | |||
| @@ -162,11 +162,11 @@ Status DistributedSamplerRT::ResetSampler() { | |||
| } | |||
| int64_t DistributedSamplerRT::CalculateNumSamples(int64_t num_rows) { | |||
| int64_t childs = num_rows; | |||
| int64_t child_num_rows = num_rows; | |||
| if (!child_.empty()) { | |||
| childs = child_[0]->CalculateNumSamples(num_rows); | |||
| child_num_rows = child_[0]->CalculateNumSamples(num_rows); | |||
| } | |||
| int64_t num_samples = (num_samples_ > 0) ? std::min(childs, num_samples_) : childs; | |||
| int64_t num_samples = (num_samples_ > 0) ? std::min(child_num_rows, num_samples_) : child_num_rows; | |||
| return std::ceil(num_samples * 1.0 / num_devices_); | |||
| } | |||
| @@ -63,6 +63,11 @@ class DistributedSamplerRT : public SamplerRT { | |||
| int64_t GetDeviceNum() { return num_devices_; } | |||
| /// \brief Recursively calls this function on its children to get the actual number of samples on a tree of samplers | |||
| /// \note This is not a getter for num_samples_. For example, if num_samples_ is 0 or if it's smaller than num_rows, | |||
| /// then num_samples_ is not returned at all. | |||
| /// \param[in] num_rows The total number of rows in the dataset | |||
| /// \return int64_t Calculated number of samples | |||
| int64_t CalculateNumSamples(int64_t num_rows) override; | |||
| void Print(std::ostream &out, bool show_all) const override; | |||
| @@ -520,19 +520,5 @@ Status TextFileOp::Accept(NodePass *p, bool *modified) { | |||
| return p->RunOnNode(shared_from_base<TextFileOp>(), modified); | |||
| } | |||
| // Get Dataset size | |||
| Status TextFileOp::GetDatasetSize(int64_t *dataset_size) { | |||
| if (dataset_size_ > 0) { | |||
| *dataset_size = dataset_size_; | |||
| return Status::OK(); | |||
| } | |||
| int64_t num_rows, sample_size; | |||
| sample_size = total_rows_; | |||
| if (num_rows_per_shard_ <= 0) RETURN_IF_NOT_OK(CalculateNumRowsPerShard()); | |||
| num_rows = num_rows_per_shard_; | |||
| *dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows; | |||
| dataset_size_ = *dataset_size; | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -198,11 +198,6 @@ class TextFileOp : public ParallelOp { | |||
| // @return - Status of the node visit. | |||
| Status Accept(NodePass *p, bool *modified) override; | |||
| /// \brief Base-class override for GetDatasetSize | |||
| /// \param[out] dataset_size the size of the dataset | |||
| /// \return Status of the function | |||
| Status GetDatasetSize(int64_t *dataset_size) override; | |||
| private: | |||
| // The entry point for when workers are launched. | |||
| // @param worker_id - the id of the worker that is executing this function. | |||
| @@ -1067,41 +1067,5 @@ Status TFReaderOp::PrepareNodePostAction() { | |||
| return Status::OK(); | |||
| } | |||
| // Get the file list of the specific shard ID | |||
| Status TFReaderOp::GetShardFileList(std::vector<std::string> *shard_filenames) { | |||
| if (!shard_filenames->empty()) { | |||
| RETURN_STATUS_UNEXPECTED("The initial file list must be empty.\n"); | |||
| } | |||
| for (int index = 0; index < dataset_files_list_.size(); index++) { | |||
| if (index % num_devices_ == device_id_) { | |||
| shard_filenames->push_back(dataset_files_list_.at(index)); | |||
| } | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| // Get Dataset size | |||
| Status TFReaderOp::GetDatasetSize(int64_t *dataset_size) { | |||
| if (dataset_size_ > 0) { | |||
| *dataset_size = dataset_size_; | |||
| return Status::OK(); | |||
| } | |||
| int64_t num_rows, sample_size; | |||
| num_rows = num_rows_; | |||
| if (num_rows_ <= 0) { | |||
| if (equal_rows_per_shard_) { | |||
| RETURN_IF_NOT_OK(CalculateNumRowsPerShard()); | |||
| num_rows = num_rows_per_shard_; | |||
| } else { | |||
| std::vector<std::string> shard_file_list; | |||
| RETURN_IF_NOT_OK(GetShardFileList(&shard_file_list)); | |||
| RETURN_IF_NOT_OK(CountTotalRows(&num_rows, shard_file_list)); | |||
| } | |||
| } | |||
| sample_size = total_rows_ != 0 ? total_rows_ : data_schema_->num_rows(); | |||
| *dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows; | |||
| dataset_size_ = *dataset_size; | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -257,11 +257,6 @@ class TFReaderOp : public ParallelOp { | |||
| // before providing their own implementations. | |||
| Status PrepareNodePostAction() override; | |||
| /// \brief Base-class override for GetDatasetSize | |||
| /// \param[out] dataset_size the size of the dataset | |||
| /// \return Status of the function | |||
| Status GetDatasetSize(int64_t *dataset_size) override; | |||
| static bool ValidateFirstRowCrc(const std::string &filename); | |||
| private: | |||
| @@ -400,11 +395,6 @@ class TFReaderOp : public ParallelOp { | |||
| // @return - Status | |||
| Status ComputeColMap() override; | |||
| // Private function for computing the file list of the specific shard ID. This is because in distributed scenario, | |||
| // data will be divided into shards by row when equal_rows_per_shard is true, but by file in the opposite case. | |||
| // @return - Status - the status code returned. | |||
| Status GetShardFileList(std::vector<std::string> *shard_filenames); | |||
| int32_t device_id_; | |||
| int32_t num_devices_; | |||
| int64_t rows_per_buffer_; | |||
| @@ -447,16 +447,9 @@ Status VOCOp::ReadAnnotationToTensor(const std::string &path, TensorRow *row) { | |||
| return Status::OK(); | |||
| } | |||
| #ifdef ENABLE_PYTHON | |||
| Status VOCOp::CountTotalRows(const std::string &dir, const std::string &task_type, const std::string &task_mode, | |||
| const py::dict &dict, int64_t *count) { | |||
| const std::map<std::string, int32_t> &input_class_indexing, int64_t *count) { | |||
| if (task_type == "Detection") { | |||
| std::map<std::string, int32_t> input_class_indexing; | |||
| for (auto p : dict) { | |||
| (void)input_class_indexing.insert(std::pair<std::string, int32_t>(py::reinterpret_borrow<py::str>(p.first), | |||
| py::reinterpret_borrow<py::int_>(p.second))); | |||
| } | |||
| std::shared_ptr<VOCOp> op; | |||
| RETURN_IF_NOT_OK( | |||
| Builder().SetDir(dir).SetTask(task_type).SetUsage(task_mode).SetClassIndex(input_class_indexing).Build(&op)); | |||
| @@ -473,6 +466,7 @@ Status VOCOp::CountTotalRows(const std::string &dir, const std::string &task_typ | |||
| return Status::OK(); | |||
| } | |||
| #ifdef ENABLE_PYTHON | |||
| Status VOCOp::GetClassIndexing(const std::string &dir, const std::string &task_type, const std::string &task_mode, | |||
| const py::dict &dict, std::map<std::string, int32_t> *output_class_indexing) { | |||
| std::map<std::string, int32_t> input_class_indexing; | |||
| @@ -516,36 +510,6 @@ Status VOCOp::ComputeColMap() { | |||
| return Status::OK(); | |||
| } | |||
| // Get Dataset size | |||
| Status VOCOp::GetDatasetSize(int64_t *dataset_size) { | |||
| if (dataset_size_ > 0) { | |||
| *dataset_size = dataset_size_; | |||
| return Status::OK(); | |||
| } | |||
| int64_t num_rows = 0, sample_size; | |||
| if (image_ids_.size() == 0) { | |||
| if (task_type_ == TaskType::Detection) { | |||
| std::shared_ptr<VOCOp> op; | |||
| RETURN_IF_NOT_OK( | |||
| Builder().SetDir(folder_path_).SetTask("Detection").SetUsage(usage_).SetClassIndex(class_index_).Build(&op)); | |||
| RETURN_IF_NOT_OK(op->ParseImageIds()); | |||
| RETURN_IF_NOT_OK(op->ParseAnnotationIds()); | |||
| num_rows = static_cast<int64_t>(op->image_ids_.size()); | |||
| } else if (task_type_ == TaskType::Segmentation) { | |||
| std::shared_ptr<VOCOp> op; | |||
| RETURN_IF_NOT_OK(Builder().SetDir(folder_path_).SetTask("Segmentation").SetUsage(usage_).Build(&op)); | |||
| RETURN_IF_NOT_OK(op->ParseImageIds()); | |||
| num_rows = static_cast<int64_t>(op->image_ids_.size()); | |||
| } | |||
| } else { | |||
| num_rows = image_ids_.size(); | |||
| } | |||
| sample_size = sampler_->CalculateNumSamples(num_rows); | |||
| *dataset_size = sample_size; | |||
| dataset_size_ = *dataset_size; | |||
| return Status::OK(); | |||
| } | |||
| Status VOCOp::GetClassIndexing(std::vector<std::pair<std::string, std::vector<int32_t>>> *output_class_indexing) { | |||
| if ((*output_class_indexing).empty()) { | |||
| if (task_type_ != TaskType::Detection) { | |||
| @@ -187,15 +187,15 @@ class VOCOp : public ParallelOp, public RandomAccessOp { | |||
| // @param show_all | |||
| void Print(std::ostream &out, bool show_all) const override; | |||
| #ifdef ENABLE_PYTHON | |||
| // @param const std::string &dir - VOC dir path | |||
| // @param const std::string &task_type - task type of reading voc job | |||
| // @param const std::string &task_mode - task mode of reading voc job | |||
| // @param const py::dict &dict - input dict of class index | |||
| // @param const std::map<std::string, int32_t> input_class_indexing - input map of class index | |||
| // @param int64_t *count - output rows number of VOCDataset | |||
| static Status CountTotalRows(const std::string &dir, const std::string &task_type, const std::string &task_mode, | |||
| const py::dict &dict, int64_t *count); | |||
| const std::map<std::string, int32_t> &input_class_indexing, int64_t *count); | |||
| #ifdef ENABLE_PYTHON | |||
| // @param const std::string &dir - VOC dir path | |||
| // @param const std::string &task_type - task type of reading voc job | |||
| // @param const std::string &task_mode - task mode of reading voc job | |||
| @@ -216,11 +216,6 @@ class VOCOp : public ParallelOp, public RandomAccessOp { | |||
| // @return Name of the current Op | |||
| std::string Name() const override { return "VOCOp"; } | |||
| /// \brief Base-class override for GetDatasetSize | |||
| /// \param[out] dataset_size the size of the dataset | |||
| /// \return Status of the function | |||
| Status GetDatasetSize(int64_t *dataset_size) override; | |||
| // /// \brief Gets the class indexing | |||
| // /// \return Status - The status code return | |||
| Status GetClassIndexing(std::vector<std::pair<std::string, std::vector<int32_t>>> *output_class_indexing) override; | |||
| @@ -139,17 +139,5 @@ Status TakeOp::PreAccept(NodePass *p, bool *modified) { | |||
| return p->PreRunOnNode(shared_from_base<TakeOp>(), modified); | |||
| } | |||
| // Get Dataset size | |||
| Status TakeOp::GetDatasetSize(int64_t *dataset_size) { | |||
| if (dataset_size_ > 0) { | |||
| *dataset_size = dataset_size_; | |||
| return Status::OK(); | |||
| } | |||
| int64_t num_rows; | |||
| RETURN_IF_NOT_OK(child_[0]->GetDatasetSize(&num_rows)); | |||
| *dataset_size = std::min(static_cast<int64_t>(max_takes_), num_rows); | |||
| dataset_size_ = *dataset_size; | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -94,11 +94,6 @@ class TakeOp : public PipelineOp { | |||
| // @return Name of the current Op | |||
| std::string Name() const override { return kTakeOp; } | |||
| /// \brief Base-class override for GetDatasetSize | |||
| /// \param[out] dataset_size the size of the dataset | |||
| /// \return Status of the function | |||
| Status GetDatasetSize(int64_t *dataset_size) override; | |||
| private: | |||
| int32_t max_takes_; // The number of takes that the user requested | |||
| int32_t take_count_; // A counter for the current number of executed takes | |||
| @@ -248,24 +248,6 @@ Status ZipOp::Accept(NodePass *p, bool *modified) { | |||
| return p->RunOnNode(shared_from_base<ZipOp>(), modified); | |||
| } | |||
| // Get Dataset size | |||
| Status ZipOp::GetDatasetSize(int64_t *dataset_size) { | |||
| if (dataset_size_ > 0) { | |||
| *dataset_size = dataset_size_; | |||
| return Status::OK(); | |||
| } | |||
| std::vector<int32_t> dataset_sizes; | |||
| int64_t child_dataset_size; | |||
| for (auto child : child_) { | |||
| RETURN_IF_NOT_OK(child->GetDatasetSize(&child_dataset_size)); | |||
| dataset_sizes.push_back(child_dataset_size); | |||
| } | |||
| *dataset_size = *std::min_element(dataset_sizes.begin(), dataset_sizes.end()); | |||
| dataset_size_ = *dataset_size; | |||
| return Status::OK(); | |||
| } | |||
| Status ZipOp::ComputeColMap() { | |||
| if (column_name_id_map_.empty()) { | |||
| column_name_id_map_ = {}; | |||
| @@ -120,11 +120,6 @@ class ZipOp : public PipelineOp { | |||
| // @return Name of the current Op | |||
| std::string Name() const override { return kZipOp; } | |||
| /// \brief Base-class override for GetDatasetSize | |||
| /// \param[out] dataset_size the size of the dataset | |||
| /// \return Status of the function | |||
| Status GetDatasetSize(int64_t *dataset_size) override; | |||
| private: | |||
| // Handles preprocessing of the main loop, used when starting new epoch | |||
| Status prepare(TensorQTable *const table); | |||
| @@ -114,5 +114,33 @@ std::vector<std::shared_ptr<DatasetOp>> BatchNode::Build() { | |||
| return node_ops; | |||
| } | |||
| // Get Dataset size | |||
| Status BatchNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate, | |||
| int64_t *dataset_size) { | |||
| if (dataset_size_ > 0) { | |||
| *dataset_size = dataset_size_; | |||
| return Status::OK(); | |||
| } | |||
| #ifdef ENABLE_PYTHON | |||
| if (batch_size_func_) { | |||
| RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), dataset_size)); | |||
| dataset_size_ = *dataset_size; | |||
| return Status::OK(); | |||
| } | |||
| #endif | |||
| int64_t num_rows; | |||
| RETURN_IF_NOT_OK(children_[0]->GetDatasetSize(size_getter, estimate, &num_rows)); | |||
| if (num_rows > 0 && batch_size_ > 0) { | |||
| if (drop_remainder_) { | |||
| num_rows = static_cast<int64_t>(floor(num_rows / (1.0 * batch_size_))); | |||
| } else { | |||
| num_rows = static_cast<int64_t>(ceil(num_rows / (1.0 * batch_size_))); | |||
| } | |||
| } | |||
| *dataset_size = num_rows; | |||
| dataset_size_ = num_rows; | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -64,6 +64,15 @@ class BatchNode : public DatasetNode { | |||
| /// \return Status Status::OK() if all the parameters are valid | |||
| Status ValidateParams() override; | |||
| /// \brief Base-class override for GetDatasetSize | |||
| /// \param[in] size_getter Shared pointer to DatasetSizeGetter | |||
| /// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting | |||
| /// dataset size at the expense of accuracy. | |||
| /// \param[out] dataset_size the size of the dataset | |||
| /// \return Status of the function | |||
| Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate, | |||
| int64_t *dataset_size) override; | |||
| private: | |||
| int32_t batch_size_; | |||
| bool drop_remainder_; | |||
| @@ -127,6 +127,5 @@ Status BucketBatchByLengthNode::ValidateParams() { | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -60,6 +60,8 @@ class BucketBatchByLengthNode : public DatasetNode { | |||
| /// \return Status Status::OK() if all the parameters are valid | |||
| Status ValidateParams() override; | |||
| bool IsSizeDefined() override { return false; }; | |||
| private: | |||
| std::vector<std::string> column_names_; | |||
| std::vector<int32_t> bucket_boundaries_; | |||
| @@ -58,6 +58,8 @@ class ConcatNode : public DatasetNode { | |||
| /// \return Status Status::OK() if all the parameters are valid | |||
| Status ValidateParams() override; | |||
| bool IsSizeDefined() override { return false; } | |||
| private: | |||
| std::shared_ptr<SamplerObj> sampler_; | |||
| std::vector<std::pair<int, int>> children_flag_and_nums_; | |||
| @@ -342,6 +342,31 @@ Status DatasetNode::GetShardId(int32_t *shard_id) { | |||
| RETURN_STATUS_SYNTAX_ERROR("Get Shard Id failed at source node: " + Name() + "\n"); | |||
| } | |||
| } | |||
| // Gets the dataset size | |||
| Status DatasetNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate, | |||
| int64_t *dataset_size) { | |||
| if (dataset_size_ > 0) { | |||
| *dataset_size = dataset_size_; | |||
| return Status::OK(); | |||
| } | |||
| if (!IsSizeDefined()) { | |||
| RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), dataset_size)); | |||
| dataset_size_ = *dataset_size; | |||
| return Status::OK(); | |||
| } | |||
| if (children_.size() == 1) { | |||
| return children_[0]->GetDatasetSize(size_getter, estimate, dataset_size); | |||
| } else if (children_.size() > 1) { | |||
| // It is okay for dataset to have more than 1 child, GetDatasetSize shouldn't fail in this case. | |||
| // This is done mostly for cache, which injects cache lookup/merge operators. Cache path will | |||
| // always be in front of the child_ structure, so we get the dataset size from the last child. | |||
| return children_[children_.size() - 1]->GetDatasetSize(size_getter, estimate, dataset_size); | |||
| } else { | |||
| RETURN_STATUS_UNEXPECTED("Trying to get dataset size from leaf node, missing override"); | |||
| } | |||
| } | |||
| // Visitor accepting method for NodePass | |||
| Status SourceNode::Accept(NodePass *p, bool *modified) { | |||
| // Downcast shared pointer then call visitor | |||
| @@ -25,6 +25,7 @@ | |||
| #include <vector> | |||
| #include "minddata/dataset/include/datasets.h" | |||
| #include "minddata/dataset/engine/consumers/tree_consumer.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| @@ -32,6 +33,7 @@ namespace dataset { | |||
| class Dataset; | |||
| class SamplerObj; | |||
| class NodePass; | |||
| class DatasetSizeGetter; | |||
| #define RETURN_EMPTY_IF_ERROR(_s) \ | |||
| do { \ | |||
| @@ -169,6 +171,14 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> { | |||
| /// \return Status Status::OK() if get shard id successfully | |||
| virtual Status GetShardId(int32_t *shard_id); | |||
| /// \brief Gets the dataset size | |||
| /// \param[in] size_getter Shared pointer to DatasetSizeGetter | |||
| /// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting | |||
| /// dataset size at the expense of accuracy. | |||
| /// \return Status - The status code return | |||
| virtual Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate, | |||
| int64_t *dataset_size); | |||
| /// \brief Getter function for child nodes | |||
| /// \return Child nodes | |||
| const std::vector<std::shared_ptr<DatasetNode>> Children() const { return children_; } | |||
| @@ -219,10 +229,13 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> { | |||
| /// \notes Remove me after changing return val of Build() | |||
| Status BuildStatus() { return build_status; } | |||
| virtual bool IsSizeDefined() { return true; } | |||
| protected: | |||
| std::vector<std::shared_ptr<DatasetNode>> children_; | |||
| DatasetNode *parent_; | |||
| std::shared_ptr<DatasetCache> cache_; | |||
| int64_t dataset_size_ = -1; | |||
| int32_t num_workers_; | |||
| int32_t rows_per_buffer_; | |||
| int32_t connector_que_size_; | |||
| @@ -55,6 +55,8 @@ class FilterNode : public DatasetNode { | |||
| /// \return Status Status::OK() if all the parameters are valid | |||
| Status ValidateParams() override; | |||
| bool IsSizeDefined() override { return false; }; | |||
| /// \brief Base-class override for accepting NodePass visitor | |||
| /// \param[in] p The node to visit | |||
| /// \param[out] modified Indicator if the node was modified | |||
| @@ -56,6 +56,23 @@ Status RepeatNode::ValidateParams() { | |||
| return Status::OK(); | |||
| } | |||
| // Get Dataset size | |||
| Status RepeatNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate, | |||
| int64_t *dataset_size) { | |||
| if (dataset_size_ > 0) { | |||
| *dataset_size = dataset_size_; | |||
| return Status::OK(); | |||
| } | |||
| int64_t num_rows; | |||
| RETURN_IF_NOT_OK(children_[0]->GetDatasetSize(size_getter, estimate, &num_rows)); | |||
| if (num_rows > 0 && repeat_count_ > 0) { | |||
| num_rows = num_rows * repeat_count_; | |||
| } | |||
| *dataset_size = num_rows; | |||
| dataset_size_ = num_rows; | |||
| return Status::OK(); | |||
| } | |||
| // Visitor accepting method for NodePass | |||
| Status RepeatNode::Accept(NodePass *p, bool *modified) { | |||
| // Downcast shared pointer then call visitor | |||
| @@ -56,6 +56,15 @@ class RepeatNode : public DatasetNode { | |||
| /// \return Status Status::OK() if all the parameters are valid | |||
| Status ValidateParams() override; | |||
| /// \brief Base-class override for GetDatasetSize | |||
| /// \param[in] size_getter Shared pointer to DatasetSizeGetter | |||
| /// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting | |||
| /// dataset size at the expense of accuracy. | |||
| /// \param[out] dataset_size the size of the dataset | |||
| /// \return Status of the function | |||
| Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate, | |||
| int64_t *dataset_size) override; | |||
| /// \brief Base-class override for accepting NodePass visitor | |||
| /// \param[in] p The node to visit | |||
| /// \param[out] modified Indicator if the node was modified | |||
| @@ -56,5 +56,22 @@ Status SkipNode::ValidateParams() { | |||
| return Status::OK(); | |||
| } | |||
| // Get Dataset size | |||
| Status SkipNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate, | |||
| int64_t *dataset_size) { | |||
| if (dataset_size_ > 0) { | |||
| *dataset_size = dataset_size_; | |||
| return Status::OK(); | |||
| } | |||
| int64_t num_rows; | |||
| RETURN_IF_NOT_OK(children_[0]->GetDatasetSize(size_getter, estimate, &num_rows)); | |||
| *dataset_size = 0; | |||
| if (skip_count_ >= 0 && skip_count_ < num_rows) { | |||
| *dataset_size = num_rows - skip_count_; | |||
| } | |||
| dataset_size_ = *dataset_size; | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -54,6 +54,15 @@ class SkipNode : public DatasetNode { | |||
| /// \return Status Status::OK() if all the parameters are valid | |||
| Status ValidateParams() override; | |||
| /// \brief Base-class override for GetDatasetSize | |||
| /// \param[in] size_getter Shared pointer to DatasetSizeGetter | |||
| /// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting | |||
| /// dataset size at the expense of accuracy. | |||
| /// \param[out] dataset_size the size of the dataset | |||
| /// \return Status of the function | |||
| Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate, | |||
| int64_t *dataset_size) override; | |||
| private: | |||
| int32_t skip_count_; | |||
| }; | |||
| @@ -21,6 +21,7 @@ | |||
| #include <string> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include <algorithm> | |||
| #include "minddata/dataset/engine/datasetops/source/celeba_op.h" | |||
| #include "minddata/dataset/util/status.h" | |||
| @@ -87,5 +88,66 @@ Status CelebANode::GetShardId(int32_t *shard_id) { | |||
| return Status::OK(); | |||
| } | |||
| // Get Dataset size | |||
| Status CelebANode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate, | |||
| int64_t *dataset_size) { | |||
| int64_t num_rows, sample_size; | |||
| std::ifstream partition_file; | |||
| std::string line; | |||
| Path folder_path(dataset_dir_); | |||
| std::ifstream attr_file((folder_path / "list_attr_celeba.txt").toString()); | |||
| if (!attr_file.is_open()) { | |||
| std::string attr_file_name = (folder_path / "list_attr_celeba.txt").toString(); | |||
| RETURN_STATUS_UNEXPECTED("Invalid file, failed to open Celeba attr file: " + attr_file_name); | |||
| } | |||
| std::string rows_num; | |||
| (void)getline(attr_file, rows_num); | |||
| try { | |||
| num_rows = static_cast<int64_t>(std::stoul(rows_num)); // First line is rows number in attr file | |||
| } catch (std::invalid_argument &e) { | |||
| RETURN_STATUS_UNEXPECTED( | |||
| "Invalid data, failed to convert rows_num from attr_file to unsigned long, invalid argument: " + rows_num); | |||
| } catch (std::out_of_range &e) { | |||
| RETURN_STATUS_UNEXPECTED( | |||
| "Invalid data, failed to convert rows_num from attr_file to unsigned long, out of range: " + rows_num); | |||
| } | |||
| if (usage_ != "all") { | |||
| int64_t partition_num = 0; | |||
| char usage_type; | |||
| if (usage_ == "train") { | |||
| usage_type = '0'; | |||
| } else { | |||
| if (usage_ == "valid") { | |||
| usage_type = '1'; | |||
| } else { | |||
| if (usage_ == "test") | |||
| usage_type = '2'; | |||
| else | |||
| RETURN_STATUS_UNEXPECTED("Invalid usage."); | |||
| } | |||
| } | |||
| if (!partition_file.is_open()) { | |||
| partition_file.open((folder_path / "list_eval_partition.txt").toString()); | |||
| } | |||
| if (partition_file.is_open()) { | |||
| while (getline(partition_file, line)) { | |||
| int start = line.find(' '); | |||
| if (line.at(start + 1) == usage_type) { | |||
| partition_num++; | |||
| } | |||
| } | |||
| } else { | |||
| std::string partition_file_name = "list_eval_partition.txt"; | |||
| RETURN_STATUS_UNEXPECTED("Invalid file, failed to open CelebA partition file: " + partition_file_name); | |||
| } | |||
| num_rows = std::min(num_rows, partition_num); | |||
| } | |||
| sample_size = sampler_->Build()->CalculateNumSamples(num_rows); | |||
| *dataset_size = sample_size; | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -61,6 +61,15 @@ class CelebANode : public MappableSourceNode { | |||
| /// \return Status Status::OK() if get shard id successfully | |||
| Status GetShardId(int32_t *shard_id) override; | |||
| /// \brief Base-class override for GetDatasetSize | |||
| /// \param[in] size_getter Shared pointer to DatasetSizeGetter | |||
| /// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting | |||
| /// dataset size at the expense of accuracy. | |||
| /// \param[out] dataset_size the size of the dataset | |||
| /// \return Status of the function | |||
| Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate, | |||
| int64_t *dataset_size) override; | |||
| private: | |||
| std::string dataset_dir_; | |||
| std::string usage_; | |||
| @@ -83,5 +83,20 @@ Status Cifar100Node::GetShardId(int32_t *shard_id) { | |||
| return Status::OK(); | |||
| } | |||
| // Get Dataset size | |||
| Status Cifar100Node::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate, | |||
| int64_t *dataset_size) { | |||
| if (dataset_size_ > 0) { | |||
| *dataset_size = dataset_size_; | |||
| return Status::OK(); | |||
| } | |||
| int64_t num_rows, sample_size; | |||
| RETURN_IF_NOT_OK(CifarOp::CountTotalRows(dataset_dir_, usage_, false, &num_rows)); | |||
| sample_size = sampler_->Build()->CalculateNumSamples(num_rows); | |||
| *dataset_size = sample_size; | |||
| dataset_size_ = *dataset_size; | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -59,6 +59,15 @@ class Cifar100Node : public MappableSourceNode { | |||
| /// \return Status Status::OK() if get shard id successfully | |||
| Status GetShardId(int32_t *shard_id) override; | |||
| /// \brief Base-class override for GetDatasetSize | |||
| /// \param[in] size_getter Shared pointer to DatasetSizeGetter | |||
| /// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting | |||
| /// dataset size at the expense of accuracy. | |||
| /// \param[out] dataset_size the size of the dataset | |||
| /// \return Status of the function | |||
| Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate, | |||
| int64_t *dataset_size) override; | |||
| private: | |||
| std::string dataset_dir_; | |||
| std::string usage_; | |||
| @@ -81,5 +81,20 @@ Status Cifar10Node::GetShardId(int32_t *shard_id) { | |||
| return Status::OK(); | |||
| } | |||
| // Get Dataset size | |||
| Status Cifar10Node::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate, | |||
| int64_t *dataset_size) { | |||
| if (dataset_size_ > 0) { | |||
| *dataset_size = dataset_size_; | |||
| return Status::OK(); | |||
| } | |||
| int64_t num_rows, sample_size; | |||
| RETURN_IF_NOT_OK(CifarOp::CountTotalRows(dataset_dir_, usage_, true, &num_rows)); | |||
| sample_size = sampler_->Build()->CalculateNumSamples(num_rows); | |||
| *dataset_size = sample_size; | |||
| dataset_size_ = *dataset_size; | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -59,6 +59,15 @@ class Cifar10Node : public MappableSourceNode { | |||
| /// \return Status Status::OK() if get shard id successfully | |||
| Status GetShardId(int32_t *shard_id) override; | |||
| /// \brief Base-class override for GetDatasetSize | |||
| /// \param[in] size_getter Shared pointer to DatasetSizeGetter | |||
| /// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting | |||
| /// dataset size at the expense of accuracy. | |||
| /// \param[out] dataset_size the size of the dataset | |||
| /// \return Status of the function | |||
| Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate, | |||
| int64_t *dataset_size) override; | |||
| private: | |||
| std::string dataset_dir_; | |||
| std::string usage_; | |||
| @@ -241,5 +241,21 @@ Status CLUENode::GetShardId(int32_t *shard_id) { | |||
| return Status::OK(); | |||
| } | |||
| // Get Dataset size | |||
| Status CLUENode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate, | |||
| int64_t *dataset_size) { | |||
| if (dataset_size_ > 0) { | |||
| *dataset_size = dataset_size_; | |||
| return Status::OK(); | |||
| } | |||
| int64_t num_rows, sample_size; | |||
| RETURN_IF_NOT_OK(ClueOp::CountAllFileRows(dataset_files_, &num_rows)); | |||
| sample_size = num_samples_; | |||
| num_rows = static_cast<int64_t>(ceil(num_rows / (1.0 * num_shards_))); | |||
| *dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows; | |||
| dataset_size_ = *dataset_size; | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -61,6 +61,15 @@ class CLUENode : public NonMappableSourceNode { | |||
| /// \return Status Status::OK() if get shard id successfully | |||
| Status GetShardId(int32_t *shard_id) override; | |||
| /// \brief Base-class override for GetDatasetSize | |||
| /// \param[in] size_getter Shared pointer to DatasetSizeGetter | |||
| /// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting | |||
| /// dataset size at the expense of accuracy. | |||
| /// \param[out] dataset_size the size of the dataset | |||
| /// \return Status of the function | |||
| Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate, | |||
| int64_t *dataset_size) override; | |||
| private: | |||
| /// \brief Split string based on a character delimiter | |||
| /// \return A string vector | |||
| @@ -134,5 +134,20 @@ Status CocoNode::GetShardId(int32_t *shard_id) { | |||
| return Status::OK(); | |||
| } | |||
| // Get Dataset size | |||
| Status CocoNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate, | |||
| int64_t *dataset_size) { | |||
| if (dataset_size_ > 0) { | |||
| *dataset_size = dataset_size_; | |||
| return Status::OK(); | |||
| } | |||
| int64_t num_rows = 0, sample_size; | |||
| RETURN_IF_NOT_OK(CocoOp::CountTotalRows(dataset_dir_, annotation_file_, task_, &num_rows)); | |||
| sample_size = sampler_->Build()->CalculateNumSamples(num_rows); | |||
| *dataset_size = sample_size; | |||
| dataset_size_ = *dataset_size; | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -59,6 +59,15 @@ class CocoNode : public MappableSourceNode { | |||
| /// \return Status Status::OK() if get shard id successfully | |||
| Status GetShardId(int32_t *shard_id) override; | |||
| /// \brief Base-class override for GetDatasetSize | |||
| /// \param[in] size_getter Shared pointer to DatasetSizeGetter | |||
| /// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting | |||
| /// dataset size at the expense of accuracy. | |||
| /// \param[out] dataset_size the size of the dataset | |||
| /// \return Status of the function | |||
| Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate, | |||
| int64_t *dataset_size) override; | |||
| private: | |||
| std::string dataset_dir_; | |||
| std::string annotation_file_; | |||
| @@ -153,5 +153,21 @@ Status CSVNode::GetShardId(int32_t *shard_id) { | |||
| return Status::OK(); | |||
| } | |||
| // Get Dataset size | |||
| Status CSVNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate, | |||
| int64_t *dataset_size) { | |||
| if (dataset_size_ > 0) { | |||
| *dataset_size = dataset_size_; | |||
| return Status::OK(); | |||
| } | |||
| int64_t num_rows, sample_size; | |||
| RETURN_IF_NOT_OK(CsvOp::CountAllFileRows(dataset_files_, column_names_.empty(), &num_rows)); | |||
| sample_size = num_samples_; | |||
| num_rows = static_cast<int64_t>(ceil(num_rows / (1.0 * num_shards_))); | |||
| *dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows; | |||
| dataset_size_ = *dataset_size; | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -82,6 +82,15 @@ class CSVNode : public NonMappableSourceNode { | |||
| /// \return Status Status::OK() if get shard id successfully | |||
| Status GetShardId(int32_t *shard_id) override; | |||
| /// \brief Base-class override for GetDatasetSize | |||
| /// \param[in] size_getter Shared pointer to DatasetSizeGetter | |||
| /// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting | |||
| /// dataset size at the expense of accuracy. | |||
| /// \param[out] dataset_size the size of the dataset | |||
| /// \return Status of the function | |||
| Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate, | |||
| int64_t *dataset_size) override; | |||
| private: | |||
| std::vector<std::string> dataset_files_; | |||
| char field_delim_; | |||
| @@ -93,6 +93,5 @@ Status GeneratorNode::GetShardId(int32_t *shard_id) { | |||
| *shard_id = 0; | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -64,6 +64,13 @@ class GeneratorNode : public MappableSourceNode { | |||
| /// \return Status Status::OK() if get shard id successfully | |||
| Status GetShardId(int32_t *shard_id) override; | |||
| /// \brief Setter for DatasetSize in GeneratorNode | |||
| /// \param[in] sz dataset size to set | |||
| /// \return void | |||
| void SetGeneratorDatasetSize(int64_t sz) { dataset_size_ = sz; } | |||
| bool IsSizeDefined() override { return false; } | |||
| private: | |||
| py::function generator_function_; | |||
| std::vector<std::string> column_names_; | |||
| @@ -89,5 +89,20 @@ Status ImageFolderNode::GetShardId(int32_t *shard_id) { | |||
| return Status::OK(); | |||
| } | |||
| // Get Dataset size | |||
| Status ImageFolderNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate, | |||
| int64_t *dataset_size) { | |||
| if (dataset_size_ > 0) { | |||
| *dataset_size = dataset_size_; | |||
| return Status::OK(); | |||
| } | |||
| int64_t sample_size, num_rows; | |||
| RETURN_IF_NOT_OK(ImageFolderOp::CountRowsAndClasses(dataset_dir_, exts_, &num_rows, nullptr, {})); | |||
| sample_size = sampler_->Build()->CalculateNumSamples(num_rows); | |||
| *dataset_size = sample_size; | |||
| dataset_size_ = *dataset_size; | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -65,6 +65,15 @@ class ImageFolderNode : public MappableSourceNode { | |||
| /// \return Status Status::OK() if get shard id successfully | |||
| Status GetShardId(int32_t *shard_id) override; | |||
| /// \brief Base-class override for GetDatasetSize | |||
| /// \param[in] size_getter Shared pointer to DatasetSizeGetter | |||
| /// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting | |||
| /// dataset size at the expense of accuracy. | |||
| /// \param[out] dataset_size the size of the dataset | |||
| /// \return Status of the function | |||
| Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate, | |||
| int64_t *dataset_size) override; | |||
| private: | |||
| std::string dataset_dir_; | |||
| bool decode_; | |||
| @@ -111,5 +111,21 @@ Status ManifestNode::GetShardId(int32_t *shard_id) { | |||
| return Status::OK(); | |||
| } | |||
| // Get Dataset size | |||
| Status ManifestNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate, | |||
| int64_t *dataset_size) { | |||
| if (dataset_size_ > 0) { | |||
| *dataset_size = dataset_size_; | |||
| return Status::OK(); | |||
| } | |||
| int64_t num_rows, sample_size; | |||
| int64_t num_classes; // dummy variable | |||
| RETURN_IF_NOT_OK(ManifestOp::CountTotalRows(dataset_file_, class_index_, usage_, &num_rows, &num_classes)); | |||
| sample_size = sampler_->Build()->CalculateNumSamples(num_rows); | |||
| *dataset_size = sample_size; | |||
| dataset_size_ = *dataset_size; | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -60,6 +60,15 @@ class ManifestNode : public MappableSourceNode { | |||
| /// \return Status Status::OK() if get shard id successfully | |||
| Status GetShardId(int32_t *shard_id) override; | |||
| /// \brief Base-class override for GetDatasetSize | |||
| /// \param[in] size_getter Shared pointer to DatasetSizeGetter | |||
| /// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting | |||
| /// dataset size at the expense of accuracy. | |||
| /// \param[out] dataset_size the size of the dataset | |||
| /// \return Status of the function | |||
| Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate, | |||
| int64_t *dataset_size) override; | |||
| private: | |||
| std::string dataset_file_; | |||
| std::string usage_; | |||
| @@ -152,7 +152,6 @@ std::vector<std::shared_ptr<DatasetOp>> MindDataNode::Build() { | |||
| // A vector containing shared pointer to the Dataset Ops that this object will create | |||
| std::vector<std::shared_ptr<DatasetOp>> node_ops; | |||
| std::vector<std::shared_ptr<ShardOperator>> operators_; | |||
| build_status = BuildMindDatasetSamplerChain(sampler_, &operators_, num_padded_); | |||
| RETURN_EMPTY_IF_ERROR(build_status); // remove me after changing return val of Build() | |||
| @@ -184,5 +183,28 @@ Status MindDataNode::GetShardId(int32_t *shard_id) { | |||
| return Status::OK(); | |||
| } | |||
| // Get Dataset size | |||
| Status MindDataNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate, | |||
| int64_t *dataset_size) { | |||
| if (dataset_size_ > 0) { | |||
| *dataset_size = dataset_size_; | |||
| return Status::OK(); | |||
| } | |||
| int64_t num_rows; | |||
| std::vector<std::shared_ptr<ShardOperator>> operators; | |||
| RETURN_IF_NOT_OK(BuildMindDatasetSamplerChain(sampler_, &operators, num_padded_)); | |||
| if (search_for_pattern_) { | |||
| dataset_files_ = {dataset_file_}; | |||
| } | |||
| // The last operator is parent sampler | |||
| std::shared_ptr<ShardOperator> op = operators.back(); | |||
| RETURN_IF_NOT_OK(MindRecordOp::CountTotalRows(dataset_files_, search_for_pattern_, op, &num_rows, num_padded_)); | |||
| *dataset_size = num_rows; | |||
| dataset_size_ = *dataset_size; | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -74,6 +74,15 @@ class MindDataNode : public MappableSourceNode { | |||
| /// \note Pybind will use this function to set sample_bytes into MindDataNode | |||
| void SetSampleBytes(std::map<std::string, std::string> *sample_bytes); | |||
| /// \brief Base-class override for GetDatasetSize | |||
| /// \param[in] size_getter Shared pointer to DatasetSizeGetter | |||
| /// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting | |||
| /// dataset size at the expense of accuracy. | |||
| /// \param[out] dataset_size the size of the dataset | |||
| /// \return Status of the function | |||
| Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate, | |||
| int64_t *dataset_size) override; | |||
| private: | |||
| std::string dataset_file_; // search_for_pattern_ will be true in this mode | |||
| std::vector<std::string> dataset_files_; // search_for_pattern_ will be false in this mode | |||
| @@ -83,6 +92,7 @@ class MindDataNode : public MappableSourceNode { | |||
| nlohmann::json padded_sample_; | |||
| std::map<std::string, std::string> sample_bytes_; // enable in python | |||
| int64_t num_padded_; | |||
| std::vector<std::shared_ptr<ShardOperator>> operators_; | |||
| }; | |||
| } // namespace dataset | |||
| @@ -75,5 +75,20 @@ Status MnistNode::GetShardId(int32_t *shard_id) { | |||
| return Status::OK(); | |||
| } | |||
| // Get Dataset size | |||
| Status MnistNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate, | |||
| int64_t *dataset_size) { | |||
| if (dataset_size_ > 0) { | |||
| *dataset_size = dataset_size_; | |||
| return Status::OK(); | |||
| } | |||
| int64_t num_rows, sample_size; | |||
| RETURN_IF_NOT_OK(MnistOp::CountTotalRows(dataset_dir_, usage_, &num_rows)); | |||
| sample_size = sampler_->Build()->CalculateNumSamples(num_rows); | |||
| *dataset_size = sample_size; | |||
| dataset_size_ = *dataset_size; | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -59,6 +59,15 @@ class MnistNode : public MappableSourceNode { | |||
| /// \return Status Status::OK() if get shard id successfully | |||
| Status GetShardId(int32_t *shard_id) override; | |||
| /// \brief Base-class override for GetDatasetSize | |||
| /// \param[in] size_getter Shared pointer to DatasetSizeGetter | |||
| /// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting | |||
| /// dataset size at the expense of accuracy. | |||
| /// \param[out] dataset_size the size of the dataset | |||
| /// \return Status of the function | |||
| Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate, | |||
| int64_t *dataset_size) override; | |||
| private: | |||
| std::string dataset_dir_; | |||
| std::string usage_; | |||
| @@ -86,17 +86,16 @@ std::vector<std::shared_ptr<DatasetOp>> RandomNode::Build() { | |||
| schema_file_path = schema_path_; | |||
| } | |||
| std::unique_ptr<DataSchema> data_schema; | |||
| std::vector<std::string> columns_to_load; | |||
| if (columns_list_.size() > 0) { | |||
| columns_to_load = columns_list_; | |||
| } | |||
| if (!schema_file_path.empty() || !schema_json_string.empty()) { | |||
| data_schema = std::make_unique<DataSchema>(); | |||
| data_schema_ = std::make_unique<DataSchema>(); | |||
| if (!schema_file_path.empty()) { | |||
| data_schema->LoadSchemaFile(schema_file_path, columns_to_load); | |||
| data_schema_->LoadSchemaFile(schema_file_path, columns_to_load); | |||
| } else if (!schema_json_string.empty()) { | |||
| data_schema->LoadSchemaString(schema_json_string, columns_to_load); | |||
| data_schema_->LoadSchemaString(schema_json_string, columns_to_load); | |||
| } | |||
| } | |||
| @@ -109,7 +108,7 @@ std::vector<std::shared_ptr<DatasetOp>> RandomNode::Build() { | |||
| std::shared_ptr<RandomDataOp> op; | |||
| op = std::make_shared<RandomDataOp>(num_workers_, connector_que_size_, rows_per_buffer_, total_rows_, | |||
| std::move(data_schema), std::move(sampler_->Build())); | |||
| std::move(data_schema_), std::move(sampler_->Build())); | |||
| build_status = AddCacheOp(&node_ops); // remove me after changing return val of Build() | |||
| RETURN_EMPTY_IF_ERROR(build_status); | |||
| @@ -125,5 +124,24 @@ Status RandomNode::GetShardId(int32_t *shard_id) { | |||
| return Status::OK(); | |||
| } | |||
| // Get Dataset size | |||
| Status RandomNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate, | |||
| int64_t *dataset_size) { | |||
| if (dataset_size_ > 0) { | |||
| *dataset_size = dataset_size_; | |||
| return Status::OK(); | |||
| } | |||
| int64_t num_rows; | |||
| num_rows = total_rows_ != 0 ? total_rows_ : data_schema_->num_rows(); | |||
| if (sampler_ != nullptr) { | |||
| int64_t sample_size; | |||
| sample_size = sampler_->Build()->CalculateNumSamples(num_rows); | |||
| *dataset_size = sample_size; | |||
| } else { | |||
| *dataset_size = num_rows; | |||
| } | |||
| dataset_size_ = *dataset_size; | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -79,6 +79,15 @@ class RandomNode : public NonMappableSourceNode { | |||
| /// \return Status Status::OK() if get shard id successfully | |||
| Status GetShardId(int32_t *shard_id) override; | |||
| /// \brief Base-class override for GetDatasetSize | |||
| /// \param[in] size_getter Shared pointer to DatasetSizeGetter | |||
| /// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting | |||
| /// dataset size at the expense of accuracy. | |||
| /// \param[out] dataset_size the size of the dataset | |||
| /// \return Status of the function | |||
| Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate, | |||
| int64_t *dataset_size) override; | |||
| private: | |||
| /// \brief A quick inline for producing a random number between (and including) min/max | |||
| /// \param[in] min minimum number that can be generated. | |||
| @@ -92,6 +101,7 @@ class RandomNode : public NonMappableSourceNode { | |||
| std::vector<std::string> columns_list_; | |||
| std::shared_ptr<SamplerObj> sampler_; | |||
| std::mt19937 rand_gen_; | |||
| std::unique_ptr<DataSchema> data_schema_; | |||
| }; | |||
| } // namespace dataset | |||
| @@ -122,5 +122,20 @@ Status TextFileNode::GetShardId(int32_t *shard_id) { | |||
| return Status::OK(); | |||
| } | |||
| // Get Dataset size | |||
| Status TextFileNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate, | |||
| int64_t *dataset_size) { | |||
| if (dataset_size_ > 0) { | |||
| *dataset_size = dataset_size_; | |||
| return Status::OK(); | |||
| } | |||
| int64_t num_rows, sample_size = num_samples_; | |||
| RETURN_IF_NOT_OK(TextFileOp::CountAllFileRows(dataset_files_, &num_rows)); | |||
| num_rows = static_cast<int64_t>(ceil(num_rows / (1.0 * num_shards_))); | |||
| *dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows; | |||
| dataset_size_ = *dataset_size; | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -61,6 +61,15 @@ class TextFileNode : public NonMappableSourceNode { | |||
| /// \return Status Status::OK() if get shard id successfully | |||
| Status GetShardId(int32_t *shard_id) override; | |||
| /// \brief Base-class override for GetDatasetSize | |||
| /// \param[in] size_getter Shared pointer to DatasetSizeGetter | |||
| /// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting | |||
| /// dataset size at the expense of accuracy. | |||
| /// \param[out] dataset_size the size of the dataset | |||
| /// \return Status of the function | |||
| Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate, | |||
| int64_t *dataset_size) override; | |||
| private: | |||
| std::vector<std::string> dataset_files_; | |||
| int32_t num_samples_; | |||
| @@ -169,5 +169,41 @@ Status TFRecordNode::GetShardId(int32_t *shard_id) { | |||
| return Status::OK(); | |||
| } | |||
| // Get Dataset size | |||
| Status TFRecordNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate, | |||
| int64_t *dataset_size) { | |||
| if (dataset_size_ > 0) { | |||
| *dataset_size = dataset_size_; | |||
| return Status::OK(); | |||
| } | |||
| int64_t num_rows; | |||
| if (!shard_equal_rows_) { | |||
| // Data will be sharded by file | |||
| std::vector<std::string> shard_file_list; | |||
| RETURN_IF_NOT_OK(GetShardFileList(&shard_file_list)); | |||
| RETURN_IF_NOT_OK(TFReaderOp::CountTotalRows(&num_rows, shard_file_list, 8, estimate)); | |||
| } else { | |||
| // Data will be sharded by row | |||
| RETURN_IF_NOT_OK(TFReaderOp::CountTotalRows(&num_rows, dataset_files_, 8, estimate)); | |||
| num_rows = static_cast<int64_t>(ceil(num_rows / (num_shards_ * 1.0))); | |||
| } | |||
| *dataset_size = num_samples_ > 0 ? std::min(num_rows, num_samples_) : num_rows; | |||
| dataset_size_ = *dataset_size; | |||
| return Status::OK(); | |||
| } | |||
| // Get the file list of the specific shard ID | |||
| Status TFRecordNode::GetShardFileList(std::vector<std::string> *shard_filenames) { | |||
| if (!shard_filenames->empty()) { | |||
| RETURN_STATUS_UNEXPECTED("The initial file list must be empty."); | |||
| } | |||
| for (int index = 0; index < dataset_files_.size(); index++) { | |||
| if (index % num_shards_ == shard_id_) { | |||
| shard_filenames->push_back(dataset_files_.at(index)); | |||
| } | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -88,6 +88,20 @@ class TFRecordNode : public NonMappableSourceNode { | |||
| /// \return Status Status::OK() if get shard id successfully | |||
| Status GetShardId(int32_t *shard_id) override; | |||
| /// \brief Base-class override for GetDatasetSize | |||
| /// \param[in] size_getter Shared pointer to DatasetSizeGetter | |||
| /// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting | |||
| /// dataset size at the expense of accuracy. | |||
| /// \param[out] dataset_size the size of the dataset | |||
| /// \return Status of the function | |||
| Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate, | |||
| int64_t *dataset_size) override; | |||
| /// \brief Get the file list of the specific shard ID | |||
| /// \param[out] shard_filenames the list of filenames for that specific shard ID | |||
| /// \return Status of the function | |||
| Status GetShardFileList(std::vector<std::string> *shard_filenames); | |||
| private: | |||
| std::vector<std::string> dataset_files_; | |||
| std::string schema_path_; // schema_path_ path to schema file. It is set when type of schema parameter is string | |||
| @@ -128,5 +128,20 @@ Status VOCNode::GetShardId(int32_t *shard_id) { | |||
| return Status::OK(); | |||
| } | |||
| // Get Dataset size | |||
| Status VOCNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate, | |||
| int64_t *dataset_size) { | |||
| if (dataset_size_ > 0) { | |||
| *dataset_size = dataset_size_; | |||
| return Status::OK(); | |||
| } | |||
| int64_t num_rows = 0, sample_size; | |||
| RETURN_IF_NOT_OK(VOCOp::CountTotalRows(dataset_dir_, task_, usage_, class_index_, &num_rows)); | |||
| sample_size = sampler_->Build()->CalculateNumSamples(num_rows); | |||
| *dataset_size = sample_size; | |||
| dataset_size_ = *dataset_size; | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -61,6 +61,15 @@ class VOCNode : public MappableSourceNode { | |||
| /// \return Status Status::OK() if get shard id successfully | |||
| Status GetShardId(int32_t *shard_id) override; | |||
| /// \brief Base-class override for GetDatasetSize | |||
| /// \param[in] size_getter Shared pointer to DatasetSizeGetter | |||
| /// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting | |||
| /// dataset size at the expense of accuracy. | |||
| /// \param[out] dataset_size the size of the dataset | |||
| /// \return Status of the function | |||
| Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate, | |||
| int64_t *dataset_size) override; | |||
| private: | |||
| const std::string kColumnImage = "image"; | |||
| const std::string kColumnTarget = "target"; | |||
| @@ -19,6 +19,7 @@ | |||
| #include <memory> | |||
| #include <string> | |||
| #include <vector> | |||
| #include <algorithm> | |||
| #include "minddata/dataset/engine/datasetops/take_op.h" | |||
| #include "minddata/dataset/util/status.h" | |||
| @@ -56,5 +57,19 @@ Status TakeNode::ValidateParams() { | |||
| return Status::OK(); | |||
| } | |||
| // Get Dataset size | |||
| Status TakeNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate, | |||
| int64_t *dataset_size) { | |||
| if (dataset_size_ > 0) { | |||
| *dataset_size = dataset_size_; | |||
| return Status::OK(); | |||
| } | |||
| int64_t num_rows; | |||
| RETURN_IF_NOT_OK(children_[0]->GetDatasetSize(size_getter, estimate, &num_rows)); | |||
| *dataset_size = std::min(static_cast<int64_t>(take_count_), num_rows); | |||
| dataset_size_ = *dataset_size; | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -54,6 +54,15 @@ class TakeNode : public DatasetNode { | |||
| /// \return Status Status::OK() if all the parameters are valid | |||
| Status ValidateParams() override; | |||
| /// \brief Base-class override for GetDatasetSize | |||
| /// \param[in] size_getter Shared pointer to DatasetSizeGetter | |||
| /// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting | |||
| /// dataset size at the expense of accuracy. | |||
| /// \param[out] dataset_size the size of the dataset | |||
| /// \return Status of the function | |||
| Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate, | |||
| int64_t *dataset_size) override; | |||
| private: | |||
| int32_t take_count_; | |||
| }; | |||