diff --git a/mindspore/ccsrc/minddata/dataset/api/config.cc b/mindspore/ccsrc/minddata/dataset/api/config.cc index 6e6556173a..ead8f1b80f 100644 --- a/mindspore/ccsrc/minddata/dataset/api/config.cc +++ b/mindspore/ccsrc/minddata/dataset/api/config.cc @@ -21,7 +21,6 @@ namespace mindspore { namespace dataset { -namespace api { // Config operations for setting and getting the configuration. namespace config { @@ -104,6 +103,5 @@ bool load(std::string file) { } } // namespace config -} // namespace api } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/api/datasets.cc b/mindspore/ccsrc/minddata/dataset/api/datasets.cc index 289abdd194..785fbb58f8 100644 --- a/mindspore/ccsrc/minddata/dataset/api/datasets.cc +++ b/mindspore/ccsrc/minddata/dataset/api/datasets.cc @@ -21,36 +21,14 @@ #include #include "minddata/dataset/include/samplers.h" #include "minddata/dataset/include/transforms.h" -// Source dataset headers (in alphabetical order) -#include "minddata/dataset/engine/dataset_iterator.h" -#include "minddata/dataset/engine/datasetops/source/album_op.h" -#include "minddata/dataset/engine/datasetops/source/celeba_op.h" -#include "minddata/dataset/engine/datasetops/source/cifar_op.h" -#include "minddata/dataset/engine/datasetops/source/clue_op.h" -#include "minddata/dataset/engine/datasetops/source/coco_op.h" -#include "minddata/dataset/engine/datasetops/source/csv_op.h" -#include "minddata/dataset/engine/datasetops/source/image_folder_op.h" + #ifndef ENABLE_ANDROID -#include "minddata/dataset/engine/datasetops/source/manifest_op.h" -#include "minddata/dataset/engine/datasetops/source/mindrecord_op.h" + #include "minddata/dataset/engine/ir/cache/dataset_cache_impl.h" #endif -#include "minddata/dataset/engine/datasetops/source/mnist_op.h" -#include "minddata/dataset/engine/datasetops/source/random_data_op.h" -#include "minddata/dataset/engine/datasetops/source/text_file_op.h" -#ifndef ENABLE_ANDROID -#include "minddata/dataset/engine/datasetops/source/tf_reader_op.h" -#include "minddata/dataset/engine/datasetops/source/voc_op.h" -#endif -// Dataset operator headers (in alphabetical order) -#include "minddata/dataset/engine/datasetops/map_op/map_op.h" -#include "minddata/dataset/engine/datasetops/skip_op.h" -#include "minddata/dataset/engine/datasetops/zip_op.h" // Sampler headers (in alphabetical order) -#include "minddata/dataset/engine/datasetops/source/sampler/random_sampler.h" #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" -#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h" // IR non-leaf nodes #include "minddata/dataset/engine/ir/datasetops/batch_node.h" @@ -99,7 +77,6 @@ namespace mindspore { namespace dataset { -namespace api { // Function to create the iterator, which will build and launch the execution tree. std::shared_ptr Dataset::CreateIterator(std::vector columns) { @@ -317,7 +294,7 @@ std::shared_ptr Schema(const std::string &schema_file) { return schema->init() ? schema : nullptr; } -// FUNCTIONS TO CREATE DATASETS FOR LEAF-NODE DATASETS +// FUNCTIONS TO CREATE DATASETS FOR LEAF CLASSES // (In alphabetical order) // Function to create a AlbumDataset. @@ -466,7 +443,7 @@ std::shared_ptr VOC(const std::string &dataset_dir, const std::strin } #endif -// Function to create a ZipNode. +// Function to create a ZipDatset. std::shared_ptr Zip(const std::vector> &datasets) { auto ds = std::make_shared(datasets); return ds; @@ -639,7 +616,7 @@ std::shared_ptr Dataset::BuildSentencePieceVocab( std::unique_ptr runtime_context = std::make_unique(); Status rc = runtime_context->Init(); if (rc.IsError()) { - MS_LOG(ERROR) << "Failed to init runtime context. Error status: " << rc; + MS_LOG(ERROR) << "BuildSentencePieceVocab: Failed to init runtime context. Error status: " << rc; return nullptr; } @@ -647,15 +624,15 @@ std::shared_ptr Dataset::BuildSentencePieceVocab( BuildVocabConsumer *bv_consumer = consumer.get(); rc = consumer->Init(ds); if (rc.IsError()) { - MS_LOG(ERROR) << "BuildVocab: Failed to init. Error status: " << rc; + MS_LOG(ERROR) << "BuildSentencePieceVocab: Failed to init consumer. Error status: " << rc; return nullptr; } runtime_context->AssignConsumer(std::move(consumer)); - // Run tree here to starting building vocab + // Run tree here to starting building SentencePieceVocab rc = bv_consumer->Start(); if (rc.IsError()) { - MS_LOG(ERROR) << "BuildVocab: Failed to start. Error status: " << rc; + MS_LOG(ERROR) << "BuildSentencePieceVocab: Failed to start consumer. Error status: " << rc; return nullptr; } return vocab; @@ -671,7 +648,7 @@ std::shared_ptr Dataset::BuildVocab(const std::vector &colum std::unique_ptr runtime_context = std::make_unique(); Status rc = runtime_context->Init(); if (rc.IsError()) { - MS_LOG(ERROR) << "Failed to init runtime context. Error status: " << rc; + MS_LOG(ERROR) << "BuildVocab: Failed to init runtime context. Error status: " << rc; return nullptr; } @@ -679,7 +656,7 @@ std::shared_ptr Dataset::BuildVocab(const std::vector &colum BuildVocabConsumer *bv_consumer = consumer.get(); rc = consumer->Init(ds); if (rc.IsError()) { - MS_LOG(ERROR) << "BuildVocab: Failed to init. Error status: " << rc; + MS_LOG(ERROR) << "BuildVocab: Failed to init consumer. Error status: " << rc; return nullptr; } runtime_context->AssignConsumer(std::move(consumer)); @@ -687,11 +664,14 @@ std::shared_ptr Dataset::BuildVocab(const std::vector &colum // Run tree here to starting building vocab rc = bv_consumer->Start(); if (rc.IsError()) { - MS_LOG(ERROR) << "BuildVocab: Failed to start. Error status: " << rc; + MS_LOG(ERROR) << "BuildVocab: Failed to start consumer. Error status: " << rc; return nullptr; } return vocab; } +std::shared_ptr Dataset::Batch(int32_t batch_size, bool drop_remainder) { + return std::make_shared(shared_from_this(), batch_size, drop_remainder); +} #endif SchemaObj::SchemaObj(const std::string &schema_file) : schema_file_(schema_file), num_rows_(0), dataset_type_("") {} @@ -856,162 +836,6 @@ bool SchemaObj::from_json(nlohmann::json json_obj) { // OTHER FUNCTIONS -// Helper function to compute a default shuffle size -Status ComputeShuffleSize(int64_t num_files, int64_t num_devices, int64_t num_rows, int64_t total_rows, - int64_t *shuffle_size) { - const int64_t average_files_multiplier = 4; - const int64_t shuffle_max = 10000; - int64_t avg_rows_per_file = 0; - - // Adjust the num rows per shard if sharding was given - if (num_devices > 0) { - if (num_rows % num_devices == 0) { - num_rows = num_rows / num_devices; - } else { - num_rows = (num_rows / num_devices) + 1; - } - } - - // Cap based on total rows directive. Some ops do not have this and give value of 0. - if (total_rows > 0) { - num_rows = std::min(num_rows, total_rows); - } - - // get the average per file - CHECK_FAIL_RETURN_UNEXPECTED(num_files != 0, "The size of dataset_files must greater than 0."); - avg_rows_per_file = num_rows / num_files; - - *shuffle_size = std::max(avg_rows_per_file * average_files_multiplier, shuffle_max); - return Status::OK(); -} - -// Helper function to inject a shuffle operator over top of current operator being built -Status AddShuffleOp(int64_t num_files, int64_t num_devices, int64_t num_rows, int64_t total_rows, - int32_t connector_que_size, int32_t rows_per_buffer, std::shared_ptr *shuffle_op) { - std::shared_ptr new_shuffle_op = nullptr; - int64_t shuffle_size = 0; - RETURN_EMPTY_IF_ERROR(ComputeShuffleSize(num_files, num_devices, num_rows, total_rows, &shuffle_size)); - MS_LOG(INFO) << "Dataset::AddShuffleOp - num_rows: " << num_rows << ", shuffle_size: " << shuffle_size; - // Add the shuffle op - *shuffle_op = std::make_shared(shuffle_size, GetSeed(), connector_que_size, true, rows_per_buffer); - return Status::OK(); -} - -// Helper function to validate dataset directory parameter -Status ValidateDatasetDirParam(const std::string &dataset_name, std::string dataset_dir) { - if (dataset_dir.empty()) { - std::string err_msg = dataset_name + ": dataset_dir is not specified."; - MS_LOG(ERROR) << err_msg; - RETURN_STATUS_SYNTAX_ERROR(err_msg); - } - - Path dir(dataset_dir); - if (!dir.IsDirectory()) { - std::string err_msg = dataset_name + ": dataset_dir: [" + dataset_dir + "] is an invalid directory path."; - MS_LOG(ERROR) << err_msg; - RETURN_STATUS_SYNTAX_ERROR(err_msg); - } - - if (access(dataset_dir.c_str(), R_OK) == -1) { - std::string err_msg = dataset_name + ": No access to specified dataset path: " + dataset_dir; - MS_LOG(ERROR) << err_msg; - RETURN_STATUS_SYNTAX_ERROR(err_msg); - } - - return Status::OK(); -} - -// Helper function to validate dataset files parameter -Status ValidateDatasetFilesParam(const std::string &dataset_name, const std::vector &dataset_files) { - if (dataset_files.empty()) { - std::string err_msg = dataset_name + ": dataset_files is not specified."; - MS_LOG(ERROR) << err_msg; - RETURN_STATUS_SYNTAX_ERROR(err_msg); - } - - for (auto f : dataset_files) { - Path dataset_file(f); - if (!dataset_file.Exists()) { - std::string err_msg = dataset_name + ": dataset file: [" + f + "] is invalid or does not exist."; - MS_LOG(ERROR) << err_msg; - RETURN_STATUS_SYNTAX_ERROR(err_msg); - } - if (access(dataset_file.toString().c_str(), R_OK) == -1) { - std::string err_msg = dataset_name + ": No access to specified dataset file: " + f; - MS_LOG(ERROR) << err_msg; - RETURN_STATUS_SYNTAX_ERROR(err_msg); - } - } - - return Status::OK(); -} - -// Helper function to validate dataset num_shards and shard_id parameters -Status ValidateDatasetShardParams(const std::string &dataset_name, int32_t num_shards, int32_t shard_id) { - if (num_shards <= 0) { - std::string err_msg = dataset_name + ": Invalid num_shards: " + std::to_string(num_shards); - MS_LOG(ERROR) << err_msg; - RETURN_STATUS_SYNTAX_ERROR(err_msg); - } - - if (shard_id < 0 || shard_id >= num_shards) { - // num_shards; - std::string err_msg = dataset_name + ": Invalid input, shard_id: " + std::to_string(shard_id) + - ", num_shards: " + std::to_string(num_shards); - MS_LOG(ERROR) << err_msg; - RETURN_STATUS_SYNTAX_ERROR(err_msg); - } - - return Status::OK(); -} - -// Helper function to validate dataset sampler parameter -Status ValidateDatasetSampler(const std::string &dataset_name, const std::shared_ptr &sampler) { - if (sampler == nullptr) { - std::string err_msg = dataset_name + ": Sampler is not constructed correctly, sampler: nullptr"; - MS_LOG(ERROR) << err_msg; - RETURN_STATUS_SYNTAX_ERROR(err_msg); - } - return Status::OK(); -} - -Status ValidateStringValue(const std::string &dataset_name, const std::string &str, - const std::unordered_set &valid_strings) { - if (valid_strings.find(str) == valid_strings.end()) { - std::string mode; - mode = std::accumulate(valid_strings.begin(), valid_strings.end(), mode, - [](std::string a, std::string b) { return std::move(a) + " " + std::move(b); }); - std::string err_msg = dataset_name + ": " + str + " does not match any mode in [" + mode + " ]"; - MS_LOG(ERROR) << err_msg; - RETURN_STATUS_SYNTAX_ERROR(err_msg); - } - return Status::OK(); -} - -// Helper function to validate dataset input/output column parameter -Status ValidateDatasetColumnParam(const std::string &dataset_name, const std::string &column_param, - const std::vector &columns) { - if (columns.empty()) { - std::string err_msg = dataset_name + ":" + column_param + " should not be empty string"; - MS_LOG(ERROR) << err_msg; - RETURN_STATUS_SYNTAX_ERROR(err_msg); - } - for (uint32_t i = 0; i < columns.size(); ++i) { - if (columns[i].empty()) { - std::string err_msg = dataset_name + ":" + column_param + "[" + std::to_string(i) + "] must not be empty"; - MS_LOG(ERROR) << err_msg; - RETURN_STATUS_SYNTAX_ERROR(err_msg); - } - } - std::set columns_set(columns.begin(), columns.end()); - if (columns_set.size() != columns.size()) { - std::string err_msg = dataset_name + ":" + column_param + ": Every column name should not be same with others"; - MS_LOG(ERROR) << err_msg; - RETURN_STATUS_SYNTAX_ERROR(err_msg); - } - return Status::OK(); -} - #ifndef ENABLE_ANDROID std::shared_ptr CreateDatasetCache(session_id_type id, uint64_t mem_sz, bool spill, @@ -1153,22 +977,5 @@ TFRecordDataset::TFRecordDataset(const std::vector &dataset_files, ir_node_ = std::static_pointer_cast(ds); } #endif -std::shared_ptr SelectSampler(int64_t num_samples, bool shuffle, int32_t num_shards, int32_t shard_id) { - if (shuffle) { - if (num_shards > 1) { - // If shuffle enabled, sharding enabled, use distributed random sampler - return DistributedSampler(num_shards, shard_id, shuffle, num_samples); - } - // If shuffle enabled, sharding disabled, use random sampler - return RandomSampler(num_samples >= 0, num_samples); - } - if (num_shards > 1) { - // If shuffle disabled, sharding enabled, use distributed sequential sampler - return DistributedSampler(num_shards, shard_id, shuffle, num_samples); - } - // If shuffle disabled, sharding disabled, use sequential sampler - return SequentialSampler(0, num_samples); -} -} // namespace api } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/api/execute.cc b/mindspore/ccsrc/minddata/dataset/api/execute.cc index d1db26c5c6..23222c4d47 100644 --- a/mindspore/ccsrc/minddata/dataset/api/execute.cc +++ b/mindspore/ccsrc/minddata/dataset/api/execute.cc @@ -26,7 +26,6 @@ namespace mindspore { namespace dataset { -namespace api { Execute::Execute(std::shared_ptr op) : op_(std::move(op)) {} @@ -54,6 +53,5 @@ std::shared_ptr Execute::operator()(std::shared_ptr(std::move(de_output)); } -} // namespace api } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/api/iterator.cc b/mindspore/ccsrc/minddata/dataset/api/iterator.cc index 1e8e74ca1a..5410ba4052 100644 --- a/mindspore/ccsrc/minddata/dataset/api/iterator.cc +++ b/mindspore/ccsrc/minddata/dataset/api/iterator.cc @@ -20,7 +20,6 @@ namespace mindspore { namespace dataset { -namespace api { // Get the next row from the data pipeline. bool Iterator::GetNextRow(TensorMap *row) { @@ -45,19 +44,18 @@ bool Iterator::GetNextRow(TensorVec *row) { } // Shut down the data pipeline. -void Iterator::Stop() { runtime_context->Terminate(); } +void Iterator::Stop() { runtime_context_->Terminate(); } // Function to build and launch the execution tree. Status Iterator::BuildAndLaunchTree(std::shared_ptr ds) { - runtime_context = std::make_unique(); - RETURN_IF_NOT_OK(runtime_context->Init()); + runtime_context_ = std::make_unique(); + RETURN_IF_NOT_OK(runtime_context_->Init()); auto consumer = std::make_unique(); consumer_ = consumer.get(); RETURN_IF_NOT_OK(consumer->Init(ds->IRNode())); - runtime_context->AssignConsumer(std::move(consumer)); + runtime_context_->AssignConsumer(std::move(consumer)); return Status::OK(); } -} // namespace api } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/engine/datasetops/source/sampler/bindings.cc b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/engine/datasetops/source/sampler/bindings.cc index b7bc88aeb9..0b41cc1c7e 100644 --- a/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/engine/datasetops/source/sampler/bindings.cc +++ b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/engine/datasetops/source/sampler/bindings.cc @@ -27,59 +27,59 @@ namespace mindspore { namespace dataset { -PYBIND_REGISTER(Sampler, 0, ([](const py::module *m) { - (void)py::class_>(*m, "Sampler") +PYBIND_REGISTER(SamplerRT, 0, ([](const py::module *m) { + (void)py::class_>(*m, "Sampler") .def("set_num_rows", - [](Sampler &self, int64_t rows) { THROW_IF_ERROR(self.SetNumRowsInDataset(rows)); }) + [](SamplerRT &self, int64_t rows) { THROW_IF_ERROR(self.SetNumRowsInDataset(rows)); }) .def("set_num_samples", - [](Sampler &self, int64_t samples) { THROW_IF_ERROR(self.SetNumSamples(samples)); }) - .def("initialize", [](Sampler &self) { THROW_IF_ERROR(self.InitSampler()); }) + [](SamplerRT &self, int64_t samples) { THROW_IF_ERROR(self.SetNumSamples(samples)); }) + .def("initialize", [](SamplerRT &self) { THROW_IF_ERROR(self.InitSampler()); }) .def("get_indices", - [](Sampler &self) { + [](SamplerRT &self) { py::array ret; THROW_IF_ERROR(self.GetAllIdsThenReset(&ret)); return ret; }) - .def("add_child", [](std::shared_ptr self, std::shared_ptr child) { + .def("add_child", [](std::shared_ptr self, std::shared_ptr child) { THROW_IF_ERROR(self->AddChild(child)); }); })); -PYBIND_REGISTER(DistributedSampler, 1, ([](const py::module *m) { - (void)py::class_>( +PYBIND_REGISTER(DistributedSamplerRT, 1, ([](const py::module *m) { + (void)py::class_>( *m, "DistributedSampler") .def(py::init()); })); -PYBIND_REGISTER(PKSampler, 1, ([](const py::module *m) { - (void)py::class_>(*m, "PKSampler") +PYBIND_REGISTER(PKSamplerRT, 1, ([](const py::module *m) { + (void)py::class_>(*m, "PKSampler") .def(py::init()); })); -PYBIND_REGISTER(PythonSampler, 1, ([](const py::module *m) { - (void)py::class_>(*m, "PythonSampler") +PYBIND_REGISTER(PythonSamplerRT, 1, ([](const py::module *m) { + (void)py::class_>(*m, "PythonSampler") .def(py::init()); })); -PYBIND_REGISTER(RandomSampler, 1, ([](const py::module *m) { - (void)py::class_>(*m, "RandomSampler") +PYBIND_REGISTER(RandomSamplerRT, 1, ([](const py::module *m) { + (void)py::class_>(*m, "RandomSampler") .def(py::init()); })); -PYBIND_REGISTER(SequentialSampler, 1, ([](const py::module *m) { - (void)py::class_>(*m, - "SequentialSampler") +PYBIND_REGISTER(SequentialSamplerRT, 1, ([](const py::module *m) { + (void)py::class_>( + *m, "SequentialSampler") .def(py::init()); })); -PYBIND_REGISTER(SubsetRandomSampler, 1, ([](const py::module *m) { - (void)py::class_>( +PYBIND_REGISTER(SubsetRandomSamplerRT, 1, ([](const py::module *m) { + (void)py::class_>( *m, "SubsetRandomSampler") .def(py::init>()); })); -PYBIND_REGISTER(WeightedRandomSampler, 1, ([](const py::module *m) { - (void)py::class_>( +PYBIND_REGISTER(WeightedRandomSamplerRT, 1, ([](const py::module *m) { + (void)py::class_>( *m, "WeightedRandomSampler") .def(py::init, bool>()); })); diff --git a/mindspore/ccsrc/minddata/dataset/api/python/de_pipeline.cc b/mindspore/ccsrc/minddata/dataset/api/python/de_pipeline.cc index 16632aad89..a484d611fb 100644 --- a/mindspore/ccsrc/minddata/dataset/api/python/de_pipeline.cc +++ b/mindspore/ccsrc/minddata/dataset/api/python/de_pipeline.cc @@ -1140,7 +1140,7 @@ Status DEPipeline::ParseConcatOp(const py::dict &args, std::shared_ptr(value).attr("create"); - std::shared_ptr sampler = create().cast>(); + std::shared_ptr sampler = create().cast>(); (void)builder->SetSampler(std::move(sampler)); } if (key == "children_flag_and_nums") { @@ -1164,7 +1164,7 @@ Status DEPipeline::ParseTFReaderOp(const py::dict &args, std::shared_ptr files_list; std::shared_ptr cache_client = nullptr; - std::shared_ptr sampler = nullptr; + std::shared_ptr sampler = nullptr; int num_workers = 0; std::shared_ptr builder = std::make_shared(); if (!args["dataset_files"].is_none()) { @@ -1210,7 +1210,7 @@ Status DEPipeline::ParseTFReaderOp(const py::dict &args, std::shared_ptr>(); } else if (key == "sampler") { auto create = py::reinterpret_borrow(value).attr("create"); - sampler = create().cast>(); + sampler = create().cast>(); } } } @@ -1234,7 +1234,7 @@ Status DEPipeline::ParseTFReaderOp(const py::dict &args, std::shared_ptr(num_samples, start_index); + sampler = std::make_shared(num_samples, start_index); (void)builder->SetSampler(std::move(sampler)); } @@ -1308,7 +1308,7 @@ Status DEPipeline::ParseImageFolderOp(const py::dict &args, std::shared_ptrSetNumWorkers(num_workers); } else if (key == "sampler") { auto create = py::reinterpret_borrow(value).attr("create"); - std::shared_ptr sampler = create().cast>(); + std::shared_ptr sampler = create().cast>(); (void)builder->SetSampler(std::move(sampler)); } else if (key == "extensions") { (void)builder->SetExtensions(ToStringSet(value)); @@ -1363,7 +1363,7 @@ Status DEPipeline::ParseManifestOp(const py::dict &args, std::shared_ptrSetNumWorkers(num_workers); } else if (key == "sampler") { auto create = py::reinterpret_borrow(value).attr("create"); - std::shared_ptr sampler = create().cast>(); + std::shared_ptr sampler = create().cast>(); (void)builder->SetSampler(std::move(sampler)); } else if (key == "class_indexing") { (void)builder->SetClassIndex(ToStringMap(value)); @@ -1416,7 +1416,7 @@ Status DEPipeline::ParseVOCOp(const py::dict &args, std::shared_ptr * (void)builder->SetNumWorkers(num_workers); } else if (key == "sampler") { auto create = py::reinterpret_borrow(value).attr("create"); - std::shared_ptr sampler = create().cast>(); + std::shared_ptr sampler = create().cast>(); (void)builder->SetSampler(std::move(sampler)); } else if (key == "decode") { (void)builder->SetDecode(ToBool(value)); @@ -1478,7 +1478,7 @@ Status DEPipeline::ParseCocoOp(const py::dict &args, std::shared_ptr (void)builder->SetNumWorkers(num_workers); } else if (key == "sampler") { auto create = py::reinterpret_borrow(value).attr("create"); - std::shared_ptr sampler = create().cast>(); + std::shared_ptr sampler = create().cast>(); (void)builder->SetSampler(std::move(sampler)); } else if (key == "decode") { (void)builder->SetDecode(ToBool(value)); @@ -1529,7 +1529,7 @@ Status DEPipeline::ParseCifar10Op(const py::dict &args, std::shared_ptrSetNumWorkers(num_workers); } else if (key == "sampler") { auto create = py::reinterpret_borrow(value).attr("create"); - std::shared_ptr sampler = create().cast>(); + std::shared_ptr sampler = create().cast>(); (void)builder->SetSampler(std::move(sampler)); } else if (key == "usage") { (void)builder->SetUsage(ToString(value)); @@ -1583,7 +1583,7 @@ Status DEPipeline::ParseCifar100Op(const py::dict &args, std::shared_ptrSetNumWorkers(num_workers); } else if (key == "sampler") { auto create = py::reinterpret_borrow(value).attr("create"); - std::shared_ptr sampler = create().cast>(); + std::shared_ptr sampler = create().cast>(); (void)builder->SetSampler(std::move(sampler)); } else if (key == "usage") { (void)builder->SetUsage(ToString(value)); @@ -1618,7 +1618,7 @@ Status DEPipeline::ParseRandomDataOp(const py::dict &args, std::shared_ptr cache_client = nullptr; - std::shared_ptr sampler = nullptr; + std::shared_ptr sampler = nullptr; int num_workers = 0; if (args["total_rows"].is_none()) { @@ -1646,7 +1646,7 @@ Status DEPipeline::ParseRandomDataOp(const py::dict &args, std::shared_ptr>(); } else if (key == "sampler") { auto create = py::reinterpret_borrow(value).attr("create"); - sampler = create().cast>(); + sampler = create().cast>(); } } } @@ -1670,7 +1670,7 @@ Status DEPipeline::ParseRandomDataOp(const py::dict &args, std::shared_ptr(num_samples, start_index); + sampler = std::make_shared(num_samples, start_index); (void)builder.SetSampler(std::move(sampler)); } @@ -1715,7 +1715,7 @@ Status DEPipeline::ParseMnistOp(const py::dict &args, std::shared_ptr (void)builder->SetNumWorkers(num_workers); } else if (key == "sampler") { auto create = py::reinterpret_borrow(value).attr("create"); - std::shared_ptr sampler = create().cast>(); + std::shared_ptr sampler = create().cast>(); (void)builder->SetSampler(std::move(sampler)); } else if (key == "usage") { (void)builder->SetUsage(ToString(value)); @@ -1768,7 +1768,7 @@ Status DEPipeline::ParseCelebAOp(const py::dict &args, std::shared_ptrSetNumWorkers(num_workers); } else if (key == "sampler") { auto create = py::reinterpret_borrow(value).attr("create"); - std::shared_ptr sampler = create().cast>(); + std::shared_ptr sampler = create().cast>(); (void)builder->SetSampler(std::move(sampler)); } else if (key == "decode") { (void)builder->SetDecode(ToBool(value)); @@ -1806,7 +1806,7 @@ Status DEPipeline::ParseTextFileOp(const py::dict &args, std::shared_ptr files_list; std::shared_ptr cache_client = nullptr; - std::shared_ptr sampler = nullptr; + std::shared_ptr sampler = nullptr; int num_workers = 0; std::shared_ptr builder = std::make_shared(); if (!args["dataset_files"].is_none()) { @@ -1840,7 +1840,7 @@ Status DEPipeline::ParseTextFileOp(const py::dict &args, std::shared_ptr>(); } else if (key == "sampler") { auto create = py::reinterpret_borrow(value).attr("create"); - sampler = create().cast>(); + sampler = create().cast>(); } } } @@ -1855,7 +1855,7 @@ Status DEPipeline::ParseTextFileOp(const py::dict &args, std::shared_ptr(num_samples, start_index); + sampler = std::make_shared(num_samples, start_index); (void)builder->SetSampler(std::move(sampler)); } @@ -1991,7 +1991,7 @@ Status DEPipeline::ParseClueOp(const py::dict &args, std::shared_ptr std::shared_ptr *bottom) { std::vector files_list; std::shared_ptr cache_client = nullptr; - std::shared_ptr sampler = nullptr; + std::shared_ptr sampler = nullptr; int num_workers = 0; std::shared_ptr builder = std::make_shared(); @@ -2036,7 +2036,7 @@ Status DEPipeline::ParseClueOp(const py::dict &args, std::shared_ptr cache_client = value.cast>(); } else if (key == "sampler") { auto create = py::reinterpret_borrow(value).attr("create"); - sampler = create().cast>(); + sampler = create().cast>(); } } } @@ -2051,7 +2051,7 @@ Status DEPipeline::ParseClueOp(const py::dict &args, std::shared_ptr } else if (cache_client) { int64_t num_samples = 0; int64_t start_index = 0; - sampler = std::make_shared(num_samples, start_index); + sampler = std::make_shared(num_samples, start_index); (void)builder->SetSampler(std::move(sampler)); } @@ -2116,7 +2116,7 @@ Status DEPipeline::ParseCsvOp(const py::dict &args, std::shared_ptr * std::shared_ptr *bottom) { std::vector files_list; std::shared_ptr cache_client = nullptr; - std::shared_ptr sampler = nullptr; + std::shared_ptr sampler = nullptr; int num_workers = 0; std::shared_ptr builder = std::make_shared(); if (!args["dataset_files"].is_none()) { @@ -2173,7 +2173,7 @@ Status DEPipeline::ParseCsvOp(const py::dict &args, std::shared_ptr * cache_client = value.cast>(); } else if (key == "sampler") { auto create = py::reinterpret_borrow(value).attr("create"); - sampler = create().cast>(); + sampler = create().cast>(); } } } @@ -2188,7 +2188,7 @@ Status DEPipeline::ParseCsvOp(const py::dict &args, std::shared_ptr * } else if (cache_client) { int64_t num_samples = 0; int64_t start_index = 0; - sampler = std::make_shared(num_samples, start_index); + sampler = std::make_shared(num_samples, start_index); (void)builder->SetSampler(std::move(sampler)); } diff --git a/mindspore/ccsrc/minddata/dataset/api/samplers.cc b/mindspore/ccsrc/minddata/dataset/api/samplers.cc index 9307f4dc25..469fe84c44 100644 --- a/mindspore/ccsrc/minddata/dataset/api/samplers.cc +++ b/mindspore/ccsrc/minddata/dataset/api/samplers.cc @@ -35,7 +35,6 @@ namespace mindspore { namespace dataset { -namespace api { #define RETURN_NULL_IF_ERROR(_s) \ do { \ @@ -151,10 +150,10 @@ bool DistributedSamplerObj::ValidateParams() { return true; } -std::shared_ptr DistributedSamplerObj::Build() { +std::shared_ptr DistributedSamplerObj::Build() { // runtime sampler object - auto sampler = std::make_shared(num_samples_, num_shards_, shard_id_, shuffle_, seed_, - offset_, even_dist_); + auto sampler = std::make_shared(num_samples_, num_shards_, shard_id_, shuffle_, seed_, + offset_, even_dist_); return sampler; } @@ -184,9 +183,9 @@ bool PKSamplerObj::ValidateParams() { return true; } -std::shared_ptr PKSamplerObj::Build() { +std::shared_ptr PKSamplerObj::Build() { // runtime sampler object - auto sampler = std::make_shared(num_samples_, num_val_, shuffle_); + auto sampler = std::make_shared(num_samples_, num_val_, shuffle_); return sampler; } @@ -218,10 +217,10 @@ bool RandomSamplerObj::ValidateParams() { return true; } -std::shared_ptr RandomSamplerObj::Build() { +std::shared_ptr RandomSamplerObj::Build() { // runtime sampler object bool reshuffle_each_epoch = true; - auto sampler = std::make_shared(num_samples_, replacement_, reshuffle_each_epoch); + auto sampler = std::make_shared(num_samples_, replacement_, reshuffle_each_epoch); return sampler; } @@ -255,9 +254,9 @@ bool SequentialSamplerObj::ValidateParams() { return true; } -std::shared_ptr SequentialSamplerObj::Build() { +std::shared_ptr SequentialSamplerObj::Build() { // runtime sampler object - auto sampler = std::make_shared(num_samples_, start_index_); + auto sampler = std::make_shared(num_samples_, start_index_); return sampler; } @@ -284,9 +283,9 @@ bool SubsetRandomSamplerObj::ValidateParams() { return true; } -std::shared_ptr SubsetRandomSamplerObj::Build() { +std::shared_ptr SubsetRandomSamplerObj::Build() { // runtime sampler object - auto sampler = std::make_shared(num_samples_, indices_); + auto sampler = std::make_shared(num_samples_, indices_); return sampler; } @@ -330,11 +329,10 @@ bool WeightedRandomSamplerObj::ValidateParams() { return true; } -std::shared_ptr WeightedRandomSamplerObj::Build() { - auto sampler = std::make_shared(num_samples_, weights_, replacement_); +std::shared_ptr WeightedRandomSamplerObj::Build() { + auto sampler = std::make_shared(num_samples_, weights_, replacement_); return sampler; } -} // namespace api } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/api/text.cc b/mindspore/ccsrc/minddata/dataset/api/text.cc index 2adf7a188e..9fed10e30b 100644 --- a/mindspore/ccsrc/minddata/dataset/api/text.cc +++ b/mindspore/ccsrc/minddata/dataset/api/text.cc @@ -22,7 +22,6 @@ namespace mindspore { namespace dataset { -namespace api { // Transform operations for text. namespace text { @@ -130,6 +129,5 @@ std::shared_ptr SentencePieceTokenizerOperation::Build() { } } // namespace text -} // namespace api } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/api/transforms.cc b/mindspore/ccsrc/minddata/dataset/api/transforms.cc index 3274f35725..46822e1323 100644 --- a/mindspore/ccsrc/minddata/dataset/api/transforms.cc +++ b/mindspore/ccsrc/minddata/dataset/api/transforms.cc @@ -22,7 +22,6 @@ namespace mindspore { namespace dataset { -namespace api { TensorOperation::TensorOperation() {} @@ -94,6 +93,5 @@ Status TypeCastOperation::ValidateParams() { std::shared_ptr TypeCastOperation::Build() { return std::make_shared(data_type_); } } // namespace transforms -} // namespace api } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/api/vision.cc b/mindspore/ccsrc/minddata/dataset/api/vision.cc index 02b8ebf19b..0d16605e1f 100644 --- a/mindspore/ccsrc/minddata/dataset/api/vision.cc +++ b/mindspore/ccsrc/minddata/dataset/api/vision.cc @@ -65,7 +65,6 @@ namespace mindspore { namespace dataset { -namespace api { // Transform operations for computer vision. namespace vision { @@ -1702,6 +1701,5 @@ std::shared_ptr UniformAugOperation::Build() { #endif } // namespace vision -} // namespace api } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.cc b/mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.cc index 5b62deefd2..f8c356259d 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.cc @@ -34,11 +34,11 @@ namespace mindspore::dataset { // TreeConsumer TreeConsumer::TreeConsumer() { tree_adapter_ = std::make_unique(); } -Status TreeConsumer::Init(std::shared_ptr d) { return tree_adapter_->BuildAndPrepare(std::move(d)); } +Status TreeConsumer::Init(std::shared_ptr d) { return tree_adapter_->BuildAndPrepare(std::move(d)); } Status TreeConsumer::Terminate() { return tree_adapter_->AllTasks()->DoServiceStop(); } // IteratorConsumer -Status IteratorConsumer::Init(std::shared_ptr d) { +Status IteratorConsumer::Init(std::shared_ptr d) { return tree_adapter_->BuildAndPrepare(std::move(d), num_epochs_); } @@ -74,7 +74,7 @@ Status IteratorConsumer::GetNextAsMap(std::unordered_map } // ToDevice -Status ToDevice::Init(std::shared_ptr d) { +Status ToDevice::Init(std::shared_ptr d) { return tree_adapter_->BuildAndPrepare(std::move(d), num_epochs_); } @@ -385,8 +385,8 @@ TreeGetters::TreeGetters() : dataset_size_(-1), init_flag_(false), row_flag_(fal tree_adapter_ = std::make_unique(); } -Status TreeGetters::Init(std::shared_ptr d) { - Status s = tree_adapter_->BuildAndPrepare(std::move(d)); +Status TreeGetters::Init(std::shared_ptr d) { + Status s = tree_adapter_->BuildAndPrepare(std::move(d), 1); if (!s.IsError()) { init_flag_ = true; } @@ -464,7 +464,7 @@ Status TreeGetters::GetNumClasses(int64_t *num_classes) { RETURN_IF_NOT_OK(root->GetNumClasses(num_classes)); return Status::OK(); } -Status BuildVocabConsumer::Init(std::shared_ptr d) { +Status BuildVocabConsumer::Init(std::shared_ptr d) { return tree_adapter_->BuildAndPrepare(std::move(d), 1); } Status BuildVocabConsumer::Start() { diff --git a/mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.h b/mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.h index 7e947f3909..07549cab82 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.h +++ b/mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.h @@ -29,10 +29,7 @@ namespace mindspore::dataset { // Forward declare class TreeAdapter; - -namespace api { class DatasetNode; -} /// A base class for tree consumers which would fetch rows from the tree pipeline class TreeConsumer { @@ -42,7 +39,7 @@ class TreeConsumer { /// Initializes the consumer, this involves constructing and preparing the tree. /// \param d The dataset node that represent the root of the IR tree. /// \return Status error code. - virtual Status Init(std::shared_ptr d); + virtual Status Init(std::shared_ptr d); Status Terminate(); @@ -61,7 +58,7 @@ class IteratorConsumer : public TreeConsumer { /// \param num_epochs number of epochs. Default to -1 (infinite epochs). explicit IteratorConsumer(int32_t num_epochs = -1) : TreeConsumer(), num_epochs_(num_epochs) {} - Status Init(std::shared_ptr d) override; + Status Init(std::shared_ptr d) override; /// Returns the next row in a vector format /// \param[out] out std::vector of Tensors @@ -133,7 +130,7 @@ class ToDevice : public TreeConsumer { explicit ToDevice(bool send_epoch_end, int32_t num_epochs = -1) : TreeConsumer(), send_epoch_end_(send_epoch_end), num_epochs_(num_epochs) {} - Status Init(std::shared_ptr d) override; + Status Init(std::shared_ptr d) override; /// Send the data to device /// \return Status error code @@ -162,7 +159,7 @@ class ToDevice : public TreeConsumer { class TreeGetters : public TreeConsumer { public: TreeGetters(); - Status Init(std::shared_ptr d) override; + Status Init(std::shared_ptr d) override; Status GetDatasetSize(int64_t *size); Status GetOutputTypes(std::vector *types); Status GetOutputShapes(std::vector *shapes); @@ -185,10 +182,9 @@ class BuildVocabConsumer : public TreeConsumer { /// BuildVocabConsumer Constructor which will call the base class default constructor. BuildVocabConsumer() = default; - Status Init(std::shared_ptr d) override; + Status Init(std::shared_ptr d) override; - /// Save the given dataset to MindRecord format on disk. This is a blocking method (i.e., after returning, all rows - /// would be written to disk) + /// Start consuming /// \return Status error code Status Start(); diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_base_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_base_op.cc index 7ffcd3569e..9a31d35b36 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_base_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_base_op.cc @@ -46,7 +46,7 @@ Status CacheBase::Reset() { return Status::OK(); } CacheBase::CacheBase(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf, - std::shared_ptr cache_client, std::shared_ptr sampler) + std::shared_ptr cache_client, std::shared_ptr sampler) : ParallelOp(num_workers, op_connector_size, std::move(sampler)), row_cnt_(0), num_cache_miss_(0), diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_base_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_base_op.h index af47748b1c..cfdbde1440 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_base_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_base_op.h @@ -46,7 +46,7 @@ class CacheBase : public ParallelOp { /// \param cache_client CacheClient for communication to the CacheServer /// \param sampler Sampler which is mandatory CacheBase(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf, - std::shared_ptr cache_client, std::shared_ptr sampler); + std::shared_ptr cache_client, std::shared_ptr sampler); /// \brief Destructor ~CacheBase(); diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_lookup_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_lookup_op.cc index 1cdaec76f8..967287d312 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_lookup_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_lookup_op.cc @@ -87,7 +87,7 @@ Status CacheLookupOp::HandshakeRandomAccessOp(const RandomAccessOp *op) { leaf_op_wp_.Set(); return Status::OK(); } -Status CacheLookupOp::InitSampler() { return Sampler::InitSampler(); } +Status CacheLookupOp::InitSampler() { return SamplerRT::InitSampler(); } void CacheLookupOp::Print(std::ostream &out, bool show_all) const { CacheBase::Print(out, show_all); } Status CacheLookupOp::GetNextSample(std::unique_ptr *out_buffer) { std::vector cache_miss; diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_lookup_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_lookup_op.h index adec3d8283..fdf9b530ef 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_lookup_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_lookup_op.h @@ -28,7 +28,7 @@ namespace dataset { /// \brief provides a memory/disk cache that acts as a save-point within a mappable dataset. /// \note For non-mappable dataset, please see CacheOp /// \see CacheOp -class CacheLookupOp : public CacheBase, public Sampler { +class CacheLookupOp : public CacheBase, public SamplerRT { public: class Builder { public: @@ -62,7 +62,7 @@ class CacheLookupOp : public CacheBase, public Sampler { /// \brief Setter method. /// \return Builder setter method returns reference to the builder. - Builder &SetSampler(std::shared_ptr sampler) { + Builder &SetSampler(std::shared_ptr sampler) { build_sampler_ = std::move(sampler); return *this; } @@ -77,7 +77,7 @@ class CacheLookupOp : public CacheBase, public Sampler { int32_t rows_per_buffer_; int32_t build_op_connector_size_; std::shared_ptr build_cache_client_; - std::shared_ptr build_sampler_; + std::shared_ptr build_sampler_; // Check if the required parameters are set by the builder. // \return Status The error code return @@ -87,8 +87,8 @@ class CacheLookupOp : public CacheBase, public Sampler { /// \note It takes the same argument as the base class. /// \see CacheBase CacheLookupOp(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf, - std::shared_ptr cache_client, std::shared_ptr sampler) - : CacheBase(num_workers, op_connector_size, rows_per_buf, cache_client, sampler), Sampler(*(sampler.get())) {} + std::shared_ptr cache_client, std::shared_ptr sampler) + : CacheBase(num_workers, op_connector_size, rows_per_buf, cache_client, sampler), SamplerRT(*(sampler.get())) {} ~CacheLookupOp() = default; // As a parallel op, we override these two functions Status operator()() override; diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.cc index ca18109381..d9908b2003 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.cc @@ -46,7 +46,7 @@ void CacheMergeOp::Print(std::ostream &out, bool show_all) const { } CacheMergeOp::CacheMergeOp(int32_t numWorkers, int32_t opConnectorSize, int32_t numCleaners, - std::shared_ptr cache_client, const std::shared_ptr &sampler) + std::shared_ptr cache_client, const std::shared_ptr &sampler) : ParallelOp(numWorkers, opConnectorSize, sampler), num_cleaners_(numCleaners), cache_client_(std::move(cache_client)), diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.h index db702c03db..d56fbdceb5 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.h @@ -110,7 +110,7 @@ class CacheMergeOp : public ParallelOp { /// \brief Setter method /// \param sampler /// \return Builder setter method returns reference to the builder. - Builder &SetSampler(std::shared_ptr sampler) { + Builder &SetSampler(std::shared_ptr sampler) { build_sampler_ = std::move(sampler); return *this; } @@ -133,7 +133,7 @@ class CacheMergeOp : public ParallelOp { int32_t build_op_connector_size_; int32_t build_num_cleaners_; std::shared_ptr build_cache_client_; - std::shared_ptr build_sampler_; + std::shared_ptr build_sampler_; /// Check if the required parameters are set by the builder. /// \return Status The error code return @@ -147,7 +147,7 @@ class CacheMergeOp : public ParallelOp { /// \param cache_client CacheClient to commmunicate with the Cache server /// \param sampler as a derived class of ParallelOp CacheMergeOp(int32_t numWorkers, int32_t opConnectorSize, int32_t numCleaners, - std::shared_ptr cache_client, const std::shared_ptr &sampler); + std::shared_ptr cache_client, const std::shared_ptr &sampler); ~CacheMergeOp(); void Print(std::ostream &out, bool show_all) const override; std::string Name() const override { return kCacheMergeOp; } diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_op.cc index f579bf165c..fea1913988 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_op.cc @@ -68,7 +68,7 @@ Status CacheOp::Builder::Build(std::shared_ptr *ptr) { // Constructor of CacheOp CacheOp::CacheOp(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf, - std::shared_ptr cache_client, std::shared_ptr sampler) + std::shared_ptr cache_client, std::shared_ptr sampler) : CacheBase(num_workers, op_connector_size, rows_per_buf, std::move(cache_client), std::move(sampler)), num_guys_in_(0), phase_(Phase::kBuildPhase) {} diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_op.h index f6af02fdba..b3fffe6be5 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_op.h @@ -81,7 +81,7 @@ class CacheOp : public CacheBase, public RandomAccessOp { /// \brief Setter method /// \param sampler /// \return Builder setter method returns reference to the builder. - Builder &SetSampler(std::shared_ptr sampler) { + Builder &SetSampler(std::shared_ptr sampler) { build_sampler_ = std::move(sampler); return *this; } @@ -96,7 +96,7 @@ class CacheOp : public CacheBase, public RandomAccessOp { int32_t rows_per_buffer_; int32_t build_op_connector_size_; std::shared_ptr build_cache_client_; - std::shared_ptr build_sampler_; + std::shared_ptr build_sampler_; /// \brief Check if the required parameters are set by the builder. /// \return Status The error code return @@ -108,7 +108,7 @@ class CacheOp : public CacheBase, public RandomAccessOp { /// \param num_workers The number of worker threads. /// \param op_connector_size The size of each queue in the connector. CacheOp(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf, - std::shared_ptr cache_client, std::shared_ptr sampler); + std::shared_ptr cache_client, std::shared_ptr sampler); // Destructor ~CacheOp(); diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/concat_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/concat_op.cc index 1e934e7259..a31bcdc2a5 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/concat_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/concat_op.cc @@ -36,7 +36,7 @@ ConcatOp::Builder::Builder() { // The builder "build" method creates the final object. Status ConcatOp::Builder::Build(std::shared_ptr *ptr) { if (builder_sampler_ == nullptr) { - builder_sampler_ = std::make_shared(0, 1, 0, false); + builder_sampler_ = std::make_shared(0, 1, 0, false); } *ptr = std::make_shared(builder_op_connector_size_, builder_sampler_, children_flag_and_nums_, children_start_end_index_); @@ -44,7 +44,7 @@ Status ConcatOp::Builder::Build(std::shared_ptr *ptr) { } // Constructor of the ConcatOp. -ConcatOp::ConcatOp(int32_t op_connector_size, std::shared_ptr sampler, +ConcatOp::ConcatOp(int32_t op_connector_size, std::shared_ptr sampler, std::vector> children_flag_and_nums, std::vector> children_start_end_index) : PipelineOp(op_connector_size), @@ -80,7 +80,7 @@ Status ConcatOp::operator()() { bool is_not_mappable = true; int num_shard = 1; int shard_index = 0; - std::shared_ptr distribute_sampler = std::dynamic_pointer_cast(sampler_); + std::shared_ptr distribute_sampler = std::dynamic_pointer_cast(sampler_); if (distribute_sampler != nullptr) { num_shard = distribute_sampler->GetDeviceNum(); shard_index = distribute_sampler->GetDeviceID(); diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/concat_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/concat_op.h index 7639aa18d5..9d23f4b909 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/concat_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/concat_op.h @@ -44,7 +44,7 @@ class ConcatOp : public PipelineOp { // The builder "build" method creates the final object. // @return shared_ptr to the new ConcatOp object Status Build(std::shared_ptr *); - Builder &SetSampler(std::shared_ptr sampler) { + Builder &SetSampler(std::shared_ptr sampler) { builder_sampler_ = std::move(sampler); return *this; } @@ -61,7 +61,7 @@ class ConcatOp : public PipelineOp { private: int32_t builder_op_connector_size_; - std::shared_ptr builder_sampler_; + std::shared_ptr builder_sampler_; std::vector> children_flag_and_nums_; std::vector> children_start_end_index_; }; @@ -70,7 +70,7 @@ class ConcatOp : public PipelineOp { // @note The builder class should be used to call it // @param op_connector_size - connector size explicit ConcatOp(int32_t op_connector_size); - explicit ConcatOp(int32_t op_connector_size, std::shared_ptr sampler, + explicit ConcatOp(int32_t op_connector_size, std::shared_ptr sampler, std::vector> children_flag_and_nums, std::vector> children_start_end_index); @@ -123,7 +123,7 @@ class ConcatOp : public PipelineOp { std::unordered_map column_name_id_; // Mapping between col index and col name std::vector data_type_; std::vector data_rank_; - std::shared_ptr sampler_; + std::shared_ptr sampler_; std::vector> children_flag_and_nums_; std::vector> children_start_end_index_; }; diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.cc index e4f759c240..cda64420eb 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.cc @@ -40,7 +40,7 @@ namespace mindspore { namespace dataset { // Constructor -DatasetOp::DatasetOp(int32_t op_connector_size, std::shared_ptr sampler) +DatasetOp::DatasetOp(int32_t op_connector_size, std::shared_ptr sampler) : oc_queue_size_(op_connector_size), sampler_(sampler), operator_id_(kInvalidOperatorId), @@ -409,7 +409,7 @@ Status DatasetOp::Accept(NodePass *p, bool *modified) { } // Getter for the sampler, and it also removes the sampler from the op -Status DatasetOp::FetchRemoveSampler(std::shared_ptr *sampler) { +Status DatasetOp::FetchRemoveSampler(std::shared_ptr *sampler) { *sampler = sampler_; // It's okay if it sampler_ points to nullptr sampler_.reset(); // clear our member-copy of this pointer. We no longer have this sampler return Status::OK(); diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.h index 1eebb8f08f..6c9fd15dbb 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.h @@ -62,7 +62,7 @@ class DataBuffer; class NodePass; -class Sampler; +class SamplerRT; /// \brief The base class DatasetOp is the main tree node. It is an abstract class, so /// the actual implementation of the operators will be derived from here. @@ -80,7 +80,7 @@ class DatasetOp : public std::enable_shared_from_this { /// Constructor /// \param op_connector_size - The size for the output connector of this operator. /// \param sampler - The sampler for the op - explicit DatasetOp(int32_t op_connector_size, std::shared_ptr sampler); + explicit DatasetOp(int32_t op_connector_size, std::shared_ptr sampler); /// Destructor virtual ~DatasetOp() { tree_ = nullptr; } @@ -347,12 +347,12 @@ class DatasetOp : public std::enable_shared_from_this { /// Getter for the sampler /// \return Shared pointer to the sampler (may return nullptr) - std::shared_ptr sampler() { return sampler_; } + std::shared_ptr sampler() { return sampler_; } /// \brief Getter for the sampler, and it also removes the sampler from the op /// \param[out] sampler A pointer to the output sampler that was removed /// \return Status error code - Status FetchRemoveSampler(std::shared_ptr *sampler); + Status FetchRemoveSampler(std::shared_ptr *sampler); #ifndef ENABLE_ANDROID // Computes a CRC value for the operator @@ -368,7 +368,7 @@ class DatasetOp : public std::enable_shared_from_this { } /// \brief Setter for the sampler. Allows you to overwrite a previous sampler with a new one. - void SetSampler(std::shared_ptr sampler) { sampler_ = sampler; } + void SetSampler(std::shared_ptr sampler) { sampler_ = sampler; } /// \brief Checks if this is a leaf node (0 children) /// \return boolean returns true if it's a leaf @@ -409,7 +409,7 @@ class DatasetOp : public std::enable_shared_from_this { std::vector> child_; // Child nodes std::vector parent_; // Parent nodes. No ownership - std::shared_ptr sampler_; // Some leaf ops might have a sampler + std::shared_ptr sampler_; // Some leaf ops might have a sampler int32_t oc_queue_size_; // Capacity for each out_connector_ int32_t operator_id_; // Generated id for the node ExecutionTree *tree_; // Back pointer to our tree. diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/parallel_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/parallel_op.cc index 52d3c4645c..63921aec69 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/parallel_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/parallel_op.cc @@ -26,7 +26,7 @@ namespace mindspore { namespace dataset { // Constructor -ParallelOp::ParallelOp(int32_t num_workers, int32_t op_connector_size, std::shared_ptr sampler) +ParallelOp::ParallelOp(int32_t num_workers, int32_t op_connector_size, std::shared_ptr sampler) : DatasetOp(op_connector_size, sampler), num_workers_(num_workers), num_producers_(num_workers), diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/parallel_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/parallel_op.h index e09ef52e2c..8487eacdfd 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/parallel_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/parallel_op.h @@ -41,7 +41,7 @@ class ParallelOp : public DatasetOp { // @param num_workers // @param op_connector_size - size of the output connector for this operator // @param sampler - The sampler for the op - ParallelOp(int32_t num_workers, int32_t op_connector_size, std::shared_ptr sampler = nullptr); + ParallelOp(int32_t num_workers, int32_t op_connector_size, std::shared_ptr sampler = nullptr); // Destructor ~ParallelOp() = default; diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/pipeline_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/pipeline_op.cc index 6e4f533eb7..18d15ec7cc 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/pipeline_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/pipeline_op.cc @@ -20,7 +20,7 @@ namespace mindspore { namespace dataset { // Constructor -PipelineOp::PipelineOp(int32_t op_connector_size, std::shared_ptr sampler) +PipelineOp::PipelineOp(int32_t op_connector_size, std::shared_ptr sampler) : DatasetOp(op_connector_size, sampler) {} // A print method typically used for debugging diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/pipeline_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/pipeline_op.h index 88faad8265..3c137c49d5 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/pipeline_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/pipeline_op.h @@ -34,7 +34,7 @@ class PipelineOp : public DatasetOp { // @param op_connector_size - size of the output connector // @return Builder setter method returns reference to the builder. // @param sampler - The sampler for the op - explicit PipelineOp(int32_t op_connector_size, std::shared_ptr sampler = nullptr); + explicit PipelineOp(int32_t op_connector_size, std::shared_ptr sampler = nullptr); // Destructor ~PipelineOp() = default; diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/album_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/album_op.cc index cf424f4cb0..a2cc45a192 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/album_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/album_op.cc @@ -42,7 +42,7 @@ Status AlbumOp::Builder::Build(std::shared_ptr *ptr) { if (builder_sampler_ == nullptr) { const int64_t num_samples = 0; // default num samples of 0 means to sample entire set of data const int64_t start_index = 0; - builder_sampler_ = std::make_shared(start_index, num_samples); + builder_sampler_ = std::make_shared(start_index, num_samples); } builder_schema_ = std::make_unique(); @@ -73,7 +73,7 @@ Status AlbumOp::Builder::SanityCheck() { AlbumOp::AlbumOp(int32_t num_wkrs, int32_t rows_per_buffer, std::string file_dir, int32_t queue_size, bool do_decode, const std::set &exts, std::unique_ptr data_schema, - std::shared_ptr sampler) + std::shared_ptr sampler) : ParallelOp(num_wkrs, queue_size, std::move(sampler)), rows_per_buffer_(rows_per_buffer), folder_path_(file_dir), diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/album_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/album_op.h index 2d87160fb7..f3589099ed 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/album_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/album_op.h @@ -100,7 +100,7 @@ class AlbumOp : public ParallelOp, public RandomAccessOp { /// \brief Setter method /// \param[in] sampler /// \return Builder setter method returns reference to the builder - Builder &SetSampler(std::shared_ptr sampler) { + Builder &SetSampler(std::shared_ptr sampler) { builder_sampler_ = std::move(sampler); return *this; } @@ -147,7 +147,7 @@ class AlbumOp : public ParallelOp, public RandomAccessOp { int32_t builder_rows_per_buffer_; int32_t builder_op_connector_size_; std::set builder_extensions_; - std::shared_ptr builder_sampler_; + std::shared_ptr builder_sampler_; std::unique_ptr builder_schema_; }; @@ -161,7 +161,8 @@ class AlbumOp : public ParallelOp, public RandomAccessOp { /// \param[in] data_schema - schema of dataset /// \param[in] sampler - sampler tells AlbumOp what to read AlbumOp(int32_t num_wkrs, int32_t rows_per_buffer, std::string file_dir, int32_t queue_size, bool do_decode, - const std::set &exts, std::unique_ptr data_schema, std::shared_ptr sampler); + const std::set &exts, std::unique_ptr data_schema, + std::shared_ptr sampler); /// \brief Destructor. ~AlbumOp() = default; diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/celeba_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/celeba_op.cc index 295400b865..44a9806e29 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/celeba_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/celeba_op.cc @@ -46,7 +46,7 @@ Status CelebAOp::Builder::Build(std::shared_ptr *op) { if (builder_sampler_ == nullptr) { const int64_t num_samples = 0; const int64_t start_index = 0; - builder_sampler_ = std::make_shared(start_index, num_samples); + builder_sampler_ = std::make_shared(start_index, num_samples); } builder_schema_ = std::make_unique(); @@ -79,7 +79,7 @@ Status CelebAOp::Builder::SanityCheck() { CelebAOp::CelebAOp(int32_t num_workers, int32_t rows_per_buffer, const std::string &dir, int32_t queue_size, bool decode, const std::string &usage, const std::set &exts, - std::unique_ptr schema, std::shared_ptr sampler) + std::unique_ptr schema, std::shared_ptr sampler) : ParallelOp(num_workers, queue_size, std::move(sampler)), rows_per_buffer_(rows_per_buffer), folder_path_(dir), diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/celeba_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/celeba_op.h index fdd7781132..2a9de5e493 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/celeba_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/celeba_op.h @@ -95,7 +95,7 @@ class CelebAOp : public ParallelOp, RandomAccessOp { // Setter method // @param std::shared_ptr sampler // @return Builder setter method returns reference to the builder. - Builder &SetSampler(std::shared_ptr sampler) { + Builder &SetSampler(std::shared_ptr sampler) { builder_sampler_ = std::move(sampler); return *this; } @@ -131,7 +131,7 @@ class CelebAOp : public ParallelOp, RandomAccessOp { int32_t builder_rows_per_buffer_; int32_t builder_op_connector_size_; std::set builder_extensions_; - std::shared_ptr builder_sampler_; + std::shared_ptr builder_sampler_; std::unique_ptr builder_schema_; std::string builder_usage_; }; @@ -144,7 +144,7 @@ class CelebAOp : public ParallelOp, RandomAccessOp { // @param std::unique_ptr sampler - sampler tells CelebAOp what to read CelebAOp(int32_t num_workers, int32_t rows_per_buffer, const std::string &dir, int32_t queue_size, bool decode, const std::string &usage, const std::set &exts, std::unique_ptr schema, - std::shared_ptr sampler); + std::shared_ptr sampler); ~CelebAOp() override = default; diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/cifar_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/cifar_op.cc index 93324dccfa..ade39e0e2e 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/cifar_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/cifar_op.cc @@ -50,7 +50,7 @@ Status CifarOp::Builder::Build(std::shared_ptr *ptr) { if (sampler_ == nullptr) { const int64_t num_samples = 0; const int64_t start_index = 0; - sampler_ = std::make_shared(start_index, num_samples); + sampler_ = std::make_shared(start_index, num_samples); } schema_ = std::make_unique(); TensorShape scalar = TensorShape::CreateScalar(); @@ -88,7 +88,7 @@ Status CifarOp::Builder::SanityCheck() { CifarOp::CifarOp(CifarType type, const std::string &usage, int32_t num_works, int32_t rows_per_buf, const std::string &file_dir, int32_t queue_size, std::unique_ptr data_schema, - std::shared_ptr sampler) + std::shared_ptr sampler) : ParallelOp(num_works, queue_size, std::move(sampler)), cifar_type_(type), usage_(usage), diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/cifar_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/cifar_op.h index a2e15dbfbd..60ee4848f0 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/cifar_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/cifar_op.h @@ -75,7 +75,7 @@ class CifarOp : public ParallelOp, public RandomAccessOp { // Setter method // @param std::shared_ptr sampler // @return Builder setter method returns reference to the builder. - Builder &SetSampler(std::shared_ptr sampler) { + Builder &SetSampler(std::shared_ptr sampler) { sampler_ = std::move(sampler); return *this; } @@ -123,7 +123,7 @@ class CifarOp : public ParallelOp, public RandomAccessOp { int32_t num_workers_; int32_t rows_per_buffer_; int32_t op_connect_size_; - std::shared_ptr sampler_; + std::shared_ptr sampler_; std::unique_ptr schema_; CifarType cifar_type_; }; @@ -138,7 +138,7 @@ class CifarOp : public ParallelOp, public RandomAccessOp { // @param std::unique_ptr sampler - sampler tells ImageFolderOp what to read CifarOp(CifarType type, const std::string &usage, int32_t num_works, int32_t rows_per_buf, const std::string &file_dir, int32_t queue_size, std::unique_ptr data_schema, - std::shared_ptr sampler); + std::shared_ptr sampler); // Destructor. ~CifarOp() = default; diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/clue_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/clue_op.cc index 9d3ad80c7a..f3e9bf40d4 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/clue_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/clue_op.cc @@ -94,7 +94,7 @@ std::vector ClueOp::Builder::split(const std::string &s, char delim ClueOp::ClueOp(int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size, ColKeyMap cols_to_keyword, std::vector clue_files_list, int32_t op_connector_size, - bool shuffle_files, int32_t num_device, int32_t device_id, std::shared_ptr sampler) + bool shuffle_files, int32_t num_device, int32_t device_id, std::shared_ptr sampler) : ParallelOp(num_workers, op_connector_size, std::move(sampler)), rows_per_buffer_(rows_per_buffer), num_rows_per_shard_(0), diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/clue_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/clue_op.h index 9c2ae5ac03..09b7c9e6a9 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/clue_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/clue_op.h @@ -125,7 +125,7 @@ class ClueOp : public ParallelOp { // Setter method // @param std::shared_ptr sampler // @return Builder setter method returns reference to the builder. - Builder &SetSampler(std::shared_ptr sampler) { + Builder &SetSampler(std::shared_ptr sampler) { builder_sampler_ = std::move(sampler); return *this; } @@ -141,13 +141,13 @@ class ClueOp : public ParallelOp { std::vector builder_clue_files_list_; bool builder_shuffle_files_; std::map builder_cols_to_keyword_; - std::shared_ptr builder_sampler_; + std::shared_ptr builder_sampler_; }; // Constructor of ClueOp ClueOp(int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size, ColKeyMap cols_to_keyword, std::vector clue_files_list, int32_t op_connector_size, - bool shuffle_files, int32_t num_devices, int32_t device_id, std::shared_ptr sampler); + bool shuffle_files, int32_t num_devices, int32_t device_id, std::shared_ptr sampler); // Default destructor ~ClueOp() = default; diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/coco_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/coco_op.cc index 69ac8a28d3..915873cff0 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/coco_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/coco_op.cc @@ -60,7 +60,7 @@ Status CocoOp::Builder::Build(std::shared_ptr *ptr) { if (builder_sampler_ == nullptr) { const int64_t num_samples = 0; const int64_t start_index = 0; - builder_sampler_ = std::make_shared(start_index, num_samples); + builder_sampler_ = std::make_shared(start_index, num_samples); } builder_schema_ = std::make_unique(); RETURN_IF_NOT_OK(builder_schema_->AddColumn( @@ -123,7 +123,7 @@ Status CocoOp::Builder::SanityCheck() { CocoOp::CocoOp(const TaskType &task_type, const std::string &image_folder_path, const std::string &annotation_path, int32_t num_workers, int32_t rows_per_buffer, int32_t queue_size, bool decode, - std::unique_ptr data_schema, std::shared_ptr sampler) + std::unique_ptr data_schema, std::shared_ptr sampler) : ParallelOp(num_workers, queue_size, std::move(sampler)), decode_(decode), row_cnt_(0), diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/coco_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/coco_op.h index 096caa9e63..5b7600764e 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/coco_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/coco_op.h @@ -119,7 +119,7 @@ class CocoOp : public ParallelOp, public RandomAccessOp { // Setter method. // @param std::shared_ptr sampler // @return Builder setter method returns reference to the builder. - Builder &SetSampler(std::shared_ptr sampler) { + Builder &SetSampler(std::shared_ptr sampler) { builder_sampler_ = std::move(sampler); return *this; } @@ -149,7 +149,7 @@ class CocoOp : public ParallelOp, public RandomAccessOp { int32_t builder_num_workers_; int32_t builder_op_connector_size_; int32_t builder_rows_per_buffer_; - std::shared_ptr builder_sampler_; + std::shared_ptr builder_sampler_; std::unique_ptr builder_schema_; }; @@ -166,7 +166,7 @@ class CocoOp : public ParallelOp, public RandomAccessOp { // @param std::shared_ptr sampler - sampler tells CocoOp what to read CocoOp(const TaskType &task_type, const std::string &image_folder_path, const std::string &annotation_path, int32_t num_workers, int32_t rows_per_buffer, int32_t queue_size, bool decode, - std::unique_ptr data_schema, std::shared_ptr sampler); + std::unique_ptr data_schema, std::shared_ptr sampler); // Destructor ~CocoOp() = default; diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.cc index 2e7b59b7ff..333a374c3b 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.cc @@ -77,7 +77,7 @@ CsvOp::CsvOp(const std::vector &csv_files_list, char field_delim, const std::vector> &column_default, const std::vector &column_name, int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size, int32_t op_connector_size, bool shuffle_files, - int32_t num_device, int32_t device_id, std::shared_ptr sampler) + int32_t num_device, int32_t device_id, std::shared_ptr sampler) : ParallelOp(num_workers, op_connector_size, std::move(sampler)), csv_files_list_(std::move(csv_files_list)), field_delim_(field_delim), diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.h index 35d95daf3b..154d027744 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.h @@ -243,7 +243,7 @@ class CsvOp : public ParallelOp { // Setter method // @param std::shared_ptr sampler // @return Builder setter method returns reference to the builder. - Builder &SetSampler(std::shared_ptr sampler) { + Builder &SetSampler(std::shared_ptr sampler) { builder_sampler_ = std::move(sampler); return *this; } @@ -261,7 +261,7 @@ class CsvOp : public ParallelOp { char builder_field_delim_; std::vector> builder_column_default_list_; std::vector builder_column_name_list_; - std::shared_ptr builder_sampler_; + std::shared_ptr builder_sampler_; }; // Constructor of CsvOp @@ -271,7 +271,7 @@ class CsvOp : public ParallelOp { const std::vector> &column_default, const std::vector &column_name, int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size, int32_t op_connector_size, bool shuffle_files, int32_t num_devices, int32_t device_id, - std::shared_ptr sampler); + std::shared_ptr sampler); // Default destructor ~CsvOp() = default; diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/image_folder_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/image_folder_op.cc index 7a9d810eb0..478883ffd5 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/image_folder_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/image_folder_op.cc @@ -38,7 +38,7 @@ Status ImageFolderOp::Builder::Build(std::shared_ptr *ptr) { if (builder_sampler_ == nullptr) { const int64_t num_samples = 0; // default num samples of 0 means to sample entire set of data const int64_t start_index = 0; - builder_sampler_ = std::make_shared(start_index, num_samples); + builder_sampler_ = std::make_shared(start_index, num_samples); } builder_schema_ = std::make_unique(); TensorShape scalar = TensorShape::CreateScalar(); @@ -68,7 +68,7 @@ Status ImageFolderOp::Builder::SanityCheck() { ImageFolderOp::ImageFolderOp(int32_t num_wkrs, int32_t rows_per_buffer, std::string file_dir, int32_t queue_size, bool recursive, bool do_decode, const std::set &exts, const std::map &map, std::unique_ptr data_schema, - std::shared_ptr sampler) + std::shared_ptr sampler) : ParallelOp(num_wkrs, queue_size, std::move(sampler)), rows_per_buffer_(rows_per_buffer), folder_path_(file_dir), diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/image_folder_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/image_folder_op.h index f2e2e21f8e..979642ecce 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/image_folder_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/image_folder_op.h @@ -113,7 +113,7 @@ class ImageFolderOp : public ParallelOp, public RandomAccessOp { // Setter method // @param std::shared_ptr sampler // @return Builder setter method returns reference to the builder. - Builder &SetSampler(std::shared_ptr sampler) { + Builder &SetSampler(std::shared_ptr sampler) { builder_sampler_ = std::move(sampler); return *this; } @@ -151,7 +151,7 @@ class ImageFolderOp : public ParallelOp, public RandomAccessOp { int32_t builder_rows_per_buffer_; int32_t builder_op_connector_size_; std::set builder_extensions_; - std::shared_ptr builder_sampler_; + std::shared_ptr builder_sampler_; std::unique_ptr builder_schema_; std::map builder_labels_to_read_; }; @@ -165,7 +165,7 @@ class ImageFolderOp : public ParallelOp, public RandomAccessOp { // @param td::unique_ptr sampler - sampler tells ImageFolderOp what to read ImageFolderOp(int32_t num_wkrs, int32_t rows_per_buffer, std::string file_dir, int32_t queue_size, bool recursive, bool do_decode, const std::set &exts, const std::map &map, - std::unique_ptr, std::shared_ptr sampler); + std::unique_ptr, std::shared_ptr sampler); // Destructor. ~ImageFolderOp() = default; diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/manifest_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/manifest_op.cc index 48344f410e..fe1919fc8a 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/manifest_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/manifest_op.cc @@ -43,7 +43,7 @@ Status ManifestOp::Builder::Build(std::shared_ptr *ptr) { if (builder_sampler_ == nullptr) { const int64_t num_samples = 0; const int64_t start_index = 0; - builder_sampler_ = std::make_shared(start_index, num_samples); + builder_sampler_ = std::make_shared(start_index, num_samples); } builder_schema_ = std::make_unique(); RETURN_IF_NOT_OK( @@ -67,7 +67,7 @@ Status ManifestOp::Builder::SanityCheck() { ManifestOp::ManifestOp(int32_t num_works, int32_t rows_per_buffer, std::string file, int32_t queue_size, bool decode, const std::map &class_index, std::unique_ptr data_schema, - std::shared_ptr sampler, std::string usage) + std::shared_ptr sampler, std::string usage) : ParallelOp(num_works, queue_size, std::move(sampler)), rows_per_buffer_(rows_per_buffer), io_block_pushed_(0), diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/manifest_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/manifest_op.h index fa98b9c2a7..fe35caeff8 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/manifest_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/manifest_op.h @@ -88,7 +88,7 @@ class ManifestOp : public ParallelOp, public RandomAccessOp { // Setter method // @param std::shared_ptr sampler // @return Builder setter method returns reference to the builder. - Builder &SetSampler(std::shared_ptr sampler) { + Builder &SetSampler(std::shared_ptr sampler) { builder_sampler_ = std::move(sampler); return *this; } @@ -119,7 +119,7 @@ class ManifestOp : public ParallelOp, public RandomAccessOp { Status Build(std::shared_ptr *op); private: - std::shared_ptr builder_sampler_; + std::shared_ptr builder_sampler_; bool builder_decode_; std::string builder_file_; @@ -139,7 +139,7 @@ class ManifestOp : public ParallelOp, public RandomAccessOp { // @param td::unique_ptr sampler - sampler tells ImageFolderOp what to read ManifestOp(int32_t num_works, int32_t rows_per_buffer, std::string file, int32_t queue_size, bool decode, const std::map &class_index, std::unique_ptr data_schema, - std::shared_ptr sampler, std::string usage); + std::shared_ptr sampler, std::string usage); // Destructor. ~ManifestOp() = default; diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mnist_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mnist_op.cc index 7d5923c475..72ce81cdb5 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mnist_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mnist_op.cc @@ -45,7 +45,7 @@ Status MnistOp::Builder::Build(std::shared_ptr *ptr) { if (builder_sampler_ == nullptr) { const int64_t num_samples = 0; const int64_t start_index = 0; - builder_sampler_ = std::make_shared(start_index, num_samples); + builder_sampler_ = std::make_shared(start_index, num_samples); } builder_schema_ = std::make_unique(); RETURN_IF_NOT_OK( @@ -75,7 +75,7 @@ Status MnistOp::Builder::SanityCheck() { } MnistOp::MnistOp(const std::string &usage, int32_t num_workers, int32_t rows_per_buffer, std::string folder_path, - int32_t queue_size, std::unique_ptr data_schema, std::shared_ptr sampler) + int32_t queue_size, std::unique_ptr data_schema, std::shared_ptr sampler) : ParallelOp(num_workers, queue_size, std::move(sampler)), usage_(usage), buf_cnt_(0), diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mnist_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mnist_op.h index c845ad1217..2accd8eb8c 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mnist_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mnist_op.h @@ -78,7 +78,7 @@ class MnistOp : public ParallelOp, public RandomAccessOp { // Setter method // @param std::shared_ptr sampler // @return Builder setter method returns reference to the builder. - Builder &SetSampler(std::shared_ptr sampler) { + Builder &SetSampler(std::shared_ptr sampler) { builder_sampler_ = std::move(sampler); return *this; } @@ -113,7 +113,7 @@ class MnistOp : public ParallelOp, public RandomAccessOp { int32_t builder_num_workers_; int32_t builder_rows_per_buffer_; int32_t builder_op_connector_size_; - std::shared_ptr builder_sampler_; + std::shared_ptr builder_sampler_; std::unique_ptr builder_schema_; }; @@ -126,7 +126,7 @@ class MnistOp : public ParallelOp, public RandomAccessOp { // @param std::unique_ptr data_schema - the schema of the mnist dataset // @param td::unique_ptr sampler - sampler tells MnistOp what to read MnistOp(const std::string &usage, int32_t num_workers, int32_t rows_per_buffer, std::string folder_path, - int32_t queue_size, std::unique_ptr data_schema, std::shared_ptr sampler); + int32_t queue_size, std::unique_ptr data_schema, std::shared_ptr sampler); // Destructor. ~MnistOp() = default; diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/random_data_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/random_data_op.cc index 2f5d3b0e39..75c43d8c61 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/random_data_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/random_data_op.cc @@ -65,7 +65,7 @@ Status RandomDataOp::Builder::SanityCheck() const { // Constructor for RandomDataOp RandomDataOp::RandomDataOp(int32_t num_workers, int32_t op_connector_size, int64_t rows_per_buffer, int64_t total_rows, - std::unique_ptr data_schema, std::shared_ptr sampler) + std::unique_ptr data_schema, std::shared_ptr sampler) : ParallelOp(num_workers, op_connector_size, std::move(sampler)), buffer_id_(0), rows_per_buffer_(rows_per_buffer), diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/random_data_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/random_data_op.h index e558aded55..b2ed68f5ad 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/random_data_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/random_data_op.h @@ -120,7 +120,7 @@ class RandomDataOp : public ParallelOp { // Setter method // @param std::shared_ptr sampler // @return Builder setter method returns reference to the builder. - Builder &SetSampler(std::shared_ptr sampler) { + Builder &SetSampler(std::shared_ptr sampler) { builder_sampler_ = std::move(sampler); return *this; } @@ -133,7 +133,7 @@ class RandomDataOp : public ParallelOp { Status SanityCheck() const; std::unique_ptr builder_data_schema_; - std::shared_ptr builder_sampler_; + std::shared_ptr builder_sampler_; int32_t builder_num_workers_; int32_t builder_op_connector_size_; int64_t builder_rows_per_buffer_; @@ -152,7 +152,7 @@ class RandomDataOp : public ParallelOp { * @return Builder - The modified builder by reference */ RandomDataOp(int32_t num_workers, int32_t op_connector_size, int64_t rows_per_buffer, int64_t total_rows, - std::unique_ptr data_schema, std::shared_ptr sampler); + std::unique_ptr data_schema, std::shared_ptr sampler); /** * Destructor diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.cc index 979da9e8b6..1f4671a092 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.cc @@ -23,9 +23,9 @@ namespace mindspore { namespace dataset { -DistributedSampler::DistributedSampler(int64_t num_samples, int64_t num_dev, int64_t dev_id, bool shuffle, - uint32_t seed, int64_t offset, bool even_dist) - : Sampler(num_samples, std::numeric_limits::max()), +DistributedSamplerRT::DistributedSamplerRT(int64_t num_samples, int64_t num_dev, int64_t dev_id, bool shuffle, + uint32_t seed, int64_t offset, bool even_dist) + : SamplerRT(num_samples, std::numeric_limits::max()), cnt_(0), seed_(seed == std::numeric_limits::max() ? GetSeed() : seed), device_id_(dev_id), @@ -35,7 +35,7 @@ DistributedSampler::DistributedSampler(int64_t num_samples, int64_t num_dev, int offset_(offset), non_empty_(true) {} -Status DistributedSampler::InitSampler() { +Status DistributedSamplerRT::InitSampler() { // Special value of 0 for num_samples means that the user wants to sample the entire set of data. // If the user asked to sample more rows than exists in the dataset, adjust the num_samples accordingly. if (num_samples_ == 0 || num_samples_ > num_rows_) { @@ -74,7 +74,7 @@ Status DistributedSampler::InitSampler() { return Status::OK(); } -Status DistributedSampler::GetNextSample(std::unique_ptr *out_buffer) { +Status DistributedSamplerRT::GetNextSample(std::unique_ptr *out_buffer) { if (cnt_ > samples_per_buffer_) { RETURN_STATUS_UNEXPECTED( "Number of samples(cnt) that have already been filled in to buffer should be less than or " @@ -143,7 +143,7 @@ Status DistributedSampler::GetNextSample(std::unique_ptr *out_buffer return Status::OK(); } -Status DistributedSampler::ResetSampler() { +Status DistributedSamplerRT::ResetSampler() { CHECK_FAIL_RETURN_UNEXPECTED(cnt_ == samples_per_buffer_, "ERROR Reset() called early/late"); cnt_ = 0; @@ -160,10 +160,10 @@ Status DistributedSampler::ResetSampler() { return Status::OK(); } -void DistributedSampler::Print(std::ostream &out, bool show_all) const { +void DistributedSamplerRT::Print(std::ostream &out, bool show_all) const { out << "\nSampler: DistributedSampler"; if (show_all) { - Sampler::Print(out, show_all); + SamplerRT::Print(out, show_all); out << "\nseed: " << seed_ << "\ndevice_id: " << device_id_ << "\nnum_devices: " << num_devices_ << "\nshuffle: " << shuffle_; } diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h index d9425b052e..015ad23fd3 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h @@ -25,7 +25,7 @@ namespace mindspore { namespace dataset { -class DistributedSampler : public Sampler { +class DistributedSamplerRT : public SamplerRT { public: /// \brief Constructor /// \param[in] num_samples The total number of rows in the dataset @@ -40,11 +40,12 @@ class DistributedSampler : public Sampler { /// This option is not exposed in the python API. Current behavior is that the remainder will always /// be handled by the first n shards, n being the corresponding device id. Please notice that when offset is set, /// even_dist will be forcibly converted to false for sending rest datasets in concatdataset scenario. - DistributedSampler(int64_t num_samples, int64_t num_dev, int64_t dev_id, bool shuffle, - uint32_t seed = std::numeric_limits::max(), int64_t offset = -1, bool even_dist = true); + DistributedSamplerRT(int64_t num_samples, int64_t num_dev, int64_t dev_id, bool shuffle, + uint32_t seed = std::numeric_limits::max(), int64_t offset = -1, + bool even_dist = true); /// \brief default destructor - ~DistributedSampler() = default; + ~DistributedSamplerRT() = default; /// \param std::unique_ptr * pBuffer /// \param int32_t workerId diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/pk_sampler.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/pk_sampler.cc index 9ec0ef2aad..b9e6bdf041 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/pk_sampler.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/pk_sampler.cc @@ -20,14 +20,14 @@ namespace mindspore { namespace dataset { -PKSampler::PKSampler(int64_t num_samples, int64_t val, bool shuffle, int64_t samples_per_buffer) - : Sampler(num_samples, samples_per_buffer), +PKSamplerRT::PKSamplerRT(int64_t num_samples, int64_t val, bool shuffle, int64_t samples_per_buffer) + : SamplerRT(num_samples, samples_per_buffer), shuffle_(shuffle), seed_(GetSeed()), next_id_(0), samples_per_class_(val) {} -Status PKSampler::InitSampler() { +Status PKSamplerRT::InitSampler() { labels_.reserve(label_to_ids_.size()); for (const auto &pair : label_to_ids_) { if (pair.second.empty() == false) { @@ -61,7 +61,7 @@ Status PKSampler::InitSampler() { return Status::OK(); } -Status PKSampler::GetNextSample(std::unique_ptr *out_buffer) { +Status PKSamplerRT::GetNextSample(std::unique_ptr *out_buffer) { if (next_id_ > num_samples_ || num_samples_ == 0) { RETURN_STATUS_UNEXPECTED("Index must be less than or equal to num_samples, but got: " + std::to_string(next_id_)); } else if (next_id_ == num_samples_) { @@ -96,7 +96,7 @@ Status PKSampler::GetNextSample(std::unique_ptr *out_buffer) { return Status::OK(); } -Status PKSampler::ResetSampler() { +Status PKSamplerRT::ResetSampler() { CHECK_FAIL_RETURN_UNEXPECTED(next_id_ == num_samples_, "ERROR Reset() called early/late"); next_id_ = 0; rnd_.seed(seed_++); @@ -108,18 +108,18 @@ Status PKSampler::ResetSampler() { return Status::OK(); } -Status PKSampler::HandshakeRandomAccessOp(const RandomAccessOp *op) { +Status PKSamplerRT::HandshakeRandomAccessOp(const RandomAccessOp *op) { RETURN_UNEXPECTED_IF_NULL(op); RETURN_IF_NOT_OK(op->GetClassIds(&label_to_ids_)); RETURN_IF_NOT_OK(InitSampler()); return Status::OK(); } -void PKSampler::Print(std::ostream &out, bool show_all) const { +void PKSamplerRT::Print(std::ostream &out, bool show_all) const { out << "\nSampler: PKSampler"; if (show_all) { // Call the super class for displaying any common detailed info - Sampler::Print(out, show_all); + SamplerRT::Print(out, show_all); // Then add our own info if any } } diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/pk_sampler.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/pk_sampler.h index e51c419cd4..fc05f261a8 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/pk_sampler.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/pk_sampler.h @@ -26,17 +26,17 @@ namespace mindspore { namespace dataset { -class PKSampler : public Sampler { // NOT YET FINISHED +class PKSamplerRT : public SamplerRT { // NOT YET FINISHED public: // @param num_samples - the number of samples to draw. value of 0 means to take the full amount // @param int64_t val // @param bool shuffle - shuffle all classIds or not, if true, classes may be 5,1,4,3,2 // @param int64_t samplesPerBuffer - Num of Sampler Ids to fetch via 1 GetNextBuffer call - explicit PKSampler(int64_t num_samples, int64_t val, bool shuffle, - int64_t samples_per_buffer = std::numeric_limits::max()); + explicit PKSamplerRT(int64_t num_samples, int64_t val, bool shuffle, + int64_t samples_per_buffer = std::numeric_limits::max()); // default destructor - ~PKSampler() = default; + ~PKSamplerRT() = default; // @param std::unique_ptr *out_buffer) { +Status PythonSamplerRT::GetNextSample(std::unique_ptr *out_buffer) { if (need_to_reset_) { (*out_buffer) = std::make_unique(0, DataBuffer::kDeBFlagEOE); } else { @@ -64,7 +64,7 @@ Status PythonSampler::GetNextSample(std::unique_ptr *out_buffer) { return Status::OK(); } -Status PythonSampler::InitSampler() { +Status PythonSamplerRT::InitSampler() { CHECK_FAIL_RETURN_UNEXPECTED( num_rows_ > 0, "Invalid parameter, num_rows must be greater than 0, but got " + std::to_string(num_rows_)); // Special value of 0 for num_samples means that the user wants to sample the entire set of data. @@ -86,7 +86,7 @@ Status PythonSampler::InitSampler() { return Status::OK(); } -Status PythonSampler::ResetSampler() { +Status PythonSamplerRT::ResetSampler() { CHECK_FAIL_RETURN_UNEXPECTED(need_to_reset_, "ERROR Reset() called not at end of an epoch"); need_to_reset_ = false; py::gil_scoped_acquire gil_acquire; @@ -106,11 +106,11 @@ Status PythonSampler::ResetSampler() { return Status::OK(); } -void PythonSampler::Print(std::ostream &out, bool show_all) const { +void PythonSamplerRT::Print(std::ostream &out, bool show_all) const { out << "\nSampler: PythonSampler"; if (show_all) { // Call the super class for displaying any common detailed info - Sampler::Print(out, show_all); + SamplerRT::Print(out, show_all); // Then add our own info if any } } diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/python_sampler.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/python_sampler.h index 0700edee27..1bf6c7979a 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/python_sampler.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/python_sampler.h @@ -23,18 +23,18 @@ namespace mindspore { namespace dataset { -class PythonSampler : public Sampler { +class PythonSamplerRT : public SamplerRT { public: // Constructor // @param num_samples - the number of samples to draw. Value of 0 means to sample all of the // data from the dataset. // @param py_sampler_instance - the python instance of the sampler // @param int64_t samples_per_buffer - Num of Sampler Ids to fetch via 1 GetNextBuffer call - explicit PythonSampler(int64_t num_samples, py::object py_sampler_instance, - int64_t samples_per_buffer = std::numeric_limits::max()); + explicit PythonSamplerRT(int64_t num_samples, py::object py_sampler_instance, + int64_t samples_per_buffer = std::numeric_limits::max()); // Destructor. - ~PythonSampler() = default; + ~PythonSamplerRT() = default; // Initialize the sampler. // @return Status diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/random_sampler.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/random_sampler.cc index a55659eff7..71eb34a3e3 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/random_sampler.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/random_sampler.cc @@ -22,16 +22,16 @@ namespace mindspore { namespace dataset { -RandomSampler::RandomSampler(int64_t num_samples, bool replacement, bool reshuffle_each_epoch, - int64_t samples_per_buffer) - : Sampler(num_samples, samples_per_buffer), +RandomSamplerRT::RandomSamplerRT(int64_t num_samples, bool replacement, bool reshuffle_each_epoch, + int64_t samples_per_buffer) + : SamplerRT(num_samples, samples_per_buffer), seed_(GetSeed()), replacement_(replacement), next_id_(0), reshuffle_each_epoch_(reshuffle_each_epoch), dist(nullptr) {} -Status RandomSampler::GetNextSample(std::unique_ptr *out_buffer) { +Status RandomSamplerRT::GetNextSample(std::unique_ptr *out_buffer) { if (next_id_ > num_samples_) { RETURN_STATUS_UNEXPECTED("RandomSampler Internal Error"); } else if (next_id_ == num_samples_) { @@ -68,7 +68,7 @@ Status RandomSampler::GetNextSample(std::unique_ptr *out_buffer) { return Status::OK(); } -Status RandomSampler::InitSampler() { +Status RandomSamplerRT::InitSampler() { // Special value of 0 for num_samples means that the user wants to sample the entire set of data. // If the user asked to sample more rows than exists in the dataset, adjust the num_samples accordingly. if (num_samples_ == 0 || num_samples_ > num_rows_) { @@ -94,7 +94,7 @@ Status RandomSampler::InitSampler() { return Status::OK(); } -Status RandomSampler::ResetSampler() { +Status RandomSamplerRT::ResetSampler() { CHECK_FAIL_RETURN_UNEXPECTED(next_id_ == num_samples_, "ERROR Reset() called early/late"); next_id_ = 0; @@ -115,11 +115,11 @@ Status RandomSampler::ResetSampler() { return Status::OK(); } -void RandomSampler::Print(std::ostream &out, bool show_all) const { +void RandomSamplerRT::Print(std::ostream &out, bool show_all) const { out << "\nSampler: RandomSampler"; if (show_all) { // Call the super class for displaying any common detailed info - Sampler::Print(out, show_all); + SamplerRT::Print(out, show_all); // Then add our own info if any } } diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/random_sampler.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/random_sampler.h index fe5330a42f..6a6e3f52a1 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/random_sampler.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/random_sampler.h @@ -24,18 +24,18 @@ namespace mindspore { namespace dataset { -class RandomSampler : public Sampler { +class RandomSamplerRT : public SamplerRT { public: // Constructor // @param int64_t num_samples - number samples to draw // @param bool replacement - put he id back / or not after a sample // @param reshuffle_each_epoch - T/F to reshuffle after epoch // @param int64_t samples_per_buffer - Num of Sampler Ids to fetch via 1 GetNextBuffer call - explicit RandomSampler(int64_t num_samples, bool replacement, bool reshuffle_each_epoch, - int64_t samples_per_buffer = std::numeric_limits::max()); + explicit RandomSamplerRT(int64_t num_samples, bool replacement, bool reshuffle_each_epoch, + int64_t samples_per_buffer = std::numeric_limits::max()); // Destructor. - ~RandomSampler() = default; + ~RandomSamplerRT() = default; // Op calls this to get next Buffer that contains all the sampleIds // @param std::unique_ptr pBuffer - Buffer to be returned to StorageOp diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sampler.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sampler.cc index ebf00413b4..7441b62771 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sampler.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sampler.cc @@ -32,13 +32,13 @@ Status RandomAccessOp::GetNumRowsInDataset(int64_t *num) const { return Status::OK(); } -Sampler::Sampler(int64_t num_samples, int64_t samples_per_buffer) +SamplerRT::SamplerRT(int64_t num_samples, int64_t samples_per_buffer) : num_rows_(0), num_samples_(num_samples), samples_per_buffer_(samples_per_buffer), col_desc_(nullptr) {} -Status Sampler::HandshakeRandomAccessOp(const RandomAccessOp *op) { - std::shared_ptr child_sampler; +Status SamplerRT::HandshakeRandomAccessOp(const RandomAccessOp *op) { + std::shared_ptr child_sampler; if (HasChildSampler()) { - child_sampler = std::dynamic_pointer_cast(child_[0]); + child_sampler = std::dynamic_pointer_cast(child_[0]); if (!child_sampler) { std::string err_msg("Cannot handshake, child is not a sampler object."); RETURN_STATUS_UNEXPECTED(err_msg); @@ -64,7 +64,7 @@ Status Sampler::HandshakeRandomAccessOp(const RandomAccessOp *op) { return Status::OK(); } -Status Sampler::CreateSamplerTensor(std::shared_ptr *sample_ids, int64_t num_elements) { +Status SamplerRT::CreateSamplerTensor(std::shared_ptr *sample_ids, int64_t num_elements) { if (num_elements == 0) { RETURN_STATUS_UNEXPECTED("Invalid data, num of elements cannot be 0."); } @@ -77,7 +77,7 @@ Status Sampler::CreateSamplerTensor(std::shared_ptr *sample_ids, int64_t return Status::OK(); } -void Sampler::Print(std::ostream &out, bool show_all) const { +void SamplerRT::Print(std::ostream &out, bool show_all) const { // Sampler printing is usually only called in the show_all mode. // Derived classes will display the name, then call back to this base // for common info. @@ -88,7 +88,7 @@ void Sampler::Print(std::ostream &out, bool show_all) const { } #ifdef ENABLE_PYTHON -Status Sampler::GetAllIdsThenReset(py::array *data) { +Status SamplerRT::GetAllIdsThenReset(py::array *data) { std::unique_ptr db; std::shared_ptr sample_ids; TensorRow sample_row; @@ -123,27 +123,27 @@ Status Sampler::GetAllIdsThenReset(py::array *data) { } #endif -Status Sampler::SetNumSamples(int64_t num_samples) { +Status SamplerRT::SetNumSamples(int64_t num_samples) { CHECK_FAIL_RETURN_UNEXPECTED(num_samples >= 0, "Invalid parameter, num_samples must be greater than or equal to 0."); num_samples_ = num_samples; return Status::OK(); } -int64_t Sampler::GetNumSamples() { return num_samples_; } +int64_t SamplerRT::GetNumSamples() { return num_samples_; } -Status Sampler::SetNumRowsInDataset(int64_t num_rows) { +Status SamplerRT::SetNumRowsInDataset(int64_t num_rows) { CHECK_FAIL_RETURN_UNEXPECTED(num_rows > 0, "Invalid parameter, num_rows must be greater than 0."); num_rows_ = num_rows; return Status::OK(); } -Status Sampler::AddChild(std::shared_ptr child) { +Status SamplerRT::AddChild(std::shared_ptr child) { if (child == nullptr) { return Status::OK(); } // Only samplers can be added, not any other DatasetOp. - std::shared_ptr sampler = std::dynamic_pointer_cast(child); + std::shared_ptr sampler = std::dynamic_pointer_cast(child); if (!sampler) { std::string err_msg("Cannot add child, child is not a sampler object."); RETURN_STATUS_UNEXPECTED(err_msg); @@ -160,9 +160,9 @@ Status Sampler::AddChild(std::shared_ptr child) { return Status::OK(); } -bool Sampler::HasChildSampler() { return !child_.empty(); } +bool SamplerRT::HasChildSampler() { return !child_.empty(); } -Status Sampler::GetAssociatedChildId(int64_t *out_associated_id, int64_t id) { +Status SamplerRT::GetAssociatedChildId(int64_t *out_associated_id, int64_t id) { if (child_ids_ == nullptr) { RETURN_STATUS_UNEXPECTED("Trying to get associated child id, but there are no child ids!"); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sampler.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sampler.h index 1aa061558c..76a8dee4a8 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sampler.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sampler.h @@ -51,21 +51,21 @@ class RandomAccessOp { protected: // The amount of rows in the dataset itself. This is the before-sampling value, the // total count of rows. A sampler may choose to sample less than this amount. - int64_t num_rows_; + int64_t num_rows_ = -1; }; -class Sampler { +class SamplerRT { public: // Constructor // @param int64_t num_samples: the user-requested number of samples ids to generate. A value of 0 // indicates that the sampler should produce the complete set of ids. // @param int64_t samplesPerBuffer: Num of Sampler Ids to fetch via 1 GetNextBuffer call - explicit Sampler(int64_t num_samples, int64_t samples_per_buffer); + explicit SamplerRT(int64_t num_samples, int64_t samples_per_buffer); - Sampler(const Sampler &s) : Sampler(s.num_samples_, s.samples_per_buffer_) {} + SamplerRT(const SamplerRT &s) : SamplerRT(s.num_samples_, s.samples_per_buffer_) {} // default destructor - ~Sampler() = default; + ~SamplerRT() = default; // Get a list of sample ids. // @note It is Sampler responsibility to make sure that the id is not out of bound. @@ -111,7 +111,7 @@ class Sampler { // Adds a sampler to become our child. // @param std::shared_ptr - The sampler to add as a child. // @return - The error code returned. - Status AddChild(std::shared_ptr child); + Status AddChild(std::shared_ptr child); // A helper function to create a int64_t 1-D Tensor specifically used to hold sampleIds for Sampler // @param std::shared_ptr* sampleIds @@ -129,7 +129,7 @@ class Sampler { // @param out - reference to the output stream being overloaded // @param sampler - reference to teh sampler to print // @return - the output stream must be returned - friend std::ostream &operator<<(std::ostream &out, const Sampler &sampler) { + friend std::ostream &operator<<(std::ostream &out, const SamplerRT &sampler) { sampler.Print(out, false); return out; } @@ -158,7 +158,7 @@ class Sampler { int64_t samples_per_buffer_; std::unique_ptr col_desc_; - std::vector> child_; // Child nodes + std::vector> child_; // Child nodes std::unique_ptr child_ids_; }; } // namespace dataset diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.cc index ced0ed1eea..92f8384fe0 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.cc @@ -20,10 +20,10 @@ namespace mindspore { namespace dataset { -SequentialSampler::SequentialSampler(int64_t num_samples, int64_t start_index, int64_t samples_per_buffer) - : Sampler(num_samples, samples_per_buffer), start_index_(start_index), current_id_(start_index), id_count_(0) {} +SequentialSamplerRT::SequentialSamplerRT(int64_t num_samples, int64_t start_index, int64_t samples_per_buffer) + : SamplerRT(num_samples, samples_per_buffer), start_index_(start_index), current_id_(start_index), id_count_(0) {} -Status SequentialSampler::GetNextSample(std::unique_ptr *out_buffer) { +Status SequentialSamplerRT::GetNextSample(std::unique_ptr *out_buffer) { if (id_count_ > num_samples_) { RETURN_STATUS_UNEXPECTED("SequentialSampler Internal Error"); } else if (id_count_ == num_samples_) { @@ -62,7 +62,7 @@ Status SequentialSampler::GetNextSample(std::unique_ptr *out_buffer) return Status::OK(); } -Status SequentialSampler::InitSampler() { +Status SequentialSamplerRT::InitSampler() { CHECK_FAIL_RETURN_UNEXPECTED(start_index_ >= 0, "Invalid parameter, start_index must be greater than or equal to 0, but got " + std::to_string(start_index_) + ".\n"); @@ -85,7 +85,7 @@ Status SequentialSampler::InitSampler() { return Status::OK(); } -Status SequentialSampler::ResetSampler() { +Status SequentialSamplerRT::ResetSampler() { CHECK_FAIL_RETURN_UNEXPECTED(id_count_ == num_samples_, "ERROR Reset() called early/late"); current_id_ = start_index_; id_count_ = 0; @@ -97,11 +97,11 @@ Status SequentialSampler::ResetSampler() { return Status::OK(); } -void SequentialSampler::Print(std::ostream &out, bool show_all) const { +void SequentialSamplerRT::Print(std::ostream &out, bool show_all) const { out << "\nSampler: SequentialSampler"; if (show_all) { // Call the super class for displaying any common detailed info - Sampler::Print(out, show_all); + SamplerRT::Print(out, show_all); // Then add our own info out << "\nStart index: " << start_index_; } diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h index 2a313347f1..78349aba8b 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h @@ -23,18 +23,18 @@ namespace mindspore { namespace dataset { -class SequentialSampler : public Sampler { +class SequentialSamplerRT : public SamplerRT { public: // Constructor // @param num_samples - The number of samples to draw. A value of 0 indicates the sampler should produce the // full amount of ids from the dataset // @param start_index - The starting index value // @param int64_t samplesPerBuffer - Num of Sampler Ids to fetch via 1 GetNextBuffer call - explicit SequentialSampler(int64_t num_samples, int64_t start_index, - int64_t samples_per_buffer = std::numeric_limits::max()); + explicit SequentialSamplerRT(int64_t num_samples, int64_t start_index, + int64_t samples_per_buffer = std::numeric_limits::max()); // Destructor. - ~SequentialSampler() = default; + ~SequentialSamplerRT() = default; // init sampler, called by python Status InitSampler() override; diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.cc index b1f251a8d5..311b305927 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.cc @@ -27,12 +27,12 @@ namespace mindspore { namespace dataset { // Constructor. -SubsetRandomSampler::SubsetRandomSampler(int64_t num_samples, const std::vector &indices, - int64_t samples_per_buffer) - : Sampler(num_samples, samples_per_buffer), indices_(indices), sample_id_(0), buffer_id_(0) {} +SubsetRandomSamplerRT::SubsetRandomSamplerRT(int64_t num_samples, const std::vector &indices, + int64_t samples_per_buffer) + : SamplerRT(num_samples, samples_per_buffer), indices_(indices), sample_id_(0), buffer_id_(0) {} // Initialized this Sampler. -Status SubsetRandomSampler::InitSampler() { +Status SubsetRandomSamplerRT::InitSampler() { CHECK_FAIL_RETURN_UNEXPECTED( num_rows_ > 0, "Invalid parameter, num_rows must be greater than 0, but got " + std::to_string(num_rows_) + ".\n"); @@ -56,7 +56,7 @@ Status SubsetRandomSampler::InitSampler() { } // Reset the internal variable to the initial state. -Status SubsetRandomSampler::ResetSampler() { +Status SubsetRandomSamplerRT::ResetSampler() { // Reset the internal counters. sample_id_ = 0; buffer_id_ = 0; @@ -73,7 +73,7 @@ Status SubsetRandomSampler::ResetSampler() { } // Get the sample ids. -Status SubsetRandomSampler::GetNextSample(std::unique_ptr *out_buffer) { +Status SubsetRandomSamplerRT::GetNextSample(std::unique_ptr *out_buffer) { // All samples have been drawn if (sample_id_ == num_samples_) { (*out_buffer) = std::make_unique(buffer_id_++, DataBuffer::kDeBFlagEOE); @@ -120,11 +120,11 @@ Status SubsetRandomSampler::GetNextSample(std::unique_ptr *out_buffe return Status::OK(); } -void SubsetRandomSampler::Print(std::ostream &out, bool show_all) const { +void SubsetRandomSamplerRT::Print(std::ostream &out, bool show_all) const { out << "\nSampler: SubsetRandomSampler"; if (show_all) { // Call the super class for displaying any common detailed info - Sampler::Print(out, show_all); + SamplerRT::Print(out, show_all); // Then add our own info if any } } diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.h index 0a1feef0a9..09fa93d713 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.h @@ -25,18 +25,18 @@ namespace mindspore { namespace dataset { // Randomly samples elements from a given list of indices, without replacement. -class SubsetRandomSampler : public Sampler { +class SubsetRandomSamplerRT : public SamplerRT { public: // Constructor. // @param num_samples The number of samples to draw. 0 for the full amount. // @param indices List of indices from where we will randomly draw samples. // @param samples_per_buffer The number of ids we draw on each call to GetNextBuffer(). // When samplesPerBuffer=0, GetNextBuffer() will draw all the sample ids and return them at once. - explicit SubsetRandomSampler(int64_t num_samples, const std::vector &indices, - std::int64_t samples_per_buffer = std::numeric_limits::max()); + explicit SubsetRandomSamplerRT(int64_t num_samples, const std::vector &indices, + std::int64_t samples_per_buffer = std::numeric_limits::max()); // Destructor. - ~SubsetRandomSampler() = default; + ~SubsetRandomSamplerRT() = default; // Initialize the sampler. // @return Status diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.cc index 555bdbb55e..4dc9c0b718 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.cc @@ -27,16 +27,16 @@ namespace mindspore { namespace dataset { // Constructor. -WeightedRandomSampler::WeightedRandomSampler(int64_t num_samples, const std::vector &weights, bool replacement, - int64_t samples_per_buffer) - : Sampler(num_samples, samples_per_buffer), +WeightedRandomSamplerRT::WeightedRandomSamplerRT(int64_t num_samples, const std::vector &weights, + bool replacement, int64_t samples_per_buffer) + : SamplerRT(num_samples, samples_per_buffer), weights_(weights), replacement_(replacement), sample_id_(0), buffer_id_(0) {} // Initialized this Sampler. -Status WeightedRandomSampler::InitSampler() { +Status WeightedRandomSamplerRT::InitSampler() { // Special value of 0 for num_samples means that the user wants to sample the entire set of data. // If the user asked to sample more rows than exists in the dataset, adjust the num_samples accordingly. if (num_samples_ == 0 || num_samples_ > num_rows_) { @@ -78,7 +78,7 @@ Status WeightedRandomSampler::InitSampler() { } // Initialized the computation for generating weighted random numbers without replacement using onepass method. -void WeightedRandomSampler::InitOnePassSampling() { +void WeightedRandomSamplerRT::InitOnePassSampling() { exp_dist_->reset(); onepass_ids_.clear(); std::vector> val_idx; @@ -94,7 +94,7 @@ void WeightedRandomSampler::InitOnePassSampling() { } // Reset the internal variable to the initial state and reshuffle the indices. -Status WeightedRandomSampler::ResetSampler() { +Status WeightedRandomSamplerRT::ResetSampler() { sample_id_ = 0; buffer_id_ = 0; rand_gen_.seed(GetSeed()); @@ -112,7 +112,7 @@ Status WeightedRandomSampler::ResetSampler() { } // Get the sample ids. -Status WeightedRandomSampler::GetNextSample(std::unique_ptr *out_buffer) { +Status WeightedRandomSamplerRT::GetNextSample(std::unique_ptr *out_buffer) { if (weights_.size() > static_cast(num_rows_)) { return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Invalid parameter, size of sample weights must be less than or equal to num of data, " @@ -180,11 +180,11 @@ Status WeightedRandomSampler::GetNextSample(std::unique_ptr *out_buf return Status::OK(); } -void WeightedRandomSampler::Print(std::ostream &out, bool show_all) const { +void WeightedRandomSamplerRT::Print(std::ostream &out, bool show_all) const { out << "\nSampler: WeightedRandomSampler"; if (show_all) { // Call the super class for displaying any common detailed info - Sampler::Print(out, show_all); + SamplerRT::Print(out, show_all); // Then add our own info if any } } diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h index 9bcb2bac22..134f288914 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h @@ -26,7 +26,7 @@ namespace mindspore { namespace dataset { // Samples elements from id `0, 1, ..., weights.size()-1` with given probabilities (weights). -class WeightedRandomSampler : public Sampler { +class WeightedRandomSamplerRT : public SamplerRT { public: // Constructor. // @param num_samples Number of samples to be drawn. @@ -34,11 +34,11 @@ class WeightedRandomSampler : public Sampler { // @param replacement Determine if samples are drawn with/without replacement. // @param samples_per_buffer The number of ids we draw on each call to GetNextBuffer(). // When samplesPerBuffer=0, GetNextBuffer() will draw all the sample ids and return them at once. - WeightedRandomSampler(int64_t num_samples, const std::vector &weights, bool replacement, - int64_t samples_per_buffer = std::numeric_limits::max()); + WeightedRandomSamplerRT(int64_t num_samples, const std::vector &weights, bool replacement, + int64_t samples_per_buffer = std::numeric_limits::max()); // Destructor. - ~WeightedRandomSampler() = default; + ~WeightedRandomSamplerRT() = default; // Initialize the sampler. // @param op (Not used in this sampler) diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.cc index 52d7e7745a..935150c361 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.cc @@ -84,7 +84,7 @@ Status TextFileOp::Builder::Build(std::shared_ptr *op) { TextFileOp::TextFileOp(int32_t num_workers, int64_t rows_per_buffer, int64_t total_rows, int32_t worker_connector_size, std::unique_ptr schema, std::vector text_files_list, int32_t op_connector_size, bool shuffle_files, int32_t num_device, int32_t device_id, - std::shared_ptr sampler) + std::shared_ptr sampler) : ParallelOp(num_workers, op_connector_size, std::move(sampler)), device_id_(device_id), num_devices_(num_device), diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.h index 131d3accb9..7084eae332 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.h @@ -115,7 +115,7 @@ class TextFileOp : public ParallelOp { // Setter method // @param std::shared_ptr sampler // @return Builder setter method returns reference to the builder. - Builder &SetSampler(std::shared_ptr sampler) { + Builder &SetSampler(std::shared_ptr sampler) { builder_sampler_ = std::move(sampler); return *this; } @@ -131,7 +131,7 @@ class TextFileOp : public ParallelOp { std::vector builder_text_files_list_; bool builder_shuffle_files_; std::unique_ptr builder_schema_; - std::shared_ptr builder_sampler_; + std::shared_ptr builder_sampler_; }; // Constructor of TextFileOp @@ -148,7 +148,7 @@ class TextFileOp : public ParallelOp { // @param sampler - allow a sampler. Only valid if a cache exists in ascendent tree nodes TextFileOp(int32_t num_workers, int64_t rows_per_buffer, int64_t total_rows, int32_t worker_connector_size, std::unique_ptr, std::vector text_files_list, int32_t op_connector_size, - bool shuffle_files, int32_t num_devices, int32_t device_id, std::shared_ptr sampler); + bool shuffle_files, int32_t num_devices, int32_t device_id, std::shared_ptr sampler); // Default destructor ~TextFileOp() = default; diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/tf_reader_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/tf_reader_op.cc index 1dfbd9a3ad..206bc4f50a 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/tf_reader_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/tf_reader_op.cc @@ -58,7 +58,7 @@ TFReaderOp::Builder::Builder() builder_data_schema_ = std::make_unique(); } -bool ValidateFirstRowCrc(const std::string &filename) { +bool TFReaderOp::ValidateFirstRowCrc(const std::string &filename) { std::ifstream reader; reader.open(filename); if (!reader) { @@ -134,7 +134,7 @@ TFReaderOp::TFReaderOp(int32_t num_workers, int32_t worker_connector_size, int64 int64_t total_num_rows, std::vector dataset_files_list, std::unique_ptr data_schema, int32_t op_connector_size, std::vector columns_to_load, bool shuffle_files, int32_t num_device, - int32_t device_id, bool equal_rows_per_shard, std::shared_ptr sampler) + int32_t device_id, bool equal_rows_per_shard, std::shared_ptr sampler) : ParallelOp(num_workers, op_connector_size, std::move(sampler)), device_id_(device_id), num_devices_(num_device), diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/tf_reader_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/tf_reader_op.h index 217ed4787b..a3dadd7df9 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/tf_reader_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/tf_reader_op.h @@ -156,14 +156,14 @@ class TFReaderOp : public ParallelOp { // Setter method // @param std::shared_ptr sampler // @return Builder setter method returns reference to the builder. - Builder &SetSampler(std::shared_ptr sampler) { + Builder &SetSampler(std::shared_ptr sampler) { builder_sampler_ = std::move(sampler); return *this; } private: std::unique_ptr builder_data_schema_; - std::shared_ptr builder_sampler_; + std::shared_ptr builder_sampler_; int32_t builder_device_id_; int32_t builder_num_devices_; int32_t builder_num_workers_; @@ -193,7 +193,7 @@ class TFReaderOp : public ParallelOp { TFReaderOp(int32_t num_workers, int32_t worker_connector_size, int64_t rows_per_buffer, int64_t total_num_rows, std::vector dataset_files_list, std::unique_ptr data_schema, int32_t op_connector_size, std::vector columns_to_load, bool shuffle_files, - int32_t num_devices, int32_t device_id, bool equal_rows_per_shard, std::shared_ptr sampler); + int32_t num_devices, int32_t device_id, bool equal_rows_per_shard, std::shared_ptr sampler); // Default destructor ~TFReaderOp() = default; @@ -262,6 +262,8 @@ class TFReaderOp : public ParallelOp { /// \return Status of the function Status GetDatasetSize(int64_t *dataset_size) override; + static bool ValidateFirstRowCrc(const std::string &filename); + private: // The entry point for when workers are launched. // @param worker_id - the id of the worker that is executing this function. diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/voc_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/voc_op.cc index 2acd467862..c9589e7380 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/voc_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/voc_op.cc @@ -62,7 +62,7 @@ Status VOCOp::Builder::Build(std::shared_ptr *ptr) { if (builder_sampler_ == nullptr) { const int64_t num_samples = 0; const int64_t start_index = 0; - builder_sampler_ = std::make_shared(start_index, num_samples); + builder_sampler_ = std::make_shared(start_index, num_samples); } builder_schema_ = std::make_unique(); if (builder_task_type_ == TaskType::Segmentation) { @@ -102,7 +102,8 @@ Status VOCOp::Builder::SanityCheck() { VOCOp::VOCOp(const TaskType &task_type, const std::string &task_mode, const std::string &folder_path, const std::map &class_index, int32_t num_workers, int32_t rows_per_buffer, - int32_t queue_size, bool decode, std::unique_ptr data_schema, std::shared_ptr sampler) + int32_t queue_size, bool decode, std::unique_ptr data_schema, + std::shared_ptr sampler) : ParallelOp(num_workers, queue_size, std::move(sampler)), decode_(decode), row_cnt_(0), diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/voc_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/voc_op.h index b648068a1e..35aef73df2 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/voc_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/voc_op.h @@ -118,7 +118,7 @@ class VOCOp : public ParallelOp, public RandomAccessOp { // Setter method. // @param std::shared_ptr sampler // @return Builder setter method returns reference to the builder. - Builder &SetSampler(std::shared_ptr sampler) { + Builder &SetSampler(std::shared_ptr sampler) { builder_sampler_ = std::move(sampler); return *this; } @@ -148,7 +148,7 @@ class VOCOp : public ParallelOp, public RandomAccessOp { int32_t builder_num_workers_; int32_t builder_op_connector_size_; int32_t builder_rows_per_buffer_; - std::shared_ptr builder_sampler_; + std::shared_ptr builder_sampler_; std::unique_ptr builder_schema_; std::map builder_labels_to_read_; }; @@ -166,7 +166,7 @@ class VOCOp : public ParallelOp, public RandomAccessOp { // @param std::shared_ptr sampler - sampler tells VOCOp what to read VOCOp(const TaskType &task_type, const std::string &task_mode, const std::string &folder_path, const std::map &class_index, int32_t num_workers, int32_t rows_per_buffer, - int32_t queue_size, bool decode, std::unique_ptr data_schema, std::shared_ptr sampler); + int32_t queue_size, bool decode, std::unique_ptr data_schema, std::shared_ptr sampler); // Destructor ~VOCOp() = default; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/cache/dataset_cache.h b/mindspore/ccsrc/minddata/dataset/engine/ir/cache/dataset_cache.h index 9096824d1b..0922aba403 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/cache/dataset_cache.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/cache/dataset_cache.h @@ -21,7 +21,7 @@ #include "minddata/dataset/util/status.h" #include "minddata/dataset/engine/datasetops/dataset_op.h" -namespace mindspore::dataset::api { +namespace mindspore::dataset { class DatasetCache { public: @@ -29,6 +29,6 @@ class DatasetCache { virtual Status ValidateParams() = 0; virtual Status CreateCacheOp(int num_workers, std::shared_ptr *ds_op) = 0; }; -} // namespace mindspore::dataset::api +} // namespace mindspore::dataset #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_CACHE_DATASET_CACHE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/cache/dataset_cache_impl.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/cache/dataset_cache_impl.cc index ffacd02e8d..30cb8fc144 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/cache/dataset_cache_impl.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/cache/dataset_cache_impl.cc @@ -18,7 +18,7 @@ #include "minddata/dataset/engine/ir/cache/dataset_cache_impl.h" #include "minddata/dataset/engine/datasetops/cache_op.h" -namespace mindspore::dataset::api { +namespace mindspore::dataset { /// Method to initialize the DatasetCache by creating an instance of a CacheClient /// \return Status Error code @@ -41,4 +41,4 @@ Status DatasetCacheImpl::CreateCacheOp(int32_t num_workers, std::shared_ptr num_connections_; std::optional prefetch_sz_; }; -} // namespace mindspore::dataset::api +} // namespace mindspore::dataset #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_CACHE_DATASET_CACHE_IMPL_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/batch_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/batch_node.cc index 6c74716149..9af28833de 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/batch_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/batch_node.cc @@ -26,7 +26,6 @@ #include "minddata/dataset/util/status.h" namespace mindspore { namespace dataset { -namespace api { #ifdef ENABLE_PYTHON // constructor #1, called by Pybind @@ -96,6 +95,5 @@ std::vector> BatchNode::Build() { return node_ops; } -} // namespace api } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/batch_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/batch_node.h index cedb05c4b4..9156cc6684 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/batch_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/batch_node.h @@ -27,7 +27,6 @@ namespace mindspore { namespace dataset { -namespace api { class BatchNode : public DatasetNode { public: @@ -66,7 +65,6 @@ class BatchNode : public DatasetNode { std::map>> pad_map_; }; -} // namespace api } // namespace dataset } // namespace mindspore #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_BATCH_NODE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.cc index 81be43eac1..9d5ec38b47 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.cc @@ -27,7 +27,7 @@ #include "minddata/dataset/util/status.h" namespace mindspore { namespace dataset { -namespace api { + BucketBatchByLengthNode::BucketBatchByLengthNode( std::shared_ptr child, const std::vector &column_names, const std::vector &bucket_boundaries, const std::vector &bucket_batch_sizes, @@ -121,6 +121,5 @@ Status BucketBatchByLengthNode::ValidateParams() { return Status::OK(); } -} // namespace api } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h index e9f395c363..2dd0bc04d8 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h @@ -27,7 +27,7 @@ namespace mindspore { namespace dataset { -namespace api { + class BucketBatchByLengthNode : public DatasetNode { public: /// \brief Constructor @@ -58,7 +58,6 @@ class BucketBatchByLengthNode : public DatasetNode { bool drop_remainder_; }; -} // namespace api } // namespace dataset } // namespace mindspore #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_BUCKET_BATCH_BY_LENGTH_NODE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_sentence_piece_vocab_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_sentence_piece_vocab_node.cc index 1dda6410cd..9d3a17b559 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_sentence_piece_vocab_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_sentence_piece_vocab_node.cc @@ -26,7 +26,6 @@ namespace mindspore { namespace dataset { -namespace api { BuildSentenceVocabNode::BuildSentenceVocabNode(std::shared_ptr child, std::shared_ptr vocab, @@ -77,6 +76,6 @@ Status BuildSentenceVocabNode::ValidateParams() { return Status::OK(); } -} // namespace api + } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_sentence_piece_vocab_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_sentence_piece_vocab_node.h index 10eb7e99d0..01b36a8e6f 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_sentence_piece_vocab_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_sentence_piece_vocab_node.h @@ -27,7 +27,6 @@ namespace mindspore { namespace dataset { -namespace api { class BuildSentenceVocabNode : public DatasetNode { public: @@ -56,7 +55,6 @@ class BuildSentenceVocabNode : public DatasetNode { std::unordered_map params_; }; -} // namespace api } // namespace dataset } // namespace mindspore #endif // #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_BUILD_SENTENCE_PIECE_VOCAB_NODE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_vocab_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_vocab_node.cc index bad52db138..623eccb86a 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_vocab_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_vocab_node.cc @@ -26,7 +26,6 @@ #include "minddata/dataset/util/status.h" namespace mindspore { namespace dataset { -namespace api { BuildVocabNode::BuildVocabNode(std::shared_ptr child, std::shared_ptr vocab, const std::vector &columns, const std::pair &freq_range, @@ -78,6 +77,6 @@ Status BuildVocabNode::ValidateParams() { return Status::OK(); } -} // namespace api + } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_vocab_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_vocab_node.h index 9f0de2d133..408115a4aa 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_vocab_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_vocab_node.h @@ -26,7 +26,6 @@ namespace mindspore { namespace dataset { -namespace api { class BuildVocabNode : public DatasetNode { public: @@ -55,7 +54,6 @@ class BuildVocabNode : public DatasetNode { bool special_first_; }; -} // namespace api } // namespace dataset } // namespace mindspore #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_BUILD_VOCAB_NODE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/concat_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/concat_node.cc index 36ce94a732..fc3fc0f14e 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/concat_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/concat_node.cc @@ -25,7 +25,7 @@ #include "minddata/dataset/util/status.h" namespace mindspore { namespace dataset { -namespace api { + // Function to build ConcatOp ConcatNode::ConcatNode(const std::vector> &datasets) { this->children = datasets; } @@ -53,6 +53,5 @@ std::vector> ConcatNode::Build() { return node_ops; } -} // namespace api } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/concat_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/concat_node.h index 61822b1283..1a496c76c9 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/concat_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/concat_node.h @@ -25,7 +25,6 @@ namespace mindspore { namespace dataset { -namespace api { class ConcatNode : public DatasetNode { public: @@ -44,7 +43,6 @@ class ConcatNode : public DatasetNode { Status ValidateParams() override; }; -} // namespace api } // namespace dataset } // namespace mindspore #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_CONCAT_NODE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.cc index 895fc5fd03..547e2e47fe 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.cc @@ -16,11 +16,187 @@ #include "minddata/dataset/engine/ir/datasetops/dataset_node.h" +#include #include +#include + +#include "minddata/dataset/util/random.h" namespace mindspore { namespace dataset { -namespace api { + +// Helper function to compute a default shuffle size +Status ComputeShuffleSize(int64_t num_files, int64_t num_devices, int64_t num_rows, int64_t total_rows, + int64_t *shuffle_size) { + const int64_t average_files_multiplier = 4; + const int64_t shuffle_max = 10000; + int64_t avg_rows_per_file = 0; + + // Adjust the num rows per shard if sharding was given + if (num_devices > 0) { + if (num_rows % num_devices == 0) { + num_rows = num_rows / num_devices; + } else { + num_rows = (num_rows / num_devices) + 1; + } + } + + // Cap based on total rows directive. Some ops do not have this and give value of 0. + if (total_rows > 0) { + num_rows = std::min(num_rows, total_rows); + } + + // get the average per file + CHECK_FAIL_RETURN_UNEXPECTED(num_files != 0, "The size of dataset_files must greater than 0."); + avg_rows_per_file = num_rows / num_files; + + *shuffle_size = std::max(avg_rows_per_file * average_files_multiplier, shuffle_max); + return Status::OK(); +} + +// Helper function to inject a shuffle operator over top of current operator being built +Status AddShuffleOp(int64_t num_files, int64_t num_devices, int64_t num_rows, int64_t total_rows, + int32_t connector_que_size, int32_t rows_per_buffer, std::shared_ptr *shuffle_op) { + std::shared_ptr new_shuffle_op = nullptr; + int64_t shuffle_size = 0; + RETURN_EMPTY_IF_ERROR(ComputeShuffleSize(num_files, num_devices, num_rows, total_rows, &shuffle_size)); + MS_LOG(INFO) << "Dataset::AddShuffleOp - num_rows: " << num_rows << ", shuffle_size: " << shuffle_size; + // Add the shuffle op + *shuffle_op = std::make_shared(shuffle_size, GetSeed(), connector_que_size, true, rows_per_buffer); + return Status::OK(); +} + +// Helper function to validate dataset directory parameter +Status ValidateDatasetDirParam(const std::string &dataset_name, std::string dataset_dir) { + if (dataset_dir.empty()) { + std::string err_msg = dataset_name + ": dataset_dir is not specified."; + MS_LOG(ERROR) << err_msg; + RETURN_STATUS_SYNTAX_ERROR(err_msg); + } + + Path dir(dataset_dir); + if (!dir.IsDirectory()) { + std::string err_msg = dataset_name + ": dataset_dir: [" + dataset_dir + "] is an invalid directory path."; + MS_LOG(ERROR) << err_msg; + RETURN_STATUS_SYNTAX_ERROR(err_msg); + } + + if (access(dataset_dir.c_str(), R_OK) == -1) { + std::string err_msg = dataset_name + ": No access to specified dataset path: " + dataset_dir; + MS_LOG(ERROR) << err_msg; + RETURN_STATUS_SYNTAX_ERROR(err_msg); + } + + return Status::OK(); +} + +// Helper function to validate dataset files parameter +Status ValidateDatasetFilesParam(const std::string &dataset_name, const std::vector &dataset_files) { + if (dataset_files.empty()) { + std::string err_msg = dataset_name + ": dataset_files is not specified."; + MS_LOG(ERROR) << err_msg; + RETURN_STATUS_SYNTAX_ERROR(err_msg); + } + + for (auto f : dataset_files) { + Path dataset_file(f); + if (!dataset_file.Exists()) { + std::string err_msg = dataset_name + ": dataset file: [" + f + "] is invalid or does not exist."; + MS_LOG(ERROR) << err_msg; + RETURN_STATUS_SYNTAX_ERROR(err_msg); + } + if (access(dataset_file.toString().c_str(), R_OK) == -1) { + std::string err_msg = dataset_name + ": No access to specified dataset file: " + f; + MS_LOG(ERROR) << err_msg; + RETURN_STATUS_SYNTAX_ERROR(err_msg); + } + } + + return Status::OK(); +} + +// Helper function to validate dataset num_shards and shard_id parameters +Status ValidateDatasetShardParams(const std::string &dataset_name, int32_t num_shards, int32_t shard_id) { + if (num_shards <= 0) { + std::string err_msg = dataset_name + ": Invalid num_shards: " + std::to_string(num_shards); + MS_LOG(ERROR) << err_msg; + RETURN_STATUS_SYNTAX_ERROR(err_msg); + } + + if (shard_id < 0 || shard_id >= num_shards) { + // num_shards; + std::string err_msg = dataset_name + ": Invalid input, shard_id: " + std::to_string(shard_id) + + ", num_shards: " + std::to_string(num_shards); + MS_LOG(ERROR) << err_msg; + RETURN_STATUS_SYNTAX_ERROR(err_msg); + } + + return Status::OK(); +} + +// Helper function to validate dataset sampler parameter +Status ValidateDatasetSampler(const std::string &dataset_name, const std::shared_ptr &sampler) { + if (sampler == nullptr) { + std::string err_msg = dataset_name + ": Sampler is not constructed correctly, sampler: nullptr"; + MS_LOG(ERROR) << err_msg; + RETURN_STATUS_SYNTAX_ERROR(err_msg); + } + return Status::OK(); +} + +Status ValidateStringValue(const std::string &dataset_name, const std::string &str, + const std::unordered_set &valid_strings) { + if (valid_strings.find(str) == valid_strings.end()) { + std::string mode; + mode = std::accumulate(valid_strings.begin(), valid_strings.end(), mode, + [](std::string a, std::string b) { return std::move(a) + " " + std::move(b); }); + std::string err_msg = dataset_name + ": " + str + " does not match any mode in [" + mode + " ]"; + MS_LOG(ERROR) << err_msg; + RETURN_STATUS_SYNTAX_ERROR(err_msg); + } + return Status::OK(); +} + +// Helper function to validate dataset input/output column parameter +Status ValidateDatasetColumnParam(const std::string &dataset_name, const std::string &column_param, + const std::vector &columns) { + if (columns.empty()) { + std::string err_msg = dataset_name + ":" + column_param + " should not be empty string"; + MS_LOG(ERROR) << err_msg; + RETURN_STATUS_SYNTAX_ERROR(err_msg); + } + for (uint32_t i = 0; i < columns.size(); ++i) { + if (columns[i].empty()) { + std::string err_msg = dataset_name + ":" + column_param + "[" + std::to_string(i) + "] must not be empty"; + MS_LOG(ERROR) << err_msg; + RETURN_STATUS_SYNTAX_ERROR(err_msg); + } + } + std::set columns_set(columns.begin(), columns.end()); + if (columns_set.size() != columns.size()) { + std::string err_msg = dataset_name + ":" + column_param + ": Every column name should not be same with others"; + MS_LOG(ERROR) << err_msg; + RETURN_STATUS_SYNTAX_ERROR(err_msg); + } + return Status::OK(); +} + +std::shared_ptr SelectSampler(int64_t num_samples, bool shuffle, int32_t num_shards, int32_t shard_id) { + if (shuffle) { + if (num_shards > 1) { + // If shuffle enabled, sharding enabled, use distributed random sampler + return DistributedSampler(num_shards, shard_id, shuffle, num_samples); + } + // If shuffle enabled, sharding disabled, use random sampler + return RandomSampler(num_samples >= 0, num_samples); + } + if (num_shards > 1) { + // If shuffle disabled, sharding enabled, use distributed sequential sampler + return DistributedSampler(num_shards, shard_id, shuffle, num_samples); + } + // If shuffle disabled, sharding disabled, use sequential sampler + return SequentialSampler(0, num_samples); +} Status DatasetNode::AddCacheOp(std::vector> *node_ops) { if (cache_ != nullptr) { @@ -60,6 +236,5 @@ DatasetNode::DatasetNode() { worker_connector_size_ = cfg->worker_connector_size(); } -} // namespace api } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.h index 89766d31a7..0e92b5547d 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.h @@ -28,7 +28,6 @@ namespace mindspore { namespace dataset { -namespace api { class Dataset; class SamplerObj; @@ -120,7 +119,6 @@ class DatasetNode : public std::enable_shared_from_this { int32_t worker_connector_size_; }; -} // namespace api } // namespace dataset } // namespace mindspore #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_DATASET_NODE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/map_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/map_node.cc index c7dc98a9e5..91cc1fe98b 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/map_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/map_node.cc @@ -26,7 +26,6 @@ #include "minddata/dataset/util/status.h" namespace mindspore { namespace dataset { -namespace api { MapNode::MapNode(std::shared_ptr child, std::vector> operations, std::vector input_columns, std::vector output_columns, @@ -86,6 +85,5 @@ Status MapNode::ValidateParams() { return Status::OK(); } -} // namespace api } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/map_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/map_node.h index 9ee5d1b8b8..aca9a19187 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/map_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/map_node.h @@ -25,7 +25,7 @@ namespace mindspore { namespace dataset { -namespace api { + class MapNode : public DatasetNode { public: /// \brief Constructor @@ -51,7 +51,6 @@ class MapNode : public DatasetNode { std::vector project_columns_; }; -} // namespace api } // namespace dataset } // namespace mindspore #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_MAP_NODE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/project_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/project_node.cc index 9fa7234c58..3ea08f2b05 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/project_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/project_node.cc @@ -25,7 +25,6 @@ #include "minddata/dataset/util/status.h" namespace mindspore { namespace dataset { -namespace api { // Function to build ProjectOp ProjectNode::ProjectNode(std::shared_ptr child, const std::vector &columns) @@ -53,6 +52,5 @@ std::vector> ProjectNode::Build() { return node_ops; } -} // namespace api } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/project_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/project_node.h index 7a6fb52869..e90f6d68d9 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/project_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/project_node.h @@ -26,8 +26,6 @@ namespace mindspore { namespace dataset { -namespace api { - class ProjectNode : public DatasetNode { public: /// \brief Constructor @@ -48,7 +46,6 @@ class ProjectNode : public DatasetNode { std::vector columns_; }; -} // namespace api } // namespace dataset } // namespace mindspore #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_PROJECT_NODE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/rename_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/rename_node.cc index 4d29b8e030..8761102356 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/rename_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/rename_node.cc @@ -25,7 +25,7 @@ #include "minddata/dataset/util/status.h" namespace mindspore { namespace dataset { -namespace api { + // Function to build RenameOp RenameNode::RenameNode(std::shared_ptr child, const std::vector &input_columns, const std::vector &output_columns) @@ -54,6 +54,6 @@ std::vector> RenameNode::Build() { node_ops.push_back(std::make_shared(input_columns_, output_columns_, connector_que_size_)); return node_ops; } -} // namespace api + } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/rename_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/rename_node.h index 379c74beae..8a8faf2a4a 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/rename_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/rename_node.h @@ -26,8 +26,6 @@ namespace mindspore { namespace dataset { -namespace api { - class RenameNode : public DatasetNode { public: /// \brief Constructor @@ -50,7 +48,6 @@ class RenameNode : public DatasetNode { std::vector output_columns_; }; -} // namespace api } // namespace dataset } // namespace mindspore #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_RENAME_NODE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/repeat_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/repeat_node.cc index 071d92c816..7fe738a20d 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/repeat_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/repeat_node.cc @@ -25,7 +25,6 @@ #include "minddata/dataset/util/status.h" namespace mindspore { namespace dataset { -namespace api { RepeatNode::RepeatNode(std::shared_ptr child, int32_t count) : repeat_count_(count) { this->children.push_back(child); @@ -49,6 +48,6 @@ Status RepeatNode::ValidateParams() { return Status::OK(); } -} // namespace api + } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/repeat_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/repeat_node.h index 3385b33db0..b582dcb326 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/repeat_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/repeat_node.h @@ -28,8 +28,6 @@ namespace mindspore { namespace dataset { -namespace api { - class RepeatNode : public DatasetNode { public: /// \brief Constructor @@ -50,7 +48,6 @@ class RepeatNode : public DatasetNode { int32_t repeat_count_; }; -} // namespace api } // namespace dataset } // namespace mindspore #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_REPEAT_NODE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/shuffle_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/shuffle_node.cc index a82f3367fe..e722547aae 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/shuffle_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/shuffle_node.cc @@ -25,7 +25,6 @@ #include "minddata/dataset/util/status.h" namespace mindspore { namespace dataset { -namespace api { // Constructor for ShuffleNode ShuffleNode::ShuffleNode(std::shared_ptr child, int32_t shuffle_size, bool reset_every_epoch) @@ -54,6 +53,5 @@ Status ShuffleNode::ValidateParams() { return Status::OK(); } -} // namespace api } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/shuffle_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/shuffle_node.h index 0274cf8b69..0b81684e61 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/shuffle_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/shuffle_node.h @@ -28,8 +28,6 @@ namespace mindspore { namespace dataset { -namespace api { - class ShuffleNode : public DatasetNode { public: ShuffleNode(std::shared_ptr child, int32_t shuffle_size, bool reset_every_epoch); @@ -46,7 +44,6 @@ class ShuffleNode : public DatasetNode { bool reset_every_epoch_; }; -} // namespace api } // namespace dataset } // namespace mindspore #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SHUFFLE_NODE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/skip_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/skip_node.cc index c2e5618106..8590c47c42 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/skip_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/skip_node.cc @@ -25,7 +25,6 @@ namespace mindspore { namespace dataset { -namespace api { // Constructor for SkipNode SkipNode::SkipNode(std::shared_ptr child, int32_t count) : skip_count_(count) { @@ -52,6 +51,5 @@ Status SkipNode::ValidateParams() { return Status::OK(); } -} // namespace api } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/skip_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/skip_node.h index 438eb54f99..19e7cc9031 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/skip_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/skip_node.h @@ -26,7 +26,6 @@ namespace mindspore { namespace dataset { -namespace api { class SkipNode : public DatasetNode { public: /// \brief Constructor @@ -46,7 +45,7 @@ class SkipNode : public DatasetNode { private: int32_t skip_count_; }; -} // namespace api + } // namespace dataset } // namespace mindspore #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SKIP_NODE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/album_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/album_node.cc index 1e9cdd9c4d..9fe3bac265 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/album_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/album_node.cc @@ -27,7 +27,7 @@ #include "minddata/dataset/util/status.h" namespace mindspore { namespace dataset { -namespace api { + // Constructor for AlbumNode AlbumNode::AlbumNode(const std::string &dataset_dir, const std::string &data_schema, const std::vector &column_names, bool decode, @@ -78,6 +78,5 @@ Status AlbumNode::GetShardId(int32_t *shard_id) { return Status::OK(); } -} // namespace api } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/album_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/album_node.h index fb50353df2..21fce849c0 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/album_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/album_node.h @@ -25,7 +25,6 @@ namespace mindspore { namespace dataset { -namespace api { class AlbumNode : public DatasetNode { public: @@ -57,7 +56,6 @@ class AlbumNode : public DatasetNode { std::shared_ptr sampler_; }; -} // namespace api } // namespace dataset } // namespace mindspore #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_ALBUM_NODE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/celeba_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/celeba_node.cc index a9eaa3442d..24638592b0 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/celeba_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/celeba_node.cc @@ -26,7 +26,7 @@ #include "minddata/dataset/util/status.h" namespace mindspore { namespace dataset { -namespace api { + // Constructor for CelebANode CelebANode::CelebANode(const std::string &dataset_dir, const std::string &usage, const std::shared_ptr &sampler, const bool &decode, @@ -76,6 +76,5 @@ Status CelebANode::GetShardId(int32_t *shard_id) { return Status::OK(); } -} // namespace api } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/celeba_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/celeba_node.h index 0e90a72b2c..ed30adfd9c 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/celeba_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/celeba_node.h @@ -27,7 +27,6 @@ namespace mindspore { namespace dataset { -namespace api { class CelebANode : public DatasetNode { public: @@ -58,7 +57,6 @@ class CelebANode : public DatasetNode { std::shared_ptr sampler_; }; -} // namespace api } // namespace dataset } // namespace mindspore #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_CELEBA_NODE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar100_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar100_node.cc index 104c00ee3c..cf9b73ad36 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar100_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar100_node.cc @@ -26,7 +26,6 @@ #include "minddata/dataset/util/status.h" namespace mindspore { namespace dataset { -namespace api { // Constructor for Cifar100Node Cifar100Node::Cifar100Node(const std::string &dataset_dir, const std::string &usage, @@ -73,6 +72,5 @@ Status Cifar100Node::GetShardId(int32_t *shard_id) { return Status::OK(); } -} // namespace api } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar100_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar100_node.h index 79dd35486b..fe24f8f3a0 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar100_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar100_node.h @@ -25,7 +25,6 @@ namespace mindspore { namespace dataset { -namespace api { class Cifar100Node : public DatasetNode { public: @@ -54,7 +53,6 @@ class Cifar100Node : public DatasetNode { std::shared_ptr sampler_; }; -} // namespace api } // namespace dataset } // namespace mindspore #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_CIFAR100_NODE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar10_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar10_node.cc index 3d19d8fd79..3e7faeb965 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar10_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar10_node.cc @@ -26,7 +26,6 @@ #include "minddata/dataset/util/status.h" namespace mindspore { namespace dataset { -namespace api { // Constructor for Cifar10Node Cifar10Node::Cifar10Node(const std::string &dataset_dir, const std::string &usage, std::shared_ptr sampler, @@ -71,6 +70,5 @@ Status Cifar10Node::GetShardId(int32_t *shard_id) { return Status::OK(); } -} // namespace api } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar10_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar10_node.h index 3037caefc0..716474ae2e 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar10_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar10_node.h @@ -25,7 +25,6 @@ namespace mindspore { namespace dataset { -namespace api { class Cifar10Node : public DatasetNode { public: @@ -54,7 +53,6 @@ class Cifar10Node : public DatasetNode { std::shared_ptr sampler_; }; -} // namespace api } // namespace dataset } // namespace mindspore #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_CIFAR10_NODE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/clue_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/clue_node.cc index 17cd4769e4..45d8b0f43f 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/clue_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/clue_node.cc @@ -28,7 +28,6 @@ #include "minddata/dataset/util/status.h" namespace mindspore { namespace dataset { -namespace api { // Constructor for CLUENode CLUENode::CLUENode(const std::vector clue_files, std::string task, std::string usage, int64_t num_samples, @@ -226,6 +225,5 @@ Status CLUENode::GetShardId(int32_t *shard_id) { return Status::OK(); } -} // namespace api } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/clue_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/clue_node.h index eba34dfab3..76da0501b8 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/clue_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/clue_node.h @@ -25,7 +25,7 @@ namespace mindspore { namespace dataset { -namespace api { + /// \class CLUENode /// \brief A Dataset derived class to represent CLUE dataset class CLUENode : public DatasetNode { @@ -63,7 +63,6 @@ class CLUENode : public DatasetNode { int32_t shard_id_; }; -} // namespace api } // namespace dataset } // namespace mindspore #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_CLUE_NODE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/coco_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/coco_node.cc index 3e0729f146..085e82d5b3 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/coco_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/coco_node.cc @@ -26,7 +26,7 @@ #include "minddata/dataset/util/status.h" namespace mindspore { namespace dataset { -namespace api { + // Constructor for CocoNode CocoNode::CocoNode(const std::string &dataset_dir, const std::string &annotation_file, const std::string &task, const bool &decode, const std::shared_ptr &sampler, std::shared_ptr cache) @@ -125,6 +125,5 @@ Status CocoNode::GetShardId(int32_t *shard_id) { return Status::OK(); } -} // namespace api } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/coco_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/coco_node.h index 2593534509..d3b4275d7f 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/coco_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/coco_node.h @@ -25,7 +25,7 @@ namespace mindspore { namespace dataset { -namespace api { + class CocoNode : public DatasetNode { public: /// \brief Constructor @@ -55,7 +55,6 @@ class CocoNode : public DatasetNode { std::shared_ptr sampler_; }; -} // namespace api } // namespace dataset } // namespace mindspore #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_COCO_NODE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/csv_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/csv_node.cc index 1909b1cfe6..dc72bb168a 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/csv_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/csv_node.cc @@ -27,7 +27,7 @@ #include "minddata/dataset/util/status.h" namespace mindspore { namespace dataset { -namespace api { + // Constructor for CSVNode CSVNode::CSVNode(const std::vector &csv_files, char field_delim, const std::vector> &column_defaults, @@ -137,6 +137,5 @@ Status CSVNode::GetShardId(int32_t *shard_id) { return Status::OK(); } -} // namespace api } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/csv_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/csv_node.h index 9828d5d03f..e006b237e4 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/csv_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/csv_node.h @@ -25,7 +25,7 @@ namespace mindspore { namespace dataset { -namespace api { /// \brief Base class of CSV Record +/// \brief Base class of CSV Record /// \brief Record type for CSV enum CsvType : uint8_t { INT = 0, FLOAT, STRING }; @@ -80,7 +80,7 @@ class CSVNode : public DatasetNode { int32_t num_shards_; int32_t shard_id_; }; -} // namespace api + } // namespace dataset } // namespace mindspore #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_CSV_NODE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/generator_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/generator_node.cc index c60eff1486..25b0e678e7 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/generator_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/generator_node.cc @@ -25,7 +25,7 @@ namespace mindspore { namespace dataset { -namespace api { + GeneratorNode::GeneratorNode(py::function generator_function, const std::vector &column_names, const std::vector &column_types) : generator_function_(generator_function), column_names_(column_names), column_types_(column_types) {} @@ -55,6 +55,6 @@ std::vector> GeneratorNode::Build() { // no validation is needed for generator op. Status GeneratorNode::ValidateParams() { return Status::OK(); } -} // namespace api + } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/generator_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/generator_node.h index 6237346f11..7b29eea7bd 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/generator_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/generator_node.h @@ -27,8 +27,6 @@ namespace mindspore { namespace dataset { -namespace api { - /// \class GeneratorNode /// \brief A Dataset derived class to represent GeneratorNode dataset class GeneratorNode : public DatasetNode { @@ -53,7 +51,7 @@ class GeneratorNode : public DatasetNode { std::vector column_names_; std::vector column_types_; }; -} // namespace api + } // namespace dataset } // namespace mindspore #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_GENERATOR_NODE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/image_folder_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/image_folder_node.cc index 5c4159435d..cba92e0b97 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/image_folder_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/image_folder_node.cc @@ -28,7 +28,6 @@ #include "minddata/dataset/util/status.h" namespace mindspore { namespace dataset { -namespace api { ImageFolderNode::ImageFolderNode(std::string dataset_dir, bool decode, std::shared_ptr sampler, bool recursive, std::set extensions, @@ -78,6 +77,5 @@ Status ImageFolderNode::GetShardId(int32_t *shard_id) { return Status::OK(); } -} // namespace api } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/image_folder_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/image_folder_node.h index 7922eda8bd..22045ed791 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/image_folder_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/image_folder_node.h @@ -29,8 +29,6 @@ namespace mindspore { namespace dataset { -namespace api { - /// \class ImageFolderNode /// \brief A Dataset derived class to represent ImageFolder dataset class ImageFolderNode : public DatasetNode { @@ -63,7 +61,7 @@ class ImageFolderNode : public DatasetNode { std::map class_indexing_; std::set exts_; }; -} // namespace api + } // namespace dataset } // namespace mindspore #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_IMAGE_FOLDER_NODE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/manifest_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/manifest_node.cc index cbb01a9cf2..051a56aabd 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/manifest_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/manifest_node.cc @@ -27,7 +27,7 @@ #include "minddata/dataset/util/status.h" namespace mindspore { namespace dataset { -namespace api { + ManifestNode::ManifestNode(const std::string &dataset_file, const std::string &usage, const std::shared_ptr &sampler, const std::map &class_indexing, bool decode, @@ -93,6 +93,5 @@ Status ManifestNode::GetShardId(int32_t *shard_id) { return Status::OK(); } -} // namespace api } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/manifest_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/manifest_node.h index b8da1555d5..b623868bcd 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/manifest_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/manifest_node.h @@ -26,7 +26,7 @@ namespace mindspore { namespace dataset { -namespace api { + class ManifestNode : public DatasetNode { public: /// \brief Constructor @@ -55,7 +55,7 @@ class ManifestNode : public DatasetNode { std::map class_index_; std::shared_ptr sampler_; }; -} // namespace api + } // namespace dataset } // namespace mindspore #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_MANIFEST_NODE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/minddata_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/minddata_node.cc index 9e05ec0ac7..ef7422c625 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/minddata_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/minddata_node.cc @@ -27,7 +27,7 @@ #include "minddata/dataset/util/status.h" namespace mindspore { namespace dataset { -namespace api { + MindDataNode::MindDataNode(const std::vector &dataset_files, const std::vector &columns_list, const std::shared_ptr &sampler, nlohmann::json padded_sample, int64_t num_padded) : dataset_file_(std::string()), @@ -167,6 +167,5 @@ Status MindDataNode::GetShardId(int32_t *shard_id) { return Status::OK(); } -} // namespace api } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/minddata_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/minddata_node.h index ea10456bcf..850137fcb2 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/minddata_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/minddata_node.h @@ -26,7 +26,7 @@ namespace mindspore { namespace dataset { -namespace api { + class MindDataNode : public DatasetNode { public: /// \brief Constructor @@ -73,7 +73,6 @@ class MindDataNode : public DatasetNode { int64_t num_padded_; }; -} // namespace api } // namespace dataset } // namespace mindspore #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_MINDDATA_NODE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/mnist_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/mnist_node.cc index 5feee17998..0bdc2725fe 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/mnist_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/mnist_node.cc @@ -26,7 +26,6 @@ #include "minddata/dataset/util/status.h" namespace mindspore { namespace dataset { -namespace api { MnistNode::MnistNode(std::string dataset_dir, std::string usage, std::shared_ptr sampler, std::shared_ptr cache) @@ -67,6 +66,5 @@ Status MnistNode::GetShardId(int32_t *shard_id) { return Status::OK(); } -} // namespace api } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/mnist_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/mnist_node.h index 663e2ede97..5e614ad335 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/mnist_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/mnist_node.h @@ -25,7 +25,6 @@ namespace mindspore { namespace dataset { -namespace api { class MnistNode : public DatasetNode { public: @@ -54,7 +53,6 @@ class MnistNode : public DatasetNode { std::shared_ptr sampler_; }; -} // namespace api } // namespace dataset } // namespace mindspore #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_MNIST_NODE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/random_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/random_node.cc index 71839e43dc..83afdbdd2a 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/random_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/random_node.cc @@ -26,7 +26,7 @@ #include "minddata/dataset/util/status.h" namespace mindspore { namespace dataset { -namespace api { + // ValidateParams for RandomNode Status RandomNode::ValidateParams() { if (total_rows_ < 0) { @@ -106,6 +106,5 @@ Status RandomNode::GetShardId(int32_t *shard_id) { return Status::OK(); } -} // namespace api } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/random_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/random_node.h index 79d995438a..4b798e7002 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/random_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/random_node.h @@ -26,7 +26,6 @@ namespace mindspore { namespace dataset { -namespace api { class RandomNode : public DatasetNode { public: @@ -84,7 +83,6 @@ class RandomNode : public DatasetNode { std::mt19937 rand_gen_; }; -} // namespace api } // namespace dataset } // namespace mindspore #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_RANDOM_NODE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/text_file_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/text_file_node.cc index 2d4841f9e7..32ff1498f9 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/text_file_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/text_file_node.cc @@ -27,7 +27,7 @@ #include "minddata/dataset/util/status.h" namespace mindspore { namespace dataset { -namespace api { + // Constructor for TextFileNode TextFileNode::TextFileNode(std::vector dataset_files, int32_t num_samples, ShuffleMode shuffle, int32_t num_shards, int32_t shard_id, std::shared_ptr cache) @@ -108,6 +108,5 @@ Status TextFileNode::GetShardId(int32_t *shard_id) { return Status::OK(); } -} // namespace api } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/text_file_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/text_file_node.h index 9011aa1603..96a76cef28 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/text_file_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/text_file_node.h @@ -25,7 +25,7 @@ namespace mindspore { namespace dataset { -namespace api { + /// \class TextFileNode /// \brief A Dataset derived class to represent TextFile dataset class TextFileNode : public DatasetNode { @@ -57,7 +57,6 @@ class TextFileNode : public DatasetNode { ShuffleMode shuffle_; }; -} // namespace api } // namespace dataset } // namespace mindspore #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_TEXT_FILE_NODE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/tf_record_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/tf_record_node.cc index 7492ec1131..c5619eaa6b 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/tf_record_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/tf_record_node.cc @@ -29,36 +29,13 @@ namespace mindspore { namespace dataset { -namespace api { - -bool ValidateFirstRowCrc(const std::string &filename) { - std::ifstream reader; - reader.open(filename); - if (!reader) { - return false; - } - - // read data - int64_t record_length = 0; - (void)reader.read(reinterpret_cast(&record_length), static_cast(sizeof(int64_t))); - - // read crc from file - uint32_t masked_crc = 0; - (void)reader.read(reinterpret_cast(&masked_crc), static_cast(sizeof(uint32_t))); - - // generate crc from data - uint32_t generated_crc = - system::Crc32c::GetMaskCrc32cValue(reinterpret_cast(&record_length), sizeof(int64_t)); - - return masked_crc == generated_crc; -} // Validator for TFRecordNode Status TFRecordNode::ValidateParams() { if (dataset_files_.empty()) { std::string err_msg = "TFRecordNode: dataset_files is not specified."; MS_LOG(ERROR) << err_msg; - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg); + return Status(StatusCode::kSyntaxError, __LINE__, __FILE__, err_msg); } for (const auto &f : dataset_files_) { @@ -67,7 +44,7 @@ Status TFRecordNode::ValidateParams() { std::string err_msg = "TFRecordNode: dataset file: [" + f + "] is invalid or does not exist."; MS_LOG(ERROR) << err_msg; - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg); + return Status(StatusCode::kSyntaxError, __LINE__, __FILE__, err_msg); } } @@ -75,14 +52,14 @@ Status TFRecordNode::ValidateParams() { std::string err_msg = "TFRecordNode: Invalid number of samples: " + std::to_string(num_samples_); MS_LOG(ERROR) << err_msg; - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg); + return Status(StatusCode::kSyntaxError, __LINE__, __FILE__, err_msg); } if (num_shards_ <= 0) { std::string err_msg = "TFRecordNode: Invalid num_shards: " + std::to_string(num_shards_); MS_LOG(ERROR) << err_msg; - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg); + return Status(StatusCode::kSyntaxError, __LINE__, __FILE__, err_msg); } if (shard_id_ < 0 || shard_id_ >= num_shards_) { @@ -90,7 +67,7 @@ Status TFRecordNode::ValidateParams() { ", num_shards: " + std::to_string(num_shards_); MS_LOG(ERROR) << err_msg; - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg); + return Status(StatusCode::kSyntaxError, __LINE__, __FILE__, err_msg); } if (cache_ == nullptr && !shard_equal_rows_ && dataset_files_.size() < num_shards_) { @@ -99,12 +76,12 @@ Status TFRecordNode::ValidateParams() { std::string err_msg = "TFRecordNode: Invalid number of dataset files, should at least be " + std::to_string(num_shards_); MS_LOG(ERROR) << err_msg; - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg); + return Status(StatusCode::kSyntaxError, __LINE__, __FILE__, err_msg); } std::vector invalid_files(dataset_files_.size()); auto it = std::copy_if(dataset_files_.begin(), dataset_files_.end(), invalid_files.begin(), - [](const std::string &filename) { return !ValidateFirstRowCrc(filename); }); + [](const std::string &filename) { return !TFReaderOp::ValidateFirstRowCrc(filename); }); invalid_files.resize(std::distance(invalid_files.begin(), it)); std::string err_msg; if (!invalid_files.empty()) { @@ -115,7 +92,7 @@ Status TFRecordNode::ValidateParams() { [](const std::string &accumulated, const std::string &next) { return accumulated + " " + next + "\n"; }); err_msg += accumulated_filenames; } - return err_msg.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg); + return err_msg.empty() ? Status::OK() : Status(StatusCode::kSyntaxError, __LINE__, __FILE__, err_msg); } // Function to build TFRecordNode @@ -180,6 +157,5 @@ Status TFRecordNode::GetShardId(int32_t *shard_id) { return Status::OK(); } -} // namespace api } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/tf_record_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/tf_record_node.h index 08e4d094c4..6f12b0a64e 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/tf_record_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/tf_record_node.h @@ -26,7 +26,7 @@ namespace mindspore { namespace dataset { -namespace api { + /// \class TFRecordNode /// \brief A Dataset derived class to represent TFRecord dataset class TFRecordNode : public DatasetNode { @@ -88,7 +88,6 @@ class TFRecordNode : public DatasetNode { bool shard_equal_rows_; }; -} // namespace api } // namespace dataset } // namespace mindspore #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_TF_RECORD_NODE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/voc_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/voc_node.cc index 68ade8aa07..0d413e8a68 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/voc_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/voc_node.cc @@ -27,7 +27,7 @@ #include "minddata/dataset/util/status.h" namespace mindspore { namespace dataset { -namespace api { + // Constructor for VOCNode VOCNode::VOCNode(const std::string &dataset_dir, const std::string &task, const std::string &usage, const std::map &class_indexing, bool decode, std::shared_ptr sampler, @@ -119,6 +119,5 @@ Status VOCNode::GetShardId(int32_t *shard_id) { return Status::OK(); } -} // namespace api } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/voc_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/voc_node.h index ed3656397c..4102e3189c 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/voc_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/voc_node.h @@ -26,7 +26,7 @@ namespace mindspore { namespace dataset { -namespace api { + class VOCNode : public DatasetNode { public: /// \brief Constructor @@ -63,7 +63,7 @@ class VOCNode : public DatasetNode { bool decode_; std::shared_ptr sampler_; }; -} // namespace api + } // namespace dataset } // namespace mindspore #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_VOC_NODE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/sync_wait_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/sync_wait_node.cc index 8642cdb79f..ca6bd228fb 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/sync_wait_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/sync_wait_node.cc @@ -25,7 +25,7 @@ namespace mindspore { namespace dataset { -namespace api { + // Constructor for SyncWaitNode SyncWaitNode::SyncWaitNode(std::shared_ptr child, const std::string &condition_name, int32_t num_batch, py::function callback) @@ -58,6 +58,6 @@ Status SyncWaitNode::ValidateParams() { return Status::OK(); } -} // namespace api + } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/sync_wait_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/sync_wait_node.h index b108e257af..8e193c96f6 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/sync_wait_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/sync_wait_node.h @@ -26,8 +26,6 @@ namespace mindspore { namespace dataset { -namespace api { - /// \class SyncWaitNode /// \brief A Dataset derived class to represent SyncWaitNode dataset class SyncWaitNode : public DatasetNode { @@ -52,7 +50,7 @@ class SyncWaitNode : public DatasetNode { int32_t num_batch_; py::function callback_; }; -} // namespace api + } // namespace dataset } // namespace mindspore #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SYNC_WAIT_NODE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/take_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/take_node.cc index 9a3fed7b87..917df6a781 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/take_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/take_node.cc @@ -25,7 +25,7 @@ namespace mindspore { namespace dataset { -namespace api { + // Constructor for TakeNode TakeNode::TakeNode(std::shared_ptr child, int32_t count) : take_count_(count) { this->children.push_back(child); @@ -50,6 +50,6 @@ Status TakeNode::ValidateParams() { } return Status::OK(); } -} // namespace api + } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/take_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/take_node.h index dfc7199384..93d735d15a 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/take_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/take_node.h @@ -26,8 +26,6 @@ namespace mindspore { namespace dataset { -namespace api { - class TakeNode : public DatasetNode { public: /// \brief Constructor @@ -48,7 +46,6 @@ class TakeNode : public DatasetNode { int32_t take_count_; }; -} // namespace api } // namespace dataset } // namespace mindspore #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_TAKE_NODE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/transfer_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/transfer_node.cc index d566d30d3a..dcace2b176 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/transfer_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/transfer_node.cc @@ -25,7 +25,6 @@ namespace mindspore { namespace dataset { -namespace api { // Constructor for TransferNode TransferNode::TransferNode(std::shared_ptr child, bool send_epoch_end) @@ -88,6 +87,5 @@ Status TransferNode::get_distribution(std::shared_ptr ds, int32_t * return Status::OK(); } -} // namespace api } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/transfer_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/transfer_node.h index 34f00800e5..5fec5b51cd 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/transfer_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/transfer_node.h @@ -26,8 +26,6 @@ namespace mindspore { namespace dataset { -namespace api { - class TransferNode : public DatasetNode { public: /// \brief Constructor @@ -55,7 +53,6 @@ class TransferNode : public DatasetNode { int32_t total_batch_; }; -} // namespace api } // namespace dataset } // namespace mindspore #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_TRANSFER_NODE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/zip_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/zip_node.cc index 2099cd1035..0d2c068635 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/zip_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/zip_node.cc @@ -25,7 +25,6 @@ #include "minddata/dataset/util/status.h" namespace mindspore { namespace dataset { -namespace api { ZipNode::ZipNode(const std::vector> &datasets) : datasets_(datasets) { for (auto dataset : datasets_) { @@ -57,6 +56,5 @@ std::vector> ZipNode::Build() { return node_ops; } -} // namespace api } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/zip_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/zip_node.h index f7046842a9..27f92e0da5 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/zip_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/zip_node.h @@ -25,7 +25,6 @@ namespace mindspore { namespace dataset { -namespace api { class ZipNode : public DatasetNode { public: @@ -47,7 +46,6 @@ class ZipNode : public DatasetNode { std::vector> datasets_; }; -} // namespace api } // namespace dataset } // namespace mindspore #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_ZIP_NODE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_transform_pass.cc b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_transform_pass.cc index 30a2e33cd1..85891afc02 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_transform_pass.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_transform_pass.cc @@ -85,7 +85,7 @@ Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr node, b // We're a cache op but no sampler was saved from leaf, so create a default sampler const int64_t num_samples = 0; const int64_t start_index = 0; - sampler_ = std::make_shared(num_samples, start_index); + sampler_ = std::make_shared(num_samples, start_index); node->SetSampler(std::move(sampler_)); MS_LOG(INFO) << "Cache transform pass: Creating default sequential sampler for cache op."; } @@ -128,7 +128,7 @@ Status CacheTransformPass::CachePass::NonMappableCacheLeafSetup(std::shared_ptr< } else { // If we are a non-mappable leaf and are not in a cache tree, then this sampler is not used so we can // remove it here. The leaf itself will provide it's own methods of fetching the data (not sampler-based) - std::shared_ptr sampler_from_leaf; + std::shared_ptr sampler_from_leaf; RETURN_IF_NOT_OK(leaf_op->FetchRemoveSampler(&sampler_from_leaf)); } return Status::OK(); @@ -278,7 +278,7 @@ Status CacheTransformPass::ExecuteCacheTransform(ExecutionTree *tree, std::share cache_op->Parent(&cache_parent, 0); // fetch the cache op's parent // Extract the sampler from the leaf. We will overwrite this sampler with the lookup op later. - std::shared_ptr leaf_sampler = leaf_op->sampler(); + std::shared_ptr leaf_sampler = leaf_op->sampler(); // Construct the merge op with defaults std::shared_ptr merge_op; diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_transform_pass.h b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_transform_pass.h index 346f4dd62d..0c6d288f37 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_transform_pass.h +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_transform_pass.h @@ -176,7 +176,7 @@ class CacheTransformPass : public TreePass { bool is_caching_; std::shared_ptr leaf_op_; - std::shared_ptr sampler_; + std::shared_ptr sampler_; // The two operators that work together to establish the cache transform std::vector, std::shared_ptr>> cache_pairs_; }; diff --git a/mindspore/ccsrc/minddata/dataset/engine/python_runtime_context.h b/mindspore/ccsrc/minddata/dataset/engine/python_runtime_context.h index 63353b7efd..05116fed63 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/python_runtime_context.h +++ b/mindspore/ccsrc/minddata/dataset/engine/python_runtime_context.h @@ -26,13 +26,14 @@ namespace mindspore::dataset { class RuntimeContext; -/// Class the represents single runtime instance which can consume data from a data pipeline +/// Class that represents single runtime instance which can consume data from a data pipeline class PythonRuntimeContext : public RuntimeContext { public: /// Method to terminate the runtime, this will not release the resources /// \return Status error code Status Terminate() override; + // Safe destructing the tree that includes python objects ~PythonRuntimeContext() { Terminate(); { diff --git a/mindspore/ccsrc/minddata/dataset/engine/tree_adapter.cc b/mindspore/ccsrc/minddata/dataset/engine/tree_adapter.cc index 07d01ceec1..81613579ca 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/tree_adapter.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/tree_adapter.cc @@ -22,7 +22,7 @@ namespace mindspore { namespace dataset { -Status TreeAdapter::BuildAndPrepare(std::shared_ptr root_ir, int32_t num_epoch) { +Status TreeAdapter::BuildAndPrepare(std::shared_ptr root_ir, int32_t num_epoch) { // Check whether this function has been called before. If so, return failure CHECK_FAIL_RETURN_UNEXPECTED(tree_ == nullptr, "ExecutionTree is already built."); RETURN_UNEXPECTED_IF_NULL(root_ir); @@ -65,7 +65,7 @@ Status TreeAdapter::GetNext(TensorRow *row) { return Status::OK(); } -Status TreeAdapter::DFSBuildTree(std::shared_ptr ir, std::shared_ptr *op) { +Status TreeAdapter::DFSBuildTree(std::shared_ptr ir, std::shared_ptr *op) { // validate the op can be built first before building the DatasetOp RETURN_IF_NOT_OK(ir->ValidateParams()); std::vector> ops = ir->Build(); diff --git a/mindspore/ccsrc/minddata/dataset/engine/tree_adapter.h b/mindspore/ccsrc/minddata/dataset/engine/tree_adapter.h index a1f8201cee..cfe25f4e70 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/tree_adapter.h +++ b/mindspore/ccsrc/minddata/dataset/engine/tree_adapter.h @@ -28,9 +28,8 @@ namespace mindspore { namespace dataset { -namespace api { class DatasetNode; -} + class TreeAdapter { public: TreeAdapter() = default; @@ -40,7 +39,7 @@ class TreeAdapter { // This will construct an ExeTree from a Dataset root and Prepare() the ExeTree // This function is only meant to be called once and needs to be called before GetNext // ExeTree will be launched when the first GetNext is called - Status BuildAndPrepare(std::shared_ptr root, int32_t num_epoch = -1); + Status BuildAndPrepare(std::shared_ptr root, int32_t num_epoch = -1); // This is the main method TreeConsumer uses to interact with TreeAdapter // 1. GetNext will Launch() the ExeTree on its first call by iterator (tree is already prepared) @@ -62,7 +61,7 @@ class TreeAdapter { private: // This RECURSIVE function converts IR nodes into DatasetOp in ExecutionTree. IR could build a vector of ops. In // such case, the first node is returned. Op is added as child when the current function returns. - Status DFSBuildTree(std::shared_ptr ir, std::shared_ptr *op); + Status DFSBuildTree(std::shared_ptr ir, std::shared_ptr *op); std::unique_ptr cur_db_; std::unordered_map column_name_map_; diff --git a/mindspore/ccsrc/minddata/dataset/include/config.h b/mindspore/ccsrc/minddata/dataset/include/config.h index 95bf271a04..2f159c0287 100644 --- a/mindspore/ccsrc/minddata/dataset/include/config.h +++ b/mindspore/ccsrc/minddata/dataset/include/config.h @@ -24,7 +24,6 @@ namespace mindspore { namespace dataset { -namespace api { // Config operations for setting and getting the configuration. namespace config { @@ -76,7 +75,6 @@ int32_t get_callback_timeout(); bool load(std::string file); } // namespace config -} // namespace api } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/include/datasets.h b/mindspore/ccsrc/minddata/dataset/include/datasets.h index 1cb9426bb0..9858388580 100644 --- a/mindspore/ccsrc/minddata/dataset/include/datasets.h +++ b/mindspore/ccsrc/minddata/dataset/include/datasets.h @@ -30,7 +30,6 @@ #include "minddata/dataset/core/constants.h" #include "minddata/dataset/engine/consumers/tree_consumer.h" -#include "minddata/dataset/engine/data_schema.h" #include "minddata/dataset/engine/ir/datasetops/dataset_node.h" #include "minddata/dataset/include/iterator.h" #include "minddata/dataset/include/samplers.h" @@ -47,8 +46,6 @@ namespace mindspore { namespace dataset { -class DatasetOp; -class DataSchema; class Tensor; class TensorShape; class TreeAdapter; @@ -57,10 +54,8 @@ class TreeGetters; class Vocab; #endif -namespace api { -// Forward declare class DatasetNode; -class Dataset; + class Iterator; class TensorOperation; @@ -91,7 +86,6 @@ class Dataset : public std::enable_shared_from_this { // need friend class so they can access the children_ field friend class Iterator; friend class TransferNode; - friend class mindspore::dataset::TreeAdapter; /// \brief Constructor Dataset(); @@ -108,14 +102,14 @@ class Dataset : public std::enable_shared_from_this { std::vector GetOutputTypes(); /// \brief Gets the output shape - /// \return a vector of TensorShape. If failed, return am empty vector + /// \return a vector of TensorShape. If failed, return an empty vector std::vector GetOutputShapes(); /// \brief Gets the batch size /// \return int64_t int64_t GetBatchSize(); - /// \brief Gets the the repeat count + /// \brief Gets the repeat count /// \return int64_t int64_t GetRepeatCount(); @@ -136,7 +130,7 @@ class Dataset : public std::enable_shared_from_this { /// \brief Function to transfer data through a device. /// \notes If device is Ascend, features of data will be transferred one by one. The limitation /// of data transmission per time is 256M. - /// \param[in] send_epoch_end Whether to send end of sequence to device or not (default=True). + /// \param[in] send_epoch_end Whether to send end of sequence to device or not (default=true). /// \return Returns true if no error encountered else false. bool DeviceQueue(bool send_epoch_end = true); @@ -164,9 +158,7 @@ class Dataset : public std::enable_shared_from_this { /// available to make the last batch, then those rows will /// be dropped and not propagated to the next node /// \return Shared pointer to the current BatchDataset - std::shared_ptr Batch(int32_t batch_size, bool drop_remainder = false) { - return std::make_shared(shared_from_this(), batch_size, drop_remainder); - } + std::shared_ptr Batch(int32_t batch_size, bool drop_remainder = false); #ifndef ENABLE_ANDROID /// \brief Function to create a BucketBatchByLengthDataset @@ -965,7 +957,6 @@ std::shared_ptr CreateDatasetCache(session_id_type id, uint64_t me /// \param[in] datasets List of shared pointers to the datasets that we want to zip /// \return Shared pointer to the current Dataset std::shared_ptr Zip(const std::vector> &datasets); -} // namespace api } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/include/execute.h b/mindspore/ccsrc/minddata/dataset/include/execute.h index 53d6ee5572..7ff158efc5 100644 --- a/mindspore/ccsrc/minddata/dataset/include/execute.h +++ b/mindspore/ccsrc/minddata/dataset/include/execute.h @@ -28,8 +28,6 @@ namespace dataset { class TensorOp; -namespace api { - // class to run tensor operations in eager mode class Execute { public: @@ -45,7 +43,6 @@ class Execute { std::shared_ptr op_; }; -} // namespace api } // namespace dataset } // namespace mindspore #endif // DATASET_API_EXECUTE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/include/iterator.h b/mindspore/ccsrc/minddata/dataset/include/iterator.h index 58781ecefb..959a9c5191 100644 --- a/mindspore/ccsrc/minddata/dataset/include/iterator.h +++ b/mindspore/ccsrc/minddata/dataset/include/iterator.h @@ -35,7 +35,6 @@ class Tensor; class RuntimeContext; class IteratorConsumer; -namespace api { class Dataset; @@ -114,10 +113,9 @@ class Iterator { _Iterator end() { return _Iterator(nullptr); } private: - std::unique_ptr runtime_context; + std::unique_ptr runtime_context_; IteratorConsumer *consumer_; }; -} // namespace api } // namespace dataset } // namespace mindspore #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_ITERATOR_H_ diff --git a/mindspore/ccsrc/minddata/dataset/include/samplers.h b/mindspore/ccsrc/minddata/dataset/include/samplers.h index 197729258e..96c71f486f 100644 --- a/mindspore/ccsrc/minddata/dataset/include/samplers.h +++ b/mindspore/ccsrc/minddata/dataset/include/samplers.h @@ -29,9 +29,7 @@ namespace mindspore { namespace dataset { // Internal Sampler class forward declaration -class Sampler; - -namespace api { +class SamplerRT; class SamplerObj : public std::enable_shared_from_this { public: @@ -47,7 +45,7 @@ class SamplerObj : public std::enable_shared_from_this { /// \brief Pure virtual function to convert a SamplerObj class into a runtime sampler object /// \return Shared pointers to the newly created Sampler - virtual std::shared_ptr Build() = 0; + virtual std::shared_ptr Build() = 0; /// \brief Function for derived class to get the shard id of sampler /// \return The shard id of the derived sampler @@ -131,7 +129,7 @@ class DistributedSamplerObj : public SamplerObj { ~DistributedSamplerObj() = default; - std::shared_ptr Build() override; + std::shared_ptr Build() override; #ifndef ENABLE_ANDROID std::shared_ptr BuildForMindDataset() override; @@ -159,7 +157,7 @@ class PKSamplerObj : public SamplerObj { ~PKSamplerObj() = default; - std::shared_ptr Build() override; + std::shared_ptr Build() override; #ifndef ENABLE_ANDROID std::shared_ptr BuildForMindDataset() override; @@ -179,7 +177,7 @@ class RandomSamplerObj : public SamplerObj { ~RandomSamplerObj() = default; - std::shared_ptr Build() override; + std::shared_ptr Build() override; #ifndef ENABLE_ANDROID std::shared_ptr BuildForMindDataset() override; @@ -198,7 +196,7 @@ class SequentialSamplerObj : public SamplerObj { ~SequentialSamplerObj() = default; - std::shared_ptr Build() override; + std::shared_ptr Build() override; #ifndef ENABLE_ANDROID std::shared_ptr BuildForMindDataset() override; @@ -217,7 +215,7 @@ class SubsetRandomSamplerObj : public SamplerObj { ~SubsetRandomSamplerObj() = default; - std::shared_ptr Build() override; + std::shared_ptr Build() override; #ifndef ENABLE_ANDROID std::shared_ptr BuildForMindDataset() override; @@ -236,7 +234,7 @@ class WeightedRandomSamplerObj : public SamplerObj { ~WeightedRandomSamplerObj() = default; - std::shared_ptr Build() override; + std::shared_ptr Build() override; bool ValidateParams() override; @@ -245,7 +243,6 @@ class WeightedRandomSamplerObj : public SamplerObj { int64_t num_samples_; bool replacement_; }; -} // namespace api } // namespace dataset } // namespace mindspore #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_SAMPLERS_H_ diff --git a/mindspore/ccsrc/minddata/dataset/include/text.h b/mindspore/ccsrc/minddata/dataset/include/text.h index 7e465dc983..a03cba422f 100644 --- a/mindspore/ccsrc/minddata/dataset/include/text.h +++ b/mindspore/ccsrc/minddata/dataset/include/text.h @@ -32,7 +32,6 @@ namespace mindspore { namespace dataset { -namespace api { // Transform operations for text namespace text { @@ -103,7 +102,6 @@ class SentencePieceTokenizerOperation : public TensorOperation { SPieceTokenizerOutType out_type_; }; } // namespace text -} // namespace api } // namespace dataset } // namespace mindspore #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_TEXT_H_ diff --git a/mindspore/ccsrc/minddata/dataset/include/transforms.h b/mindspore/ccsrc/minddata/dataset/include/transforms.h index a5f4d3025b..2df71e40be 100644 --- a/mindspore/ccsrc/minddata/dataset/include/transforms.h +++ b/mindspore/ccsrc/minddata/dataset/include/transforms.h @@ -28,7 +28,6 @@ namespace dataset { class TensorOp; -namespace api { // Abstract class to represent a dataset in the data pipeline. class TensorOperation : public std::enable_shared_from_this { public: @@ -94,7 +93,6 @@ class TypeCastOperation : public TensorOperation { std::string data_type_; }; } // namespace transforms -} // namespace api } // namespace dataset } // namespace mindspore #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_TRANSFORMS_H_ diff --git a/mindspore/ccsrc/minddata/dataset/include/vision.h b/mindspore/ccsrc/minddata/dataset/include/vision.h index 831ea6e5a9..348a194281 100644 --- a/mindspore/ccsrc/minddata/dataset/include/vision.h +++ b/mindspore/ccsrc/minddata/dataset/include/vision.h @@ -25,7 +25,6 @@ namespace mindspore { namespace dataset { -namespace api { // Transform operations for performing computer vision. namespace vision { @@ -880,7 +879,6 @@ class UniformAugOperation : public TensorOperation { #endif } // namespace vision -} // namespace api } // namespace dataset } // namespace mindspore #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_VISION_H_ diff --git a/mindspore/lite/minddata/example/jni-example.cc b/mindspore/lite/minddata/example/jni-example.cc index 469a63c1e6..0b46470c7d 100644 --- a/mindspore/lite/minddata/example/jni-example.cc +++ b/mindspore/lite/minddata/example/jni-example.cc @@ -34,12 +34,12 @@ extern "C" JNIEXPORT jstring JNICALL Java_com_example_mindsporepredict_MainActiv return env->NewStringUTF(hello.c_str()); } -using Dataset = mindspore::dataset::api::Dataset; -using Iterator = mindspore::dataset::api::Iterator; +using Dataset = mindspore::dataset::Dataset; +using Iterator = mindspore::dataset::Iterator; +using mindspore::dataset::Cifar10; using mindspore::dataset::Path; +using mindspore::dataset::RandomSampler; using mindspore::dataset::Tensor; -using mindspore::dataset::api::Cifar10; -using mindspore::dataset::api::RandomSampler; extern "C" JNIEXPORT void JNICALL Java_com_example_mindsporepredict_MainActivity_pathTest(JNIEnv *env, jobject /* this */, diff --git a/mindspore/lite/minddata/example/x86-example.cc b/mindspore/lite/minddata/example/x86-example.cc index 440e3c9bf7..e0ac1f9f69 100644 --- a/mindspore/lite/minddata/example/x86-example.cc +++ b/mindspore/lite/minddata/example/x86-example.cc @@ -22,11 +22,11 @@ #include "minddata/dataset/engine/ir/datasetops/source/cifar10_node.h" #include "minddata/dataset/util/path.h" -using Dataset = mindspore::dataset::api::Dataset; -using Iterator = mindspore::dataset::api::Iterator; +using Dataset = mindspore::dataset::Dataset; +using Iterator = mindspore::dataset::Iterator; +using mindspore::dataset::Cifar10; +using mindspore::dataset::RandomSampler; using mindspore::dataset::Tensor; -using mindspore::dataset::api::Cifar10; -using mindspore::dataset::api::RandomSampler; int main() { MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCifar10Dataset."; diff --git a/tests/ut/cpp/dataset/album_op_test.cc b/tests/ut/cpp/dataset/album_op_test.cc index b921dd04ea..dadfe7d9a6 100644 --- a/tests/ut/cpp/dataset/album_op_test.cc +++ b/tests/ut/cpp/dataset/album_op_test.cc @@ -47,9 +47,8 @@ std::shared_ptr Repeat(int repeat_cnt); std::shared_ptr Build(std::vector> ops); -std::shared_ptr Album(int64_t num_works, int64_t rows, int64_t conns, std::string path, - bool shuf = false, std::unique_ptr sampler = nullptr, - bool decode = false) { +std::shared_ptr Album(int64_t num_works, int64_t rows, int64_t conns, std::string path, bool shuf = false, + std::unique_ptr sampler = nullptr, bool decode = false) { std::shared_ptr so; AlbumOp::Builder builder; Status rc = builder.SetNumWorkers(num_works) @@ -64,9 +63,9 @@ std::shared_ptr Album(int64_t num_works, int64_t rows, int64_t conns, s } std::shared_ptr AlbumSchema(int64_t num_works, int64_t rows, int64_t conns, std::string path, - std::string schema_file, std::vector column_names = {}, - bool shuf = false, std::unique_ptr sampler = nullptr, - bool decode = false) { + std::string schema_file, std::vector column_names = {}, + bool shuf = false, std::unique_ptr sampler = nullptr, + bool decode = false) { std::shared_ptr so; AlbumOp::Builder builder; Status rc = builder.SetNumWorkers(num_works) diff --git a/tests/ut/cpp/dataset/c_api_cache_test.cc b/tests/ut/cpp/dataset/c_api_cache_test.cc index 6c34250d99..0da852f335 100644 --- a/tests/ut/cpp/dataset/c_api_cache_test.cc +++ b/tests/ut/cpp/dataset/c_api_cache_test.cc @@ -33,7 +33,6 @@ #include "minddata/dataset/engine/ir/datasetops/source/voc_node.h" using namespace mindspore::dataset; -using namespace mindspore::dataset::api; // Helper function to get the session id from SESSION_ID env variable Status GetSessionFromEnv(session_id_type *session_id); @@ -744,4 +743,3 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestCacheClueCApi) { // Manually terminate the pipeline iter->Stop(); } - diff --git a/tests/ut/cpp/dataset/c_api_dataset_album_test.cc b/tests/ut/cpp/dataset/c_api_dataset_album_test.cc index 068ba02e1a..5e3dec81c0 100644 --- a/tests/ut/cpp/dataset/c_api_dataset_album_test.cc +++ b/tests/ut/cpp/dataset/c_api_dataset_album_test.cc @@ -29,8 +29,7 @@ // IR leaf nodes #include "minddata/dataset/engine/ir/datasetops/source/album_node.h" -using namespace mindspore::dataset::api; -using mindspore::dataset::Tensor; +using namespace mindspore::dataset; class MindDataTestPipeline : public UT::DatasetOpTesting { protected: diff --git a/tests/ut/cpp/dataset/c_api_dataset_cifar_test.cc b/tests/ut/cpp/dataset/c_api_dataset_cifar_test.cc index 59d9ce408b..9c6fd74944 100644 --- a/tests/ut/cpp/dataset/c_api_dataset_cifar_test.cc +++ b/tests/ut/cpp/dataset/c_api_dataset_cifar_test.cc @@ -27,15 +27,11 @@ #include "minddata/dataset/engine/ir/datasetops/skip_node.h" #include "minddata/dataset/engine/ir/datasetops/zip_node.h" - // IR leaf nodes #include "minddata/dataset/engine/ir/datasetops/source/cifar100_node.h" #include "minddata/dataset/engine/ir/datasetops/source/cifar10_node.h" -using namespace mindspore::dataset::api; -using mindspore::dataset::Tensor; -using mindspore::dataset::DataType; -using mindspore::dataset::TensorShape; +using namespace mindspore::dataset; class MindDataTestPipeline : public UT::DatasetOpTesting { protected: diff --git a/tests/ut/cpp/dataset/c_api_dataset_clue_test.cc b/tests/ut/cpp/dataset/c_api_dataset_clue_test.cc index a2d721cc23..c3758cfc7d 100644 --- a/tests/ut/cpp/dataset/c_api_dataset_clue_test.cc +++ b/tests/ut/cpp/dataset/c_api_dataset_clue_test.cc @@ -32,10 +32,7 @@ // IR leaf nodes #include "minddata/dataset/engine/ir/datasetops/source/clue_node.h" -using namespace mindspore::dataset::api; -using mindspore::dataset::GlobalContext; -using mindspore::dataset::ShuffleMode; -using mindspore::dataset::Tensor; +using namespace mindspore::dataset; class MindDataTestPipeline : public UT::DatasetOpTesting { protected: diff --git a/tests/ut/cpp/dataset/c_api_dataset_coco_test.cc b/tests/ut/cpp/dataset/c_api_dataset_coco_test.cc index 2a022abdec..ae64483616 100644 --- a/tests/ut/cpp/dataset/c_api_dataset_coco_test.cc +++ b/tests/ut/cpp/dataset/c_api_dataset_coco_test.cc @@ -46,10 +46,7 @@ #include "minddata/dataset/engine/ir/datasetops/source/tf_record_node.h" #include "minddata/dataset/engine/ir/datasetops/source/voc_node.h" -using namespace mindspore::dataset::api; -using mindspore::dataset::dsize_t; -using mindspore::dataset::Tensor; -using mindspore::dataset::TensorShape; +using namespace mindspore::dataset; class MindDataTestPipeline : public UT::DatasetOpTesting { protected: diff --git a/tests/ut/cpp/dataset/c_api_dataset_config_test.cc b/tests/ut/cpp/dataset/c_api_dataset_config_test.cc index 1ab089f406..54a3e4e565 100644 --- a/tests/ut/cpp/dataset/c_api_dataset_config_test.cc +++ b/tests/ut/cpp/dataset/c_api_dataset_config_test.cc @@ -29,9 +29,7 @@ // IR leaf nodes #include "minddata/dataset/engine/ir/datasetops/source/text_file_node.h" -using namespace mindspore::dataset::api; -using mindspore::dataset::ShuffleMode; -using mindspore::dataset::Tensor; +using namespace mindspore::dataset; class MindDataTestPipeline : public UT::DatasetOpTesting { protected: diff --git a/tests/ut/cpp/dataset/c_api_dataset_csv_test.cc b/tests/ut/cpp/dataset/c_api_dataset_csv_test.cc index f90fe02652..98aada926b 100644 --- a/tests/ut/cpp/dataset/c_api_dataset_csv_test.cc +++ b/tests/ut/cpp/dataset/c_api_dataset_csv_test.cc @@ -48,10 +48,7 @@ #include "minddata/dataset/engine/ir/datasetops/source/tf_record_node.h" #include "minddata/dataset/engine/ir/datasetops/source/voc_node.h" -using namespace mindspore::dataset::api; -using mindspore::dataset::GlobalContext; -using mindspore::dataset::ShuffleMode; -using mindspore::dataset::Tensor; +using namespace mindspore::dataset; class MindDataTestPipeline : public UT::DatasetOpTesting { protected: diff --git a/tests/ut/cpp/dataset/c_api_dataset_iterator_test.cc b/tests/ut/cpp/dataset/c_api_dataset_iterator_test.cc index 06c14a3013..3d3f503cbe 100644 --- a/tests/ut/cpp/dataset/c_api_dataset_iterator_test.cc +++ b/tests/ut/cpp/dataset/c_api_dataset_iterator_test.cc @@ -33,9 +33,7 @@ #include "minddata/dataset/engine/ir/datasetops/source/mnist_node.h" #include "minddata/dataset/engine/ir/datasetops/source/voc_node.h" -using namespace mindspore::dataset::api; -using mindspore::dataset::Tensor; -using mindspore::dataset::TensorShape; +using namespace mindspore::dataset; class MindDataTestPipeline : public UT::DatasetOpTesting { protected: diff --git a/tests/ut/cpp/dataset/c_api_dataset_manifest_test.cc b/tests/ut/cpp/dataset/c_api_dataset_manifest_test.cc index dcca936ea2..bc67747d89 100644 --- a/tests/ut/cpp/dataset/c_api_dataset_manifest_test.cc +++ b/tests/ut/cpp/dataset/c_api_dataset_manifest_test.cc @@ -29,8 +29,7 @@ // IR leaf nodes #include "minddata/dataset/engine/ir/datasetops/source/manifest_node.h" -using namespace mindspore::dataset::api; -using mindspore::dataset::Tensor; +using namespace mindspore::dataset; class MindDataTestPipeline : public UT::DatasetOpTesting { protected: diff --git a/tests/ut/cpp/dataset/c_api_dataset_minddata_test.cc b/tests/ut/cpp/dataset/c_api_dataset_minddata_test.cc index b9b61a9fdf..b7bde3da13 100644 --- a/tests/ut/cpp/dataset/c_api_dataset_minddata_test.cc +++ b/tests/ut/cpp/dataset/c_api_dataset_minddata_test.cc @@ -30,8 +30,7 @@ // IR leaf nodes #include "minddata/dataset/engine/ir/datasetops/source/minddata_node.h" -using namespace mindspore::dataset::api; -using mindspore::dataset::Tensor; +using namespace mindspore::dataset; class MindDataTestPipeline : public UT::DatasetOpTesting { protected: diff --git a/tests/ut/cpp/dataset/c_api_dataset_ops_test.cc b/tests/ut/cpp/dataset/c_api_dataset_ops_test.cc index 6b06259e85..74aa2b0cb6 100644 --- a/tests/ut/cpp/dataset/c_api_dataset_ops_test.cc +++ b/tests/ut/cpp/dataset/c_api_dataset_ops_test.cc @@ -35,8 +35,7 @@ #include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h" #include "minddata/dataset/engine/ir/datasetops/source/mnist_node.h" -using namespace mindspore::dataset::api; -using mindspore::dataset::Tensor; +using namespace mindspore::dataset; class MindDataTestPipeline : public UT::DatasetOpTesting { protected: diff --git a/tests/ut/cpp/dataset/c_api_dataset_randomdata_test.cc b/tests/ut/cpp/dataset/c_api_dataset_randomdata_test.cc index 090c4f7352..42429097e5 100644 --- a/tests/ut/cpp/dataset/c_api_dataset_randomdata_test.cc +++ b/tests/ut/cpp/dataset/c_api_dataset_randomdata_test.cc @@ -34,10 +34,6 @@ #include "minddata/dataset/engine/ir/datasetops/source/random_node.h" using namespace mindspore::dataset; -using namespace mindspore::dataset::api; -using mindspore::dataset::DataType; -using mindspore::dataset::Tensor; -using mindspore::dataset::TensorShape; class MindDataTestPipeline : public UT::DatasetOpTesting { protected: diff --git a/tests/ut/cpp/dataset/c_api_dataset_save.cc b/tests/ut/cpp/dataset/c_api_dataset_save.cc index f15d0989be..61bbb1930a 100644 --- a/tests/ut/cpp/dataset/c_api_dataset_save.cc +++ b/tests/ut/cpp/dataset/c_api_dataset_save.cc @@ -34,8 +34,7 @@ #include "minddata/dataset/engine/ir/datasetops/source/cifar10_node.h" #include "minddata/dataset/engine/ir/datasetops/source/minddata_node.h" -using namespace mindspore::dataset::api; -using mindspore::dataset::Tensor; +using namespace mindspore::dataset; class MindDataTestPipeline : public UT::DatasetOpTesting { protected: diff --git a/tests/ut/cpp/dataset/c_api_dataset_textfile_test.cc b/tests/ut/cpp/dataset/c_api_dataset_textfile_test.cc index ba34175a81..3da9e45708 100644 --- a/tests/ut/cpp/dataset/c_api_dataset_textfile_test.cc +++ b/tests/ut/cpp/dataset/c_api_dataset_textfile_test.cc @@ -32,7 +32,7 @@ #include "minddata/dataset/engine/ir/datasetops/source/text_file_node.h" using namespace mindspore::dataset; -using namespace mindspore::dataset::api; + using mindspore::dataset::ShuffleMode; using mindspore::dataset::Tensor; diff --git a/tests/ut/cpp/dataset/c_api_dataset_tfrecord_test.cc b/tests/ut/cpp/dataset/c_api_dataset_tfrecord_test.cc index 62eb1efd8c..a590bc3270 100644 --- a/tests/ut/cpp/dataset/c_api_dataset_tfrecord_test.cc +++ b/tests/ut/cpp/dataset/c_api_dataset_tfrecord_test.cc @@ -33,7 +33,7 @@ #include "minddata/dataset/engine/ir/datasetops/source/tf_record_node.h" using namespace mindspore::dataset; -using namespace mindspore::dataset::api; + using mindspore::dataset::DataType; using mindspore::dataset::ShuffleMode; using mindspore::dataset::Tensor; @@ -492,10 +492,10 @@ TEST_F(MindDataTestPipeline, TestTFRecordDatasetExeception2) { TEST_F(MindDataTestPipeline, TestIncorrectTFSchemaObject) { std::string path = datasets_root_path_ + "/test_tf_file_3_images2/train-0000-of-0001.data"; - std::shared_ptr schema = api::Schema(); + std::shared_ptr schema = Schema(); schema->add_column("image", "uint8", {1}); schema->add_column("label", "int64", {1}); - std::shared_ptr ds = api::TFRecord({path}, schema); + std::shared_ptr ds = TFRecord({path}, schema); EXPECT_NE(ds, nullptr); auto itr = ds->CreateIterator(); EXPECT_NE(itr, nullptr); @@ -506,7 +506,7 @@ TEST_F(MindDataTestPipeline, TestIncorrectTFSchemaObject) { TEST_F(MindDataTestPipeline, TestIncorrectTFrecordFile) { std::string path = datasets_root_path_ + "/test_tf_file_3_images2/datasetSchema.json"; - std::shared_ptr ds = api::TFRecord({path}); + std::shared_ptr ds = TFRecord({path}); EXPECT_NE(ds, nullptr); // the tf record file is incorrect, hence validate param will fail auto itr = ds->CreateIterator(); diff --git a/tests/ut/cpp/dataset/c_api_dataset_voc_test.cc b/tests/ut/cpp/dataset/c_api_dataset_voc_test.cc index 191447527d..b15058beeb 100644 --- a/tests/ut/cpp/dataset/c_api_dataset_voc_test.cc +++ b/tests/ut/cpp/dataset/c_api_dataset_voc_test.cc @@ -30,10 +30,7 @@ // IR leaf nodes #include "minddata/dataset/engine/ir/datasetops/source/voc_node.h" -using namespace mindspore::dataset::api; -using mindspore::dataset::DataType; -using mindspore::dataset::Tensor; -using mindspore::dataset::TensorShape; +using namespace mindspore::dataset; class MindDataTestPipeline : public UT::DatasetOpTesting { protected: diff --git a/tests/ut/cpp/dataset/c_api_datasets_test.cc b/tests/ut/cpp/dataset/c_api_datasets_test.cc index 8d40d8f081..6079d2b977 100644 --- a/tests/ut/cpp/dataset/c_api_datasets_test.cc +++ b/tests/ut/cpp/dataset/c_api_datasets_test.cc @@ -33,9 +33,7 @@ #include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h" #include "minddata/dataset/engine/ir/datasetops/source/mnist_node.h" -using namespace mindspore::dataset::api; -using mindspore::dataset::Tensor; -using mindspore::dataset::TensorShape; +using namespace mindspore::dataset; class MindDataTestPipeline : public UT::DatasetOpTesting { protected: diff --git a/tests/ut/cpp/dataset/c_api_samplers_test.cc b/tests/ut/cpp/dataset/c_api_samplers_test.cc index 2b246f19a3..47a7af349a 100644 --- a/tests/ut/cpp/dataset/c_api_samplers_test.cc +++ b/tests/ut/cpp/dataset/c_api_samplers_test.cc @@ -24,8 +24,7 @@ #include "minddata/dataset/engine/ir/datasetops/shuffle_node.h" #include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h" -using namespace mindspore::dataset::api; -using mindspore::dataset::Tensor; +using namespace mindspore::dataset; class MindDataTestPipeline : public UT::DatasetOpTesting { protected: diff --git a/tests/ut/cpp/dataset/c_api_text_sentence_piece_vocab_test.cc b/tests/ut/cpp/dataset/c_api_text_sentence_piece_vocab_test.cc index 2d344be1a0..2c89ae4a1f 100644 --- a/tests/ut/cpp/dataset/c_api_text_sentence_piece_vocab_test.cc +++ b/tests/ut/cpp/dataset/c_api_text_sentence_piece_vocab_test.cc @@ -29,11 +29,7 @@ // IR leaf nodes #include "minddata/dataset/engine/ir/datasetops/source/text_file_node.h" -using namespace mindspore::dataset::api; -using mindspore::dataset::Tensor; -using mindspore::dataset::ShuffleMode; -using mindspore::dataset::SentencePieceModel; -using mindspore::dataset::SentencePieceVocab; +using namespace mindspore::dataset; class MindDataTestPipeline : public UT::DatasetOpTesting { protected: diff --git a/tests/ut/cpp/dataset/c_api_text_vocab_test.cc b/tests/ut/cpp/dataset/c_api_text_vocab_test.cc index 1ab29a0c28..9dd7afb55f 100644 --- a/tests/ut/cpp/dataset/c_api_text_vocab_test.cc +++ b/tests/ut/cpp/dataset/c_api_text_vocab_test.cc @@ -37,12 +37,7 @@ // IR leaf nodes #include "minddata/dataset/engine/ir/datasetops/source/text_file_node.h" -using namespace mindspore::dataset::api; -using mindspore::dataset::DataType; -using mindspore::dataset::ShuffleMode; -using mindspore::dataset::Status; -using mindspore::dataset::Tensor; -using mindspore::dataset::Vocab; +using namespace mindspore::dataset; class MindDataTestPipeline : public UT::DatasetOpTesting { protected: diff --git a/tests/ut/cpp/dataset/c_api_transforms_test.cc b/tests/ut/cpp/dataset/c_api_transforms_test.cc index f8a729173f..86b860549d 100644 --- a/tests/ut/cpp/dataset/c_api_transforms_test.cc +++ b/tests/ut/cpp/dataset/c_api_transforms_test.cc @@ -48,9 +48,7 @@ #include "minddata/dataset/engine/ir/datasetops/source/tf_record_node.h" #include "minddata/dataset/engine/ir/datasetops/source/voc_node.h" -using namespace mindspore::dataset::api; -using mindspore::dataset::BorderType; -using mindspore::dataset::Tensor; +using namespace mindspore::dataset; class MindDataTestPipeline : public UT::DatasetOpTesting { protected: diff --git a/tests/ut/cpp/dataset/c_api_vision_test.cc b/tests/ut/cpp/dataset/c_api_vision_test.cc index 9642ac8652..4420c67152 100644 --- a/tests/ut/cpp/dataset/c_api_vision_test.cc +++ b/tests/ut/cpp/dataset/c_api_vision_test.cc @@ -35,10 +35,7 @@ #include "minddata/dataset/engine/ir/datasetops/source/mnist_node.h" #include "minddata/dataset/engine/ir/datasetops/source/voc_node.h" -using namespace mindspore::dataset::api; -using mindspore::dataset::BorderType; -using mindspore::dataset::InterpolationMode; -using mindspore::dataset::Tensor; +using namespace mindspore::dataset; class MindDataTestPipeline : public UT::DatasetOpTesting { protected: @@ -207,16 +204,16 @@ TEST_F(MindDataTestPipeline, TestCenterCropFail) { MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCenterCrop with invalid parameters."; // center crop height value negative - std::shared_ptr center_crop = mindspore::dataset::api::vision::CenterCrop({-32, 32}); + std::shared_ptr center_crop = mindspore::dataset::vision::CenterCrop({-32, 32}); EXPECT_EQ(center_crop, nullptr); // center crop width value negative - center_crop = mindspore::dataset::api::vision::CenterCrop({32, -32}); + center_crop = mindspore::dataset::vision::CenterCrop({32, -32}); EXPECT_EQ(center_crop, nullptr); // 0 value would result in nullptr - center_crop = mindspore::dataset::api::vision::CenterCrop({0, 32}); + center_crop = mindspore::dataset::vision::CenterCrop({0, 32}); EXPECT_EQ(center_crop, nullptr); // center crop with 3 values - center_crop = mindspore::dataset::api::vision::CenterCrop({10, 20, 30}); + center_crop = mindspore::dataset::vision::CenterCrop({10, 20, 30}); EXPECT_EQ(center_crop, nullptr); } @@ -224,13 +221,13 @@ TEST_F(MindDataTestPipeline, TestCropFail) { MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCrop with invalid parameters."; // wrong width - std::shared_ptr crop = mindspore::dataset::api::vision::Crop({0, 0}, {32, -32}); + std::shared_ptr crop = mindspore::dataset::vision::Crop({0, 0}, {32, -32}); EXPECT_EQ(crop, nullptr); // wrong height - crop = mindspore::dataset::api::vision::Crop({0, 0}, {-32, -32}); + crop = mindspore::dataset::vision::Crop({0, 0}, {-32, -32}); EXPECT_EQ(crop, nullptr); // zero height - crop = mindspore::dataset::api::vision::Crop({0, 0}, {0, 32}); + crop = mindspore::dataset::vision::Crop({0, 0}, {0, 32}); EXPECT_EQ(crop, nullptr); } @@ -889,13 +886,13 @@ TEST_F(MindDataTestPipeline, TestNormalizeFail) { // std value at 0.0 std::shared_ptr normalize = - mindspore::dataset::api::vision::Normalize({121.0, 115.0, 100.0}, {0.0, 68.0, 71.0}); + mindspore::dataset::vision::Normalize({121.0, 115.0, 100.0}, {0.0, 68.0, 71.0}); EXPECT_EQ(normalize, nullptr); // normalize with 2 values (not 3 values) for mean - normalize = mindspore::dataset::api::vision::Normalize({121.0, 115.0}, {70.0, 68.0, 71.0}); + normalize = mindspore::dataset::vision::Normalize({121.0, 115.0}, {70.0, 68.0, 71.0}); EXPECT_EQ(normalize, nullptr); // normalize with 2 values (not 3 values) for standard deviation - normalize = mindspore::dataset::api::vision::Normalize({121.0, 115.0, 100.0}, {68.0, 71.0}); + normalize = mindspore::dataset::vision::Normalize({121.0, 115.0, 100.0}, {68.0, 71.0}); EXPECT_EQ(normalize, nullptr); } @@ -1308,7 +1305,7 @@ TEST_F(MindDataTestPipeline, TestRandomCropWithBboxSuccess) { EXPECT_NE(ds, nullptr); // Create objects for the tensor ops - std::shared_ptr random_crop = mindspore::dataset::api::vision::RandomCropWithBBox({128, 128}); + std::shared_ptr random_crop = mindspore::dataset::vision::RandomCropWithBBox({128, 128}); EXPECT_NE(random_crop, nullptr); // Create a Map operation on ds @@ -1903,7 +1900,7 @@ TEST_F(MindDataTestPipeline, TestRandomSolarizeSucess1) { // Create objects for the tensor ops std::vector threshold = {10, 100}; - std::shared_ptr random_solarize = mindspore::dataset::api::vision::RandomSolarize(threshold); + std::shared_ptr random_solarize = mindspore::dataset::vision::RandomSolarize(threshold); EXPECT_NE(random_solarize, nullptr); // Create a Map operation on ds @@ -1942,7 +1939,7 @@ TEST_F(MindDataTestPipeline, TestRandomSolarizeSucess2) { EXPECT_NE(ds, nullptr); // Create objects for the tensor ops - std::shared_ptr random_solarize = mindspore::dataset::api::vision::RandomSolarize(); + std::shared_ptr random_solarize = mindspore::dataset::vision::RandomSolarize(); EXPECT_NE(random_solarize, nullptr); // Create a Map operation on ds @@ -1976,19 +1973,19 @@ TEST_F(MindDataTestPipeline, TestRandomSolarizeFail) { MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRandomSolarizeFail with invalid parameters."; std::vector threshold = {13, 1}; - std::shared_ptr random_solarize = mindspore::dataset::api::vision::RandomSolarize(threshold); + std::shared_ptr random_solarize = mindspore::dataset::vision::RandomSolarize(threshold); EXPECT_EQ(random_solarize, nullptr); threshold = {1, 2, 3}; - random_solarize = mindspore::dataset::api::vision::RandomSolarize(threshold); + random_solarize = mindspore::dataset::vision::RandomSolarize(threshold); EXPECT_EQ(random_solarize, nullptr); threshold = {1}; - random_solarize = mindspore::dataset::api::vision::RandomSolarize(threshold); + random_solarize = mindspore::dataset::vision::RandomSolarize(threshold); EXPECT_EQ(random_solarize, nullptr); threshold = {}; - random_solarize = mindspore::dataset::api::vision::RandomSolarize(threshold); + random_solarize = mindspore::dataset::vision::RandomSolarize(threshold); EXPECT_EQ(random_solarize, nullptr); } @@ -2007,13 +2004,13 @@ TEST_F(MindDataTestPipeline, DISABLED_TestRandomVerticalFlipFail) { TEST_F(MindDataTestPipeline, TestResizeFail) { MS_LOG(INFO) << "Doing MindDataTestPipeline-TestResize with invalid parameters."; // negative resize value - std::shared_ptr resize_op = mindspore::dataset::api::vision::Resize({30, -30}); + std::shared_ptr resize_op = mindspore::dataset::vision::Resize({30, -30}); EXPECT_EQ(resize_op, nullptr); // zero resize value - resize_op = mindspore::dataset::api::vision::Resize({0, 30}); + resize_op = mindspore::dataset::vision::Resize({0, 30}); EXPECT_EQ(resize_op, nullptr); // resize with 3 values - resize_op = mindspore::dataset::api::vision::Resize({30, 20, 10}); + resize_op = mindspore::dataset::vision::Resize({30, 20, 10}); EXPECT_EQ(resize_op, nullptr); } @@ -2137,7 +2134,7 @@ TEST_F(MindDataTestPipeline, TestRescaleSucess1) { auto image = row["image"]; // Create objects for the tensor ops - std::shared_ptr rescale = mindspore::dataset::api::vision::Rescale(1.0, 0.0); + std::shared_ptr rescale = mindspore::dataset::vision::Rescale(1.0, 0.0); EXPECT_NE(rescale, nullptr); // Convert to the same type @@ -2172,7 +2169,7 @@ TEST_F(MindDataTestPipeline, TestRescaleSucess2) { EXPECT_NE(ds, nullptr); // Create objects for the tensor ops - std::shared_ptr rescale = mindspore::dataset::api::vision::Rescale(1.0 / 255, 1.0); + std::shared_ptr rescale = mindspore::dataset::vision::Rescale(1.0 / 255, 1.0); EXPECT_NE(rescale, nullptr); ds = ds->Map({rescale}, {"image"}); @@ -2204,7 +2201,7 @@ TEST_F(MindDataTestPipeline, TestRescaleSucess2) { TEST_F(MindDataTestPipeline, TestRescaleFail) { MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRescaleFail with invalid params."; // incorrect negative rescale parameter - std::shared_ptr rescale = mindspore::dataset::api::vision::Rescale(-1.0, 0.0); + std::shared_ptr rescale = mindspore::dataset::vision::Rescale(-1.0, 0.0); EXPECT_EQ(rescale, nullptr); } diff --git a/tests/ut/cpp/dataset/cache_op_test.cc b/tests/ut/cpp/dataset/cache_op_test.cc index 1cd4328c5f..7926d11607 100644 --- a/tests/ut/cpp/dataset/cache_op_test.cc +++ b/tests/ut/cpp/dataset/cache_op_test.cc @@ -147,11 +147,11 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestConcurrencyRequest) { (void)TaskManager::GetMasterThreadRc(); TaskGroup vg; Status rc; - + session_id_type env_session; rc = GetSessionFromEnv(&env_session); ASSERT_TRUE(rc.IsOk()); - + // use arbitrary session of 1, size 1, spilling is true CacheClient::Builder builder; // use arbitrary session of 1, size of 0, spilling// is true @@ -273,7 +273,7 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestRandomDataCache1) { int64_t num_samples = 0; int64_t start_index = 0; - auto seq_sampler = std::make_shared(num_samples, start_index); + auto seq_sampler = std::make_shared(num_samples, start_index); rc = CacheOp::Builder() .SetNumWorkers(5) .SetClient(myClient) @@ -386,7 +386,7 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestRandomDataCacheSpill) { // CacheOp int64_t num_samples = 0; int64_t start_index = 0; - auto seq_sampler = std::make_shared(num_samples, start_index); + auto seq_sampler = std::make_shared(num_samples, start_index); CacheClient::Builder builder; builder.SetSessionId(env_session).SetCacheMemSz(4).SetSpill(true); std::shared_ptr myClient; @@ -457,7 +457,7 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestImageFolderCacheMerge) { rc = GetSessionFromEnv(&env_session); ASSERT_TRUE(rc.IsOk()); - auto seq_sampler = std::make_shared(num_samples, start_index); + auto seq_sampler = std::make_shared(num_samples, start_index); CacheClient::Builder ccbuilder; ccbuilder.SetSessionId(env_session).SetCacheMemSz(0).SetSpill(true); @@ -559,7 +559,7 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestCacheInheritSampler) { int64_t num_samples = 0; int64_t start_index = 0; - auto seq_sampler = std::make_shared(num_samples, start_index); + auto seq_sampler = std::make_shared(num_samples, start_index); // Start with an empty execution tree auto myTree = std::make_shared(); diff --git a/tests/ut/cpp/dataset/celeba_op_test.cc b/tests/ut/cpp/dataset/celeba_op_test.cc index ded5490fd5..988ad82f74 100644 --- a/tests/ut/cpp/dataset/celeba_op_test.cc +++ b/tests/ut/cpp/dataset/celeba_op_test.cc @@ -38,8 +38,8 @@ std::shared_ptr Repeat(int repeat_cnt); std::shared_ptr Build(std::vector> ops); std::shared_ptr Celeba(int32_t num_workers, int32_t rows_per_buffer, int32_t queue_size, - const std::string &dir, std::shared_ptr sampler = nullptr, - bool decode = false, const std::string &dataset_type="all") { + const std::string &dir, std::shared_ptr sampler = nullptr, + bool decode = false, const std::string &dataset_type = "all") { std::shared_ptr so; CelebAOp::Builder builder; Status rc = builder.SetNumWorkers(num_workers) @@ -125,8 +125,9 @@ TEST_F(MindDataTestCelebaDataset, TestCelebaRepeat) { TEST_F(MindDataTestCelebaDataset, TestSubsetRandomSamplerCeleba) { std::vector indices({1}); int64_t num_samples = 0; - std::shared_ptr sampler = std::make_shared(num_samples, indices); - uint32_t expect_labels[1][40] = {{0,0,0,1,0,0,0,1,0,0,0,1,0,0,0,0,0,0,0,1,0,1,0,0,1,0,0,0,0,0,0,1,0,0,0,0,0,0,0,1}}; + std::shared_ptr sampler = std::make_shared(num_samples, indices); + uint32_t expect_labels[1][40] = {{0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, + 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1}}; std::string dir = datasets_root_path_ + "/testCelebAData/"; uint32_t count = 0; auto tree = Build({Celeba(16, 2, 32, dir, std::move(sampler))}); diff --git a/tests/ut/cpp/dataset/cifar_op_test.cc b/tests/ut/cpp/dataset/cifar_op_test.cc index 0b6b4099a4..dc26e8d64f 100644 --- a/tests/ut/cpp/dataset/cifar_op_test.cc +++ b/tests/ut/cpp/dataset/cifar_op_test.cc @@ -44,12 +44,16 @@ std::shared_ptr Repeat(int repeatCnt); std::shared_ptr Build(std::vector> ops); std::shared_ptr Cifarop(uint64_t num_works, uint64_t rows, uint64_t conns, std::string path, - std::shared_ptr sampler = nullptr, bool cifar10 = true) { + std::shared_ptr sampler = nullptr, bool cifar10 = true) { std::shared_ptr so; CifarOp::Builder builder; - Status rc = builder.SetNumWorkers(num_works).SetCifarDir(path).SetRowsPerBuffer(rows) - .SetOpConnectorSize(conns).SetSampler(std::move(sampler)).SetCifarType(cifar10) - .Build(&so); + Status rc = builder.SetNumWorkers(num_works) + .SetCifarDir(path) + .SetRowsPerBuffer(rows) + .SetOpConnectorSize(conns) + .SetSampler(std::move(sampler)) + .SetCifarType(cifar10) + .Build(&so); return so; } @@ -91,7 +95,7 @@ TEST_F(MindDataTestCifarOp, TestSequentialSamplerCifar10) { TEST_F(MindDataTestCifarOp, TestRandomSamplerCifar10) { uint32_t original_seed = GlobalContext::config_manager()->seed(); GlobalContext::config_manager()->set_seed(0); - std::shared_ptr sampler = std::make_unique(12, true, true); + std::shared_ptr sampler = std::make_unique(12, true, true); std::string folder_path = datasets_root_path_ + "/testCifar10Data/"; auto tree = Build({Cifarop(16, 2, 32, folder_path, std::move(sampler))}); tree->Prepare(); diff --git a/tests/ut/cpp/dataset/distributed_sampler_test.cc b/tests/ut/cpp/dataset/distributed_sampler_test.cc index 5fe9a46327..90dfe2b7d8 100644 --- a/tests/ut/cpp/dataset/distributed_sampler_test.cc +++ b/tests/ut/cpp/dataset/distributed_sampler_test.cc @@ -48,7 +48,7 @@ TEST_F(MindDataTestDistributedSampler, TestTwoShardsOne) { uint64_t num_samples = 7; // create sampler with replacement = true - DistributedSampler m_sampler(num_samples, 2, 0, false, 0, -1, false); + DistributedSamplerRT m_sampler(num_samples, 2, 0, false, 0, -1, false); DummyRandomAccessOp dummyRandomAccessOp(num_samples); m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp); @@ -74,7 +74,7 @@ TEST_F(MindDataTestDistributedSampler, TestTwoShardsTwo) { uint64_t num_samples = 7; // create sampler with replacement = true - DistributedSampler m_sampler(num_samples, 2, 1, false, 0, -1, false); + DistributedSamplerRT m_sampler(num_samples, 2, 1, false, 0, -1, false); DummyRandomAccessOp dummyRandomAccessOp(num_samples); m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp); @@ -100,7 +100,7 @@ TEST_F(MindDataTestDistributedSampler, TestThreeShards) { uint64_t num_samples = 2; // create sampler with replacement = true - DistributedSampler m_sampler(num_samples, 3, 2, false, 0, -1, false); + DistributedSamplerRT m_sampler(num_samples, 3, 2, false, 0, -1, false); DummyRandomAccessOp dummyRandomAccessOp(num_samples); m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp); diff --git a/tests/ut/cpp/dataset/epoch_ctrl_op_test.cc b/tests/ut/cpp/dataset/epoch_ctrl_op_test.cc index 2fc5f3c047..dba29151ef 100644 --- a/tests/ut/cpp/dataset/epoch_ctrl_op_test.cc +++ b/tests/ut/cpp/dataset/epoch_ctrl_op_test.cc @@ -26,7 +26,7 @@ using mindspore::ExceptionType::NoExceptionType; using mindspore::LogStream; std::shared_ptr ImageFolder(int64_t num_works, int64_t rows, int64_t conns, std::string path, - bool shuf = false, std::shared_ptr sampler = nullptr, + bool shuf = false, std::shared_ptr sampler = nullptr, std::map map = {}, bool decode = false); std::shared_ptr Build(std::vector> ops); diff --git a/tests/ut/cpp/dataset/image_folder_op_test.cc b/tests/ut/cpp/dataset/image_folder_op_test.cc index 2cce023dcf..332343d45f 100644 --- a/tests/ut/cpp/dataset/image_folder_op_test.cc +++ b/tests/ut/cpp/dataset/image_folder_op_test.cc @@ -49,7 +49,7 @@ std::shared_ptr Repeat(int repeat_cnt); std::shared_ptr Build(std::vector> ops); std::shared_ptr ImageFolder(int64_t num_works, int64_t rows, int64_t conns, std::string path, - bool shuf = false, std::shared_ptr sampler = nullptr, + bool shuf = false, std::shared_ptr sampler = nullptr, std::map map = {}, bool decode = false) { std::shared_ptr so; ImageFolderOp::Builder builder; @@ -133,7 +133,7 @@ TEST_F(MindDataTestImageFolderSampler, TestRandomSamplerImageFolder) { int32_t original_seed = GlobalContext::config_manager()->seed(); GlobalContext::config_manager()->set_seed(0); int64_t num_samples = 12; - std::shared_ptr sampler = std::make_unique(num_samples, true, true); + std::shared_ptr sampler = std::make_unique(num_samples, true, true); int32_t res[] = {2, 2, 2, 3, 2, 3, 2, 3, 1, 2, 2, 1}; // ground truth label std::string folder_path = datasets_root_path_ + "/testPK/data"; auto tree = Build({ImageFolder(16, 2, 32, folder_path, false, std::move(sampler))}); @@ -196,7 +196,7 @@ TEST_F(MindDataTestImageFolderSampler, TestSubsetRandomSamplerImageFolder) { // id range 0 - 10 is label 0, and id range 11 - 21 is label 1 std::vector indices({0, 1, 2, 3, 4, 5, 12, 13, 14, 15, 16, 11}); int64_t num_samples = 0; - std::shared_ptr sampler = std::make_shared(num_samples, indices); + std::shared_ptr sampler = std::make_shared(num_samples, indices); std::string folder_path = datasets_root_path_ + "/testPK/data"; // Expect 6 samples for label 0 and 1 int res[2] = {6, 6}; @@ -233,8 +233,8 @@ TEST_F(MindDataTestImageFolderSampler, TestWeightedRandomSamplerImageFolder) { std::vector weights(total_samples, std::rand() % 100); // create sampler with replacement = replacement - std::shared_ptr sampler = - std::make_shared(num_samples, weights, true, samples_per_buffer); + std::shared_ptr sampler = + std::make_shared(num_samples, weights, true, samples_per_buffer); std::string folder_path = datasets_root_path_ + "/testPK/data"; auto tree = Build({ImageFolder(16, 2, 32, folder_path, false, std::move(sampler))}); @@ -292,7 +292,7 @@ TEST_F(MindDataTestImageFolderSampler, TestImageFolderClassIndex) { TEST_F(MindDataTestImageFolderSampler, TestDistributedSampler) { int64_t num_samples = 0; - std::shared_ptr sampler = std::make_shared(num_samples, 11, 10, false); + std::shared_ptr sampler = std::make_shared(num_samples, 11, 10, false); std::string folder_path = datasets_root_path_ + "/testPK/data"; auto tree = Build({ImageFolder(16, 2, 32, folder_path, false, std::move(sampler)), Repeat(4)}); tree->Prepare(); @@ -320,7 +320,7 @@ TEST_F(MindDataTestImageFolderSampler, TestDistributedSampler) { TEST_F(MindDataTestImageFolderSampler, TestPKSamplerImageFolder) { int64_t num_samples = 0; - std::shared_ptr sampler = std::make_shared(num_samples, 3, false, 4); + std::shared_ptr sampler = std::make_shared(num_samples, 3, false, 4); int32_t res[] = {0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3}; // ground truth label std::string folder_path = datasets_root_path_ + "/testPK/data"; auto tree = Build({ImageFolder(16, 2, 32, folder_path, false, std::move(sampler))}); @@ -355,7 +355,7 @@ TEST_F(MindDataTestImageFolderSampler, TestImageFolderDecode) { map["wrong folder name"] = 1234; // this is skipped int64_t num_samples = 20; int64_t start_index = 0; - auto seq_sampler = std::make_shared(num_samples, start_index); + auto seq_sampler = std::make_shared(num_samples, start_index); auto tree = Build({ImageFolder(16, 2, 32, folder_path, false, std::move(seq_sampler), map, true)}); int64_t res[2] = {111, 333}; tree->Prepare(); @@ -385,7 +385,7 @@ TEST_F(MindDataTestImageFolderSampler, TestImageFolderDecode) { TEST_F(MindDataTestImageFolderSampler, TestImageFolderSharding1) { int64_t num_samples = 5; - std::shared_ptr sampler = std::make_shared(num_samples, 4, 0, false); + std::shared_ptr sampler = std::make_shared(num_samples, 4, 0, false); std::string folder_path = datasets_root_path_ + "/testPK/data"; // numWrks, rows, conns, path, shuffle, sampler, map, numSamples, decode auto tree = Build({ImageFolder(16, 2, 32, folder_path, false, std::move(sampler), {})}); @@ -415,7 +415,7 @@ TEST_F(MindDataTestImageFolderSampler, TestImageFolderSharding1) { TEST_F(MindDataTestImageFolderSampler, TestImageFolderSharding2) { int64_t num_samples = 12; - std::shared_ptr sampler = std::make_shared(num_samples, 4, 3, false); + std::shared_ptr sampler = std::make_shared(num_samples, 4, 3, false); std::string folder_path = datasets_root_path_ + "/testPK/data"; // numWrks, rows, conns, path, shuffle, sampler, map, numSamples, decode auto tree = Build({ImageFolder(16, 16, 32, folder_path, false, std::move(sampler), {})}); diff --git a/tests/ut/cpp/dataset/manifest_op_test.cc b/tests/ut/cpp/dataset/manifest_op_test.cc index 0d6621bfa2..af84b2cbcd 100644 --- a/tests/ut/cpp/dataset/manifest_op_test.cc +++ b/tests/ut/cpp/dataset/manifest_op_test.cc @@ -42,7 +42,7 @@ std::shared_ptr Repeat(int repeatCnt); std::shared_ptr Build(std::vector> ops); std::shared_ptr Manifest(int32_t num_works, int32_t rows, int32_t conns, const std::string &file, - std::string usage = "train", std::shared_ptr sampler = nullptr, + std::string usage = "train", std::shared_ptr sampler = nullptr, std::map map = {}, bool decode = false) { std::shared_ptr so; ManifestOp::Builder builder; @@ -86,7 +86,7 @@ TEST_F(MindDataTestManifest, TestSequentialManifestWithRepeat) { TEST_F(MindDataTestManifest, TestSubsetRandomSamplerManifest) { std::vector indices({1}); int64_t num_samples = 0; - std::shared_ptr sampler = std::make_shared(num_samples, indices); + std::shared_ptr sampler = std::make_shared(num_samples, indices); std::string file = datasets_root_path_ + "/testManifestData/cpp.json"; // Expect 6 samples for label 0 and 1 auto tree = Build({Manifest(16, 2, 32, file, "train", std::move(sampler))}); @@ -147,7 +147,7 @@ TEST_F(MindDataTestManifest, MindDataTestManifestNumSamples) { std::string file = datasets_root_path_ + "/testManifestData/cpp.json"; int64_t num_samples = 1; int64_t start_index = 0; - auto seq_sampler = std::make_shared(num_samples, start_index); + auto seq_sampler = std::make_shared(num_samples, start_index); auto tree = Build({Manifest(16, 2, 32, file, "train", std::move(seq_sampler), {}), Repeat(4)}); tree->Prepare(); Status rc = tree->Launch(); @@ -176,7 +176,7 @@ TEST_F(MindDataTestManifest, MindDataTestManifestEval) { std::string file = datasets_root_path_ + "/testManifestData/cpp.json"; int64_t num_samples = 1; int64_t start_index = 0; - auto seq_sampler = std::make_shared(num_samples, start_index); + auto seq_sampler = std::make_shared(num_samples, start_index); auto tree = Build({Manifest(16, 2, 32, file, "eval", std::move(seq_sampler), {})}); tree->Prepare(); Status rc = tree->Launch(); diff --git a/tests/ut/cpp/dataset/map_op_test.cc b/tests/ut/cpp/dataset/map_op_test.cc index 0cee75b264..aa1bd29add 100644 --- a/tests/ut/cpp/dataset/map_op_test.cc +++ b/tests/ut/cpp/dataset/map_op_test.cc @@ -132,7 +132,7 @@ class MindDataTestMapOp : public UT::DatasetOpTesting { }; std::shared_ptr ImageFolder(int64_t num_works, int64_t rows, int64_t conns, std::string path, - bool shuf = false, std::shared_ptr sampler = nullptr, + bool shuf = false, std::shared_ptr sampler = nullptr, std::map map = {}, bool decode = false); std::shared_ptr Build(std::vector> ops); diff --git a/tests/ut/cpp/dataset/mnist_op_test.cc b/tests/ut/cpp/dataset/mnist_op_test.cc index c40086e20f..a6b03c288b 100644 --- a/tests/ut/cpp/dataset/mnist_op_test.cc +++ b/tests/ut/cpp/dataset/mnist_op_test.cc @@ -51,12 +51,16 @@ std::shared_ptr Build(std::vector> ops Status Create1DTensor(std::shared_ptr *sample_ids, int64_t num_elements, unsigned char *data = nullptr, DataType::Type data_type = DataType::DE_UINT32); -std::shared_ptr CreateMnist(int64_t num_wrks, int64_t rows, int64_t conns, std::string path, - bool shuf = false, std::shared_ptr sampler = nullptr) { +std::shared_ptr CreateMnist(int64_t num_wrks, int64_t rows, int64_t conns, std::string path, bool shuf = false, + std::shared_ptr sampler = nullptr) { std::shared_ptr so; MnistOp::Builder builder; - Status rc = builder.SetNumWorkers(num_wrks).SetDir(path).SetRowsPerBuffer(rows) - .SetOpConnectorSize(conns).SetSampler(std::move(sampler)).Build(&so); + Status rc = builder.SetNumWorkers(num_wrks) + .SetDir(path) + .SetRowsPerBuffer(rows) + .SetOpConnectorSize(conns) + .SetSampler(std::move(sampler)) + .Build(&so); return so; } @@ -73,7 +77,7 @@ TEST_F(MindDataTestMnistSampler, TestSequentialMnistWithRepeat) { std::string folder_path = datasets_root_path_ + "/testMnistData/"; int64_t num_samples = 10; int64_t start_index = 0; - auto seq_sampler = std::make_shared(num_samples, start_index); + auto seq_sampler = std::make_shared(num_samples, start_index); auto tree = Build({CreateMnist(16, 2, 32, folder_path, false, std::move(seq_sampler)), Repeat(2)}); tree->Prepare(); uint32_t res[] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; @@ -103,7 +107,7 @@ TEST_F(MindDataTestMnistSampler, TestSequentialImageFolderWithRepeatBatch) { std::string folder_path = datasets_root_path_ + "/testMnistData/"; int64_t num_samples = 10; int64_t start_index = 0; - auto seq_sampler = std::make_shared(num_samples, start_index); + auto seq_sampler = std::make_shared(num_samples, start_index); auto tree = Build({CreateMnist(16, 2, 32, folder_path, false, std::move(seq_sampler)), Repeat(2), Batch(5)}); tree->Prepare(); uint32_t res[4][5] = { {0, 0, 0, 0, 0 }, diff --git a/tests/ut/cpp/dataset/stand_alone_samplers_test.cc b/tests/ut/cpp/dataset/stand_alone_samplers_test.cc index 79464b732b..27002a2734 100644 --- a/tests/ut/cpp/dataset/stand_alone_samplers_test.cc +++ b/tests/ut/cpp/dataset/stand_alone_samplers_test.cc @@ -61,7 +61,8 @@ TEST_F(MindDataTestStandAloneSampler, TestDistributedSampler) { std::shared_ptr tensor; int64_t num_samples = 0; for (int i = 0; i < 6; i++) { - std::shared_ptr sampler = std::make_shared(num_samples, 3, i % 3, (i < 3 ? false : true)); + std::shared_ptr sampler = + std::make_shared(num_samples, 3, i % 3, (i < 3 ? false : true)); sampler->HandshakeRandomAccessOp(&mock); sampler->GetNextSample(&db); db->GetTensor(&tensor, 0, 0); @@ -81,7 +82,7 @@ TEST_F(MindDataTestStandAloneSampler, TestStandAoneSequentialSampler) { CreateINT64Tensor(&label2, 2, reinterpret_cast(res + 3)); int64_t num_samples = 0; int64_t start_index = 0; - std::shared_ptr sampler = std::make_shared(num_samples, start_index, 3); + std::shared_ptr sampler = std::make_shared(num_samples, start_index, 3); std::unique_ptr db; std::shared_ptr tensor; sampler->HandshakeRandomAccessOp(&mock); diff --git a/tests/ut/cpp/dataset/subset_random_sampler_test.cc b/tests/ut/cpp/dataset/subset_random_sampler_test.cc index c389686014..3ea8078a9b 100644 --- a/tests/ut/cpp/dataset/subset_random_sampler_test.cc +++ b/tests/ut/cpp/dataset/subset_random_sampler_test.cc @@ -41,7 +41,7 @@ TEST_F(MindDataTestSubsetRandomSampler, TestAllAtOnce) { std::vector in({0, 1, 2, 3, 4}); std::unordered_set in_set(in.begin(), in.end()); int64_t num_samples = 0; - SubsetRandomSampler sampler(num_samples, in); + SubsetRandomSamplerRT sampler(num_samples, in); DummyRandomAccessOp dummyRandomAccessOp(5); sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp); @@ -70,7 +70,7 @@ TEST_F(MindDataTestSubsetRandomSampler, TestGetNextBuffer) { int64_t samples_per_buffer = 10; int64_t num_samples = 0; std::vector input(total_samples, 1); - SubsetRandomSampler sampler(num_samples, input, samples_per_buffer); + SubsetRandomSamplerRT sampler(num_samples, input, samples_per_buffer); DummyRandomAccessOp dummyRandomAccessOp(total_samples); sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp); @@ -102,7 +102,7 @@ TEST_F(MindDataTestSubsetRandomSampler, TestReset) { std::vector in({0, 1, 2, 3, 4}); std::unordered_set in_set(in.begin(), in.end()); int64_t num_samples = 0; - SubsetRandomSampler sampler(num_samples, in); + SubsetRandomSamplerRT sampler(num_samples, in); DummyRandomAccessOp dummyRandomAccessOp(5); sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp); diff --git a/tests/ut/cpp/dataset/tensor_op_fusion_pass_test.cc b/tests/ut/cpp/dataset/tensor_op_fusion_pass_test.cc index 70832c04b5..14f215ecf0 100644 --- a/tests/ut/cpp/dataset/tensor_op_fusion_pass_test.cc +++ b/tests/ut/cpp/dataset/tensor_op_fusion_pass_test.cc @@ -38,7 +38,7 @@ class MindDataTestTensorOpFusionPass : public UT::DatasetOpTesting { TEST_F(MindDataTestTensorOpFusionPass, RandomCropDecodeResize_fusion_disabled) { MS_LOG(INFO) << "Doing RandomCropDecodeResize_fusion"; std::shared_ptr ImageFolder(int64_t num_works, int64_t rows, int64_t conns, std::string path, - bool shuf = false, std::shared_ptr sampler = nullptr, + bool shuf = false, std::shared_ptr sampler = nullptr, std::map map = {}, bool decode = false); std::shared_ptr Build(std::vector> ops); auto rcar_op = std::make_shared(); @@ -73,7 +73,7 @@ TEST_F(MindDataTestTensorOpFusionPass, RandomCropDecodeResize_fusion_disabled) { TEST_F(MindDataTestTensorOpFusionPass, RandomCropDecodeResize_fusion_enabled) { MS_LOG(INFO) << "Doing RandomCropDecodeResize_fusion"; std::shared_ptr ImageFolder(int64_t num_works, int64_t rows, int64_t conns, std::string path, - bool shuf = false, std::shared_ptr sampler = nullptr, + bool shuf = false, std::shared_ptr sampler = nullptr, std::map map = {}, bool decode = false); std::shared_ptr Build(std::vector> ops); auto rcar_op = std::make_shared(); diff --git a/tests/ut/cpp/dataset/tree_adapter_test.cc b/tests/ut/cpp/dataset/tree_adapter_test.cc index c0f336232b..a51224e9a3 100644 --- a/tests/ut/cpp/dataset/tree_adapter_test.cc +++ b/tests/ut/cpp/dataset/tree_adapter_test.cc @@ -48,7 +48,7 @@ TEST_F(MindDataTestTreeAdapter, TestSimpleTreeAdapter) { // Create a Mnist Dataset std::string folder_path = datasets_root_path_ + "/testMnistData/"; - std::shared_ptr ds = Mnist(folder_path, "all", api::SequentialSampler(0, 4)); + std::shared_ptr ds = Mnist(folder_path, "all", SequentialSampler(0, 4)); EXPECT_NE(ds, nullptr); ds = ds->Batch(2); @@ -83,7 +83,7 @@ TEST_F(MindDataTestTreeAdapter, TestTreeAdapterWithRepeat) { // Create a Mnist Dataset std::string folder_path = datasets_root_path_ + "/testMnistData/"; - std::shared_ptr ds = Mnist(folder_path, "all", api::SequentialSampler(0, 3)); + std::shared_ptr ds = Mnist(folder_path, "all", SequentialSampler(0, 3)); EXPECT_NE(ds, nullptr); ds = ds->Batch(2, false); @@ -115,11 +115,11 @@ TEST_F(MindDataTestTreeAdapter, TestProjectMapTreeAdapter) { // Create an ImageFolder Dataset std::string folder_path = datasets_root_path_ + "/testPK/data/"; - std::shared_ptr ds = ImageFolder(folder_path, true, api::SequentialSampler(0, 2)); + std::shared_ptr ds = ImageFolder(folder_path, true, SequentialSampler(0, 2)); EXPECT_NE(ds, nullptr); // Create objects for the tensor ops - std::shared_ptr one_hot = api::transforms::OneHot(10); + std::shared_ptr one_hot = transforms::OneHot(10); EXPECT_NE(one_hot, nullptr); // Create a Map operation, this will automatically add a project after map diff --git a/tests/ut/cpp/dataset/weighted_random_sampler_test.cc b/tests/ut/cpp/dataset/weighted_random_sampler_test.cc index bb3079aec8..70505644b5 100644 --- a/tests/ut/cpp/dataset/weighted_random_sampler_test.cc +++ b/tests/ut/cpp/dataset/weighted_random_sampler_test.cc @@ -51,7 +51,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestOneshotReplacement) { std::vector freq(total_samples, 0); // create sampler with replacement = true - WeightedRandomSampler m_sampler(num_samples, weights, true); + WeightedRandomSamplerRT m_sampler(num_samples, weights, true); DummyRandomAccessOp dummyRandomAccessOp(total_samples); m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp); @@ -81,7 +81,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestOneshotNoReplacement) { std::vector freq(total_samples, 0); // create sampler with replacement = replacement - WeightedRandomSampler m_sampler(num_samples, weights, false); + WeightedRandomSamplerRT m_sampler(num_samples, weights, false); DummyRandomAccessOp dummyRandomAccessOp(total_samples); m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp); @@ -117,7 +117,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestGetNextBufferReplacement) { std::vector weights(total_samples, std::rand() % 100); // create sampler with replacement = replacement - WeightedRandomSampler m_sampler(num_samples, weights, true, samples_per_buffer); + WeightedRandomSamplerRT m_sampler(num_samples, weights, true, samples_per_buffer); DummyRandomAccessOp dummyRandomAccessOp(total_samples); m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp); @@ -153,7 +153,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestGetNextBufferNoReplacement) { std::vector freq(total_samples, 0); // create sampler with replacement = replacement - WeightedRandomSampler m_sampler(num_samples, weights, false, samples_per_buffer); + WeightedRandomSamplerRT m_sampler(num_samples, weights, false, samples_per_buffer); DummyRandomAccessOp dummyRandomAccessOp(total_samples); m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp); @@ -194,7 +194,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestResetReplacement) { std::vector freq(total_samples, 0); // create sampler with replacement = true - WeightedRandomSampler m_sampler(num_samples, weights, true); + WeightedRandomSamplerRT m_sampler(num_samples, weights, true); DummyRandomAccessOp dummyRandomAccessOp(total_samples); m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp); @@ -239,7 +239,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestResetNoReplacement) { std::vector freq(total_samples, 0); // create sampler with replacement = true - WeightedRandomSampler m_sampler(num_samples, weights, false); + WeightedRandomSamplerRT m_sampler(num_samples, weights, false); DummyRandomAccessOp dummyRandomAccessOp(total_samples); m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);