Merge pull request !7609 from h.farahat/cache_c++_apitags/v1.1.0
| @@ -80,6 +80,8 @@ add_dependencies(text-kernels core) | |||||
| add_dependencies(cpp-API core) | add_dependencies(cpp-API core) | ||||
| add_dependencies(engine-ir-datasetops core) | add_dependencies(engine-ir-datasetops core) | ||||
| add_dependencies(engine-ir-datasetops-source core) | add_dependencies(engine-ir-datasetops-source core) | ||||
| add_dependencies(engine-ir-cache core) | |||||
| if (ENABLE_PYTHON) | if (ENABLE_PYTHON) | ||||
| add_dependencies(APItoPython core) | add_dependencies(APItoPython core) | ||||
| endif() | endif() | ||||
| @@ -102,8 +104,9 @@ set(submodules | |||||
| $<TARGET_OBJECTS:kernels-data> | $<TARGET_OBJECTS:kernels-data> | ||||
| $<TARGET_OBJECTS:cpp-API> | $<TARGET_OBJECTS:cpp-API> | ||||
| $<TARGET_OBJECTS:engine-ir-datasetops> | $<TARGET_OBJECTS:engine-ir-datasetops> | ||||
| $<TARGET_OBJECTS:engine-ir-datasetops-source> | |||||
| $<TARGET_OBJECTS:kernels-soft-dvpp-image> | |||||
| $<TARGET_OBJECTS:engine-ir-datasetops-source> | |||||
| $<TARGET_OBJECTS:engine-ir-cache> | |||||
| $<TARGET_OBJECTS:kernels-soft-dvpp-image> | |||||
| $<TARGET_OBJECTS:soft-dvpp-utils> | $<TARGET_OBJECTS:soft-dvpp-utils> | ||||
| $<TARGET_OBJECTS:engine-datasetops-source> | $<TARGET_OBJECTS:engine-datasetops-source> | ||||
| $<TARGET_OBJECTS:engine-datasetops-source-sampler> | $<TARGET_OBJECTS:engine-datasetops-source-sampler> | ||||
| @@ -18,6 +18,7 @@ | |||||
| #include <algorithm> | #include <algorithm> | ||||
| #include <fstream> | #include <fstream> | ||||
| #include <unordered_set> | #include <unordered_set> | ||||
| #include <utility> | |||||
| #include "minddata/dataset/include/samplers.h" | #include "minddata/dataset/include/samplers.h" | ||||
| #include "minddata/dataset/include/transforms.h" | #include "minddata/dataset/include/transforms.h" | ||||
| // Source dataset headers (in alphabetical order) | // Source dataset headers (in alphabetical order) | ||||
| @@ -32,6 +33,7 @@ | |||||
| #ifndef ENABLE_ANDROID | #ifndef ENABLE_ANDROID | ||||
| #include "minddata/dataset/engine/datasetops/source/manifest_op.h" | #include "minddata/dataset/engine/datasetops/source/manifest_op.h" | ||||
| #include "minddata/dataset/engine/datasetops/source/mindrecord_op.h" | #include "minddata/dataset/engine/datasetops/source/mindrecord_op.h" | ||||
| #include "minddata/dataset/engine/ir/cache/dataset_cache_impl.h" | |||||
| #endif | #endif | ||||
| #include "minddata/dataset/engine/datasetops/source/mnist_op.h" | #include "minddata/dataset/engine/datasetops/source/mnist_op.h" | ||||
| #include "minddata/dataset/engine/datasetops/source/random_data_op.h" | #include "minddata/dataset/engine/datasetops/source/random_data_op.h" | ||||
| @@ -113,6 +115,9 @@ Dataset::Dataset() { | |||||
| worker_connector_size_ = cfg->worker_connector_size(); | worker_connector_size_ = cfg->worker_connector_size(); | ||||
| } | } | ||||
| // Constructor to initialize the cache | |||||
| Dataset::Dataset(const std::shared_ptr<DatasetCache> &dataset_cache) : Dataset() { cache_ = dataset_cache; } | |||||
| /// \brief Function to create a SchemaObj | /// \brief Function to create a SchemaObj | ||||
| /// \param[in] schema_file Path of schema file | /// \param[in] schema_file Path of schema file | ||||
| /// \return Shared pointer to the current schema | /// \return Shared pointer to the current schema | ||||
| @@ -137,8 +142,9 @@ std::shared_ptr<AlbumNode> Album(const std::string &dataset_dir, const std::stri | |||||
| // Function to create a CelebANode. | // Function to create a CelebANode. | ||||
| std::shared_ptr<CelebANode> CelebA(const std::string &dataset_dir, const std::string &usage, | std::shared_ptr<CelebANode> CelebA(const std::string &dataset_dir, const std::string &usage, | ||||
| const std::shared_ptr<SamplerObj> &sampler, bool decode, | const std::shared_ptr<SamplerObj> &sampler, bool decode, | ||||
| const std::set<std::string> &extensions) { | |||||
| auto ds = std::make_shared<CelebANode>(dataset_dir, usage, sampler, decode, extensions); | |||||
| const std::set<std::string> &extensions, | |||||
| const std::shared_ptr<DatasetCache> &cache) { | |||||
| auto ds = std::make_shared<CelebANode>(dataset_dir, usage, sampler, decode, extensions, cache); | |||||
| // Call derived class validation method. | // Call derived class validation method. | ||||
| return ds->ValidateParams() ? ds : nullptr; | return ds->ValidateParams() ? ds : nullptr; | ||||
| @@ -146,8 +152,9 @@ std::shared_ptr<CelebANode> CelebA(const std::string &dataset_dir, const std::st | |||||
| // Function to create a Cifar10Node. | // Function to create a Cifar10Node. | ||||
| std::shared_ptr<Cifar10Node> Cifar10(const std::string &dataset_dir, const std::string &usage, | std::shared_ptr<Cifar10Node> Cifar10(const std::string &dataset_dir, const std::string &usage, | ||||
| const std::shared_ptr<SamplerObj> &sampler) { | |||||
| auto ds = std::make_shared<Cifar10Node>(dataset_dir, usage, sampler); | |||||
| const std::shared_ptr<SamplerObj> &sampler, | |||||
| const std::shared_ptr<DatasetCache> &cache) { | |||||
| auto ds = std::make_shared<Cifar10Node>(dataset_dir, usage, sampler, cache); | |||||
| // Call derived class validation method. | // Call derived class validation method. | ||||
| return ds->ValidateParams() ? ds : nullptr; | return ds->ValidateParams() ? ds : nullptr; | ||||
| @@ -155,8 +162,9 @@ std::shared_ptr<Cifar10Node> Cifar10(const std::string &dataset_dir, const std:: | |||||
| // Function to create a Cifar100Node. | // Function to create a Cifar100Node. | ||||
| std::shared_ptr<Cifar100Node> Cifar100(const std::string &dataset_dir, const std::string &usage, | std::shared_ptr<Cifar100Node> Cifar100(const std::string &dataset_dir, const std::string &usage, | ||||
| const std::shared_ptr<SamplerObj> &sampler) { | |||||
| auto ds = std::make_shared<Cifar100Node>(dataset_dir, usage, sampler); | |||||
| const std::shared_ptr<SamplerObj> &sampler, | |||||
| const std::shared_ptr<DatasetCache> &cache) { | |||||
| auto ds = std::make_shared<Cifar100Node>(dataset_dir, usage, sampler, cache); | |||||
| // Call derived class validation method. | // Call derived class validation method. | ||||
| return ds->ValidateParams() ? ds : nullptr; | return ds->ValidateParams() ? ds : nullptr; | ||||
| @@ -165,8 +173,8 @@ std::shared_ptr<Cifar100Node> Cifar100(const std::string &dataset_dir, const std | |||||
| // Function to create a CLUENode. | // Function to create a CLUENode. | ||||
| std::shared_ptr<CLUENode> CLUE(const std::vector<std::string> &clue_files, const std::string &task, | std::shared_ptr<CLUENode> CLUE(const std::vector<std::string> &clue_files, const std::string &task, | ||||
| const std::string &usage, int64_t num_samples, ShuffleMode shuffle, int32_t num_shards, | const std::string &usage, int64_t num_samples, ShuffleMode shuffle, int32_t num_shards, | ||||
| int32_t shard_id) { | |||||
| auto ds = std::make_shared<CLUENode>(clue_files, task, usage, num_samples, shuffle, num_shards, shard_id); | |||||
| int32_t shard_id, const std::shared_ptr<DatasetCache> &cache) { | |||||
| auto ds = std::make_shared<CLUENode>(clue_files, task, usage, num_samples, shuffle, num_shards, shard_id, cache); | |||||
| // Call derived class validation method. | // Call derived class validation method. | ||||
| return ds->ValidateParams() ? ds : nullptr; | return ds->ValidateParams() ? ds : nullptr; | ||||
| @@ -174,9 +182,9 @@ std::shared_ptr<CLUENode> CLUE(const std::vector<std::string> &clue_files, const | |||||
| // Function to create a CocoNode. | // Function to create a CocoNode. | ||||
| std::shared_ptr<CocoNode> Coco(const std::string &dataset_dir, const std::string &annotation_file, | std::shared_ptr<CocoNode> Coco(const std::string &dataset_dir, const std::string &annotation_file, | ||||
| const std::string &task, const bool &decode, | |||||
| const std::shared_ptr<SamplerObj> &sampler) { | |||||
| auto ds = std::make_shared<CocoNode>(dataset_dir, annotation_file, task, decode, sampler); | |||||
| const std::string &task, const bool &decode, const std::shared_ptr<SamplerObj> &sampler, | |||||
| const std::shared_ptr<DatasetCache> &cache) { | |||||
| auto ds = std::make_shared<CocoNode>(dataset_dir, annotation_file, task, decode, sampler, cache); | |||||
| // Call derived class validation method. | // Call derived class validation method. | ||||
| return ds->ValidateParams() ? ds : nullptr; | return ds->ValidateParams() ? ds : nullptr; | ||||
| @@ -186,9 +194,9 @@ std::shared_ptr<CocoNode> Coco(const std::string &dataset_dir, const std::string | |||||
| std::shared_ptr<CSVNode> CSV(const std::vector<std::string> &dataset_files, char field_delim, | std::shared_ptr<CSVNode> CSV(const std::vector<std::string> &dataset_files, char field_delim, | ||||
| const std::vector<std::shared_ptr<CsvBase>> &column_defaults, | const std::vector<std::shared_ptr<CsvBase>> &column_defaults, | ||||
| const std::vector<std::string> &column_names, int64_t num_samples, ShuffleMode shuffle, | const std::vector<std::string> &column_names, int64_t num_samples, ShuffleMode shuffle, | ||||
| int32_t num_shards, int32_t shard_id) { | |||||
| int32_t num_shards, int32_t shard_id, const std::shared_ptr<DatasetCache> &cache) { | |||||
| auto ds = std::make_shared<CSVNode>(dataset_files, field_delim, column_defaults, column_names, num_samples, shuffle, | auto ds = std::make_shared<CSVNode>(dataset_files, field_delim, column_defaults, column_names, num_samples, shuffle, | ||||
| num_shards, shard_id); | |||||
| num_shards, shard_id, cache); | |||||
| // Call derived class validation method. | // Call derived class validation method. | ||||
| return ds->ValidateParams() ? ds : nullptr; | return ds->ValidateParams() ? ds : nullptr; | ||||
| @@ -198,12 +206,14 @@ std::shared_ptr<CSVNode> CSV(const std::vector<std::string> &dataset_files, char | |||||
| std::shared_ptr<ImageFolderNode> ImageFolder(const std::string &dataset_dir, bool decode, | std::shared_ptr<ImageFolderNode> ImageFolder(const std::string &dataset_dir, bool decode, | ||||
| const std::shared_ptr<SamplerObj> &sampler, | const std::shared_ptr<SamplerObj> &sampler, | ||||
| const std::set<std::string> &extensions, | const std::set<std::string> &extensions, | ||||
| const std::map<std::string, int32_t> &class_indexing) { | |||||
| const std::map<std::string, int32_t> &class_indexing, | |||||
| const std::shared_ptr<DatasetCache> &cache) { | |||||
| // This arg exists in ImageFolderOp, but not externalized (in Python API). The default value is false. | // This arg exists in ImageFolderOp, but not externalized (in Python API). The default value is false. | ||||
| bool recursive = false; | bool recursive = false; | ||||
| // Create logical representation of ImageFolderNode. | // Create logical representation of ImageFolderNode. | ||||
| auto ds = std::make_shared<ImageFolderNode>(dataset_dir, decode, sampler, recursive, extensions, class_indexing); | |||||
| auto ds = | |||||
| std::make_shared<ImageFolderNode>(dataset_dir, decode, sampler, recursive, extensions, class_indexing, cache); | |||||
| // Call derived class validation method. | // Call derived class validation method. | ||||
| return ds->ValidateParams() ? ds : nullptr; | return ds->ValidateParams() ? ds : nullptr; | ||||
| @@ -213,8 +223,9 @@ std::shared_ptr<ImageFolderNode> ImageFolder(const std::string &dataset_dir, boo | |||||
| // Function to create a ManifestNode. | // Function to create a ManifestNode. | ||||
| std::shared_ptr<ManifestNode> Manifest(const std::string &dataset_file, const std::string &usage, | std::shared_ptr<ManifestNode> Manifest(const std::string &dataset_file, const std::string &usage, | ||||
| const std::shared_ptr<SamplerObj> &sampler, | const std::shared_ptr<SamplerObj> &sampler, | ||||
| const std::map<std::string, int32_t> &class_indexing, bool decode) { | |||||
| auto ds = std::make_shared<ManifestNode>(dataset_file, usage, sampler, class_indexing, decode); | |||||
| const std::map<std::string, int32_t> &class_indexing, bool decode, | |||||
| const std::shared_ptr<DatasetCache> &cache) { | |||||
| auto ds = std::make_shared<ManifestNode>(dataset_file, usage, sampler, class_indexing, decode, cache); | |||||
| // Call derived class validation method. | // Call derived class validation method. | ||||
| return ds->ValidateParams() ? ds : nullptr; | return ds->ValidateParams() ? ds : nullptr; | ||||
| @@ -244,8 +255,9 @@ std::shared_ptr<MindDataNode> MindData(const std::vector<std::string> &dataset_f | |||||
| // Function to create a MnistNode. | // Function to create a MnistNode. | ||||
| std::shared_ptr<MnistNode> Mnist(const std::string &dataset_dir, const std::string &usage, | std::shared_ptr<MnistNode> Mnist(const std::string &dataset_dir, const std::string &usage, | ||||
| const std::shared_ptr<SamplerObj> &sampler) { | |||||
| auto ds = std::make_shared<MnistNode>(dataset_dir, usage, sampler); | |||||
| const std::shared_ptr<SamplerObj> &sampler, | |||||
| const std::shared_ptr<DatasetCache> &cache) { | |||||
| auto ds = std::make_shared<MnistNode>(dataset_dir, usage, sampler, cache); | |||||
| // Call derived class validation method. | // Call derived class validation method. | ||||
| return ds->ValidateParams() ? ds : nullptr; | return ds->ValidateParams() ? ds : nullptr; | ||||
| @@ -262,8 +274,9 @@ std::shared_ptr<ConcatNode> operator+(const std::shared_ptr<Dataset> &datasets1, | |||||
| // Function to create a TextFileNode. | // Function to create a TextFileNode. | ||||
| std::shared_ptr<TextFileNode> TextFile(const std::vector<std::string> &dataset_files, int64_t num_samples, | std::shared_ptr<TextFileNode> TextFile(const std::vector<std::string> &dataset_files, int64_t num_samples, | ||||
| ShuffleMode shuffle, int32_t num_shards, int32_t shard_id) { | |||||
| auto ds = std::make_shared<TextFileNode>(dataset_files, num_samples, shuffle, num_shards, shard_id); | |||||
| ShuffleMode shuffle, int32_t num_shards, int32_t shard_id, | |||||
| const std::shared_ptr<DatasetCache> &cache) { | |||||
| auto ds = std::make_shared<TextFileNode>(dataset_files, num_samples, shuffle, num_shards, shard_id, cache); | |||||
| // Call derived class validation method. | // Call derived class validation method. | ||||
| return ds->ValidateParams() ? ds : nullptr; | return ds->ValidateParams() ? ds : nullptr; | ||||
| @@ -273,8 +286,8 @@ std::shared_ptr<TextFileNode> TextFile(const std::vector<std::string> &dataset_f | |||||
| // Function to create a VOCNode. | // Function to create a VOCNode. | ||||
| std::shared_ptr<VOCNode> VOC(const std::string &dataset_dir, const std::string &task, const std::string &usage, | std::shared_ptr<VOCNode> VOC(const std::string &dataset_dir, const std::string &task, const std::string &usage, | ||||
| const std::map<std::string, int32_t> &class_indexing, bool decode, | const std::map<std::string, int32_t> &class_indexing, bool decode, | ||||
| const std::shared_ptr<SamplerObj> &sampler) { | |||||
| auto ds = std::make_shared<VOCNode>(dataset_dir, task, usage, class_indexing, decode, sampler); | |||||
| const std::shared_ptr<SamplerObj> &sampler, const std::shared_ptr<DatasetCache> &cache) { | |||||
| auto ds = std::make_shared<VOCNode>(dataset_dir, task, usage, class_indexing, decode, sampler, cache); | |||||
| // Call derived class validation method. | // Call derived class validation method. | ||||
| return ds->ValidateParams() ? ds : nullptr; | return ds->ValidateParams() ? ds : nullptr; | ||||
| @@ -365,8 +378,10 @@ std::shared_ptr<ConcatNode> Dataset::Concat(const std::vector<std::shared_ptr<Da | |||||
| // Function to create a Map dataset. | // Function to create a Map dataset. | ||||
| std::shared_ptr<MapNode> Dataset::Map(std::vector<std::shared_ptr<TensorOperation>> operations, | std::shared_ptr<MapNode> Dataset::Map(std::vector<std::shared_ptr<TensorOperation>> operations, | ||||
| std::vector<std::string> input_columns, std::vector<std::string> output_columns, | std::vector<std::string> input_columns, std::vector<std::string> output_columns, | ||||
| const std::vector<std::string> &project_columns) { | |||||
| auto ds = std::make_shared<MapNode>(shared_from_this(), operations, input_columns, output_columns, project_columns); | |||||
| const std::vector<std::string> &project_columns, | |||||
| const std::shared_ptr<DatasetCache> &cache) { | |||||
| auto ds = | |||||
| std::make_shared<MapNode>(shared_from_this(), operations, input_columns, output_columns, project_columns, cache); | |||||
| if (!ds->ValidateParams()) { | if (!ds->ValidateParams()) { | ||||
| return nullptr; | return nullptr; | ||||
| @@ -464,6 +479,14 @@ std::shared_ptr<ZipNode> Dataset::Zip(const std::vector<std::shared_ptr<Dataset> | |||||
| return ds->ValidateParams() ? ds : nullptr; | return ds->ValidateParams() ? ds : nullptr; | ||||
| } | } | ||||
| Status Dataset::AddCacheOp(std::vector<std::shared_ptr<DatasetOp>> *node_ops) { | |||||
| if (cache_ != nullptr) { | |||||
| std::shared_ptr<DatasetOp> cache_op; | |||||
| RETURN_IF_NOT_OK(cache_->CreateCacheOp(num_workers_, &cache_op)); | |||||
| node_ops->push_back(cache_op); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| SchemaObj::SchemaObj(const std::string &schema_file) : schema_file_(schema_file), num_rows_(0), dataset_type_("") {} | SchemaObj::SchemaObj(const std::string &schema_file) : schema_file_(schema_file), num_rows_(0), dataset_type_("") {} | ||||
| @@ -831,8 +854,13 @@ std::vector<std::shared_ptr<DatasetOp>> AlbumNode::Build() { | |||||
| // Constructor for CelebANode | // Constructor for CelebANode | ||||
| CelebANode::CelebANode(const std::string &dataset_dir, const std::string &usage, | CelebANode::CelebANode(const std::string &dataset_dir, const std::string &usage, | ||||
| const std::shared_ptr<SamplerObj> &sampler, const bool &decode, | const std::shared_ptr<SamplerObj> &sampler, const bool &decode, | ||||
| const std::set<std::string> &extensions) | |||||
| : dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler), decode_(decode), extensions_(extensions) {} | |||||
| const std::set<std::string> &extensions, const std::shared_ptr<DatasetCache> &cache) | |||||
| : Dataset(cache), | |||||
| dataset_dir_(dataset_dir), | |||||
| usage_(usage), | |||||
| sampler_(sampler), | |||||
| decode_(decode), | |||||
| extensions_(extensions) {} | |||||
| Status CelebANode::ValidateParams() { | Status CelebANode::ValidateParams() { | ||||
| RETURN_IF_NOT_OK(ValidateDatasetDirParam("CelebANode", dataset_dir_)); | RETURN_IF_NOT_OK(ValidateDatasetDirParam("CelebANode", dataset_dir_)); | ||||
| @@ -855,15 +883,18 @@ std::vector<std::shared_ptr<DatasetOp>> CelebANode::Build() { | |||||
| // label is like this:0 1 0 0 1...... | // label is like this:0 1 0 0 1...... | ||||
| RETURN_EMPTY_IF_ERROR( | RETURN_EMPTY_IF_ERROR( | ||||
| schema->AddColumn(ColDescriptor("attr", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1))); | schema->AddColumn(ColDescriptor("attr", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1))); | ||||
| node_ops.push_back(std::make_shared<CelebAOp>(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_, | node_ops.push_back(std::make_shared<CelebAOp>(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_, | ||||
| decode_, usage_, extensions_, std::move(schema), | decode_, usage_, extensions_, std::move(schema), | ||||
| std::move(sampler_->Build()))); | std::move(sampler_->Build()))); | ||||
| return node_ops; | return node_ops; | ||||
| } | } | ||||
| // Constructor for Cifar10Node | // Constructor for Cifar10Node | ||||
| Cifar10Node::Cifar10Node(const std::string &dataset_dir, const std::string &usage, std::shared_ptr<SamplerObj> sampler) | |||||
| : dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {} | |||||
| Cifar10Node::Cifar10Node(const std::string &dataset_dir, const std::string &usage, std::shared_ptr<SamplerObj> sampler, | |||||
| std::shared_ptr<DatasetCache> cache) | |||||
| : Dataset(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {} | |||||
| Status Cifar10Node::ValidateParams() { | Status Cifar10Node::ValidateParams() { | ||||
| RETURN_IF_NOT_OK(ValidateDatasetDirParam("Cifar10Node", dataset_dir_)); | RETURN_IF_NOT_OK(ValidateDatasetDirParam("Cifar10Node", dataset_dir_)); | ||||
| @@ -887,16 +918,19 @@ std::vector<std::shared_ptr<DatasetOp>> Cifar10Node::Build() { | |||||
| RETURN_EMPTY_IF_ERROR( | RETURN_EMPTY_IF_ERROR( | ||||
| schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar))); | schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar))); | ||||
| RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops)); | |||||
| node_ops.push_back(std::make_shared<CifarOp>(CifarOp::CifarType::kCifar10, usage_, num_workers_, rows_per_buffer_, | node_ops.push_back(std::make_shared<CifarOp>(CifarOp::CifarType::kCifar10, usage_, num_workers_, rows_per_buffer_, | ||||
| dataset_dir_, connector_que_size_, std::move(schema), | dataset_dir_, connector_que_size_, std::move(schema), | ||||
| std::move(sampler_->Build()))); | std::move(sampler_->Build()))); | ||||
| return node_ops; | return node_ops; | ||||
| } | } | ||||
| // Constructor for Cifar100Node | // Constructor for Cifar100Node | ||||
| Cifar100Node::Cifar100Node(const std::string &dataset_dir, const std::string &usage, | Cifar100Node::Cifar100Node(const std::string &dataset_dir, const std::string &usage, | ||||
| std::shared_ptr<SamplerObj> sampler) | |||||
| : dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {} | |||||
| std::shared_ptr<SamplerObj> sampler, std::shared_ptr<DatasetCache> cache) | |||||
| : Dataset(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {} | |||||
| Status Cifar100Node::ValidateParams() { | Status Cifar100Node::ValidateParams() { | ||||
| RETURN_IF_NOT_OK(ValidateDatasetDirParam("Cifar100Node", dataset_dir_)); | RETURN_IF_NOT_OK(ValidateDatasetDirParam("Cifar100Node", dataset_dir_)); | ||||
| @@ -922,16 +956,20 @@ std::vector<std::shared_ptr<DatasetOp>> Cifar100Node::Build() { | |||||
| RETURN_EMPTY_IF_ERROR( | RETURN_EMPTY_IF_ERROR( | ||||
| schema->AddColumn(ColDescriptor("fine_label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar))); | schema->AddColumn(ColDescriptor("fine_label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar))); | ||||
| RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops)); | |||||
| node_ops.push_back(std::make_shared<CifarOp>(CifarOp::CifarType::kCifar100, usage_, num_workers_, rows_per_buffer_, | node_ops.push_back(std::make_shared<CifarOp>(CifarOp::CifarType::kCifar100, usage_, num_workers_, rows_per_buffer_, | ||||
| dataset_dir_, connector_que_size_, std::move(schema), | dataset_dir_, connector_que_size_, std::move(schema), | ||||
| std::move(sampler_->Build()))); | std::move(sampler_->Build()))); | ||||
| return node_ops; | return node_ops; | ||||
| } | } | ||||
| // Constructor for CLUENode | // Constructor for CLUENode | ||||
| CLUENode::CLUENode(const std::vector<std::string> clue_files, std::string task, std::string usage, int64_t num_samples, | CLUENode::CLUENode(const std::vector<std::string> clue_files, std::string task, std::string usage, int64_t num_samples, | ||||
| ShuffleMode shuffle, int32_t num_shards, int32_t shard_id) | |||||
| : dataset_files_(clue_files), | |||||
| ShuffleMode shuffle, int32_t num_shards, int32_t shard_id, std::shared_ptr<DatasetCache> cache) | |||||
| : Dataset(std::move(cache)), | |||||
| dataset_files_(clue_files), | |||||
| task_(task), | task_(task), | ||||
| usage_(usage), | usage_(usage), | ||||
| num_samples_(num_samples), | num_samples_(num_samples), | ||||
| @@ -973,6 +1011,7 @@ std::vector<std::string> CLUENode::split(const std::string &s, char delim) { | |||||
| std::vector<std::shared_ptr<DatasetOp>> CLUENode::Build() { | std::vector<std::shared_ptr<DatasetOp>> CLUENode::Build() { | ||||
| // A vector containing shared pointer to the Dataset Ops that this object will create | // A vector containing shared pointer to the Dataset Ops that this object will create | ||||
| std::vector<std::shared_ptr<DatasetOp>> node_ops; | std::vector<std::shared_ptr<DatasetOp>> node_ops; | ||||
| std::map<std::string, std::string> key_map; | std::map<std::string, std::string> key_map; | ||||
| if (task_ == "AFQMC") { | if (task_ == "AFQMC") { | ||||
| if (usage_ == "train") { | if (usage_ == "train") { | ||||
| @@ -1102,15 +1141,22 @@ std::vector<std::shared_ptr<DatasetOp>> CLUENode::Build() { | |||||
| rows_per_buffer_, &shuffle_op)); | rows_per_buffer_, &shuffle_op)); | ||||
| node_ops.push_back(shuffle_op); | node_ops.push_back(shuffle_op); | ||||
| } | } | ||||
| RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops)); | |||||
| node_ops.push_back(clue_op); | node_ops.push_back(clue_op); | ||||
| return node_ops; | return node_ops; | ||||
| } | } | ||||
| // Constructor for CocoNode | // Constructor for CocoNode | ||||
| CocoNode::CocoNode(const std::string &dataset_dir, const std::string &annotation_file, const std::string &task, | CocoNode::CocoNode(const std::string &dataset_dir, const std::string &annotation_file, const std::string &task, | ||||
| const bool &decode, const std::shared_ptr<SamplerObj> &sampler) | |||||
| : dataset_dir_(dataset_dir), annotation_file_(annotation_file), task_(task), decode_(decode), sampler_(sampler) {} | |||||
| const bool &decode, const std::shared_ptr<SamplerObj> &sampler, std::shared_ptr<DatasetCache> cache) | |||||
| : Dataset(std::move(cache)), | |||||
| dataset_dir_(dataset_dir), | |||||
| annotation_file_(annotation_file), | |||||
| task_(task), | |||||
| decode_(decode), | |||||
| sampler_(sampler) {} | |||||
| Status CocoNode::ValidateParams() { | Status CocoNode::ValidateParams() { | ||||
| RETURN_IF_NOT_OK(ValidateDatasetDirParam("CocoNode", dataset_dir_)); | RETURN_IF_NOT_OK(ValidateDatasetDirParam("CocoNode", dataset_dir_)); | ||||
| @@ -1186,7 +1232,10 @@ std::vector<std::shared_ptr<DatasetOp>> CocoNode::Build() { | |||||
| std::shared_ptr<CocoOp> op = | std::shared_ptr<CocoOp> op = | ||||
| std::make_shared<CocoOp>(task_type, dataset_dir_, annotation_file_, num_workers_, rows_per_buffer_, | std::make_shared<CocoOp>(task_type, dataset_dir_, annotation_file_, num_workers_, rows_per_buffer_, | ||||
| connector_que_size_, decode_, std::move(schema), std::move(sampler_->Build())); | connector_que_size_, decode_, std::move(schema), std::move(sampler_->Build())); | ||||
| RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops)); | |||||
| node_ops.push_back(op); | node_ops.push_back(op); | ||||
| return node_ops; | return node_ops; | ||||
| } | } | ||||
| @@ -1194,8 +1243,9 @@ std::vector<std::shared_ptr<DatasetOp>> CocoNode::Build() { | |||||
| CSVNode::CSVNode(const std::vector<std::string> &csv_files, char field_delim, | CSVNode::CSVNode(const std::vector<std::string> &csv_files, char field_delim, | ||||
| const std::vector<std::shared_ptr<CsvBase>> &column_defaults, | const std::vector<std::shared_ptr<CsvBase>> &column_defaults, | ||||
| const std::vector<std::string> &column_names, int64_t num_samples, ShuffleMode shuffle, | const std::vector<std::string> &column_names, int64_t num_samples, ShuffleMode shuffle, | ||||
| int32_t num_shards, int32_t shard_id) | |||||
| : dataset_files_(csv_files), | |||||
| int32_t num_shards, int32_t shard_id, std::shared_ptr<DatasetCache> cache) | |||||
| : Dataset(std::move(cache)), | |||||
| dataset_files_(csv_files), | |||||
| field_delim_(field_delim), | field_delim_(field_delim), | ||||
| column_defaults_(column_defaults), | column_defaults_(column_defaults), | ||||
| column_names_(column_names), | column_names_(column_names), | ||||
| @@ -1274,17 +1324,26 @@ std::vector<std::shared_ptr<DatasetOp>> CSVNode::Build() { | |||||
| // Add the shuffle op after this op | // Add the shuffle op after this op | ||||
| RETURN_EMPTY_IF_ERROR(AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_, | RETURN_EMPTY_IF_ERROR(AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_, | ||||
| rows_per_buffer_, &shuffle_op)); | rows_per_buffer_, &shuffle_op)); | ||||
| node_ops.push_back(shuffle_op); | node_ops.push_back(shuffle_op); | ||||
| } | } | ||||
| RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops)); | |||||
| node_ops.push_back(csv_op); | node_ops.push_back(csv_op); | ||||
| return node_ops; | return node_ops; | ||||
| } | } | ||||
| #ifndef ENABLE_ANDROID | #ifndef ENABLE_ANDROID | ||||
| ManifestNode::ManifestNode(const std::string &dataset_file, const std::string &usage, | ManifestNode::ManifestNode(const std::string &dataset_file, const std::string &usage, | ||||
| const std::shared_ptr<SamplerObj> &sampler, | const std::shared_ptr<SamplerObj> &sampler, | ||||
| const std::map<std::string, int32_t> &class_indexing, bool decode) | |||||
| : dataset_file_(dataset_file), usage_(usage), decode_(decode), class_index_(class_indexing), sampler_(sampler) {} | |||||
| const std::map<std::string, int32_t> &class_indexing, bool decode, | |||||
| std::shared_ptr<DatasetCache> cache) | |||||
| : Dataset(std::move(cache)), | |||||
| dataset_file_(dataset_file), | |||||
| usage_(usage), | |||||
| decode_(decode), | |||||
| class_index_(class_indexing), | |||||
| sampler_(sampler) {} | |||||
| Status ManifestNode::ValidateParams() { | Status ManifestNode::ValidateParams() { | ||||
| std::vector<char> forbidden_symbols = {':', '*', '?', '"', '<', '>', '|', '`', '&', '\'', ';'}; | std::vector<char> forbidden_symbols = {':', '*', '?', '"', '<', '>', '|', '`', '&', '\'', ';'}; | ||||
| @@ -1326,8 +1385,10 @@ std::vector<std::shared_ptr<DatasetOp>> ManifestNode::Build() { | |||||
| manifest_op = | manifest_op = | ||||
| std::make_shared<ManifestOp>(num_workers_, rows_per_buffer_, dataset_file_, connector_que_size_, decode_, | std::make_shared<ManifestOp>(num_workers_, rows_per_buffer_, dataset_file_, connector_que_size_, decode_, | ||||
| class_index_, std::move(schema), std::move(sampler_->Build()), usage_); | class_index_, std::move(schema), std::move(sampler_->Build()), usage_); | ||||
| RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops)); | |||||
| node_ops.push_back(manifest_op); | node_ops.push_back(manifest_op); | ||||
| return node_ops; | return node_ops; | ||||
| } | } | ||||
| #endif | #endif | ||||
| @@ -1466,8 +1527,9 @@ std::vector<std::shared_ptr<DatasetOp>> MindDataNode::Build() { | |||||
| } | } | ||||
| #endif | #endif | ||||
| MnistNode::MnistNode(std::string dataset_dir, std::string usage, std::shared_ptr<SamplerObj> sampler) | |||||
| : dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {} | |||||
| MnistNode::MnistNode(std::string dataset_dir, std::string usage, std::shared_ptr<SamplerObj> sampler, | |||||
| std::shared_ptr<DatasetCache> cache) | |||||
| : Dataset(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {} | |||||
| Status MnistNode::ValidateParams() { | Status MnistNode::ValidateParams() { | ||||
| RETURN_IF_NOT_OK(ValidateDatasetDirParam("MnistNode", dataset_dir_)); | RETURN_IF_NOT_OK(ValidateDatasetDirParam("MnistNode", dataset_dir_)); | ||||
| @@ -1489,9 +1551,11 @@ std::vector<std::shared_ptr<DatasetOp>> MnistNode::Build() { | |||||
| TensorShape scalar = TensorShape::CreateScalar(); | TensorShape scalar = TensorShape::CreateScalar(); | ||||
| RETURN_EMPTY_IF_ERROR( | RETURN_EMPTY_IF_ERROR( | ||||
| schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar))); | schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar))); | ||||
| RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops)); | |||||
| node_ops.push_back(std::make_shared<MnistOp>(usage_, num_workers_, rows_per_buffer_, dataset_dir_, | node_ops.push_back(std::make_shared<MnistOp>(usage_, num_workers_, rows_per_buffer_, dataset_dir_, | ||||
| connector_que_size_, std::move(schema), std::move(sampler_->Build()))); | connector_que_size_, std::move(schema), std::move(sampler_->Build()))); | ||||
| return node_ops; | return node_ops; | ||||
| } | } | ||||
| @@ -1560,14 +1624,18 @@ std::vector<std::shared_ptr<DatasetOp>> RandomNode::Build() { | |||||
| std::shared_ptr<RandomDataOp> op; | std::shared_ptr<RandomDataOp> op; | ||||
| op = std::make_shared<RandomDataOp>(num_workers_, connector_que_size_, rows_per_buffer_, total_rows_, | op = std::make_shared<RandomDataOp>(num_workers_, connector_que_size_, rows_per_buffer_, total_rows_, | ||||
| std::move(data_schema), std::move(sampler_->Build())); | std::move(data_schema), std::move(sampler_->Build())); | ||||
| RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops)); | |||||
| node_ops.push_back(op); | node_ops.push_back(op); | ||||
| return node_ops; | return node_ops; | ||||
| } | } | ||||
| // Constructor for TextFileNode | // Constructor for TextFileNode | ||||
| TextFileNode::TextFileNode(std::vector<std::string> dataset_files, int32_t num_samples, ShuffleMode shuffle, | TextFileNode::TextFileNode(std::vector<std::string> dataset_files, int32_t num_samples, ShuffleMode shuffle, | ||||
| int32_t num_shards, int32_t shard_id) | |||||
| : dataset_files_(dataset_files), | |||||
| int32_t num_shards, int32_t shard_id, std::shared_ptr<DatasetCache> cache) | |||||
| : Dataset(std::move(cache)), | |||||
| dataset_files_(dataset_files), | |||||
| num_samples_(num_samples), | num_samples_(num_samples), | ||||
| shuffle_(shuffle), | shuffle_(shuffle), | ||||
| num_shards_(num_shards), | num_shards_(num_shards), | ||||
| @@ -1622,9 +1690,11 @@ std::vector<std::shared_ptr<DatasetOp>> TextFileNode::Build() { | |||||
| rows_per_buffer_, &shuffle_op)); | rows_per_buffer_, &shuffle_op)); | ||||
| node_ops.push_back(shuffle_op); | node_ops.push_back(shuffle_op); | ||||
| } | } | ||||
| RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops)); | |||||
| // Add TextFileOp | // Add TextFileOp | ||||
| node_ops.push_back(text_file_op); | node_ops.push_back(text_file_op); | ||||
| return node_ops; | return node_ops; | ||||
| } | } | ||||
| @@ -1673,6 +1743,7 @@ std::vector<std::shared_ptr<DatasetOp>> TFRecordNode::Build() { | |||||
| rows_per_buffer_, &shuffle_op)); | rows_per_buffer_, &shuffle_op)); | ||||
| node_ops.push_back(shuffle_op); | node_ops.push_back(shuffle_op); | ||||
| } | } | ||||
| RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops)); | |||||
| // Add TFReaderOp | // Add TFReaderOp | ||||
| node_ops.push_back(tf_reader_op); | node_ops.push_back(tf_reader_op); | ||||
| @@ -1681,8 +1752,10 @@ std::vector<std::shared_ptr<DatasetOp>> TFRecordNode::Build() { | |||||
| // Constructor for VOCNode | // Constructor for VOCNode | ||||
| VOCNode::VOCNode(const std::string &dataset_dir, const std::string &task, const std::string &usage, | VOCNode::VOCNode(const std::string &dataset_dir, const std::string &task, const std::string &usage, | ||||
| const std::map<std::string, int32_t> &class_indexing, bool decode, std::shared_ptr<SamplerObj> sampler) | |||||
| : dataset_dir_(dataset_dir), | |||||
| const std::map<std::string, int32_t> &class_indexing, bool decode, std::shared_ptr<SamplerObj> sampler, | |||||
| std::shared_ptr<DatasetCache> cache) | |||||
| : Dataset(std::move(cache)), | |||||
| dataset_dir_(dataset_dir), | |||||
| task_(task), | task_(task), | ||||
| usage_(usage), | usage_(usage), | ||||
| class_index_(class_indexing), | class_index_(class_indexing), | ||||
| @@ -1755,9 +1828,18 @@ std::vector<std::shared_ptr<DatasetOp>> VOCNode::Build() { | |||||
| std::shared_ptr<VOCOp> voc_op; | std::shared_ptr<VOCOp> voc_op; | ||||
| voc_op = std::make_shared<VOCOp>(task_type_, usage_, dataset_dir_, class_index_, num_workers_, rows_per_buffer_, | voc_op = std::make_shared<VOCOp>(task_type_, usage_, dataset_dir_, class_index_, num_workers_, rows_per_buffer_, | ||||
| connector_que_size_, decode_, std::move(schema), std::move(sampler_->Build())); | connector_que_size_, decode_, std::move(schema), std::move(sampler_->Build())); | ||||
| RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops)); | |||||
| node_ops.push_back(voc_op); | node_ops.push_back(voc_op); | ||||
| return node_ops; | return node_ops; | ||||
| } | } | ||||
| std::shared_ptr<DatasetCache> CreateDatasetCache(session_id_type id, uint64_t mem_sz, bool spill, | |||||
| std::optional<std::string> hostname, std::optional<int32_t> port, | |||||
| std::optional<int32_t> num_connections, | |||||
| std::optional<int32_t> prefetch_sz) { | |||||
| auto cache = std::make_shared<DatasetCacheImpl>(id, mem_sz, spill, hostname, port, num_connections, prefetch_sz); | |||||
| return cache->ValidateParams() ? cache : nullptr; | |||||
| } | |||||
| #endif | #endif | ||||
| #ifndef ENABLE_ANDROID | #ifndef ENABLE_ANDROID | ||||
| @@ -1766,11 +1848,12 @@ std::vector<std::shared_ptr<DatasetOp>> VOCNode::Build() { | |||||
| MapNode::MapNode(std::shared_ptr<Dataset> child, std::vector<std::shared_ptr<TensorOperation>> operations, | MapNode::MapNode(std::shared_ptr<Dataset> child, std::vector<std::shared_ptr<TensorOperation>> operations, | ||||
| std::vector<std::string> input_columns, std::vector<std::string> output_columns, | std::vector<std::string> input_columns, std::vector<std::string> output_columns, | ||||
| const std::vector<std::string> &project_columns) | |||||
| const std::vector<std::string> &project_columns, std::shared_ptr<DatasetCache> cache) | |||||
| : operations_(operations), | : operations_(operations), | ||||
| input_columns_(input_columns), | input_columns_(input_columns), | ||||
| output_columns_(output_columns), | output_columns_(output_columns), | ||||
| project_columns_(project_columns) { | |||||
| project_columns_(project_columns), | |||||
| Dataset(std::move(cache)) { | |||||
| this->children.push_back(child); | this->children.push_back(child); | ||||
| } | } | ||||
| @@ -1793,6 +1876,7 @@ std::vector<std::shared_ptr<DatasetOp>> MapNode::Build() { | |||||
| auto project_op = std::make_shared<ProjectOp>(project_columns_); | auto project_op = std::make_shared<ProjectOp>(project_columns_); | ||||
| node_ops.push_back(project_op); | node_ops.push_back(project_op); | ||||
| } | } | ||||
| RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops)); | |||||
| node_ops.push_back(map_op); | node_ops.push_back(map_op); | ||||
| return node_ops; | return node_ops; | ||||
| @@ -1,3 +1,4 @@ | |||||
| file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") | file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") | ||||
| set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) | set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) | ||||
| add_subdirectory(datasetops) | |||||
| add_subdirectory(datasetops) | |||||
| add_subdirectory(cache) | |||||
| @@ -0,0 +1,4 @@ | |||||
| file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") | |||||
| set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) | |||||
| add_library(engine-ir-cache OBJECT | |||||
| dataset_cache_impl.cc) | |||||
| @@ -0,0 +1,34 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_CACHE_DATASET_CACHE_H_ | |||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_CACHE_DATASET_CACHE_H_ | |||||
| #include <memory> | |||||
| #include "minddata/dataset/util/status.h" | |||||
| #include "minddata/dataset/engine/datasetops/dataset_op.h" | |||||
| namespace mindspore::dataset::api { | |||||
| class DatasetCache { | |||||
| public: | |||||
| virtual Status Build() = 0; | |||||
| virtual Status ValidateParams() = 0; | |||||
| virtual Status CreateCacheOp(int num_workers, std::shared_ptr<DatasetOp> *ds_op) = 0; | |||||
| }; | |||||
| } // namespace mindspore::dataset::api | |||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_CACHE_DATASET_CACHE_H_ | |||||
| @@ -0,0 +1,44 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <memory> | |||||
| #include "minddata/dataset/engine/ir/cache/dataset_cache_impl.h" | |||||
| #include "minddata/dataset/engine/datasetops/cache_op.h" | |||||
| namespace mindspore::dataset::api { | |||||
| /// Method to initialize the DatasetCache by creating an instance of a CacheClient | |||||
| /// \return Status Error code | |||||
| Status DatasetCacheImpl::Build() { | |||||
| CacheClient::Builder builder; | |||||
| builder.SetSessionId(session_id_).SetCacheMemSz(cache_mem_sz_).SetSpill(spill_); | |||||
| if (hostname_) builder.SetHostname(hostname_.value()); | |||||
| if (port_) builder.SetPort(port_.value()); | |||||
| if (num_connections_) builder.SetNumConnections(num_connections_.value()); | |||||
| if (prefetch_sz_) builder.SetPrefetchSize(prefetch_sz_.value()); | |||||
| return builder.Build(&cache_client_); | |||||
| } | |||||
| Status DatasetCacheImpl::CreateCacheOp(int32_t num_workers, std::shared_ptr<DatasetOp> *ds) { | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(cache_client_ != nullptr, "Cache client has not been created yet."); | |||||
| std::shared_ptr<CacheOp> cache_op = nullptr; | |||||
| RETURN_IF_NOT_OK(CacheOp::Builder().SetNumWorkers(num_workers).SetClient(cache_client_).Build(&cache_op)); | |||||
| *ds = cache_op; | |||||
| return Status::OK(); | |||||
| } | |||||
| } // namespace mindspore::dataset::api | |||||
| @@ -0,0 +1,72 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_CACHE_DATASET_CACHE_IMPL_H_ | |||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_CACHE_DATASET_CACHE_IMPL_H_ | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <optional> | |||||
| #include <utility> | |||||
| #include "minddata/dataset/engine/cache/cache_client.h" | |||||
| #include "minddata/dataset/engine/datasetops/cache_op.h" | |||||
| #include "minddata/dataset/engine/ir/cache/dataset_cache.h" | |||||
| namespace mindspore::dataset::api { | |||||
| /// DatasetCache is the IR of CacheClient | |||||
| class DatasetCacheImpl : public DatasetCache { | |||||
| public: | |||||
| /// | |||||
| /// \brief Constructor | |||||
| /// \param id A user assigned session id for the current pipeline | |||||
| /// \param mem_sz Size of the memory set aside for the row caching. 0 for unlimited | |||||
| /// \param spill Spill to disk if out of memory | |||||
| /// \param hostname optional host name | |||||
| /// \param port optional port | |||||
| /// \param num_connections optional number of connections | |||||
| /// \param prefetch_sz optional prefetch size | |||||
| DatasetCacheImpl(session_id_type id, uint64_t mem_sz, bool spill, std::optional<std::string> hostname, | |||||
| std::optional<int32_t> port, std::optional<int32_t> num_connections, | |||||
| std::optional<int32_t> prefetch_sz) | |||||
| : session_id_(id), | |||||
| cache_mem_sz_(mem_sz), | |||||
| spill_(spill), | |||||
| hostname_(std::move(hostname)), | |||||
| port_(std::move(port)), | |||||
| num_connections_(std::move(num_connections)), | |||||
| prefetch_sz_(std::move(prefetch_sz)) {} | |||||
| /// Method to initialize the DatasetCache by creating an instance of a CacheClient | |||||
| /// \return Status Error code | |||||
| Status Build() override; | |||||
| Status CreateCacheOp(int32_t num_workers, std::shared_ptr<DatasetOp> *ds) override; | |||||
| Status ValidateParams() override { return Status::OK(); } | |||||
| private: | |||||
| std::shared_ptr<CacheClient> cache_client_; | |||||
| session_id_type session_id_; | |||||
| uint64_t cache_mem_sz_; | |||||
| bool spill_; | |||||
| std::optional<std::string> hostname_; | |||||
| std::optional<int32_t> port_; | |||||
| std::optional<int32_t> num_connections_; | |||||
| std::optional<int32_t> prefetch_sz_; | |||||
| }; | |||||
| } // namespace mindspore::dataset::api | |||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_CACHE_DATASET_CACHE_IMPL_H_ | |||||
| @@ -32,13 +32,15 @@ namespace api { | |||||
| ImageFolderNode::ImageFolderNode(std::string dataset_dir, bool decode, std::shared_ptr<SamplerObj> sampler, | ImageFolderNode::ImageFolderNode(std::string dataset_dir, bool decode, std::shared_ptr<SamplerObj> sampler, | ||||
| bool recursive, std::set<std::string> extensions, | bool recursive, std::set<std::string> extensions, | ||||
| std::map<std::string, int32_t> class_indexing) | |||||
| std::map<std::string, int32_t> class_indexing, | |||||
| std::shared_ptr<DatasetCache> cache = nullptr) | |||||
| : dataset_dir_(dataset_dir), | : dataset_dir_(dataset_dir), | ||||
| decode_(decode), | decode_(decode), | ||||
| sampler_(sampler), | sampler_(sampler), | ||||
| recursive_(recursive), | recursive_(recursive), | ||||
| class_indexing_(class_indexing), | class_indexing_(class_indexing), | ||||
| exts_(extensions) {} | |||||
| exts_(extensions), | |||||
| Dataset(std::move(cache)) {} | |||||
| Status ImageFolderNode::ValidateParams() { | Status ImageFolderNode::ValidateParams() { | ||||
| RETURN_IF_NOT_OK(ValidateDatasetDirParam("ImageFolderNode", dataset_dir_)); | RETURN_IF_NOT_OK(ValidateDatasetDirParam("ImageFolderNode", dataset_dir_)); | ||||
| @@ -60,6 +62,9 @@ std::vector<std::shared_ptr<DatasetOp>> ImageFolderNode::Build() { | |||||
| schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1))); | schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1))); | ||||
| RETURN_EMPTY_IF_ERROR( | RETURN_EMPTY_IF_ERROR( | ||||
| schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_INT32), TensorImpl::kFlexible, 0, &scalar))); | schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_INT32), TensorImpl::kFlexible, 0, &scalar))); | ||||
| RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops)); | |||||
| node_ops.push_back(std::make_shared<ImageFolderOp>(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_, | node_ops.push_back(std::make_shared<ImageFolderOp>(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_, | ||||
| recursive_, decode_, exts_, class_indexing_, std::move(schema), | recursive_, decode_, exts_, class_indexing_, std::move(schema), | ||||
| std::move(sampler_->Build()))); | std::move(sampler_->Build()))); | ||||
| @@ -23,6 +23,7 @@ | |||||
| #include <string> | #include <string> | ||||
| #include <vector> | #include <vector> | ||||
| #include "mindspore/ccsrc/minddata/dataset/engine/ir/cache/dataset_cache.h" | |||||
| #include "minddata/dataset/include/datasets.h" | #include "minddata/dataset/include/datasets.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| @@ -36,7 +37,8 @@ class ImageFolderNode : public Dataset { | |||||
| public: | public: | ||||
| /// \brief Constructor | /// \brief Constructor | ||||
| ImageFolderNode(std::string dataset_dir, bool decode, std::shared_ptr<SamplerObj> sampler, bool recursive, | ImageFolderNode(std::string dataset_dir, bool decode, std::shared_ptr<SamplerObj> sampler, bool recursive, | ||||
| std::set<std::string> extensions, std::map<std::string, int32_t> class_indexing); | |||||
| std::set<std::string> extensions, std::map<std::string, int32_t> class_indexing, | |||||
| std::shared_ptr<DatasetCache> cache); | |||||
| /// \brief Destructor | /// \brief Destructor | ||||
| ~ImageFolderNode() = default; | ~ImageFolderNode() = default; | ||||
| @@ -25,7 +25,9 @@ | |||||
| #include <unordered_set> | #include <unordered_set> | ||||
| #include <utility> | #include <utility> | ||||
| #include <vector> | #include <vector> | ||||
| #include "mindspore/ccsrc/minddata/dataset/engine/ir/cache/dataset_cache.h" | |||||
| #include "minddata/dataset/core/constants.h" | #include "minddata/dataset/core/constants.h" | ||||
| #include "minddata/dataset/engine/data_schema.h" | #include "minddata/dataset/engine/data_schema.h" | ||||
| #include "minddata/dataset/include/iterator.h" | #include "minddata/dataset/include/iterator.h" | ||||
| #include "minddata/dataset/include/samplers.h" | #include "minddata/dataset/include/samplers.h" | ||||
| @@ -147,10 +149,13 @@ std::shared_ptr<AlbumNode> Album(const std::string &dataset_dir, const std::stri | |||||
| /// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler()) | /// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler()) | ||||
| /// \param[in] decode Decode the images after reading (default=false). | /// \param[in] decode Decode the images after reading (default=false). | ||||
| /// \param[in] extensions Set of file extensions to be included in the dataset (default={}). | /// \param[in] extensions Set of file extensions to be included in the dataset (default={}). | ||||
| /// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used). | |||||
| /// The cache feature is under development and is not recommended. | |||||
| /// \return Shared pointer to the current Dataset | /// \return Shared pointer to the current Dataset | ||||
| std::shared_ptr<CelebANode> CelebA(const std::string &dataset_dir, const std::string &usage = "all", | std::shared_ptr<CelebANode> CelebA(const std::string &dataset_dir, const std::string &usage = "all", | ||||
| const std::shared_ptr<SamplerObj> &sampler = RandomSampler(), bool decode = false, | const std::shared_ptr<SamplerObj> &sampler = RandomSampler(), bool decode = false, | ||||
| const std::set<std::string> &extensions = {}); | |||||
| const std::set<std::string> &extensions = {}, | |||||
| const std::shared_ptr<DatasetCache> &cache = nullptr); | |||||
| /// \brief Function to create a Cifar10 Dataset | /// \brief Function to create a Cifar10 Dataset | ||||
| /// \notes The generated dataset has two columns ["image", "label"] | /// \notes The generated dataset has two columns ["image", "label"] | ||||
| @@ -158,9 +163,12 @@ std::shared_ptr<CelebANode> CelebA(const std::string &dataset_dir, const std::st | |||||
| /// \param[in] usage of CIFAR10, can be "train", "test" or "all" (default = "all"). | /// \param[in] usage of CIFAR10, can be "train", "test" or "all" (default = "all"). | ||||
| /// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given, | /// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given, | ||||
| /// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler()) | /// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler()) | ||||
| /// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used). | |||||
| /// The cache feature is under development and is not recommended. | |||||
| /// \return Shared pointer to the current Dataset | /// \return Shared pointer to the current Dataset | ||||
| std::shared_ptr<Cifar10Node> Cifar10(const std::string &dataset_dir, const std::string &usage = "all", | std::shared_ptr<Cifar10Node> Cifar10(const std::string &dataset_dir, const std::string &usage = "all", | ||||
| const std::shared_ptr<SamplerObj> &sampler = RandomSampler()); | |||||
| const std::shared_ptr<SamplerObj> &sampler = RandomSampler(), | |||||
| const std::shared_ptr<DatasetCache> &cache = nullptr); | |||||
| /// \brief Function to create a Cifar100 Dataset | /// \brief Function to create a Cifar100 Dataset | ||||
| /// \notes The generated dataset has three columns ["image", "coarse_label", "fine_label"] | /// \notes The generated dataset has three columns ["image", "coarse_label", "fine_label"] | ||||
| @@ -168,9 +176,12 @@ std::shared_ptr<Cifar10Node> Cifar10(const std::string &dataset_dir, const std:: | |||||
| /// \param[in] usage of CIFAR100, can be "train", "test" or "all" (default = "all"). | /// \param[in] usage of CIFAR100, can be "train", "test" or "all" (default = "all"). | ||||
| /// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given, | /// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given, | ||||
| /// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler()) | /// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler()) | ||||
| /// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used). | |||||
| /// The cache feature is under development and is not recommended. | |||||
| /// \return Shared pointer to the current Dataset | /// \return Shared pointer to the current Dataset | ||||
| std::shared_ptr<Cifar100Node> Cifar100(const std::string &dataset_dir, const std::string &usage = "all", | std::shared_ptr<Cifar100Node> Cifar100(const std::string &dataset_dir, const std::string &usage = "all", | ||||
| const std::shared_ptr<SamplerObj> &sampler = RandomSampler()); | |||||
| const std::shared_ptr<SamplerObj> &sampler = RandomSampler(), | |||||
| const std::shared_ptr<DatasetCache> &cache = nullptr); | |||||
| /// \brief Function to create a CLUENode | /// \brief Function to create a CLUENode | ||||
| /// \notes The generated dataset has a variable number of columns depending on the task and usage | /// \notes The generated dataset has a variable number of columns depending on the task and usage | ||||
| @@ -188,11 +199,13 @@ std::shared_ptr<Cifar100Node> Cifar100(const std::string &dataset_dir, const std | |||||
| /// \param[in] num_shards Number of shards that the dataset should be divided into. (Default = 1) | /// \param[in] num_shards Number of shards that the dataset should be divided into. (Default = 1) | ||||
| /// \param[in] shard_id The shard ID within num_shards. This argument should be | /// \param[in] shard_id The shard ID within num_shards. This argument should be | ||||
| /// specified only when num_shards is also specified. (Default = 0) | /// specified only when num_shards is also specified. (Default = 0) | ||||
| /// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used). | |||||
| /// The cache feature is under development and is not recommended. | |||||
| /// \return Shared pointer to the current CLUENode | /// \return Shared pointer to the current CLUENode | ||||
| std::shared_ptr<CLUENode> CLUE(const std::vector<std::string> &dataset_files, const std::string &task = "AFQMC", | std::shared_ptr<CLUENode> CLUE(const std::vector<std::string> &dataset_files, const std::string &task = "AFQMC", | ||||
| const std::string &usage = "train", int64_t num_samples = 0, | const std::string &usage = "train", int64_t num_samples = 0, | ||||
| ShuffleMode shuffle = ShuffleMode::kGlobal, int32_t num_shards = 1, | |||||
| int32_t shard_id = 0); | |||||
| ShuffleMode shuffle = ShuffleMode::kGlobal, int32_t num_shards = 1, int32_t shard_id = 0, | |||||
| const std::shared_ptr<DatasetCache> &cache = nullptr); | |||||
| /// \brief Function to create a CocoNode | /// \brief Function to create a CocoNode | ||||
| /// \notes The generated dataset has multi-columns : | /// \notes The generated dataset has multi-columns : | ||||
| @@ -209,10 +222,13 @@ std::shared_ptr<CLUENode> CLUE(const std::vector<std::string> &dataset_files, co | |||||
| /// \param[in] decode Decode the images after reading | /// \param[in] decode Decode the images after reading | ||||
| /// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given, | /// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given, | ||||
| /// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler()) | /// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler()) | ||||
| /// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used). | |||||
| /// The cache feature is under development and is not recommended. | |||||
| /// \return Shared pointer to the current Dataset | /// \return Shared pointer to the current Dataset | ||||
| std::shared_ptr<CocoNode> Coco(const std::string &dataset_dir, const std::string &annotation_file, | std::shared_ptr<CocoNode> Coco(const std::string &dataset_dir, const std::string &annotation_file, | ||||
| const std::string &task = "Detection", const bool &decode = false, | const std::string &task = "Detection", const bool &decode = false, | ||||
| const std::shared_ptr<SamplerObj> &sampler = RandomSampler()); | |||||
| const std::shared_ptr<SamplerObj> &sampler = RandomSampler(), | |||||
| const std::shared_ptr<DatasetCache> &cache = nullptr); | |||||
| /// \brief Function to create a CSVNode | /// \brief Function to create a CSVNode | ||||
| /// \notes The generated dataset has a variable number of columns | /// \notes The generated dataset has a variable number of columns | ||||
| @@ -233,11 +249,14 @@ std::shared_ptr<CocoNode> Coco(const std::string &dataset_dir, const std::string | |||||
| /// \param[in] num_shards Number of shards that the dataset should be divided into. (Default = 1) | /// \param[in] num_shards Number of shards that the dataset should be divided into. (Default = 1) | ||||
| /// \param[in] shard_id The shard ID within num_shards. This argument should be | /// \param[in] shard_id The shard ID within num_shards. This argument should be | ||||
| /// specified only when num_shards is also specified. (Default = 0) | /// specified only when num_shards is also specified. (Default = 0) | ||||
| /// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used). | |||||
| /// The cache feature is under development and is not recommended. | |||||
| /// \return Shared pointer to the current Dataset | /// \return Shared pointer to the current Dataset | ||||
| std::shared_ptr<CSVNode> CSV(const std::vector<std::string> &dataset_files, char field_delim = ',', | std::shared_ptr<CSVNode> CSV(const std::vector<std::string> &dataset_files, char field_delim = ',', | ||||
| const std::vector<std::shared_ptr<CsvBase>> &column_defaults = {}, | const std::vector<std::shared_ptr<CsvBase>> &column_defaults = {}, | ||||
| const std::vector<std::string> &column_names = {}, int64_t num_samples = 0, | const std::vector<std::string> &column_names = {}, int64_t num_samples = 0, | ||||
| ShuffleMode shuffle = ShuffleMode::kGlobal, int32_t num_shards = 1, int32_t shard_id = 0); | |||||
| ShuffleMode shuffle = ShuffleMode::kGlobal, int32_t num_shards = 1, int32_t shard_id = 0, | |||||
| const std::shared_ptr<DatasetCache> &cache = nullptr); | |||||
| /// \brief Function to create an ImageFolderNode | /// \brief Function to create an ImageFolderNode | ||||
| /// \notes A source dataset that reads images from a tree of directories | /// \notes A source dataset that reads images from a tree of directories | ||||
| @@ -249,11 +268,14 @@ std::shared_ptr<CSVNode> CSV(const std::vector<std::string> &dataset_files, char | |||||
| /// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler()) | /// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler()) | ||||
| /// \param[in] extensions File extensions to be read | /// \param[in] extensions File extensions to be read | ||||
| /// \param[in] class_indexing a class name to label map | /// \param[in] class_indexing a class name to label map | ||||
| /// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used). | |||||
| /// The cache feature is under development and is not recommended. | |||||
| /// \return Shared pointer to the current ImageFolderNode | /// \return Shared pointer to the current ImageFolderNode | ||||
| std::shared_ptr<ImageFolderNode> ImageFolder(const std::string &dataset_dir, bool decode = false, | std::shared_ptr<ImageFolderNode> ImageFolder(const std::string &dataset_dir, bool decode = false, | ||||
| const std::shared_ptr<SamplerObj> &sampler = RandomSampler(), | const std::shared_ptr<SamplerObj> &sampler = RandomSampler(), | ||||
| const std::set<std::string> &extensions = {}, | const std::set<std::string> &extensions = {}, | ||||
| const std::map<std::string, int32_t> &class_indexing = {}); | |||||
| const std::map<std::string, int32_t> &class_indexing = {}, | |||||
| const std::shared_ptr<DatasetCache> &cache = nullptr); | |||||
| #ifndef ENABLE_ANDROID | #ifndef ENABLE_ANDROID | ||||
| /// \brief Function to create a ManifestNode | /// \brief Function to create a ManifestNode | ||||
| @@ -265,10 +287,13 @@ std::shared_ptr<ImageFolderNode> ImageFolder(const std::string &dataset_dir, boo | |||||
| /// \param[in] class_indexing A str-to-int mapping from label name to index (default={}, the folder | /// \param[in] class_indexing A str-to-int mapping from label name to index (default={}, the folder | ||||
| /// names will be sorted alphabetically and each class will be given a unique index starting from 0). | /// names will be sorted alphabetically and each class will be given a unique index starting from 0). | ||||
| /// \param[in] decode Decode the images after reading (default=false). | /// \param[in] decode Decode the images after reading (default=false). | ||||
| /// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used). | |||||
| /// The cache feature is under development and is not recommended. | |||||
| /// \return Shared pointer to the current ManifestNode | /// \return Shared pointer to the current ManifestNode | ||||
| std::shared_ptr<ManifestNode> Manifest(const std::string &dataset_file, const std::string &usage = "train", | std::shared_ptr<ManifestNode> Manifest(const std::string &dataset_file, const std::string &usage = "train", | ||||
| const std::shared_ptr<SamplerObj> &sampler = RandomSampler(), | const std::shared_ptr<SamplerObj> &sampler = RandomSampler(), | ||||
| const std::map<std::string, int32_t> &class_indexing = {}, bool decode = false); | |||||
| const std::map<std::string, int32_t> &class_indexing = {}, bool decode = false, | |||||
| const std::shared_ptr<DatasetCache> &cache = nullptr); | |||||
| #endif | #endif | ||||
| #ifndef ENABLE_ANDROID | #ifndef ENABLE_ANDROID | ||||
| @@ -308,9 +333,12 @@ std::shared_ptr<MindDataNode> MindData(const std::vector<std::string> &dataset_f | |||||
| /// \param[in] usage of MNIST, can be "train", "test" or "all" (default = "all"). | /// \param[in] usage of MNIST, can be "train", "test" or "all" (default = "all"). | ||||
| /// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given, | /// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given, | ||||
| /// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler()) | /// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler()) | ||||
| /// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used). | |||||
| /// The cache feature is under development and is not recommended. | |||||
| /// \return Shared pointer to the current MnistNode | /// \return Shared pointer to the current MnistNode | ||||
| std::shared_ptr<MnistNode> Mnist(const std::string &dataset_dir, const std::string &usage = "all", | std::shared_ptr<MnistNode> Mnist(const std::string &dataset_dir, const std::string &usage = "all", | ||||
| const std::shared_ptr<SamplerObj> &sampler = RandomSampler()); | |||||
| const std::shared_ptr<SamplerObj> &sampler = RandomSampler(), | |||||
| const std::shared_ptr<DatasetCache> &cache = nullptr); | |||||
| /// \brief Function to create a ConcatNode | /// \brief Function to create a ConcatNode | ||||
| /// \notes Reload "+" operator to concat two datasets | /// \notes Reload "+" operator to concat two datasets | ||||
| @@ -326,11 +354,14 @@ std::shared_ptr<ConcatNode> operator+(const std::shared_ptr<Dataset> &datasets1, | |||||
| /// \param[in] columns_list List of columns to be read (default={}, read all columns) | /// \param[in] columns_list List of columns to be read (default={}, read all columns) | ||||
| /// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given, | /// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given, | ||||
| /// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler()) | /// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler()) | ||||
| /// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used). | |||||
| /// The cache feature is under development and is not recommended. | |||||
| /// \return Shared pointer to the current Dataset | /// \return Shared pointer to the current Dataset | ||||
| template <typename T = std::shared_ptr<SchemaObj>> | template <typename T = std::shared_ptr<SchemaObj>> | ||||
| std::shared_ptr<RandomNode> RandomData(const int32_t &total_rows = 0, const T &schema = nullptr, | std::shared_ptr<RandomNode> RandomData(const int32_t &total_rows = 0, const T &schema = nullptr, | ||||
| const std::vector<std::string> &columns_list = {}, | const std::vector<std::string> &columns_list = {}, | ||||
| const std::shared_ptr<SamplerObj> &sampler = RandomSampler()) { | |||||
| const std::shared_ptr<SamplerObj> &sampler = RandomSampler(), | |||||
| const std::shared_ptr<DatasetCache> &cache = nullptr) { | |||||
| if (total_rows < 0) { | if (total_rows < 0) { | ||||
| MS_LOG(ERROR) << "RandomNode: total_rows must be greater than or equal 0, now get " << total_rows; | MS_LOG(ERROR) << "RandomNode: total_rows must be greater than or equal 0, now get " << total_rows; | ||||
| return nullptr; | return nullptr; | ||||
| @@ -356,9 +387,11 @@ std::shared_ptr<RandomNode> RandomData(const int32_t &total_rows = 0, const T &s | |||||
| std::shared_ptr<RandomNode> ds; | std::shared_ptr<RandomNode> ds; | ||||
| if constexpr (std::is_same<T, std::nullptr_t>::value || std::is_same<T, std::shared_ptr<SchemaObj>>::value) { | if constexpr (std::is_same<T, std::nullptr_t>::value || std::is_same<T, std::shared_ptr<SchemaObj>>::value) { | ||||
| std::shared_ptr<SchemaObj> schema_obj = schema; | std::shared_ptr<SchemaObj> schema_obj = schema; | ||||
| ds = std::make_shared<RandomNode>(total_rows, std::move(schema_obj), std::move(columns_list), std::move(sampler)); | |||||
| ds = std::make_shared<RandomNode>(total_rows, std::move(schema_obj), std::move(columns_list), std::move(sampler), | |||||
| cache); | |||||
| } else { | } else { | ||||
| ds = std::make_shared<RandomNode>(total_rows, std::move(schema), std::move(columns_list), std::move(sampler)); | |||||
| ds = | |||||
| std::make_shared<RandomNode>(total_rows, std::move(schema), std::move(columns_list), std::move(sampler), cache); | |||||
| } | } | ||||
| return ds; | return ds; | ||||
| } | } | ||||
| @@ -377,10 +410,12 @@ std::shared_ptr<RandomNode> RandomData(const int32_t &total_rows = 0, const T &s | |||||
| /// \param[in] num_shards Number of shards that the dataset should be divided into. (Default = 1) | /// \param[in] num_shards Number of shards that the dataset should be divided into. (Default = 1) | ||||
| /// \param[in] shard_id The shard ID within num_shards. This argument should be | /// \param[in] shard_id The shard ID within num_shards. This argument should be | ||||
| /// specified only when num_shards is also specified. (Default = 0) | /// specified only when num_shards is also specified. (Default = 0) | ||||
| /// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used). | |||||
| /// The cache feature is under development and is not recommended. | |||||
| /// \return Shared pointer to the current TextFileNode | /// \return Shared pointer to the current TextFileNode | ||||
| std::shared_ptr<TextFileNode> TextFile(const std::vector<std::string> &dataset_files, int64_t num_samples = 0, | std::shared_ptr<TextFileNode> TextFile(const std::vector<std::string> &dataset_files, int64_t num_samples = 0, | ||||
| ShuffleMode shuffle = ShuffleMode::kGlobal, int32_t num_shards = 1, | ShuffleMode shuffle = ShuffleMode::kGlobal, int32_t num_shards = 1, | ||||
| int32_t shard_id = 0); | |||||
| int32_t shard_id = 0, const std::shared_ptr<DatasetCache> &cache = nullptr); | |||||
| #ifndef ENABLE_ANDROID | #ifndef ENABLE_ANDROID | ||||
| /// \brief Function to create a TFRecordNode | /// \brief Function to create a TFRecordNode | ||||
| @@ -404,12 +439,15 @@ std::shared_ptr<TextFileNode> TextFile(const std::vector<std::string> &dataset_f | |||||
| /// when num_shards is also specified. (Default = 0) | /// when num_shards is also specified. (Default = 0) | ||||
| /// \param[in] shard_equal_rows Get equal rows for all shards. (Default = False, number of rows of | /// \param[in] shard_equal_rows Get equal rows for all shards. (Default = False, number of rows of | ||||
| /// each shard may be not equal) | /// each shard may be not equal) | ||||
| /// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used). | |||||
| /// The cache feature is under development and is not recommended. | |||||
| /// \return Shared pointer to the current TFRecordNode | /// \return Shared pointer to the current TFRecordNode | ||||
| template <typename T = std::shared_ptr<SchemaObj>> | template <typename T = std::shared_ptr<SchemaObj>> | ||||
| std::shared_ptr<TFRecordNode> TFRecord(const std::vector<std::string> &dataset_files, const T &schema = nullptr, | std::shared_ptr<TFRecordNode> TFRecord(const std::vector<std::string> &dataset_files, const T &schema = nullptr, | ||||
| const std::vector<std::string> &columns_list = {}, int64_t num_samples = 0, | const std::vector<std::string> &columns_list = {}, int64_t num_samples = 0, | ||||
| ShuffleMode shuffle = ShuffleMode::kGlobal, int32_t num_shards = 1, | ShuffleMode shuffle = ShuffleMode::kGlobal, int32_t num_shards = 1, | ||||
| int32_t shard_id = 0, bool shard_equal_rows = false) { | |||||
| int32_t shard_id = 0, bool shard_equal_rows = false, | |||||
| const std::shared_ptr<DatasetCache> &cache = nullptr) { | |||||
| if (dataset_files.empty()) { | if (dataset_files.empty()) { | ||||
| MS_LOG(ERROR) << "TFRecordNode: dataset_files is not specified."; | MS_LOG(ERROR) << "TFRecordNode: dataset_files is not specified."; | ||||
| return nullptr; | return nullptr; | ||||
| @@ -441,7 +479,7 @@ std::shared_ptr<TFRecordNode> TFRecord(const std::vector<std::string> &dataset_f | |||||
| if constexpr (std::is_same<T, std::nullptr_t>::value || std::is_same<T, std::shared_ptr<SchemaObj>>::value) { | if constexpr (std::is_same<T, std::nullptr_t>::value || std::is_same<T, std::shared_ptr<SchemaObj>>::value) { | ||||
| std::shared_ptr<SchemaObj> schema_obj = schema; | std::shared_ptr<SchemaObj> schema_obj = schema; | ||||
| ds = std::make_shared<TFRecordNode>(dataset_files, schema_obj, columns_list, num_samples, shuffle, num_shards, | ds = std::make_shared<TFRecordNode>(dataset_files, schema_obj, columns_list, num_samples, shuffle, num_shards, | ||||
| shard_id, shard_equal_rows); | |||||
| shard_id, shard_equal_rows, cache); | |||||
| } else { | } else { | ||||
| std::string schema_path = schema; | std::string schema_path = schema; | ||||
| if (!schema_path.empty()) { | if (!schema_path.empty()) { | ||||
| @@ -452,7 +490,7 @@ std::shared_ptr<TFRecordNode> TFRecord(const std::vector<std::string> &dataset_f | |||||
| } | } | ||||
| } | } | ||||
| ds = std::make_shared<TFRecordNode>(dataset_files, schema_path, columns_list, num_samples, shuffle, num_shards, | ds = std::make_shared<TFRecordNode>(dataset_files, schema_path, columns_list, num_samples, shuffle, num_shards, | ||||
| shard_id, shard_equal_rows); | |||||
| shard_id, shard_equal_rows, cache); | |||||
| } | } | ||||
| return ds; | return ds; | ||||
| } | } | ||||
| @@ -469,11 +507,28 @@ std::shared_ptr<TFRecordNode> TFRecord(const std::vector<std::string> &dataset_f | |||||
| /// \param[in] decode Decode the images after reading | /// \param[in] decode Decode the images after reading | ||||
| /// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given, | /// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given, | ||||
| /// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler()) | /// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler()) | ||||
| /// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used). | |||||
| /// The cache feature is under development and is not recommended. | |||||
| /// \return Shared pointer to the current Dataset | /// \return Shared pointer to the current Dataset | ||||
| std::shared_ptr<VOCNode> VOC(const std::string &dataset_dir, const std::string &task = "Segmentation", | std::shared_ptr<VOCNode> VOC(const std::string &dataset_dir, const std::string &task = "Segmentation", | ||||
| const std::string &usage = "train", | const std::string &usage = "train", | ||||
| const std::map<std::string, int32_t> &class_indexing = {}, bool decode = false, | const std::map<std::string, int32_t> &class_indexing = {}, bool decode = false, | ||||
| const std::shared_ptr<SamplerObj> &sampler = RandomSampler()); | |||||
| const std::shared_ptr<SamplerObj> &sampler = RandomSampler(), | |||||
| const std::shared_ptr<DatasetCache> &cache = nullptr); | |||||
| /// \brief Function the create a cache to be attached to a dataset | |||||
| /// \param id A user assigned session id for the current pipeline | |||||
| /// \param mem_sz Size of the memory set aside for the row caching. 0 for unlimited | |||||
| /// \param spill Spill to disk if out of memory | |||||
| /// \param hostname optional host name | |||||
| /// \param port optional port | |||||
| /// \param num_connections optional number of connections | |||||
| /// \param prefetch_sz optional prefetch size | |||||
| /// \return Shared pointer to DatasetCache. If error, nullptr is returned. | |||||
| std::shared_ptr<DatasetCache> CreateDatasetCache(session_id_type id, uint64_t mem_sz, bool spill, | |||||
| std::optional<std::string> hostname, std::optional<int32_t> port, | |||||
| std::optional<int32_t> num_connections, | |||||
| std::optional<int32_t> prefetch_sz); | |||||
| #endif | #endif | ||||
| /// \brief Function to create a ZipNode | /// \brief Function to create a ZipNode | ||||
| @@ -493,6 +548,10 @@ class Dataset : public std::enable_shared_from_this<Dataset> { | |||||
| /// \brief Constructor | /// \brief Constructor | ||||
| Dataset(); | Dataset(); | ||||
| /// \brief Constructor that initializes the cache | |||||
| /// \param dataset_cache DatasetCache | |||||
| explicit Dataset(const std::shared_ptr<DatasetCache> &dataset_cache); | |||||
| /// \brief Destructor | /// \brief Destructor | ||||
| ~Dataset() = default; | ~Dataset() = default; | ||||
| @@ -610,11 +669,14 @@ class Dataset : public std::enable_shared_from_this<Dataset> { | |||||
| /// last operation. The default output_columns will have the same | /// last operation. The default output_columns will have the same | ||||
| /// name as the input columns, i.e., the columns will be replaced | /// name as the input columns, i.e., the columns will be replaced | ||||
| /// \param[in] project_columns A list of column names to project | /// \param[in] project_columns A list of column names to project | ||||
| /// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used). | |||||
| /// The cache feature is under development and is not recommended. | |||||
| /// \return Shared pointer to the current MapNode | /// \return Shared pointer to the current MapNode | ||||
| std::shared_ptr<MapNode> Map(std::vector<std::shared_ptr<TensorOperation>> operations, | std::shared_ptr<MapNode> Map(std::vector<std::shared_ptr<TensorOperation>> operations, | ||||
| std::vector<std::string> input_columns = {}, | std::vector<std::string> input_columns = {}, | ||||
| std::vector<std::string> output_columns = {}, | std::vector<std::string> output_columns = {}, | ||||
| const std::vector<std::string> &project_columns = {}); | |||||
| const std::vector<std::string> &project_columns = {}, | |||||
| const std::shared_ptr<DatasetCache> &cache = nullptr); | |||||
| /// \brief Function to create a Project Dataset | /// \brief Function to create a Project Dataset | ||||
| /// \notes Applies project to the dataset | /// \notes Applies project to the dataset | ||||
| @@ -670,6 +732,9 @@ class Dataset : public std::enable_shared_from_this<Dataset> { | |||||
| int32_t rows_per_buffer_; | int32_t rows_per_buffer_; | ||||
| int32_t connector_que_size_; | int32_t connector_que_size_; | ||||
| int32_t worker_connector_size_; | int32_t worker_connector_size_; | ||||
| std::shared_ptr<DatasetCache> cache_; | |||||
| Status AddCacheOp(std::vector<std::shared_ptr<DatasetOp>> *node_ops); | |||||
| }; | }; | ||||
| class SchemaObj { | class SchemaObj { | ||||
| @@ -766,7 +831,7 @@ class CelebANode : public Dataset { | |||||
| public: | public: | ||||
| /// \brief Constructor | /// \brief Constructor | ||||
| CelebANode(const std::string &dataset_dir, const std::string &usage, const std::shared_ptr<SamplerObj> &sampler, | CelebANode(const std::string &dataset_dir, const std::string &usage, const std::shared_ptr<SamplerObj> &sampler, | ||||
| const bool &decode, const std::set<std::string> &extensions); | |||||
| const bool &decode, const std::set<std::string> &extensions, const std::shared_ptr<DatasetCache> &cache); | |||||
| /// \brief Destructor | /// \brief Destructor | ||||
| ~CelebANode() = default; | ~CelebANode() = default; | ||||
| @@ -792,7 +857,8 @@ class CelebANode : public Dataset { | |||||
| class Cifar10Node : public Dataset { | class Cifar10Node : public Dataset { | ||||
| public: | public: | ||||
| /// \brief Constructor | /// \brief Constructor | ||||
| Cifar10Node(const std::string &dataset_dir, const std::string &usage, std::shared_ptr<SamplerObj> sampler); | |||||
| Cifar10Node(const std::string &dataset_dir, const std::string &usage, std::shared_ptr<SamplerObj> sampler, | |||||
| std::shared_ptr<DatasetCache> cache); | |||||
| /// \brief Destructor | /// \brief Destructor | ||||
| ~Cifar10Node() = default; | ~Cifar10Node() = default; | ||||
| @@ -814,7 +880,8 @@ class Cifar10Node : public Dataset { | |||||
| class Cifar100Node : public Dataset { | class Cifar100Node : public Dataset { | ||||
| public: | public: | ||||
| /// \brief Constructor | /// \brief Constructor | ||||
| Cifar100Node(const std::string &dataset_dir, const std::string &usage, std::shared_ptr<SamplerObj> sampler); | |||||
| Cifar100Node(const std::string &dataset_dir, const std::string &usage, std::shared_ptr<SamplerObj> sampler, | |||||
| std::shared_ptr<DatasetCache> cache); | |||||
| /// \brief Destructor | /// \brief Destructor | ||||
| ~Cifar100Node() = default; | ~Cifar100Node() = default; | ||||
| @@ -839,7 +906,7 @@ class CLUENode : public Dataset { | |||||
| public: | public: | ||||
| /// \brief Constructor | /// \brief Constructor | ||||
| CLUENode(const std::vector<std::string> dataset_files, std::string task, std::string usage, int64_t num_samples, | CLUENode(const std::vector<std::string> dataset_files, std::string task, std::string usage, int64_t num_samples, | ||||
| ShuffleMode shuffle, int32_t num_shards, int32_t shard_id); | |||||
| ShuffleMode shuffle, int32_t num_shards, int32_t shard_id, std::shared_ptr<DatasetCache> cache); | |||||
| /// \brief Destructor | /// \brief Destructor | ||||
| ~CLUENode() = default; | ~CLUENode() = default; | ||||
| @@ -870,7 +937,7 @@ class CocoNode : public Dataset { | |||||
| public: | public: | ||||
| /// \brief Constructor | /// \brief Constructor | ||||
| CocoNode(const std::string &dataset_dir, const std::string &annotation_file, const std::string &task, | CocoNode(const std::string &dataset_dir, const std::string &annotation_file, const std::string &task, | ||||
| const bool &decode, const std::shared_ptr<SamplerObj> &sampler); | |||||
| const bool &decode, const std::shared_ptr<SamplerObj> &sampler, std::shared_ptr<DatasetCache> cache); | |||||
| /// \brief Destructor | /// \brief Destructor | ||||
| ~CocoNode() = default; | ~CocoNode() = default; | ||||
| @@ -918,7 +985,8 @@ class CSVNode : public Dataset { | |||||
| /// \brief Constructor | /// \brief Constructor | ||||
| CSVNode(const std::vector<std::string> &dataset_files, char field_delim, | CSVNode(const std::vector<std::string> &dataset_files, char field_delim, | ||||
| const std::vector<std::shared_ptr<CsvBase>> &column_defaults, const std::vector<std::string> &column_names, | const std::vector<std::shared_ptr<CsvBase>> &column_defaults, const std::vector<std::string> &column_names, | ||||
| int64_t num_samples, ShuffleMode shuffle, int32_t num_shards, int32_t shard_id); | |||||
| int64_t num_samples, ShuffleMode shuffle, int32_t num_shards, int32_t shard_id, | |||||
| std::shared_ptr<DatasetCache> cache); | |||||
| /// \brief Destructor | /// \brief Destructor | ||||
| ~CSVNode() = default; | ~CSVNode() = default; | ||||
| @@ -947,7 +1015,7 @@ class ManifestNode : public Dataset { | |||||
| public: | public: | ||||
| /// \brief Constructor | /// \brief Constructor | ||||
| ManifestNode(const std::string &dataset_file, const std::string &usage, const std::shared_ptr<SamplerObj> &sampler, | ManifestNode(const std::string &dataset_file, const std::string &usage, const std::shared_ptr<SamplerObj> &sampler, | ||||
| const std::map<std::string, int32_t> &class_indexing, bool decode); | |||||
| const std::map<std::string, int32_t> &class_indexing, bool decode, std::shared_ptr<DatasetCache> cache); | |||||
| /// \brief Destructor | /// \brief Destructor | ||||
| ~ManifestNode() = default; | ~ManifestNode() = default; | ||||
| @@ -1016,7 +1084,8 @@ class MindDataNode : public Dataset { | |||||
| class MnistNode : public Dataset { | class MnistNode : public Dataset { | ||||
| public: | public: | ||||
| /// \brief Constructor | /// \brief Constructor | ||||
| MnistNode(std::string dataset_dir, std::string usage, std::shared_ptr<SamplerObj> sampler); | |||||
| MnistNode(std::string dataset_dir, std::string usage, std::shared_ptr<SamplerObj> sampler, | |||||
| std::shared_ptr<DatasetCache> cache); | |||||
| /// \brief Destructor | /// \brief Destructor | ||||
| ~MnistNode() = default; | ~MnistNode() = default; | ||||
| @@ -1044,8 +1113,9 @@ class RandomNode : public Dataset { | |||||
| /// \brief Constructor | /// \brief Constructor | ||||
| RandomNode(const int32_t &total_rows, std::shared_ptr<SchemaObj> schema, const std::vector<std::string> &columns_list, | RandomNode(const int32_t &total_rows, std::shared_ptr<SchemaObj> schema, const std::vector<std::string> &columns_list, | ||||
| const std::shared_ptr<SamplerObj> &sampler) | |||||
| : total_rows_(total_rows), | |||||
| const std::shared_ptr<SamplerObj> &sampler, std::shared_ptr<DatasetCache> cache) | |||||
| : Dataset(std::move(cache)), | |||||
| total_rows_(total_rows), | |||||
| schema_path_(""), | schema_path_(""), | ||||
| schema_(std::move(schema)), | schema_(std::move(schema)), | ||||
| columns_list_(columns_list), | columns_list_(columns_list), | ||||
| @@ -1053,8 +1123,12 @@ class RandomNode : public Dataset { | |||||
| /// \brief Constructor | /// \brief Constructor | ||||
| RandomNode(const int32_t &total_rows, std::string schema_path, const std::vector<std::string> &columns_list, | RandomNode(const int32_t &total_rows, std::string schema_path, const std::vector<std::string> &columns_list, | ||||
| const std::shared_ptr<SamplerObj> &sampler) | |||||
| : total_rows_(total_rows), schema_path_(schema_path), columns_list_(columns_list), sampler_(std::move(sampler)) {} | |||||
| const std::shared_ptr<SamplerObj> &sampler, std::shared_ptr<DatasetCache> cache) | |||||
| : Dataset(std::move(cache)), | |||||
| total_rows_(total_rows), | |||||
| schema_path_(schema_path), | |||||
| columns_list_(columns_list), | |||||
| sampler_(std::move(sampler)) {} | |||||
| /// \brief Destructor | /// \brief Destructor | ||||
| ~RandomNode() = default; | ~RandomNode() = default; | ||||
| @@ -1088,7 +1162,7 @@ class TextFileNode : public Dataset { | |||||
| public: | public: | ||||
| /// \brief Constructor | /// \brief Constructor | ||||
| TextFileNode(std::vector<std::string> dataset_files, int32_t num_samples, ShuffleMode shuffle, int32_t num_shards, | TextFileNode(std::vector<std::string> dataset_files, int32_t num_samples, ShuffleMode shuffle, int32_t num_shards, | ||||
| int32_t shard_id); | |||||
| int32_t shard_id, std::shared_ptr<DatasetCache> cache); | |||||
| /// \brief Destructor | /// \brief Destructor | ||||
| ~TextFileNode() = default; | ~TextFileNode() = default; | ||||
| @@ -1117,8 +1191,9 @@ class TFRecordNode : public Dataset { | |||||
| /// \note Parameter 'schema' is the path to the schema file | /// \note Parameter 'schema' is the path to the schema file | ||||
| TFRecordNode(const std::vector<std::string> &dataset_files, std::string schema, | TFRecordNode(const std::vector<std::string> &dataset_files, std::string schema, | ||||
| const std::vector<std::string> &columns_list, int64_t num_samples, ShuffleMode shuffle, | const std::vector<std::string> &columns_list, int64_t num_samples, ShuffleMode shuffle, | ||||
| int32_t num_shards, int32_t shard_id, bool shard_equal_rows) | |||||
| : dataset_files_(dataset_files), | |||||
| int32_t num_shards, int32_t shard_id, bool shard_equal_rows, std::shared_ptr<DatasetCache> cache) | |||||
| : Dataset(std::move(cache)), | |||||
| dataset_files_(dataset_files), | |||||
| schema_path_(schema), | schema_path_(schema), | ||||
| columns_list_(columns_list), | columns_list_(columns_list), | ||||
| num_samples_(num_samples), | num_samples_(num_samples), | ||||
| @@ -1131,8 +1206,9 @@ class TFRecordNode : public Dataset { | |||||
| /// \note Parameter 'schema' is shared pointer to Schema object | /// \note Parameter 'schema' is shared pointer to Schema object | ||||
| TFRecordNode(const std::vector<std::string> &dataset_files, std::shared_ptr<SchemaObj> schema, | TFRecordNode(const std::vector<std::string> &dataset_files, std::shared_ptr<SchemaObj> schema, | ||||
| const std::vector<std::string> &columns_list, int64_t num_samples, ShuffleMode shuffle, | const std::vector<std::string> &columns_list, int64_t num_samples, ShuffleMode shuffle, | ||||
| int32_t num_shards, int32_t shard_id, bool shard_equal_rows) | |||||
| : dataset_files_(dataset_files), | |||||
| int32_t num_shards, int32_t shard_id, bool shard_equal_rows, std::shared_ptr<DatasetCache> cache) | |||||
| : Dataset(std::move(cache)), | |||||
| dataset_files_(dataset_files), | |||||
| schema_obj_(schema), | schema_obj_(schema), | ||||
| columns_list_(columns_list), | columns_list_(columns_list), | ||||
| num_samples_(num_samples), | num_samples_(num_samples), | ||||
| @@ -1169,7 +1245,8 @@ class VOCNode : public Dataset { | |||||
| public: | public: | ||||
| /// \brief Constructor | /// \brief Constructor | ||||
| VOCNode(const std::string &dataset_dir, const std::string &task, const std::string &usage, | VOCNode(const std::string &dataset_dir, const std::string &task, const std::string &usage, | ||||
| const std::map<std::string, int32_t> &class_indexing, bool decode, std::shared_ptr<SamplerObj> sampler); | |||||
| const std::map<std::string, int32_t> &class_indexing, bool decode, std::shared_ptr<SamplerObj> sampler, | |||||
| std::shared_ptr<DatasetCache> cache); | |||||
| /// \brief Destructor | /// \brief Destructor | ||||
| ~VOCNode() = default; | ~VOCNode() = default; | ||||
| @@ -1206,7 +1283,7 @@ class MapNode : public Dataset { | |||||
| /// \brief Constructor | /// \brief Constructor | ||||
| MapNode(std::shared_ptr<Dataset> child, std::vector<std::shared_ptr<TensorOperation>> operations, | MapNode(std::shared_ptr<Dataset> child, std::vector<std::shared_ptr<TensorOperation>> operations, | ||||
| std::vector<std::string> input_columns = {}, std::vector<std::string> output_columns = {}, | std::vector<std::string> input_columns = {}, std::vector<std::string> output_columns = {}, | ||||
| const std::vector<std::string> &columns = {}); | |||||
| const std::vector<std::string> &columns = {}, std::shared_ptr<DatasetCache> cache = nullptr); | |||||
| /// \brief Destructor | /// \brief Destructor | ||||
| ~MapNode() = default; | ~MapNode() = default; | ||||