| @@ -21,7 +21,6 @@ | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| namespace api { | |||
| // Config operations for setting and getting the configuration. | |||
| namespace config { | |||
| @@ -104,6 +103,5 @@ bool load(std::string file) { | |||
| } | |||
| } // namespace config | |||
| } // namespace api | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -21,36 +21,14 @@ | |||
| #include <utility> | |||
| #include "minddata/dataset/include/samplers.h" | |||
| #include "minddata/dataset/include/transforms.h" | |||
| // Source dataset headers (in alphabetical order) | |||
| #include "minddata/dataset/engine/dataset_iterator.h" | |||
| #include "minddata/dataset/engine/datasetops/source/album_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/celeba_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/cifar_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/clue_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/coco_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/csv_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/image_folder_op.h" | |||
| #ifndef ENABLE_ANDROID | |||
| #include "minddata/dataset/engine/datasetops/source/manifest_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/mindrecord_op.h" | |||
| #include "minddata/dataset/engine/ir/cache/dataset_cache_impl.h" | |||
| #endif | |||
| #include "minddata/dataset/engine/datasetops/source/mnist_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/random_data_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/text_file_op.h" | |||
| #ifndef ENABLE_ANDROID | |||
| #include "minddata/dataset/engine/datasetops/source/tf_reader_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/voc_op.h" | |||
| #endif | |||
| // Dataset operator headers (in alphabetical order) | |||
| #include "minddata/dataset/engine/datasetops/map_op/map_op.h" | |||
| #include "minddata/dataset/engine/datasetops/skip_op.h" | |||
| #include "minddata/dataset/engine/datasetops/zip_op.h" | |||
| // Sampler headers (in alphabetical order) | |||
| #include "minddata/dataset/engine/datasetops/source/sampler/random_sampler.h" | |||
| #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" | |||
| #include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h" | |||
| // IR non-leaf nodes | |||
| #include "minddata/dataset/engine/ir/datasetops/batch_node.h" | |||
| @@ -99,7 +77,6 @@ | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| namespace api { | |||
| // Function to create the iterator, which will build and launch the execution tree. | |||
| std::shared_ptr<Iterator> Dataset::CreateIterator(std::vector<std::string> columns) { | |||
| @@ -317,7 +294,7 @@ std::shared_ptr<SchemaObj> Schema(const std::string &schema_file) { | |||
| return schema->init() ? schema : nullptr; | |||
| } | |||
| // FUNCTIONS TO CREATE DATASETS FOR LEAF-NODE DATASETS | |||
| // FUNCTIONS TO CREATE DATASETS FOR LEAF CLASSES | |||
| // (In alphabetical order) | |||
| // Function to create a AlbumDataset. | |||
| @@ -466,7 +443,7 @@ std::shared_ptr<VOCDataset> VOC(const std::string &dataset_dir, const std::strin | |||
| } | |||
| #endif | |||
| // Function to create a ZipNode. | |||
| // Function to create a ZipDatset. | |||
| std::shared_ptr<ZipDataset> Zip(const std::vector<std::shared_ptr<Dataset>> &datasets) { | |||
| auto ds = std::make_shared<ZipDataset>(datasets); | |||
| return ds; | |||
| @@ -639,7 +616,7 @@ std::shared_ptr<SentencePieceVocab> Dataset::BuildSentencePieceVocab( | |||
| std::unique_ptr<RuntimeContext> runtime_context = std::make_unique<RuntimeContext>(); | |||
| Status rc = runtime_context->Init(); | |||
| if (rc.IsError()) { | |||
| MS_LOG(ERROR) << "Failed to init runtime context. Error status: " << rc; | |||
| MS_LOG(ERROR) << "BuildSentencePieceVocab: Failed to init runtime context. Error status: " << rc; | |||
| return nullptr; | |||
| } | |||
| @@ -647,15 +624,15 @@ std::shared_ptr<SentencePieceVocab> Dataset::BuildSentencePieceVocab( | |||
| BuildVocabConsumer *bv_consumer = consumer.get(); | |||
| rc = consumer->Init(ds); | |||
| if (rc.IsError()) { | |||
| MS_LOG(ERROR) << "BuildVocab: Failed to init. Error status: " << rc; | |||
| MS_LOG(ERROR) << "BuildSentencePieceVocab: Failed to init consumer. Error status: " << rc; | |||
| return nullptr; | |||
| } | |||
| runtime_context->AssignConsumer(std::move(consumer)); | |||
| // Run tree here to starting building vocab | |||
| // Run tree here to starting building SentencePieceVocab | |||
| rc = bv_consumer->Start(); | |||
| if (rc.IsError()) { | |||
| MS_LOG(ERROR) << "BuildVocab: Failed to start. Error status: " << rc; | |||
| MS_LOG(ERROR) << "BuildSentencePieceVocab: Failed to start consumer. Error status: " << rc; | |||
| return nullptr; | |||
| } | |||
| return vocab; | |||
| @@ -671,7 +648,7 @@ std::shared_ptr<Vocab> Dataset::BuildVocab(const std::vector<std::string> &colum | |||
| std::unique_ptr<RuntimeContext> runtime_context = std::make_unique<RuntimeContext>(); | |||
| Status rc = runtime_context->Init(); | |||
| if (rc.IsError()) { | |||
| MS_LOG(ERROR) << "Failed to init runtime context. Error status: " << rc; | |||
| MS_LOG(ERROR) << "BuildVocab: Failed to init runtime context. Error status: " << rc; | |||
| return nullptr; | |||
| } | |||
| @@ -679,7 +656,7 @@ std::shared_ptr<Vocab> Dataset::BuildVocab(const std::vector<std::string> &colum | |||
| BuildVocabConsumer *bv_consumer = consumer.get(); | |||
| rc = consumer->Init(ds); | |||
| if (rc.IsError()) { | |||
| MS_LOG(ERROR) << "BuildVocab: Failed to init. Error status: " << rc; | |||
| MS_LOG(ERROR) << "BuildVocab: Failed to init consumer. Error status: " << rc; | |||
| return nullptr; | |||
| } | |||
| runtime_context->AssignConsumer(std::move(consumer)); | |||
| @@ -687,11 +664,14 @@ std::shared_ptr<Vocab> Dataset::BuildVocab(const std::vector<std::string> &colum | |||
| // Run tree here to starting building vocab | |||
| rc = bv_consumer->Start(); | |||
| if (rc.IsError()) { | |||
| MS_LOG(ERROR) << "BuildVocab: Failed to start. Error status: " << rc; | |||
| MS_LOG(ERROR) << "BuildVocab: Failed to start consumer. Error status: " << rc; | |||
| return nullptr; | |||
| } | |||
| return vocab; | |||
| } | |||
| std::shared_ptr<BatchDataset> Dataset::Batch(int32_t batch_size, bool drop_remainder) { | |||
| return std::make_shared<BatchDataset>(shared_from_this(), batch_size, drop_remainder); | |||
| } | |||
| #endif | |||
| SchemaObj::SchemaObj(const std::string &schema_file) : schema_file_(schema_file), num_rows_(0), dataset_type_("") {} | |||
| @@ -856,162 +836,6 @@ bool SchemaObj::from_json(nlohmann::json json_obj) { | |||
| // OTHER FUNCTIONS | |||
| // Helper function to compute a default shuffle size | |||
| Status ComputeShuffleSize(int64_t num_files, int64_t num_devices, int64_t num_rows, int64_t total_rows, | |||
| int64_t *shuffle_size) { | |||
| const int64_t average_files_multiplier = 4; | |||
| const int64_t shuffle_max = 10000; | |||
| int64_t avg_rows_per_file = 0; | |||
| // Adjust the num rows per shard if sharding was given | |||
| if (num_devices > 0) { | |||
| if (num_rows % num_devices == 0) { | |||
| num_rows = num_rows / num_devices; | |||
| } else { | |||
| num_rows = (num_rows / num_devices) + 1; | |||
| } | |||
| } | |||
| // Cap based on total rows directive. Some ops do not have this and give value of 0. | |||
| if (total_rows > 0) { | |||
| num_rows = std::min(num_rows, total_rows); | |||
| } | |||
| // get the average per file | |||
| CHECK_FAIL_RETURN_UNEXPECTED(num_files != 0, "The size of dataset_files must greater than 0."); | |||
| avg_rows_per_file = num_rows / num_files; | |||
| *shuffle_size = std::max(avg_rows_per_file * average_files_multiplier, shuffle_max); | |||
| return Status::OK(); | |||
| } | |||
| // Helper function to inject a shuffle operator over top of current operator being built | |||
| Status AddShuffleOp(int64_t num_files, int64_t num_devices, int64_t num_rows, int64_t total_rows, | |||
| int32_t connector_que_size, int32_t rows_per_buffer, std::shared_ptr<DatasetOp> *shuffle_op) { | |||
| std::shared_ptr<ShuffleOp> new_shuffle_op = nullptr; | |||
| int64_t shuffle_size = 0; | |||
| RETURN_EMPTY_IF_ERROR(ComputeShuffleSize(num_files, num_devices, num_rows, total_rows, &shuffle_size)); | |||
| MS_LOG(INFO) << "Dataset::AddShuffleOp - num_rows: " << num_rows << ", shuffle_size: " << shuffle_size; | |||
| // Add the shuffle op | |||
| *shuffle_op = std::make_shared<ShuffleOp>(shuffle_size, GetSeed(), connector_que_size, true, rows_per_buffer); | |||
| return Status::OK(); | |||
| } | |||
| // Helper function to validate dataset directory parameter | |||
| Status ValidateDatasetDirParam(const std::string &dataset_name, std::string dataset_dir) { | |||
| if (dataset_dir.empty()) { | |||
| std::string err_msg = dataset_name + ": dataset_dir is not specified."; | |||
| MS_LOG(ERROR) << err_msg; | |||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||
| } | |||
| Path dir(dataset_dir); | |||
| if (!dir.IsDirectory()) { | |||
| std::string err_msg = dataset_name + ": dataset_dir: [" + dataset_dir + "] is an invalid directory path."; | |||
| MS_LOG(ERROR) << err_msg; | |||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||
| } | |||
| if (access(dataset_dir.c_str(), R_OK) == -1) { | |||
| std::string err_msg = dataset_name + ": No access to specified dataset path: " + dataset_dir; | |||
| MS_LOG(ERROR) << err_msg; | |||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| // Helper function to validate dataset files parameter | |||
| Status ValidateDatasetFilesParam(const std::string &dataset_name, const std::vector<std::string> &dataset_files) { | |||
| if (dataset_files.empty()) { | |||
| std::string err_msg = dataset_name + ": dataset_files is not specified."; | |||
| MS_LOG(ERROR) << err_msg; | |||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||
| } | |||
| for (auto f : dataset_files) { | |||
| Path dataset_file(f); | |||
| if (!dataset_file.Exists()) { | |||
| std::string err_msg = dataset_name + ": dataset file: [" + f + "] is invalid or does not exist."; | |||
| MS_LOG(ERROR) << err_msg; | |||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||
| } | |||
| if (access(dataset_file.toString().c_str(), R_OK) == -1) { | |||
| std::string err_msg = dataset_name + ": No access to specified dataset file: " + f; | |||
| MS_LOG(ERROR) << err_msg; | |||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||
| } | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| // Helper function to validate dataset num_shards and shard_id parameters | |||
| Status ValidateDatasetShardParams(const std::string &dataset_name, int32_t num_shards, int32_t shard_id) { | |||
| if (num_shards <= 0) { | |||
| std::string err_msg = dataset_name + ": Invalid num_shards: " + std::to_string(num_shards); | |||
| MS_LOG(ERROR) << err_msg; | |||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||
| } | |||
| if (shard_id < 0 || shard_id >= num_shards) { | |||
| // num_shards; | |||
| std::string err_msg = dataset_name + ": Invalid input, shard_id: " + std::to_string(shard_id) + | |||
| ", num_shards: " + std::to_string(num_shards); | |||
| MS_LOG(ERROR) << err_msg; | |||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| // Helper function to validate dataset sampler parameter | |||
| Status ValidateDatasetSampler(const std::string &dataset_name, const std::shared_ptr<SamplerObj> &sampler) { | |||
| if (sampler == nullptr) { | |||
| std::string err_msg = dataset_name + ": Sampler is not constructed correctly, sampler: nullptr"; | |||
| MS_LOG(ERROR) << err_msg; | |||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status ValidateStringValue(const std::string &dataset_name, const std::string &str, | |||
| const std::unordered_set<std::string> &valid_strings) { | |||
| if (valid_strings.find(str) == valid_strings.end()) { | |||
| std::string mode; | |||
| mode = std::accumulate(valid_strings.begin(), valid_strings.end(), mode, | |||
| [](std::string a, std::string b) { return std::move(a) + " " + std::move(b); }); | |||
| std::string err_msg = dataset_name + ": " + str + " does not match any mode in [" + mode + " ]"; | |||
| MS_LOG(ERROR) << err_msg; | |||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| // Helper function to validate dataset input/output column parameter | |||
| Status ValidateDatasetColumnParam(const std::string &dataset_name, const std::string &column_param, | |||
| const std::vector<std::string> &columns) { | |||
| if (columns.empty()) { | |||
| std::string err_msg = dataset_name + ":" + column_param + " should not be empty string"; | |||
| MS_LOG(ERROR) << err_msg; | |||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||
| } | |||
| for (uint32_t i = 0; i < columns.size(); ++i) { | |||
| if (columns[i].empty()) { | |||
| std::string err_msg = dataset_name + ":" + column_param + "[" + std::to_string(i) + "] must not be empty"; | |||
| MS_LOG(ERROR) << err_msg; | |||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||
| } | |||
| } | |||
| std::set<std::string> columns_set(columns.begin(), columns.end()); | |||
| if (columns_set.size() != columns.size()) { | |||
| std::string err_msg = dataset_name + ":" + column_param + ": Every column name should not be same with others"; | |||
| MS_LOG(ERROR) << err_msg; | |||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| #ifndef ENABLE_ANDROID | |||
| std::shared_ptr<DatasetCache> CreateDatasetCache(session_id_type id, uint64_t mem_sz, bool spill, | |||
| @@ -1153,22 +977,5 @@ TFRecordDataset::TFRecordDataset(const std::vector<std::string> &dataset_files, | |||
| ir_node_ = std::static_pointer_cast<DatasetNode>(ds); | |||
| } | |||
| #endif | |||
| std::shared_ptr<SamplerObj> SelectSampler(int64_t num_samples, bool shuffle, int32_t num_shards, int32_t shard_id) { | |||
| if (shuffle) { | |||
| if (num_shards > 1) { | |||
| // If shuffle enabled, sharding enabled, use distributed random sampler | |||
| return DistributedSampler(num_shards, shard_id, shuffle, num_samples); | |||
| } | |||
| // If shuffle enabled, sharding disabled, use random sampler | |||
| return RandomSampler(num_samples >= 0, num_samples); | |||
| } | |||
| if (num_shards > 1) { | |||
| // If shuffle disabled, sharding enabled, use distributed sequential sampler | |||
| return DistributedSampler(num_shards, shard_id, shuffle, num_samples); | |||
| } | |||
| // If shuffle disabled, sharding disabled, use sequential sampler | |||
| return SequentialSampler(0, num_samples); | |||
| } | |||
| } // namespace api | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -26,7 +26,6 @@ | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| namespace api { | |||
| Execute::Execute(std::shared_ptr<TensorOperation> op) : op_(std::move(op)) {} | |||
| @@ -54,6 +53,5 @@ std::shared_ptr<tensor::MSTensor> Execute::operator()(std::shared_ptr<tensor::MS | |||
| return std::make_shared<tensor::DETensor>(std::move(de_output)); | |||
| } | |||
| } // namespace api | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -20,7 +20,6 @@ | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| namespace api { | |||
| // Get the next row from the data pipeline. | |||
| bool Iterator::GetNextRow(TensorMap *row) { | |||
| @@ -45,19 +44,18 @@ bool Iterator::GetNextRow(TensorVec *row) { | |||
| } | |||
| // Shut down the data pipeline. | |||
| void Iterator::Stop() { runtime_context->Terminate(); } | |||
| void Iterator::Stop() { runtime_context_->Terminate(); } | |||
| // Function to build and launch the execution tree. | |||
| Status Iterator::BuildAndLaunchTree(std::shared_ptr<Dataset> ds) { | |||
| runtime_context = std::make_unique<RuntimeContext>(); | |||
| RETURN_IF_NOT_OK(runtime_context->Init()); | |||
| runtime_context_ = std::make_unique<RuntimeContext>(); | |||
| RETURN_IF_NOT_OK(runtime_context_->Init()); | |||
| auto consumer = std::make_unique<IteratorConsumer>(); | |||
| consumer_ = consumer.get(); | |||
| RETURN_IF_NOT_OK(consumer->Init(ds->IRNode())); | |||
| runtime_context->AssignConsumer(std::move(consumer)); | |||
| runtime_context_->AssignConsumer(std::move(consumer)); | |||
| return Status::OK(); | |||
| } | |||
| } // namespace api | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -27,59 +27,59 @@ | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| PYBIND_REGISTER(Sampler, 0, ([](const py::module *m) { | |||
| (void)py::class_<Sampler, std::shared_ptr<Sampler>>(*m, "Sampler") | |||
| PYBIND_REGISTER(SamplerRT, 0, ([](const py::module *m) { | |||
| (void)py::class_<SamplerRT, std::shared_ptr<SamplerRT>>(*m, "Sampler") | |||
| .def("set_num_rows", | |||
| [](Sampler &self, int64_t rows) { THROW_IF_ERROR(self.SetNumRowsInDataset(rows)); }) | |||
| [](SamplerRT &self, int64_t rows) { THROW_IF_ERROR(self.SetNumRowsInDataset(rows)); }) | |||
| .def("set_num_samples", | |||
| [](Sampler &self, int64_t samples) { THROW_IF_ERROR(self.SetNumSamples(samples)); }) | |||
| .def("initialize", [](Sampler &self) { THROW_IF_ERROR(self.InitSampler()); }) | |||
| [](SamplerRT &self, int64_t samples) { THROW_IF_ERROR(self.SetNumSamples(samples)); }) | |||
| .def("initialize", [](SamplerRT &self) { THROW_IF_ERROR(self.InitSampler()); }) | |||
| .def("get_indices", | |||
| [](Sampler &self) { | |||
| [](SamplerRT &self) { | |||
| py::array ret; | |||
| THROW_IF_ERROR(self.GetAllIdsThenReset(&ret)); | |||
| return ret; | |||
| }) | |||
| .def("add_child", [](std::shared_ptr<Sampler> self, std::shared_ptr<Sampler> child) { | |||
| .def("add_child", [](std::shared_ptr<SamplerRT> self, std::shared_ptr<SamplerRT> child) { | |||
| THROW_IF_ERROR(self->AddChild(child)); | |||
| }); | |||
| })); | |||
| PYBIND_REGISTER(DistributedSampler, 1, ([](const py::module *m) { | |||
| (void)py::class_<DistributedSampler, Sampler, std::shared_ptr<DistributedSampler>>( | |||
| PYBIND_REGISTER(DistributedSamplerRT, 1, ([](const py::module *m) { | |||
| (void)py::class_<DistributedSamplerRT, SamplerRT, std::shared_ptr<DistributedSamplerRT>>( | |||
| *m, "DistributedSampler") | |||
| .def(py::init<int64_t, int64_t, int64_t, bool, uint32_t, int64_t>()); | |||
| })); | |||
| PYBIND_REGISTER(PKSampler, 1, ([](const py::module *m) { | |||
| (void)py::class_<PKSampler, Sampler, std::shared_ptr<PKSampler>>(*m, "PKSampler") | |||
| PYBIND_REGISTER(PKSamplerRT, 1, ([](const py::module *m) { | |||
| (void)py::class_<PKSamplerRT, SamplerRT, std::shared_ptr<PKSamplerRT>>(*m, "PKSampler") | |||
| .def(py::init<int64_t, int64_t, bool>()); | |||
| })); | |||
| PYBIND_REGISTER(PythonSampler, 1, ([](const py::module *m) { | |||
| (void)py::class_<PythonSampler, Sampler, std::shared_ptr<PythonSampler>>(*m, "PythonSampler") | |||
| PYBIND_REGISTER(PythonSamplerRT, 1, ([](const py::module *m) { | |||
| (void)py::class_<PythonSamplerRT, SamplerRT, std::shared_ptr<PythonSamplerRT>>(*m, "PythonSampler") | |||
| .def(py::init<int64_t, py::object>()); | |||
| })); | |||
| PYBIND_REGISTER(RandomSampler, 1, ([](const py::module *m) { | |||
| (void)py::class_<RandomSampler, Sampler, std::shared_ptr<RandomSampler>>(*m, "RandomSampler") | |||
| PYBIND_REGISTER(RandomSamplerRT, 1, ([](const py::module *m) { | |||
| (void)py::class_<RandomSamplerRT, SamplerRT, std::shared_ptr<RandomSamplerRT>>(*m, "RandomSampler") | |||
| .def(py::init<int64_t, bool, bool>()); | |||
| })); | |||
| PYBIND_REGISTER(SequentialSampler, 1, ([](const py::module *m) { | |||
| (void)py::class_<SequentialSampler, Sampler, std::shared_ptr<SequentialSampler>>(*m, | |||
| "SequentialSampler") | |||
| PYBIND_REGISTER(SequentialSamplerRT, 1, ([](const py::module *m) { | |||
| (void)py::class_<SequentialSamplerRT, SamplerRT, std::shared_ptr<SequentialSamplerRT>>( | |||
| *m, "SequentialSampler") | |||
| .def(py::init<int64_t, int64_t>()); | |||
| })); | |||
| PYBIND_REGISTER(SubsetRandomSampler, 1, ([](const py::module *m) { | |||
| (void)py::class_<SubsetRandomSampler, Sampler, std::shared_ptr<SubsetRandomSampler>>( | |||
| PYBIND_REGISTER(SubsetRandomSamplerRT, 1, ([](const py::module *m) { | |||
| (void)py::class_<SubsetRandomSamplerRT, SamplerRT, std::shared_ptr<SubsetRandomSamplerRT>>( | |||
| *m, "SubsetRandomSampler") | |||
| .def(py::init<int64_t, std::vector<int64_t>>()); | |||
| })); | |||
| PYBIND_REGISTER(WeightedRandomSampler, 1, ([](const py::module *m) { | |||
| (void)py::class_<WeightedRandomSampler, Sampler, std::shared_ptr<WeightedRandomSampler>>( | |||
| PYBIND_REGISTER(WeightedRandomSamplerRT, 1, ([](const py::module *m) { | |||
| (void)py::class_<WeightedRandomSamplerRT, SamplerRT, std::shared_ptr<WeightedRandomSamplerRT>>( | |||
| *m, "WeightedRandomSampler") | |||
| .def(py::init<int64_t, std::vector<double>, bool>()); | |||
| })); | |||
| @@ -1140,7 +1140,7 @@ Status DEPipeline::ParseConcatOp(const py::dict &args, std::shared_ptr<DatasetOp | |||
| if (!value.is_none()) { | |||
| if (key == "sampler") { | |||
| auto create = py::reinterpret_borrow<py::object>(value).attr("create"); | |||
| std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>(); | |||
| std::shared_ptr<SamplerRT> sampler = create().cast<std::shared_ptr<SamplerRT>>(); | |||
| (void)builder->SetSampler(std::move(sampler)); | |||
| } | |||
| if (key == "children_flag_and_nums") { | |||
| @@ -1164,7 +1164,7 @@ Status DEPipeline::ParseTFReaderOp(const py::dict &args, std::shared_ptr<Dataset | |||
| // Required arguments | |||
| std::vector<std::string> files_list; | |||
| std::shared_ptr<CacheClient> cache_client = nullptr; | |||
| std::shared_ptr<Sampler> sampler = nullptr; | |||
| std::shared_ptr<SamplerRT> sampler = nullptr; | |||
| int num_workers = 0; | |||
| std::shared_ptr<TFReaderOp::Builder> builder = std::make_shared<TFReaderOp::Builder>(); | |||
| if (!args["dataset_files"].is_none()) { | |||
| @@ -1210,7 +1210,7 @@ Status DEPipeline::ParseTFReaderOp(const py::dict &args, std::shared_ptr<Dataset | |||
| cache_client = value.cast<std::shared_ptr<CacheClient>>(); | |||
| } else if (key == "sampler") { | |||
| auto create = py::reinterpret_borrow<py::object>(value).attr("create"); | |||
| sampler = create().cast<std::shared_ptr<Sampler>>(); | |||
| sampler = create().cast<std::shared_ptr<SamplerRT>>(); | |||
| } | |||
| } | |||
| } | |||
| @@ -1234,7 +1234,7 @@ Status DEPipeline::ParseTFReaderOp(const py::dict &args, std::shared_ptr<Dataset | |||
| } else if (cache_client) { | |||
| const int64_t num_samples = 0; | |||
| const int64_t start_index = 0; | |||
| sampler = std::make_shared<SequentialSampler>(num_samples, start_index); | |||
| sampler = std::make_shared<SequentialSamplerRT>(num_samples, start_index); | |||
| (void)builder->SetSampler(std::move(sampler)); | |||
| } | |||
| @@ -1308,7 +1308,7 @@ Status DEPipeline::ParseImageFolderOp(const py::dict &args, std::shared_ptr<Data | |||
| (void)builder->SetNumWorkers(num_workers); | |||
| } else if (key == "sampler") { | |||
| auto create = py::reinterpret_borrow<py::object>(value).attr("create"); | |||
| std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>(); | |||
| std::shared_ptr<SamplerRT> sampler = create().cast<std::shared_ptr<SamplerRT>>(); | |||
| (void)builder->SetSampler(std::move(sampler)); | |||
| } else if (key == "extensions") { | |||
| (void)builder->SetExtensions(ToStringSet(value)); | |||
| @@ -1363,7 +1363,7 @@ Status DEPipeline::ParseManifestOp(const py::dict &args, std::shared_ptr<Dataset | |||
| (void)builder->SetNumWorkers(num_workers); | |||
| } else if (key == "sampler") { | |||
| auto create = py::reinterpret_borrow<py::object>(value).attr("create"); | |||
| std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>(); | |||
| std::shared_ptr<SamplerRT> sampler = create().cast<std::shared_ptr<SamplerRT>>(); | |||
| (void)builder->SetSampler(std::move(sampler)); | |||
| } else if (key == "class_indexing") { | |||
| (void)builder->SetClassIndex(ToStringMap(value)); | |||
| @@ -1416,7 +1416,7 @@ Status DEPipeline::ParseVOCOp(const py::dict &args, std::shared_ptr<DatasetOp> * | |||
| (void)builder->SetNumWorkers(num_workers); | |||
| } else if (key == "sampler") { | |||
| auto create = py::reinterpret_borrow<py::object>(value).attr("create"); | |||
| std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>(); | |||
| std::shared_ptr<SamplerRT> sampler = create().cast<std::shared_ptr<SamplerRT>>(); | |||
| (void)builder->SetSampler(std::move(sampler)); | |||
| } else if (key == "decode") { | |||
| (void)builder->SetDecode(ToBool(value)); | |||
| @@ -1478,7 +1478,7 @@ Status DEPipeline::ParseCocoOp(const py::dict &args, std::shared_ptr<DatasetOp> | |||
| (void)builder->SetNumWorkers(num_workers); | |||
| } else if (key == "sampler") { | |||
| auto create = py::reinterpret_borrow<py::object>(value).attr("create"); | |||
| std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>(); | |||
| std::shared_ptr<SamplerRT> sampler = create().cast<std::shared_ptr<SamplerRT>>(); | |||
| (void)builder->SetSampler(std::move(sampler)); | |||
| } else if (key == "decode") { | |||
| (void)builder->SetDecode(ToBool(value)); | |||
| @@ -1529,7 +1529,7 @@ Status DEPipeline::ParseCifar10Op(const py::dict &args, std::shared_ptr<DatasetO | |||
| (void)builder->SetNumWorkers(num_workers); | |||
| } else if (key == "sampler") { | |||
| auto create = py::reinterpret_borrow<py::object>(value).attr("create"); | |||
| std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>(); | |||
| std::shared_ptr<SamplerRT> sampler = create().cast<std::shared_ptr<SamplerRT>>(); | |||
| (void)builder->SetSampler(std::move(sampler)); | |||
| } else if (key == "usage") { | |||
| (void)builder->SetUsage(ToString(value)); | |||
| @@ -1583,7 +1583,7 @@ Status DEPipeline::ParseCifar100Op(const py::dict &args, std::shared_ptr<Dataset | |||
| (void)builder->SetNumWorkers(num_workers); | |||
| } else if (key == "sampler") { | |||
| auto create = py::reinterpret_borrow<py::object>(value).attr("create"); | |||
| std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>(); | |||
| std::shared_ptr<SamplerRT> sampler = create().cast<std::shared_ptr<SamplerRT>>(); | |||
| (void)builder->SetSampler(std::move(sampler)); | |||
| } else if (key == "usage") { | |||
| (void)builder->SetUsage(ToString(value)); | |||
| @@ -1618,7 +1618,7 @@ Status DEPipeline::ParseRandomDataOp(const py::dict &args, std::shared_ptr<Datas | |||
| // Required arguments | |||
| RandomDataOp::Builder builder; | |||
| std::shared_ptr<CacheClient> cache_client = nullptr; | |||
| std::shared_ptr<Sampler> sampler = nullptr; | |||
| std::shared_ptr<SamplerRT> sampler = nullptr; | |||
| int num_workers = 0; | |||
| if (args["total_rows"].is_none()) { | |||
| @@ -1646,7 +1646,7 @@ Status DEPipeline::ParseRandomDataOp(const py::dict &args, std::shared_ptr<Datas | |||
| cache_client = value.cast<std::shared_ptr<CacheClient>>(); | |||
| } else if (key == "sampler") { | |||
| auto create = py::reinterpret_borrow<py::object>(value).attr("create"); | |||
| sampler = create().cast<std::shared_ptr<Sampler>>(); | |||
| sampler = create().cast<std::shared_ptr<SamplerRT>>(); | |||
| } | |||
| } | |||
| } | |||
| @@ -1670,7 +1670,7 @@ Status DEPipeline::ParseRandomDataOp(const py::dict &args, std::shared_ptr<Datas | |||
| } else if (cache_client) { | |||
| const int64_t num_samples = 0; | |||
| const int64_t start_index = 0; | |||
| sampler = std::make_shared<SequentialSampler>(num_samples, start_index); | |||
| sampler = std::make_shared<SequentialSamplerRT>(num_samples, start_index); | |||
| (void)builder.SetSampler(std::move(sampler)); | |||
| } | |||
| @@ -1715,7 +1715,7 @@ Status DEPipeline::ParseMnistOp(const py::dict &args, std::shared_ptr<DatasetOp> | |||
| (void)builder->SetNumWorkers(num_workers); | |||
| } else if (key == "sampler") { | |||
| auto create = py::reinterpret_borrow<py::object>(value).attr("create"); | |||
| std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>(); | |||
| std::shared_ptr<SamplerRT> sampler = create().cast<std::shared_ptr<SamplerRT>>(); | |||
| (void)builder->SetSampler(std::move(sampler)); | |||
| } else if (key == "usage") { | |||
| (void)builder->SetUsage(ToString(value)); | |||
| @@ -1768,7 +1768,7 @@ Status DEPipeline::ParseCelebAOp(const py::dict &args, std::shared_ptr<DatasetOp | |||
| (void)builder->SetNumWorkers(num_workers); | |||
| } else if (key == "sampler") { | |||
| auto create = py::reinterpret_borrow<py::object>(value).attr("create"); | |||
| std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>(); | |||
| std::shared_ptr<SamplerRT> sampler = create().cast<std::shared_ptr<SamplerRT>>(); | |||
| (void)builder->SetSampler(std::move(sampler)); | |||
| } else if (key == "decode") { | |||
| (void)builder->SetDecode(ToBool(value)); | |||
| @@ -1806,7 +1806,7 @@ Status DEPipeline::ParseTextFileOp(const py::dict &args, std::shared_ptr<Dataset | |||
| // Required arguments | |||
| std::vector<std::string> files_list; | |||
| std::shared_ptr<CacheClient> cache_client = nullptr; | |||
| std::shared_ptr<Sampler> sampler = nullptr; | |||
| std::shared_ptr<SamplerRT> sampler = nullptr; | |||
| int num_workers = 0; | |||
| std::shared_ptr<TextFileOp::Builder> builder = std::make_shared<TextFileOp::Builder>(); | |||
| if (!args["dataset_files"].is_none()) { | |||
| @@ -1840,7 +1840,7 @@ Status DEPipeline::ParseTextFileOp(const py::dict &args, std::shared_ptr<Dataset | |||
| cache_client = value.cast<std::shared_ptr<CacheClient>>(); | |||
| } else if (key == "sampler") { | |||
| auto create = py::reinterpret_borrow<py::object>(value).attr("create"); | |||
| sampler = create().cast<std::shared_ptr<Sampler>>(); | |||
| sampler = create().cast<std::shared_ptr<SamplerRT>>(); | |||
| } | |||
| } | |||
| } | |||
| @@ -1855,7 +1855,7 @@ Status DEPipeline::ParseTextFileOp(const py::dict &args, std::shared_ptr<Dataset | |||
| } else if (cache_client) { | |||
| int64_t num_samples = 0; | |||
| int64_t start_index = 0; | |||
| sampler = std::make_shared<SequentialSampler>(num_samples, start_index); | |||
| sampler = std::make_shared<SequentialSamplerRT>(num_samples, start_index); | |||
| (void)builder->SetSampler(std::move(sampler)); | |||
| } | |||
| @@ -1991,7 +1991,7 @@ Status DEPipeline::ParseClueOp(const py::dict &args, std::shared_ptr<DatasetOp> | |||
| std::shared_ptr<DatasetOp> *bottom) { | |||
| std::vector<std::string> files_list; | |||
| std::shared_ptr<CacheClient> cache_client = nullptr; | |||
| std::shared_ptr<Sampler> sampler = nullptr; | |||
| std::shared_ptr<SamplerRT> sampler = nullptr; | |||
| int num_workers = 0; | |||
| std::shared_ptr<ClueOp::Builder> builder = std::make_shared<ClueOp::Builder>(); | |||
| @@ -2036,7 +2036,7 @@ Status DEPipeline::ParseClueOp(const py::dict &args, std::shared_ptr<DatasetOp> | |||
| cache_client = value.cast<std::shared_ptr<CacheClient>>(); | |||
| } else if (key == "sampler") { | |||
| auto create = py::reinterpret_borrow<py::object>(value).attr("create"); | |||
| sampler = create().cast<std::shared_ptr<Sampler>>(); | |||
| sampler = create().cast<std::shared_ptr<SamplerRT>>(); | |||
| } | |||
| } | |||
| } | |||
| @@ -2051,7 +2051,7 @@ Status DEPipeline::ParseClueOp(const py::dict &args, std::shared_ptr<DatasetOp> | |||
| } else if (cache_client) { | |||
| int64_t num_samples = 0; | |||
| int64_t start_index = 0; | |||
| sampler = std::make_shared<SequentialSampler>(num_samples, start_index); | |||
| sampler = std::make_shared<SequentialSamplerRT>(num_samples, start_index); | |||
| (void)builder->SetSampler(std::move(sampler)); | |||
| } | |||
| @@ -2116,7 +2116,7 @@ Status DEPipeline::ParseCsvOp(const py::dict &args, std::shared_ptr<DatasetOp> * | |||
| std::shared_ptr<DatasetOp> *bottom) { | |||
| std::vector<std::string> files_list; | |||
| std::shared_ptr<CacheClient> cache_client = nullptr; | |||
| std::shared_ptr<Sampler> sampler = nullptr; | |||
| std::shared_ptr<SamplerRT> sampler = nullptr; | |||
| int num_workers = 0; | |||
| std::shared_ptr<CsvOp::Builder> builder = std::make_shared<CsvOp::Builder>(); | |||
| if (!args["dataset_files"].is_none()) { | |||
| @@ -2173,7 +2173,7 @@ Status DEPipeline::ParseCsvOp(const py::dict &args, std::shared_ptr<DatasetOp> * | |||
| cache_client = value.cast<std::shared_ptr<CacheClient>>(); | |||
| } else if (key == "sampler") { | |||
| auto create = py::reinterpret_borrow<py::object>(value).attr("create"); | |||
| sampler = create().cast<std::shared_ptr<Sampler>>(); | |||
| sampler = create().cast<std::shared_ptr<SamplerRT>>(); | |||
| } | |||
| } | |||
| } | |||
| @@ -2188,7 +2188,7 @@ Status DEPipeline::ParseCsvOp(const py::dict &args, std::shared_ptr<DatasetOp> * | |||
| } else if (cache_client) { | |||
| int64_t num_samples = 0; | |||
| int64_t start_index = 0; | |||
| sampler = std::make_shared<SequentialSampler>(num_samples, start_index); | |||
| sampler = std::make_shared<SequentialSamplerRT>(num_samples, start_index); | |||
| (void)builder->SetSampler(std::move(sampler)); | |||
| } | |||
| @@ -35,7 +35,6 @@ | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| namespace api { | |||
| #define RETURN_NULL_IF_ERROR(_s) \ | |||
| do { \ | |||
| @@ -151,10 +150,10 @@ bool DistributedSamplerObj::ValidateParams() { | |||
| return true; | |||
| } | |||
| std::shared_ptr<Sampler> DistributedSamplerObj::Build() { | |||
| std::shared_ptr<SamplerRT> DistributedSamplerObj::Build() { | |||
| // runtime sampler object | |||
| auto sampler = std::make_shared<dataset::DistributedSampler>(num_samples_, num_shards_, shard_id_, shuffle_, seed_, | |||
| offset_, even_dist_); | |||
| auto sampler = std::make_shared<dataset::DistributedSamplerRT>(num_samples_, num_shards_, shard_id_, shuffle_, seed_, | |||
| offset_, even_dist_); | |||
| return sampler; | |||
| } | |||
| @@ -184,9 +183,9 @@ bool PKSamplerObj::ValidateParams() { | |||
| return true; | |||
| } | |||
| std::shared_ptr<Sampler> PKSamplerObj::Build() { | |||
| std::shared_ptr<SamplerRT> PKSamplerObj::Build() { | |||
| // runtime sampler object | |||
| auto sampler = std::make_shared<dataset::PKSampler>(num_samples_, num_val_, shuffle_); | |||
| auto sampler = std::make_shared<dataset::PKSamplerRT>(num_samples_, num_val_, shuffle_); | |||
| return sampler; | |||
| } | |||
| @@ -218,10 +217,10 @@ bool RandomSamplerObj::ValidateParams() { | |||
| return true; | |||
| } | |||
| std::shared_ptr<Sampler> RandomSamplerObj::Build() { | |||
| std::shared_ptr<SamplerRT> RandomSamplerObj::Build() { | |||
| // runtime sampler object | |||
| bool reshuffle_each_epoch = true; | |||
| auto sampler = std::make_shared<dataset::RandomSampler>(num_samples_, replacement_, reshuffle_each_epoch); | |||
| auto sampler = std::make_shared<dataset::RandomSamplerRT>(num_samples_, replacement_, reshuffle_each_epoch); | |||
| return sampler; | |||
| } | |||
| @@ -255,9 +254,9 @@ bool SequentialSamplerObj::ValidateParams() { | |||
| return true; | |||
| } | |||
| std::shared_ptr<Sampler> SequentialSamplerObj::Build() { | |||
| std::shared_ptr<SamplerRT> SequentialSamplerObj::Build() { | |||
| // runtime sampler object | |||
| auto sampler = std::make_shared<dataset::SequentialSampler>(num_samples_, start_index_); | |||
| auto sampler = std::make_shared<dataset::SequentialSamplerRT>(num_samples_, start_index_); | |||
| return sampler; | |||
| } | |||
| @@ -284,9 +283,9 @@ bool SubsetRandomSamplerObj::ValidateParams() { | |||
| return true; | |||
| } | |||
| std::shared_ptr<Sampler> SubsetRandomSamplerObj::Build() { | |||
| std::shared_ptr<SamplerRT> SubsetRandomSamplerObj::Build() { | |||
| // runtime sampler object | |||
| auto sampler = std::make_shared<dataset::SubsetRandomSampler>(num_samples_, indices_); | |||
| auto sampler = std::make_shared<dataset::SubsetRandomSamplerRT>(num_samples_, indices_); | |||
| return sampler; | |||
| } | |||
| @@ -330,11 +329,10 @@ bool WeightedRandomSamplerObj::ValidateParams() { | |||
| return true; | |||
| } | |||
| std::shared_ptr<Sampler> WeightedRandomSamplerObj::Build() { | |||
| auto sampler = std::make_shared<dataset::WeightedRandomSampler>(num_samples_, weights_, replacement_); | |||
| std::shared_ptr<SamplerRT> WeightedRandomSamplerObj::Build() { | |||
| auto sampler = std::make_shared<dataset::WeightedRandomSamplerRT>(num_samples_, weights_, replacement_); | |||
| return sampler; | |||
| } | |||
| } // namespace api | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -22,7 +22,6 @@ | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| namespace api { | |||
| // Transform operations for text. | |||
| namespace text { | |||
| @@ -130,6 +129,5 @@ std::shared_ptr<TensorOp> SentencePieceTokenizerOperation::Build() { | |||
| } | |||
| } // namespace text | |||
| } // namespace api | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -22,7 +22,6 @@ | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| namespace api { | |||
| TensorOperation::TensorOperation() {} | |||
| @@ -94,6 +93,5 @@ Status TypeCastOperation::ValidateParams() { | |||
| std::shared_ptr<TensorOp> TypeCastOperation::Build() { return std::make_shared<TypeCastOp>(data_type_); } | |||
| } // namespace transforms | |||
| } // namespace api | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -65,7 +65,6 @@ | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| namespace api { | |||
| // Transform operations for computer vision. | |||
| namespace vision { | |||
| @@ -1702,6 +1701,5 @@ std::shared_ptr<TensorOp> UniformAugOperation::Build() { | |||
| #endif | |||
| } // namespace vision | |||
| } // namespace api | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -34,11 +34,11 @@ namespace mindspore::dataset { | |||
| // TreeConsumer | |||
| TreeConsumer::TreeConsumer() { tree_adapter_ = std::make_unique<TreeAdapter>(); } | |||
| Status TreeConsumer::Init(std::shared_ptr<api::DatasetNode> d) { return tree_adapter_->BuildAndPrepare(std::move(d)); } | |||
| Status TreeConsumer::Init(std::shared_ptr<DatasetNode> d) { return tree_adapter_->BuildAndPrepare(std::move(d)); } | |||
| Status TreeConsumer::Terminate() { return tree_adapter_->AllTasks()->DoServiceStop(); } | |||
| // IteratorConsumer | |||
| Status IteratorConsumer::Init(std::shared_ptr<api::DatasetNode> d) { | |||
| Status IteratorConsumer::Init(std::shared_ptr<DatasetNode> d) { | |||
| return tree_adapter_->BuildAndPrepare(std::move(d), num_epochs_); | |||
| } | |||
| @@ -74,7 +74,7 @@ Status IteratorConsumer::GetNextAsMap(std::unordered_map<std::string, TensorPtr> | |||
| } | |||
| // ToDevice | |||
| Status ToDevice::Init(std::shared_ptr<api::DatasetNode> d) { | |||
| Status ToDevice::Init(std::shared_ptr<DatasetNode> d) { | |||
| return tree_adapter_->BuildAndPrepare(std::move(d), num_epochs_); | |||
| } | |||
| @@ -385,8 +385,8 @@ TreeGetters::TreeGetters() : dataset_size_(-1), init_flag_(false), row_flag_(fal | |||
| tree_adapter_ = std::make_unique<TreeAdapter>(); | |||
| } | |||
| Status TreeGetters::Init(std::shared_ptr<api::DatasetNode> d) { | |||
| Status s = tree_adapter_->BuildAndPrepare(std::move(d)); | |||
| Status TreeGetters::Init(std::shared_ptr<DatasetNode> d) { | |||
| Status s = tree_adapter_->BuildAndPrepare(std::move(d), 1); | |||
| if (!s.IsError()) { | |||
| init_flag_ = true; | |||
| } | |||
| @@ -464,7 +464,7 @@ Status TreeGetters::GetNumClasses(int64_t *num_classes) { | |||
| RETURN_IF_NOT_OK(root->GetNumClasses(num_classes)); | |||
| return Status::OK(); | |||
| } | |||
| Status BuildVocabConsumer::Init(std::shared_ptr<api::DatasetNode> d) { | |||
| Status BuildVocabConsumer::Init(std::shared_ptr<DatasetNode> d) { | |||
| return tree_adapter_->BuildAndPrepare(std::move(d), 1); | |||
| } | |||
| Status BuildVocabConsumer::Start() { | |||
| @@ -29,10 +29,7 @@ | |||
| namespace mindspore::dataset { | |||
| // Forward declare | |||
| class TreeAdapter; | |||
| namespace api { | |||
| class DatasetNode; | |||
| } | |||
| /// A base class for tree consumers which would fetch rows from the tree pipeline | |||
| class TreeConsumer { | |||
| @@ -42,7 +39,7 @@ class TreeConsumer { | |||
| /// Initializes the consumer, this involves constructing and preparing the tree. | |||
| /// \param d The dataset node that represent the root of the IR tree. | |||
| /// \return Status error code. | |||
| virtual Status Init(std::shared_ptr<api::DatasetNode> d); | |||
| virtual Status Init(std::shared_ptr<DatasetNode> d); | |||
| Status Terminate(); | |||
| @@ -61,7 +58,7 @@ class IteratorConsumer : public TreeConsumer { | |||
| /// \param num_epochs number of epochs. Default to -1 (infinite epochs). | |||
| explicit IteratorConsumer(int32_t num_epochs = -1) : TreeConsumer(), num_epochs_(num_epochs) {} | |||
| Status Init(std::shared_ptr<api::DatasetNode> d) override; | |||
| Status Init(std::shared_ptr<DatasetNode> d) override; | |||
| /// Returns the next row in a vector format | |||
| /// \param[out] out std::vector of Tensors | |||
| @@ -133,7 +130,7 @@ class ToDevice : public TreeConsumer { | |||
| explicit ToDevice(bool send_epoch_end, int32_t num_epochs = -1) | |||
| : TreeConsumer(), send_epoch_end_(send_epoch_end), num_epochs_(num_epochs) {} | |||
| Status Init(std::shared_ptr<api::DatasetNode> d) override; | |||
| Status Init(std::shared_ptr<DatasetNode> d) override; | |||
| /// Send the data to device | |||
| /// \return Status error code | |||
| @@ -162,7 +159,7 @@ class ToDevice : public TreeConsumer { | |||
| class TreeGetters : public TreeConsumer { | |||
| public: | |||
| TreeGetters(); | |||
| Status Init(std::shared_ptr<api::DatasetNode> d) override; | |||
| Status Init(std::shared_ptr<DatasetNode> d) override; | |||
| Status GetDatasetSize(int64_t *size); | |||
| Status GetOutputTypes(std::vector<DataType> *types); | |||
| Status GetOutputShapes(std::vector<TensorShape> *shapes); | |||
| @@ -185,10 +182,9 @@ class BuildVocabConsumer : public TreeConsumer { | |||
| /// BuildVocabConsumer Constructor which will call the base class default constructor. | |||
| BuildVocabConsumer() = default; | |||
| Status Init(std::shared_ptr<api::DatasetNode> d) override; | |||
| Status Init(std::shared_ptr<DatasetNode> d) override; | |||
| /// Save the given dataset to MindRecord format on disk. This is a blocking method (i.e., after returning, all rows | |||
| /// would be written to disk) | |||
| /// Start consuming | |||
| /// \return Status error code | |||
| Status Start(); | |||
| @@ -46,7 +46,7 @@ Status CacheBase::Reset() { | |||
| return Status::OK(); | |||
| } | |||
| CacheBase::CacheBase(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf, | |||
| std::shared_ptr<CacheClient> cache_client, std::shared_ptr<Sampler> sampler) | |||
| std::shared_ptr<CacheClient> cache_client, std::shared_ptr<SamplerRT> sampler) | |||
| : ParallelOp(num_workers, op_connector_size, std::move(sampler)), | |||
| row_cnt_(0), | |||
| num_cache_miss_(0), | |||
| @@ -46,7 +46,7 @@ class CacheBase : public ParallelOp { | |||
| /// \param cache_client CacheClient for communication to the CacheServer | |||
| /// \param sampler Sampler which is mandatory | |||
| CacheBase(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf, | |||
| std::shared_ptr<CacheClient> cache_client, std::shared_ptr<Sampler> sampler); | |||
| std::shared_ptr<CacheClient> cache_client, std::shared_ptr<SamplerRT> sampler); | |||
| /// \brief Destructor | |||
| ~CacheBase(); | |||
| @@ -87,7 +87,7 @@ Status CacheLookupOp::HandshakeRandomAccessOp(const RandomAccessOp *op) { | |||
| leaf_op_wp_.Set(); | |||
| return Status::OK(); | |||
| } | |||
| Status CacheLookupOp::InitSampler() { return Sampler::InitSampler(); } | |||
| Status CacheLookupOp::InitSampler() { return SamplerRT::InitSampler(); } | |||
| void CacheLookupOp::Print(std::ostream &out, bool show_all) const { CacheBase::Print(out, show_all); } | |||
| Status CacheLookupOp::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) { | |||
| std::vector<row_id_type> cache_miss; | |||
| @@ -28,7 +28,7 @@ namespace dataset { | |||
| /// \brief provides a memory/disk cache that acts as a save-point within a mappable dataset. | |||
| /// \note For non-mappable dataset, please see CacheOp | |||
| /// \see CacheOp | |||
| class CacheLookupOp : public CacheBase, public Sampler { | |||
| class CacheLookupOp : public CacheBase, public SamplerRT { | |||
| public: | |||
| class Builder { | |||
| public: | |||
| @@ -62,7 +62,7 @@ class CacheLookupOp : public CacheBase, public Sampler { | |||
| /// \brief Setter method. | |||
| /// \return Builder setter method returns reference to the builder. | |||
| Builder &SetSampler(std::shared_ptr<Sampler> sampler) { | |||
| Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) { | |||
| build_sampler_ = std::move(sampler); | |||
| return *this; | |||
| } | |||
| @@ -77,7 +77,7 @@ class CacheLookupOp : public CacheBase, public Sampler { | |||
| int32_t rows_per_buffer_; | |||
| int32_t build_op_connector_size_; | |||
| std::shared_ptr<CacheClient> build_cache_client_; | |||
| std::shared_ptr<Sampler> build_sampler_; | |||
| std::shared_ptr<SamplerRT> build_sampler_; | |||
| // Check if the required parameters are set by the builder. | |||
| // \return Status The error code return | |||
| @@ -87,8 +87,8 @@ class CacheLookupOp : public CacheBase, public Sampler { | |||
| /// \note It takes the same argument as the base class. | |||
| /// \see CacheBase | |||
| CacheLookupOp(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf, | |||
| std::shared_ptr<CacheClient> cache_client, std::shared_ptr<Sampler> sampler) | |||
| : CacheBase(num_workers, op_connector_size, rows_per_buf, cache_client, sampler), Sampler(*(sampler.get())) {} | |||
| std::shared_ptr<CacheClient> cache_client, std::shared_ptr<SamplerRT> sampler) | |||
| : CacheBase(num_workers, op_connector_size, rows_per_buf, cache_client, sampler), SamplerRT(*(sampler.get())) {} | |||
| ~CacheLookupOp() = default; | |||
| // As a parallel op, we override these two functions | |||
| Status operator()() override; | |||
| @@ -46,7 +46,7 @@ void CacheMergeOp::Print(std::ostream &out, bool show_all) const { | |||
| } | |||
| CacheMergeOp::CacheMergeOp(int32_t numWorkers, int32_t opConnectorSize, int32_t numCleaners, | |||
| std::shared_ptr<CacheClient> cache_client, const std::shared_ptr<Sampler> &sampler) | |||
| std::shared_ptr<CacheClient> cache_client, const std::shared_ptr<SamplerRT> &sampler) | |||
| : ParallelOp(numWorkers, opConnectorSize, sampler), | |||
| num_cleaners_(numCleaners), | |||
| cache_client_(std::move(cache_client)), | |||
| @@ -110,7 +110,7 @@ class CacheMergeOp : public ParallelOp { | |||
| /// \brief Setter method | |||
| /// \param sampler | |||
| /// \return Builder setter method returns reference to the builder. | |||
| Builder &SetSampler(std::shared_ptr<Sampler> sampler) { | |||
| Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) { | |||
| build_sampler_ = std::move(sampler); | |||
| return *this; | |||
| } | |||
| @@ -133,7 +133,7 @@ class CacheMergeOp : public ParallelOp { | |||
| int32_t build_op_connector_size_; | |||
| int32_t build_num_cleaners_; | |||
| std::shared_ptr<CacheClient> build_cache_client_; | |||
| std::shared_ptr<Sampler> build_sampler_; | |||
| std::shared_ptr<SamplerRT> build_sampler_; | |||
| /// Check if the required parameters are set by the builder. | |||
| /// \return Status The error code return | |||
| @@ -147,7 +147,7 @@ class CacheMergeOp : public ParallelOp { | |||
| /// \param cache_client CacheClient to commmunicate with the Cache server | |||
| /// \param sampler as a derived class of ParallelOp | |||
| CacheMergeOp(int32_t numWorkers, int32_t opConnectorSize, int32_t numCleaners, | |||
| std::shared_ptr<CacheClient> cache_client, const std::shared_ptr<Sampler> &sampler); | |||
| std::shared_ptr<CacheClient> cache_client, const std::shared_ptr<SamplerRT> &sampler); | |||
| ~CacheMergeOp(); | |||
| void Print(std::ostream &out, bool show_all) const override; | |||
| std::string Name() const override { return kCacheMergeOp; } | |||
| @@ -68,7 +68,7 @@ Status CacheOp::Builder::Build(std::shared_ptr<CacheOp> *ptr) { | |||
| // Constructor of CacheOp | |||
| CacheOp::CacheOp(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf, | |||
| std::shared_ptr<CacheClient> cache_client, std::shared_ptr<Sampler> sampler) | |||
| std::shared_ptr<CacheClient> cache_client, std::shared_ptr<SamplerRT> sampler) | |||
| : CacheBase(num_workers, op_connector_size, rows_per_buf, std::move(cache_client), std::move(sampler)), | |||
| num_guys_in_(0), | |||
| phase_(Phase::kBuildPhase) {} | |||
| @@ -81,7 +81,7 @@ class CacheOp : public CacheBase, public RandomAccessOp { | |||
| /// \brief Setter method | |||
| /// \param sampler | |||
| /// \return Builder setter method returns reference to the builder. | |||
| Builder &SetSampler(std::shared_ptr<Sampler> sampler) { | |||
| Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) { | |||
| build_sampler_ = std::move(sampler); | |||
| return *this; | |||
| } | |||
| @@ -96,7 +96,7 @@ class CacheOp : public CacheBase, public RandomAccessOp { | |||
| int32_t rows_per_buffer_; | |||
| int32_t build_op_connector_size_; | |||
| std::shared_ptr<CacheClient> build_cache_client_; | |||
| std::shared_ptr<Sampler> build_sampler_; | |||
| std::shared_ptr<SamplerRT> build_sampler_; | |||
| /// \brief Check if the required parameters are set by the builder. | |||
| /// \return Status The error code return | |||
| @@ -108,7 +108,7 @@ class CacheOp : public CacheBase, public RandomAccessOp { | |||
| /// \param num_workers The number of worker threads. | |||
| /// \param op_connector_size The size of each queue in the connector. | |||
| CacheOp(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf, | |||
| std::shared_ptr<CacheClient> cache_client, std::shared_ptr<Sampler> sampler); | |||
| std::shared_ptr<CacheClient> cache_client, std::shared_ptr<SamplerRT> sampler); | |||
| // Destructor | |||
| ~CacheOp(); | |||
| @@ -36,7 +36,7 @@ ConcatOp::Builder::Builder() { | |||
| // The builder "build" method creates the final object. | |||
| Status ConcatOp::Builder::Build(std::shared_ptr<ConcatOp> *ptr) { | |||
| if (builder_sampler_ == nullptr) { | |||
| builder_sampler_ = std::make_shared<DistributedSampler>(0, 1, 0, false); | |||
| builder_sampler_ = std::make_shared<DistributedSamplerRT>(0, 1, 0, false); | |||
| } | |||
| *ptr = std::make_shared<ConcatOp>(builder_op_connector_size_, builder_sampler_, children_flag_and_nums_, | |||
| children_start_end_index_); | |||
| @@ -44,7 +44,7 @@ Status ConcatOp::Builder::Build(std::shared_ptr<ConcatOp> *ptr) { | |||
| } | |||
| // Constructor of the ConcatOp. | |||
| ConcatOp::ConcatOp(int32_t op_connector_size, std::shared_ptr<Sampler> sampler, | |||
| ConcatOp::ConcatOp(int32_t op_connector_size, std::shared_ptr<SamplerRT> sampler, | |||
| std::vector<std::pair<int, int>> children_flag_and_nums, | |||
| std::vector<std::pair<int, int>> children_start_end_index) | |||
| : PipelineOp(op_connector_size), | |||
| @@ -80,7 +80,7 @@ Status ConcatOp::operator()() { | |||
| bool is_not_mappable = true; | |||
| int num_shard = 1; | |||
| int shard_index = 0; | |||
| std::shared_ptr<DistributedSampler> distribute_sampler = std::dynamic_pointer_cast<DistributedSampler>(sampler_); | |||
| std::shared_ptr<DistributedSamplerRT> distribute_sampler = std::dynamic_pointer_cast<DistributedSamplerRT>(sampler_); | |||
| if (distribute_sampler != nullptr) { | |||
| num_shard = distribute_sampler->GetDeviceNum(); | |||
| shard_index = distribute_sampler->GetDeviceID(); | |||
| @@ -44,7 +44,7 @@ class ConcatOp : public PipelineOp { | |||
| // The builder "build" method creates the final object. | |||
| // @return shared_ptr to the new ConcatOp object | |||
| Status Build(std::shared_ptr<ConcatOp> *); | |||
| Builder &SetSampler(std::shared_ptr<Sampler> sampler) { | |||
| Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) { | |||
| builder_sampler_ = std::move(sampler); | |||
| return *this; | |||
| } | |||
| @@ -61,7 +61,7 @@ class ConcatOp : public PipelineOp { | |||
| private: | |||
| int32_t builder_op_connector_size_; | |||
| std::shared_ptr<Sampler> builder_sampler_; | |||
| std::shared_ptr<SamplerRT> builder_sampler_; | |||
| std::vector<std::pair<int, int>> children_flag_and_nums_; | |||
| std::vector<std::pair<int, int>> children_start_end_index_; | |||
| }; | |||
| @@ -70,7 +70,7 @@ class ConcatOp : public PipelineOp { | |||
| // @note The builder class should be used to call it | |||
| // @param op_connector_size - connector size | |||
| explicit ConcatOp(int32_t op_connector_size); | |||
| explicit ConcatOp(int32_t op_connector_size, std::shared_ptr<Sampler> sampler, | |||
| explicit ConcatOp(int32_t op_connector_size, std::shared_ptr<SamplerRT> sampler, | |||
| std::vector<std::pair<int, int>> children_flag_and_nums, | |||
| std::vector<std::pair<int, int>> children_start_end_index); | |||
| @@ -123,7 +123,7 @@ class ConcatOp : public PipelineOp { | |||
| std::unordered_map<std::string, int32_t> column_name_id_; // Mapping between col index and col name | |||
| std::vector<DataType> data_type_; | |||
| std::vector<dsize_t> data_rank_; | |||
| std::shared_ptr<Sampler> sampler_; | |||
| std::shared_ptr<SamplerRT> sampler_; | |||
| std::vector<std::pair<int, int>> children_flag_and_nums_; | |||
| std::vector<std::pair<int, int>> children_start_end_index_; | |||
| }; | |||
| @@ -40,7 +40,7 @@ | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| // Constructor | |||
| DatasetOp::DatasetOp(int32_t op_connector_size, std::shared_ptr<Sampler> sampler) | |||
| DatasetOp::DatasetOp(int32_t op_connector_size, std::shared_ptr<SamplerRT> sampler) | |||
| : oc_queue_size_(op_connector_size), | |||
| sampler_(sampler), | |||
| operator_id_(kInvalidOperatorId), | |||
| @@ -409,7 +409,7 @@ Status DatasetOp::Accept(NodePass *p, bool *modified) { | |||
| } | |||
| // Getter for the sampler, and it also removes the sampler from the op | |||
| Status DatasetOp::FetchRemoveSampler(std::shared_ptr<Sampler> *sampler) { | |||
| Status DatasetOp::FetchRemoveSampler(std::shared_ptr<SamplerRT> *sampler) { | |||
| *sampler = sampler_; // It's okay if it sampler_ points to nullptr | |||
| sampler_.reset(); // clear our member-copy of this pointer. We no longer have this sampler | |||
| return Status::OK(); | |||
| @@ -62,7 +62,7 @@ class DataBuffer; | |||
| class NodePass; | |||
| class Sampler; | |||
| class SamplerRT; | |||
| /// \brief The base class DatasetOp is the main tree node. It is an abstract class, so | |||
| /// the actual implementation of the operators will be derived from here. | |||
| @@ -80,7 +80,7 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> { | |||
| /// Constructor | |||
| /// \param op_connector_size - The size for the output connector of this operator. | |||
| /// \param sampler - The sampler for the op | |||
| explicit DatasetOp(int32_t op_connector_size, std::shared_ptr<Sampler> sampler); | |||
| explicit DatasetOp(int32_t op_connector_size, std::shared_ptr<SamplerRT> sampler); | |||
| /// Destructor | |||
| virtual ~DatasetOp() { tree_ = nullptr; } | |||
| @@ -347,12 +347,12 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> { | |||
| /// Getter for the sampler | |||
| /// \return Shared pointer to the sampler (may return nullptr) | |||
| std::shared_ptr<Sampler> sampler() { return sampler_; } | |||
| std::shared_ptr<SamplerRT> sampler() { return sampler_; } | |||
| /// \brief Getter for the sampler, and it also removes the sampler from the op | |||
| /// \param[out] sampler A pointer to the output sampler that was removed | |||
| /// \return Status error code | |||
| Status FetchRemoveSampler(std::shared_ptr<Sampler> *sampler); | |||
| Status FetchRemoveSampler(std::shared_ptr<SamplerRT> *sampler); | |||
| #ifndef ENABLE_ANDROID | |||
| // Computes a CRC value for the operator | |||
| @@ -368,7 +368,7 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> { | |||
| } | |||
| /// \brief Setter for the sampler. Allows you to overwrite a previous sampler with a new one. | |||
| void SetSampler(std::shared_ptr<Sampler> sampler) { sampler_ = sampler; } | |||
| void SetSampler(std::shared_ptr<SamplerRT> sampler) { sampler_ = sampler; } | |||
| /// \brief Checks if this is a leaf node (0 children) | |||
| /// \return boolean returns true if it's a leaf | |||
| @@ -409,7 +409,7 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> { | |||
| std::vector<std::shared_ptr<DatasetOp>> child_; // Child nodes | |||
| std::vector<DatasetOp *> parent_; // Parent nodes. No ownership | |||
| std::shared_ptr<Sampler> sampler_; // Some leaf ops might have a sampler | |||
| std::shared_ptr<SamplerRT> sampler_; // Some leaf ops might have a sampler | |||
| int32_t oc_queue_size_; // Capacity for each out_connector_ | |||
| int32_t operator_id_; // Generated id for the node | |||
| ExecutionTree *tree_; // Back pointer to our tree. | |||
| @@ -26,7 +26,7 @@ | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| // Constructor | |||
| ParallelOp::ParallelOp(int32_t num_workers, int32_t op_connector_size, std::shared_ptr<Sampler> sampler) | |||
| ParallelOp::ParallelOp(int32_t num_workers, int32_t op_connector_size, std::shared_ptr<SamplerRT> sampler) | |||
| : DatasetOp(op_connector_size, sampler), | |||
| num_workers_(num_workers), | |||
| num_producers_(num_workers), | |||
| @@ -41,7 +41,7 @@ class ParallelOp : public DatasetOp { | |||
| // @param num_workers | |||
| // @param op_connector_size - size of the output connector for this operator | |||
| // @param sampler - The sampler for the op | |||
| ParallelOp(int32_t num_workers, int32_t op_connector_size, std::shared_ptr<Sampler> sampler = nullptr); | |||
| ParallelOp(int32_t num_workers, int32_t op_connector_size, std::shared_ptr<SamplerRT> sampler = nullptr); | |||
| // Destructor | |||
| ~ParallelOp() = default; | |||
| @@ -20,7 +20,7 @@ | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| // Constructor | |||
| PipelineOp::PipelineOp(int32_t op_connector_size, std::shared_ptr<Sampler> sampler) | |||
| PipelineOp::PipelineOp(int32_t op_connector_size, std::shared_ptr<SamplerRT> sampler) | |||
| : DatasetOp(op_connector_size, sampler) {} | |||
| // A print method typically used for debugging | |||
| @@ -34,7 +34,7 @@ class PipelineOp : public DatasetOp { | |||
| // @param op_connector_size - size of the output connector | |||
| // @return Builder setter method returns reference to the builder. | |||
| // @param sampler - The sampler for the op | |||
| explicit PipelineOp(int32_t op_connector_size, std::shared_ptr<Sampler> sampler = nullptr); | |||
| explicit PipelineOp(int32_t op_connector_size, std::shared_ptr<SamplerRT> sampler = nullptr); | |||
| // Destructor | |||
| ~PipelineOp() = default; | |||
| @@ -42,7 +42,7 @@ Status AlbumOp::Builder::Build(std::shared_ptr<AlbumOp> *ptr) { | |||
| if (builder_sampler_ == nullptr) { | |||
| const int64_t num_samples = 0; // default num samples of 0 means to sample entire set of data | |||
| const int64_t start_index = 0; | |||
| builder_sampler_ = std::make_shared<SequentialSampler>(start_index, num_samples); | |||
| builder_sampler_ = std::make_shared<SequentialSamplerRT>(start_index, num_samples); | |||
| } | |||
| builder_schema_ = std::make_unique<DataSchema>(); | |||
| @@ -73,7 +73,7 @@ Status AlbumOp::Builder::SanityCheck() { | |||
| AlbumOp::AlbumOp(int32_t num_wkrs, int32_t rows_per_buffer, std::string file_dir, int32_t queue_size, bool do_decode, | |||
| const std::set<std::string> &exts, std::unique_ptr<DataSchema> data_schema, | |||
| std::shared_ptr<Sampler> sampler) | |||
| std::shared_ptr<SamplerRT> sampler) | |||
| : ParallelOp(num_wkrs, queue_size, std::move(sampler)), | |||
| rows_per_buffer_(rows_per_buffer), | |||
| folder_path_(file_dir), | |||
| @@ -100,7 +100,7 @@ class AlbumOp : public ParallelOp, public RandomAccessOp { | |||
| /// \brief Setter method | |||
| /// \param[in] sampler | |||
| /// \return Builder setter method returns reference to the builder | |||
| Builder &SetSampler(std::shared_ptr<Sampler> sampler) { | |||
| Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) { | |||
| builder_sampler_ = std::move(sampler); | |||
| return *this; | |||
| } | |||
| @@ -147,7 +147,7 @@ class AlbumOp : public ParallelOp, public RandomAccessOp { | |||
| int32_t builder_rows_per_buffer_; | |||
| int32_t builder_op_connector_size_; | |||
| std::set<std::string> builder_extensions_; | |||
| std::shared_ptr<Sampler> builder_sampler_; | |||
| std::shared_ptr<SamplerRT> builder_sampler_; | |||
| std::unique_ptr<DataSchema> builder_schema_; | |||
| }; | |||
| @@ -161,7 +161,8 @@ class AlbumOp : public ParallelOp, public RandomAccessOp { | |||
| /// \param[in] data_schema - schema of dataset | |||
| /// \param[in] sampler - sampler tells AlbumOp what to read | |||
| AlbumOp(int32_t num_wkrs, int32_t rows_per_buffer, std::string file_dir, int32_t queue_size, bool do_decode, | |||
| const std::set<std::string> &exts, std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler); | |||
| const std::set<std::string> &exts, std::unique_ptr<DataSchema> data_schema, | |||
| std::shared_ptr<SamplerRT> sampler); | |||
| /// \brief Destructor. | |||
| ~AlbumOp() = default; | |||
| @@ -46,7 +46,7 @@ Status CelebAOp::Builder::Build(std::shared_ptr<CelebAOp> *op) { | |||
| if (builder_sampler_ == nullptr) { | |||
| const int64_t num_samples = 0; | |||
| const int64_t start_index = 0; | |||
| builder_sampler_ = std::make_shared<SequentialSampler>(start_index, num_samples); | |||
| builder_sampler_ = std::make_shared<SequentialSamplerRT>(start_index, num_samples); | |||
| } | |||
| builder_schema_ = std::make_unique<DataSchema>(); | |||
| @@ -79,7 +79,7 @@ Status CelebAOp::Builder::SanityCheck() { | |||
| CelebAOp::CelebAOp(int32_t num_workers, int32_t rows_per_buffer, const std::string &dir, int32_t queue_size, | |||
| bool decode, const std::string &usage, const std::set<std::string> &exts, | |||
| std::unique_ptr<DataSchema> schema, std::shared_ptr<Sampler> sampler) | |||
| std::unique_ptr<DataSchema> schema, std::shared_ptr<SamplerRT> sampler) | |||
| : ParallelOp(num_workers, queue_size, std::move(sampler)), | |||
| rows_per_buffer_(rows_per_buffer), | |||
| folder_path_(dir), | |||
| @@ -95,7 +95,7 @@ class CelebAOp : public ParallelOp, RandomAccessOp { | |||
| // Setter method | |||
| // @param std::shared_ptr<Sampler> sampler | |||
| // @return Builder setter method returns reference to the builder. | |||
| Builder &SetSampler(std::shared_ptr<Sampler> sampler) { | |||
| Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) { | |||
| builder_sampler_ = std::move(sampler); | |||
| return *this; | |||
| } | |||
| @@ -131,7 +131,7 @@ class CelebAOp : public ParallelOp, RandomAccessOp { | |||
| int32_t builder_rows_per_buffer_; | |||
| int32_t builder_op_connector_size_; | |||
| std::set<std::string> builder_extensions_; | |||
| std::shared_ptr<Sampler> builder_sampler_; | |||
| std::shared_ptr<SamplerRT> builder_sampler_; | |||
| std::unique_ptr<DataSchema> builder_schema_; | |||
| std::string builder_usage_; | |||
| }; | |||
| @@ -144,7 +144,7 @@ class CelebAOp : public ParallelOp, RandomAccessOp { | |||
| // @param std::unique_ptr<Sampler> sampler - sampler tells CelebAOp what to read | |||
| CelebAOp(int32_t num_workers, int32_t rows_per_buffer, const std::string &dir, int32_t queue_size, bool decode, | |||
| const std::string &usage, const std::set<std::string> &exts, std::unique_ptr<DataSchema> schema, | |||
| std::shared_ptr<Sampler> sampler); | |||
| std::shared_ptr<SamplerRT> sampler); | |||
| ~CelebAOp() override = default; | |||
| @@ -50,7 +50,7 @@ Status CifarOp::Builder::Build(std::shared_ptr<CifarOp> *ptr) { | |||
| if (sampler_ == nullptr) { | |||
| const int64_t num_samples = 0; | |||
| const int64_t start_index = 0; | |||
| sampler_ = std::make_shared<SequentialSampler>(start_index, num_samples); | |||
| sampler_ = std::make_shared<SequentialSamplerRT>(start_index, num_samples); | |||
| } | |||
| schema_ = std::make_unique<DataSchema>(); | |||
| TensorShape scalar = TensorShape::CreateScalar(); | |||
| @@ -88,7 +88,7 @@ Status CifarOp::Builder::SanityCheck() { | |||
| CifarOp::CifarOp(CifarType type, const std::string &usage, int32_t num_works, int32_t rows_per_buf, | |||
| const std::string &file_dir, int32_t queue_size, std::unique_ptr<DataSchema> data_schema, | |||
| std::shared_ptr<Sampler> sampler) | |||
| std::shared_ptr<SamplerRT> sampler) | |||
| : ParallelOp(num_works, queue_size, std::move(sampler)), | |||
| cifar_type_(type), | |||
| usage_(usage), | |||
| @@ -75,7 +75,7 @@ class CifarOp : public ParallelOp, public RandomAccessOp { | |||
| // Setter method | |||
| // @param std::shared_ptr<Sampler> sampler | |||
| // @return Builder setter method returns reference to the builder. | |||
| Builder &SetSampler(std::shared_ptr<Sampler> sampler) { | |||
| Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) { | |||
| sampler_ = std::move(sampler); | |||
| return *this; | |||
| } | |||
| @@ -123,7 +123,7 @@ class CifarOp : public ParallelOp, public RandomAccessOp { | |||
| int32_t num_workers_; | |||
| int32_t rows_per_buffer_; | |||
| int32_t op_connect_size_; | |||
| std::shared_ptr<Sampler> sampler_; | |||
| std::shared_ptr<SamplerRT> sampler_; | |||
| std::unique_ptr<DataSchema> schema_; | |||
| CifarType cifar_type_; | |||
| }; | |||
| @@ -138,7 +138,7 @@ class CifarOp : public ParallelOp, public RandomAccessOp { | |||
| // @param std::unique_ptr<Sampler> sampler - sampler tells ImageFolderOp what to read | |||
| CifarOp(CifarType type, const std::string &usage, int32_t num_works, int32_t rows_per_buf, | |||
| const std::string &file_dir, int32_t queue_size, std::unique_ptr<DataSchema> data_schema, | |||
| std::shared_ptr<Sampler> sampler); | |||
| std::shared_ptr<SamplerRT> sampler); | |||
| // Destructor. | |||
| ~CifarOp() = default; | |||
| @@ -94,7 +94,7 @@ std::vector<std::string> ClueOp::Builder::split(const std::string &s, char delim | |||
| ClueOp::ClueOp(int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size, | |||
| ColKeyMap cols_to_keyword, std::vector<std::string> clue_files_list, int32_t op_connector_size, | |||
| bool shuffle_files, int32_t num_device, int32_t device_id, std::shared_ptr<Sampler> sampler) | |||
| bool shuffle_files, int32_t num_device, int32_t device_id, std::shared_ptr<SamplerRT> sampler) | |||
| : ParallelOp(num_workers, op_connector_size, std::move(sampler)), | |||
| rows_per_buffer_(rows_per_buffer), | |||
| num_rows_per_shard_(0), | |||
| @@ -125,7 +125,7 @@ class ClueOp : public ParallelOp { | |||
| // Setter method | |||
| // @param std::shared_ptr<Sampler> sampler | |||
| // @return Builder setter method returns reference to the builder. | |||
| Builder &SetSampler(std::shared_ptr<Sampler> sampler) { | |||
| Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) { | |||
| builder_sampler_ = std::move(sampler); | |||
| return *this; | |||
| } | |||
| @@ -141,13 +141,13 @@ class ClueOp : public ParallelOp { | |||
| std::vector<std::string> builder_clue_files_list_; | |||
| bool builder_shuffle_files_; | |||
| std::map<std::string, std::string> builder_cols_to_keyword_; | |||
| std::shared_ptr<Sampler> builder_sampler_; | |||
| std::shared_ptr<SamplerRT> builder_sampler_; | |||
| }; | |||
| // Constructor of ClueOp | |||
| ClueOp(int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size, | |||
| ColKeyMap cols_to_keyword, std::vector<std::string> clue_files_list, int32_t op_connector_size, | |||
| bool shuffle_files, int32_t num_devices, int32_t device_id, std::shared_ptr<Sampler> sampler); | |||
| bool shuffle_files, int32_t num_devices, int32_t device_id, std::shared_ptr<SamplerRT> sampler); | |||
| // Default destructor | |||
| ~ClueOp() = default; | |||
| @@ -60,7 +60,7 @@ Status CocoOp::Builder::Build(std::shared_ptr<CocoOp> *ptr) { | |||
| if (builder_sampler_ == nullptr) { | |||
| const int64_t num_samples = 0; | |||
| const int64_t start_index = 0; | |||
| builder_sampler_ = std::make_shared<SequentialSampler>(start_index, num_samples); | |||
| builder_sampler_ = std::make_shared<SequentialSamplerRT>(start_index, num_samples); | |||
| } | |||
| builder_schema_ = std::make_unique<DataSchema>(); | |||
| RETURN_IF_NOT_OK(builder_schema_->AddColumn( | |||
| @@ -123,7 +123,7 @@ Status CocoOp::Builder::SanityCheck() { | |||
| CocoOp::CocoOp(const TaskType &task_type, const std::string &image_folder_path, const std::string &annotation_path, | |||
| int32_t num_workers, int32_t rows_per_buffer, int32_t queue_size, bool decode, | |||
| std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler) | |||
| std::unique_ptr<DataSchema> data_schema, std::shared_ptr<SamplerRT> sampler) | |||
| : ParallelOp(num_workers, queue_size, std::move(sampler)), | |||
| decode_(decode), | |||
| row_cnt_(0), | |||
| @@ -119,7 +119,7 @@ class CocoOp : public ParallelOp, public RandomAccessOp { | |||
| // Setter method. | |||
| // @param std::shared_ptr<Sampler> sampler | |||
| // @return Builder setter method returns reference to the builder. | |||
| Builder &SetSampler(std::shared_ptr<Sampler> sampler) { | |||
| Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) { | |||
| builder_sampler_ = std::move(sampler); | |||
| return *this; | |||
| } | |||
| @@ -149,7 +149,7 @@ class CocoOp : public ParallelOp, public RandomAccessOp { | |||
| int32_t builder_num_workers_; | |||
| int32_t builder_op_connector_size_; | |||
| int32_t builder_rows_per_buffer_; | |||
| std::shared_ptr<Sampler> builder_sampler_; | |||
| std::shared_ptr<SamplerRT> builder_sampler_; | |||
| std::unique_ptr<DataSchema> builder_schema_; | |||
| }; | |||
| @@ -166,7 +166,7 @@ class CocoOp : public ParallelOp, public RandomAccessOp { | |||
| // @param std::shared_ptr<Sampler> sampler - sampler tells CocoOp what to read | |||
| CocoOp(const TaskType &task_type, const std::string &image_folder_path, const std::string &annotation_path, | |||
| int32_t num_workers, int32_t rows_per_buffer, int32_t queue_size, bool decode, | |||
| std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler); | |||
| std::unique_ptr<DataSchema> data_schema, std::shared_ptr<SamplerRT> sampler); | |||
| // Destructor | |||
| ~CocoOp() = default; | |||
| @@ -77,7 +77,7 @@ CsvOp::CsvOp(const std::vector<std::string> &csv_files_list, char field_delim, | |||
| const std::vector<std::shared_ptr<BaseRecord>> &column_default, | |||
| const std::vector<std::string> &column_name, int32_t num_workers, int64_t rows_per_buffer, | |||
| int64_t num_samples, int32_t worker_connector_size, int32_t op_connector_size, bool shuffle_files, | |||
| int32_t num_device, int32_t device_id, std::shared_ptr<Sampler> sampler) | |||
| int32_t num_device, int32_t device_id, std::shared_ptr<SamplerRT> sampler) | |||
| : ParallelOp(num_workers, op_connector_size, std::move(sampler)), | |||
| csv_files_list_(std::move(csv_files_list)), | |||
| field_delim_(field_delim), | |||
| @@ -243,7 +243,7 @@ class CsvOp : public ParallelOp { | |||
| // Setter method | |||
| // @param std::shared_ptr<Sampler> sampler | |||
| // @return Builder setter method returns reference to the builder. | |||
| Builder &SetSampler(std::shared_ptr<Sampler> sampler) { | |||
| Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) { | |||
| builder_sampler_ = std::move(sampler); | |||
| return *this; | |||
| } | |||
| @@ -261,7 +261,7 @@ class CsvOp : public ParallelOp { | |||
| char builder_field_delim_; | |||
| std::vector<std::shared_ptr<CsvOp::BaseRecord>> builder_column_default_list_; | |||
| std::vector<std::string> builder_column_name_list_; | |||
| std::shared_ptr<Sampler> builder_sampler_; | |||
| std::shared_ptr<SamplerRT> builder_sampler_; | |||
| }; | |||
| // Constructor of CsvOp | |||
| @@ -271,7 +271,7 @@ class CsvOp : public ParallelOp { | |||
| const std::vector<std::shared_ptr<BaseRecord>> &column_default, const std::vector<std::string> &column_name, | |||
| int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size, | |||
| int32_t op_connector_size, bool shuffle_files, int32_t num_devices, int32_t device_id, | |||
| std::shared_ptr<Sampler> sampler); | |||
| std::shared_ptr<SamplerRT> sampler); | |||
| // Default destructor | |||
| ~CsvOp() = default; | |||
| @@ -38,7 +38,7 @@ Status ImageFolderOp::Builder::Build(std::shared_ptr<ImageFolderOp> *ptr) { | |||
| if (builder_sampler_ == nullptr) { | |||
| const int64_t num_samples = 0; // default num samples of 0 means to sample entire set of data | |||
| const int64_t start_index = 0; | |||
| builder_sampler_ = std::make_shared<SequentialSampler>(start_index, num_samples); | |||
| builder_sampler_ = std::make_shared<SequentialSamplerRT>(start_index, num_samples); | |||
| } | |||
| builder_schema_ = std::make_unique<DataSchema>(); | |||
| TensorShape scalar = TensorShape::CreateScalar(); | |||
| @@ -68,7 +68,7 @@ Status ImageFolderOp::Builder::SanityCheck() { | |||
| ImageFolderOp::ImageFolderOp(int32_t num_wkrs, int32_t rows_per_buffer, std::string file_dir, int32_t queue_size, | |||
| bool recursive, bool do_decode, const std::set<std::string> &exts, | |||
| const std::map<std::string, int32_t> &map, std::unique_ptr<DataSchema> data_schema, | |||
| std::shared_ptr<Sampler> sampler) | |||
| std::shared_ptr<SamplerRT> sampler) | |||
| : ParallelOp(num_wkrs, queue_size, std::move(sampler)), | |||
| rows_per_buffer_(rows_per_buffer), | |||
| folder_path_(file_dir), | |||
| @@ -113,7 +113,7 @@ class ImageFolderOp : public ParallelOp, public RandomAccessOp { | |||
| // Setter method | |||
| // @param std::shared_ptr<Sampler> sampler | |||
| // @return Builder setter method returns reference to the builder. | |||
| Builder &SetSampler(std::shared_ptr<Sampler> sampler) { | |||
| Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) { | |||
| builder_sampler_ = std::move(sampler); | |||
| return *this; | |||
| } | |||
| @@ -151,7 +151,7 @@ class ImageFolderOp : public ParallelOp, public RandomAccessOp { | |||
| int32_t builder_rows_per_buffer_; | |||
| int32_t builder_op_connector_size_; | |||
| std::set<std::string> builder_extensions_; | |||
| std::shared_ptr<Sampler> builder_sampler_; | |||
| std::shared_ptr<SamplerRT> builder_sampler_; | |||
| std::unique_ptr<DataSchema> builder_schema_; | |||
| std::map<std::string, int32_t> builder_labels_to_read_; | |||
| }; | |||
| @@ -165,7 +165,7 @@ class ImageFolderOp : public ParallelOp, public RandomAccessOp { | |||
| // @param td::unique_ptr<Sampler> sampler - sampler tells ImageFolderOp what to read | |||
| ImageFolderOp(int32_t num_wkrs, int32_t rows_per_buffer, std::string file_dir, int32_t queue_size, bool recursive, | |||
| bool do_decode, const std::set<std::string> &exts, const std::map<std::string, int32_t> &map, | |||
| std::unique_ptr<DataSchema>, std::shared_ptr<Sampler> sampler); | |||
| std::unique_ptr<DataSchema>, std::shared_ptr<SamplerRT> sampler); | |||
| // Destructor. | |||
| ~ImageFolderOp() = default; | |||
| @@ -43,7 +43,7 @@ Status ManifestOp::Builder::Build(std::shared_ptr<ManifestOp> *ptr) { | |||
| if (builder_sampler_ == nullptr) { | |||
| const int64_t num_samples = 0; | |||
| const int64_t start_index = 0; | |||
| builder_sampler_ = std::make_shared<SequentialSampler>(start_index, num_samples); | |||
| builder_sampler_ = std::make_shared<SequentialSamplerRT>(start_index, num_samples); | |||
| } | |||
| builder_schema_ = std::make_unique<DataSchema>(); | |||
| RETURN_IF_NOT_OK( | |||
| @@ -67,7 +67,7 @@ Status ManifestOp::Builder::SanityCheck() { | |||
| ManifestOp::ManifestOp(int32_t num_works, int32_t rows_per_buffer, std::string file, int32_t queue_size, bool decode, | |||
| const std::map<std::string, int32_t> &class_index, std::unique_ptr<DataSchema> data_schema, | |||
| std::shared_ptr<Sampler> sampler, std::string usage) | |||
| std::shared_ptr<SamplerRT> sampler, std::string usage) | |||
| : ParallelOp(num_works, queue_size, std::move(sampler)), | |||
| rows_per_buffer_(rows_per_buffer), | |||
| io_block_pushed_(0), | |||
| @@ -88,7 +88,7 @@ class ManifestOp : public ParallelOp, public RandomAccessOp { | |||
| // Setter method | |||
| // @param std::shared_ptr<Sampler> sampler | |||
| // @return Builder setter method returns reference to the builder. | |||
| Builder &SetSampler(std::shared_ptr<Sampler> sampler) { | |||
| Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) { | |||
| builder_sampler_ = std::move(sampler); | |||
| return *this; | |||
| } | |||
| @@ -119,7 +119,7 @@ class ManifestOp : public ParallelOp, public RandomAccessOp { | |||
| Status Build(std::shared_ptr<ManifestOp> *op); | |||
| private: | |||
| std::shared_ptr<Sampler> builder_sampler_; | |||
| std::shared_ptr<SamplerRT> builder_sampler_; | |||
| bool builder_decode_; | |||
| std::string builder_file_; | |||
| @@ -139,7 +139,7 @@ class ManifestOp : public ParallelOp, public RandomAccessOp { | |||
| // @param td::unique_ptr<Sampler> sampler - sampler tells ImageFolderOp what to read | |||
| ManifestOp(int32_t num_works, int32_t rows_per_buffer, std::string file, int32_t queue_size, bool decode, | |||
| const std::map<std::string, int32_t> &class_index, std::unique_ptr<DataSchema> data_schema, | |||
| std::shared_ptr<Sampler> sampler, std::string usage); | |||
| std::shared_ptr<SamplerRT> sampler, std::string usage); | |||
| // Destructor. | |||
| ~ManifestOp() = default; | |||
| @@ -45,7 +45,7 @@ Status MnistOp::Builder::Build(std::shared_ptr<MnistOp> *ptr) { | |||
| if (builder_sampler_ == nullptr) { | |||
| const int64_t num_samples = 0; | |||
| const int64_t start_index = 0; | |||
| builder_sampler_ = std::make_shared<SequentialSampler>(start_index, num_samples); | |||
| builder_sampler_ = std::make_shared<SequentialSamplerRT>(start_index, num_samples); | |||
| } | |||
| builder_schema_ = std::make_unique<DataSchema>(); | |||
| RETURN_IF_NOT_OK( | |||
| @@ -75,7 +75,7 @@ Status MnistOp::Builder::SanityCheck() { | |||
| } | |||
| MnistOp::MnistOp(const std::string &usage, int32_t num_workers, int32_t rows_per_buffer, std::string folder_path, | |||
| int32_t queue_size, std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler) | |||
| int32_t queue_size, std::unique_ptr<DataSchema> data_schema, std::shared_ptr<SamplerRT> sampler) | |||
| : ParallelOp(num_workers, queue_size, std::move(sampler)), | |||
| usage_(usage), | |||
| buf_cnt_(0), | |||
| @@ -78,7 +78,7 @@ class MnistOp : public ParallelOp, public RandomAccessOp { | |||
| // Setter method | |||
| // @param std::shared_ptr<Sampler> sampler | |||
| // @return Builder setter method returns reference to the builder. | |||
| Builder &SetSampler(std::shared_ptr<Sampler> sampler) { | |||
| Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) { | |||
| builder_sampler_ = std::move(sampler); | |||
| return *this; | |||
| } | |||
| @@ -113,7 +113,7 @@ class MnistOp : public ParallelOp, public RandomAccessOp { | |||
| int32_t builder_num_workers_; | |||
| int32_t builder_rows_per_buffer_; | |||
| int32_t builder_op_connector_size_; | |||
| std::shared_ptr<Sampler> builder_sampler_; | |||
| std::shared_ptr<SamplerRT> builder_sampler_; | |||
| std::unique_ptr<DataSchema> builder_schema_; | |||
| }; | |||
| @@ -126,7 +126,7 @@ class MnistOp : public ParallelOp, public RandomAccessOp { | |||
| // @param std::unique_ptr<DataSchema> data_schema - the schema of the mnist dataset | |||
| // @param td::unique_ptr<Sampler> sampler - sampler tells MnistOp what to read | |||
| MnistOp(const std::string &usage, int32_t num_workers, int32_t rows_per_buffer, std::string folder_path, | |||
| int32_t queue_size, std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler); | |||
| int32_t queue_size, std::unique_ptr<DataSchema> data_schema, std::shared_ptr<SamplerRT> sampler); | |||
| // Destructor. | |||
| ~MnistOp() = default; | |||
| @@ -65,7 +65,7 @@ Status RandomDataOp::Builder::SanityCheck() const { | |||
| // Constructor for RandomDataOp | |||
| RandomDataOp::RandomDataOp(int32_t num_workers, int32_t op_connector_size, int64_t rows_per_buffer, int64_t total_rows, | |||
| std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler) | |||
| std::unique_ptr<DataSchema> data_schema, std::shared_ptr<SamplerRT> sampler) | |||
| : ParallelOp(num_workers, op_connector_size, std::move(sampler)), | |||
| buffer_id_(0), | |||
| rows_per_buffer_(rows_per_buffer), | |||
| @@ -120,7 +120,7 @@ class RandomDataOp : public ParallelOp { | |||
| // Setter method | |||
| // @param std::shared_ptr<Sampler> sampler | |||
| // @return Builder setter method returns reference to the builder. | |||
| Builder &SetSampler(std::shared_ptr<Sampler> sampler) { | |||
| Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) { | |||
| builder_sampler_ = std::move(sampler); | |||
| return *this; | |||
| } | |||
| @@ -133,7 +133,7 @@ class RandomDataOp : public ParallelOp { | |||
| Status SanityCheck() const; | |||
| std::unique_ptr<DataSchema> builder_data_schema_; | |||
| std::shared_ptr<Sampler> builder_sampler_; | |||
| std::shared_ptr<SamplerRT> builder_sampler_; | |||
| int32_t builder_num_workers_; | |||
| int32_t builder_op_connector_size_; | |||
| int64_t builder_rows_per_buffer_; | |||
| @@ -152,7 +152,7 @@ class RandomDataOp : public ParallelOp { | |||
| * @return Builder - The modified builder by reference | |||
| */ | |||
| RandomDataOp(int32_t num_workers, int32_t op_connector_size, int64_t rows_per_buffer, int64_t total_rows, | |||
| std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler); | |||
| std::unique_ptr<DataSchema> data_schema, std::shared_ptr<SamplerRT> sampler); | |||
| /** | |||
| * Destructor | |||
| @@ -23,9 +23,9 @@ | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| DistributedSampler::DistributedSampler(int64_t num_samples, int64_t num_dev, int64_t dev_id, bool shuffle, | |||
| uint32_t seed, int64_t offset, bool even_dist) | |||
| : Sampler(num_samples, std::numeric_limits<int64_t>::max()), | |||
| DistributedSamplerRT::DistributedSamplerRT(int64_t num_samples, int64_t num_dev, int64_t dev_id, bool shuffle, | |||
| uint32_t seed, int64_t offset, bool even_dist) | |||
| : SamplerRT(num_samples, std::numeric_limits<int64_t>::max()), | |||
| cnt_(0), | |||
| seed_(seed == std::numeric_limits<uint32_t>::max() ? GetSeed() : seed), | |||
| device_id_(dev_id), | |||
| @@ -35,7 +35,7 @@ DistributedSampler::DistributedSampler(int64_t num_samples, int64_t num_dev, int | |||
| offset_(offset), | |||
| non_empty_(true) {} | |||
| Status DistributedSampler::InitSampler() { | |||
| Status DistributedSamplerRT::InitSampler() { | |||
| // Special value of 0 for num_samples means that the user wants to sample the entire set of data. | |||
| // If the user asked to sample more rows than exists in the dataset, adjust the num_samples accordingly. | |||
| if (num_samples_ == 0 || num_samples_ > num_rows_) { | |||
| @@ -74,7 +74,7 @@ Status DistributedSampler::InitSampler() { | |||
| return Status::OK(); | |||
| } | |||
| Status DistributedSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) { | |||
| Status DistributedSamplerRT::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) { | |||
| if (cnt_ > samples_per_buffer_) { | |||
| RETURN_STATUS_UNEXPECTED( | |||
| "Number of samples(cnt) that have already been filled in to buffer should be less than or " | |||
| @@ -143,7 +143,7 @@ Status DistributedSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer | |||
| return Status::OK(); | |||
| } | |||
| Status DistributedSampler::ResetSampler() { | |||
| Status DistributedSamplerRT::ResetSampler() { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(cnt_ == samples_per_buffer_, "ERROR Reset() called early/late"); | |||
| cnt_ = 0; | |||
| @@ -160,10 +160,10 @@ Status DistributedSampler::ResetSampler() { | |||
| return Status::OK(); | |||
| } | |||
| void DistributedSampler::Print(std::ostream &out, bool show_all) const { | |||
| void DistributedSamplerRT::Print(std::ostream &out, bool show_all) const { | |||
| out << "\nSampler: DistributedSampler"; | |||
| if (show_all) { | |||
| Sampler::Print(out, show_all); | |||
| SamplerRT::Print(out, show_all); | |||
| out << "\nseed: " << seed_ << "\ndevice_id: " << device_id_ << "\nnum_devices: " << num_devices_ | |||
| << "\nshuffle: " << shuffle_; | |||
| } | |||
| @@ -25,7 +25,7 @@ | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| class DistributedSampler : public Sampler { | |||
| class DistributedSamplerRT : public SamplerRT { | |||
| public: | |||
| /// \brief Constructor | |||
| /// \param[in] num_samples The total number of rows in the dataset | |||
| @@ -40,11 +40,12 @@ class DistributedSampler : public Sampler { | |||
| /// This option is not exposed in the python API. Current behavior is that the remainder will always | |||
| /// be handled by the first n shards, n being the corresponding device id. Please notice that when offset is set, | |||
| /// even_dist will be forcibly converted to false for sending rest datasets in concatdataset scenario. | |||
| DistributedSampler(int64_t num_samples, int64_t num_dev, int64_t dev_id, bool shuffle, | |||
| uint32_t seed = std::numeric_limits<uint32_t>::max(), int64_t offset = -1, bool even_dist = true); | |||
| DistributedSamplerRT(int64_t num_samples, int64_t num_dev, int64_t dev_id, bool shuffle, | |||
| uint32_t seed = std::numeric_limits<uint32_t>::max(), int64_t offset = -1, | |||
| bool even_dist = true); | |||
| /// \brief default destructor | |||
| ~DistributedSampler() = default; | |||
| ~DistributedSamplerRT() = default; | |||
| /// \param std::unique_ptr<DataBuffer> * pBuffer | |||
| /// \param int32_t workerId | |||
| @@ -20,14 +20,14 @@ | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| PKSampler::PKSampler(int64_t num_samples, int64_t val, bool shuffle, int64_t samples_per_buffer) | |||
| : Sampler(num_samples, samples_per_buffer), | |||
| PKSamplerRT::PKSamplerRT(int64_t num_samples, int64_t val, bool shuffle, int64_t samples_per_buffer) | |||
| : SamplerRT(num_samples, samples_per_buffer), | |||
| shuffle_(shuffle), | |||
| seed_(GetSeed()), | |||
| next_id_(0), | |||
| samples_per_class_(val) {} | |||
| Status PKSampler::InitSampler() { | |||
| Status PKSamplerRT::InitSampler() { | |||
| labels_.reserve(label_to_ids_.size()); | |||
| for (const auto &pair : label_to_ids_) { | |||
| if (pair.second.empty() == false) { | |||
| @@ -61,7 +61,7 @@ Status PKSampler::InitSampler() { | |||
| return Status::OK(); | |||
| } | |||
| Status PKSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) { | |||
| Status PKSamplerRT::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) { | |||
| if (next_id_ > num_samples_ || num_samples_ == 0) { | |||
| RETURN_STATUS_UNEXPECTED("Index must be less than or equal to num_samples, but got: " + std::to_string(next_id_)); | |||
| } else if (next_id_ == num_samples_) { | |||
| @@ -96,7 +96,7 @@ Status PKSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) { | |||
| return Status::OK(); | |||
| } | |||
| Status PKSampler::ResetSampler() { | |||
| Status PKSamplerRT::ResetSampler() { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(next_id_ == num_samples_, "ERROR Reset() called early/late"); | |||
| next_id_ = 0; | |||
| rnd_.seed(seed_++); | |||
| @@ -108,18 +108,18 @@ Status PKSampler::ResetSampler() { | |||
| return Status::OK(); | |||
| } | |||
| Status PKSampler::HandshakeRandomAccessOp(const RandomAccessOp *op) { | |||
| Status PKSamplerRT::HandshakeRandomAccessOp(const RandomAccessOp *op) { | |||
| RETURN_UNEXPECTED_IF_NULL(op); | |||
| RETURN_IF_NOT_OK(op->GetClassIds(&label_to_ids_)); | |||
| RETURN_IF_NOT_OK(InitSampler()); | |||
| return Status::OK(); | |||
| } | |||
| void PKSampler::Print(std::ostream &out, bool show_all) const { | |||
| void PKSamplerRT::Print(std::ostream &out, bool show_all) const { | |||
| out << "\nSampler: PKSampler"; | |||
| if (show_all) { | |||
| // Call the super class for displaying any common detailed info | |||
| Sampler::Print(out, show_all); | |||
| SamplerRT::Print(out, show_all); | |||
| // Then add our own info if any | |||
| } | |||
| } | |||
| @@ -26,17 +26,17 @@ | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| class PKSampler : public Sampler { // NOT YET FINISHED | |||
| class PKSamplerRT : public SamplerRT { // NOT YET FINISHED | |||
| public: | |||
| // @param num_samples - the number of samples to draw. value of 0 means to take the full amount | |||
| // @param int64_t val | |||
| // @param bool shuffle - shuffle all classIds or not, if true, classes may be 5,1,4,3,2 | |||
| // @param int64_t samplesPerBuffer - Num of Sampler Ids to fetch via 1 GetNextBuffer call | |||
| explicit PKSampler(int64_t num_samples, int64_t val, bool shuffle, | |||
| int64_t samples_per_buffer = std::numeric_limits<int64_t>::max()); | |||
| explicit PKSamplerRT(int64_t num_samples, int64_t val, bool shuffle, | |||
| int64_t samples_per_buffer = std::numeric_limits<int64_t>::max()); | |||
| // default destructor | |||
| ~PKSampler() = default; | |||
| ~PKSamplerRT() = default; | |||
| // @param std::unique_ptr<DataBuffer pBuffer | |||
| // @param int32_t workerId | |||
| @@ -20,10 +20,10 @@ | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| PythonSampler::PythonSampler(int64_t num_samples, py::object py_sampler_instance, int64_t samples_per_buffer) | |||
| : Sampler(num_samples, samples_per_buffer), py_sampler_instance(py_sampler_instance), need_to_reset_(false) {} | |||
| PythonSamplerRT::PythonSamplerRT(int64_t num_samples, py::object py_sampler_instance, int64_t samples_per_buffer) | |||
| : SamplerRT(num_samples, samples_per_buffer), py_sampler_instance(py_sampler_instance), need_to_reset_(false) {} | |||
| Status PythonSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) { | |||
| Status PythonSamplerRT::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) { | |||
| if (need_to_reset_) { | |||
| (*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE); | |||
| } else { | |||
| @@ -64,7 +64,7 @@ Status PythonSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) { | |||
| return Status::OK(); | |||
| } | |||
| Status PythonSampler::InitSampler() { | |||
| Status PythonSamplerRT::InitSampler() { | |||
| CHECK_FAIL_RETURN_UNEXPECTED( | |||
| num_rows_ > 0, "Invalid parameter, num_rows must be greater than 0, but got " + std::to_string(num_rows_)); | |||
| // Special value of 0 for num_samples means that the user wants to sample the entire set of data. | |||
| @@ -86,7 +86,7 @@ Status PythonSampler::InitSampler() { | |||
| return Status::OK(); | |||
| } | |||
| Status PythonSampler::ResetSampler() { | |||
| Status PythonSamplerRT::ResetSampler() { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(need_to_reset_, "ERROR Reset() called not at end of an epoch"); | |||
| need_to_reset_ = false; | |||
| py::gil_scoped_acquire gil_acquire; | |||
| @@ -106,11 +106,11 @@ Status PythonSampler::ResetSampler() { | |||
| return Status::OK(); | |||
| } | |||
| void PythonSampler::Print(std::ostream &out, bool show_all) const { | |||
| void PythonSamplerRT::Print(std::ostream &out, bool show_all) const { | |||
| out << "\nSampler: PythonSampler"; | |||
| if (show_all) { | |||
| // Call the super class for displaying any common detailed info | |||
| Sampler::Print(out, show_all); | |||
| SamplerRT::Print(out, show_all); | |||
| // Then add our own info if any | |||
| } | |||
| } | |||
| @@ -23,18 +23,18 @@ | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| class PythonSampler : public Sampler { | |||
| class PythonSamplerRT : public SamplerRT { | |||
| public: | |||
| // Constructor | |||
| // @param num_samples - the number of samples to draw. Value of 0 means to sample all of the | |||
| // data from the dataset. | |||
| // @param py_sampler_instance - the python instance of the sampler | |||
| // @param int64_t samples_per_buffer - Num of Sampler Ids to fetch via 1 GetNextBuffer call | |||
| explicit PythonSampler(int64_t num_samples, py::object py_sampler_instance, | |||
| int64_t samples_per_buffer = std::numeric_limits<int64_t>::max()); | |||
| explicit PythonSamplerRT(int64_t num_samples, py::object py_sampler_instance, | |||
| int64_t samples_per_buffer = std::numeric_limits<int64_t>::max()); | |||
| // Destructor. | |||
| ~PythonSampler() = default; | |||
| ~PythonSamplerRT() = default; | |||
| // Initialize the sampler. | |||
| // @return Status | |||
| @@ -22,16 +22,16 @@ | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| RandomSampler::RandomSampler(int64_t num_samples, bool replacement, bool reshuffle_each_epoch, | |||
| int64_t samples_per_buffer) | |||
| : Sampler(num_samples, samples_per_buffer), | |||
| RandomSamplerRT::RandomSamplerRT(int64_t num_samples, bool replacement, bool reshuffle_each_epoch, | |||
| int64_t samples_per_buffer) | |||
| : SamplerRT(num_samples, samples_per_buffer), | |||
| seed_(GetSeed()), | |||
| replacement_(replacement), | |||
| next_id_(0), | |||
| reshuffle_each_epoch_(reshuffle_each_epoch), | |||
| dist(nullptr) {} | |||
| Status RandomSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) { | |||
| Status RandomSamplerRT::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) { | |||
| if (next_id_ > num_samples_) { | |||
| RETURN_STATUS_UNEXPECTED("RandomSampler Internal Error"); | |||
| } else if (next_id_ == num_samples_) { | |||
| @@ -68,7 +68,7 @@ Status RandomSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) { | |||
| return Status::OK(); | |||
| } | |||
| Status RandomSampler::InitSampler() { | |||
| Status RandomSamplerRT::InitSampler() { | |||
| // Special value of 0 for num_samples means that the user wants to sample the entire set of data. | |||
| // If the user asked to sample more rows than exists in the dataset, adjust the num_samples accordingly. | |||
| if (num_samples_ == 0 || num_samples_ > num_rows_) { | |||
| @@ -94,7 +94,7 @@ Status RandomSampler::InitSampler() { | |||
| return Status::OK(); | |||
| } | |||
| Status RandomSampler::ResetSampler() { | |||
| Status RandomSamplerRT::ResetSampler() { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(next_id_ == num_samples_, "ERROR Reset() called early/late"); | |||
| next_id_ = 0; | |||
| @@ -115,11 +115,11 @@ Status RandomSampler::ResetSampler() { | |||
| return Status::OK(); | |||
| } | |||
| void RandomSampler::Print(std::ostream &out, bool show_all) const { | |||
| void RandomSamplerRT::Print(std::ostream &out, bool show_all) const { | |||
| out << "\nSampler: RandomSampler"; | |||
| if (show_all) { | |||
| // Call the super class for displaying any common detailed info | |||
| Sampler::Print(out, show_all); | |||
| SamplerRT::Print(out, show_all); | |||
| // Then add our own info if any | |||
| } | |||
| } | |||
| @@ -24,18 +24,18 @@ | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| class RandomSampler : public Sampler { | |||
| class RandomSamplerRT : public SamplerRT { | |||
| public: | |||
| // Constructor | |||
| // @param int64_t num_samples - number samples to draw | |||
| // @param bool replacement - put he id back / or not after a sample | |||
| // @param reshuffle_each_epoch - T/F to reshuffle after epoch | |||
| // @param int64_t samples_per_buffer - Num of Sampler Ids to fetch via 1 GetNextBuffer call | |||
| explicit RandomSampler(int64_t num_samples, bool replacement, bool reshuffle_each_epoch, | |||
| int64_t samples_per_buffer = std::numeric_limits<int64_t>::max()); | |||
| explicit RandomSamplerRT(int64_t num_samples, bool replacement, bool reshuffle_each_epoch, | |||
| int64_t samples_per_buffer = std::numeric_limits<int64_t>::max()); | |||
| // Destructor. | |||
| ~RandomSampler() = default; | |||
| ~RandomSamplerRT() = default; | |||
| // Op calls this to get next Buffer that contains all the sampleIds | |||
| // @param std::unique_ptr<DataBuffer> pBuffer - Buffer to be returned to StorageOp | |||
| @@ -32,13 +32,13 @@ Status RandomAccessOp::GetNumRowsInDataset(int64_t *num) const { | |||
| return Status::OK(); | |||
| } | |||
| Sampler::Sampler(int64_t num_samples, int64_t samples_per_buffer) | |||
| SamplerRT::SamplerRT(int64_t num_samples, int64_t samples_per_buffer) | |||
| : num_rows_(0), num_samples_(num_samples), samples_per_buffer_(samples_per_buffer), col_desc_(nullptr) {} | |||
| Status Sampler::HandshakeRandomAccessOp(const RandomAccessOp *op) { | |||
| std::shared_ptr<Sampler> child_sampler; | |||
| Status SamplerRT::HandshakeRandomAccessOp(const RandomAccessOp *op) { | |||
| std::shared_ptr<SamplerRT> child_sampler; | |||
| if (HasChildSampler()) { | |||
| child_sampler = std::dynamic_pointer_cast<Sampler>(child_[0]); | |||
| child_sampler = std::dynamic_pointer_cast<SamplerRT>(child_[0]); | |||
| if (!child_sampler) { | |||
| std::string err_msg("Cannot handshake, child is not a sampler object."); | |||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||
| @@ -64,7 +64,7 @@ Status Sampler::HandshakeRandomAccessOp(const RandomAccessOp *op) { | |||
| return Status::OK(); | |||
| } | |||
| Status Sampler::CreateSamplerTensor(std::shared_ptr<Tensor> *sample_ids, int64_t num_elements) { | |||
| Status SamplerRT::CreateSamplerTensor(std::shared_ptr<Tensor> *sample_ids, int64_t num_elements) { | |||
| if (num_elements == 0) { | |||
| RETURN_STATUS_UNEXPECTED("Invalid data, num of elements cannot be 0."); | |||
| } | |||
| @@ -77,7 +77,7 @@ Status Sampler::CreateSamplerTensor(std::shared_ptr<Tensor> *sample_ids, int64_t | |||
| return Status::OK(); | |||
| } | |||
| void Sampler::Print(std::ostream &out, bool show_all) const { | |||
| void SamplerRT::Print(std::ostream &out, bool show_all) const { | |||
| // Sampler printing is usually only called in the show_all mode. | |||
| // Derived classes will display the name, then call back to this base | |||
| // for common info. | |||
| @@ -88,7 +88,7 @@ void Sampler::Print(std::ostream &out, bool show_all) const { | |||
| } | |||
| #ifdef ENABLE_PYTHON | |||
| Status Sampler::GetAllIdsThenReset(py::array *data) { | |||
| Status SamplerRT::GetAllIdsThenReset(py::array *data) { | |||
| std::unique_ptr<DataBuffer> db; | |||
| std::shared_ptr<Tensor> sample_ids; | |||
| TensorRow sample_row; | |||
| @@ -123,27 +123,27 @@ Status Sampler::GetAllIdsThenReset(py::array *data) { | |||
| } | |||
| #endif | |||
| Status Sampler::SetNumSamples(int64_t num_samples) { | |||
| Status SamplerRT::SetNumSamples(int64_t num_samples) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(num_samples >= 0, "Invalid parameter, num_samples must be greater than or equal to 0."); | |||
| num_samples_ = num_samples; | |||
| return Status::OK(); | |||
| } | |||
| int64_t Sampler::GetNumSamples() { return num_samples_; } | |||
| int64_t SamplerRT::GetNumSamples() { return num_samples_; } | |||
| Status Sampler::SetNumRowsInDataset(int64_t num_rows) { | |||
| Status SamplerRT::SetNumRowsInDataset(int64_t num_rows) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(num_rows > 0, "Invalid parameter, num_rows must be greater than 0."); | |||
| num_rows_ = num_rows; | |||
| return Status::OK(); | |||
| } | |||
| Status Sampler::AddChild(std::shared_ptr<Sampler> child) { | |||
| Status SamplerRT::AddChild(std::shared_ptr<SamplerRT> child) { | |||
| if (child == nullptr) { | |||
| return Status::OK(); | |||
| } | |||
| // Only samplers can be added, not any other DatasetOp. | |||
| std::shared_ptr<Sampler> sampler = std::dynamic_pointer_cast<Sampler>(child); | |||
| std::shared_ptr<SamplerRT> sampler = std::dynamic_pointer_cast<SamplerRT>(child); | |||
| if (!sampler) { | |||
| std::string err_msg("Cannot add child, child is not a sampler object."); | |||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||
| @@ -160,9 +160,9 @@ Status Sampler::AddChild(std::shared_ptr<Sampler> child) { | |||
| return Status::OK(); | |||
| } | |||
| bool Sampler::HasChildSampler() { return !child_.empty(); } | |||
| bool SamplerRT::HasChildSampler() { return !child_.empty(); } | |||
| Status Sampler::GetAssociatedChildId(int64_t *out_associated_id, int64_t id) { | |||
| Status SamplerRT::GetAssociatedChildId(int64_t *out_associated_id, int64_t id) { | |||
| if (child_ids_ == nullptr) { | |||
| RETURN_STATUS_UNEXPECTED("Trying to get associated child id, but there are no child ids!"); | |||
| } | |||
| @@ -51,21 +51,21 @@ class RandomAccessOp { | |||
| protected: | |||
| // The amount of rows in the dataset itself. This is the before-sampling value, the | |||
| // total count of rows. A sampler may choose to sample less than this amount. | |||
| int64_t num_rows_; | |||
| int64_t num_rows_ = -1; | |||
| }; | |||
| class Sampler { | |||
| class SamplerRT { | |||
| public: | |||
| // Constructor | |||
| // @param int64_t num_samples: the user-requested number of samples ids to generate. A value of 0 | |||
| // indicates that the sampler should produce the complete set of ids. | |||
| // @param int64_t samplesPerBuffer: Num of Sampler Ids to fetch via 1 GetNextBuffer call | |||
| explicit Sampler(int64_t num_samples, int64_t samples_per_buffer); | |||
| explicit SamplerRT(int64_t num_samples, int64_t samples_per_buffer); | |||
| Sampler(const Sampler &s) : Sampler(s.num_samples_, s.samples_per_buffer_) {} | |||
| SamplerRT(const SamplerRT &s) : SamplerRT(s.num_samples_, s.samples_per_buffer_) {} | |||
| // default destructor | |||
| ~Sampler() = default; | |||
| ~SamplerRT() = default; | |||
| // Get a list of sample ids. | |||
| // @note It is Sampler responsibility to make sure that the id is not out of bound. | |||
| @@ -111,7 +111,7 @@ class Sampler { | |||
| // Adds a sampler to become our child. | |||
| // @param std::shared_ptr<DatasetOp> - The sampler to add as a child. | |||
| // @return - The error code returned. | |||
| Status AddChild(std::shared_ptr<Sampler> child); | |||
| Status AddChild(std::shared_ptr<SamplerRT> child); | |||
| // A helper function to create a int64_t 1-D Tensor specifically used to hold sampleIds for Sampler | |||
| // @param std::shared_ptr<Tensor>* sampleIds | |||
| @@ -129,7 +129,7 @@ class Sampler { | |||
| // @param out - reference to the output stream being overloaded | |||
| // @param sampler - reference to teh sampler to print | |||
| // @return - the output stream must be returned | |||
| friend std::ostream &operator<<(std::ostream &out, const Sampler &sampler) { | |||
| friend std::ostream &operator<<(std::ostream &out, const SamplerRT &sampler) { | |||
| sampler.Print(out, false); | |||
| return out; | |||
| } | |||
| @@ -158,7 +158,7 @@ class Sampler { | |||
| int64_t samples_per_buffer_; | |||
| std::unique_ptr<ColDescriptor> col_desc_; | |||
| std::vector<std::shared_ptr<Sampler>> child_; // Child nodes | |||
| std::vector<std::shared_ptr<SamplerRT>> child_; // Child nodes | |||
| std::unique_ptr<DataBuffer> child_ids_; | |||
| }; | |||
| } // namespace dataset | |||
| @@ -20,10 +20,10 @@ | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| SequentialSampler::SequentialSampler(int64_t num_samples, int64_t start_index, int64_t samples_per_buffer) | |||
| : Sampler(num_samples, samples_per_buffer), start_index_(start_index), current_id_(start_index), id_count_(0) {} | |||
| SequentialSamplerRT::SequentialSamplerRT(int64_t num_samples, int64_t start_index, int64_t samples_per_buffer) | |||
| : SamplerRT(num_samples, samples_per_buffer), start_index_(start_index), current_id_(start_index), id_count_(0) {} | |||
| Status SequentialSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) { | |||
| Status SequentialSamplerRT::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) { | |||
| if (id_count_ > num_samples_) { | |||
| RETURN_STATUS_UNEXPECTED("SequentialSampler Internal Error"); | |||
| } else if (id_count_ == num_samples_) { | |||
| @@ -62,7 +62,7 @@ Status SequentialSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) | |||
| return Status::OK(); | |||
| } | |||
| Status SequentialSampler::InitSampler() { | |||
| Status SequentialSamplerRT::InitSampler() { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(start_index_ >= 0, | |||
| "Invalid parameter, start_index must be greater than or equal to 0, but got " + | |||
| std::to_string(start_index_) + ".\n"); | |||
| @@ -85,7 +85,7 @@ Status SequentialSampler::InitSampler() { | |||
| return Status::OK(); | |||
| } | |||
| Status SequentialSampler::ResetSampler() { | |||
| Status SequentialSamplerRT::ResetSampler() { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(id_count_ == num_samples_, "ERROR Reset() called early/late"); | |||
| current_id_ = start_index_; | |||
| id_count_ = 0; | |||
| @@ -97,11 +97,11 @@ Status SequentialSampler::ResetSampler() { | |||
| return Status::OK(); | |||
| } | |||
| void SequentialSampler::Print(std::ostream &out, bool show_all) const { | |||
| void SequentialSamplerRT::Print(std::ostream &out, bool show_all) const { | |||
| out << "\nSampler: SequentialSampler"; | |||
| if (show_all) { | |||
| // Call the super class for displaying any common detailed info | |||
| Sampler::Print(out, show_all); | |||
| SamplerRT::Print(out, show_all); | |||
| // Then add our own info | |||
| out << "\nStart index: " << start_index_; | |||
| } | |||
| @@ -23,18 +23,18 @@ | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| class SequentialSampler : public Sampler { | |||
| class SequentialSamplerRT : public SamplerRT { | |||
| public: | |||
| // Constructor | |||
| // @param num_samples - The number of samples to draw. A value of 0 indicates the sampler should produce the | |||
| // full amount of ids from the dataset | |||
| // @param start_index - The starting index value | |||
| // @param int64_t samplesPerBuffer - Num of Sampler Ids to fetch via 1 GetNextBuffer call | |||
| explicit SequentialSampler(int64_t num_samples, int64_t start_index, | |||
| int64_t samples_per_buffer = std::numeric_limits<int64_t>::max()); | |||
| explicit SequentialSamplerRT(int64_t num_samples, int64_t start_index, | |||
| int64_t samples_per_buffer = std::numeric_limits<int64_t>::max()); | |||
| // Destructor. | |||
| ~SequentialSampler() = default; | |||
| ~SequentialSamplerRT() = default; | |||
| // init sampler, called by python | |||
| Status InitSampler() override; | |||
| @@ -27,12 +27,12 @@ | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| // Constructor. | |||
| SubsetRandomSampler::SubsetRandomSampler(int64_t num_samples, const std::vector<int64_t> &indices, | |||
| int64_t samples_per_buffer) | |||
| : Sampler(num_samples, samples_per_buffer), indices_(indices), sample_id_(0), buffer_id_(0) {} | |||
| SubsetRandomSamplerRT::SubsetRandomSamplerRT(int64_t num_samples, const std::vector<int64_t> &indices, | |||
| int64_t samples_per_buffer) | |||
| : SamplerRT(num_samples, samples_per_buffer), indices_(indices), sample_id_(0), buffer_id_(0) {} | |||
| // Initialized this Sampler. | |||
| Status SubsetRandomSampler::InitSampler() { | |||
| Status SubsetRandomSamplerRT::InitSampler() { | |||
| CHECK_FAIL_RETURN_UNEXPECTED( | |||
| num_rows_ > 0, "Invalid parameter, num_rows must be greater than 0, but got " + std::to_string(num_rows_) + ".\n"); | |||
| @@ -56,7 +56,7 @@ Status SubsetRandomSampler::InitSampler() { | |||
| } | |||
| // Reset the internal variable to the initial state. | |||
| Status SubsetRandomSampler::ResetSampler() { | |||
| Status SubsetRandomSamplerRT::ResetSampler() { | |||
| // Reset the internal counters. | |||
| sample_id_ = 0; | |||
| buffer_id_ = 0; | |||
| @@ -73,7 +73,7 @@ Status SubsetRandomSampler::ResetSampler() { | |||
| } | |||
| // Get the sample ids. | |||
| Status SubsetRandomSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) { | |||
| Status SubsetRandomSamplerRT::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) { | |||
| // All samples have been drawn | |||
| if (sample_id_ == num_samples_) { | |||
| (*out_buffer) = std::make_unique<DataBuffer>(buffer_id_++, DataBuffer::kDeBFlagEOE); | |||
| @@ -120,11 +120,11 @@ Status SubsetRandomSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buffe | |||
| return Status::OK(); | |||
| } | |||
| void SubsetRandomSampler::Print(std::ostream &out, bool show_all) const { | |||
| void SubsetRandomSamplerRT::Print(std::ostream &out, bool show_all) const { | |||
| out << "\nSampler: SubsetRandomSampler"; | |||
| if (show_all) { | |||
| // Call the super class for displaying any common detailed info | |||
| Sampler::Print(out, show_all); | |||
| SamplerRT::Print(out, show_all); | |||
| // Then add our own info if any | |||
| } | |||
| } | |||
| @@ -25,18 +25,18 @@ | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| // Randomly samples elements from a given list of indices, without replacement. | |||
| class SubsetRandomSampler : public Sampler { | |||
| class SubsetRandomSamplerRT : public SamplerRT { | |||
| public: | |||
| // Constructor. | |||
| // @param num_samples The number of samples to draw. 0 for the full amount. | |||
| // @param indices List of indices from where we will randomly draw samples. | |||
| // @param samples_per_buffer The number of ids we draw on each call to GetNextBuffer(). | |||
| // When samplesPerBuffer=0, GetNextBuffer() will draw all the sample ids and return them at once. | |||
| explicit SubsetRandomSampler(int64_t num_samples, const std::vector<int64_t> &indices, | |||
| std::int64_t samples_per_buffer = std::numeric_limits<int64_t>::max()); | |||
| explicit SubsetRandomSamplerRT(int64_t num_samples, const std::vector<int64_t> &indices, | |||
| std::int64_t samples_per_buffer = std::numeric_limits<int64_t>::max()); | |||
| // Destructor. | |||
| ~SubsetRandomSampler() = default; | |||
| ~SubsetRandomSamplerRT() = default; | |||
| // Initialize the sampler. | |||
| // @return Status | |||
| @@ -27,16 +27,16 @@ | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| // Constructor. | |||
| WeightedRandomSampler::WeightedRandomSampler(int64_t num_samples, const std::vector<double> &weights, bool replacement, | |||
| int64_t samples_per_buffer) | |||
| : Sampler(num_samples, samples_per_buffer), | |||
| WeightedRandomSamplerRT::WeightedRandomSamplerRT(int64_t num_samples, const std::vector<double> &weights, | |||
| bool replacement, int64_t samples_per_buffer) | |||
| : SamplerRT(num_samples, samples_per_buffer), | |||
| weights_(weights), | |||
| replacement_(replacement), | |||
| sample_id_(0), | |||
| buffer_id_(0) {} | |||
| // Initialized this Sampler. | |||
| Status WeightedRandomSampler::InitSampler() { | |||
| Status WeightedRandomSamplerRT::InitSampler() { | |||
| // Special value of 0 for num_samples means that the user wants to sample the entire set of data. | |||
| // If the user asked to sample more rows than exists in the dataset, adjust the num_samples accordingly. | |||
| if (num_samples_ == 0 || num_samples_ > num_rows_) { | |||
| @@ -78,7 +78,7 @@ Status WeightedRandomSampler::InitSampler() { | |||
| } | |||
| // Initialized the computation for generating weighted random numbers without replacement using onepass method. | |||
| void WeightedRandomSampler::InitOnePassSampling() { | |||
| void WeightedRandomSamplerRT::InitOnePassSampling() { | |||
| exp_dist_->reset(); | |||
| onepass_ids_.clear(); | |||
| std::vector<std::pair<double, int64_t>> val_idx; | |||
| @@ -94,7 +94,7 @@ void WeightedRandomSampler::InitOnePassSampling() { | |||
| } | |||
| // Reset the internal variable to the initial state and reshuffle the indices. | |||
| Status WeightedRandomSampler::ResetSampler() { | |||
| Status WeightedRandomSamplerRT::ResetSampler() { | |||
| sample_id_ = 0; | |||
| buffer_id_ = 0; | |||
| rand_gen_.seed(GetSeed()); | |||
| @@ -112,7 +112,7 @@ Status WeightedRandomSampler::ResetSampler() { | |||
| } | |||
| // Get the sample ids. | |||
| Status WeightedRandomSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) { | |||
| Status WeightedRandomSamplerRT::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) { | |||
| if (weights_.size() > static_cast<size_t>(num_rows_)) { | |||
| return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, | |||
| "Invalid parameter, size of sample weights must be less than or equal to num of data, " | |||
| @@ -180,11 +180,11 @@ Status WeightedRandomSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buf | |||
| return Status::OK(); | |||
| } | |||
| void WeightedRandomSampler::Print(std::ostream &out, bool show_all) const { | |||
| void WeightedRandomSamplerRT::Print(std::ostream &out, bool show_all) const { | |||
| out << "\nSampler: WeightedRandomSampler"; | |||
| if (show_all) { | |||
| // Call the super class for displaying any common detailed info | |||
| Sampler::Print(out, show_all); | |||
| SamplerRT::Print(out, show_all); | |||
| // Then add our own info if any | |||
| } | |||
| } | |||
| @@ -26,7 +26,7 @@ | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| // Samples elements from id `0, 1, ..., weights.size()-1` with given probabilities (weights). | |||
| class WeightedRandomSampler : public Sampler { | |||
| class WeightedRandomSamplerRT : public SamplerRT { | |||
| public: | |||
| // Constructor. | |||
| // @param num_samples Number of samples to be drawn. | |||
| @@ -34,11 +34,11 @@ class WeightedRandomSampler : public Sampler { | |||
| // @param replacement Determine if samples are drawn with/without replacement. | |||
| // @param samples_per_buffer The number of ids we draw on each call to GetNextBuffer(). | |||
| // When samplesPerBuffer=0, GetNextBuffer() will draw all the sample ids and return them at once. | |||
| WeightedRandomSampler(int64_t num_samples, const std::vector<double> &weights, bool replacement, | |||
| int64_t samples_per_buffer = std::numeric_limits<int64_t>::max()); | |||
| WeightedRandomSamplerRT(int64_t num_samples, const std::vector<double> &weights, bool replacement, | |||
| int64_t samples_per_buffer = std::numeric_limits<int64_t>::max()); | |||
| // Destructor. | |||
| ~WeightedRandomSampler() = default; | |||
| ~WeightedRandomSamplerRT() = default; | |||
| // Initialize the sampler. | |||
| // @param op (Not used in this sampler) | |||
| @@ -84,7 +84,7 @@ Status TextFileOp::Builder::Build(std::shared_ptr<TextFileOp> *op) { | |||
| TextFileOp::TextFileOp(int32_t num_workers, int64_t rows_per_buffer, int64_t total_rows, int32_t worker_connector_size, | |||
| std::unique_ptr<DataSchema> schema, std::vector<std::string> text_files_list, | |||
| int32_t op_connector_size, bool shuffle_files, int32_t num_device, int32_t device_id, | |||
| std::shared_ptr<Sampler> sampler) | |||
| std::shared_ptr<SamplerRT> sampler) | |||
| : ParallelOp(num_workers, op_connector_size, std::move(sampler)), | |||
| device_id_(device_id), | |||
| num_devices_(num_device), | |||
| @@ -115,7 +115,7 @@ class TextFileOp : public ParallelOp { | |||
| // Setter method | |||
| // @param std::shared_ptr<Sampler> sampler | |||
| // @return Builder setter method returns reference to the builder. | |||
| Builder &SetSampler(std::shared_ptr<Sampler> sampler) { | |||
| Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) { | |||
| builder_sampler_ = std::move(sampler); | |||
| return *this; | |||
| } | |||
| @@ -131,7 +131,7 @@ class TextFileOp : public ParallelOp { | |||
| std::vector<std::string> builder_text_files_list_; | |||
| bool builder_shuffle_files_; | |||
| std::unique_ptr<DataSchema> builder_schema_; | |||
| std::shared_ptr<Sampler> builder_sampler_; | |||
| std::shared_ptr<SamplerRT> builder_sampler_; | |||
| }; | |||
| // Constructor of TextFileOp | |||
| @@ -148,7 +148,7 @@ class TextFileOp : public ParallelOp { | |||
| // @param sampler - allow a sampler. Only valid if a cache exists in ascendent tree nodes | |||
| TextFileOp(int32_t num_workers, int64_t rows_per_buffer, int64_t total_rows, int32_t worker_connector_size, | |||
| std::unique_ptr<DataSchema>, std::vector<std::string> text_files_list, int32_t op_connector_size, | |||
| bool shuffle_files, int32_t num_devices, int32_t device_id, std::shared_ptr<Sampler> sampler); | |||
| bool shuffle_files, int32_t num_devices, int32_t device_id, std::shared_ptr<SamplerRT> sampler); | |||
| // Default destructor | |||
| ~TextFileOp() = default; | |||
| @@ -58,7 +58,7 @@ TFReaderOp::Builder::Builder() | |||
| builder_data_schema_ = std::make_unique<DataSchema>(); | |||
| } | |||
| bool ValidateFirstRowCrc(const std::string &filename) { | |||
| bool TFReaderOp::ValidateFirstRowCrc(const std::string &filename) { | |||
| std::ifstream reader; | |||
| reader.open(filename); | |||
| if (!reader) { | |||
| @@ -134,7 +134,7 @@ TFReaderOp::TFReaderOp(int32_t num_workers, int32_t worker_connector_size, int64 | |||
| int64_t total_num_rows, std::vector<std::string> dataset_files_list, | |||
| std::unique_ptr<DataSchema> data_schema, int32_t op_connector_size, | |||
| std::vector<std::string> columns_to_load, bool shuffle_files, int32_t num_device, | |||
| int32_t device_id, bool equal_rows_per_shard, std::shared_ptr<Sampler> sampler) | |||
| int32_t device_id, bool equal_rows_per_shard, std::shared_ptr<SamplerRT> sampler) | |||
| : ParallelOp(num_workers, op_connector_size, std::move(sampler)), | |||
| device_id_(device_id), | |||
| num_devices_(num_device), | |||
| @@ -156,14 +156,14 @@ class TFReaderOp : public ParallelOp { | |||
| // Setter method | |||
| // @param std::shared_ptr<Sampler> sampler | |||
| // @return Builder setter method returns reference to the builder. | |||
| Builder &SetSampler(std::shared_ptr<Sampler> sampler) { | |||
| Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) { | |||
| builder_sampler_ = std::move(sampler); | |||
| return *this; | |||
| } | |||
| private: | |||
| std::unique_ptr<DataSchema> builder_data_schema_; | |||
| std::shared_ptr<Sampler> builder_sampler_; | |||
| std::shared_ptr<SamplerRT> builder_sampler_; | |||
| int32_t builder_device_id_; | |||
| int32_t builder_num_devices_; | |||
| int32_t builder_num_workers_; | |||
| @@ -193,7 +193,7 @@ class TFReaderOp : public ParallelOp { | |||
| TFReaderOp(int32_t num_workers, int32_t worker_connector_size, int64_t rows_per_buffer, int64_t total_num_rows, | |||
| std::vector<std::string> dataset_files_list, std::unique_ptr<DataSchema> data_schema, | |||
| int32_t op_connector_size, std::vector<std::string> columns_to_load, bool shuffle_files, | |||
| int32_t num_devices, int32_t device_id, bool equal_rows_per_shard, std::shared_ptr<Sampler> sampler); | |||
| int32_t num_devices, int32_t device_id, bool equal_rows_per_shard, std::shared_ptr<SamplerRT> sampler); | |||
| // Default destructor | |||
| ~TFReaderOp() = default; | |||
| @@ -262,6 +262,8 @@ class TFReaderOp : public ParallelOp { | |||
| /// \return Status of the function | |||
| Status GetDatasetSize(int64_t *dataset_size) override; | |||
| static bool ValidateFirstRowCrc(const std::string &filename); | |||
| private: | |||
| // The entry point for when workers are launched. | |||
| // @param worker_id - the id of the worker that is executing this function. | |||
| @@ -62,7 +62,7 @@ Status VOCOp::Builder::Build(std::shared_ptr<VOCOp> *ptr) { | |||
| if (builder_sampler_ == nullptr) { | |||
| const int64_t num_samples = 0; | |||
| const int64_t start_index = 0; | |||
| builder_sampler_ = std::make_shared<SequentialSampler>(start_index, num_samples); | |||
| builder_sampler_ = std::make_shared<SequentialSamplerRT>(start_index, num_samples); | |||
| } | |||
| builder_schema_ = std::make_unique<DataSchema>(); | |||
| if (builder_task_type_ == TaskType::Segmentation) { | |||
| @@ -102,7 +102,8 @@ Status VOCOp::Builder::SanityCheck() { | |||
| VOCOp::VOCOp(const TaskType &task_type, const std::string &task_mode, const std::string &folder_path, | |||
| const std::map<std::string, int32_t> &class_index, int32_t num_workers, int32_t rows_per_buffer, | |||
| int32_t queue_size, bool decode, std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler) | |||
| int32_t queue_size, bool decode, std::unique_ptr<DataSchema> data_schema, | |||
| std::shared_ptr<SamplerRT> sampler) | |||
| : ParallelOp(num_workers, queue_size, std::move(sampler)), | |||
| decode_(decode), | |||
| row_cnt_(0), | |||
| @@ -118,7 +118,7 @@ class VOCOp : public ParallelOp, public RandomAccessOp { | |||
| // Setter method. | |||
| // @param std::shared_ptr<Sampler> sampler | |||
| // @return Builder setter method returns reference to the builder. | |||
| Builder &SetSampler(std::shared_ptr<Sampler> sampler) { | |||
| Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) { | |||
| builder_sampler_ = std::move(sampler); | |||
| return *this; | |||
| } | |||
| @@ -148,7 +148,7 @@ class VOCOp : public ParallelOp, public RandomAccessOp { | |||
| int32_t builder_num_workers_; | |||
| int32_t builder_op_connector_size_; | |||
| int32_t builder_rows_per_buffer_; | |||
| std::shared_ptr<Sampler> builder_sampler_; | |||
| std::shared_ptr<SamplerRT> builder_sampler_; | |||
| std::unique_ptr<DataSchema> builder_schema_; | |||
| std::map<std::string, int32_t> builder_labels_to_read_; | |||
| }; | |||
| @@ -166,7 +166,7 @@ class VOCOp : public ParallelOp, public RandomAccessOp { | |||
| // @param std::shared_ptr<Sampler> sampler - sampler tells VOCOp what to read | |||
| VOCOp(const TaskType &task_type, const std::string &task_mode, const std::string &folder_path, | |||
| const std::map<std::string, int32_t> &class_index, int32_t num_workers, int32_t rows_per_buffer, | |||
| int32_t queue_size, bool decode, std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler); | |||
| int32_t queue_size, bool decode, std::unique_ptr<DataSchema> data_schema, std::shared_ptr<SamplerRT> sampler); | |||
| // Destructor | |||
| ~VOCOp() = default; | |||
| @@ -21,7 +21,7 @@ | |||
| #include "minddata/dataset/util/status.h" | |||
| #include "minddata/dataset/engine/datasetops/dataset_op.h" | |||
| namespace mindspore::dataset::api { | |||
| namespace mindspore::dataset { | |||
| class DatasetCache { | |||
| public: | |||
| @@ -29,6 +29,6 @@ class DatasetCache { | |||
| virtual Status ValidateParams() = 0; | |||
| virtual Status CreateCacheOp(int num_workers, std::shared_ptr<DatasetOp> *ds_op) = 0; | |||
| }; | |||
| } // namespace mindspore::dataset::api | |||
| } // namespace mindspore::dataset | |||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_CACHE_DATASET_CACHE_H_ | |||
| @@ -18,7 +18,7 @@ | |||
| #include "minddata/dataset/engine/ir/cache/dataset_cache_impl.h" | |||
| #include "minddata/dataset/engine/datasetops/cache_op.h" | |||
| namespace mindspore::dataset::api { | |||
| namespace mindspore::dataset { | |||
| /// Method to initialize the DatasetCache by creating an instance of a CacheClient | |||
| /// \return Status Error code | |||
| @@ -41,4 +41,4 @@ Status DatasetCacheImpl::CreateCacheOp(int32_t num_workers, std::shared_ptr<Data | |||
| return Status::OK(); | |||
| } | |||
| } // namespace mindspore::dataset::api | |||
| } // namespace mindspore::dataset | |||
| @@ -24,7 +24,7 @@ | |||
| #include "minddata/dataset/engine/datasetops/cache_op.h" | |||
| #include "minddata/dataset/engine/ir/cache/dataset_cache.h" | |||
| namespace mindspore::dataset::api { | |||
| namespace mindspore::dataset { | |||
| /// DatasetCache is the IR of CacheClient | |||
| class DatasetCacheImpl : public DatasetCache { | |||
| @@ -67,6 +67,6 @@ class DatasetCacheImpl : public DatasetCache { | |||
| std::optional<int32_t> num_connections_; | |||
| std::optional<int32_t> prefetch_sz_; | |||
| }; | |||
| } // namespace mindspore::dataset::api | |||
| } // namespace mindspore::dataset | |||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_CACHE_DATASET_CACHE_IMPL_H_ | |||
| @@ -26,7 +26,6 @@ | |||
| #include "minddata/dataset/util/status.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| namespace api { | |||
| #ifdef ENABLE_PYTHON | |||
| // constructor #1, called by Pybind | |||
| @@ -96,6 +95,5 @@ std::vector<std::shared_ptr<DatasetOp>> BatchNode::Build() { | |||
| return node_ops; | |||
| } | |||
| } // namespace api | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -27,7 +27,6 @@ | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| namespace api { | |||
| class BatchNode : public DatasetNode { | |||
| public: | |||
| @@ -66,7 +65,6 @@ class BatchNode : public DatasetNode { | |||
| std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> pad_map_; | |||
| }; | |||
| } // namespace api | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_BATCH_NODE_H_ | |||
| @@ -27,7 +27,7 @@ | |||
| #include "minddata/dataset/util/status.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| namespace api { | |||
| BucketBatchByLengthNode::BucketBatchByLengthNode( | |||
| std::shared_ptr<DatasetNode> child, const std::vector<std::string> &column_names, | |||
| const std::vector<int32_t> &bucket_boundaries, const std::vector<int32_t> &bucket_batch_sizes, | |||
| @@ -121,6 +121,5 @@ Status BucketBatchByLengthNode::ValidateParams() { | |||
| return Status::OK(); | |||
| } | |||
| } // namespace api | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -27,7 +27,7 @@ | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| namespace api { | |||
| class BucketBatchByLengthNode : public DatasetNode { | |||
| public: | |||
| /// \brief Constructor | |||
| @@ -58,7 +58,6 @@ class BucketBatchByLengthNode : public DatasetNode { | |||
| bool drop_remainder_; | |||
| }; | |||
| } // namespace api | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_BUCKET_BATCH_BY_LENGTH_NODE_H_ | |||
| @@ -26,7 +26,6 @@ | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| namespace api { | |||
| BuildSentenceVocabNode::BuildSentenceVocabNode(std::shared_ptr<DatasetNode> child, | |||
| std::shared_ptr<SentencePieceVocab> vocab, | |||
| @@ -77,6 +76,6 @@ Status BuildSentenceVocabNode::ValidateParams() { | |||
| return Status::OK(); | |||
| } | |||
| } // namespace api | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -27,7 +27,6 @@ | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| namespace api { | |||
| class BuildSentenceVocabNode : public DatasetNode { | |||
| public: | |||
| @@ -56,7 +55,6 @@ class BuildSentenceVocabNode : public DatasetNode { | |||
| std::unordered_map<std::string, std::string> params_; | |||
| }; | |||
| } // namespace api | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_BUILD_SENTENCE_PIECE_VOCAB_NODE_H_ | |||
| @@ -26,7 +26,6 @@ | |||
| #include "minddata/dataset/util/status.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| namespace api { | |||
| BuildVocabNode::BuildVocabNode(std::shared_ptr<DatasetNode> child, std::shared_ptr<Vocab> vocab, | |||
| const std::vector<std::string> &columns, const std::pair<int64_t, int64_t> &freq_range, | |||
| @@ -78,6 +77,6 @@ Status BuildVocabNode::ValidateParams() { | |||
| return Status::OK(); | |||
| } | |||
| } // namespace api | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -26,7 +26,6 @@ | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| namespace api { | |||
| class BuildVocabNode : public DatasetNode { | |||
| public: | |||
| @@ -55,7 +54,6 @@ class BuildVocabNode : public DatasetNode { | |||
| bool special_first_; | |||
| }; | |||
| } // namespace api | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_BUILD_VOCAB_NODE_H_ | |||
| @@ -25,7 +25,7 @@ | |||
| #include "minddata/dataset/util/status.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| namespace api { | |||
| // Function to build ConcatOp | |||
| ConcatNode::ConcatNode(const std::vector<std::shared_ptr<DatasetNode>> &datasets) { this->children = datasets; } | |||
| @@ -53,6 +53,5 @@ std::vector<std::shared_ptr<DatasetOp>> ConcatNode::Build() { | |||
| return node_ops; | |||
| } | |||
| } // namespace api | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -25,7 +25,6 @@ | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| namespace api { | |||
| class ConcatNode : public DatasetNode { | |||
| public: | |||
| @@ -44,7 +43,6 @@ class ConcatNode : public DatasetNode { | |||
| Status ValidateParams() override; | |||
| }; | |||
| } // namespace api | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_CONCAT_NODE_H_ | |||
| @@ -16,11 +16,187 @@ | |||
| #include "minddata/dataset/engine/ir/datasetops/dataset_node.h" | |||
| #include <algorithm> | |||
| #include <memory> | |||
| #include <set> | |||
| #include "minddata/dataset/util/random.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| namespace api { | |||
| // Helper function to compute a default shuffle size | |||
| Status ComputeShuffleSize(int64_t num_files, int64_t num_devices, int64_t num_rows, int64_t total_rows, | |||
| int64_t *shuffle_size) { | |||
| const int64_t average_files_multiplier = 4; | |||
| const int64_t shuffle_max = 10000; | |||
| int64_t avg_rows_per_file = 0; | |||
| // Adjust the num rows per shard if sharding was given | |||
| if (num_devices > 0) { | |||
| if (num_rows % num_devices == 0) { | |||
| num_rows = num_rows / num_devices; | |||
| } else { | |||
| num_rows = (num_rows / num_devices) + 1; | |||
| } | |||
| } | |||
| // Cap based on total rows directive. Some ops do not have this and give value of 0. | |||
| if (total_rows > 0) { | |||
| num_rows = std::min(num_rows, total_rows); | |||
| } | |||
| // get the average per file | |||
| CHECK_FAIL_RETURN_UNEXPECTED(num_files != 0, "The size of dataset_files must greater than 0."); | |||
| avg_rows_per_file = num_rows / num_files; | |||
| *shuffle_size = std::max(avg_rows_per_file * average_files_multiplier, shuffle_max); | |||
| return Status::OK(); | |||
| } | |||
| // Helper function to inject a shuffle operator over top of current operator being built | |||
| Status AddShuffleOp(int64_t num_files, int64_t num_devices, int64_t num_rows, int64_t total_rows, | |||
| int32_t connector_que_size, int32_t rows_per_buffer, std::shared_ptr<DatasetOp> *shuffle_op) { | |||
| std::shared_ptr<ShuffleOp> new_shuffle_op = nullptr; | |||
| int64_t shuffle_size = 0; | |||
| RETURN_EMPTY_IF_ERROR(ComputeShuffleSize(num_files, num_devices, num_rows, total_rows, &shuffle_size)); | |||
| MS_LOG(INFO) << "Dataset::AddShuffleOp - num_rows: " << num_rows << ", shuffle_size: " << shuffle_size; | |||
| // Add the shuffle op | |||
| *shuffle_op = std::make_shared<ShuffleOp>(shuffle_size, GetSeed(), connector_que_size, true, rows_per_buffer); | |||
| return Status::OK(); | |||
| } | |||
| // Helper function to validate dataset directory parameter | |||
| Status ValidateDatasetDirParam(const std::string &dataset_name, std::string dataset_dir) { | |||
| if (dataset_dir.empty()) { | |||
| std::string err_msg = dataset_name + ": dataset_dir is not specified."; | |||
| MS_LOG(ERROR) << err_msg; | |||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||
| } | |||
| Path dir(dataset_dir); | |||
| if (!dir.IsDirectory()) { | |||
| std::string err_msg = dataset_name + ": dataset_dir: [" + dataset_dir + "] is an invalid directory path."; | |||
| MS_LOG(ERROR) << err_msg; | |||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||
| } | |||
| if (access(dataset_dir.c_str(), R_OK) == -1) { | |||
| std::string err_msg = dataset_name + ": No access to specified dataset path: " + dataset_dir; | |||
| MS_LOG(ERROR) << err_msg; | |||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| // Helper function to validate dataset files parameter | |||
| Status ValidateDatasetFilesParam(const std::string &dataset_name, const std::vector<std::string> &dataset_files) { | |||
| if (dataset_files.empty()) { | |||
| std::string err_msg = dataset_name + ": dataset_files is not specified."; | |||
| MS_LOG(ERROR) << err_msg; | |||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||
| } | |||
| for (auto f : dataset_files) { | |||
| Path dataset_file(f); | |||
| if (!dataset_file.Exists()) { | |||
| std::string err_msg = dataset_name + ": dataset file: [" + f + "] is invalid or does not exist."; | |||
| MS_LOG(ERROR) << err_msg; | |||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||
| } | |||
| if (access(dataset_file.toString().c_str(), R_OK) == -1) { | |||
| std::string err_msg = dataset_name + ": No access to specified dataset file: " + f; | |||
| MS_LOG(ERROR) << err_msg; | |||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||
| } | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| // Helper function to validate dataset num_shards and shard_id parameters | |||
| Status ValidateDatasetShardParams(const std::string &dataset_name, int32_t num_shards, int32_t shard_id) { | |||
| if (num_shards <= 0) { | |||
| std::string err_msg = dataset_name + ": Invalid num_shards: " + std::to_string(num_shards); | |||
| MS_LOG(ERROR) << err_msg; | |||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||
| } | |||
| if (shard_id < 0 || shard_id >= num_shards) { | |||
| // num_shards; | |||
| std::string err_msg = dataset_name + ": Invalid input, shard_id: " + std::to_string(shard_id) + | |||
| ", num_shards: " + std::to_string(num_shards); | |||
| MS_LOG(ERROR) << err_msg; | |||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| // Helper function to validate dataset sampler parameter | |||
| Status ValidateDatasetSampler(const std::string &dataset_name, const std::shared_ptr<SamplerObj> &sampler) { | |||
| if (sampler == nullptr) { | |||
| std::string err_msg = dataset_name + ": Sampler is not constructed correctly, sampler: nullptr"; | |||
| MS_LOG(ERROR) << err_msg; | |||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status ValidateStringValue(const std::string &dataset_name, const std::string &str, | |||
| const std::unordered_set<std::string> &valid_strings) { | |||
| if (valid_strings.find(str) == valid_strings.end()) { | |||
| std::string mode; | |||
| mode = std::accumulate(valid_strings.begin(), valid_strings.end(), mode, | |||
| [](std::string a, std::string b) { return std::move(a) + " " + std::move(b); }); | |||
| std::string err_msg = dataset_name + ": " + str + " does not match any mode in [" + mode + " ]"; | |||
| MS_LOG(ERROR) << err_msg; | |||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| // Helper function to validate dataset input/output column parameter | |||
| Status ValidateDatasetColumnParam(const std::string &dataset_name, const std::string &column_param, | |||
| const std::vector<std::string> &columns) { | |||
| if (columns.empty()) { | |||
| std::string err_msg = dataset_name + ":" + column_param + " should not be empty string"; | |||
| MS_LOG(ERROR) << err_msg; | |||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||
| } | |||
| for (uint32_t i = 0; i < columns.size(); ++i) { | |||
| if (columns[i].empty()) { | |||
| std::string err_msg = dataset_name + ":" + column_param + "[" + std::to_string(i) + "] must not be empty"; | |||
| MS_LOG(ERROR) << err_msg; | |||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||
| } | |||
| } | |||
| std::set<std::string> columns_set(columns.begin(), columns.end()); | |||
| if (columns_set.size() != columns.size()) { | |||
| std::string err_msg = dataset_name + ":" + column_param + ": Every column name should not be same with others"; | |||
| MS_LOG(ERROR) << err_msg; | |||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| std::shared_ptr<SamplerObj> SelectSampler(int64_t num_samples, bool shuffle, int32_t num_shards, int32_t shard_id) { | |||
| if (shuffle) { | |||
| if (num_shards > 1) { | |||
| // If shuffle enabled, sharding enabled, use distributed random sampler | |||
| return DistributedSampler(num_shards, shard_id, shuffle, num_samples); | |||
| } | |||
| // If shuffle enabled, sharding disabled, use random sampler | |||
| return RandomSampler(num_samples >= 0, num_samples); | |||
| } | |||
| if (num_shards > 1) { | |||
| // If shuffle disabled, sharding enabled, use distributed sequential sampler | |||
| return DistributedSampler(num_shards, shard_id, shuffle, num_samples); | |||
| } | |||
| // If shuffle disabled, sharding disabled, use sequential sampler | |||
| return SequentialSampler(0, num_samples); | |||
| } | |||
| Status DatasetNode::AddCacheOp(std::vector<std::shared_ptr<DatasetOp>> *node_ops) { | |||
| if (cache_ != nullptr) { | |||
| @@ -60,6 +236,5 @@ DatasetNode::DatasetNode() { | |||
| worker_connector_size_ = cfg->worker_connector_size(); | |||
| } | |||
| } // namespace api | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -28,7 +28,6 @@ | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| namespace api { | |||
| class Dataset; | |||
| class SamplerObj; | |||
| @@ -120,7 +119,6 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> { | |||
| int32_t worker_connector_size_; | |||
| }; | |||
| } // namespace api | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_DATASET_NODE_H_ | |||
| @@ -26,7 +26,6 @@ | |||
| #include "minddata/dataset/util/status.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| namespace api { | |||
| MapNode::MapNode(std::shared_ptr<DatasetNode> child, std::vector<std::shared_ptr<TensorOperation>> operations, | |||
| std::vector<std::string> input_columns, std::vector<std::string> output_columns, | |||
| @@ -86,6 +85,5 @@ Status MapNode::ValidateParams() { | |||
| return Status::OK(); | |||
| } | |||
| } // namespace api | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -25,7 +25,7 @@ | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| namespace api { | |||
| class MapNode : public DatasetNode { | |||
| public: | |||
| /// \brief Constructor | |||
| @@ -51,7 +51,6 @@ class MapNode : public DatasetNode { | |||
| std::vector<std::string> project_columns_; | |||
| }; | |||
| } // namespace api | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_MAP_NODE_H_ | |||
| @@ -25,7 +25,6 @@ | |||
| #include "minddata/dataset/util/status.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| namespace api { | |||
| // Function to build ProjectOp | |||
| ProjectNode::ProjectNode(std::shared_ptr<DatasetNode> child, const std::vector<std::string> &columns) | |||
| @@ -53,6 +52,5 @@ std::vector<std::shared_ptr<DatasetOp>> ProjectNode::Build() { | |||
| return node_ops; | |||
| } | |||
| } // namespace api | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -26,8 +26,6 @@ | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| namespace api { | |||
| class ProjectNode : public DatasetNode { | |||
| public: | |||
| /// \brief Constructor | |||
| @@ -48,7 +46,6 @@ class ProjectNode : public DatasetNode { | |||
| std::vector<std::string> columns_; | |||
| }; | |||
| } // namespace api | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_PROJECT_NODE_H_ | |||
| @@ -25,7 +25,7 @@ | |||
| #include "minddata/dataset/util/status.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| namespace api { | |||
| // Function to build RenameOp | |||
| RenameNode::RenameNode(std::shared_ptr<DatasetNode> child, const std::vector<std::string> &input_columns, | |||
| const std::vector<std::string> &output_columns) | |||
| @@ -54,6 +54,6 @@ std::vector<std::shared_ptr<DatasetOp>> RenameNode::Build() { | |||
| node_ops.push_back(std::make_shared<RenameOp>(input_columns_, output_columns_, connector_que_size_)); | |||
| return node_ops; | |||
| } | |||
| } // namespace api | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -26,8 +26,6 @@ | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| namespace api { | |||
| class RenameNode : public DatasetNode { | |||
| public: | |||
| /// \brief Constructor | |||
| @@ -50,7 +48,6 @@ class RenameNode : public DatasetNode { | |||
| std::vector<std::string> output_columns_; | |||
| }; | |||
| } // namespace api | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_RENAME_NODE_H_ | |||
| @@ -25,7 +25,6 @@ | |||
| #include "minddata/dataset/util/status.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| namespace api { | |||
| RepeatNode::RepeatNode(std::shared_ptr<DatasetNode> child, int32_t count) : repeat_count_(count) { | |||
| this->children.push_back(child); | |||
| @@ -49,6 +48,6 @@ Status RepeatNode::ValidateParams() { | |||
| return Status::OK(); | |||
| } | |||
| } // namespace api | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -28,8 +28,6 @@ | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| namespace api { | |||
| class RepeatNode : public DatasetNode { | |||
| public: | |||
| /// \brief Constructor | |||
| @@ -50,7 +48,6 @@ class RepeatNode : public DatasetNode { | |||
| int32_t repeat_count_; | |||
| }; | |||
| } // namespace api | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_REPEAT_NODE_H_ | |||
| @@ -25,7 +25,6 @@ | |||
| #include "minddata/dataset/util/status.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| namespace api { | |||
| // Constructor for ShuffleNode | |||
| ShuffleNode::ShuffleNode(std::shared_ptr<DatasetNode> child, int32_t shuffle_size, bool reset_every_epoch) | |||
| @@ -54,6 +53,5 @@ Status ShuffleNode::ValidateParams() { | |||
| return Status::OK(); | |||
| } | |||
| } // namespace api | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -28,8 +28,6 @@ | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| namespace api { | |||
| class ShuffleNode : public DatasetNode { | |||
| public: | |||
| ShuffleNode(std::shared_ptr<DatasetNode> child, int32_t shuffle_size, bool reset_every_epoch); | |||
| @@ -46,7 +44,6 @@ class ShuffleNode : public DatasetNode { | |||
| bool reset_every_epoch_; | |||
| }; | |||
| } // namespace api | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SHUFFLE_NODE_H_ | |||
| @@ -25,7 +25,6 @@ | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| namespace api { | |||
| // Constructor for SkipNode | |||
| SkipNode::SkipNode(std::shared_ptr<DatasetNode> child, int32_t count) : skip_count_(count) { | |||
| @@ -52,6 +51,5 @@ Status SkipNode::ValidateParams() { | |||
| return Status::OK(); | |||
| } | |||
| } // namespace api | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -26,7 +26,6 @@ | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| namespace api { | |||
| class SkipNode : public DatasetNode { | |||
| public: | |||
| /// \brief Constructor | |||
| @@ -46,7 +45,7 @@ class SkipNode : public DatasetNode { | |||
| private: | |||
| int32_t skip_count_; | |||
| }; | |||
| } // namespace api | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SKIP_NODE_H_ | |||
| @@ -27,7 +27,7 @@ | |||
| #include "minddata/dataset/util/status.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| namespace api { | |||
| // Constructor for AlbumNode | |||
| AlbumNode::AlbumNode(const std::string &dataset_dir, const std::string &data_schema, | |||
| const std::vector<std::string> &column_names, bool decode, | |||
| @@ -78,6 +78,5 @@ Status AlbumNode::GetShardId(int32_t *shard_id) { | |||
| return Status::OK(); | |||
| } | |||
| } // namespace api | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -25,7 +25,6 @@ | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| namespace api { | |||
| class AlbumNode : public DatasetNode { | |||
| public: | |||
| @@ -57,7 +56,6 @@ class AlbumNode : public DatasetNode { | |||
| std::shared_ptr<SamplerObj> sampler_; | |||
| }; | |||
| } // namespace api | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_ALBUM_NODE_H_ | |||
| @@ -26,7 +26,7 @@ | |||
| #include "minddata/dataset/util/status.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| namespace api { | |||
| // Constructor for CelebANode | |||
| CelebANode::CelebANode(const std::string &dataset_dir, const std::string &usage, | |||
| const std::shared_ptr<SamplerObj> &sampler, const bool &decode, | |||
| @@ -76,6 +76,5 @@ Status CelebANode::GetShardId(int32_t *shard_id) { | |||
| return Status::OK(); | |||
| } | |||
| } // namespace api | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||