| @@ -21,7 +21,6 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| namespace api { | |||||
| // Config operations for setting and getting the configuration. | // Config operations for setting and getting the configuration. | ||||
| namespace config { | namespace config { | ||||
| @@ -104,6 +103,5 @@ bool load(std::string file) { | |||||
| } | } | ||||
| } // namespace config | } // namespace config | ||||
| } // namespace api | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -21,36 +21,14 @@ | |||||
| #include <utility> | #include <utility> | ||||
| #include "minddata/dataset/include/samplers.h" | #include "minddata/dataset/include/samplers.h" | ||||
| #include "minddata/dataset/include/transforms.h" | #include "minddata/dataset/include/transforms.h" | ||||
| // Source dataset headers (in alphabetical order) | |||||
| #include "minddata/dataset/engine/dataset_iterator.h" | |||||
| #include "minddata/dataset/engine/datasetops/source/album_op.h" | |||||
| #include "minddata/dataset/engine/datasetops/source/celeba_op.h" | |||||
| #include "minddata/dataset/engine/datasetops/source/cifar_op.h" | |||||
| #include "minddata/dataset/engine/datasetops/source/clue_op.h" | |||||
| #include "minddata/dataset/engine/datasetops/source/coco_op.h" | |||||
| #include "minddata/dataset/engine/datasetops/source/csv_op.h" | |||||
| #include "minddata/dataset/engine/datasetops/source/image_folder_op.h" | |||||
| #ifndef ENABLE_ANDROID | #ifndef ENABLE_ANDROID | ||||
| #include "minddata/dataset/engine/datasetops/source/manifest_op.h" | |||||
| #include "minddata/dataset/engine/datasetops/source/mindrecord_op.h" | |||||
| #include "minddata/dataset/engine/ir/cache/dataset_cache_impl.h" | #include "minddata/dataset/engine/ir/cache/dataset_cache_impl.h" | ||||
| #endif | #endif | ||||
| #include "minddata/dataset/engine/datasetops/source/mnist_op.h" | |||||
| #include "minddata/dataset/engine/datasetops/source/random_data_op.h" | |||||
| #include "minddata/dataset/engine/datasetops/source/text_file_op.h" | |||||
| #ifndef ENABLE_ANDROID | |||||
| #include "minddata/dataset/engine/datasetops/source/tf_reader_op.h" | |||||
| #include "minddata/dataset/engine/datasetops/source/voc_op.h" | |||||
| #endif | |||||
| // Dataset operator headers (in alphabetical order) | |||||
| #include "minddata/dataset/engine/datasetops/map_op/map_op.h" | |||||
| #include "minddata/dataset/engine/datasetops/skip_op.h" | |||||
| #include "minddata/dataset/engine/datasetops/zip_op.h" | |||||
| // Sampler headers (in alphabetical order) | // Sampler headers (in alphabetical order) | ||||
| #include "minddata/dataset/engine/datasetops/source/sampler/random_sampler.h" | |||||
| #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" | #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" | ||||
| #include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h" | |||||
| // IR non-leaf nodes | // IR non-leaf nodes | ||||
| #include "minddata/dataset/engine/ir/datasetops/batch_node.h" | #include "minddata/dataset/engine/ir/datasetops/batch_node.h" | ||||
| @@ -99,7 +77,6 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| namespace api { | |||||
| // Function to create the iterator, which will build and launch the execution tree. | // Function to create the iterator, which will build and launch the execution tree. | ||||
| std::shared_ptr<Iterator> Dataset::CreateIterator(std::vector<std::string> columns) { | std::shared_ptr<Iterator> Dataset::CreateIterator(std::vector<std::string> columns) { | ||||
| @@ -317,7 +294,7 @@ std::shared_ptr<SchemaObj> Schema(const std::string &schema_file) { | |||||
| return schema->init() ? schema : nullptr; | return schema->init() ? schema : nullptr; | ||||
| } | } | ||||
| // FUNCTIONS TO CREATE DATASETS FOR LEAF-NODE DATASETS | |||||
| // FUNCTIONS TO CREATE DATASETS FOR LEAF CLASSES | |||||
| // (In alphabetical order) | // (In alphabetical order) | ||||
| // Function to create a AlbumDataset. | // Function to create a AlbumDataset. | ||||
| @@ -466,7 +443,7 @@ std::shared_ptr<VOCDataset> VOC(const std::string &dataset_dir, const std::strin | |||||
| } | } | ||||
| #endif | #endif | ||||
| // Function to create a ZipNode. | |||||
| // Function to create a ZipDatset. | |||||
| std::shared_ptr<ZipDataset> Zip(const std::vector<std::shared_ptr<Dataset>> &datasets) { | std::shared_ptr<ZipDataset> Zip(const std::vector<std::shared_ptr<Dataset>> &datasets) { | ||||
| auto ds = std::make_shared<ZipDataset>(datasets); | auto ds = std::make_shared<ZipDataset>(datasets); | ||||
| return ds; | return ds; | ||||
| @@ -639,7 +616,7 @@ std::shared_ptr<SentencePieceVocab> Dataset::BuildSentencePieceVocab( | |||||
| std::unique_ptr<RuntimeContext> runtime_context = std::make_unique<RuntimeContext>(); | std::unique_ptr<RuntimeContext> runtime_context = std::make_unique<RuntimeContext>(); | ||||
| Status rc = runtime_context->Init(); | Status rc = runtime_context->Init(); | ||||
| if (rc.IsError()) { | if (rc.IsError()) { | ||||
| MS_LOG(ERROR) << "Failed to init runtime context. Error status: " << rc; | |||||
| MS_LOG(ERROR) << "BuildSentencePieceVocab: Failed to init runtime context. Error status: " << rc; | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| @@ -647,15 +624,15 @@ std::shared_ptr<SentencePieceVocab> Dataset::BuildSentencePieceVocab( | |||||
| BuildVocabConsumer *bv_consumer = consumer.get(); | BuildVocabConsumer *bv_consumer = consumer.get(); | ||||
| rc = consumer->Init(ds); | rc = consumer->Init(ds); | ||||
| if (rc.IsError()) { | if (rc.IsError()) { | ||||
| MS_LOG(ERROR) << "BuildVocab: Failed to init. Error status: " << rc; | |||||
| MS_LOG(ERROR) << "BuildSentencePieceVocab: Failed to init consumer. Error status: " << rc; | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| runtime_context->AssignConsumer(std::move(consumer)); | runtime_context->AssignConsumer(std::move(consumer)); | ||||
| // Run tree here to starting building vocab | |||||
| // Run tree here to starting building SentencePieceVocab | |||||
| rc = bv_consumer->Start(); | rc = bv_consumer->Start(); | ||||
| if (rc.IsError()) { | if (rc.IsError()) { | ||||
| MS_LOG(ERROR) << "BuildVocab: Failed to start. Error status: " << rc; | |||||
| MS_LOG(ERROR) << "BuildSentencePieceVocab: Failed to start consumer. Error status: " << rc; | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| return vocab; | return vocab; | ||||
| @@ -671,7 +648,7 @@ std::shared_ptr<Vocab> Dataset::BuildVocab(const std::vector<std::string> &colum | |||||
| std::unique_ptr<RuntimeContext> runtime_context = std::make_unique<RuntimeContext>(); | std::unique_ptr<RuntimeContext> runtime_context = std::make_unique<RuntimeContext>(); | ||||
| Status rc = runtime_context->Init(); | Status rc = runtime_context->Init(); | ||||
| if (rc.IsError()) { | if (rc.IsError()) { | ||||
| MS_LOG(ERROR) << "Failed to init runtime context. Error status: " << rc; | |||||
| MS_LOG(ERROR) << "BuildVocab: Failed to init runtime context. Error status: " << rc; | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| @@ -679,7 +656,7 @@ std::shared_ptr<Vocab> Dataset::BuildVocab(const std::vector<std::string> &colum | |||||
| BuildVocabConsumer *bv_consumer = consumer.get(); | BuildVocabConsumer *bv_consumer = consumer.get(); | ||||
| rc = consumer->Init(ds); | rc = consumer->Init(ds); | ||||
| if (rc.IsError()) { | if (rc.IsError()) { | ||||
| MS_LOG(ERROR) << "BuildVocab: Failed to init. Error status: " << rc; | |||||
| MS_LOG(ERROR) << "BuildVocab: Failed to init consumer. Error status: " << rc; | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| runtime_context->AssignConsumer(std::move(consumer)); | runtime_context->AssignConsumer(std::move(consumer)); | ||||
| @@ -687,11 +664,14 @@ std::shared_ptr<Vocab> Dataset::BuildVocab(const std::vector<std::string> &colum | |||||
| // Run tree here to starting building vocab | // Run tree here to starting building vocab | ||||
| rc = bv_consumer->Start(); | rc = bv_consumer->Start(); | ||||
| if (rc.IsError()) { | if (rc.IsError()) { | ||||
| MS_LOG(ERROR) << "BuildVocab: Failed to start. Error status: " << rc; | |||||
| MS_LOG(ERROR) << "BuildVocab: Failed to start consumer. Error status: " << rc; | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| return vocab; | return vocab; | ||||
| } | } | ||||
| std::shared_ptr<BatchDataset> Dataset::Batch(int32_t batch_size, bool drop_remainder) { | |||||
| return std::make_shared<BatchDataset>(shared_from_this(), batch_size, drop_remainder); | |||||
| } | |||||
| #endif | #endif | ||||
| SchemaObj::SchemaObj(const std::string &schema_file) : schema_file_(schema_file), num_rows_(0), dataset_type_("") {} | SchemaObj::SchemaObj(const std::string &schema_file) : schema_file_(schema_file), num_rows_(0), dataset_type_("") {} | ||||
| @@ -856,162 +836,6 @@ bool SchemaObj::from_json(nlohmann::json json_obj) { | |||||
| // OTHER FUNCTIONS | // OTHER FUNCTIONS | ||||
| // Helper function to compute a default shuffle size | |||||
| Status ComputeShuffleSize(int64_t num_files, int64_t num_devices, int64_t num_rows, int64_t total_rows, | |||||
| int64_t *shuffle_size) { | |||||
| const int64_t average_files_multiplier = 4; | |||||
| const int64_t shuffle_max = 10000; | |||||
| int64_t avg_rows_per_file = 0; | |||||
| // Adjust the num rows per shard if sharding was given | |||||
| if (num_devices > 0) { | |||||
| if (num_rows % num_devices == 0) { | |||||
| num_rows = num_rows / num_devices; | |||||
| } else { | |||||
| num_rows = (num_rows / num_devices) + 1; | |||||
| } | |||||
| } | |||||
| // Cap based on total rows directive. Some ops do not have this and give value of 0. | |||||
| if (total_rows > 0) { | |||||
| num_rows = std::min(num_rows, total_rows); | |||||
| } | |||||
| // get the average per file | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(num_files != 0, "The size of dataset_files must greater than 0."); | |||||
| avg_rows_per_file = num_rows / num_files; | |||||
| *shuffle_size = std::max(avg_rows_per_file * average_files_multiplier, shuffle_max); | |||||
| return Status::OK(); | |||||
| } | |||||
| // Helper function to inject a shuffle operator over top of current operator being built | |||||
| Status AddShuffleOp(int64_t num_files, int64_t num_devices, int64_t num_rows, int64_t total_rows, | |||||
| int32_t connector_que_size, int32_t rows_per_buffer, std::shared_ptr<DatasetOp> *shuffle_op) { | |||||
| std::shared_ptr<ShuffleOp> new_shuffle_op = nullptr; | |||||
| int64_t shuffle_size = 0; | |||||
| RETURN_EMPTY_IF_ERROR(ComputeShuffleSize(num_files, num_devices, num_rows, total_rows, &shuffle_size)); | |||||
| MS_LOG(INFO) << "Dataset::AddShuffleOp - num_rows: " << num_rows << ", shuffle_size: " << shuffle_size; | |||||
| // Add the shuffle op | |||||
| *shuffle_op = std::make_shared<ShuffleOp>(shuffle_size, GetSeed(), connector_que_size, true, rows_per_buffer); | |||||
| return Status::OK(); | |||||
| } | |||||
| // Helper function to validate dataset directory parameter | |||||
| Status ValidateDatasetDirParam(const std::string &dataset_name, std::string dataset_dir) { | |||||
| if (dataset_dir.empty()) { | |||||
| std::string err_msg = dataset_name + ": dataset_dir is not specified."; | |||||
| MS_LOG(ERROR) << err_msg; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| Path dir(dataset_dir); | |||||
| if (!dir.IsDirectory()) { | |||||
| std::string err_msg = dataset_name + ": dataset_dir: [" + dataset_dir + "] is an invalid directory path."; | |||||
| MS_LOG(ERROR) << err_msg; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| if (access(dataset_dir.c_str(), R_OK) == -1) { | |||||
| std::string err_msg = dataset_name + ": No access to specified dataset path: " + dataset_dir; | |||||
| MS_LOG(ERROR) << err_msg; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| // Helper function to validate dataset files parameter | |||||
| Status ValidateDatasetFilesParam(const std::string &dataset_name, const std::vector<std::string> &dataset_files) { | |||||
| if (dataset_files.empty()) { | |||||
| std::string err_msg = dataset_name + ": dataset_files is not specified."; | |||||
| MS_LOG(ERROR) << err_msg; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| for (auto f : dataset_files) { | |||||
| Path dataset_file(f); | |||||
| if (!dataset_file.Exists()) { | |||||
| std::string err_msg = dataset_name + ": dataset file: [" + f + "] is invalid or does not exist."; | |||||
| MS_LOG(ERROR) << err_msg; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| if (access(dataset_file.toString().c_str(), R_OK) == -1) { | |||||
| std::string err_msg = dataset_name + ": No access to specified dataset file: " + f; | |||||
| MS_LOG(ERROR) << err_msg; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| // Helper function to validate dataset num_shards and shard_id parameters | |||||
| Status ValidateDatasetShardParams(const std::string &dataset_name, int32_t num_shards, int32_t shard_id) { | |||||
| if (num_shards <= 0) { | |||||
| std::string err_msg = dataset_name + ": Invalid num_shards: " + std::to_string(num_shards); | |||||
| MS_LOG(ERROR) << err_msg; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| if (shard_id < 0 || shard_id >= num_shards) { | |||||
| // num_shards; | |||||
| std::string err_msg = dataset_name + ": Invalid input, shard_id: " + std::to_string(shard_id) + | |||||
| ", num_shards: " + std::to_string(num_shards); | |||||
| MS_LOG(ERROR) << err_msg; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| // Helper function to validate dataset sampler parameter | |||||
| Status ValidateDatasetSampler(const std::string &dataset_name, const std::shared_ptr<SamplerObj> &sampler) { | |||||
| if (sampler == nullptr) { | |||||
| std::string err_msg = dataset_name + ": Sampler is not constructed correctly, sampler: nullptr"; | |||||
| MS_LOG(ERROR) << err_msg; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| Status ValidateStringValue(const std::string &dataset_name, const std::string &str, | |||||
| const std::unordered_set<std::string> &valid_strings) { | |||||
| if (valid_strings.find(str) == valid_strings.end()) { | |||||
| std::string mode; | |||||
| mode = std::accumulate(valid_strings.begin(), valid_strings.end(), mode, | |||||
| [](std::string a, std::string b) { return std::move(a) + " " + std::move(b); }); | |||||
| std::string err_msg = dataset_name + ": " + str + " does not match any mode in [" + mode + " ]"; | |||||
| MS_LOG(ERROR) << err_msg; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| // Helper function to validate dataset input/output column parameter | |||||
| Status ValidateDatasetColumnParam(const std::string &dataset_name, const std::string &column_param, | |||||
| const std::vector<std::string> &columns) { | |||||
| if (columns.empty()) { | |||||
| std::string err_msg = dataset_name + ":" + column_param + " should not be empty string"; | |||||
| MS_LOG(ERROR) << err_msg; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| for (uint32_t i = 0; i < columns.size(); ++i) { | |||||
| if (columns[i].empty()) { | |||||
| std::string err_msg = dataset_name + ":" + column_param + "[" + std::to_string(i) + "] must not be empty"; | |||||
| MS_LOG(ERROR) << err_msg; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| } | |||||
| std::set<std::string> columns_set(columns.begin(), columns.end()); | |||||
| if (columns_set.size() != columns.size()) { | |||||
| std::string err_msg = dataset_name + ":" + column_param + ": Every column name should not be same with others"; | |||||
| MS_LOG(ERROR) << err_msg; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| #ifndef ENABLE_ANDROID | #ifndef ENABLE_ANDROID | ||||
| std::shared_ptr<DatasetCache> CreateDatasetCache(session_id_type id, uint64_t mem_sz, bool spill, | std::shared_ptr<DatasetCache> CreateDatasetCache(session_id_type id, uint64_t mem_sz, bool spill, | ||||
| @@ -1153,22 +977,5 @@ TFRecordDataset::TFRecordDataset(const std::vector<std::string> &dataset_files, | |||||
| ir_node_ = std::static_pointer_cast<DatasetNode>(ds); | ir_node_ = std::static_pointer_cast<DatasetNode>(ds); | ||||
| } | } | ||||
| #endif | #endif | ||||
| std::shared_ptr<SamplerObj> SelectSampler(int64_t num_samples, bool shuffle, int32_t num_shards, int32_t shard_id) { | |||||
| if (shuffle) { | |||||
| if (num_shards > 1) { | |||||
| // If shuffle enabled, sharding enabled, use distributed random sampler | |||||
| return DistributedSampler(num_shards, shard_id, shuffle, num_samples); | |||||
| } | |||||
| // If shuffle enabled, sharding disabled, use random sampler | |||||
| return RandomSampler(num_samples >= 0, num_samples); | |||||
| } | |||||
| if (num_shards > 1) { | |||||
| // If shuffle disabled, sharding enabled, use distributed sequential sampler | |||||
| return DistributedSampler(num_shards, shard_id, shuffle, num_samples); | |||||
| } | |||||
| // If shuffle disabled, sharding disabled, use sequential sampler | |||||
| return SequentialSampler(0, num_samples); | |||||
| } | |||||
| } // namespace api | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -26,7 +26,6 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| namespace api { | |||||
| Execute::Execute(std::shared_ptr<TensorOperation> op) : op_(std::move(op)) {} | Execute::Execute(std::shared_ptr<TensorOperation> op) : op_(std::move(op)) {} | ||||
| @@ -54,6 +53,5 @@ std::shared_ptr<tensor::MSTensor> Execute::operator()(std::shared_ptr<tensor::MS | |||||
| return std::make_shared<tensor::DETensor>(std::move(de_output)); | return std::make_shared<tensor::DETensor>(std::move(de_output)); | ||||
| } | } | ||||
| } // namespace api | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -20,7 +20,6 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| namespace api { | |||||
| // Get the next row from the data pipeline. | // Get the next row from the data pipeline. | ||||
| bool Iterator::GetNextRow(TensorMap *row) { | bool Iterator::GetNextRow(TensorMap *row) { | ||||
| @@ -45,19 +44,18 @@ bool Iterator::GetNextRow(TensorVec *row) { | |||||
| } | } | ||||
| // Shut down the data pipeline. | // Shut down the data pipeline. | ||||
| void Iterator::Stop() { runtime_context->Terminate(); } | |||||
| void Iterator::Stop() { runtime_context_->Terminate(); } | |||||
| // Function to build and launch the execution tree. | // Function to build and launch the execution tree. | ||||
| Status Iterator::BuildAndLaunchTree(std::shared_ptr<Dataset> ds) { | Status Iterator::BuildAndLaunchTree(std::shared_ptr<Dataset> ds) { | ||||
| runtime_context = std::make_unique<RuntimeContext>(); | |||||
| RETURN_IF_NOT_OK(runtime_context->Init()); | |||||
| runtime_context_ = std::make_unique<RuntimeContext>(); | |||||
| RETURN_IF_NOT_OK(runtime_context_->Init()); | |||||
| auto consumer = std::make_unique<IteratorConsumer>(); | auto consumer = std::make_unique<IteratorConsumer>(); | ||||
| consumer_ = consumer.get(); | consumer_ = consumer.get(); | ||||
| RETURN_IF_NOT_OK(consumer->Init(ds->IRNode())); | RETURN_IF_NOT_OK(consumer->Init(ds->IRNode())); | ||||
| runtime_context->AssignConsumer(std::move(consumer)); | |||||
| runtime_context_->AssignConsumer(std::move(consumer)); | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| } // namespace api | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -27,59 +27,59 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| PYBIND_REGISTER(Sampler, 0, ([](const py::module *m) { | |||||
| (void)py::class_<Sampler, std::shared_ptr<Sampler>>(*m, "Sampler") | |||||
| PYBIND_REGISTER(SamplerRT, 0, ([](const py::module *m) { | |||||
| (void)py::class_<SamplerRT, std::shared_ptr<SamplerRT>>(*m, "Sampler") | |||||
| .def("set_num_rows", | .def("set_num_rows", | ||||
| [](Sampler &self, int64_t rows) { THROW_IF_ERROR(self.SetNumRowsInDataset(rows)); }) | |||||
| [](SamplerRT &self, int64_t rows) { THROW_IF_ERROR(self.SetNumRowsInDataset(rows)); }) | |||||
| .def("set_num_samples", | .def("set_num_samples", | ||||
| [](Sampler &self, int64_t samples) { THROW_IF_ERROR(self.SetNumSamples(samples)); }) | |||||
| .def("initialize", [](Sampler &self) { THROW_IF_ERROR(self.InitSampler()); }) | |||||
| [](SamplerRT &self, int64_t samples) { THROW_IF_ERROR(self.SetNumSamples(samples)); }) | |||||
| .def("initialize", [](SamplerRT &self) { THROW_IF_ERROR(self.InitSampler()); }) | |||||
| .def("get_indices", | .def("get_indices", | ||||
| [](Sampler &self) { | |||||
| [](SamplerRT &self) { | |||||
| py::array ret; | py::array ret; | ||||
| THROW_IF_ERROR(self.GetAllIdsThenReset(&ret)); | THROW_IF_ERROR(self.GetAllIdsThenReset(&ret)); | ||||
| return ret; | return ret; | ||||
| }) | }) | ||||
| .def("add_child", [](std::shared_ptr<Sampler> self, std::shared_ptr<Sampler> child) { | |||||
| .def("add_child", [](std::shared_ptr<SamplerRT> self, std::shared_ptr<SamplerRT> child) { | |||||
| THROW_IF_ERROR(self->AddChild(child)); | THROW_IF_ERROR(self->AddChild(child)); | ||||
| }); | }); | ||||
| })); | })); | ||||
| PYBIND_REGISTER(DistributedSampler, 1, ([](const py::module *m) { | |||||
| (void)py::class_<DistributedSampler, Sampler, std::shared_ptr<DistributedSampler>>( | |||||
| PYBIND_REGISTER(DistributedSamplerRT, 1, ([](const py::module *m) { | |||||
| (void)py::class_<DistributedSamplerRT, SamplerRT, std::shared_ptr<DistributedSamplerRT>>( | |||||
| *m, "DistributedSampler") | *m, "DistributedSampler") | ||||
| .def(py::init<int64_t, int64_t, int64_t, bool, uint32_t, int64_t>()); | .def(py::init<int64_t, int64_t, int64_t, bool, uint32_t, int64_t>()); | ||||
| })); | })); | ||||
| PYBIND_REGISTER(PKSampler, 1, ([](const py::module *m) { | |||||
| (void)py::class_<PKSampler, Sampler, std::shared_ptr<PKSampler>>(*m, "PKSampler") | |||||
| PYBIND_REGISTER(PKSamplerRT, 1, ([](const py::module *m) { | |||||
| (void)py::class_<PKSamplerRT, SamplerRT, std::shared_ptr<PKSamplerRT>>(*m, "PKSampler") | |||||
| .def(py::init<int64_t, int64_t, bool>()); | .def(py::init<int64_t, int64_t, bool>()); | ||||
| })); | })); | ||||
| PYBIND_REGISTER(PythonSampler, 1, ([](const py::module *m) { | |||||
| (void)py::class_<PythonSampler, Sampler, std::shared_ptr<PythonSampler>>(*m, "PythonSampler") | |||||
| PYBIND_REGISTER(PythonSamplerRT, 1, ([](const py::module *m) { | |||||
| (void)py::class_<PythonSamplerRT, SamplerRT, std::shared_ptr<PythonSamplerRT>>(*m, "PythonSampler") | |||||
| .def(py::init<int64_t, py::object>()); | .def(py::init<int64_t, py::object>()); | ||||
| })); | })); | ||||
| PYBIND_REGISTER(RandomSampler, 1, ([](const py::module *m) { | |||||
| (void)py::class_<RandomSampler, Sampler, std::shared_ptr<RandomSampler>>(*m, "RandomSampler") | |||||
| PYBIND_REGISTER(RandomSamplerRT, 1, ([](const py::module *m) { | |||||
| (void)py::class_<RandomSamplerRT, SamplerRT, std::shared_ptr<RandomSamplerRT>>(*m, "RandomSampler") | |||||
| .def(py::init<int64_t, bool, bool>()); | .def(py::init<int64_t, bool, bool>()); | ||||
| })); | })); | ||||
| PYBIND_REGISTER(SequentialSampler, 1, ([](const py::module *m) { | |||||
| (void)py::class_<SequentialSampler, Sampler, std::shared_ptr<SequentialSampler>>(*m, | |||||
| "SequentialSampler") | |||||
| PYBIND_REGISTER(SequentialSamplerRT, 1, ([](const py::module *m) { | |||||
| (void)py::class_<SequentialSamplerRT, SamplerRT, std::shared_ptr<SequentialSamplerRT>>( | |||||
| *m, "SequentialSampler") | |||||
| .def(py::init<int64_t, int64_t>()); | .def(py::init<int64_t, int64_t>()); | ||||
| })); | })); | ||||
| PYBIND_REGISTER(SubsetRandomSampler, 1, ([](const py::module *m) { | |||||
| (void)py::class_<SubsetRandomSampler, Sampler, std::shared_ptr<SubsetRandomSampler>>( | |||||
| PYBIND_REGISTER(SubsetRandomSamplerRT, 1, ([](const py::module *m) { | |||||
| (void)py::class_<SubsetRandomSamplerRT, SamplerRT, std::shared_ptr<SubsetRandomSamplerRT>>( | |||||
| *m, "SubsetRandomSampler") | *m, "SubsetRandomSampler") | ||||
| .def(py::init<int64_t, std::vector<int64_t>>()); | .def(py::init<int64_t, std::vector<int64_t>>()); | ||||
| })); | })); | ||||
| PYBIND_REGISTER(WeightedRandomSampler, 1, ([](const py::module *m) { | |||||
| (void)py::class_<WeightedRandomSampler, Sampler, std::shared_ptr<WeightedRandomSampler>>( | |||||
| PYBIND_REGISTER(WeightedRandomSamplerRT, 1, ([](const py::module *m) { | |||||
| (void)py::class_<WeightedRandomSamplerRT, SamplerRT, std::shared_ptr<WeightedRandomSamplerRT>>( | |||||
| *m, "WeightedRandomSampler") | *m, "WeightedRandomSampler") | ||||
| .def(py::init<int64_t, std::vector<double>, bool>()); | .def(py::init<int64_t, std::vector<double>, bool>()); | ||||
| })); | })); | ||||
| @@ -1140,7 +1140,7 @@ Status DEPipeline::ParseConcatOp(const py::dict &args, std::shared_ptr<DatasetOp | |||||
| if (!value.is_none()) { | if (!value.is_none()) { | ||||
| if (key == "sampler") { | if (key == "sampler") { | ||||
| auto create = py::reinterpret_borrow<py::object>(value).attr("create"); | auto create = py::reinterpret_borrow<py::object>(value).attr("create"); | ||||
| std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>(); | |||||
| std::shared_ptr<SamplerRT> sampler = create().cast<std::shared_ptr<SamplerRT>>(); | |||||
| (void)builder->SetSampler(std::move(sampler)); | (void)builder->SetSampler(std::move(sampler)); | ||||
| } | } | ||||
| if (key == "children_flag_and_nums") { | if (key == "children_flag_and_nums") { | ||||
| @@ -1164,7 +1164,7 @@ Status DEPipeline::ParseTFReaderOp(const py::dict &args, std::shared_ptr<Dataset | |||||
| // Required arguments | // Required arguments | ||||
| std::vector<std::string> files_list; | std::vector<std::string> files_list; | ||||
| std::shared_ptr<CacheClient> cache_client = nullptr; | std::shared_ptr<CacheClient> cache_client = nullptr; | ||||
| std::shared_ptr<Sampler> sampler = nullptr; | |||||
| std::shared_ptr<SamplerRT> sampler = nullptr; | |||||
| int num_workers = 0; | int num_workers = 0; | ||||
| std::shared_ptr<TFReaderOp::Builder> builder = std::make_shared<TFReaderOp::Builder>(); | std::shared_ptr<TFReaderOp::Builder> builder = std::make_shared<TFReaderOp::Builder>(); | ||||
| if (!args["dataset_files"].is_none()) { | if (!args["dataset_files"].is_none()) { | ||||
| @@ -1210,7 +1210,7 @@ Status DEPipeline::ParseTFReaderOp(const py::dict &args, std::shared_ptr<Dataset | |||||
| cache_client = value.cast<std::shared_ptr<CacheClient>>(); | cache_client = value.cast<std::shared_ptr<CacheClient>>(); | ||||
| } else if (key == "sampler") { | } else if (key == "sampler") { | ||||
| auto create = py::reinterpret_borrow<py::object>(value).attr("create"); | auto create = py::reinterpret_borrow<py::object>(value).attr("create"); | ||||
| sampler = create().cast<std::shared_ptr<Sampler>>(); | |||||
| sampler = create().cast<std::shared_ptr<SamplerRT>>(); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -1234,7 +1234,7 @@ Status DEPipeline::ParseTFReaderOp(const py::dict &args, std::shared_ptr<Dataset | |||||
| } else if (cache_client) { | } else if (cache_client) { | ||||
| const int64_t num_samples = 0; | const int64_t num_samples = 0; | ||||
| const int64_t start_index = 0; | const int64_t start_index = 0; | ||||
| sampler = std::make_shared<SequentialSampler>(num_samples, start_index); | |||||
| sampler = std::make_shared<SequentialSamplerRT>(num_samples, start_index); | |||||
| (void)builder->SetSampler(std::move(sampler)); | (void)builder->SetSampler(std::move(sampler)); | ||||
| } | } | ||||
| @@ -1308,7 +1308,7 @@ Status DEPipeline::ParseImageFolderOp(const py::dict &args, std::shared_ptr<Data | |||||
| (void)builder->SetNumWorkers(num_workers); | (void)builder->SetNumWorkers(num_workers); | ||||
| } else if (key == "sampler") { | } else if (key == "sampler") { | ||||
| auto create = py::reinterpret_borrow<py::object>(value).attr("create"); | auto create = py::reinterpret_borrow<py::object>(value).attr("create"); | ||||
| std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>(); | |||||
| std::shared_ptr<SamplerRT> sampler = create().cast<std::shared_ptr<SamplerRT>>(); | |||||
| (void)builder->SetSampler(std::move(sampler)); | (void)builder->SetSampler(std::move(sampler)); | ||||
| } else if (key == "extensions") { | } else if (key == "extensions") { | ||||
| (void)builder->SetExtensions(ToStringSet(value)); | (void)builder->SetExtensions(ToStringSet(value)); | ||||
| @@ -1363,7 +1363,7 @@ Status DEPipeline::ParseManifestOp(const py::dict &args, std::shared_ptr<Dataset | |||||
| (void)builder->SetNumWorkers(num_workers); | (void)builder->SetNumWorkers(num_workers); | ||||
| } else if (key == "sampler") { | } else if (key == "sampler") { | ||||
| auto create = py::reinterpret_borrow<py::object>(value).attr("create"); | auto create = py::reinterpret_borrow<py::object>(value).attr("create"); | ||||
| std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>(); | |||||
| std::shared_ptr<SamplerRT> sampler = create().cast<std::shared_ptr<SamplerRT>>(); | |||||
| (void)builder->SetSampler(std::move(sampler)); | (void)builder->SetSampler(std::move(sampler)); | ||||
| } else if (key == "class_indexing") { | } else if (key == "class_indexing") { | ||||
| (void)builder->SetClassIndex(ToStringMap(value)); | (void)builder->SetClassIndex(ToStringMap(value)); | ||||
| @@ -1416,7 +1416,7 @@ Status DEPipeline::ParseVOCOp(const py::dict &args, std::shared_ptr<DatasetOp> * | |||||
| (void)builder->SetNumWorkers(num_workers); | (void)builder->SetNumWorkers(num_workers); | ||||
| } else if (key == "sampler") { | } else if (key == "sampler") { | ||||
| auto create = py::reinterpret_borrow<py::object>(value).attr("create"); | auto create = py::reinterpret_borrow<py::object>(value).attr("create"); | ||||
| std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>(); | |||||
| std::shared_ptr<SamplerRT> sampler = create().cast<std::shared_ptr<SamplerRT>>(); | |||||
| (void)builder->SetSampler(std::move(sampler)); | (void)builder->SetSampler(std::move(sampler)); | ||||
| } else if (key == "decode") { | } else if (key == "decode") { | ||||
| (void)builder->SetDecode(ToBool(value)); | (void)builder->SetDecode(ToBool(value)); | ||||
| @@ -1478,7 +1478,7 @@ Status DEPipeline::ParseCocoOp(const py::dict &args, std::shared_ptr<DatasetOp> | |||||
| (void)builder->SetNumWorkers(num_workers); | (void)builder->SetNumWorkers(num_workers); | ||||
| } else if (key == "sampler") { | } else if (key == "sampler") { | ||||
| auto create = py::reinterpret_borrow<py::object>(value).attr("create"); | auto create = py::reinterpret_borrow<py::object>(value).attr("create"); | ||||
| std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>(); | |||||
| std::shared_ptr<SamplerRT> sampler = create().cast<std::shared_ptr<SamplerRT>>(); | |||||
| (void)builder->SetSampler(std::move(sampler)); | (void)builder->SetSampler(std::move(sampler)); | ||||
| } else if (key == "decode") { | } else if (key == "decode") { | ||||
| (void)builder->SetDecode(ToBool(value)); | (void)builder->SetDecode(ToBool(value)); | ||||
| @@ -1529,7 +1529,7 @@ Status DEPipeline::ParseCifar10Op(const py::dict &args, std::shared_ptr<DatasetO | |||||
| (void)builder->SetNumWorkers(num_workers); | (void)builder->SetNumWorkers(num_workers); | ||||
| } else if (key == "sampler") { | } else if (key == "sampler") { | ||||
| auto create = py::reinterpret_borrow<py::object>(value).attr("create"); | auto create = py::reinterpret_borrow<py::object>(value).attr("create"); | ||||
| std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>(); | |||||
| std::shared_ptr<SamplerRT> sampler = create().cast<std::shared_ptr<SamplerRT>>(); | |||||
| (void)builder->SetSampler(std::move(sampler)); | (void)builder->SetSampler(std::move(sampler)); | ||||
| } else if (key == "usage") { | } else if (key == "usage") { | ||||
| (void)builder->SetUsage(ToString(value)); | (void)builder->SetUsage(ToString(value)); | ||||
| @@ -1583,7 +1583,7 @@ Status DEPipeline::ParseCifar100Op(const py::dict &args, std::shared_ptr<Dataset | |||||
| (void)builder->SetNumWorkers(num_workers); | (void)builder->SetNumWorkers(num_workers); | ||||
| } else if (key == "sampler") { | } else if (key == "sampler") { | ||||
| auto create = py::reinterpret_borrow<py::object>(value).attr("create"); | auto create = py::reinterpret_borrow<py::object>(value).attr("create"); | ||||
| std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>(); | |||||
| std::shared_ptr<SamplerRT> sampler = create().cast<std::shared_ptr<SamplerRT>>(); | |||||
| (void)builder->SetSampler(std::move(sampler)); | (void)builder->SetSampler(std::move(sampler)); | ||||
| } else if (key == "usage") { | } else if (key == "usage") { | ||||
| (void)builder->SetUsage(ToString(value)); | (void)builder->SetUsage(ToString(value)); | ||||
| @@ -1618,7 +1618,7 @@ Status DEPipeline::ParseRandomDataOp(const py::dict &args, std::shared_ptr<Datas | |||||
| // Required arguments | // Required arguments | ||||
| RandomDataOp::Builder builder; | RandomDataOp::Builder builder; | ||||
| std::shared_ptr<CacheClient> cache_client = nullptr; | std::shared_ptr<CacheClient> cache_client = nullptr; | ||||
| std::shared_ptr<Sampler> sampler = nullptr; | |||||
| std::shared_ptr<SamplerRT> sampler = nullptr; | |||||
| int num_workers = 0; | int num_workers = 0; | ||||
| if (args["total_rows"].is_none()) { | if (args["total_rows"].is_none()) { | ||||
| @@ -1646,7 +1646,7 @@ Status DEPipeline::ParseRandomDataOp(const py::dict &args, std::shared_ptr<Datas | |||||
| cache_client = value.cast<std::shared_ptr<CacheClient>>(); | cache_client = value.cast<std::shared_ptr<CacheClient>>(); | ||||
| } else if (key == "sampler") { | } else if (key == "sampler") { | ||||
| auto create = py::reinterpret_borrow<py::object>(value).attr("create"); | auto create = py::reinterpret_borrow<py::object>(value).attr("create"); | ||||
| sampler = create().cast<std::shared_ptr<Sampler>>(); | |||||
| sampler = create().cast<std::shared_ptr<SamplerRT>>(); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -1670,7 +1670,7 @@ Status DEPipeline::ParseRandomDataOp(const py::dict &args, std::shared_ptr<Datas | |||||
| } else if (cache_client) { | } else if (cache_client) { | ||||
| const int64_t num_samples = 0; | const int64_t num_samples = 0; | ||||
| const int64_t start_index = 0; | const int64_t start_index = 0; | ||||
| sampler = std::make_shared<SequentialSampler>(num_samples, start_index); | |||||
| sampler = std::make_shared<SequentialSamplerRT>(num_samples, start_index); | |||||
| (void)builder.SetSampler(std::move(sampler)); | (void)builder.SetSampler(std::move(sampler)); | ||||
| } | } | ||||
| @@ -1715,7 +1715,7 @@ Status DEPipeline::ParseMnistOp(const py::dict &args, std::shared_ptr<DatasetOp> | |||||
| (void)builder->SetNumWorkers(num_workers); | (void)builder->SetNumWorkers(num_workers); | ||||
| } else if (key == "sampler") { | } else if (key == "sampler") { | ||||
| auto create = py::reinterpret_borrow<py::object>(value).attr("create"); | auto create = py::reinterpret_borrow<py::object>(value).attr("create"); | ||||
| std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>(); | |||||
| std::shared_ptr<SamplerRT> sampler = create().cast<std::shared_ptr<SamplerRT>>(); | |||||
| (void)builder->SetSampler(std::move(sampler)); | (void)builder->SetSampler(std::move(sampler)); | ||||
| } else if (key == "usage") { | } else if (key == "usage") { | ||||
| (void)builder->SetUsage(ToString(value)); | (void)builder->SetUsage(ToString(value)); | ||||
| @@ -1768,7 +1768,7 @@ Status DEPipeline::ParseCelebAOp(const py::dict &args, std::shared_ptr<DatasetOp | |||||
| (void)builder->SetNumWorkers(num_workers); | (void)builder->SetNumWorkers(num_workers); | ||||
| } else if (key == "sampler") { | } else if (key == "sampler") { | ||||
| auto create = py::reinterpret_borrow<py::object>(value).attr("create"); | auto create = py::reinterpret_borrow<py::object>(value).attr("create"); | ||||
| std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>(); | |||||
| std::shared_ptr<SamplerRT> sampler = create().cast<std::shared_ptr<SamplerRT>>(); | |||||
| (void)builder->SetSampler(std::move(sampler)); | (void)builder->SetSampler(std::move(sampler)); | ||||
| } else if (key == "decode") { | } else if (key == "decode") { | ||||
| (void)builder->SetDecode(ToBool(value)); | (void)builder->SetDecode(ToBool(value)); | ||||
| @@ -1806,7 +1806,7 @@ Status DEPipeline::ParseTextFileOp(const py::dict &args, std::shared_ptr<Dataset | |||||
| // Required arguments | // Required arguments | ||||
| std::vector<std::string> files_list; | std::vector<std::string> files_list; | ||||
| std::shared_ptr<CacheClient> cache_client = nullptr; | std::shared_ptr<CacheClient> cache_client = nullptr; | ||||
| std::shared_ptr<Sampler> sampler = nullptr; | |||||
| std::shared_ptr<SamplerRT> sampler = nullptr; | |||||
| int num_workers = 0; | int num_workers = 0; | ||||
| std::shared_ptr<TextFileOp::Builder> builder = std::make_shared<TextFileOp::Builder>(); | std::shared_ptr<TextFileOp::Builder> builder = std::make_shared<TextFileOp::Builder>(); | ||||
| if (!args["dataset_files"].is_none()) { | if (!args["dataset_files"].is_none()) { | ||||
| @@ -1840,7 +1840,7 @@ Status DEPipeline::ParseTextFileOp(const py::dict &args, std::shared_ptr<Dataset | |||||
| cache_client = value.cast<std::shared_ptr<CacheClient>>(); | cache_client = value.cast<std::shared_ptr<CacheClient>>(); | ||||
| } else if (key == "sampler") { | } else if (key == "sampler") { | ||||
| auto create = py::reinterpret_borrow<py::object>(value).attr("create"); | auto create = py::reinterpret_borrow<py::object>(value).attr("create"); | ||||
| sampler = create().cast<std::shared_ptr<Sampler>>(); | |||||
| sampler = create().cast<std::shared_ptr<SamplerRT>>(); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -1855,7 +1855,7 @@ Status DEPipeline::ParseTextFileOp(const py::dict &args, std::shared_ptr<Dataset | |||||
| } else if (cache_client) { | } else if (cache_client) { | ||||
| int64_t num_samples = 0; | int64_t num_samples = 0; | ||||
| int64_t start_index = 0; | int64_t start_index = 0; | ||||
| sampler = std::make_shared<SequentialSampler>(num_samples, start_index); | |||||
| sampler = std::make_shared<SequentialSamplerRT>(num_samples, start_index); | |||||
| (void)builder->SetSampler(std::move(sampler)); | (void)builder->SetSampler(std::move(sampler)); | ||||
| } | } | ||||
| @@ -1991,7 +1991,7 @@ Status DEPipeline::ParseClueOp(const py::dict &args, std::shared_ptr<DatasetOp> | |||||
| std::shared_ptr<DatasetOp> *bottom) { | std::shared_ptr<DatasetOp> *bottom) { | ||||
| std::vector<std::string> files_list; | std::vector<std::string> files_list; | ||||
| std::shared_ptr<CacheClient> cache_client = nullptr; | std::shared_ptr<CacheClient> cache_client = nullptr; | ||||
| std::shared_ptr<Sampler> sampler = nullptr; | |||||
| std::shared_ptr<SamplerRT> sampler = nullptr; | |||||
| int num_workers = 0; | int num_workers = 0; | ||||
| std::shared_ptr<ClueOp::Builder> builder = std::make_shared<ClueOp::Builder>(); | std::shared_ptr<ClueOp::Builder> builder = std::make_shared<ClueOp::Builder>(); | ||||
| @@ -2036,7 +2036,7 @@ Status DEPipeline::ParseClueOp(const py::dict &args, std::shared_ptr<DatasetOp> | |||||
| cache_client = value.cast<std::shared_ptr<CacheClient>>(); | cache_client = value.cast<std::shared_ptr<CacheClient>>(); | ||||
| } else if (key == "sampler") { | } else if (key == "sampler") { | ||||
| auto create = py::reinterpret_borrow<py::object>(value).attr("create"); | auto create = py::reinterpret_borrow<py::object>(value).attr("create"); | ||||
| sampler = create().cast<std::shared_ptr<Sampler>>(); | |||||
| sampler = create().cast<std::shared_ptr<SamplerRT>>(); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -2051,7 +2051,7 @@ Status DEPipeline::ParseClueOp(const py::dict &args, std::shared_ptr<DatasetOp> | |||||
| } else if (cache_client) { | } else if (cache_client) { | ||||
| int64_t num_samples = 0; | int64_t num_samples = 0; | ||||
| int64_t start_index = 0; | int64_t start_index = 0; | ||||
| sampler = std::make_shared<SequentialSampler>(num_samples, start_index); | |||||
| sampler = std::make_shared<SequentialSamplerRT>(num_samples, start_index); | |||||
| (void)builder->SetSampler(std::move(sampler)); | (void)builder->SetSampler(std::move(sampler)); | ||||
| } | } | ||||
| @@ -2116,7 +2116,7 @@ Status DEPipeline::ParseCsvOp(const py::dict &args, std::shared_ptr<DatasetOp> * | |||||
| std::shared_ptr<DatasetOp> *bottom) { | std::shared_ptr<DatasetOp> *bottom) { | ||||
| std::vector<std::string> files_list; | std::vector<std::string> files_list; | ||||
| std::shared_ptr<CacheClient> cache_client = nullptr; | std::shared_ptr<CacheClient> cache_client = nullptr; | ||||
| std::shared_ptr<Sampler> sampler = nullptr; | |||||
| std::shared_ptr<SamplerRT> sampler = nullptr; | |||||
| int num_workers = 0; | int num_workers = 0; | ||||
| std::shared_ptr<CsvOp::Builder> builder = std::make_shared<CsvOp::Builder>(); | std::shared_ptr<CsvOp::Builder> builder = std::make_shared<CsvOp::Builder>(); | ||||
| if (!args["dataset_files"].is_none()) { | if (!args["dataset_files"].is_none()) { | ||||
| @@ -2173,7 +2173,7 @@ Status DEPipeline::ParseCsvOp(const py::dict &args, std::shared_ptr<DatasetOp> * | |||||
| cache_client = value.cast<std::shared_ptr<CacheClient>>(); | cache_client = value.cast<std::shared_ptr<CacheClient>>(); | ||||
| } else if (key == "sampler") { | } else if (key == "sampler") { | ||||
| auto create = py::reinterpret_borrow<py::object>(value).attr("create"); | auto create = py::reinterpret_borrow<py::object>(value).attr("create"); | ||||
| sampler = create().cast<std::shared_ptr<Sampler>>(); | |||||
| sampler = create().cast<std::shared_ptr<SamplerRT>>(); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -2188,7 +2188,7 @@ Status DEPipeline::ParseCsvOp(const py::dict &args, std::shared_ptr<DatasetOp> * | |||||
| } else if (cache_client) { | } else if (cache_client) { | ||||
| int64_t num_samples = 0; | int64_t num_samples = 0; | ||||
| int64_t start_index = 0; | int64_t start_index = 0; | ||||
| sampler = std::make_shared<SequentialSampler>(num_samples, start_index); | |||||
| sampler = std::make_shared<SequentialSamplerRT>(num_samples, start_index); | |||||
| (void)builder->SetSampler(std::move(sampler)); | (void)builder->SetSampler(std::move(sampler)); | ||||
| } | } | ||||
| @@ -35,7 +35,6 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| namespace api { | |||||
| #define RETURN_NULL_IF_ERROR(_s) \ | #define RETURN_NULL_IF_ERROR(_s) \ | ||||
| do { \ | do { \ | ||||
| @@ -151,10 +150,10 @@ bool DistributedSamplerObj::ValidateParams() { | |||||
| return true; | return true; | ||||
| } | } | ||||
| std::shared_ptr<Sampler> DistributedSamplerObj::Build() { | |||||
| std::shared_ptr<SamplerRT> DistributedSamplerObj::Build() { | |||||
| // runtime sampler object | // runtime sampler object | ||||
| auto sampler = std::make_shared<dataset::DistributedSampler>(num_samples_, num_shards_, shard_id_, shuffle_, seed_, | |||||
| offset_, even_dist_); | |||||
| auto sampler = std::make_shared<dataset::DistributedSamplerRT>(num_samples_, num_shards_, shard_id_, shuffle_, seed_, | |||||
| offset_, even_dist_); | |||||
| return sampler; | return sampler; | ||||
| } | } | ||||
| @@ -184,9 +183,9 @@ bool PKSamplerObj::ValidateParams() { | |||||
| return true; | return true; | ||||
| } | } | ||||
| std::shared_ptr<Sampler> PKSamplerObj::Build() { | |||||
| std::shared_ptr<SamplerRT> PKSamplerObj::Build() { | |||||
| // runtime sampler object | // runtime sampler object | ||||
| auto sampler = std::make_shared<dataset::PKSampler>(num_samples_, num_val_, shuffle_); | |||||
| auto sampler = std::make_shared<dataset::PKSamplerRT>(num_samples_, num_val_, shuffle_); | |||||
| return sampler; | return sampler; | ||||
| } | } | ||||
| @@ -218,10 +217,10 @@ bool RandomSamplerObj::ValidateParams() { | |||||
| return true; | return true; | ||||
| } | } | ||||
| std::shared_ptr<Sampler> RandomSamplerObj::Build() { | |||||
| std::shared_ptr<SamplerRT> RandomSamplerObj::Build() { | |||||
| // runtime sampler object | // runtime sampler object | ||||
| bool reshuffle_each_epoch = true; | bool reshuffle_each_epoch = true; | ||||
| auto sampler = std::make_shared<dataset::RandomSampler>(num_samples_, replacement_, reshuffle_each_epoch); | |||||
| auto sampler = std::make_shared<dataset::RandomSamplerRT>(num_samples_, replacement_, reshuffle_each_epoch); | |||||
| return sampler; | return sampler; | ||||
| } | } | ||||
| @@ -255,9 +254,9 @@ bool SequentialSamplerObj::ValidateParams() { | |||||
| return true; | return true; | ||||
| } | } | ||||
| std::shared_ptr<Sampler> SequentialSamplerObj::Build() { | |||||
| std::shared_ptr<SamplerRT> SequentialSamplerObj::Build() { | |||||
| // runtime sampler object | // runtime sampler object | ||||
| auto sampler = std::make_shared<dataset::SequentialSampler>(num_samples_, start_index_); | |||||
| auto sampler = std::make_shared<dataset::SequentialSamplerRT>(num_samples_, start_index_); | |||||
| return sampler; | return sampler; | ||||
| } | } | ||||
| @@ -284,9 +283,9 @@ bool SubsetRandomSamplerObj::ValidateParams() { | |||||
| return true; | return true; | ||||
| } | } | ||||
| std::shared_ptr<Sampler> SubsetRandomSamplerObj::Build() { | |||||
| std::shared_ptr<SamplerRT> SubsetRandomSamplerObj::Build() { | |||||
| // runtime sampler object | // runtime sampler object | ||||
| auto sampler = std::make_shared<dataset::SubsetRandomSampler>(num_samples_, indices_); | |||||
| auto sampler = std::make_shared<dataset::SubsetRandomSamplerRT>(num_samples_, indices_); | |||||
| return sampler; | return sampler; | ||||
| } | } | ||||
| @@ -330,11 +329,10 @@ bool WeightedRandomSamplerObj::ValidateParams() { | |||||
| return true; | return true; | ||||
| } | } | ||||
| std::shared_ptr<Sampler> WeightedRandomSamplerObj::Build() { | |||||
| auto sampler = std::make_shared<dataset::WeightedRandomSampler>(num_samples_, weights_, replacement_); | |||||
| std::shared_ptr<SamplerRT> WeightedRandomSamplerObj::Build() { | |||||
| auto sampler = std::make_shared<dataset::WeightedRandomSamplerRT>(num_samples_, weights_, replacement_); | |||||
| return sampler; | return sampler; | ||||
| } | } | ||||
| } // namespace api | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -22,7 +22,6 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| namespace api { | |||||
| // Transform operations for text. | // Transform operations for text. | ||||
| namespace text { | namespace text { | ||||
| @@ -130,6 +129,5 @@ std::shared_ptr<TensorOp> SentencePieceTokenizerOperation::Build() { | |||||
| } | } | ||||
| } // namespace text | } // namespace text | ||||
| } // namespace api | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -22,7 +22,6 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| namespace api { | |||||
| TensorOperation::TensorOperation() {} | TensorOperation::TensorOperation() {} | ||||
| @@ -94,6 +93,5 @@ Status TypeCastOperation::ValidateParams() { | |||||
| std::shared_ptr<TensorOp> TypeCastOperation::Build() { return std::make_shared<TypeCastOp>(data_type_); } | std::shared_ptr<TensorOp> TypeCastOperation::Build() { return std::make_shared<TypeCastOp>(data_type_); } | ||||
| } // namespace transforms | } // namespace transforms | ||||
| } // namespace api | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -65,7 +65,6 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| namespace api { | |||||
| // Transform operations for computer vision. | // Transform operations for computer vision. | ||||
| namespace vision { | namespace vision { | ||||
| @@ -1702,6 +1701,5 @@ std::shared_ptr<TensorOp> UniformAugOperation::Build() { | |||||
| #endif | #endif | ||||
| } // namespace vision | } // namespace vision | ||||
| } // namespace api | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -34,11 +34,11 @@ namespace mindspore::dataset { | |||||
| // TreeConsumer | // TreeConsumer | ||||
| TreeConsumer::TreeConsumer() { tree_adapter_ = std::make_unique<TreeAdapter>(); } | TreeConsumer::TreeConsumer() { tree_adapter_ = std::make_unique<TreeAdapter>(); } | ||||
| Status TreeConsumer::Init(std::shared_ptr<api::DatasetNode> d) { return tree_adapter_->BuildAndPrepare(std::move(d)); } | |||||
| Status TreeConsumer::Init(std::shared_ptr<DatasetNode> d) { return tree_adapter_->BuildAndPrepare(std::move(d)); } | |||||
| Status TreeConsumer::Terminate() { return tree_adapter_->AllTasks()->DoServiceStop(); } | Status TreeConsumer::Terminate() { return tree_adapter_->AllTasks()->DoServiceStop(); } | ||||
| // IteratorConsumer | // IteratorConsumer | ||||
| Status IteratorConsumer::Init(std::shared_ptr<api::DatasetNode> d) { | |||||
| Status IteratorConsumer::Init(std::shared_ptr<DatasetNode> d) { | |||||
| return tree_adapter_->BuildAndPrepare(std::move(d), num_epochs_); | return tree_adapter_->BuildAndPrepare(std::move(d), num_epochs_); | ||||
| } | } | ||||
| @@ -74,7 +74,7 @@ Status IteratorConsumer::GetNextAsMap(std::unordered_map<std::string, TensorPtr> | |||||
| } | } | ||||
| // ToDevice | // ToDevice | ||||
| Status ToDevice::Init(std::shared_ptr<api::DatasetNode> d) { | |||||
| Status ToDevice::Init(std::shared_ptr<DatasetNode> d) { | |||||
| return tree_adapter_->BuildAndPrepare(std::move(d), num_epochs_); | return tree_adapter_->BuildAndPrepare(std::move(d), num_epochs_); | ||||
| } | } | ||||
| @@ -385,8 +385,8 @@ TreeGetters::TreeGetters() : dataset_size_(-1), init_flag_(false), row_flag_(fal | |||||
| tree_adapter_ = std::make_unique<TreeAdapter>(); | tree_adapter_ = std::make_unique<TreeAdapter>(); | ||||
| } | } | ||||
| Status TreeGetters::Init(std::shared_ptr<api::DatasetNode> d) { | |||||
| Status s = tree_adapter_->BuildAndPrepare(std::move(d)); | |||||
| Status TreeGetters::Init(std::shared_ptr<DatasetNode> d) { | |||||
| Status s = tree_adapter_->BuildAndPrepare(std::move(d), 1); | |||||
| if (!s.IsError()) { | if (!s.IsError()) { | ||||
| init_flag_ = true; | init_flag_ = true; | ||||
| } | } | ||||
| @@ -464,7 +464,7 @@ Status TreeGetters::GetNumClasses(int64_t *num_classes) { | |||||
| RETURN_IF_NOT_OK(root->GetNumClasses(num_classes)); | RETURN_IF_NOT_OK(root->GetNumClasses(num_classes)); | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status BuildVocabConsumer::Init(std::shared_ptr<api::DatasetNode> d) { | |||||
| Status BuildVocabConsumer::Init(std::shared_ptr<DatasetNode> d) { | |||||
| return tree_adapter_->BuildAndPrepare(std::move(d), 1); | return tree_adapter_->BuildAndPrepare(std::move(d), 1); | ||||
| } | } | ||||
| Status BuildVocabConsumer::Start() { | Status BuildVocabConsumer::Start() { | ||||
| @@ -29,10 +29,7 @@ | |||||
| namespace mindspore::dataset { | namespace mindspore::dataset { | ||||
| // Forward declare | // Forward declare | ||||
| class TreeAdapter; | class TreeAdapter; | ||||
| namespace api { | |||||
| class DatasetNode; | class DatasetNode; | ||||
| } | |||||
| /// A base class for tree consumers which would fetch rows from the tree pipeline | /// A base class for tree consumers which would fetch rows from the tree pipeline | ||||
| class TreeConsumer { | class TreeConsumer { | ||||
| @@ -42,7 +39,7 @@ class TreeConsumer { | |||||
| /// Initializes the consumer, this involves constructing and preparing the tree. | /// Initializes the consumer, this involves constructing and preparing the tree. | ||||
| /// \param d The dataset node that represent the root of the IR tree. | /// \param d The dataset node that represent the root of the IR tree. | ||||
| /// \return Status error code. | /// \return Status error code. | ||||
| virtual Status Init(std::shared_ptr<api::DatasetNode> d); | |||||
| virtual Status Init(std::shared_ptr<DatasetNode> d); | |||||
| Status Terminate(); | Status Terminate(); | ||||
| @@ -61,7 +58,7 @@ class IteratorConsumer : public TreeConsumer { | |||||
| /// \param num_epochs number of epochs. Default to -1 (infinite epochs). | /// \param num_epochs number of epochs. Default to -1 (infinite epochs). | ||||
| explicit IteratorConsumer(int32_t num_epochs = -1) : TreeConsumer(), num_epochs_(num_epochs) {} | explicit IteratorConsumer(int32_t num_epochs = -1) : TreeConsumer(), num_epochs_(num_epochs) {} | ||||
| Status Init(std::shared_ptr<api::DatasetNode> d) override; | |||||
| Status Init(std::shared_ptr<DatasetNode> d) override; | |||||
| /// Returns the next row in a vector format | /// Returns the next row in a vector format | ||||
| /// \param[out] out std::vector of Tensors | /// \param[out] out std::vector of Tensors | ||||
| @@ -133,7 +130,7 @@ class ToDevice : public TreeConsumer { | |||||
| explicit ToDevice(bool send_epoch_end, int32_t num_epochs = -1) | explicit ToDevice(bool send_epoch_end, int32_t num_epochs = -1) | ||||
| : TreeConsumer(), send_epoch_end_(send_epoch_end), num_epochs_(num_epochs) {} | : TreeConsumer(), send_epoch_end_(send_epoch_end), num_epochs_(num_epochs) {} | ||||
| Status Init(std::shared_ptr<api::DatasetNode> d) override; | |||||
| Status Init(std::shared_ptr<DatasetNode> d) override; | |||||
| /// Send the data to device | /// Send the data to device | ||||
| /// \return Status error code | /// \return Status error code | ||||
| @@ -162,7 +159,7 @@ class ToDevice : public TreeConsumer { | |||||
| class TreeGetters : public TreeConsumer { | class TreeGetters : public TreeConsumer { | ||||
| public: | public: | ||||
| TreeGetters(); | TreeGetters(); | ||||
| Status Init(std::shared_ptr<api::DatasetNode> d) override; | |||||
| Status Init(std::shared_ptr<DatasetNode> d) override; | |||||
| Status GetDatasetSize(int64_t *size); | Status GetDatasetSize(int64_t *size); | ||||
| Status GetOutputTypes(std::vector<DataType> *types); | Status GetOutputTypes(std::vector<DataType> *types); | ||||
| Status GetOutputShapes(std::vector<TensorShape> *shapes); | Status GetOutputShapes(std::vector<TensorShape> *shapes); | ||||
| @@ -185,10 +182,9 @@ class BuildVocabConsumer : public TreeConsumer { | |||||
| /// BuildVocabConsumer Constructor which will call the base class default constructor. | /// BuildVocabConsumer Constructor which will call the base class default constructor. | ||||
| BuildVocabConsumer() = default; | BuildVocabConsumer() = default; | ||||
| Status Init(std::shared_ptr<api::DatasetNode> d) override; | |||||
| Status Init(std::shared_ptr<DatasetNode> d) override; | |||||
| /// Save the given dataset to MindRecord format on disk. This is a blocking method (i.e., after returning, all rows | |||||
| /// would be written to disk) | |||||
| /// Start consuming | |||||
| /// \return Status error code | /// \return Status error code | ||||
| Status Start(); | Status Start(); | ||||
| @@ -46,7 +46,7 @@ Status CacheBase::Reset() { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| CacheBase::CacheBase(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf, | CacheBase::CacheBase(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf, | ||||
| std::shared_ptr<CacheClient> cache_client, std::shared_ptr<Sampler> sampler) | |||||
| std::shared_ptr<CacheClient> cache_client, std::shared_ptr<SamplerRT> sampler) | |||||
| : ParallelOp(num_workers, op_connector_size, std::move(sampler)), | : ParallelOp(num_workers, op_connector_size, std::move(sampler)), | ||||
| row_cnt_(0), | row_cnt_(0), | ||||
| num_cache_miss_(0), | num_cache_miss_(0), | ||||
| @@ -46,7 +46,7 @@ class CacheBase : public ParallelOp { | |||||
| /// \param cache_client CacheClient for communication to the CacheServer | /// \param cache_client CacheClient for communication to the CacheServer | ||||
| /// \param sampler Sampler which is mandatory | /// \param sampler Sampler which is mandatory | ||||
| CacheBase(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf, | CacheBase(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf, | ||||
| std::shared_ptr<CacheClient> cache_client, std::shared_ptr<Sampler> sampler); | |||||
| std::shared_ptr<CacheClient> cache_client, std::shared_ptr<SamplerRT> sampler); | |||||
| /// \brief Destructor | /// \brief Destructor | ||||
| ~CacheBase(); | ~CacheBase(); | ||||
| @@ -87,7 +87,7 @@ Status CacheLookupOp::HandshakeRandomAccessOp(const RandomAccessOp *op) { | |||||
| leaf_op_wp_.Set(); | leaf_op_wp_.Set(); | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status CacheLookupOp::InitSampler() { return Sampler::InitSampler(); } | |||||
| Status CacheLookupOp::InitSampler() { return SamplerRT::InitSampler(); } | |||||
| void CacheLookupOp::Print(std::ostream &out, bool show_all) const { CacheBase::Print(out, show_all); } | void CacheLookupOp::Print(std::ostream &out, bool show_all) const { CacheBase::Print(out, show_all); } | ||||
| Status CacheLookupOp::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) { | Status CacheLookupOp::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) { | ||||
| std::vector<row_id_type> cache_miss; | std::vector<row_id_type> cache_miss; | ||||
| @@ -28,7 +28,7 @@ namespace dataset { | |||||
| /// \brief provides a memory/disk cache that acts as a save-point within a mappable dataset. | /// \brief provides a memory/disk cache that acts as a save-point within a mappable dataset. | ||||
| /// \note For non-mappable dataset, please see CacheOp | /// \note For non-mappable dataset, please see CacheOp | ||||
| /// \see CacheOp | /// \see CacheOp | ||||
| class CacheLookupOp : public CacheBase, public Sampler { | |||||
| class CacheLookupOp : public CacheBase, public SamplerRT { | |||||
| public: | public: | ||||
| class Builder { | class Builder { | ||||
| public: | public: | ||||
| @@ -62,7 +62,7 @@ class CacheLookupOp : public CacheBase, public Sampler { | |||||
| /// \brief Setter method. | /// \brief Setter method. | ||||
| /// \return Builder setter method returns reference to the builder. | /// \return Builder setter method returns reference to the builder. | ||||
| Builder &SetSampler(std::shared_ptr<Sampler> sampler) { | |||||
| Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) { | |||||
| build_sampler_ = std::move(sampler); | build_sampler_ = std::move(sampler); | ||||
| return *this; | return *this; | ||||
| } | } | ||||
| @@ -77,7 +77,7 @@ class CacheLookupOp : public CacheBase, public Sampler { | |||||
| int32_t rows_per_buffer_; | int32_t rows_per_buffer_; | ||||
| int32_t build_op_connector_size_; | int32_t build_op_connector_size_; | ||||
| std::shared_ptr<CacheClient> build_cache_client_; | std::shared_ptr<CacheClient> build_cache_client_; | ||||
| std::shared_ptr<Sampler> build_sampler_; | |||||
| std::shared_ptr<SamplerRT> build_sampler_; | |||||
| // Check if the required parameters are set by the builder. | // Check if the required parameters are set by the builder. | ||||
| // \return Status The error code return | // \return Status The error code return | ||||
| @@ -87,8 +87,8 @@ class CacheLookupOp : public CacheBase, public Sampler { | |||||
| /// \note It takes the same argument as the base class. | /// \note It takes the same argument as the base class. | ||||
| /// \see CacheBase | /// \see CacheBase | ||||
| CacheLookupOp(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf, | CacheLookupOp(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf, | ||||
| std::shared_ptr<CacheClient> cache_client, std::shared_ptr<Sampler> sampler) | |||||
| : CacheBase(num_workers, op_connector_size, rows_per_buf, cache_client, sampler), Sampler(*(sampler.get())) {} | |||||
| std::shared_ptr<CacheClient> cache_client, std::shared_ptr<SamplerRT> sampler) | |||||
| : CacheBase(num_workers, op_connector_size, rows_per_buf, cache_client, sampler), SamplerRT(*(sampler.get())) {} | |||||
| ~CacheLookupOp() = default; | ~CacheLookupOp() = default; | ||||
| // As a parallel op, we override these two functions | // As a parallel op, we override these two functions | ||||
| Status operator()() override; | Status operator()() override; | ||||
| @@ -46,7 +46,7 @@ void CacheMergeOp::Print(std::ostream &out, bool show_all) const { | |||||
| } | } | ||||
| CacheMergeOp::CacheMergeOp(int32_t numWorkers, int32_t opConnectorSize, int32_t numCleaners, | CacheMergeOp::CacheMergeOp(int32_t numWorkers, int32_t opConnectorSize, int32_t numCleaners, | ||||
| std::shared_ptr<CacheClient> cache_client, const std::shared_ptr<Sampler> &sampler) | |||||
| std::shared_ptr<CacheClient> cache_client, const std::shared_ptr<SamplerRT> &sampler) | |||||
| : ParallelOp(numWorkers, opConnectorSize, sampler), | : ParallelOp(numWorkers, opConnectorSize, sampler), | ||||
| num_cleaners_(numCleaners), | num_cleaners_(numCleaners), | ||||
| cache_client_(std::move(cache_client)), | cache_client_(std::move(cache_client)), | ||||
| @@ -110,7 +110,7 @@ class CacheMergeOp : public ParallelOp { | |||||
| /// \brief Setter method | /// \brief Setter method | ||||
| /// \param sampler | /// \param sampler | ||||
| /// \return Builder setter method returns reference to the builder. | /// \return Builder setter method returns reference to the builder. | ||||
| Builder &SetSampler(std::shared_ptr<Sampler> sampler) { | |||||
| Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) { | |||||
| build_sampler_ = std::move(sampler); | build_sampler_ = std::move(sampler); | ||||
| return *this; | return *this; | ||||
| } | } | ||||
| @@ -133,7 +133,7 @@ class CacheMergeOp : public ParallelOp { | |||||
| int32_t build_op_connector_size_; | int32_t build_op_connector_size_; | ||||
| int32_t build_num_cleaners_; | int32_t build_num_cleaners_; | ||||
| std::shared_ptr<CacheClient> build_cache_client_; | std::shared_ptr<CacheClient> build_cache_client_; | ||||
| std::shared_ptr<Sampler> build_sampler_; | |||||
| std::shared_ptr<SamplerRT> build_sampler_; | |||||
| /// Check if the required parameters are set by the builder. | /// Check if the required parameters are set by the builder. | ||||
| /// \return Status The error code return | /// \return Status The error code return | ||||
| @@ -147,7 +147,7 @@ class CacheMergeOp : public ParallelOp { | |||||
| /// \param cache_client CacheClient to commmunicate with the Cache server | /// \param cache_client CacheClient to commmunicate with the Cache server | ||||
| /// \param sampler as a derived class of ParallelOp | /// \param sampler as a derived class of ParallelOp | ||||
| CacheMergeOp(int32_t numWorkers, int32_t opConnectorSize, int32_t numCleaners, | CacheMergeOp(int32_t numWorkers, int32_t opConnectorSize, int32_t numCleaners, | ||||
| std::shared_ptr<CacheClient> cache_client, const std::shared_ptr<Sampler> &sampler); | |||||
| std::shared_ptr<CacheClient> cache_client, const std::shared_ptr<SamplerRT> &sampler); | |||||
| ~CacheMergeOp(); | ~CacheMergeOp(); | ||||
| void Print(std::ostream &out, bool show_all) const override; | void Print(std::ostream &out, bool show_all) const override; | ||||
| std::string Name() const override { return kCacheMergeOp; } | std::string Name() const override { return kCacheMergeOp; } | ||||
| @@ -68,7 +68,7 @@ Status CacheOp::Builder::Build(std::shared_ptr<CacheOp> *ptr) { | |||||
| // Constructor of CacheOp | // Constructor of CacheOp | ||||
| CacheOp::CacheOp(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf, | CacheOp::CacheOp(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf, | ||||
| std::shared_ptr<CacheClient> cache_client, std::shared_ptr<Sampler> sampler) | |||||
| std::shared_ptr<CacheClient> cache_client, std::shared_ptr<SamplerRT> sampler) | |||||
| : CacheBase(num_workers, op_connector_size, rows_per_buf, std::move(cache_client), std::move(sampler)), | : CacheBase(num_workers, op_connector_size, rows_per_buf, std::move(cache_client), std::move(sampler)), | ||||
| num_guys_in_(0), | num_guys_in_(0), | ||||
| phase_(Phase::kBuildPhase) {} | phase_(Phase::kBuildPhase) {} | ||||
| @@ -81,7 +81,7 @@ class CacheOp : public CacheBase, public RandomAccessOp { | |||||
| /// \brief Setter method | /// \brief Setter method | ||||
| /// \param sampler | /// \param sampler | ||||
| /// \return Builder setter method returns reference to the builder. | /// \return Builder setter method returns reference to the builder. | ||||
| Builder &SetSampler(std::shared_ptr<Sampler> sampler) { | |||||
| Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) { | |||||
| build_sampler_ = std::move(sampler); | build_sampler_ = std::move(sampler); | ||||
| return *this; | return *this; | ||||
| } | } | ||||
| @@ -96,7 +96,7 @@ class CacheOp : public CacheBase, public RandomAccessOp { | |||||
| int32_t rows_per_buffer_; | int32_t rows_per_buffer_; | ||||
| int32_t build_op_connector_size_; | int32_t build_op_connector_size_; | ||||
| std::shared_ptr<CacheClient> build_cache_client_; | std::shared_ptr<CacheClient> build_cache_client_; | ||||
| std::shared_ptr<Sampler> build_sampler_; | |||||
| std::shared_ptr<SamplerRT> build_sampler_; | |||||
| /// \brief Check if the required parameters are set by the builder. | /// \brief Check if the required parameters are set by the builder. | ||||
| /// \return Status The error code return | /// \return Status The error code return | ||||
| @@ -108,7 +108,7 @@ class CacheOp : public CacheBase, public RandomAccessOp { | |||||
| /// \param num_workers The number of worker threads. | /// \param num_workers The number of worker threads. | ||||
| /// \param op_connector_size The size of each queue in the connector. | /// \param op_connector_size The size of each queue in the connector. | ||||
| CacheOp(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf, | CacheOp(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf, | ||||
| std::shared_ptr<CacheClient> cache_client, std::shared_ptr<Sampler> sampler); | |||||
| std::shared_ptr<CacheClient> cache_client, std::shared_ptr<SamplerRT> sampler); | |||||
| // Destructor | // Destructor | ||||
| ~CacheOp(); | ~CacheOp(); | ||||
| @@ -36,7 +36,7 @@ ConcatOp::Builder::Builder() { | |||||
| // The builder "build" method creates the final object. | // The builder "build" method creates the final object. | ||||
| Status ConcatOp::Builder::Build(std::shared_ptr<ConcatOp> *ptr) { | Status ConcatOp::Builder::Build(std::shared_ptr<ConcatOp> *ptr) { | ||||
| if (builder_sampler_ == nullptr) { | if (builder_sampler_ == nullptr) { | ||||
| builder_sampler_ = std::make_shared<DistributedSampler>(0, 1, 0, false); | |||||
| builder_sampler_ = std::make_shared<DistributedSamplerRT>(0, 1, 0, false); | |||||
| } | } | ||||
| *ptr = std::make_shared<ConcatOp>(builder_op_connector_size_, builder_sampler_, children_flag_and_nums_, | *ptr = std::make_shared<ConcatOp>(builder_op_connector_size_, builder_sampler_, children_flag_and_nums_, | ||||
| children_start_end_index_); | children_start_end_index_); | ||||
| @@ -44,7 +44,7 @@ Status ConcatOp::Builder::Build(std::shared_ptr<ConcatOp> *ptr) { | |||||
| } | } | ||||
| // Constructor of the ConcatOp. | // Constructor of the ConcatOp. | ||||
| ConcatOp::ConcatOp(int32_t op_connector_size, std::shared_ptr<Sampler> sampler, | |||||
| ConcatOp::ConcatOp(int32_t op_connector_size, std::shared_ptr<SamplerRT> sampler, | |||||
| std::vector<std::pair<int, int>> children_flag_and_nums, | std::vector<std::pair<int, int>> children_flag_and_nums, | ||||
| std::vector<std::pair<int, int>> children_start_end_index) | std::vector<std::pair<int, int>> children_start_end_index) | ||||
| : PipelineOp(op_connector_size), | : PipelineOp(op_connector_size), | ||||
| @@ -80,7 +80,7 @@ Status ConcatOp::operator()() { | |||||
| bool is_not_mappable = true; | bool is_not_mappable = true; | ||||
| int num_shard = 1; | int num_shard = 1; | ||||
| int shard_index = 0; | int shard_index = 0; | ||||
| std::shared_ptr<DistributedSampler> distribute_sampler = std::dynamic_pointer_cast<DistributedSampler>(sampler_); | |||||
| std::shared_ptr<DistributedSamplerRT> distribute_sampler = std::dynamic_pointer_cast<DistributedSamplerRT>(sampler_); | |||||
| if (distribute_sampler != nullptr) { | if (distribute_sampler != nullptr) { | ||||
| num_shard = distribute_sampler->GetDeviceNum(); | num_shard = distribute_sampler->GetDeviceNum(); | ||||
| shard_index = distribute_sampler->GetDeviceID(); | shard_index = distribute_sampler->GetDeviceID(); | ||||
| @@ -44,7 +44,7 @@ class ConcatOp : public PipelineOp { | |||||
| // The builder "build" method creates the final object. | // The builder "build" method creates the final object. | ||||
| // @return shared_ptr to the new ConcatOp object | // @return shared_ptr to the new ConcatOp object | ||||
| Status Build(std::shared_ptr<ConcatOp> *); | Status Build(std::shared_ptr<ConcatOp> *); | ||||
| Builder &SetSampler(std::shared_ptr<Sampler> sampler) { | |||||
| Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) { | |||||
| builder_sampler_ = std::move(sampler); | builder_sampler_ = std::move(sampler); | ||||
| return *this; | return *this; | ||||
| } | } | ||||
| @@ -61,7 +61,7 @@ class ConcatOp : public PipelineOp { | |||||
| private: | private: | ||||
| int32_t builder_op_connector_size_; | int32_t builder_op_connector_size_; | ||||
| std::shared_ptr<Sampler> builder_sampler_; | |||||
| std::shared_ptr<SamplerRT> builder_sampler_; | |||||
| std::vector<std::pair<int, int>> children_flag_and_nums_; | std::vector<std::pair<int, int>> children_flag_and_nums_; | ||||
| std::vector<std::pair<int, int>> children_start_end_index_; | std::vector<std::pair<int, int>> children_start_end_index_; | ||||
| }; | }; | ||||
| @@ -70,7 +70,7 @@ class ConcatOp : public PipelineOp { | |||||
| // @note The builder class should be used to call it | // @note The builder class should be used to call it | ||||
| // @param op_connector_size - connector size | // @param op_connector_size - connector size | ||||
| explicit ConcatOp(int32_t op_connector_size); | explicit ConcatOp(int32_t op_connector_size); | ||||
| explicit ConcatOp(int32_t op_connector_size, std::shared_ptr<Sampler> sampler, | |||||
| explicit ConcatOp(int32_t op_connector_size, std::shared_ptr<SamplerRT> sampler, | |||||
| std::vector<std::pair<int, int>> children_flag_and_nums, | std::vector<std::pair<int, int>> children_flag_and_nums, | ||||
| std::vector<std::pair<int, int>> children_start_end_index); | std::vector<std::pair<int, int>> children_start_end_index); | ||||
| @@ -123,7 +123,7 @@ class ConcatOp : public PipelineOp { | |||||
| std::unordered_map<std::string, int32_t> column_name_id_; // Mapping between col index and col name | std::unordered_map<std::string, int32_t> column_name_id_; // Mapping between col index and col name | ||||
| std::vector<DataType> data_type_; | std::vector<DataType> data_type_; | ||||
| std::vector<dsize_t> data_rank_; | std::vector<dsize_t> data_rank_; | ||||
| std::shared_ptr<Sampler> sampler_; | |||||
| std::shared_ptr<SamplerRT> sampler_; | |||||
| std::vector<std::pair<int, int>> children_flag_and_nums_; | std::vector<std::pair<int, int>> children_flag_and_nums_; | ||||
| std::vector<std::pair<int, int>> children_start_end_index_; | std::vector<std::pair<int, int>> children_start_end_index_; | ||||
| }; | }; | ||||
| @@ -40,7 +40,7 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| // Constructor | // Constructor | ||||
| DatasetOp::DatasetOp(int32_t op_connector_size, std::shared_ptr<Sampler> sampler) | |||||
| DatasetOp::DatasetOp(int32_t op_connector_size, std::shared_ptr<SamplerRT> sampler) | |||||
| : oc_queue_size_(op_connector_size), | : oc_queue_size_(op_connector_size), | ||||
| sampler_(sampler), | sampler_(sampler), | ||||
| operator_id_(kInvalidOperatorId), | operator_id_(kInvalidOperatorId), | ||||
| @@ -409,7 +409,7 @@ Status DatasetOp::Accept(NodePass *p, bool *modified) { | |||||
| } | } | ||||
| // Getter for the sampler, and it also removes the sampler from the op | // Getter for the sampler, and it also removes the sampler from the op | ||||
| Status DatasetOp::FetchRemoveSampler(std::shared_ptr<Sampler> *sampler) { | |||||
| Status DatasetOp::FetchRemoveSampler(std::shared_ptr<SamplerRT> *sampler) { | |||||
| *sampler = sampler_; // It's okay if it sampler_ points to nullptr | *sampler = sampler_; // It's okay if it sampler_ points to nullptr | ||||
| sampler_.reset(); // clear our member-copy of this pointer. We no longer have this sampler | sampler_.reset(); // clear our member-copy of this pointer. We no longer have this sampler | ||||
| return Status::OK(); | return Status::OK(); | ||||
| @@ -62,7 +62,7 @@ class DataBuffer; | |||||
| class NodePass; | class NodePass; | ||||
| class Sampler; | |||||
| class SamplerRT; | |||||
| /// \brief The base class DatasetOp is the main tree node. It is an abstract class, so | /// \brief The base class DatasetOp is the main tree node. It is an abstract class, so | ||||
| /// the actual implementation of the operators will be derived from here. | /// the actual implementation of the operators will be derived from here. | ||||
| @@ -80,7 +80,7 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> { | |||||
| /// Constructor | /// Constructor | ||||
| /// \param op_connector_size - The size for the output connector of this operator. | /// \param op_connector_size - The size for the output connector of this operator. | ||||
| /// \param sampler - The sampler for the op | /// \param sampler - The sampler for the op | ||||
| explicit DatasetOp(int32_t op_connector_size, std::shared_ptr<Sampler> sampler); | |||||
| explicit DatasetOp(int32_t op_connector_size, std::shared_ptr<SamplerRT> sampler); | |||||
| /// Destructor | /// Destructor | ||||
| virtual ~DatasetOp() { tree_ = nullptr; } | virtual ~DatasetOp() { tree_ = nullptr; } | ||||
| @@ -347,12 +347,12 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> { | |||||
| /// Getter for the sampler | /// Getter for the sampler | ||||
| /// \return Shared pointer to the sampler (may return nullptr) | /// \return Shared pointer to the sampler (may return nullptr) | ||||
| std::shared_ptr<Sampler> sampler() { return sampler_; } | |||||
| std::shared_ptr<SamplerRT> sampler() { return sampler_; } | |||||
| /// \brief Getter for the sampler, and it also removes the sampler from the op | /// \brief Getter for the sampler, and it also removes the sampler from the op | ||||
| /// \param[out] sampler A pointer to the output sampler that was removed | /// \param[out] sampler A pointer to the output sampler that was removed | ||||
| /// \return Status error code | /// \return Status error code | ||||
| Status FetchRemoveSampler(std::shared_ptr<Sampler> *sampler); | |||||
| Status FetchRemoveSampler(std::shared_ptr<SamplerRT> *sampler); | |||||
| #ifndef ENABLE_ANDROID | #ifndef ENABLE_ANDROID | ||||
| // Computes a CRC value for the operator | // Computes a CRC value for the operator | ||||
| @@ -368,7 +368,7 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> { | |||||
| } | } | ||||
| /// \brief Setter for the sampler. Allows you to overwrite a previous sampler with a new one. | /// \brief Setter for the sampler. Allows you to overwrite a previous sampler with a new one. | ||||
| void SetSampler(std::shared_ptr<Sampler> sampler) { sampler_ = sampler; } | |||||
| void SetSampler(std::shared_ptr<SamplerRT> sampler) { sampler_ = sampler; } | |||||
| /// \brief Checks if this is a leaf node (0 children) | /// \brief Checks if this is a leaf node (0 children) | ||||
| /// \return boolean returns true if it's a leaf | /// \return boolean returns true if it's a leaf | ||||
| @@ -409,7 +409,7 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> { | |||||
| std::vector<std::shared_ptr<DatasetOp>> child_; // Child nodes | std::vector<std::shared_ptr<DatasetOp>> child_; // Child nodes | ||||
| std::vector<DatasetOp *> parent_; // Parent nodes. No ownership | std::vector<DatasetOp *> parent_; // Parent nodes. No ownership | ||||
| std::shared_ptr<Sampler> sampler_; // Some leaf ops might have a sampler | |||||
| std::shared_ptr<SamplerRT> sampler_; // Some leaf ops might have a sampler | |||||
| int32_t oc_queue_size_; // Capacity for each out_connector_ | int32_t oc_queue_size_; // Capacity for each out_connector_ | ||||
| int32_t operator_id_; // Generated id for the node | int32_t operator_id_; // Generated id for the node | ||||
| ExecutionTree *tree_; // Back pointer to our tree. | ExecutionTree *tree_; // Back pointer to our tree. | ||||
| @@ -26,7 +26,7 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| // Constructor | // Constructor | ||||
| ParallelOp::ParallelOp(int32_t num_workers, int32_t op_connector_size, std::shared_ptr<Sampler> sampler) | |||||
| ParallelOp::ParallelOp(int32_t num_workers, int32_t op_connector_size, std::shared_ptr<SamplerRT> sampler) | |||||
| : DatasetOp(op_connector_size, sampler), | : DatasetOp(op_connector_size, sampler), | ||||
| num_workers_(num_workers), | num_workers_(num_workers), | ||||
| num_producers_(num_workers), | num_producers_(num_workers), | ||||
| @@ -41,7 +41,7 @@ class ParallelOp : public DatasetOp { | |||||
| // @param num_workers | // @param num_workers | ||||
| // @param op_connector_size - size of the output connector for this operator | // @param op_connector_size - size of the output connector for this operator | ||||
| // @param sampler - The sampler for the op | // @param sampler - The sampler for the op | ||||
| ParallelOp(int32_t num_workers, int32_t op_connector_size, std::shared_ptr<Sampler> sampler = nullptr); | |||||
| ParallelOp(int32_t num_workers, int32_t op_connector_size, std::shared_ptr<SamplerRT> sampler = nullptr); | |||||
| // Destructor | // Destructor | ||||
| ~ParallelOp() = default; | ~ParallelOp() = default; | ||||
| @@ -20,7 +20,7 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| // Constructor | // Constructor | ||||
| PipelineOp::PipelineOp(int32_t op_connector_size, std::shared_ptr<Sampler> sampler) | |||||
| PipelineOp::PipelineOp(int32_t op_connector_size, std::shared_ptr<SamplerRT> sampler) | |||||
| : DatasetOp(op_connector_size, sampler) {} | : DatasetOp(op_connector_size, sampler) {} | ||||
| // A print method typically used for debugging | // A print method typically used for debugging | ||||
| @@ -34,7 +34,7 @@ class PipelineOp : public DatasetOp { | |||||
| // @param op_connector_size - size of the output connector | // @param op_connector_size - size of the output connector | ||||
| // @return Builder setter method returns reference to the builder. | // @return Builder setter method returns reference to the builder. | ||||
| // @param sampler - The sampler for the op | // @param sampler - The sampler for the op | ||||
| explicit PipelineOp(int32_t op_connector_size, std::shared_ptr<Sampler> sampler = nullptr); | |||||
| explicit PipelineOp(int32_t op_connector_size, std::shared_ptr<SamplerRT> sampler = nullptr); | |||||
| // Destructor | // Destructor | ||||
| ~PipelineOp() = default; | ~PipelineOp() = default; | ||||
| @@ -42,7 +42,7 @@ Status AlbumOp::Builder::Build(std::shared_ptr<AlbumOp> *ptr) { | |||||
| if (builder_sampler_ == nullptr) { | if (builder_sampler_ == nullptr) { | ||||
| const int64_t num_samples = 0; // default num samples of 0 means to sample entire set of data | const int64_t num_samples = 0; // default num samples of 0 means to sample entire set of data | ||||
| const int64_t start_index = 0; | const int64_t start_index = 0; | ||||
| builder_sampler_ = std::make_shared<SequentialSampler>(start_index, num_samples); | |||||
| builder_sampler_ = std::make_shared<SequentialSamplerRT>(start_index, num_samples); | |||||
| } | } | ||||
| builder_schema_ = std::make_unique<DataSchema>(); | builder_schema_ = std::make_unique<DataSchema>(); | ||||
| @@ -73,7 +73,7 @@ Status AlbumOp::Builder::SanityCheck() { | |||||
| AlbumOp::AlbumOp(int32_t num_wkrs, int32_t rows_per_buffer, std::string file_dir, int32_t queue_size, bool do_decode, | AlbumOp::AlbumOp(int32_t num_wkrs, int32_t rows_per_buffer, std::string file_dir, int32_t queue_size, bool do_decode, | ||||
| const std::set<std::string> &exts, std::unique_ptr<DataSchema> data_schema, | const std::set<std::string> &exts, std::unique_ptr<DataSchema> data_schema, | ||||
| std::shared_ptr<Sampler> sampler) | |||||
| std::shared_ptr<SamplerRT> sampler) | |||||
| : ParallelOp(num_wkrs, queue_size, std::move(sampler)), | : ParallelOp(num_wkrs, queue_size, std::move(sampler)), | ||||
| rows_per_buffer_(rows_per_buffer), | rows_per_buffer_(rows_per_buffer), | ||||
| folder_path_(file_dir), | folder_path_(file_dir), | ||||
| @@ -100,7 +100,7 @@ class AlbumOp : public ParallelOp, public RandomAccessOp { | |||||
| /// \brief Setter method | /// \brief Setter method | ||||
| /// \param[in] sampler | /// \param[in] sampler | ||||
| /// \return Builder setter method returns reference to the builder | /// \return Builder setter method returns reference to the builder | ||||
| Builder &SetSampler(std::shared_ptr<Sampler> sampler) { | |||||
| Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) { | |||||
| builder_sampler_ = std::move(sampler); | builder_sampler_ = std::move(sampler); | ||||
| return *this; | return *this; | ||||
| } | } | ||||
| @@ -147,7 +147,7 @@ class AlbumOp : public ParallelOp, public RandomAccessOp { | |||||
| int32_t builder_rows_per_buffer_; | int32_t builder_rows_per_buffer_; | ||||
| int32_t builder_op_connector_size_; | int32_t builder_op_connector_size_; | ||||
| std::set<std::string> builder_extensions_; | std::set<std::string> builder_extensions_; | ||||
| std::shared_ptr<Sampler> builder_sampler_; | |||||
| std::shared_ptr<SamplerRT> builder_sampler_; | |||||
| std::unique_ptr<DataSchema> builder_schema_; | std::unique_ptr<DataSchema> builder_schema_; | ||||
| }; | }; | ||||
| @@ -161,7 +161,8 @@ class AlbumOp : public ParallelOp, public RandomAccessOp { | |||||
| /// \param[in] data_schema - schema of dataset | /// \param[in] data_schema - schema of dataset | ||||
| /// \param[in] sampler - sampler tells AlbumOp what to read | /// \param[in] sampler - sampler tells AlbumOp what to read | ||||
| AlbumOp(int32_t num_wkrs, int32_t rows_per_buffer, std::string file_dir, int32_t queue_size, bool do_decode, | AlbumOp(int32_t num_wkrs, int32_t rows_per_buffer, std::string file_dir, int32_t queue_size, bool do_decode, | ||||
| const std::set<std::string> &exts, std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler); | |||||
| const std::set<std::string> &exts, std::unique_ptr<DataSchema> data_schema, | |||||
| std::shared_ptr<SamplerRT> sampler); | |||||
| /// \brief Destructor. | /// \brief Destructor. | ||||
| ~AlbumOp() = default; | ~AlbumOp() = default; | ||||
| @@ -46,7 +46,7 @@ Status CelebAOp::Builder::Build(std::shared_ptr<CelebAOp> *op) { | |||||
| if (builder_sampler_ == nullptr) { | if (builder_sampler_ == nullptr) { | ||||
| const int64_t num_samples = 0; | const int64_t num_samples = 0; | ||||
| const int64_t start_index = 0; | const int64_t start_index = 0; | ||||
| builder_sampler_ = std::make_shared<SequentialSampler>(start_index, num_samples); | |||||
| builder_sampler_ = std::make_shared<SequentialSamplerRT>(start_index, num_samples); | |||||
| } | } | ||||
| builder_schema_ = std::make_unique<DataSchema>(); | builder_schema_ = std::make_unique<DataSchema>(); | ||||
| @@ -79,7 +79,7 @@ Status CelebAOp::Builder::SanityCheck() { | |||||
| CelebAOp::CelebAOp(int32_t num_workers, int32_t rows_per_buffer, const std::string &dir, int32_t queue_size, | CelebAOp::CelebAOp(int32_t num_workers, int32_t rows_per_buffer, const std::string &dir, int32_t queue_size, | ||||
| bool decode, const std::string &usage, const std::set<std::string> &exts, | bool decode, const std::string &usage, const std::set<std::string> &exts, | ||||
| std::unique_ptr<DataSchema> schema, std::shared_ptr<Sampler> sampler) | |||||
| std::unique_ptr<DataSchema> schema, std::shared_ptr<SamplerRT> sampler) | |||||
| : ParallelOp(num_workers, queue_size, std::move(sampler)), | : ParallelOp(num_workers, queue_size, std::move(sampler)), | ||||
| rows_per_buffer_(rows_per_buffer), | rows_per_buffer_(rows_per_buffer), | ||||
| folder_path_(dir), | folder_path_(dir), | ||||
| @@ -95,7 +95,7 @@ class CelebAOp : public ParallelOp, RandomAccessOp { | |||||
| // Setter method | // Setter method | ||||
| // @param std::shared_ptr<Sampler> sampler | // @param std::shared_ptr<Sampler> sampler | ||||
| // @return Builder setter method returns reference to the builder. | // @return Builder setter method returns reference to the builder. | ||||
| Builder &SetSampler(std::shared_ptr<Sampler> sampler) { | |||||
| Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) { | |||||
| builder_sampler_ = std::move(sampler); | builder_sampler_ = std::move(sampler); | ||||
| return *this; | return *this; | ||||
| } | } | ||||
| @@ -131,7 +131,7 @@ class CelebAOp : public ParallelOp, RandomAccessOp { | |||||
| int32_t builder_rows_per_buffer_; | int32_t builder_rows_per_buffer_; | ||||
| int32_t builder_op_connector_size_; | int32_t builder_op_connector_size_; | ||||
| std::set<std::string> builder_extensions_; | std::set<std::string> builder_extensions_; | ||||
| std::shared_ptr<Sampler> builder_sampler_; | |||||
| std::shared_ptr<SamplerRT> builder_sampler_; | |||||
| std::unique_ptr<DataSchema> builder_schema_; | std::unique_ptr<DataSchema> builder_schema_; | ||||
| std::string builder_usage_; | std::string builder_usage_; | ||||
| }; | }; | ||||
| @@ -144,7 +144,7 @@ class CelebAOp : public ParallelOp, RandomAccessOp { | |||||
| // @param std::unique_ptr<Sampler> sampler - sampler tells CelebAOp what to read | // @param std::unique_ptr<Sampler> sampler - sampler tells CelebAOp what to read | ||||
| CelebAOp(int32_t num_workers, int32_t rows_per_buffer, const std::string &dir, int32_t queue_size, bool decode, | CelebAOp(int32_t num_workers, int32_t rows_per_buffer, const std::string &dir, int32_t queue_size, bool decode, | ||||
| const std::string &usage, const std::set<std::string> &exts, std::unique_ptr<DataSchema> schema, | const std::string &usage, const std::set<std::string> &exts, std::unique_ptr<DataSchema> schema, | ||||
| std::shared_ptr<Sampler> sampler); | |||||
| std::shared_ptr<SamplerRT> sampler); | |||||
| ~CelebAOp() override = default; | ~CelebAOp() override = default; | ||||
| @@ -50,7 +50,7 @@ Status CifarOp::Builder::Build(std::shared_ptr<CifarOp> *ptr) { | |||||
| if (sampler_ == nullptr) { | if (sampler_ == nullptr) { | ||||
| const int64_t num_samples = 0; | const int64_t num_samples = 0; | ||||
| const int64_t start_index = 0; | const int64_t start_index = 0; | ||||
| sampler_ = std::make_shared<SequentialSampler>(start_index, num_samples); | |||||
| sampler_ = std::make_shared<SequentialSamplerRT>(start_index, num_samples); | |||||
| } | } | ||||
| schema_ = std::make_unique<DataSchema>(); | schema_ = std::make_unique<DataSchema>(); | ||||
| TensorShape scalar = TensorShape::CreateScalar(); | TensorShape scalar = TensorShape::CreateScalar(); | ||||
| @@ -88,7 +88,7 @@ Status CifarOp::Builder::SanityCheck() { | |||||
| CifarOp::CifarOp(CifarType type, const std::string &usage, int32_t num_works, int32_t rows_per_buf, | CifarOp::CifarOp(CifarType type, const std::string &usage, int32_t num_works, int32_t rows_per_buf, | ||||
| const std::string &file_dir, int32_t queue_size, std::unique_ptr<DataSchema> data_schema, | const std::string &file_dir, int32_t queue_size, std::unique_ptr<DataSchema> data_schema, | ||||
| std::shared_ptr<Sampler> sampler) | |||||
| std::shared_ptr<SamplerRT> sampler) | |||||
| : ParallelOp(num_works, queue_size, std::move(sampler)), | : ParallelOp(num_works, queue_size, std::move(sampler)), | ||||
| cifar_type_(type), | cifar_type_(type), | ||||
| usage_(usage), | usage_(usage), | ||||
| @@ -75,7 +75,7 @@ class CifarOp : public ParallelOp, public RandomAccessOp { | |||||
| // Setter method | // Setter method | ||||
| // @param std::shared_ptr<Sampler> sampler | // @param std::shared_ptr<Sampler> sampler | ||||
| // @return Builder setter method returns reference to the builder. | // @return Builder setter method returns reference to the builder. | ||||
| Builder &SetSampler(std::shared_ptr<Sampler> sampler) { | |||||
| Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) { | |||||
| sampler_ = std::move(sampler); | sampler_ = std::move(sampler); | ||||
| return *this; | return *this; | ||||
| } | } | ||||
| @@ -123,7 +123,7 @@ class CifarOp : public ParallelOp, public RandomAccessOp { | |||||
| int32_t num_workers_; | int32_t num_workers_; | ||||
| int32_t rows_per_buffer_; | int32_t rows_per_buffer_; | ||||
| int32_t op_connect_size_; | int32_t op_connect_size_; | ||||
| std::shared_ptr<Sampler> sampler_; | |||||
| std::shared_ptr<SamplerRT> sampler_; | |||||
| std::unique_ptr<DataSchema> schema_; | std::unique_ptr<DataSchema> schema_; | ||||
| CifarType cifar_type_; | CifarType cifar_type_; | ||||
| }; | }; | ||||
| @@ -138,7 +138,7 @@ class CifarOp : public ParallelOp, public RandomAccessOp { | |||||
| // @param std::unique_ptr<Sampler> sampler - sampler tells ImageFolderOp what to read | // @param std::unique_ptr<Sampler> sampler - sampler tells ImageFolderOp what to read | ||||
| CifarOp(CifarType type, const std::string &usage, int32_t num_works, int32_t rows_per_buf, | CifarOp(CifarType type, const std::string &usage, int32_t num_works, int32_t rows_per_buf, | ||||
| const std::string &file_dir, int32_t queue_size, std::unique_ptr<DataSchema> data_schema, | const std::string &file_dir, int32_t queue_size, std::unique_ptr<DataSchema> data_schema, | ||||
| std::shared_ptr<Sampler> sampler); | |||||
| std::shared_ptr<SamplerRT> sampler); | |||||
| // Destructor. | // Destructor. | ||||
| ~CifarOp() = default; | ~CifarOp() = default; | ||||
| @@ -94,7 +94,7 @@ std::vector<std::string> ClueOp::Builder::split(const std::string &s, char delim | |||||
| ClueOp::ClueOp(int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size, | ClueOp::ClueOp(int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size, | ||||
| ColKeyMap cols_to_keyword, std::vector<std::string> clue_files_list, int32_t op_connector_size, | ColKeyMap cols_to_keyword, std::vector<std::string> clue_files_list, int32_t op_connector_size, | ||||
| bool shuffle_files, int32_t num_device, int32_t device_id, std::shared_ptr<Sampler> sampler) | |||||
| bool shuffle_files, int32_t num_device, int32_t device_id, std::shared_ptr<SamplerRT> sampler) | |||||
| : ParallelOp(num_workers, op_connector_size, std::move(sampler)), | : ParallelOp(num_workers, op_connector_size, std::move(sampler)), | ||||
| rows_per_buffer_(rows_per_buffer), | rows_per_buffer_(rows_per_buffer), | ||||
| num_rows_per_shard_(0), | num_rows_per_shard_(0), | ||||
| @@ -125,7 +125,7 @@ class ClueOp : public ParallelOp { | |||||
| // Setter method | // Setter method | ||||
| // @param std::shared_ptr<Sampler> sampler | // @param std::shared_ptr<Sampler> sampler | ||||
| // @return Builder setter method returns reference to the builder. | // @return Builder setter method returns reference to the builder. | ||||
| Builder &SetSampler(std::shared_ptr<Sampler> sampler) { | |||||
| Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) { | |||||
| builder_sampler_ = std::move(sampler); | builder_sampler_ = std::move(sampler); | ||||
| return *this; | return *this; | ||||
| } | } | ||||
| @@ -141,13 +141,13 @@ class ClueOp : public ParallelOp { | |||||
| std::vector<std::string> builder_clue_files_list_; | std::vector<std::string> builder_clue_files_list_; | ||||
| bool builder_shuffle_files_; | bool builder_shuffle_files_; | ||||
| std::map<std::string, std::string> builder_cols_to_keyword_; | std::map<std::string, std::string> builder_cols_to_keyword_; | ||||
| std::shared_ptr<Sampler> builder_sampler_; | |||||
| std::shared_ptr<SamplerRT> builder_sampler_; | |||||
| }; | }; | ||||
| // Constructor of ClueOp | // Constructor of ClueOp | ||||
| ClueOp(int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size, | ClueOp(int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size, | ||||
| ColKeyMap cols_to_keyword, std::vector<std::string> clue_files_list, int32_t op_connector_size, | ColKeyMap cols_to_keyword, std::vector<std::string> clue_files_list, int32_t op_connector_size, | ||||
| bool shuffle_files, int32_t num_devices, int32_t device_id, std::shared_ptr<Sampler> sampler); | |||||
| bool shuffle_files, int32_t num_devices, int32_t device_id, std::shared_ptr<SamplerRT> sampler); | |||||
| // Default destructor | // Default destructor | ||||
| ~ClueOp() = default; | ~ClueOp() = default; | ||||
| @@ -60,7 +60,7 @@ Status CocoOp::Builder::Build(std::shared_ptr<CocoOp> *ptr) { | |||||
| if (builder_sampler_ == nullptr) { | if (builder_sampler_ == nullptr) { | ||||
| const int64_t num_samples = 0; | const int64_t num_samples = 0; | ||||
| const int64_t start_index = 0; | const int64_t start_index = 0; | ||||
| builder_sampler_ = std::make_shared<SequentialSampler>(start_index, num_samples); | |||||
| builder_sampler_ = std::make_shared<SequentialSamplerRT>(start_index, num_samples); | |||||
| } | } | ||||
| builder_schema_ = std::make_unique<DataSchema>(); | builder_schema_ = std::make_unique<DataSchema>(); | ||||
| RETURN_IF_NOT_OK(builder_schema_->AddColumn( | RETURN_IF_NOT_OK(builder_schema_->AddColumn( | ||||
| @@ -123,7 +123,7 @@ Status CocoOp::Builder::SanityCheck() { | |||||
| CocoOp::CocoOp(const TaskType &task_type, const std::string &image_folder_path, const std::string &annotation_path, | CocoOp::CocoOp(const TaskType &task_type, const std::string &image_folder_path, const std::string &annotation_path, | ||||
| int32_t num_workers, int32_t rows_per_buffer, int32_t queue_size, bool decode, | int32_t num_workers, int32_t rows_per_buffer, int32_t queue_size, bool decode, | ||||
| std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler) | |||||
| std::unique_ptr<DataSchema> data_schema, std::shared_ptr<SamplerRT> sampler) | |||||
| : ParallelOp(num_workers, queue_size, std::move(sampler)), | : ParallelOp(num_workers, queue_size, std::move(sampler)), | ||||
| decode_(decode), | decode_(decode), | ||||
| row_cnt_(0), | row_cnt_(0), | ||||
| @@ -119,7 +119,7 @@ class CocoOp : public ParallelOp, public RandomAccessOp { | |||||
| // Setter method. | // Setter method. | ||||
| // @param std::shared_ptr<Sampler> sampler | // @param std::shared_ptr<Sampler> sampler | ||||
| // @return Builder setter method returns reference to the builder. | // @return Builder setter method returns reference to the builder. | ||||
| Builder &SetSampler(std::shared_ptr<Sampler> sampler) { | |||||
| Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) { | |||||
| builder_sampler_ = std::move(sampler); | builder_sampler_ = std::move(sampler); | ||||
| return *this; | return *this; | ||||
| } | } | ||||
| @@ -149,7 +149,7 @@ class CocoOp : public ParallelOp, public RandomAccessOp { | |||||
| int32_t builder_num_workers_; | int32_t builder_num_workers_; | ||||
| int32_t builder_op_connector_size_; | int32_t builder_op_connector_size_; | ||||
| int32_t builder_rows_per_buffer_; | int32_t builder_rows_per_buffer_; | ||||
| std::shared_ptr<Sampler> builder_sampler_; | |||||
| std::shared_ptr<SamplerRT> builder_sampler_; | |||||
| std::unique_ptr<DataSchema> builder_schema_; | std::unique_ptr<DataSchema> builder_schema_; | ||||
| }; | }; | ||||
| @@ -166,7 +166,7 @@ class CocoOp : public ParallelOp, public RandomAccessOp { | |||||
| // @param std::shared_ptr<Sampler> sampler - sampler tells CocoOp what to read | // @param std::shared_ptr<Sampler> sampler - sampler tells CocoOp what to read | ||||
| CocoOp(const TaskType &task_type, const std::string &image_folder_path, const std::string &annotation_path, | CocoOp(const TaskType &task_type, const std::string &image_folder_path, const std::string &annotation_path, | ||||
| int32_t num_workers, int32_t rows_per_buffer, int32_t queue_size, bool decode, | int32_t num_workers, int32_t rows_per_buffer, int32_t queue_size, bool decode, | ||||
| std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler); | |||||
| std::unique_ptr<DataSchema> data_schema, std::shared_ptr<SamplerRT> sampler); | |||||
| // Destructor | // Destructor | ||||
| ~CocoOp() = default; | ~CocoOp() = default; | ||||
| @@ -77,7 +77,7 @@ CsvOp::CsvOp(const std::vector<std::string> &csv_files_list, char field_delim, | |||||
| const std::vector<std::shared_ptr<BaseRecord>> &column_default, | const std::vector<std::shared_ptr<BaseRecord>> &column_default, | ||||
| const std::vector<std::string> &column_name, int32_t num_workers, int64_t rows_per_buffer, | const std::vector<std::string> &column_name, int32_t num_workers, int64_t rows_per_buffer, | ||||
| int64_t num_samples, int32_t worker_connector_size, int32_t op_connector_size, bool shuffle_files, | int64_t num_samples, int32_t worker_connector_size, int32_t op_connector_size, bool shuffle_files, | ||||
| int32_t num_device, int32_t device_id, std::shared_ptr<Sampler> sampler) | |||||
| int32_t num_device, int32_t device_id, std::shared_ptr<SamplerRT> sampler) | |||||
| : ParallelOp(num_workers, op_connector_size, std::move(sampler)), | : ParallelOp(num_workers, op_connector_size, std::move(sampler)), | ||||
| csv_files_list_(std::move(csv_files_list)), | csv_files_list_(std::move(csv_files_list)), | ||||
| field_delim_(field_delim), | field_delim_(field_delim), | ||||
| @@ -243,7 +243,7 @@ class CsvOp : public ParallelOp { | |||||
| // Setter method | // Setter method | ||||
| // @param std::shared_ptr<Sampler> sampler | // @param std::shared_ptr<Sampler> sampler | ||||
| // @return Builder setter method returns reference to the builder. | // @return Builder setter method returns reference to the builder. | ||||
| Builder &SetSampler(std::shared_ptr<Sampler> sampler) { | |||||
| Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) { | |||||
| builder_sampler_ = std::move(sampler); | builder_sampler_ = std::move(sampler); | ||||
| return *this; | return *this; | ||||
| } | } | ||||
| @@ -261,7 +261,7 @@ class CsvOp : public ParallelOp { | |||||
| char builder_field_delim_; | char builder_field_delim_; | ||||
| std::vector<std::shared_ptr<CsvOp::BaseRecord>> builder_column_default_list_; | std::vector<std::shared_ptr<CsvOp::BaseRecord>> builder_column_default_list_; | ||||
| std::vector<std::string> builder_column_name_list_; | std::vector<std::string> builder_column_name_list_; | ||||
| std::shared_ptr<Sampler> builder_sampler_; | |||||
| std::shared_ptr<SamplerRT> builder_sampler_; | |||||
| }; | }; | ||||
| // Constructor of CsvOp | // Constructor of CsvOp | ||||
| @@ -271,7 +271,7 @@ class CsvOp : public ParallelOp { | |||||
| const std::vector<std::shared_ptr<BaseRecord>> &column_default, const std::vector<std::string> &column_name, | const std::vector<std::shared_ptr<BaseRecord>> &column_default, const std::vector<std::string> &column_name, | ||||
| int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size, | int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size, | ||||
| int32_t op_connector_size, bool shuffle_files, int32_t num_devices, int32_t device_id, | int32_t op_connector_size, bool shuffle_files, int32_t num_devices, int32_t device_id, | ||||
| std::shared_ptr<Sampler> sampler); | |||||
| std::shared_ptr<SamplerRT> sampler); | |||||
| // Default destructor | // Default destructor | ||||
| ~CsvOp() = default; | ~CsvOp() = default; | ||||
| @@ -38,7 +38,7 @@ Status ImageFolderOp::Builder::Build(std::shared_ptr<ImageFolderOp> *ptr) { | |||||
| if (builder_sampler_ == nullptr) { | if (builder_sampler_ == nullptr) { | ||||
| const int64_t num_samples = 0; // default num samples of 0 means to sample entire set of data | const int64_t num_samples = 0; // default num samples of 0 means to sample entire set of data | ||||
| const int64_t start_index = 0; | const int64_t start_index = 0; | ||||
| builder_sampler_ = std::make_shared<SequentialSampler>(start_index, num_samples); | |||||
| builder_sampler_ = std::make_shared<SequentialSamplerRT>(start_index, num_samples); | |||||
| } | } | ||||
| builder_schema_ = std::make_unique<DataSchema>(); | builder_schema_ = std::make_unique<DataSchema>(); | ||||
| TensorShape scalar = TensorShape::CreateScalar(); | TensorShape scalar = TensorShape::CreateScalar(); | ||||
| @@ -68,7 +68,7 @@ Status ImageFolderOp::Builder::SanityCheck() { | |||||
| ImageFolderOp::ImageFolderOp(int32_t num_wkrs, int32_t rows_per_buffer, std::string file_dir, int32_t queue_size, | ImageFolderOp::ImageFolderOp(int32_t num_wkrs, int32_t rows_per_buffer, std::string file_dir, int32_t queue_size, | ||||
| bool recursive, bool do_decode, const std::set<std::string> &exts, | bool recursive, bool do_decode, const std::set<std::string> &exts, | ||||
| const std::map<std::string, int32_t> &map, std::unique_ptr<DataSchema> data_schema, | const std::map<std::string, int32_t> &map, std::unique_ptr<DataSchema> data_schema, | ||||
| std::shared_ptr<Sampler> sampler) | |||||
| std::shared_ptr<SamplerRT> sampler) | |||||
| : ParallelOp(num_wkrs, queue_size, std::move(sampler)), | : ParallelOp(num_wkrs, queue_size, std::move(sampler)), | ||||
| rows_per_buffer_(rows_per_buffer), | rows_per_buffer_(rows_per_buffer), | ||||
| folder_path_(file_dir), | folder_path_(file_dir), | ||||
| @@ -113,7 +113,7 @@ class ImageFolderOp : public ParallelOp, public RandomAccessOp { | |||||
| // Setter method | // Setter method | ||||
| // @param std::shared_ptr<Sampler> sampler | // @param std::shared_ptr<Sampler> sampler | ||||
| // @return Builder setter method returns reference to the builder. | // @return Builder setter method returns reference to the builder. | ||||
| Builder &SetSampler(std::shared_ptr<Sampler> sampler) { | |||||
| Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) { | |||||
| builder_sampler_ = std::move(sampler); | builder_sampler_ = std::move(sampler); | ||||
| return *this; | return *this; | ||||
| } | } | ||||
| @@ -151,7 +151,7 @@ class ImageFolderOp : public ParallelOp, public RandomAccessOp { | |||||
| int32_t builder_rows_per_buffer_; | int32_t builder_rows_per_buffer_; | ||||
| int32_t builder_op_connector_size_; | int32_t builder_op_connector_size_; | ||||
| std::set<std::string> builder_extensions_; | std::set<std::string> builder_extensions_; | ||||
| std::shared_ptr<Sampler> builder_sampler_; | |||||
| std::shared_ptr<SamplerRT> builder_sampler_; | |||||
| std::unique_ptr<DataSchema> builder_schema_; | std::unique_ptr<DataSchema> builder_schema_; | ||||
| std::map<std::string, int32_t> builder_labels_to_read_; | std::map<std::string, int32_t> builder_labels_to_read_; | ||||
| }; | }; | ||||
| @@ -165,7 +165,7 @@ class ImageFolderOp : public ParallelOp, public RandomAccessOp { | |||||
| // @param td::unique_ptr<Sampler> sampler - sampler tells ImageFolderOp what to read | // @param td::unique_ptr<Sampler> sampler - sampler tells ImageFolderOp what to read | ||||
| ImageFolderOp(int32_t num_wkrs, int32_t rows_per_buffer, std::string file_dir, int32_t queue_size, bool recursive, | ImageFolderOp(int32_t num_wkrs, int32_t rows_per_buffer, std::string file_dir, int32_t queue_size, bool recursive, | ||||
| bool do_decode, const std::set<std::string> &exts, const std::map<std::string, int32_t> &map, | bool do_decode, const std::set<std::string> &exts, const std::map<std::string, int32_t> &map, | ||||
| std::unique_ptr<DataSchema>, std::shared_ptr<Sampler> sampler); | |||||
| std::unique_ptr<DataSchema>, std::shared_ptr<SamplerRT> sampler); | |||||
| // Destructor. | // Destructor. | ||||
| ~ImageFolderOp() = default; | ~ImageFolderOp() = default; | ||||
| @@ -43,7 +43,7 @@ Status ManifestOp::Builder::Build(std::shared_ptr<ManifestOp> *ptr) { | |||||
| if (builder_sampler_ == nullptr) { | if (builder_sampler_ == nullptr) { | ||||
| const int64_t num_samples = 0; | const int64_t num_samples = 0; | ||||
| const int64_t start_index = 0; | const int64_t start_index = 0; | ||||
| builder_sampler_ = std::make_shared<SequentialSampler>(start_index, num_samples); | |||||
| builder_sampler_ = std::make_shared<SequentialSamplerRT>(start_index, num_samples); | |||||
| } | } | ||||
| builder_schema_ = std::make_unique<DataSchema>(); | builder_schema_ = std::make_unique<DataSchema>(); | ||||
| RETURN_IF_NOT_OK( | RETURN_IF_NOT_OK( | ||||
| @@ -67,7 +67,7 @@ Status ManifestOp::Builder::SanityCheck() { | |||||
| ManifestOp::ManifestOp(int32_t num_works, int32_t rows_per_buffer, std::string file, int32_t queue_size, bool decode, | ManifestOp::ManifestOp(int32_t num_works, int32_t rows_per_buffer, std::string file, int32_t queue_size, bool decode, | ||||
| const std::map<std::string, int32_t> &class_index, std::unique_ptr<DataSchema> data_schema, | const std::map<std::string, int32_t> &class_index, std::unique_ptr<DataSchema> data_schema, | ||||
| std::shared_ptr<Sampler> sampler, std::string usage) | |||||
| std::shared_ptr<SamplerRT> sampler, std::string usage) | |||||
| : ParallelOp(num_works, queue_size, std::move(sampler)), | : ParallelOp(num_works, queue_size, std::move(sampler)), | ||||
| rows_per_buffer_(rows_per_buffer), | rows_per_buffer_(rows_per_buffer), | ||||
| io_block_pushed_(0), | io_block_pushed_(0), | ||||
| @@ -88,7 +88,7 @@ class ManifestOp : public ParallelOp, public RandomAccessOp { | |||||
| // Setter method | // Setter method | ||||
| // @param std::shared_ptr<Sampler> sampler | // @param std::shared_ptr<Sampler> sampler | ||||
| // @return Builder setter method returns reference to the builder. | // @return Builder setter method returns reference to the builder. | ||||
| Builder &SetSampler(std::shared_ptr<Sampler> sampler) { | |||||
| Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) { | |||||
| builder_sampler_ = std::move(sampler); | builder_sampler_ = std::move(sampler); | ||||
| return *this; | return *this; | ||||
| } | } | ||||
| @@ -119,7 +119,7 @@ class ManifestOp : public ParallelOp, public RandomAccessOp { | |||||
| Status Build(std::shared_ptr<ManifestOp> *op); | Status Build(std::shared_ptr<ManifestOp> *op); | ||||
| private: | private: | ||||
| std::shared_ptr<Sampler> builder_sampler_; | |||||
| std::shared_ptr<SamplerRT> builder_sampler_; | |||||
| bool builder_decode_; | bool builder_decode_; | ||||
| std::string builder_file_; | std::string builder_file_; | ||||
| @@ -139,7 +139,7 @@ class ManifestOp : public ParallelOp, public RandomAccessOp { | |||||
| // @param td::unique_ptr<Sampler> sampler - sampler tells ImageFolderOp what to read | // @param td::unique_ptr<Sampler> sampler - sampler tells ImageFolderOp what to read | ||||
| ManifestOp(int32_t num_works, int32_t rows_per_buffer, std::string file, int32_t queue_size, bool decode, | ManifestOp(int32_t num_works, int32_t rows_per_buffer, std::string file, int32_t queue_size, bool decode, | ||||
| const std::map<std::string, int32_t> &class_index, std::unique_ptr<DataSchema> data_schema, | const std::map<std::string, int32_t> &class_index, std::unique_ptr<DataSchema> data_schema, | ||||
| std::shared_ptr<Sampler> sampler, std::string usage); | |||||
| std::shared_ptr<SamplerRT> sampler, std::string usage); | |||||
| // Destructor. | // Destructor. | ||||
| ~ManifestOp() = default; | ~ManifestOp() = default; | ||||
| @@ -45,7 +45,7 @@ Status MnistOp::Builder::Build(std::shared_ptr<MnistOp> *ptr) { | |||||
| if (builder_sampler_ == nullptr) { | if (builder_sampler_ == nullptr) { | ||||
| const int64_t num_samples = 0; | const int64_t num_samples = 0; | ||||
| const int64_t start_index = 0; | const int64_t start_index = 0; | ||||
| builder_sampler_ = std::make_shared<SequentialSampler>(start_index, num_samples); | |||||
| builder_sampler_ = std::make_shared<SequentialSamplerRT>(start_index, num_samples); | |||||
| } | } | ||||
| builder_schema_ = std::make_unique<DataSchema>(); | builder_schema_ = std::make_unique<DataSchema>(); | ||||
| RETURN_IF_NOT_OK( | RETURN_IF_NOT_OK( | ||||
| @@ -75,7 +75,7 @@ Status MnistOp::Builder::SanityCheck() { | |||||
| } | } | ||||
| MnistOp::MnistOp(const std::string &usage, int32_t num_workers, int32_t rows_per_buffer, std::string folder_path, | MnistOp::MnistOp(const std::string &usage, int32_t num_workers, int32_t rows_per_buffer, std::string folder_path, | ||||
| int32_t queue_size, std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler) | |||||
| int32_t queue_size, std::unique_ptr<DataSchema> data_schema, std::shared_ptr<SamplerRT> sampler) | |||||
| : ParallelOp(num_workers, queue_size, std::move(sampler)), | : ParallelOp(num_workers, queue_size, std::move(sampler)), | ||||
| usage_(usage), | usage_(usage), | ||||
| buf_cnt_(0), | buf_cnt_(0), | ||||
| @@ -78,7 +78,7 @@ class MnistOp : public ParallelOp, public RandomAccessOp { | |||||
| // Setter method | // Setter method | ||||
| // @param std::shared_ptr<Sampler> sampler | // @param std::shared_ptr<Sampler> sampler | ||||
| // @return Builder setter method returns reference to the builder. | // @return Builder setter method returns reference to the builder. | ||||
| Builder &SetSampler(std::shared_ptr<Sampler> sampler) { | |||||
| Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) { | |||||
| builder_sampler_ = std::move(sampler); | builder_sampler_ = std::move(sampler); | ||||
| return *this; | return *this; | ||||
| } | } | ||||
| @@ -113,7 +113,7 @@ class MnistOp : public ParallelOp, public RandomAccessOp { | |||||
| int32_t builder_num_workers_; | int32_t builder_num_workers_; | ||||
| int32_t builder_rows_per_buffer_; | int32_t builder_rows_per_buffer_; | ||||
| int32_t builder_op_connector_size_; | int32_t builder_op_connector_size_; | ||||
| std::shared_ptr<Sampler> builder_sampler_; | |||||
| std::shared_ptr<SamplerRT> builder_sampler_; | |||||
| std::unique_ptr<DataSchema> builder_schema_; | std::unique_ptr<DataSchema> builder_schema_; | ||||
| }; | }; | ||||
| @@ -126,7 +126,7 @@ class MnistOp : public ParallelOp, public RandomAccessOp { | |||||
| // @param std::unique_ptr<DataSchema> data_schema - the schema of the mnist dataset | // @param std::unique_ptr<DataSchema> data_schema - the schema of the mnist dataset | ||||
| // @param td::unique_ptr<Sampler> sampler - sampler tells MnistOp what to read | // @param td::unique_ptr<Sampler> sampler - sampler tells MnistOp what to read | ||||
| MnistOp(const std::string &usage, int32_t num_workers, int32_t rows_per_buffer, std::string folder_path, | MnistOp(const std::string &usage, int32_t num_workers, int32_t rows_per_buffer, std::string folder_path, | ||||
| int32_t queue_size, std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler); | |||||
| int32_t queue_size, std::unique_ptr<DataSchema> data_schema, std::shared_ptr<SamplerRT> sampler); | |||||
| // Destructor. | // Destructor. | ||||
| ~MnistOp() = default; | ~MnistOp() = default; | ||||
| @@ -65,7 +65,7 @@ Status RandomDataOp::Builder::SanityCheck() const { | |||||
| // Constructor for RandomDataOp | // Constructor for RandomDataOp | ||||
| RandomDataOp::RandomDataOp(int32_t num_workers, int32_t op_connector_size, int64_t rows_per_buffer, int64_t total_rows, | RandomDataOp::RandomDataOp(int32_t num_workers, int32_t op_connector_size, int64_t rows_per_buffer, int64_t total_rows, | ||||
| std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler) | |||||
| std::unique_ptr<DataSchema> data_schema, std::shared_ptr<SamplerRT> sampler) | |||||
| : ParallelOp(num_workers, op_connector_size, std::move(sampler)), | : ParallelOp(num_workers, op_connector_size, std::move(sampler)), | ||||
| buffer_id_(0), | buffer_id_(0), | ||||
| rows_per_buffer_(rows_per_buffer), | rows_per_buffer_(rows_per_buffer), | ||||
| @@ -120,7 +120,7 @@ class RandomDataOp : public ParallelOp { | |||||
| // Setter method | // Setter method | ||||
| // @param std::shared_ptr<Sampler> sampler | // @param std::shared_ptr<Sampler> sampler | ||||
| // @return Builder setter method returns reference to the builder. | // @return Builder setter method returns reference to the builder. | ||||
| Builder &SetSampler(std::shared_ptr<Sampler> sampler) { | |||||
| Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) { | |||||
| builder_sampler_ = std::move(sampler); | builder_sampler_ = std::move(sampler); | ||||
| return *this; | return *this; | ||||
| } | } | ||||
| @@ -133,7 +133,7 @@ class RandomDataOp : public ParallelOp { | |||||
| Status SanityCheck() const; | Status SanityCheck() const; | ||||
| std::unique_ptr<DataSchema> builder_data_schema_; | std::unique_ptr<DataSchema> builder_data_schema_; | ||||
| std::shared_ptr<Sampler> builder_sampler_; | |||||
| std::shared_ptr<SamplerRT> builder_sampler_; | |||||
| int32_t builder_num_workers_; | int32_t builder_num_workers_; | ||||
| int32_t builder_op_connector_size_; | int32_t builder_op_connector_size_; | ||||
| int64_t builder_rows_per_buffer_; | int64_t builder_rows_per_buffer_; | ||||
| @@ -152,7 +152,7 @@ class RandomDataOp : public ParallelOp { | |||||
| * @return Builder - The modified builder by reference | * @return Builder - The modified builder by reference | ||||
| */ | */ | ||||
| RandomDataOp(int32_t num_workers, int32_t op_connector_size, int64_t rows_per_buffer, int64_t total_rows, | RandomDataOp(int32_t num_workers, int32_t op_connector_size, int64_t rows_per_buffer, int64_t total_rows, | ||||
| std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler); | |||||
| std::unique_ptr<DataSchema> data_schema, std::shared_ptr<SamplerRT> sampler); | |||||
| /** | /** | ||||
| * Destructor | * Destructor | ||||
| @@ -23,9 +23,9 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| DistributedSampler::DistributedSampler(int64_t num_samples, int64_t num_dev, int64_t dev_id, bool shuffle, | |||||
| uint32_t seed, int64_t offset, bool even_dist) | |||||
| : Sampler(num_samples, std::numeric_limits<int64_t>::max()), | |||||
| DistributedSamplerRT::DistributedSamplerRT(int64_t num_samples, int64_t num_dev, int64_t dev_id, bool shuffle, | |||||
| uint32_t seed, int64_t offset, bool even_dist) | |||||
| : SamplerRT(num_samples, std::numeric_limits<int64_t>::max()), | |||||
| cnt_(0), | cnt_(0), | ||||
| seed_(seed == std::numeric_limits<uint32_t>::max() ? GetSeed() : seed), | seed_(seed == std::numeric_limits<uint32_t>::max() ? GetSeed() : seed), | ||||
| device_id_(dev_id), | device_id_(dev_id), | ||||
| @@ -35,7 +35,7 @@ DistributedSampler::DistributedSampler(int64_t num_samples, int64_t num_dev, int | |||||
| offset_(offset), | offset_(offset), | ||||
| non_empty_(true) {} | non_empty_(true) {} | ||||
| Status DistributedSampler::InitSampler() { | |||||
| Status DistributedSamplerRT::InitSampler() { | |||||
| // Special value of 0 for num_samples means that the user wants to sample the entire set of data. | // Special value of 0 for num_samples means that the user wants to sample the entire set of data. | ||||
| // If the user asked to sample more rows than exists in the dataset, adjust the num_samples accordingly. | // If the user asked to sample more rows than exists in the dataset, adjust the num_samples accordingly. | ||||
| if (num_samples_ == 0 || num_samples_ > num_rows_) { | if (num_samples_ == 0 || num_samples_ > num_rows_) { | ||||
| @@ -74,7 +74,7 @@ Status DistributedSampler::InitSampler() { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status DistributedSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) { | |||||
| Status DistributedSamplerRT::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) { | |||||
| if (cnt_ > samples_per_buffer_) { | if (cnt_ > samples_per_buffer_) { | ||||
| RETURN_STATUS_UNEXPECTED( | RETURN_STATUS_UNEXPECTED( | ||||
| "Number of samples(cnt) that have already been filled in to buffer should be less than or " | "Number of samples(cnt) that have already been filled in to buffer should be less than or " | ||||
| @@ -143,7 +143,7 @@ Status DistributedSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status DistributedSampler::ResetSampler() { | |||||
| Status DistributedSamplerRT::ResetSampler() { | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(cnt_ == samples_per_buffer_, "ERROR Reset() called early/late"); | CHECK_FAIL_RETURN_UNEXPECTED(cnt_ == samples_per_buffer_, "ERROR Reset() called early/late"); | ||||
| cnt_ = 0; | cnt_ = 0; | ||||
| @@ -160,10 +160,10 @@ Status DistributedSampler::ResetSampler() { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| void DistributedSampler::Print(std::ostream &out, bool show_all) const { | |||||
| void DistributedSamplerRT::Print(std::ostream &out, bool show_all) const { | |||||
| out << "\nSampler: DistributedSampler"; | out << "\nSampler: DistributedSampler"; | ||||
| if (show_all) { | if (show_all) { | ||||
| Sampler::Print(out, show_all); | |||||
| SamplerRT::Print(out, show_all); | |||||
| out << "\nseed: " << seed_ << "\ndevice_id: " << device_id_ << "\nnum_devices: " << num_devices_ | out << "\nseed: " << seed_ << "\ndevice_id: " << device_id_ << "\nnum_devices: " << num_devices_ | ||||
| << "\nshuffle: " << shuffle_; | << "\nshuffle: " << shuffle_; | ||||
| } | } | ||||
| @@ -25,7 +25,7 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| class DistributedSampler : public Sampler { | |||||
| class DistributedSamplerRT : public SamplerRT { | |||||
| public: | public: | ||||
| /// \brief Constructor | /// \brief Constructor | ||||
| /// \param[in] num_samples The total number of rows in the dataset | /// \param[in] num_samples The total number of rows in the dataset | ||||
| @@ -40,11 +40,12 @@ class DistributedSampler : public Sampler { | |||||
| /// This option is not exposed in the python API. Current behavior is that the remainder will always | /// This option is not exposed in the python API. Current behavior is that the remainder will always | ||||
| /// be handled by the first n shards, n being the corresponding device id. Please notice that when offset is set, | /// be handled by the first n shards, n being the corresponding device id. Please notice that when offset is set, | ||||
| /// even_dist will be forcibly converted to false for sending rest datasets in concatdataset scenario. | /// even_dist will be forcibly converted to false for sending rest datasets in concatdataset scenario. | ||||
| DistributedSampler(int64_t num_samples, int64_t num_dev, int64_t dev_id, bool shuffle, | |||||
| uint32_t seed = std::numeric_limits<uint32_t>::max(), int64_t offset = -1, bool even_dist = true); | |||||
| DistributedSamplerRT(int64_t num_samples, int64_t num_dev, int64_t dev_id, bool shuffle, | |||||
| uint32_t seed = std::numeric_limits<uint32_t>::max(), int64_t offset = -1, | |||||
| bool even_dist = true); | |||||
| /// \brief default destructor | /// \brief default destructor | ||||
| ~DistributedSampler() = default; | |||||
| ~DistributedSamplerRT() = default; | |||||
| /// \param std::unique_ptr<DataBuffer> * pBuffer | /// \param std::unique_ptr<DataBuffer> * pBuffer | ||||
| /// \param int32_t workerId | /// \param int32_t workerId | ||||
| @@ -20,14 +20,14 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| PKSampler::PKSampler(int64_t num_samples, int64_t val, bool shuffle, int64_t samples_per_buffer) | |||||
| : Sampler(num_samples, samples_per_buffer), | |||||
| PKSamplerRT::PKSamplerRT(int64_t num_samples, int64_t val, bool shuffle, int64_t samples_per_buffer) | |||||
| : SamplerRT(num_samples, samples_per_buffer), | |||||
| shuffle_(shuffle), | shuffle_(shuffle), | ||||
| seed_(GetSeed()), | seed_(GetSeed()), | ||||
| next_id_(0), | next_id_(0), | ||||
| samples_per_class_(val) {} | samples_per_class_(val) {} | ||||
| Status PKSampler::InitSampler() { | |||||
| Status PKSamplerRT::InitSampler() { | |||||
| labels_.reserve(label_to_ids_.size()); | labels_.reserve(label_to_ids_.size()); | ||||
| for (const auto &pair : label_to_ids_) { | for (const auto &pair : label_to_ids_) { | ||||
| if (pair.second.empty() == false) { | if (pair.second.empty() == false) { | ||||
| @@ -61,7 +61,7 @@ Status PKSampler::InitSampler() { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status PKSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) { | |||||
| Status PKSamplerRT::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) { | |||||
| if (next_id_ > num_samples_ || num_samples_ == 0) { | if (next_id_ > num_samples_ || num_samples_ == 0) { | ||||
| RETURN_STATUS_UNEXPECTED("Index must be less than or equal to num_samples, but got: " + std::to_string(next_id_)); | RETURN_STATUS_UNEXPECTED("Index must be less than or equal to num_samples, but got: " + std::to_string(next_id_)); | ||||
| } else if (next_id_ == num_samples_) { | } else if (next_id_ == num_samples_) { | ||||
| @@ -96,7 +96,7 @@ Status PKSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status PKSampler::ResetSampler() { | |||||
| Status PKSamplerRT::ResetSampler() { | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(next_id_ == num_samples_, "ERROR Reset() called early/late"); | CHECK_FAIL_RETURN_UNEXPECTED(next_id_ == num_samples_, "ERROR Reset() called early/late"); | ||||
| next_id_ = 0; | next_id_ = 0; | ||||
| rnd_.seed(seed_++); | rnd_.seed(seed_++); | ||||
| @@ -108,18 +108,18 @@ Status PKSampler::ResetSampler() { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status PKSampler::HandshakeRandomAccessOp(const RandomAccessOp *op) { | |||||
| Status PKSamplerRT::HandshakeRandomAccessOp(const RandomAccessOp *op) { | |||||
| RETURN_UNEXPECTED_IF_NULL(op); | RETURN_UNEXPECTED_IF_NULL(op); | ||||
| RETURN_IF_NOT_OK(op->GetClassIds(&label_to_ids_)); | RETURN_IF_NOT_OK(op->GetClassIds(&label_to_ids_)); | ||||
| RETURN_IF_NOT_OK(InitSampler()); | RETURN_IF_NOT_OK(InitSampler()); | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| void PKSampler::Print(std::ostream &out, bool show_all) const { | |||||
| void PKSamplerRT::Print(std::ostream &out, bool show_all) const { | |||||
| out << "\nSampler: PKSampler"; | out << "\nSampler: PKSampler"; | ||||
| if (show_all) { | if (show_all) { | ||||
| // Call the super class for displaying any common detailed info | // Call the super class for displaying any common detailed info | ||||
| Sampler::Print(out, show_all); | |||||
| SamplerRT::Print(out, show_all); | |||||
| // Then add our own info if any | // Then add our own info if any | ||||
| } | } | ||||
| } | } | ||||
| @@ -26,17 +26,17 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| class PKSampler : public Sampler { // NOT YET FINISHED | |||||
| class PKSamplerRT : public SamplerRT { // NOT YET FINISHED | |||||
| public: | public: | ||||
| // @param num_samples - the number of samples to draw. value of 0 means to take the full amount | // @param num_samples - the number of samples to draw. value of 0 means to take the full amount | ||||
| // @param int64_t val | // @param int64_t val | ||||
| // @param bool shuffle - shuffle all classIds or not, if true, classes may be 5,1,4,3,2 | // @param bool shuffle - shuffle all classIds or not, if true, classes may be 5,1,4,3,2 | ||||
| // @param int64_t samplesPerBuffer - Num of Sampler Ids to fetch via 1 GetNextBuffer call | // @param int64_t samplesPerBuffer - Num of Sampler Ids to fetch via 1 GetNextBuffer call | ||||
| explicit PKSampler(int64_t num_samples, int64_t val, bool shuffle, | |||||
| int64_t samples_per_buffer = std::numeric_limits<int64_t>::max()); | |||||
| explicit PKSamplerRT(int64_t num_samples, int64_t val, bool shuffle, | |||||
| int64_t samples_per_buffer = std::numeric_limits<int64_t>::max()); | |||||
| // default destructor | // default destructor | ||||
| ~PKSampler() = default; | |||||
| ~PKSamplerRT() = default; | |||||
| // @param std::unique_ptr<DataBuffer pBuffer | // @param std::unique_ptr<DataBuffer pBuffer | ||||
| // @param int32_t workerId | // @param int32_t workerId | ||||
| @@ -20,10 +20,10 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| PythonSampler::PythonSampler(int64_t num_samples, py::object py_sampler_instance, int64_t samples_per_buffer) | |||||
| : Sampler(num_samples, samples_per_buffer), py_sampler_instance(py_sampler_instance), need_to_reset_(false) {} | |||||
| PythonSamplerRT::PythonSamplerRT(int64_t num_samples, py::object py_sampler_instance, int64_t samples_per_buffer) | |||||
| : SamplerRT(num_samples, samples_per_buffer), py_sampler_instance(py_sampler_instance), need_to_reset_(false) {} | |||||
| Status PythonSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) { | |||||
| Status PythonSamplerRT::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) { | |||||
| if (need_to_reset_) { | if (need_to_reset_) { | ||||
| (*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE); | (*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE); | ||||
| } else { | } else { | ||||
| @@ -64,7 +64,7 @@ Status PythonSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status PythonSampler::InitSampler() { | |||||
| Status PythonSamplerRT::InitSampler() { | |||||
| CHECK_FAIL_RETURN_UNEXPECTED( | CHECK_FAIL_RETURN_UNEXPECTED( | ||||
| num_rows_ > 0, "Invalid parameter, num_rows must be greater than 0, but got " + std::to_string(num_rows_)); | num_rows_ > 0, "Invalid parameter, num_rows must be greater than 0, but got " + std::to_string(num_rows_)); | ||||
| // Special value of 0 for num_samples means that the user wants to sample the entire set of data. | // Special value of 0 for num_samples means that the user wants to sample the entire set of data. | ||||
| @@ -86,7 +86,7 @@ Status PythonSampler::InitSampler() { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status PythonSampler::ResetSampler() { | |||||
| Status PythonSamplerRT::ResetSampler() { | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(need_to_reset_, "ERROR Reset() called not at end of an epoch"); | CHECK_FAIL_RETURN_UNEXPECTED(need_to_reset_, "ERROR Reset() called not at end of an epoch"); | ||||
| need_to_reset_ = false; | need_to_reset_ = false; | ||||
| py::gil_scoped_acquire gil_acquire; | py::gil_scoped_acquire gil_acquire; | ||||
| @@ -106,11 +106,11 @@ Status PythonSampler::ResetSampler() { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| void PythonSampler::Print(std::ostream &out, bool show_all) const { | |||||
| void PythonSamplerRT::Print(std::ostream &out, bool show_all) const { | |||||
| out << "\nSampler: PythonSampler"; | out << "\nSampler: PythonSampler"; | ||||
| if (show_all) { | if (show_all) { | ||||
| // Call the super class for displaying any common detailed info | // Call the super class for displaying any common detailed info | ||||
| Sampler::Print(out, show_all); | |||||
| SamplerRT::Print(out, show_all); | |||||
| // Then add our own info if any | // Then add our own info if any | ||||
| } | } | ||||
| } | } | ||||
| @@ -23,18 +23,18 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| class PythonSampler : public Sampler { | |||||
| class PythonSamplerRT : public SamplerRT { | |||||
| public: | public: | ||||
| // Constructor | // Constructor | ||||
| // @param num_samples - the number of samples to draw. Value of 0 means to sample all of the | // @param num_samples - the number of samples to draw. Value of 0 means to sample all of the | ||||
| // data from the dataset. | // data from the dataset. | ||||
| // @param py_sampler_instance - the python instance of the sampler | // @param py_sampler_instance - the python instance of the sampler | ||||
| // @param int64_t samples_per_buffer - Num of Sampler Ids to fetch via 1 GetNextBuffer call | // @param int64_t samples_per_buffer - Num of Sampler Ids to fetch via 1 GetNextBuffer call | ||||
| explicit PythonSampler(int64_t num_samples, py::object py_sampler_instance, | |||||
| int64_t samples_per_buffer = std::numeric_limits<int64_t>::max()); | |||||
| explicit PythonSamplerRT(int64_t num_samples, py::object py_sampler_instance, | |||||
| int64_t samples_per_buffer = std::numeric_limits<int64_t>::max()); | |||||
| // Destructor. | // Destructor. | ||||
| ~PythonSampler() = default; | |||||
| ~PythonSamplerRT() = default; | |||||
| // Initialize the sampler. | // Initialize the sampler. | ||||
| // @return Status | // @return Status | ||||
| @@ -22,16 +22,16 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| RandomSampler::RandomSampler(int64_t num_samples, bool replacement, bool reshuffle_each_epoch, | |||||
| int64_t samples_per_buffer) | |||||
| : Sampler(num_samples, samples_per_buffer), | |||||
| RandomSamplerRT::RandomSamplerRT(int64_t num_samples, bool replacement, bool reshuffle_each_epoch, | |||||
| int64_t samples_per_buffer) | |||||
| : SamplerRT(num_samples, samples_per_buffer), | |||||
| seed_(GetSeed()), | seed_(GetSeed()), | ||||
| replacement_(replacement), | replacement_(replacement), | ||||
| next_id_(0), | next_id_(0), | ||||
| reshuffle_each_epoch_(reshuffle_each_epoch), | reshuffle_each_epoch_(reshuffle_each_epoch), | ||||
| dist(nullptr) {} | dist(nullptr) {} | ||||
| Status RandomSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) { | |||||
| Status RandomSamplerRT::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) { | |||||
| if (next_id_ > num_samples_) { | if (next_id_ > num_samples_) { | ||||
| RETURN_STATUS_UNEXPECTED("RandomSampler Internal Error"); | RETURN_STATUS_UNEXPECTED("RandomSampler Internal Error"); | ||||
| } else if (next_id_ == num_samples_) { | } else if (next_id_ == num_samples_) { | ||||
| @@ -68,7 +68,7 @@ Status RandomSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status RandomSampler::InitSampler() { | |||||
| Status RandomSamplerRT::InitSampler() { | |||||
| // Special value of 0 for num_samples means that the user wants to sample the entire set of data. | // Special value of 0 for num_samples means that the user wants to sample the entire set of data. | ||||
| // If the user asked to sample more rows than exists in the dataset, adjust the num_samples accordingly. | // If the user asked to sample more rows than exists in the dataset, adjust the num_samples accordingly. | ||||
| if (num_samples_ == 0 || num_samples_ > num_rows_) { | if (num_samples_ == 0 || num_samples_ > num_rows_) { | ||||
| @@ -94,7 +94,7 @@ Status RandomSampler::InitSampler() { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status RandomSampler::ResetSampler() { | |||||
| Status RandomSamplerRT::ResetSampler() { | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(next_id_ == num_samples_, "ERROR Reset() called early/late"); | CHECK_FAIL_RETURN_UNEXPECTED(next_id_ == num_samples_, "ERROR Reset() called early/late"); | ||||
| next_id_ = 0; | next_id_ = 0; | ||||
| @@ -115,11 +115,11 @@ Status RandomSampler::ResetSampler() { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| void RandomSampler::Print(std::ostream &out, bool show_all) const { | |||||
| void RandomSamplerRT::Print(std::ostream &out, bool show_all) const { | |||||
| out << "\nSampler: RandomSampler"; | out << "\nSampler: RandomSampler"; | ||||
| if (show_all) { | if (show_all) { | ||||
| // Call the super class for displaying any common detailed info | // Call the super class for displaying any common detailed info | ||||
| Sampler::Print(out, show_all); | |||||
| SamplerRT::Print(out, show_all); | |||||
| // Then add our own info if any | // Then add our own info if any | ||||
| } | } | ||||
| } | } | ||||
| @@ -24,18 +24,18 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| class RandomSampler : public Sampler { | |||||
| class RandomSamplerRT : public SamplerRT { | |||||
| public: | public: | ||||
| // Constructor | // Constructor | ||||
| // @param int64_t num_samples - number samples to draw | // @param int64_t num_samples - number samples to draw | ||||
| // @param bool replacement - put he id back / or not after a sample | // @param bool replacement - put he id back / or not after a sample | ||||
| // @param reshuffle_each_epoch - T/F to reshuffle after epoch | // @param reshuffle_each_epoch - T/F to reshuffle after epoch | ||||
| // @param int64_t samples_per_buffer - Num of Sampler Ids to fetch via 1 GetNextBuffer call | // @param int64_t samples_per_buffer - Num of Sampler Ids to fetch via 1 GetNextBuffer call | ||||
| explicit RandomSampler(int64_t num_samples, bool replacement, bool reshuffle_each_epoch, | |||||
| int64_t samples_per_buffer = std::numeric_limits<int64_t>::max()); | |||||
| explicit RandomSamplerRT(int64_t num_samples, bool replacement, bool reshuffle_each_epoch, | |||||
| int64_t samples_per_buffer = std::numeric_limits<int64_t>::max()); | |||||
| // Destructor. | // Destructor. | ||||
| ~RandomSampler() = default; | |||||
| ~RandomSamplerRT() = default; | |||||
| // Op calls this to get next Buffer that contains all the sampleIds | // Op calls this to get next Buffer that contains all the sampleIds | ||||
| // @param std::unique_ptr<DataBuffer> pBuffer - Buffer to be returned to StorageOp | // @param std::unique_ptr<DataBuffer> pBuffer - Buffer to be returned to StorageOp | ||||
| @@ -32,13 +32,13 @@ Status RandomAccessOp::GetNumRowsInDataset(int64_t *num) const { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Sampler::Sampler(int64_t num_samples, int64_t samples_per_buffer) | |||||
| SamplerRT::SamplerRT(int64_t num_samples, int64_t samples_per_buffer) | |||||
| : num_rows_(0), num_samples_(num_samples), samples_per_buffer_(samples_per_buffer), col_desc_(nullptr) {} | : num_rows_(0), num_samples_(num_samples), samples_per_buffer_(samples_per_buffer), col_desc_(nullptr) {} | ||||
| Status Sampler::HandshakeRandomAccessOp(const RandomAccessOp *op) { | |||||
| std::shared_ptr<Sampler> child_sampler; | |||||
| Status SamplerRT::HandshakeRandomAccessOp(const RandomAccessOp *op) { | |||||
| std::shared_ptr<SamplerRT> child_sampler; | |||||
| if (HasChildSampler()) { | if (HasChildSampler()) { | ||||
| child_sampler = std::dynamic_pointer_cast<Sampler>(child_[0]); | |||||
| child_sampler = std::dynamic_pointer_cast<SamplerRT>(child_[0]); | |||||
| if (!child_sampler) { | if (!child_sampler) { | ||||
| std::string err_msg("Cannot handshake, child is not a sampler object."); | std::string err_msg("Cannot handshake, child is not a sampler object."); | ||||
| RETURN_STATUS_UNEXPECTED(err_msg); | RETURN_STATUS_UNEXPECTED(err_msg); | ||||
| @@ -64,7 +64,7 @@ Status Sampler::HandshakeRandomAccessOp(const RandomAccessOp *op) { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status Sampler::CreateSamplerTensor(std::shared_ptr<Tensor> *sample_ids, int64_t num_elements) { | |||||
| Status SamplerRT::CreateSamplerTensor(std::shared_ptr<Tensor> *sample_ids, int64_t num_elements) { | |||||
| if (num_elements == 0) { | if (num_elements == 0) { | ||||
| RETURN_STATUS_UNEXPECTED("Invalid data, num of elements cannot be 0."); | RETURN_STATUS_UNEXPECTED("Invalid data, num of elements cannot be 0."); | ||||
| } | } | ||||
| @@ -77,7 +77,7 @@ Status Sampler::CreateSamplerTensor(std::shared_ptr<Tensor> *sample_ids, int64_t | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| void Sampler::Print(std::ostream &out, bool show_all) const { | |||||
| void SamplerRT::Print(std::ostream &out, bool show_all) const { | |||||
| // Sampler printing is usually only called in the show_all mode. | // Sampler printing is usually only called in the show_all mode. | ||||
| // Derived classes will display the name, then call back to this base | // Derived classes will display the name, then call back to this base | ||||
| // for common info. | // for common info. | ||||
| @@ -88,7 +88,7 @@ void Sampler::Print(std::ostream &out, bool show_all) const { | |||||
| } | } | ||||
| #ifdef ENABLE_PYTHON | #ifdef ENABLE_PYTHON | ||||
| Status Sampler::GetAllIdsThenReset(py::array *data) { | |||||
| Status SamplerRT::GetAllIdsThenReset(py::array *data) { | |||||
| std::unique_ptr<DataBuffer> db; | std::unique_ptr<DataBuffer> db; | ||||
| std::shared_ptr<Tensor> sample_ids; | std::shared_ptr<Tensor> sample_ids; | ||||
| TensorRow sample_row; | TensorRow sample_row; | ||||
| @@ -123,27 +123,27 @@ Status Sampler::GetAllIdsThenReset(py::array *data) { | |||||
| } | } | ||||
| #endif | #endif | ||||
| Status Sampler::SetNumSamples(int64_t num_samples) { | |||||
| Status SamplerRT::SetNumSamples(int64_t num_samples) { | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(num_samples >= 0, "Invalid parameter, num_samples must be greater than or equal to 0."); | CHECK_FAIL_RETURN_UNEXPECTED(num_samples >= 0, "Invalid parameter, num_samples must be greater than or equal to 0."); | ||||
| num_samples_ = num_samples; | num_samples_ = num_samples; | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| int64_t Sampler::GetNumSamples() { return num_samples_; } | |||||
| int64_t SamplerRT::GetNumSamples() { return num_samples_; } | |||||
| Status Sampler::SetNumRowsInDataset(int64_t num_rows) { | |||||
| Status SamplerRT::SetNumRowsInDataset(int64_t num_rows) { | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(num_rows > 0, "Invalid parameter, num_rows must be greater than 0."); | CHECK_FAIL_RETURN_UNEXPECTED(num_rows > 0, "Invalid parameter, num_rows must be greater than 0."); | ||||
| num_rows_ = num_rows; | num_rows_ = num_rows; | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status Sampler::AddChild(std::shared_ptr<Sampler> child) { | |||||
| Status SamplerRT::AddChild(std::shared_ptr<SamplerRT> child) { | |||||
| if (child == nullptr) { | if (child == nullptr) { | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| // Only samplers can be added, not any other DatasetOp. | // Only samplers can be added, not any other DatasetOp. | ||||
| std::shared_ptr<Sampler> sampler = std::dynamic_pointer_cast<Sampler>(child); | |||||
| std::shared_ptr<SamplerRT> sampler = std::dynamic_pointer_cast<SamplerRT>(child); | |||||
| if (!sampler) { | if (!sampler) { | ||||
| std::string err_msg("Cannot add child, child is not a sampler object."); | std::string err_msg("Cannot add child, child is not a sampler object."); | ||||
| RETURN_STATUS_UNEXPECTED(err_msg); | RETURN_STATUS_UNEXPECTED(err_msg); | ||||
| @@ -160,9 +160,9 @@ Status Sampler::AddChild(std::shared_ptr<Sampler> child) { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| bool Sampler::HasChildSampler() { return !child_.empty(); } | |||||
| bool SamplerRT::HasChildSampler() { return !child_.empty(); } | |||||
| Status Sampler::GetAssociatedChildId(int64_t *out_associated_id, int64_t id) { | |||||
| Status SamplerRT::GetAssociatedChildId(int64_t *out_associated_id, int64_t id) { | |||||
| if (child_ids_ == nullptr) { | if (child_ids_ == nullptr) { | ||||
| RETURN_STATUS_UNEXPECTED("Trying to get associated child id, but there are no child ids!"); | RETURN_STATUS_UNEXPECTED("Trying to get associated child id, but there are no child ids!"); | ||||
| } | } | ||||
| @@ -51,21 +51,21 @@ class RandomAccessOp { | |||||
| protected: | protected: | ||||
| // The amount of rows in the dataset itself. This is the before-sampling value, the | // The amount of rows in the dataset itself. This is the before-sampling value, the | ||||
| // total count of rows. A sampler may choose to sample less than this amount. | // total count of rows. A sampler may choose to sample less than this amount. | ||||
| int64_t num_rows_; | |||||
| int64_t num_rows_ = -1; | |||||
| }; | }; | ||||
| class Sampler { | |||||
| class SamplerRT { | |||||
| public: | public: | ||||
| // Constructor | // Constructor | ||||
| // @param int64_t num_samples: the user-requested number of samples ids to generate. A value of 0 | // @param int64_t num_samples: the user-requested number of samples ids to generate. A value of 0 | ||||
| // indicates that the sampler should produce the complete set of ids. | // indicates that the sampler should produce the complete set of ids. | ||||
| // @param int64_t samplesPerBuffer: Num of Sampler Ids to fetch via 1 GetNextBuffer call | // @param int64_t samplesPerBuffer: Num of Sampler Ids to fetch via 1 GetNextBuffer call | ||||
| explicit Sampler(int64_t num_samples, int64_t samples_per_buffer); | |||||
| explicit SamplerRT(int64_t num_samples, int64_t samples_per_buffer); | |||||
| Sampler(const Sampler &s) : Sampler(s.num_samples_, s.samples_per_buffer_) {} | |||||
| SamplerRT(const SamplerRT &s) : SamplerRT(s.num_samples_, s.samples_per_buffer_) {} | |||||
| // default destructor | // default destructor | ||||
| ~Sampler() = default; | |||||
| ~SamplerRT() = default; | |||||
| // Get a list of sample ids. | // Get a list of sample ids. | ||||
| // @note It is Sampler responsibility to make sure that the id is not out of bound. | // @note It is Sampler responsibility to make sure that the id is not out of bound. | ||||
| @@ -111,7 +111,7 @@ class Sampler { | |||||
| // Adds a sampler to become our child. | // Adds a sampler to become our child. | ||||
| // @param std::shared_ptr<DatasetOp> - The sampler to add as a child. | // @param std::shared_ptr<DatasetOp> - The sampler to add as a child. | ||||
| // @return - The error code returned. | // @return - The error code returned. | ||||
| Status AddChild(std::shared_ptr<Sampler> child); | |||||
| Status AddChild(std::shared_ptr<SamplerRT> child); | |||||
| // A helper function to create a int64_t 1-D Tensor specifically used to hold sampleIds for Sampler | // A helper function to create a int64_t 1-D Tensor specifically used to hold sampleIds for Sampler | ||||
| // @param std::shared_ptr<Tensor>* sampleIds | // @param std::shared_ptr<Tensor>* sampleIds | ||||
| @@ -129,7 +129,7 @@ class Sampler { | |||||
| // @param out - reference to the output stream being overloaded | // @param out - reference to the output stream being overloaded | ||||
| // @param sampler - reference to teh sampler to print | // @param sampler - reference to teh sampler to print | ||||
| // @return - the output stream must be returned | // @return - the output stream must be returned | ||||
| friend std::ostream &operator<<(std::ostream &out, const Sampler &sampler) { | |||||
| friend std::ostream &operator<<(std::ostream &out, const SamplerRT &sampler) { | |||||
| sampler.Print(out, false); | sampler.Print(out, false); | ||||
| return out; | return out; | ||||
| } | } | ||||
| @@ -158,7 +158,7 @@ class Sampler { | |||||
| int64_t samples_per_buffer_; | int64_t samples_per_buffer_; | ||||
| std::unique_ptr<ColDescriptor> col_desc_; | std::unique_ptr<ColDescriptor> col_desc_; | ||||
| std::vector<std::shared_ptr<Sampler>> child_; // Child nodes | |||||
| std::vector<std::shared_ptr<SamplerRT>> child_; // Child nodes | |||||
| std::unique_ptr<DataBuffer> child_ids_; | std::unique_ptr<DataBuffer> child_ids_; | ||||
| }; | }; | ||||
| } // namespace dataset | } // namespace dataset | ||||
| @@ -20,10 +20,10 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| SequentialSampler::SequentialSampler(int64_t num_samples, int64_t start_index, int64_t samples_per_buffer) | |||||
| : Sampler(num_samples, samples_per_buffer), start_index_(start_index), current_id_(start_index), id_count_(0) {} | |||||
| SequentialSamplerRT::SequentialSamplerRT(int64_t num_samples, int64_t start_index, int64_t samples_per_buffer) | |||||
| : SamplerRT(num_samples, samples_per_buffer), start_index_(start_index), current_id_(start_index), id_count_(0) {} | |||||
| Status SequentialSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) { | |||||
| Status SequentialSamplerRT::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) { | |||||
| if (id_count_ > num_samples_) { | if (id_count_ > num_samples_) { | ||||
| RETURN_STATUS_UNEXPECTED("SequentialSampler Internal Error"); | RETURN_STATUS_UNEXPECTED("SequentialSampler Internal Error"); | ||||
| } else if (id_count_ == num_samples_) { | } else if (id_count_ == num_samples_) { | ||||
| @@ -62,7 +62,7 @@ Status SequentialSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status SequentialSampler::InitSampler() { | |||||
| Status SequentialSamplerRT::InitSampler() { | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(start_index_ >= 0, | CHECK_FAIL_RETURN_UNEXPECTED(start_index_ >= 0, | ||||
| "Invalid parameter, start_index must be greater than or equal to 0, but got " + | "Invalid parameter, start_index must be greater than or equal to 0, but got " + | ||||
| std::to_string(start_index_) + ".\n"); | std::to_string(start_index_) + ".\n"); | ||||
| @@ -85,7 +85,7 @@ Status SequentialSampler::InitSampler() { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status SequentialSampler::ResetSampler() { | |||||
| Status SequentialSamplerRT::ResetSampler() { | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(id_count_ == num_samples_, "ERROR Reset() called early/late"); | CHECK_FAIL_RETURN_UNEXPECTED(id_count_ == num_samples_, "ERROR Reset() called early/late"); | ||||
| current_id_ = start_index_; | current_id_ = start_index_; | ||||
| id_count_ = 0; | id_count_ = 0; | ||||
| @@ -97,11 +97,11 @@ Status SequentialSampler::ResetSampler() { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| void SequentialSampler::Print(std::ostream &out, bool show_all) const { | |||||
| void SequentialSamplerRT::Print(std::ostream &out, bool show_all) const { | |||||
| out << "\nSampler: SequentialSampler"; | out << "\nSampler: SequentialSampler"; | ||||
| if (show_all) { | if (show_all) { | ||||
| // Call the super class for displaying any common detailed info | // Call the super class for displaying any common detailed info | ||||
| Sampler::Print(out, show_all); | |||||
| SamplerRT::Print(out, show_all); | |||||
| // Then add our own info | // Then add our own info | ||||
| out << "\nStart index: " << start_index_; | out << "\nStart index: " << start_index_; | ||||
| } | } | ||||
| @@ -23,18 +23,18 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| class SequentialSampler : public Sampler { | |||||
| class SequentialSamplerRT : public SamplerRT { | |||||
| public: | public: | ||||
| // Constructor | // Constructor | ||||
| // @param num_samples - The number of samples to draw. A value of 0 indicates the sampler should produce the | // @param num_samples - The number of samples to draw. A value of 0 indicates the sampler should produce the | ||||
| // full amount of ids from the dataset | // full amount of ids from the dataset | ||||
| // @param start_index - The starting index value | // @param start_index - The starting index value | ||||
| // @param int64_t samplesPerBuffer - Num of Sampler Ids to fetch via 1 GetNextBuffer call | // @param int64_t samplesPerBuffer - Num of Sampler Ids to fetch via 1 GetNextBuffer call | ||||
| explicit SequentialSampler(int64_t num_samples, int64_t start_index, | |||||
| int64_t samples_per_buffer = std::numeric_limits<int64_t>::max()); | |||||
| explicit SequentialSamplerRT(int64_t num_samples, int64_t start_index, | |||||
| int64_t samples_per_buffer = std::numeric_limits<int64_t>::max()); | |||||
| // Destructor. | // Destructor. | ||||
| ~SequentialSampler() = default; | |||||
| ~SequentialSamplerRT() = default; | |||||
| // init sampler, called by python | // init sampler, called by python | ||||
| Status InitSampler() override; | Status InitSampler() override; | ||||
| @@ -27,12 +27,12 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| // Constructor. | // Constructor. | ||||
| SubsetRandomSampler::SubsetRandomSampler(int64_t num_samples, const std::vector<int64_t> &indices, | |||||
| int64_t samples_per_buffer) | |||||
| : Sampler(num_samples, samples_per_buffer), indices_(indices), sample_id_(0), buffer_id_(0) {} | |||||
| SubsetRandomSamplerRT::SubsetRandomSamplerRT(int64_t num_samples, const std::vector<int64_t> &indices, | |||||
| int64_t samples_per_buffer) | |||||
| : SamplerRT(num_samples, samples_per_buffer), indices_(indices), sample_id_(0), buffer_id_(0) {} | |||||
| // Initialized this Sampler. | // Initialized this Sampler. | ||||
| Status SubsetRandomSampler::InitSampler() { | |||||
| Status SubsetRandomSamplerRT::InitSampler() { | |||||
| CHECK_FAIL_RETURN_UNEXPECTED( | CHECK_FAIL_RETURN_UNEXPECTED( | ||||
| num_rows_ > 0, "Invalid parameter, num_rows must be greater than 0, but got " + std::to_string(num_rows_) + ".\n"); | num_rows_ > 0, "Invalid parameter, num_rows must be greater than 0, but got " + std::to_string(num_rows_) + ".\n"); | ||||
| @@ -56,7 +56,7 @@ Status SubsetRandomSampler::InitSampler() { | |||||
| } | } | ||||
| // Reset the internal variable to the initial state. | // Reset the internal variable to the initial state. | ||||
| Status SubsetRandomSampler::ResetSampler() { | |||||
| Status SubsetRandomSamplerRT::ResetSampler() { | |||||
| // Reset the internal counters. | // Reset the internal counters. | ||||
| sample_id_ = 0; | sample_id_ = 0; | ||||
| buffer_id_ = 0; | buffer_id_ = 0; | ||||
| @@ -73,7 +73,7 @@ Status SubsetRandomSampler::ResetSampler() { | |||||
| } | } | ||||
| // Get the sample ids. | // Get the sample ids. | ||||
| Status SubsetRandomSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) { | |||||
| Status SubsetRandomSamplerRT::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) { | |||||
| // All samples have been drawn | // All samples have been drawn | ||||
| if (sample_id_ == num_samples_) { | if (sample_id_ == num_samples_) { | ||||
| (*out_buffer) = std::make_unique<DataBuffer>(buffer_id_++, DataBuffer::kDeBFlagEOE); | (*out_buffer) = std::make_unique<DataBuffer>(buffer_id_++, DataBuffer::kDeBFlagEOE); | ||||
| @@ -120,11 +120,11 @@ Status SubsetRandomSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buffe | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| void SubsetRandomSampler::Print(std::ostream &out, bool show_all) const { | |||||
| void SubsetRandomSamplerRT::Print(std::ostream &out, bool show_all) const { | |||||
| out << "\nSampler: SubsetRandomSampler"; | out << "\nSampler: SubsetRandomSampler"; | ||||
| if (show_all) { | if (show_all) { | ||||
| // Call the super class for displaying any common detailed info | // Call the super class for displaying any common detailed info | ||||
| Sampler::Print(out, show_all); | |||||
| SamplerRT::Print(out, show_all); | |||||
| // Then add our own info if any | // Then add our own info if any | ||||
| } | } | ||||
| } | } | ||||
| @@ -25,18 +25,18 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| // Randomly samples elements from a given list of indices, without replacement. | // Randomly samples elements from a given list of indices, without replacement. | ||||
| class SubsetRandomSampler : public Sampler { | |||||
| class SubsetRandomSamplerRT : public SamplerRT { | |||||
| public: | public: | ||||
| // Constructor. | // Constructor. | ||||
| // @param num_samples The number of samples to draw. 0 for the full amount. | // @param num_samples The number of samples to draw. 0 for the full amount. | ||||
| // @param indices List of indices from where we will randomly draw samples. | // @param indices List of indices from where we will randomly draw samples. | ||||
| // @param samples_per_buffer The number of ids we draw on each call to GetNextBuffer(). | // @param samples_per_buffer The number of ids we draw on each call to GetNextBuffer(). | ||||
| // When samplesPerBuffer=0, GetNextBuffer() will draw all the sample ids and return them at once. | // When samplesPerBuffer=0, GetNextBuffer() will draw all the sample ids and return them at once. | ||||
| explicit SubsetRandomSampler(int64_t num_samples, const std::vector<int64_t> &indices, | |||||
| std::int64_t samples_per_buffer = std::numeric_limits<int64_t>::max()); | |||||
| explicit SubsetRandomSamplerRT(int64_t num_samples, const std::vector<int64_t> &indices, | |||||
| std::int64_t samples_per_buffer = std::numeric_limits<int64_t>::max()); | |||||
| // Destructor. | // Destructor. | ||||
| ~SubsetRandomSampler() = default; | |||||
| ~SubsetRandomSamplerRT() = default; | |||||
| // Initialize the sampler. | // Initialize the sampler. | ||||
| // @return Status | // @return Status | ||||
| @@ -27,16 +27,16 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| // Constructor. | // Constructor. | ||||
| WeightedRandomSampler::WeightedRandomSampler(int64_t num_samples, const std::vector<double> &weights, bool replacement, | |||||
| int64_t samples_per_buffer) | |||||
| : Sampler(num_samples, samples_per_buffer), | |||||
| WeightedRandomSamplerRT::WeightedRandomSamplerRT(int64_t num_samples, const std::vector<double> &weights, | |||||
| bool replacement, int64_t samples_per_buffer) | |||||
| : SamplerRT(num_samples, samples_per_buffer), | |||||
| weights_(weights), | weights_(weights), | ||||
| replacement_(replacement), | replacement_(replacement), | ||||
| sample_id_(0), | sample_id_(0), | ||||
| buffer_id_(0) {} | buffer_id_(0) {} | ||||
| // Initialized this Sampler. | // Initialized this Sampler. | ||||
| Status WeightedRandomSampler::InitSampler() { | |||||
| Status WeightedRandomSamplerRT::InitSampler() { | |||||
| // Special value of 0 for num_samples means that the user wants to sample the entire set of data. | // Special value of 0 for num_samples means that the user wants to sample the entire set of data. | ||||
| // If the user asked to sample more rows than exists in the dataset, adjust the num_samples accordingly. | // If the user asked to sample more rows than exists in the dataset, adjust the num_samples accordingly. | ||||
| if (num_samples_ == 0 || num_samples_ > num_rows_) { | if (num_samples_ == 0 || num_samples_ > num_rows_) { | ||||
| @@ -78,7 +78,7 @@ Status WeightedRandomSampler::InitSampler() { | |||||
| } | } | ||||
| // Initialized the computation for generating weighted random numbers without replacement using onepass method. | // Initialized the computation for generating weighted random numbers without replacement using onepass method. | ||||
| void WeightedRandomSampler::InitOnePassSampling() { | |||||
| void WeightedRandomSamplerRT::InitOnePassSampling() { | |||||
| exp_dist_->reset(); | exp_dist_->reset(); | ||||
| onepass_ids_.clear(); | onepass_ids_.clear(); | ||||
| std::vector<std::pair<double, int64_t>> val_idx; | std::vector<std::pair<double, int64_t>> val_idx; | ||||
| @@ -94,7 +94,7 @@ void WeightedRandomSampler::InitOnePassSampling() { | |||||
| } | } | ||||
| // Reset the internal variable to the initial state and reshuffle the indices. | // Reset the internal variable to the initial state and reshuffle the indices. | ||||
| Status WeightedRandomSampler::ResetSampler() { | |||||
| Status WeightedRandomSamplerRT::ResetSampler() { | |||||
| sample_id_ = 0; | sample_id_ = 0; | ||||
| buffer_id_ = 0; | buffer_id_ = 0; | ||||
| rand_gen_.seed(GetSeed()); | rand_gen_.seed(GetSeed()); | ||||
| @@ -112,7 +112,7 @@ Status WeightedRandomSampler::ResetSampler() { | |||||
| } | } | ||||
| // Get the sample ids. | // Get the sample ids. | ||||
| Status WeightedRandomSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) { | |||||
| Status WeightedRandomSamplerRT::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) { | |||||
| if (weights_.size() > static_cast<size_t>(num_rows_)) { | if (weights_.size() > static_cast<size_t>(num_rows_)) { | ||||
| return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, | return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, | ||||
| "Invalid parameter, size of sample weights must be less than or equal to num of data, " | "Invalid parameter, size of sample weights must be less than or equal to num of data, " | ||||
| @@ -180,11 +180,11 @@ Status WeightedRandomSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buf | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| void WeightedRandomSampler::Print(std::ostream &out, bool show_all) const { | |||||
| void WeightedRandomSamplerRT::Print(std::ostream &out, bool show_all) const { | |||||
| out << "\nSampler: WeightedRandomSampler"; | out << "\nSampler: WeightedRandomSampler"; | ||||
| if (show_all) { | if (show_all) { | ||||
| // Call the super class for displaying any common detailed info | // Call the super class for displaying any common detailed info | ||||
| Sampler::Print(out, show_all); | |||||
| SamplerRT::Print(out, show_all); | |||||
| // Then add our own info if any | // Then add our own info if any | ||||
| } | } | ||||
| } | } | ||||
| @@ -26,7 +26,7 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| // Samples elements from id `0, 1, ..., weights.size()-1` with given probabilities (weights). | // Samples elements from id `0, 1, ..., weights.size()-1` with given probabilities (weights). | ||||
| class WeightedRandomSampler : public Sampler { | |||||
| class WeightedRandomSamplerRT : public SamplerRT { | |||||
| public: | public: | ||||
| // Constructor. | // Constructor. | ||||
| // @param num_samples Number of samples to be drawn. | // @param num_samples Number of samples to be drawn. | ||||
| @@ -34,11 +34,11 @@ class WeightedRandomSampler : public Sampler { | |||||
| // @param replacement Determine if samples are drawn with/without replacement. | // @param replacement Determine if samples are drawn with/without replacement. | ||||
| // @param samples_per_buffer The number of ids we draw on each call to GetNextBuffer(). | // @param samples_per_buffer The number of ids we draw on each call to GetNextBuffer(). | ||||
| // When samplesPerBuffer=0, GetNextBuffer() will draw all the sample ids and return them at once. | // When samplesPerBuffer=0, GetNextBuffer() will draw all the sample ids and return them at once. | ||||
| WeightedRandomSampler(int64_t num_samples, const std::vector<double> &weights, bool replacement, | |||||
| int64_t samples_per_buffer = std::numeric_limits<int64_t>::max()); | |||||
| WeightedRandomSamplerRT(int64_t num_samples, const std::vector<double> &weights, bool replacement, | |||||
| int64_t samples_per_buffer = std::numeric_limits<int64_t>::max()); | |||||
| // Destructor. | // Destructor. | ||||
| ~WeightedRandomSampler() = default; | |||||
| ~WeightedRandomSamplerRT() = default; | |||||
| // Initialize the sampler. | // Initialize the sampler. | ||||
| // @param op (Not used in this sampler) | // @param op (Not used in this sampler) | ||||
| @@ -84,7 +84,7 @@ Status TextFileOp::Builder::Build(std::shared_ptr<TextFileOp> *op) { | |||||
| TextFileOp::TextFileOp(int32_t num_workers, int64_t rows_per_buffer, int64_t total_rows, int32_t worker_connector_size, | TextFileOp::TextFileOp(int32_t num_workers, int64_t rows_per_buffer, int64_t total_rows, int32_t worker_connector_size, | ||||
| std::unique_ptr<DataSchema> schema, std::vector<std::string> text_files_list, | std::unique_ptr<DataSchema> schema, std::vector<std::string> text_files_list, | ||||
| int32_t op_connector_size, bool shuffle_files, int32_t num_device, int32_t device_id, | int32_t op_connector_size, bool shuffle_files, int32_t num_device, int32_t device_id, | ||||
| std::shared_ptr<Sampler> sampler) | |||||
| std::shared_ptr<SamplerRT> sampler) | |||||
| : ParallelOp(num_workers, op_connector_size, std::move(sampler)), | : ParallelOp(num_workers, op_connector_size, std::move(sampler)), | ||||
| device_id_(device_id), | device_id_(device_id), | ||||
| num_devices_(num_device), | num_devices_(num_device), | ||||
| @@ -115,7 +115,7 @@ class TextFileOp : public ParallelOp { | |||||
| // Setter method | // Setter method | ||||
| // @param std::shared_ptr<Sampler> sampler | // @param std::shared_ptr<Sampler> sampler | ||||
| // @return Builder setter method returns reference to the builder. | // @return Builder setter method returns reference to the builder. | ||||
| Builder &SetSampler(std::shared_ptr<Sampler> sampler) { | |||||
| Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) { | |||||
| builder_sampler_ = std::move(sampler); | builder_sampler_ = std::move(sampler); | ||||
| return *this; | return *this; | ||||
| } | } | ||||
| @@ -131,7 +131,7 @@ class TextFileOp : public ParallelOp { | |||||
| std::vector<std::string> builder_text_files_list_; | std::vector<std::string> builder_text_files_list_; | ||||
| bool builder_shuffle_files_; | bool builder_shuffle_files_; | ||||
| std::unique_ptr<DataSchema> builder_schema_; | std::unique_ptr<DataSchema> builder_schema_; | ||||
| std::shared_ptr<Sampler> builder_sampler_; | |||||
| std::shared_ptr<SamplerRT> builder_sampler_; | |||||
| }; | }; | ||||
| // Constructor of TextFileOp | // Constructor of TextFileOp | ||||
| @@ -148,7 +148,7 @@ class TextFileOp : public ParallelOp { | |||||
| // @param sampler - allow a sampler. Only valid if a cache exists in ascendent tree nodes | // @param sampler - allow a sampler. Only valid if a cache exists in ascendent tree nodes | ||||
| TextFileOp(int32_t num_workers, int64_t rows_per_buffer, int64_t total_rows, int32_t worker_connector_size, | TextFileOp(int32_t num_workers, int64_t rows_per_buffer, int64_t total_rows, int32_t worker_connector_size, | ||||
| std::unique_ptr<DataSchema>, std::vector<std::string> text_files_list, int32_t op_connector_size, | std::unique_ptr<DataSchema>, std::vector<std::string> text_files_list, int32_t op_connector_size, | ||||
| bool shuffle_files, int32_t num_devices, int32_t device_id, std::shared_ptr<Sampler> sampler); | |||||
| bool shuffle_files, int32_t num_devices, int32_t device_id, std::shared_ptr<SamplerRT> sampler); | |||||
| // Default destructor | // Default destructor | ||||
| ~TextFileOp() = default; | ~TextFileOp() = default; | ||||
| @@ -58,7 +58,7 @@ TFReaderOp::Builder::Builder() | |||||
| builder_data_schema_ = std::make_unique<DataSchema>(); | builder_data_schema_ = std::make_unique<DataSchema>(); | ||||
| } | } | ||||
| bool ValidateFirstRowCrc(const std::string &filename) { | |||||
| bool TFReaderOp::ValidateFirstRowCrc(const std::string &filename) { | |||||
| std::ifstream reader; | std::ifstream reader; | ||||
| reader.open(filename); | reader.open(filename); | ||||
| if (!reader) { | if (!reader) { | ||||
| @@ -134,7 +134,7 @@ TFReaderOp::TFReaderOp(int32_t num_workers, int32_t worker_connector_size, int64 | |||||
| int64_t total_num_rows, std::vector<std::string> dataset_files_list, | int64_t total_num_rows, std::vector<std::string> dataset_files_list, | ||||
| std::unique_ptr<DataSchema> data_schema, int32_t op_connector_size, | std::unique_ptr<DataSchema> data_schema, int32_t op_connector_size, | ||||
| std::vector<std::string> columns_to_load, bool shuffle_files, int32_t num_device, | std::vector<std::string> columns_to_load, bool shuffle_files, int32_t num_device, | ||||
| int32_t device_id, bool equal_rows_per_shard, std::shared_ptr<Sampler> sampler) | |||||
| int32_t device_id, bool equal_rows_per_shard, std::shared_ptr<SamplerRT> sampler) | |||||
| : ParallelOp(num_workers, op_connector_size, std::move(sampler)), | : ParallelOp(num_workers, op_connector_size, std::move(sampler)), | ||||
| device_id_(device_id), | device_id_(device_id), | ||||
| num_devices_(num_device), | num_devices_(num_device), | ||||
| @@ -156,14 +156,14 @@ class TFReaderOp : public ParallelOp { | |||||
| // Setter method | // Setter method | ||||
| // @param std::shared_ptr<Sampler> sampler | // @param std::shared_ptr<Sampler> sampler | ||||
| // @return Builder setter method returns reference to the builder. | // @return Builder setter method returns reference to the builder. | ||||
| Builder &SetSampler(std::shared_ptr<Sampler> sampler) { | |||||
| Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) { | |||||
| builder_sampler_ = std::move(sampler); | builder_sampler_ = std::move(sampler); | ||||
| return *this; | return *this; | ||||
| } | } | ||||
| private: | private: | ||||
| std::unique_ptr<DataSchema> builder_data_schema_; | std::unique_ptr<DataSchema> builder_data_schema_; | ||||
| std::shared_ptr<Sampler> builder_sampler_; | |||||
| std::shared_ptr<SamplerRT> builder_sampler_; | |||||
| int32_t builder_device_id_; | int32_t builder_device_id_; | ||||
| int32_t builder_num_devices_; | int32_t builder_num_devices_; | ||||
| int32_t builder_num_workers_; | int32_t builder_num_workers_; | ||||
| @@ -193,7 +193,7 @@ class TFReaderOp : public ParallelOp { | |||||
| TFReaderOp(int32_t num_workers, int32_t worker_connector_size, int64_t rows_per_buffer, int64_t total_num_rows, | TFReaderOp(int32_t num_workers, int32_t worker_connector_size, int64_t rows_per_buffer, int64_t total_num_rows, | ||||
| std::vector<std::string> dataset_files_list, std::unique_ptr<DataSchema> data_schema, | std::vector<std::string> dataset_files_list, std::unique_ptr<DataSchema> data_schema, | ||||
| int32_t op_connector_size, std::vector<std::string> columns_to_load, bool shuffle_files, | int32_t op_connector_size, std::vector<std::string> columns_to_load, bool shuffle_files, | ||||
| int32_t num_devices, int32_t device_id, bool equal_rows_per_shard, std::shared_ptr<Sampler> sampler); | |||||
| int32_t num_devices, int32_t device_id, bool equal_rows_per_shard, std::shared_ptr<SamplerRT> sampler); | |||||
| // Default destructor | // Default destructor | ||||
| ~TFReaderOp() = default; | ~TFReaderOp() = default; | ||||
| @@ -262,6 +262,8 @@ class TFReaderOp : public ParallelOp { | |||||
| /// \return Status of the function | /// \return Status of the function | ||||
| Status GetDatasetSize(int64_t *dataset_size) override; | Status GetDatasetSize(int64_t *dataset_size) override; | ||||
| static bool ValidateFirstRowCrc(const std::string &filename); | |||||
| private: | private: | ||||
| // The entry point for when workers are launched. | // The entry point for when workers are launched. | ||||
| // @param worker_id - the id of the worker that is executing this function. | // @param worker_id - the id of the worker that is executing this function. | ||||
| @@ -62,7 +62,7 @@ Status VOCOp::Builder::Build(std::shared_ptr<VOCOp> *ptr) { | |||||
| if (builder_sampler_ == nullptr) { | if (builder_sampler_ == nullptr) { | ||||
| const int64_t num_samples = 0; | const int64_t num_samples = 0; | ||||
| const int64_t start_index = 0; | const int64_t start_index = 0; | ||||
| builder_sampler_ = std::make_shared<SequentialSampler>(start_index, num_samples); | |||||
| builder_sampler_ = std::make_shared<SequentialSamplerRT>(start_index, num_samples); | |||||
| } | } | ||||
| builder_schema_ = std::make_unique<DataSchema>(); | builder_schema_ = std::make_unique<DataSchema>(); | ||||
| if (builder_task_type_ == TaskType::Segmentation) { | if (builder_task_type_ == TaskType::Segmentation) { | ||||
| @@ -102,7 +102,8 @@ Status VOCOp::Builder::SanityCheck() { | |||||
| VOCOp::VOCOp(const TaskType &task_type, const std::string &task_mode, const std::string &folder_path, | VOCOp::VOCOp(const TaskType &task_type, const std::string &task_mode, const std::string &folder_path, | ||||
| const std::map<std::string, int32_t> &class_index, int32_t num_workers, int32_t rows_per_buffer, | const std::map<std::string, int32_t> &class_index, int32_t num_workers, int32_t rows_per_buffer, | ||||
| int32_t queue_size, bool decode, std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler) | |||||
| int32_t queue_size, bool decode, std::unique_ptr<DataSchema> data_schema, | |||||
| std::shared_ptr<SamplerRT> sampler) | |||||
| : ParallelOp(num_workers, queue_size, std::move(sampler)), | : ParallelOp(num_workers, queue_size, std::move(sampler)), | ||||
| decode_(decode), | decode_(decode), | ||||
| row_cnt_(0), | row_cnt_(0), | ||||
| @@ -118,7 +118,7 @@ class VOCOp : public ParallelOp, public RandomAccessOp { | |||||
| // Setter method. | // Setter method. | ||||
| // @param std::shared_ptr<Sampler> sampler | // @param std::shared_ptr<Sampler> sampler | ||||
| // @return Builder setter method returns reference to the builder. | // @return Builder setter method returns reference to the builder. | ||||
| Builder &SetSampler(std::shared_ptr<Sampler> sampler) { | |||||
| Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) { | |||||
| builder_sampler_ = std::move(sampler); | builder_sampler_ = std::move(sampler); | ||||
| return *this; | return *this; | ||||
| } | } | ||||
| @@ -148,7 +148,7 @@ class VOCOp : public ParallelOp, public RandomAccessOp { | |||||
| int32_t builder_num_workers_; | int32_t builder_num_workers_; | ||||
| int32_t builder_op_connector_size_; | int32_t builder_op_connector_size_; | ||||
| int32_t builder_rows_per_buffer_; | int32_t builder_rows_per_buffer_; | ||||
| std::shared_ptr<Sampler> builder_sampler_; | |||||
| std::shared_ptr<SamplerRT> builder_sampler_; | |||||
| std::unique_ptr<DataSchema> builder_schema_; | std::unique_ptr<DataSchema> builder_schema_; | ||||
| std::map<std::string, int32_t> builder_labels_to_read_; | std::map<std::string, int32_t> builder_labels_to_read_; | ||||
| }; | }; | ||||
| @@ -166,7 +166,7 @@ class VOCOp : public ParallelOp, public RandomAccessOp { | |||||
| // @param std::shared_ptr<Sampler> sampler - sampler tells VOCOp what to read | // @param std::shared_ptr<Sampler> sampler - sampler tells VOCOp what to read | ||||
| VOCOp(const TaskType &task_type, const std::string &task_mode, const std::string &folder_path, | VOCOp(const TaskType &task_type, const std::string &task_mode, const std::string &folder_path, | ||||
| const std::map<std::string, int32_t> &class_index, int32_t num_workers, int32_t rows_per_buffer, | const std::map<std::string, int32_t> &class_index, int32_t num_workers, int32_t rows_per_buffer, | ||||
| int32_t queue_size, bool decode, std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler); | |||||
| int32_t queue_size, bool decode, std::unique_ptr<DataSchema> data_schema, std::shared_ptr<SamplerRT> sampler); | |||||
| // Destructor | // Destructor | ||||
| ~VOCOp() = default; | ~VOCOp() = default; | ||||
| @@ -21,7 +21,7 @@ | |||||
| #include "minddata/dataset/util/status.h" | #include "minddata/dataset/util/status.h" | ||||
| #include "minddata/dataset/engine/datasetops/dataset_op.h" | #include "minddata/dataset/engine/datasetops/dataset_op.h" | ||||
| namespace mindspore::dataset::api { | |||||
| namespace mindspore::dataset { | |||||
| class DatasetCache { | class DatasetCache { | ||||
| public: | public: | ||||
| @@ -29,6 +29,6 @@ class DatasetCache { | |||||
| virtual Status ValidateParams() = 0; | virtual Status ValidateParams() = 0; | ||||
| virtual Status CreateCacheOp(int num_workers, std::shared_ptr<DatasetOp> *ds_op) = 0; | virtual Status CreateCacheOp(int num_workers, std::shared_ptr<DatasetOp> *ds_op) = 0; | ||||
| }; | }; | ||||
| } // namespace mindspore::dataset::api | |||||
| } // namespace mindspore::dataset | |||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_CACHE_DATASET_CACHE_H_ | #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_CACHE_DATASET_CACHE_H_ | ||||
| @@ -18,7 +18,7 @@ | |||||
| #include "minddata/dataset/engine/ir/cache/dataset_cache_impl.h" | #include "minddata/dataset/engine/ir/cache/dataset_cache_impl.h" | ||||
| #include "minddata/dataset/engine/datasetops/cache_op.h" | #include "minddata/dataset/engine/datasetops/cache_op.h" | ||||
| namespace mindspore::dataset::api { | |||||
| namespace mindspore::dataset { | |||||
| /// Method to initialize the DatasetCache by creating an instance of a CacheClient | /// Method to initialize the DatasetCache by creating an instance of a CacheClient | ||||
| /// \return Status Error code | /// \return Status Error code | ||||
| @@ -41,4 +41,4 @@ Status DatasetCacheImpl::CreateCacheOp(int32_t num_workers, std::shared_ptr<Data | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| } // namespace mindspore::dataset::api | |||||
| } // namespace mindspore::dataset | |||||
| @@ -24,7 +24,7 @@ | |||||
| #include "minddata/dataset/engine/datasetops/cache_op.h" | #include "minddata/dataset/engine/datasetops/cache_op.h" | ||||
| #include "minddata/dataset/engine/ir/cache/dataset_cache.h" | #include "minddata/dataset/engine/ir/cache/dataset_cache.h" | ||||
| namespace mindspore::dataset::api { | |||||
| namespace mindspore::dataset { | |||||
| /// DatasetCache is the IR of CacheClient | /// DatasetCache is the IR of CacheClient | ||||
| class DatasetCacheImpl : public DatasetCache { | class DatasetCacheImpl : public DatasetCache { | ||||
| @@ -67,6 +67,6 @@ class DatasetCacheImpl : public DatasetCache { | |||||
| std::optional<int32_t> num_connections_; | std::optional<int32_t> num_connections_; | ||||
| std::optional<int32_t> prefetch_sz_; | std::optional<int32_t> prefetch_sz_; | ||||
| }; | }; | ||||
| } // namespace mindspore::dataset::api | |||||
| } // namespace mindspore::dataset | |||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_CACHE_DATASET_CACHE_IMPL_H_ | #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_CACHE_DATASET_CACHE_IMPL_H_ | ||||
| @@ -26,7 +26,6 @@ | |||||
| #include "minddata/dataset/util/status.h" | #include "minddata/dataset/util/status.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| namespace api { | |||||
| #ifdef ENABLE_PYTHON | #ifdef ENABLE_PYTHON | ||||
| // constructor #1, called by Pybind | // constructor #1, called by Pybind | ||||
| @@ -96,6 +95,5 @@ std::vector<std::shared_ptr<DatasetOp>> BatchNode::Build() { | |||||
| return node_ops; | return node_ops; | ||||
| } | } | ||||
| } // namespace api | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -27,7 +27,6 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| namespace api { | |||||
| class BatchNode : public DatasetNode { | class BatchNode : public DatasetNode { | ||||
| public: | public: | ||||
| @@ -66,7 +65,6 @@ class BatchNode : public DatasetNode { | |||||
| std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> pad_map_; | std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> pad_map_; | ||||
| }; | }; | ||||
| } // namespace api | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_BATCH_NODE_H_ | #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_BATCH_NODE_H_ | ||||
| @@ -27,7 +27,7 @@ | |||||
| #include "minddata/dataset/util/status.h" | #include "minddata/dataset/util/status.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| namespace api { | |||||
| BucketBatchByLengthNode::BucketBatchByLengthNode( | BucketBatchByLengthNode::BucketBatchByLengthNode( | ||||
| std::shared_ptr<DatasetNode> child, const std::vector<std::string> &column_names, | std::shared_ptr<DatasetNode> child, const std::vector<std::string> &column_names, | ||||
| const std::vector<int32_t> &bucket_boundaries, const std::vector<int32_t> &bucket_batch_sizes, | const std::vector<int32_t> &bucket_boundaries, const std::vector<int32_t> &bucket_batch_sizes, | ||||
| @@ -121,6 +121,5 @@ Status BucketBatchByLengthNode::ValidateParams() { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| } // namespace api | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -27,7 +27,7 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| namespace api { | |||||
| class BucketBatchByLengthNode : public DatasetNode { | class BucketBatchByLengthNode : public DatasetNode { | ||||
| public: | public: | ||||
| /// \brief Constructor | /// \brief Constructor | ||||
| @@ -58,7 +58,6 @@ class BucketBatchByLengthNode : public DatasetNode { | |||||
| bool drop_remainder_; | bool drop_remainder_; | ||||
| }; | }; | ||||
| } // namespace api | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_BUCKET_BATCH_BY_LENGTH_NODE_H_ | #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_BUCKET_BATCH_BY_LENGTH_NODE_H_ | ||||
| @@ -26,7 +26,6 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| namespace api { | |||||
| BuildSentenceVocabNode::BuildSentenceVocabNode(std::shared_ptr<DatasetNode> child, | BuildSentenceVocabNode::BuildSentenceVocabNode(std::shared_ptr<DatasetNode> child, | ||||
| std::shared_ptr<SentencePieceVocab> vocab, | std::shared_ptr<SentencePieceVocab> vocab, | ||||
| @@ -77,6 +76,6 @@ Status BuildSentenceVocabNode::ValidateParams() { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| } // namespace api | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -27,7 +27,6 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| namespace api { | |||||
| class BuildSentenceVocabNode : public DatasetNode { | class BuildSentenceVocabNode : public DatasetNode { | ||||
| public: | public: | ||||
| @@ -56,7 +55,6 @@ class BuildSentenceVocabNode : public DatasetNode { | |||||
| std::unordered_map<std::string, std::string> params_; | std::unordered_map<std::string, std::string> params_; | ||||
| }; | }; | ||||
| } // namespace api | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_BUILD_SENTENCE_PIECE_VOCAB_NODE_H_ | #endif // #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_BUILD_SENTENCE_PIECE_VOCAB_NODE_H_ | ||||
| @@ -26,7 +26,6 @@ | |||||
| #include "minddata/dataset/util/status.h" | #include "minddata/dataset/util/status.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| namespace api { | |||||
| BuildVocabNode::BuildVocabNode(std::shared_ptr<DatasetNode> child, std::shared_ptr<Vocab> vocab, | BuildVocabNode::BuildVocabNode(std::shared_ptr<DatasetNode> child, std::shared_ptr<Vocab> vocab, | ||||
| const std::vector<std::string> &columns, const std::pair<int64_t, int64_t> &freq_range, | const std::vector<std::string> &columns, const std::pair<int64_t, int64_t> &freq_range, | ||||
| @@ -78,6 +77,6 @@ Status BuildVocabNode::ValidateParams() { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| } // namespace api | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -26,7 +26,6 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| namespace api { | |||||
| class BuildVocabNode : public DatasetNode { | class BuildVocabNode : public DatasetNode { | ||||
| public: | public: | ||||
| @@ -55,7 +54,6 @@ class BuildVocabNode : public DatasetNode { | |||||
| bool special_first_; | bool special_first_; | ||||
| }; | }; | ||||
| } // namespace api | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_BUILD_VOCAB_NODE_H_ | #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_BUILD_VOCAB_NODE_H_ | ||||
| @@ -25,7 +25,7 @@ | |||||
| #include "minddata/dataset/util/status.h" | #include "minddata/dataset/util/status.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| namespace api { | |||||
| // Function to build ConcatOp | // Function to build ConcatOp | ||||
| ConcatNode::ConcatNode(const std::vector<std::shared_ptr<DatasetNode>> &datasets) { this->children = datasets; } | ConcatNode::ConcatNode(const std::vector<std::shared_ptr<DatasetNode>> &datasets) { this->children = datasets; } | ||||
| @@ -53,6 +53,5 @@ std::vector<std::shared_ptr<DatasetOp>> ConcatNode::Build() { | |||||
| return node_ops; | return node_ops; | ||||
| } | } | ||||
| } // namespace api | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -25,7 +25,6 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| namespace api { | |||||
| class ConcatNode : public DatasetNode { | class ConcatNode : public DatasetNode { | ||||
| public: | public: | ||||
| @@ -44,7 +43,6 @@ class ConcatNode : public DatasetNode { | |||||
| Status ValidateParams() override; | Status ValidateParams() override; | ||||
| }; | }; | ||||
| } // namespace api | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_CONCAT_NODE_H_ | #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_CONCAT_NODE_H_ | ||||
| @@ -16,11 +16,187 @@ | |||||
| #include "minddata/dataset/engine/ir/datasetops/dataset_node.h" | #include "minddata/dataset/engine/ir/datasetops/dataset_node.h" | ||||
| #include <algorithm> | |||||
| #include <memory> | #include <memory> | ||||
| #include <set> | |||||
| #include "minddata/dataset/util/random.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| namespace api { | |||||
| // Helper function to compute a default shuffle size | |||||
| Status ComputeShuffleSize(int64_t num_files, int64_t num_devices, int64_t num_rows, int64_t total_rows, | |||||
| int64_t *shuffle_size) { | |||||
| const int64_t average_files_multiplier = 4; | |||||
| const int64_t shuffle_max = 10000; | |||||
| int64_t avg_rows_per_file = 0; | |||||
| // Adjust the num rows per shard if sharding was given | |||||
| if (num_devices > 0) { | |||||
| if (num_rows % num_devices == 0) { | |||||
| num_rows = num_rows / num_devices; | |||||
| } else { | |||||
| num_rows = (num_rows / num_devices) + 1; | |||||
| } | |||||
| } | |||||
| // Cap based on total rows directive. Some ops do not have this and give value of 0. | |||||
| if (total_rows > 0) { | |||||
| num_rows = std::min(num_rows, total_rows); | |||||
| } | |||||
| // get the average per file | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(num_files != 0, "The size of dataset_files must greater than 0."); | |||||
| avg_rows_per_file = num_rows / num_files; | |||||
| *shuffle_size = std::max(avg_rows_per_file * average_files_multiplier, shuffle_max); | |||||
| return Status::OK(); | |||||
| } | |||||
| // Helper function to inject a shuffle operator over top of current operator being built | |||||
| Status AddShuffleOp(int64_t num_files, int64_t num_devices, int64_t num_rows, int64_t total_rows, | |||||
| int32_t connector_que_size, int32_t rows_per_buffer, std::shared_ptr<DatasetOp> *shuffle_op) { | |||||
| std::shared_ptr<ShuffleOp> new_shuffle_op = nullptr; | |||||
| int64_t shuffle_size = 0; | |||||
| RETURN_EMPTY_IF_ERROR(ComputeShuffleSize(num_files, num_devices, num_rows, total_rows, &shuffle_size)); | |||||
| MS_LOG(INFO) << "Dataset::AddShuffleOp - num_rows: " << num_rows << ", shuffle_size: " << shuffle_size; | |||||
| // Add the shuffle op | |||||
| *shuffle_op = std::make_shared<ShuffleOp>(shuffle_size, GetSeed(), connector_que_size, true, rows_per_buffer); | |||||
| return Status::OK(); | |||||
| } | |||||
| // Helper function to validate dataset directory parameter | |||||
| Status ValidateDatasetDirParam(const std::string &dataset_name, std::string dataset_dir) { | |||||
| if (dataset_dir.empty()) { | |||||
| std::string err_msg = dataset_name + ": dataset_dir is not specified."; | |||||
| MS_LOG(ERROR) << err_msg; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| Path dir(dataset_dir); | |||||
| if (!dir.IsDirectory()) { | |||||
| std::string err_msg = dataset_name + ": dataset_dir: [" + dataset_dir + "] is an invalid directory path."; | |||||
| MS_LOG(ERROR) << err_msg; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| if (access(dataset_dir.c_str(), R_OK) == -1) { | |||||
| std::string err_msg = dataset_name + ": No access to specified dataset path: " + dataset_dir; | |||||
| MS_LOG(ERROR) << err_msg; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| // Helper function to validate dataset files parameter | |||||
| Status ValidateDatasetFilesParam(const std::string &dataset_name, const std::vector<std::string> &dataset_files) { | |||||
| if (dataset_files.empty()) { | |||||
| std::string err_msg = dataset_name + ": dataset_files is not specified."; | |||||
| MS_LOG(ERROR) << err_msg; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| for (auto f : dataset_files) { | |||||
| Path dataset_file(f); | |||||
| if (!dataset_file.Exists()) { | |||||
| std::string err_msg = dataset_name + ": dataset file: [" + f + "] is invalid or does not exist."; | |||||
| MS_LOG(ERROR) << err_msg; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| if (access(dataset_file.toString().c_str(), R_OK) == -1) { | |||||
| std::string err_msg = dataset_name + ": No access to specified dataset file: " + f; | |||||
| MS_LOG(ERROR) << err_msg; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| // Helper function to validate dataset num_shards and shard_id parameters | |||||
| Status ValidateDatasetShardParams(const std::string &dataset_name, int32_t num_shards, int32_t shard_id) { | |||||
| if (num_shards <= 0) { | |||||
| std::string err_msg = dataset_name + ": Invalid num_shards: " + std::to_string(num_shards); | |||||
| MS_LOG(ERROR) << err_msg; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| if (shard_id < 0 || shard_id >= num_shards) { | |||||
| // num_shards; | |||||
| std::string err_msg = dataset_name + ": Invalid input, shard_id: " + std::to_string(shard_id) + | |||||
| ", num_shards: " + std::to_string(num_shards); | |||||
| MS_LOG(ERROR) << err_msg; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| // Helper function to validate dataset sampler parameter | |||||
| Status ValidateDatasetSampler(const std::string &dataset_name, const std::shared_ptr<SamplerObj> &sampler) { | |||||
| if (sampler == nullptr) { | |||||
| std::string err_msg = dataset_name + ": Sampler is not constructed correctly, sampler: nullptr"; | |||||
| MS_LOG(ERROR) << err_msg; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| Status ValidateStringValue(const std::string &dataset_name, const std::string &str, | |||||
| const std::unordered_set<std::string> &valid_strings) { | |||||
| if (valid_strings.find(str) == valid_strings.end()) { | |||||
| std::string mode; | |||||
| mode = std::accumulate(valid_strings.begin(), valid_strings.end(), mode, | |||||
| [](std::string a, std::string b) { return std::move(a) + " " + std::move(b); }); | |||||
| std::string err_msg = dataset_name + ": " + str + " does not match any mode in [" + mode + " ]"; | |||||
| MS_LOG(ERROR) << err_msg; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| // Helper function to validate dataset input/output column parameter | |||||
| Status ValidateDatasetColumnParam(const std::string &dataset_name, const std::string &column_param, | |||||
| const std::vector<std::string> &columns) { | |||||
| if (columns.empty()) { | |||||
| std::string err_msg = dataset_name + ":" + column_param + " should not be empty string"; | |||||
| MS_LOG(ERROR) << err_msg; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| for (uint32_t i = 0; i < columns.size(); ++i) { | |||||
| if (columns[i].empty()) { | |||||
| std::string err_msg = dataset_name + ":" + column_param + "[" + std::to_string(i) + "] must not be empty"; | |||||
| MS_LOG(ERROR) << err_msg; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| } | |||||
| std::set<std::string> columns_set(columns.begin(), columns.end()); | |||||
| if (columns_set.size() != columns.size()) { | |||||
| std::string err_msg = dataset_name + ":" + column_param + ": Every column name should not be same with others"; | |||||
| MS_LOG(ERROR) << err_msg; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| std::shared_ptr<SamplerObj> SelectSampler(int64_t num_samples, bool shuffle, int32_t num_shards, int32_t shard_id) { | |||||
| if (shuffle) { | |||||
| if (num_shards > 1) { | |||||
| // If shuffle enabled, sharding enabled, use distributed random sampler | |||||
| return DistributedSampler(num_shards, shard_id, shuffle, num_samples); | |||||
| } | |||||
| // If shuffle enabled, sharding disabled, use random sampler | |||||
| return RandomSampler(num_samples >= 0, num_samples); | |||||
| } | |||||
| if (num_shards > 1) { | |||||
| // If shuffle disabled, sharding enabled, use distributed sequential sampler | |||||
| return DistributedSampler(num_shards, shard_id, shuffle, num_samples); | |||||
| } | |||||
| // If shuffle disabled, sharding disabled, use sequential sampler | |||||
| return SequentialSampler(0, num_samples); | |||||
| } | |||||
| Status DatasetNode::AddCacheOp(std::vector<std::shared_ptr<DatasetOp>> *node_ops) { | Status DatasetNode::AddCacheOp(std::vector<std::shared_ptr<DatasetOp>> *node_ops) { | ||||
| if (cache_ != nullptr) { | if (cache_ != nullptr) { | ||||
| @@ -60,6 +236,5 @@ DatasetNode::DatasetNode() { | |||||
| worker_connector_size_ = cfg->worker_connector_size(); | worker_connector_size_ = cfg->worker_connector_size(); | ||||
| } | } | ||||
| } // namespace api | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -28,7 +28,6 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| namespace api { | |||||
| class Dataset; | class Dataset; | ||||
| class SamplerObj; | class SamplerObj; | ||||
| @@ -120,7 +119,6 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> { | |||||
| int32_t worker_connector_size_; | int32_t worker_connector_size_; | ||||
| }; | }; | ||||
| } // namespace api | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_DATASET_NODE_H_ | #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_DATASET_NODE_H_ | ||||
| @@ -26,7 +26,6 @@ | |||||
| #include "minddata/dataset/util/status.h" | #include "minddata/dataset/util/status.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| namespace api { | |||||
| MapNode::MapNode(std::shared_ptr<DatasetNode> child, std::vector<std::shared_ptr<TensorOperation>> operations, | MapNode::MapNode(std::shared_ptr<DatasetNode> child, std::vector<std::shared_ptr<TensorOperation>> operations, | ||||
| std::vector<std::string> input_columns, std::vector<std::string> output_columns, | std::vector<std::string> input_columns, std::vector<std::string> output_columns, | ||||
| @@ -86,6 +85,5 @@ Status MapNode::ValidateParams() { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| } // namespace api | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -25,7 +25,7 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| namespace api { | |||||
| class MapNode : public DatasetNode { | class MapNode : public DatasetNode { | ||||
| public: | public: | ||||
| /// \brief Constructor | /// \brief Constructor | ||||
| @@ -51,7 +51,6 @@ class MapNode : public DatasetNode { | |||||
| std::vector<std::string> project_columns_; | std::vector<std::string> project_columns_; | ||||
| }; | }; | ||||
| } // namespace api | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_MAP_NODE_H_ | #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_MAP_NODE_H_ | ||||
| @@ -25,7 +25,6 @@ | |||||
| #include "minddata/dataset/util/status.h" | #include "minddata/dataset/util/status.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| namespace api { | |||||
| // Function to build ProjectOp | // Function to build ProjectOp | ||||
| ProjectNode::ProjectNode(std::shared_ptr<DatasetNode> child, const std::vector<std::string> &columns) | ProjectNode::ProjectNode(std::shared_ptr<DatasetNode> child, const std::vector<std::string> &columns) | ||||
| @@ -53,6 +52,5 @@ std::vector<std::shared_ptr<DatasetOp>> ProjectNode::Build() { | |||||
| return node_ops; | return node_ops; | ||||
| } | } | ||||
| } // namespace api | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -26,8 +26,6 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| namespace api { | |||||
| class ProjectNode : public DatasetNode { | class ProjectNode : public DatasetNode { | ||||
| public: | public: | ||||
| /// \brief Constructor | /// \brief Constructor | ||||
| @@ -48,7 +46,6 @@ class ProjectNode : public DatasetNode { | |||||
| std::vector<std::string> columns_; | std::vector<std::string> columns_; | ||||
| }; | }; | ||||
| } // namespace api | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_PROJECT_NODE_H_ | #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_PROJECT_NODE_H_ | ||||
| @@ -25,7 +25,7 @@ | |||||
| #include "minddata/dataset/util/status.h" | #include "minddata/dataset/util/status.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| namespace api { | |||||
| // Function to build RenameOp | // Function to build RenameOp | ||||
| RenameNode::RenameNode(std::shared_ptr<DatasetNode> child, const std::vector<std::string> &input_columns, | RenameNode::RenameNode(std::shared_ptr<DatasetNode> child, const std::vector<std::string> &input_columns, | ||||
| const std::vector<std::string> &output_columns) | const std::vector<std::string> &output_columns) | ||||
| @@ -54,6 +54,6 @@ std::vector<std::shared_ptr<DatasetOp>> RenameNode::Build() { | |||||
| node_ops.push_back(std::make_shared<RenameOp>(input_columns_, output_columns_, connector_que_size_)); | node_ops.push_back(std::make_shared<RenameOp>(input_columns_, output_columns_, connector_que_size_)); | ||||
| return node_ops; | return node_ops; | ||||
| } | } | ||||
| } // namespace api | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -26,8 +26,6 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| namespace api { | |||||
| class RenameNode : public DatasetNode { | class RenameNode : public DatasetNode { | ||||
| public: | public: | ||||
| /// \brief Constructor | /// \brief Constructor | ||||
| @@ -50,7 +48,6 @@ class RenameNode : public DatasetNode { | |||||
| std::vector<std::string> output_columns_; | std::vector<std::string> output_columns_; | ||||
| }; | }; | ||||
| } // namespace api | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_RENAME_NODE_H_ | #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_RENAME_NODE_H_ | ||||
| @@ -25,7 +25,6 @@ | |||||
| #include "minddata/dataset/util/status.h" | #include "minddata/dataset/util/status.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| namespace api { | |||||
| RepeatNode::RepeatNode(std::shared_ptr<DatasetNode> child, int32_t count) : repeat_count_(count) { | RepeatNode::RepeatNode(std::shared_ptr<DatasetNode> child, int32_t count) : repeat_count_(count) { | ||||
| this->children.push_back(child); | this->children.push_back(child); | ||||
| @@ -49,6 +48,6 @@ Status RepeatNode::ValidateParams() { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| } // namespace api | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -28,8 +28,6 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| namespace api { | |||||
| class RepeatNode : public DatasetNode { | class RepeatNode : public DatasetNode { | ||||
| public: | public: | ||||
| /// \brief Constructor | /// \brief Constructor | ||||
| @@ -50,7 +48,6 @@ class RepeatNode : public DatasetNode { | |||||
| int32_t repeat_count_; | int32_t repeat_count_; | ||||
| }; | }; | ||||
| } // namespace api | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_REPEAT_NODE_H_ | #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_REPEAT_NODE_H_ | ||||
| @@ -25,7 +25,6 @@ | |||||
| #include "minddata/dataset/util/status.h" | #include "minddata/dataset/util/status.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| namespace api { | |||||
| // Constructor for ShuffleNode | // Constructor for ShuffleNode | ||||
| ShuffleNode::ShuffleNode(std::shared_ptr<DatasetNode> child, int32_t shuffle_size, bool reset_every_epoch) | ShuffleNode::ShuffleNode(std::shared_ptr<DatasetNode> child, int32_t shuffle_size, bool reset_every_epoch) | ||||
| @@ -54,6 +53,5 @@ Status ShuffleNode::ValidateParams() { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| } // namespace api | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -28,8 +28,6 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| namespace api { | |||||
| class ShuffleNode : public DatasetNode { | class ShuffleNode : public DatasetNode { | ||||
| public: | public: | ||||
| ShuffleNode(std::shared_ptr<DatasetNode> child, int32_t shuffle_size, bool reset_every_epoch); | ShuffleNode(std::shared_ptr<DatasetNode> child, int32_t shuffle_size, bool reset_every_epoch); | ||||
| @@ -46,7 +44,6 @@ class ShuffleNode : public DatasetNode { | |||||
| bool reset_every_epoch_; | bool reset_every_epoch_; | ||||
| }; | }; | ||||
| } // namespace api | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SHUFFLE_NODE_H_ | #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SHUFFLE_NODE_H_ | ||||
| @@ -25,7 +25,6 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| namespace api { | |||||
| // Constructor for SkipNode | // Constructor for SkipNode | ||||
| SkipNode::SkipNode(std::shared_ptr<DatasetNode> child, int32_t count) : skip_count_(count) { | SkipNode::SkipNode(std::shared_ptr<DatasetNode> child, int32_t count) : skip_count_(count) { | ||||
| @@ -52,6 +51,5 @@ Status SkipNode::ValidateParams() { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| } // namespace api | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -26,7 +26,6 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| namespace api { | |||||
| class SkipNode : public DatasetNode { | class SkipNode : public DatasetNode { | ||||
| public: | public: | ||||
| /// \brief Constructor | /// \brief Constructor | ||||
| @@ -46,7 +45,7 @@ class SkipNode : public DatasetNode { | |||||
| private: | private: | ||||
| int32_t skip_count_; | int32_t skip_count_; | ||||
| }; | }; | ||||
| } // namespace api | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SKIP_NODE_H_ | #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SKIP_NODE_H_ | ||||
| @@ -27,7 +27,7 @@ | |||||
| #include "minddata/dataset/util/status.h" | #include "minddata/dataset/util/status.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| namespace api { | |||||
| // Constructor for AlbumNode | // Constructor for AlbumNode | ||||
| AlbumNode::AlbumNode(const std::string &dataset_dir, const std::string &data_schema, | AlbumNode::AlbumNode(const std::string &dataset_dir, const std::string &data_schema, | ||||
| const std::vector<std::string> &column_names, bool decode, | const std::vector<std::string> &column_names, bool decode, | ||||
| @@ -78,6 +78,5 @@ Status AlbumNode::GetShardId(int32_t *shard_id) { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| } // namespace api | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -25,7 +25,6 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| namespace api { | |||||
| class AlbumNode : public DatasetNode { | class AlbumNode : public DatasetNode { | ||||
| public: | public: | ||||
| @@ -57,7 +56,6 @@ class AlbumNode : public DatasetNode { | |||||
| std::shared_ptr<SamplerObj> sampler_; | std::shared_ptr<SamplerObj> sampler_; | ||||
| }; | }; | ||||
| } // namespace api | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_ALBUM_NODE_H_ | #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_ALBUM_NODE_H_ | ||||
| @@ -26,7 +26,7 @@ | |||||
| #include "minddata/dataset/util/status.h" | #include "minddata/dataset/util/status.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| namespace api { | |||||
| // Constructor for CelebANode | // Constructor for CelebANode | ||||
| CelebANode::CelebANode(const std::string &dataset_dir, const std::string &usage, | CelebANode::CelebANode(const std::string &dataset_dir, const std::string &usage, | ||||
| const std::shared_ptr<SamplerObj> &sampler, const bool &decode, | const std::shared_ptr<SamplerObj> &sampler, const bool &decode, | ||||
| @@ -76,6 +76,5 @@ Status CelebANode::GetShardId(int32_t *shard_id) { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| } // namespace api | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||