diff --git a/mindspore/ccsrc/minddata/dataset/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/CMakeLists.txt index a1a1402d9f..dfa3e8f1df 100644 --- a/mindspore/ccsrc/minddata/dataset/CMakeLists.txt +++ b/mindspore/ccsrc/minddata/dataset/CMakeLists.txt @@ -80,6 +80,8 @@ add_dependencies(text-kernels core) add_dependencies(cpp-API core) add_dependencies(engine-ir-datasetops core) add_dependencies(engine-ir-datasetops-source core) +add_dependencies(engine-ir-cache core) + if (ENABLE_PYTHON) add_dependencies(APItoPython core) endif() @@ -102,8 +104,9 @@ set(submodules $ $ $ - $ - $ + $ + $ + $ $ $ $ diff --git a/mindspore/ccsrc/minddata/dataset/api/datasets.cc b/mindspore/ccsrc/minddata/dataset/api/datasets.cc index 3f3cce764e..0fb12d4272 100644 --- a/mindspore/ccsrc/minddata/dataset/api/datasets.cc +++ b/mindspore/ccsrc/minddata/dataset/api/datasets.cc @@ -18,6 +18,7 @@ #include #include #include +#include #include "minddata/dataset/include/samplers.h" #include "minddata/dataset/include/transforms.h" // Source dataset headers (in alphabetical order) @@ -32,6 +33,7 @@ #ifndef ENABLE_ANDROID #include "minddata/dataset/engine/datasetops/source/manifest_op.h" #include "minddata/dataset/engine/datasetops/source/mindrecord_op.h" +#include "minddata/dataset/engine/ir/cache/dataset_cache_impl.h" #endif #include "minddata/dataset/engine/datasetops/source/mnist_op.h" #include "minddata/dataset/engine/datasetops/source/random_data_op.h" @@ -113,6 +115,9 @@ Dataset::Dataset() { worker_connector_size_ = cfg->worker_connector_size(); } +// Constructor to initialize the cache +Dataset::Dataset(const std::shared_ptr &dataset_cache) : Dataset() { cache_ = dataset_cache; } + /// \brief Function to create a SchemaObj /// \param[in] schema_file Path of schema file /// \return Shared pointer to the current schema @@ -137,8 +142,9 @@ std::shared_ptr Album(const std::string &dataset_dir, const std::stri // Function to create a CelebANode. std::shared_ptr CelebA(const std::string &dataset_dir, const std::string &usage, const std::shared_ptr &sampler, bool decode, - const std::set &extensions) { - auto ds = std::make_shared(dataset_dir, usage, sampler, decode, extensions); + const std::set &extensions, + const std::shared_ptr &cache) { + auto ds = std::make_shared(dataset_dir, usage, sampler, decode, extensions, cache); // Call derived class validation method. return ds->ValidateParams() ? ds : nullptr; @@ -146,8 +152,9 @@ std::shared_ptr CelebA(const std::string &dataset_dir, const std::st // Function to create a Cifar10Node. std::shared_ptr Cifar10(const std::string &dataset_dir, const std::string &usage, - const std::shared_ptr &sampler) { - auto ds = std::make_shared(dataset_dir, usage, sampler); + const std::shared_ptr &sampler, + const std::shared_ptr &cache) { + auto ds = std::make_shared(dataset_dir, usage, sampler, cache); // Call derived class validation method. return ds->ValidateParams() ? ds : nullptr; @@ -155,8 +162,9 @@ std::shared_ptr Cifar10(const std::string &dataset_dir, const std:: // Function to create a Cifar100Node. std::shared_ptr Cifar100(const std::string &dataset_dir, const std::string &usage, - const std::shared_ptr &sampler) { - auto ds = std::make_shared(dataset_dir, usage, sampler); + const std::shared_ptr &sampler, + const std::shared_ptr &cache) { + auto ds = std::make_shared(dataset_dir, usage, sampler, cache); // Call derived class validation method. return ds->ValidateParams() ? ds : nullptr; @@ -165,8 +173,8 @@ std::shared_ptr Cifar100(const std::string &dataset_dir, const std // Function to create a CLUENode. std::shared_ptr CLUE(const std::vector &clue_files, const std::string &task, const std::string &usage, int64_t num_samples, ShuffleMode shuffle, int32_t num_shards, - int32_t shard_id) { - auto ds = std::make_shared(clue_files, task, usage, num_samples, shuffle, num_shards, shard_id); + int32_t shard_id, const std::shared_ptr &cache) { + auto ds = std::make_shared(clue_files, task, usage, num_samples, shuffle, num_shards, shard_id, cache); // Call derived class validation method. return ds->ValidateParams() ? ds : nullptr; @@ -174,9 +182,9 @@ std::shared_ptr CLUE(const std::vector &clue_files, const // Function to create a CocoNode. std::shared_ptr Coco(const std::string &dataset_dir, const std::string &annotation_file, - const std::string &task, const bool &decode, - const std::shared_ptr &sampler) { - auto ds = std::make_shared(dataset_dir, annotation_file, task, decode, sampler); + const std::string &task, const bool &decode, const std::shared_ptr &sampler, + const std::shared_ptr &cache) { + auto ds = std::make_shared(dataset_dir, annotation_file, task, decode, sampler, cache); // Call derived class validation method. return ds->ValidateParams() ? ds : nullptr; @@ -186,9 +194,9 @@ std::shared_ptr Coco(const std::string &dataset_dir, const std::string std::shared_ptr CSV(const std::vector &dataset_files, char field_delim, const std::vector> &column_defaults, const std::vector &column_names, int64_t num_samples, ShuffleMode shuffle, - int32_t num_shards, int32_t shard_id) { + int32_t num_shards, int32_t shard_id, const std::shared_ptr &cache) { auto ds = std::make_shared(dataset_files, field_delim, column_defaults, column_names, num_samples, shuffle, - num_shards, shard_id); + num_shards, shard_id, cache); // Call derived class validation method. return ds->ValidateParams() ? ds : nullptr; @@ -198,12 +206,14 @@ std::shared_ptr CSV(const std::vector &dataset_files, char std::shared_ptr ImageFolder(const std::string &dataset_dir, bool decode, const std::shared_ptr &sampler, const std::set &extensions, - const std::map &class_indexing) { + const std::map &class_indexing, + const std::shared_ptr &cache) { // This arg exists in ImageFolderOp, but not externalized (in Python API). The default value is false. bool recursive = false; // Create logical representation of ImageFolderNode. - auto ds = std::make_shared(dataset_dir, decode, sampler, recursive, extensions, class_indexing); + auto ds = + std::make_shared(dataset_dir, decode, sampler, recursive, extensions, class_indexing, cache); // Call derived class validation method. return ds->ValidateParams() ? ds : nullptr; @@ -213,8 +223,9 @@ std::shared_ptr ImageFolder(const std::string &dataset_dir, boo // Function to create a ManifestNode. std::shared_ptr Manifest(const std::string &dataset_file, const std::string &usage, const std::shared_ptr &sampler, - const std::map &class_indexing, bool decode) { - auto ds = std::make_shared(dataset_file, usage, sampler, class_indexing, decode); + const std::map &class_indexing, bool decode, + const std::shared_ptr &cache) { + auto ds = std::make_shared(dataset_file, usage, sampler, class_indexing, decode, cache); // Call derived class validation method. return ds->ValidateParams() ? ds : nullptr; @@ -244,8 +255,9 @@ std::shared_ptr MindData(const std::vector &dataset_f // Function to create a MnistNode. std::shared_ptr Mnist(const std::string &dataset_dir, const std::string &usage, - const std::shared_ptr &sampler) { - auto ds = std::make_shared(dataset_dir, usage, sampler); + const std::shared_ptr &sampler, + const std::shared_ptr &cache) { + auto ds = std::make_shared(dataset_dir, usage, sampler, cache); // Call derived class validation method. return ds->ValidateParams() ? ds : nullptr; @@ -262,8 +274,9 @@ std::shared_ptr operator+(const std::shared_ptr &datasets1, // Function to create a TextFileNode. std::shared_ptr TextFile(const std::vector &dataset_files, int64_t num_samples, - ShuffleMode shuffle, int32_t num_shards, int32_t shard_id) { - auto ds = std::make_shared(dataset_files, num_samples, shuffle, num_shards, shard_id); + ShuffleMode shuffle, int32_t num_shards, int32_t shard_id, + const std::shared_ptr &cache) { + auto ds = std::make_shared(dataset_files, num_samples, shuffle, num_shards, shard_id, cache); // Call derived class validation method. return ds->ValidateParams() ? ds : nullptr; @@ -273,8 +286,8 @@ std::shared_ptr TextFile(const std::vector &dataset_f // Function to create a VOCNode. std::shared_ptr VOC(const std::string &dataset_dir, const std::string &task, const std::string &usage, const std::map &class_indexing, bool decode, - const std::shared_ptr &sampler) { - auto ds = std::make_shared(dataset_dir, task, usage, class_indexing, decode, sampler); + const std::shared_ptr &sampler, const std::shared_ptr &cache) { + auto ds = std::make_shared(dataset_dir, task, usage, class_indexing, decode, sampler, cache); // Call derived class validation method. return ds->ValidateParams() ? ds : nullptr; @@ -365,8 +378,10 @@ std::shared_ptr Dataset::Concat(const std::vector Dataset::Map(std::vector> operations, std::vector input_columns, std::vector output_columns, - const std::vector &project_columns) { - auto ds = std::make_shared(shared_from_this(), operations, input_columns, output_columns, project_columns); + const std::vector &project_columns, + const std::shared_ptr &cache) { + auto ds = + std::make_shared(shared_from_this(), operations, input_columns, output_columns, project_columns, cache); if (!ds->ValidateParams()) { return nullptr; @@ -464,6 +479,14 @@ std::shared_ptr Dataset::Zip(const std::vector return ds->ValidateParams() ? ds : nullptr; } +Status Dataset::AddCacheOp(std::vector> *node_ops) { + if (cache_ != nullptr) { + std::shared_ptr cache_op; + RETURN_IF_NOT_OK(cache_->CreateCacheOp(num_workers_, &cache_op)); + node_ops->push_back(cache_op); + } + return Status::OK(); +} SchemaObj::SchemaObj(const std::string &schema_file) : schema_file_(schema_file), num_rows_(0), dataset_type_("") {} @@ -831,8 +854,13 @@ std::vector> AlbumNode::Build() { // Constructor for CelebANode CelebANode::CelebANode(const std::string &dataset_dir, const std::string &usage, const std::shared_ptr &sampler, const bool &decode, - const std::set &extensions) - : dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler), decode_(decode), extensions_(extensions) {} + const std::set &extensions, const std::shared_ptr &cache) + : Dataset(cache), + dataset_dir_(dataset_dir), + usage_(usage), + sampler_(sampler), + decode_(decode), + extensions_(extensions) {} Status CelebANode::ValidateParams() { RETURN_IF_NOT_OK(ValidateDatasetDirParam("CelebANode", dataset_dir_)); @@ -855,15 +883,18 @@ std::vector> CelebANode::Build() { // label is like this:0 1 0 0 1...... RETURN_EMPTY_IF_ERROR( schema->AddColumn(ColDescriptor("attr", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1))); + node_ops.push_back(std::make_shared(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_, decode_, usage_, extensions_, std::move(schema), std::move(sampler_->Build()))); + return node_ops; } // Constructor for Cifar10Node -Cifar10Node::Cifar10Node(const std::string &dataset_dir, const std::string &usage, std::shared_ptr sampler) - : dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {} +Cifar10Node::Cifar10Node(const std::string &dataset_dir, const std::string &usage, std::shared_ptr sampler, + std::shared_ptr cache) + : Dataset(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {} Status Cifar10Node::ValidateParams() { RETURN_IF_NOT_OK(ValidateDatasetDirParam("Cifar10Node", dataset_dir_)); @@ -887,16 +918,19 @@ std::vector> Cifar10Node::Build() { RETURN_EMPTY_IF_ERROR( schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar))); + RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops)); + node_ops.push_back(std::make_shared(CifarOp::CifarType::kCifar10, usage_, num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_, std::move(schema), std::move(sampler_->Build()))); + return node_ops; } // Constructor for Cifar100Node Cifar100Node::Cifar100Node(const std::string &dataset_dir, const std::string &usage, - std::shared_ptr sampler) - : dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {} + std::shared_ptr sampler, std::shared_ptr cache) + : Dataset(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {} Status Cifar100Node::ValidateParams() { RETURN_IF_NOT_OK(ValidateDatasetDirParam("Cifar100Node", dataset_dir_)); @@ -922,16 +956,20 @@ std::vector> Cifar100Node::Build() { RETURN_EMPTY_IF_ERROR( schema->AddColumn(ColDescriptor("fine_label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar))); + RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops)); + node_ops.push_back(std::make_shared(CifarOp::CifarType::kCifar100, usage_, num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_, std::move(schema), std::move(sampler_->Build()))); + return node_ops; } // Constructor for CLUENode CLUENode::CLUENode(const std::vector clue_files, std::string task, std::string usage, int64_t num_samples, - ShuffleMode shuffle, int32_t num_shards, int32_t shard_id) - : dataset_files_(clue_files), + ShuffleMode shuffle, int32_t num_shards, int32_t shard_id, std::shared_ptr cache) + : Dataset(std::move(cache)), + dataset_files_(clue_files), task_(task), usage_(usage), num_samples_(num_samples), @@ -973,6 +1011,7 @@ std::vector CLUENode::split(const std::string &s, char delim) { std::vector> CLUENode::Build() { // A vector containing shared pointer to the Dataset Ops that this object will create std::vector> node_ops; + std::map key_map; if (task_ == "AFQMC") { if (usage_ == "train") { @@ -1102,15 +1141,22 @@ std::vector> CLUENode::Build() { rows_per_buffer_, &shuffle_op)); node_ops.push_back(shuffle_op); } + RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops)); node_ops.push_back(clue_op); + return node_ops; } // Constructor for CocoNode CocoNode::CocoNode(const std::string &dataset_dir, const std::string &annotation_file, const std::string &task, - const bool &decode, const std::shared_ptr &sampler) - : dataset_dir_(dataset_dir), annotation_file_(annotation_file), task_(task), decode_(decode), sampler_(sampler) {} + const bool &decode, const std::shared_ptr &sampler, std::shared_ptr cache) + : Dataset(std::move(cache)), + dataset_dir_(dataset_dir), + annotation_file_(annotation_file), + task_(task), + decode_(decode), + sampler_(sampler) {} Status CocoNode::ValidateParams() { RETURN_IF_NOT_OK(ValidateDatasetDirParam("CocoNode", dataset_dir_)); @@ -1186,7 +1232,10 @@ std::vector> CocoNode::Build() { std::shared_ptr op = std::make_shared(task_type, dataset_dir_, annotation_file_, num_workers_, rows_per_buffer_, connector_que_size_, decode_, std::move(schema), std::move(sampler_->Build())); + RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops)); + node_ops.push_back(op); + return node_ops; } @@ -1194,8 +1243,9 @@ std::vector> CocoNode::Build() { CSVNode::CSVNode(const std::vector &csv_files, char field_delim, const std::vector> &column_defaults, const std::vector &column_names, int64_t num_samples, ShuffleMode shuffle, - int32_t num_shards, int32_t shard_id) - : dataset_files_(csv_files), + int32_t num_shards, int32_t shard_id, std::shared_ptr cache) + : Dataset(std::move(cache)), + dataset_files_(csv_files), field_delim_(field_delim), column_defaults_(column_defaults), column_names_(column_names), @@ -1274,17 +1324,26 @@ std::vector> CSVNode::Build() { // Add the shuffle op after this op RETURN_EMPTY_IF_ERROR(AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_, rows_per_buffer_, &shuffle_op)); + node_ops.push_back(shuffle_op); } + RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops)); node_ops.push_back(csv_op); + return node_ops; } #ifndef ENABLE_ANDROID ManifestNode::ManifestNode(const std::string &dataset_file, const std::string &usage, const std::shared_ptr &sampler, - const std::map &class_indexing, bool decode) - : dataset_file_(dataset_file), usage_(usage), decode_(decode), class_index_(class_indexing), sampler_(sampler) {} + const std::map &class_indexing, bool decode, + std::shared_ptr cache) + : Dataset(std::move(cache)), + dataset_file_(dataset_file), + usage_(usage), + decode_(decode), + class_index_(class_indexing), + sampler_(sampler) {} Status ManifestNode::ValidateParams() { std::vector forbidden_symbols = {':', '*', '?', '"', '<', '>', '|', '`', '&', '\'', ';'}; @@ -1326,8 +1385,10 @@ std::vector> ManifestNode::Build() { manifest_op = std::make_shared(num_workers_, rows_per_buffer_, dataset_file_, connector_que_size_, decode_, class_index_, std::move(schema), std::move(sampler_->Build()), usage_); + RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops)); node_ops.push_back(manifest_op); + return node_ops; } #endif @@ -1466,8 +1527,9 @@ std::vector> MindDataNode::Build() { } #endif -MnistNode::MnistNode(std::string dataset_dir, std::string usage, std::shared_ptr sampler) - : dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {} +MnistNode::MnistNode(std::string dataset_dir, std::string usage, std::shared_ptr sampler, + std::shared_ptr cache) + : Dataset(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {} Status MnistNode::ValidateParams() { RETURN_IF_NOT_OK(ValidateDatasetDirParam("MnistNode", dataset_dir_)); @@ -1489,9 +1551,11 @@ std::vector> MnistNode::Build() { TensorShape scalar = TensorShape::CreateScalar(); RETURN_EMPTY_IF_ERROR( schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar))); + RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops)); node_ops.push_back(std::make_shared(usage_, num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_, std::move(schema), std::move(sampler_->Build()))); + return node_ops; } @@ -1560,14 +1624,18 @@ std::vector> RandomNode::Build() { std::shared_ptr op; op = std::make_shared(num_workers_, connector_que_size_, rows_per_buffer_, total_rows_, std::move(data_schema), std::move(sampler_->Build())); + RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops)); + node_ops.push_back(op); + return node_ops; } // Constructor for TextFileNode TextFileNode::TextFileNode(std::vector dataset_files, int32_t num_samples, ShuffleMode shuffle, - int32_t num_shards, int32_t shard_id) - : dataset_files_(dataset_files), + int32_t num_shards, int32_t shard_id, std::shared_ptr cache) + : Dataset(std::move(cache)), + dataset_files_(dataset_files), num_samples_(num_samples), shuffle_(shuffle), num_shards_(num_shards), @@ -1622,9 +1690,11 @@ std::vector> TextFileNode::Build() { rows_per_buffer_, &shuffle_op)); node_ops.push_back(shuffle_op); } + RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops)); // Add TextFileOp node_ops.push_back(text_file_op); + return node_ops; } @@ -1673,6 +1743,7 @@ std::vector> TFRecordNode::Build() { rows_per_buffer_, &shuffle_op)); node_ops.push_back(shuffle_op); } + RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops)); // Add TFReaderOp node_ops.push_back(tf_reader_op); @@ -1681,8 +1752,10 @@ std::vector> TFRecordNode::Build() { // Constructor for VOCNode VOCNode::VOCNode(const std::string &dataset_dir, const std::string &task, const std::string &usage, - const std::map &class_indexing, bool decode, std::shared_ptr sampler) - : dataset_dir_(dataset_dir), + const std::map &class_indexing, bool decode, std::shared_ptr sampler, + std::shared_ptr cache) + : Dataset(std::move(cache)), + dataset_dir_(dataset_dir), task_(task), usage_(usage), class_index_(class_indexing), @@ -1755,9 +1828,18 @@ std::vector> VOCNode::Build() { std::shared_ptr voc_op; voc_op = std::make_shared(task_type_, usage_, dataset_dir_, class_index_, num_workers_, rows_per_buffer_, connector_que_size_, decode_, std::move(schema), std::move(sampler_->Build())); + RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops)); + node_ops.push_back(voc_op); return node_ops; } +std::shared_ptr CreateDatasetCache(session_id_type id, uint64_t mem_sz, bool spill, + std::optional hostname, std::optional port, + std::optional num_connections, + std::optional prefetch_sz) { + auto cache = std::make_shared(id, mem_sz, spill, hostname, port, num_connections, prefetch_sz); + return cache->ValidateParams() ? cache : nullptr; +} #endif #ifndef ENABLE_ANDROID @@ -1766,11 +1848,12 @@ std::vector> VOCNode::Build() { MapNode::MapNode(std::shared_ptr child, std::vector> operations, std::vector input_columns, std::vector output_columns, - const std::vector &project_columns) + const std::vector &project_columns, std::shared_ptr cache) : operations_(operations), input_columns_(input_columns), output_columns_(output_columns), - project_columns_(project_columns) { + project_columns_(project_columns), + Dataset(std::move(cache)) { this->children.push_back(child); } @@ -1793,6 +1876,7 @@ std::vector> MapNode::Build() { auto project_op = std::make_shared(project_columns_); node_ops.push_back(project_op); } + RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops)); node_ops.push_back(map_op); return node_ops; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/engine/ir/CMakeLists.txt index 3f7f85780a..ee9ebadb16 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/CMakeLists.txt +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/CMakeLists.txt @@ -1,3 +1,4 @@ file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) -add_subdirectory(datasetops) \ No newline at end of file +add_subdirectory(datasetops) +add_subdirectory(cache) \ No newline at end of file diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/cache/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/engine/ir/cache/CMakeLists.txt new file mode 100644 index 0000000000..b78e0f3d36 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/cache/CMakeLists.txt @@ -0,0 +1,4 @@ +file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") +set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) +add_library(engine-ir-cache OBJECT + dataset_cache_impl.cc) \ No newline at end of file diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/cache/dataset_cache.h b/mindspore/ccsrc/minddata/dataset/engine/ir/cache/dataset_cache.h new file mode 100644 index 0000000000..9096824d1b --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/cache/dataset_cache.h @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_CACHE_DATASET_CACHE_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_CACHE_DATASET_CACHE_H_ + +#include + +#include "minddata/dataset/util/status.h" +#include "minddata/dataset/engine/datasetops/dataset_op.h" + +namespace mindspore::dataset::api { + +class DatasetCache { + public: + virtual Status Build() = 0; + virtual Status ValidateParams() = 0; + virtual Status CreateCacheOp(int num_workers, std::shared_ptr *ds_op) = 0; +}; +} // namespace mindspore::dataset::api + +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_CACHE_DATASET_CACHE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/cache/dataset_cache_impl.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/cache/dataset_cache_impl.cc new file mode 100644 index 0000000000..ffacd02e8d --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/cache/dataset_cache_impl.cc @@ -0,0 +1,44 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +#include "minddata/dataset/engine/ir/cache/dataset_cache_impl.h" +#include "minddata/dataset/engine/datasetops/cache_op.h" + +namespace mindspore::dataset::api { + +/// Method to initialize the DatasetCache by creating an instance of a CacheClient +/// \return Status Error code +Status DatasetCacheImpl::Build() { + CacheClient::Builder builder; + builder.SetSessionId(session_id_).SetCacheMemSz(cache_mem_sz_).SetSpill(spill_); + if (hostname_) builder.SetHostname(hostname_.value()); + if (port_) builder.SetPort(port_.value()); + if (num_connections_) builder.SetNumConnections(num_connections_.value()); + if (prefetch_sz_) builder.SetPrefetchSize(prefetch_sz_.value()); + return builder.Build(&cache_client_); +} + +Status DatasetCacheImpl::CreateCacheOp(int32_t num_workers, std::shared_ptr *ds) { + CHECK_FAIL_RETURN_UNEXPECTED(cache_client_ != nullptr, "Cache client has not been created yet."); + std::shared_ptr cache_op = nullptr; + RETURN_IF_NOT_OK(CacheOp::Builder().SetNumWorkers(num_workers).SetClient(cache_client_).Build(&cache_op)); + *ds = cache_op; + + return Status::OK(); +} + +} // namespace mindspore::dataset::api diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/cache/dataset_cache_impl.h b/mindspore/ccsrc/minddata/dataset/engine/ir/cache/dataset_cache_impl.h new file mode 100644 index 0000000000..0efef3820a --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/cache/dataset_cache_impl.h @@ -0,0 +1,72 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_CACHE_DATASET_CACHE_IMPL_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_CACHE_DATASET_CACHE_IMPL_H_ + +#include +#include +#include +#include +#include "minddata/dataset/engine/cache/cache_client.h" +#include "minddata/dataset/engine/datasetops/cache_op.h" +#include "minddata/dataset/engine/ir/cache/dataset_cache.h" + +namespace mindspore::dataset::api { + +/// DatasetCache is the IR of CacheClient +class DatasetCacheImpl : public DatasetCache { + public: + /// + /// \brief Constructor + /// \param id A user assigned session id for the current pipeline + /// \param mem_sz Size of the memory set aside for the row caching. 0 for unlimited + /// \param spill Spill to disk if out of memory + /// \param hostname optional host name + /// \param port optional port + /// \param num_connections optional number of connections + /// \param prefetch_sz optional prefetch size + DatasetCacheImpl(session_id_type id, uint64_t mem_sz, bool spill, std::optional hostname, + std::optional port, std::optional num_connections, + std::optional prefetch_sz) + : session_id_(id), + cache_mem_sz_(mem_sz), + spill_(spill), + hostname_(std::move(hostname)), + port_(std::move(port)), + num_connections_(std::move(num_connections)), + prefetch_sz_(std::move(prefetch_sz)) {} + + /// Method to initialize the DatasetCache by creating an instance of a CacheClient + /// \return Status Error code + Status Build() override; + + Status CreateCacheOp(int32_t num_workers, std::shared_ptr *ds) override; + + Status ValidateParams() override { return Status::OK(); } + + private: + std::shared_ptr cache_client_; + session_id_type session_id_; + uint64_t cache_mem_sz_; + bool spill_; + std::optional hostname_; + std::optional port_; + std::optional num_connections_; + std::optional prefetch_sz_; +}; +} // namespace mindspore::dataset::api + +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_CACHE_DATASET_CACHE_IMPL_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/image_folder_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/image_folder_node.cc index a9c58e6bab..714d6f9799 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/image_folder_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/image_folder_node.cc @@ -32,13 +32,15 @@ namespace api { ImageFolderNode::ImageFolderNode(std::string dataset_dir, bool decode, std::shared_ptr sampler, bool recursive, std::set extensions, - std::map class_indexing) + std::map class_indexing, + std::shared_ptr cache = nullptr) : dataset_dir_(dataset_dir), decode_(decode), sampler_(sampler), recursive_(recursive), class_indexing_(class_indexing), - exts_(extensions) {} + exts_(extensions), + Dataset(std::move(cache)) {} Status ImageFolderNode::ValidateParams() { RETURN_IF_NOT_OK(ValidateDatasetDirParam("ImageFolderNode", dataset_dir_)); @@ -60,6 +62,9 @@ std::vector> ImageFolderNode::Build() { schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1))); RETURN_EMPTY_IF_ERROR( schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_INT32), TensorImpl::kFlexible, 0, &scalar))); + + RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops)); + node_ops.push_back(std::make_shared(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_, recursive_, decode_, exts_, class_indexing_, std::move(schema), std::move(sampler_->Build()))); diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/image_folder_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/image_folder_node.h index e160ca5a79..6f5345472d 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/image_folder_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/image_folder_node.h @@ -23,6 +23,7 @@ #include #include +#include "mindspore/ccsrc/minddata/dataset/engine/ir/cache/dataset_cache.h" #include "minddata/dataset/include/datasets.h" namespace mindspore { @@ -36,7 +37,8 @@ class ImageFolderNode : public Dataset { public: /// \brief Constructor ImageFolderNode(std::string dataset_dir, bool decode, std::shared_ptr sampler, bool recursive, - std::set extensions, std::map class_indexing); + std::set extensions, std::map class_indexing, + std::shared_ptr cache); /// \brief Destructor ~ImageFolderNode() = default; diff --git a/mindspore/ccsrc/minddata/dataset/include/datasets.h b/mindspore/ccsrc/minddata/dataset/include/datasets.h index f221a322f4..23d7e961b0 100644 --- a/mindspore/ccsrc/minddata/dataset/include/datasets.h +++ b/mindspore/ccsrc/minddata/dataset/include/datasets.h @@ -25,7 +25,9 @@ #include #include #include +#include "mindspore/ccsrc/minddata/dataset/engine/ir/cache/dataset_cache.h" #include "minddata/dataset/core/constants.h" + #include "minddata/dataset/engine/data_schema.h" #include "minddata/dataset/include/iterator.h" #include "minddata/dataset/include/samplers.h" @@ -147,10 +149,13 @@ std::shared_ptr Album(const std::string &dataset_dir, const std::stri /// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler()) /// \param[in] decode Decode the images after reading (default=false). /// \param[in] extensions Set of file extensions to be included in the dataset (default={}). +/// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used). +/// The cache feature is under development and is not recommended. /// \return Shared pointer to the current Dataset std::shared_ptr CelebA(const std::string &dataset_dir, const std::string &usage = "all", const std::shared_ptr &sampler = RandomSampler(), bool decode = false, - const std::set &extensions = {}); + const std::set &extensions = {}, + const std::shared_ptr &cache = nullptr); /// \brief Function to create a Cifar10 Dataset /// \notes The generated dataset has two columns ["image", "label"] @@ -158,9 +163,12 @@ std::shared_ptr CelebA(const std::string &dataset_dir, const std::st /// \param[in] usage of CIFAR10, can be "train", "test" or "all" (default = "all"). /// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given, /// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler()) +/// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used). +/// The cache feature is under development and is not recommended. /// \return Shared pointer to the current Dataset std::shared_ptr Cifar10(const std::string &dataset_dir, const std::string &usage = "all", - const std::shared_ptr &sampler = RandomSampler()); + const std::shared_ptr &sampler = RandomSampler(), + const std::shared_ptr &cache = nullptr); /// \brief Function to create a Cifar100 Dataset /// \notes The generated dataset has three columns ["image", "coarse_label", "fine_label"] @@ -168,9 +176,12 @@ std::shared_ptr Cifar10(const std::string &dataset_dir, const std:: /// \param[in] usage of CIFAR100, can be "train", "test" or "all" (default = "all"). /// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given, /// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler()) +/// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used). +/// The cache feature is under development and is not recommended. /// \return Shared pointer to the current Dataset std::shared_ptr Cifar100(const std::string &dataset_dir, const std::string &usage = "all", - const std::shared_ptr &sampler = RandomSampler()); + const std::shared_ptr &sampler = RandomSampler(), + const std::shared_ptr &cache = nullptr); /// \brief Function to create a CLUENode /// \notes The generated dataset has a variable number of columns depending on the task and usage @@ -188,11 +199,13 @@ std::shared_ptr Cifar100(const std::string &dataset_dir, const std /// \param[in] num_shards Number of shards that the dataset should be divided into. (Default = 1) /// \param[in] shard_id The shard ID within num_shards. This argument should be /// specified only when num_shards is also specified. (Default = 0) +/// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used). +/// The cache feature is under development and is not recommended. /// \return Shared pointer to the current CLUENode std::shared_ptr CLUE(const std::vector &dataset_files, const std::string &task = "AFQMC", const std::string &usage = "train", int64_t num_samples = 0, - ShuffleMode shuffle = ShuffleMode::kGlobal, int32_t num_shards = 1, - int32_t shard_id = 0); + ShuffleMode shuffle = ShuffleMode::kGlobal, int32_t num_shards = 1, int32_t shard_id = 0, + const std::shared_ptr &cache = nullptr); /// \brief Function to create a CocoNode /// \notes The generated dataset has multi-columns : @@ -209,10 +222,13 @@ std::shared_ptr CLUE(const std::vector &dataset_files, co /// \param[in] decode Decode the images after reading /// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given, /// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler()) +/// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used). +/// The cache feature is under development and is not recommended. /// \return Shared pointer to the current Dataset std::shared_ptr Coco(const std::string &dataset_dir, const std::string &annotation_file, const std::string &task = "Detection", const bool &decode = false, - const std::shared_ptr &sampler = RandomSampler()); + const std::shared_ptr &sampler = RandomSampler(), + const std::shared_ptr &cache = nullptr); /// \brief Function to create a CSVNode /// \notes The generated dataset has a variable number of columns @@ -233,11 +249,14 @@ std::shared_ptr Coco(const std::string &dataset_dir, const std::string /// \param[in] num_shards Number of shards that the dataset should be divided into. (Default = 1) /// \param[in] shard_id The shard ID within num_shards. This argument should be /// specified only when num_shards is also specified. (Default = 0) +/// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used). +/// The cache feature is under development and is not recommended. /// \return Shared pointer to the current Dataset std::shared_ptr CSV(const std::vector &dataset_files, char field_delim = ',', const std::vector> &column_defaults = {}, const std::vector &column_names = {}, int64_t num_samples = 0, - ShuffleMode shuffle = ShuffleMode::kGlobal, int32_t num_shards = 1, int32_t shard_id = 0); + ShuffleMode shuffle = ShuffleMode::kGlobal, int32_t num_shards = 1, int32_t shard_id = 0, + const std::shared_ptr &cache = nullptr); /// \brief Function to create an ImageFolderNode /// \notes A source dataset that reads images from a tree of directories @@ -249,11 +268,14 @@ std::shared_ptr CSV(const std::vector &dataset_files, char /// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler()) /// \param[in] extensions File extensions to be read /// \param[in] class_indexing a class name to label map +/// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used). +/// The cache feature is under development and is not recommended. /// \return Shared pointer to the current ImageFolderNode std::shared_ptr ImageFolder(const std::string &dataset_dir, bool decode = false, const std::shared_ptr &sampler = RandomSampler(), const std::set &extensions = {}, - const std::map &class_indexing = {}); + const std::map &class_indexing = {}, + const std::shared_ptr &cache = nullptr); #ifndef ENABLE_ANDROID /// \brief Function to create a ManifestNode @@ -265,10 +287,13 @@ std::shared_ptr ImageFolder(const std::string &dataset_dir, boo /// \param[in] class_indexing A str-to-int mapping from label name to index (default={}, the folder /// names will be sorted alphabetically and each class will be given a unique index starting from 0). /// \param[in] decode Decode the images after reading (default=false). +/// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used). +/// The cache feature is under development and is not recommended. /// \return Shared pointer to the current ManifestNode std::shared_ptr Manifest(const std::string &dataset_file, const std::string &usage = "train", const std::shared_ptr &sampler = RandomSampler(), - const std::map &class_indexing = {}, bool decode = false); + const std::map &class_indexing = {}, bool decode = false, + const std::shared_ptr &cache = nullptr); #endif #ifndef ENABLE_ANDROID @@ -308,9 +333,12 @@ std::shared_ptr MindData(const std::vector &dataset_f /// \param[in] usage of MNIST, can be "train", "test" or "all" (default = "all"). /// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given, /// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler()) +/// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used). +/// The cache feature is under development and is not recommended. /// \return Shared pointer to the current MnistNode std::shared_ptr Mnist(const std::string &dataset_dir, const std::string &usage = "all", - const std::shared_ptr &sampler = RandomSampler()); + const std::shared_ptr &sampler = RandomSampler(), + const std::shared_ptr &cache = nullptr); /// \brief Function to create a ConcatNode /// \notes Reload "+" operator to concat two datasets @@ -326,11 +354,14 @@ std::shared_ptr operator+(const std::shared_ptr &datasets1, /// \param[in] columns_list List of columns to be read (default={}, read all columns) /// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given, /// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler()) +/// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used). +/// The cache feature is under development and is not recommended. /// \return Shared pointer to the current Dataset template > std::shared_ptr RandomData(const int32_t &total_rows = 0, const T &schema = nullptr, const std::vector &columns_list = {}, - const std::shared_ptr &sampler = RandomSampler()) { + const std::shared_ptr &sampler = RandomSampler(), + const std::shared_ptr &cache = nullptr) { if (total_rows < 0) { MS_LOG(ERROR) << "RandomNode: total_rows must be greater than or equal 0, now get " << total_rows; return nullptr; @@ -356,9 +387,11 @@ std::shared_ptr RandomData(const int32_t &total_rows = 0, const T &s std::shared_ptr ds; if constexpr (std::is_same::value || std::is_same>::value) { std::shared_ptr schema_obj = schema; - ds = std::make_shared(total_rows, std::move(schema_obj), std::move(columns_list), std::move(sampler)); + ds = std::make_shared(total_rows, std::move(schema_obj), std::move(columns_list), std::move(sampler), + cache); } else { - ds = std::make_shared(total_rows, std::move(schema), std::move(columns_list), std::move(sampler)); + ds = + std::make_shared(total_rows, std::move(schema), std::move(columns_list), std::move(sampler), cache); } return ds; } @@ -377,10 +410,12 @@ std::shared_ptr RandomData(const int32_t &total_rows = 0, const T &s /// \param[in] num_shards Number of shards that the dataset should be divided into. (Default = 1) /// \param[in] shard_id The shard ID within num_shards. This argument should be /// specified only when num_shards is also specified. (Default = 0) +/// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used). +/// The cache feature is under development and is not recommended. /// \return Shared pointer to the current TextFileNode std::shared_ptr TextFile(const std::vector &dataset_files, int64_t num_samples = 0, ShuffleMode shuffle = ShuffleMode::kGlobal, int32_t num_shards = 1, - int32_t shard_id = 0); + int32_t shard_id = 0, const std::shared_ptr &cache = nullptr); #ifndef ENABLE_ANDROID /// \brief Function to create a TFRecordNode @@ -404,12 +439,15 @@ std::shared_ptr TextFile(const std::vector &dataset_f /// when num_shards is also specified. (Default = 0) /// \param[in] shard_equal_rows Get equal rows for all shards. (Default = False, number of rows of /// each shard may be not equal) +/// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used). +/// The cache feature is under development and is not recommended. /// \return Shared pointer to the current TFRecordNode template > std::shared_ptr TFRecord(const std::vector &dataset_files, const T &schema = nullptr, const std::vector &columns_list = {}, int64_t num_samples = 0, ShuffleMode shuffle = ShuffleMode::kGlobal, int32_t num_shards = 1, - int32_t shard_id = 0, bool shard_equal_rows = false) { + int32_t shard_id = 0, bool shard_equal_rows = false, + const std::shared_ptr &cache = nullptr) { if (dataset_files.empty()) { MS_LOG(ERROR) << "TFRecordNode: dataset_files is not specified."; return nullptr; @@ -441,7 +479,7 @@ std::shared_ptr TFRecord(const std::vector &dataset_f if constexpr (std::is_same::value || std::is_same>::value) { std::shared_ptr schema_obj = schema; ds = std::make_shared(dataset_files, schema_obj, columns_list, num_samples, shuffle, num_shards, - shard_id, shard_equal_rows); + shard_id, shard_equal_rows, cache); } else { std::string schema_path = schema; if (!schema_path.empty()) { @@ -452,7 +490,7 @@ std::shared_ptr TFRecord(const std::vector &dataset_f } } ds = std::make_shared(dataset_files, schema_path, columns_list, num_samples, shuffle, num_shards, - shard_id, shard_equal_rows); + shard_id, shard_equal_rows, cache); } return ds; } @@ -469,11 +507,28 @@ std::shared_ptr TFRecord(const std::vector &dataset_f /// \param[in] decode Decode the images after reading /// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given, /// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler()) +/// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used). +/// The cache feature is under development and is not recommended. /// \return Shared pointer to the current Dataset std::shared_ptr VOC(const std::string &dataset_dir, const std::string &task = "Segmentation", const std::string &usage = "train", const std::map &class_indexing = {}, bool decode = false, - const std::shared_ptr &sampler = RandomSampler()); + const std::shared_ptr &sampler = RandomSampler(), + const std::shared_ptr &cache = nullptr); + +/// \brief Function the create a cache to be attached to a dataset +/// \param id A user assigned session id for the current pipeline +/// \param mem_sz Size of the memory set aside for the row caching. 0 for unlimited +/// \param spill Spill to disk if out of memory +/// \param hostname optional host name +/// \param port optional port +/// \param num_connections optional number of connections +/// \param prefetch_sz optional prefetch size +/// \return Shared pointer to DatasetCache. If error, nullptr is returned. +std::shared_ptr CreateDatasetCache(session_id_type id, uint64_t mem_sz, bool spill, + std::optional hostname, std::optional port, + std::optional num_connections, + std::optional prefetch_sz); #endif /// \brief Function to create a ZipNode @@ -493,6 +548,10 @@ class Dataset : public std::enable_shared_from_this { /// \brief Constructor Dataset(); + /// \brief Constructor that initializes the cache + /// \param dataset_cache DatasetCache + explicit Dataset(const std::shared_ptr &dataset_cache); + /// \brief Destructor ~Dataset() = default; @@ -610,11 +669,14 @@ class Dataset : public std::enable_shared_from_this { /// last operation. The default output_columns will have the same /// name as the input columns, i.e., the columns will be replaced /// \param[in] project_columns A list of column names to project + /// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used). + /// The cache feature is under development and is not recommended. /// \return Shared pointer to the current MapNode std::shared_ptr Map(std::vector> operations, std::vector input_columns = {}, std::vector output_columns = {}, - const std::vector &project_columns = {}); + const std::vector &project_columns = {}, + const std::shared_ptr &cache = nullptr); /// \brief Function to create a Project Dataset /// \notes Applies project to the dataset @@ -670,6 +732,9 @@ class Dataset : public std::enable_shared_from_this { int32_t rows_per_buffer_; int32_t connector_que_size_; int32_t worker_connector_size_; + + std::shared_ptr cache_; + Status AddCacheOp(std::vector> *node_ops); }; class SchemaObj { @@ -766,7 +831,7 @@ class CelebANode : public Dataset { public: /// \brief Constructor CelebANode(const std::string &dataset_dir, const std::string &usage, const std::shared_ptr &sampler, - const bool &decode, const std::set &extensions); + const bool &decode, const std::set &extensions, const std::shared_ptr &cache); /// \brief Destructor ~CelebANode() = default; @@ -792,7 +857,8 @@ class CelebANode : public Dataset { class Cifar10Node : public Dataset { public: /// \brief Constructor - Cifar10Node(const std::string &dataset_dir, const std::string &usage, std::shared_ptr sampler); + Cifar10Node(const std::string &dataset_dir, const std::string &usage, std::shared_ptr sampler, + std::shared_ptr cache); /// \brief Destructor ~Cifar10Node() = default; @@ -814,7 +880,8 @@ class Cifar10Node : public Dataset { class Cifar100Node : public Dataset { public: /// \brief Constructor - Cifar100Node(const std::string &dataset_dir, const std::string &usage, std::shared_ptr sampler); + Cifar100Node(const std::string &dataset_dir, const std::string &usage, std::shared_ptr sampler, + std::shared_ptr cache); /// \brief Destructor ~Cifar100Node() = default; @@ -839,7 +906,7 @@ class CLUENode : public Dataset { public: /// \brief Constructor CLUENode(const std::vector dataset_files, std::string task, std::string usage, int64_t num_samples, - ShuffleMode shuffle, int32_t num_shards, int32_t shard_id); + ShuffleMode shuffle, int32_t num_shards, int32_t shard_id, std::shared_ptr cache); /// \brief Destructor ~CLUENode() = default; @@ -870,7 +937,7 @@ class CocoNode : public Dataset { public: /// \brief Constructor CocoNode(const std::string &dataset_dir, const std::string &annotation_file, const std::string &task, - const bool &decode, const std::shared_ptr &sampler); + const bool &decode, const std::shared_ptr &sampler, std::shared_ptr cache); /// \brief Destructor ~CocoNode() = default; @@ -918,7 +985,8 @@ class CSVNode : public Dataset { /// \brief Constructor CSVNode(const std::vector &dataset_files, char field_delim, const std::vector> &column_defaults, const std::vector &column_names, - int64_t num_samples, ShuffleMode shuffle, int32_t num_shards, int32_t shard_id); + int64_t num_samples, ShuffleMode shuffle, int32_t num_shards, int32_t shard_id, + std::shared_ptr cache); /// \brief Destructor ~CSVNode() = default; @@ -947,7 +1015,7 @@ class ManifestNode : public Dataset { public: /// \brief Constructor ManifestNode(const std::string &dataset_file, const std::string &usage, const std::shared_ptr &sampler, - const std::map &class_indexing, bool decode); + const std::map &class_indexing, bool decode, std::shared_ptr cache); /// \brief Destructor ~ManifestNode() = default; @@ -1016,7 +1084,8 @@ class MindDataNode : public Dataset { class MnistNode : public Dataset { public: /// \brief Constructor - MnistNode(std::string dataset_dir, std::string usage, std::shared_ptr sampler); + MnistNode(std::string dataset_dir, std::string usage, std::shared_ptr sampler, + std::shared_ptr cache); /// \brief Destructor ~MnistNode() = default; @@ -1044,8 +1113,9 @@ class RandomNode : public Dataset { /// \brief Constructor RandomNode(const int32_t &total_rows, std::shared_ptr schema, const std::vector &columns_list, - const std::shared_ptr &sampler) - : total_rows_(total_rows), + const std::shared_ptr &sampler, std::shared_ptr cache) + : Dataset(std::move(cache)), + total_rows_(total_rows), schema_path_(""), schema_(std::move(schema)), columns_list_(columns_list), @@ -1053,8 +1123,12 @@ class RandomNode : public Dataset { /// \brief Constructor RandomNode(const int32_t &total_rows, std::string schema_path, const std::vector &columns_list, - const std::shared_ptr &sampler) - : total_rows_(total_rows), schema_path_(schema_path), columns_list_(columns_list), sampler_(std::move(sampler)) {} + const std::shared_ptr &sampler, std::shared_ptr cache) + : Dataset(std::move(cache)), + total_rows_(total_rows), + schema_path_(schema_path), + columns_list_(columns_list), + sampler_(std::move(sampler)) {} /// \brief Destructor ~RandomNode() = default; @@ -1088,7 +1162,7 @@ class TextFileNode : public Dataset { public: /// \brief Constructor TextFileNode(std::vector dataset_files, int32_t num_samples, ShuffleMode shuffle, int32_t num_shards, - int32_t shard_id); + int32_t shard_id, std::shared_ptr cache); /// \brief Destructor ~TextFileNode() = default; @@ -1117,8 +1191,9 @@ class TFRecordNode : public Dataset { /// \note Parameter 'schema' is the path to the schema file TFRecordNode(const std::vector &dataset_files, std::string schema, const std::vector &columns_list, int64_t num_samples, ShuffleMode shuffle, - int32_t num_shards, int32_t shard_id, bool shard_equal_rows) - : dataset_files_(dataset_files), + int32_t num_shards, int32_t shard_id, bool shard_equal_rows, std::shared_ptr cache) + : Dataset(std::move(cache)), + dataset_files_(dataset_files), schema_path_(schema), columns_list_(columns_list), num_samples_(num_samples), @@ -1131,8 +1206,9 @@ class TFRecordNode : public Dataset { /// \note Parameter 'schema' is shared pointer to Schema object TFRecordNode(const std::vector &dataset_files, std::shared_ptr schema, const std::vector &columns_list, int64_t num_samples, ShuffleMode shuffle, - int32_t num_shards, int32_t shard_id, bool shard_equal_rows) - : dataset_files_(dataset_files), + int32_t num_shards, int32_t shard_id, bool shard_equal_rows, std::shared_ptr cache) + : Dataset(std::move(cache)), + dataset_files_(dataset_files), schema_obj_(schema), columns_list_(columns_list), num_samples_(num_samples), @@ -1169,7 +1245,8 @@ class VOCNode : public Dataset { public: /// \brief Constructor VOCNode(const std::string &dataset_dir, const std::string &task, const std::string &usage, - const std::map &class_indexing, bool decode, std::shared_ptr sampler); + const std::map &class_indexing, bool decode, std::shared_ptr sampler, + std::shared_ptr cache); /// \brief Destructor ~VOCNode() = default; @@ -1206,7 +1283,7 @@ class MapNode : public Dataset { /// \brief Constructor MapNode(std::shared_ptr child, std::vector> operations, std::vector input_columns = {}, std::vector output_columns = {}, - const std::vector &columns = {}); + const std::vector &columns = {}, std::shared_ptr cache = nullptr); /// \brief Destructor ~MapNode() = default;