Browse Source

!7609 Add c++ API for CacheOp

Merge pull request !7609 from h.farahat/cache_c++_api
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
f6eaeca62c
10 changed files with 417 additions and 91 deletions
  1. +5
    -2
      mindspore/ccsrc/minddata/dataset/CMakeLists.txt
  2. +131
    -47
      mindspore/ccsrc/minddata/dataset/api/datasets.cc
  3. +2
    -1
      mindspore/ccsrc/minddata/dataset/engine/ir/CMakeLists.txt
  4. +4
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/cache/CMakeLists.txt
  5. +34
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/cache/dataset_cache.h
  6. +44
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/cache/dataset_cache_impl.cc
  7. +72
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/cache/dataset_cache_impl.h
  8. +7
    -2
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/image_folder_node.cc
  9. +3
    -1
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/image_folder_node.h
  10. +115
    -38
      mindspore/ccsrc/minddata/dataset/include/datasets.h

+ 5
- 2
mindspore/ccsrc/minddata/dataset/CMakeLists.txt View File

@@ -80,6 +80,8 @@ add_dependencies(text-kernels core)
add_dependencies(cpp-API core) add_dependencies(cpp-API core)
add_dependencies(engine-ir-datasetops core) add_dependencies(engine-ir-datasetops core)
add_dependencies(engine-ir-datasetops-source core) add_dependencies(engine-ir-datasetops-source core)
add_dependencies(engine-ir-cache core)

if (ENABLE_PYTHON) if (ENABLE_PYTHON)
add_dependencies(APItoPython core) add_dependencies(APItoPython core)
endif() endif()
@@ -102,8 +104,9 @@ set(submodules
$<TARGET_OBJECTS:kernels-data> $<TARGET_OBJECTS:kernels-data>
$<TARGET_OBJECTS:cpp-API> $<TARGET_OBJECTS:cpp-API>
$<TARGET_OBJECTS:engine-ir-datasetops> $<TARGET_OBJECTS:engine-ir-datasetops>
$<TARGET_OBJECTS:engine-ir-datasetops-source>
$<TARGET_OBJECTS:kernels-soft-dvpp-image>
$<TARGET_OBJECTS:engine-ir-datasetops-source>
$<TARGET_OBJECTS:engine-ir-cache>
$<TARGET_OBJECTS:kernels-soft-dvpp-image>
$<TARGET_OBJECTS:soft-dvpp-utils> $<TARGET_OBJECTS:soft-dvpp-utils>
$<TARGET_OBJECTS:engine-datasetops-source> $<TARGET_OBJECTS:engine-datasetops-source>
$<TARGET_OBJECTS:engine-datasetops-source-sampler> $<TARGET_OBJECTS:engine-datasetops-source-sampler>


+ 131
- 47
mindspore/ccsrc/minddata/dataset/api/datasets.cc View File

@@ -18,6 +18,7 @@
#include <algorithm> #include <algorithm>
#include <fstream> #include <fstream>
#include <unordered_set> #include <unordered_set>
#include <utility>
#include "minddata/dataset/include/samplers.h" #include "minddata/dataset/include/samplers.h"
#include "minddata/dataset/include/transforms.h" #include "minddata/dataset/include/transforms.h"
// Source dataset headers (in alphabetical order) // Source dataset headers (in alphabetical order)
@@ -32,6 +33,7 @@
#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
#include "minddata/dataset/engine/datasetops/source/manifest_op.h" #include "minddata/dataset/engine/datasetops/source/manifest_op.h"
#include "minddata/dataset/engine/datasetops/source/mindrecord_op.h" #include "minddata/dataset/engine/datasetops/source/mindrecord_op.h"
#include "minddata/dataset/engine/ir/cache/dataset_cache_impl.h"
#endif #endif
#include "minddata/dataset/engine/datasetops/source/mnist_op.h" #include "minddata/dataset/engine/datasetops/source/mnist_op.h"
#include "minddata/dataset/engine/datasetops/source/random_data_op.h" #include "minddata/dataset/engine/datasetops/source/random_data_op.h"
@@ -113,6 +115,9 @@ Dataset::Dataset() {
worker_connector_size_ = cfg->worker_connector_size(); worker_connector_size_ = cfg->worker_connector_size();
} }


// Constructor to initialize the cache
Dataset::Dataset(const std::shared_ptr<DatasetCache> &dataset_cache) : Dataset() { cache_ = dataset_cache; }

/// \brief Function to create a SchemaObj /// \brief Function to create a SchemaObj
/// \param[in] schema_file Path of schema file /// \param[in] schema_file Path of schema file
/// \return Shared pointer to the current schema /// \return Shared pointer to the current schema
@@ -137,8 +142,9 @@ std::shared_ptr<AlbumNode> Album(const std::string &dataset_dir, const std::stri
// Function to create a CelebANode. // Function to create a CelebANode.
std::shared_ptr<CelebANode> CelebA(const std::string &dataset_dir, const std::string &usage, std::shared_ptr<CelebANode> CelebA(const std::string &dataset_dir, const std::string &usage,
const std::shared_ptr<SamplerObj> &sampler, bool decode, const std::shared_ptr<SamplerObj> &sampler, bool decode,
const std::set<std::string> &extensions) {
auto ds = std::make_shared<CelebANode>(dataset_dir, usage, sampler, decode, extensions);
const std::set<std::string> &extensions,
const std::shared_ptr<DatasetCache> &cache) {
auto ds = std::make_shared<CelebANode>(dataset_dir, usage, sampler, decode, extensions, cache);


// Call derived class validation method. // Call derived class validation method.
return ds->ValidateParams() ? ds : nullptr; return ds->ValidateParams() ? ds : nullptr;
@@ -146,8 +152,9 @@ std::shared_ptr<CelebANode> CelebA(const std::string &dataset_dir, const std::st


// Function to create a Cifar10Node. // Function to create a Cifar10Node.
std::shared_ptr<Cifar10Node> Cifar10(const std::string &dataset_dir, const std::string &usage, std::shared_ptr<Cifar10Node> Cifar10(const std::string &dataset_dir, const std::string &usage,
const std::shared_ptr<SamplerObj> &sampler) {
auto ds = std::make_shared<Cifar10Node>(dataset_dir, usage, sampler);
const std::shared_ptr<SamplerObj> &sampler,
const std::shared_ptr<DatasetCache> &cache) {
auto ds = std::make_shared<Cifar10Node>(dataset_dir, usage, sampler, cache);


// Call derived class validation method. // Call derived class validation method.
return ds->ValidateParams() ? ds : nullptr; return ds->ValidateParams() ? ds : nullptr;
@@ -155,8 +162,9 @@ std::shared_ptr<Cifar10Node> Cifar10(const std::string &dataset_dir, const std::


// Function to create a Cifar100Node. // Function to create a Cifar100Node.
std::shared_ptr<Cifar100Node> Cifar100(const std::string &dataset_dir, const std::string &usage, std::shared_ptr<Cifar100Node> Cifar100(const std::string &dataset_dir, const std::string &usage,
const std::shared_ptr<SamplerObj> &sampler) {
auto ds = std::make_shared<Cifar100Node>(dataset_dir, usage, sampler);
const std::shared_ptr<SamplerObj> &sampler,
const std::shared_ptr<DatasetCache> &cache) {
auto ds = std::make_shared<Cifar100Node>(dataset_dir, usage, sampler, cache);


// Call derived class validation method. // Call derived class validation method.
return ds->ValidateParams() ? ds : nullptr; return ds->ValidateParams() ? ds : nullptr;
@@ -165,8 +173,8 @@ std::shared_ptr<Cifar100Node> Cifar100(const std::string &dataset_dir, const std
// Function to create a CLUENode. // Function to create a CLUENode.
std::shared_ptr<CLUENode> CLUE(const std::vector<std::string> &clue_files, const std::string &task, std::shared_ptr<CLUENode> CLUE(const std::vector<std::string> &clue_files, const std::string &task,
const std::string &usage, int64_t num_samples, ShuffleMode shuffle, int32_t num_shards, const std::string &usage, int64_t num_samples, ShuffleMode shuffle, int32_t num_shards,
int32_t shard_id) {
auto ds = std::make_shared<CLUENode>(clue_files, task, usage, num_samples, shuffle, num_shards, shard_id);
int32_t shard_id, const std::shared_ptr<DatasetCache> &cache) {
auto ds = std::make_shared<CLUENode>(clue_files, task, usage, num_samples, shuffle, num_shards, shard_id, cache);


// Call derived class validation method. // Call derived class validation method.
return ds->ValidateParams() ? ds : nullptr; return ds->ValidateParams() ? ds : nullptr;
@@ -174,9 +182,9 @@ std::shared_ptr<CLUENode> CLUE(const std::vector<std::string> &clue_files, const


// Function to create a CocoNode. // Function to create a CocoNode.
std::shared_ptr<CocoNode> Coco(const std::string &dataset_dir, const std::string &annotation_file, std::shared_ptr<CocoNode> Coco(const std::string &dataset_dir, const std::string &annotation_file,
const std::string &task, const bool &decode,
const std::shared_ptr<SamplerObj> &sampler) {
auto ds = std::make_shared<CocoNode>(dataset_dir, annotation_file, task, decode, sampler);
const std::string &task, const bool &decode, const std::shared_ptr<SamplerObj> &sampler,
const std::shared_ptr<DatasetCache> &cache) {
auto ds = std::make_shared<CocoNode>(dataset_dir, annotation_file, task, decode, sampler, cache);


// Call derived class validation method. // Call derived class validation method.
return ds->ValidateParams() ? ds : nullptr; return ds->ValidateParams() ? ds : nullptr;
@@ -186,9 +194,9 @@ std::shared_ptr<CocoNode> Coco(const std::string &dataset_dir, const std::string
std::shared_ptr<CSVNode> CSV(const std::vector<std::string> &dataset_files, char field_delim, std::shared_ptr<CSVNode> CSV(const std::vector<std::string> &dataset_files, char field_delim,
const std::vector<std::shared_ptr<CsvBase>> &column_defaults, const std::vector<std::shared_ptr<CsvBase>> &column_defaults,
const std::vector<std::string> &column_names, int64_t num_samples, ShuffleMode shuffle, const std::vector<std::string> &column_names, int64_t num_samples, ShuffleMode shuffle,
int32_t num_shards, int32_t shard_id) {
int32_t num_shards, int32_t shard_id, const std::shared_ptr<DatasetCache> &cache) {
auto ds = std::make_shared<CSVNode>(dataset_files, field_delim, column_defaults, column_names, num_samples, shuffle, auto ds = std::make_shared<CSVNode>(dataset_files, field_delim, column_defaults, column_names, num_samples, shuffle,
num_shards, shard_id);
num_shards, shard_id, cache);


// Call derived class validation method. // Call derived class validation method.
return ds->ValidateParams() ? ds : nullptr; return ds->ValidateParams() ? ds : nullptr;
@@ -198,12 +206,14 @@ std::shared_ptr<CSVNode> CSV(const std::vector<std::string> &dataset_files, char
std::shared_ptr<ImageFolderNode> ImageFolder(const std::string &dataset_dir, bool decode, std::shared_ptr<ImageFolderNode> ImageFolder(const std::string &dataset_dir, bool decode,
const std::shared_ptr<SamplerObj> &sampler, const std::shared_ptr<SamplerObj> &sampler,
const std::set<std::string> &extensions, const std::set<std::string> &extensions,
const std::map<std::string, int32_t> &class_indexing) {
const std::map<std::string, int32_t> &class_indexing,
const std::shared_ptr<DatasetCache> &cache) {
// This arg exists in ImageFolderOp, but not externalized (in Python API). The default value is false. // This arg exists in ImageFolderOp, but not externalized (in Python API). The default value is false.
bool recursive = false; bool recursive = false;


// Create logical representation of ImageFolderNode. // Create logical representation of ImageFolderNode.
auto ds = std::make_shared<ImageFolderNode>(dataset_dir, decode, sampler, recursive, extensions, class_indexing);
auto ds =
std::make_shared<ImageFolderNode>(dataset_dir, decode, sampler, recursive, extensions, class_indexing, cache);


// Call derived class validation method. // Call derived class validation method.
return ds->ValidateParams() ? ds : nullptr; return ds->ValidateParams() ? ds : nullptr;
@@ -213,8 +223,9 @@ std::shared_ptr<ImageFolderNode> ImageFolder(const std::string &dataset_dir, boo
// Function to create a ManifestNode. // Function to create a ManifestNode.
std::shared_ptr<ManifestNode> Manifest(const std::string &dataset_file, const std::string &usage, std::shared_ptr<ManifestNode> Manifest(const std::string &dataset_file, const std::string &usage,
const std::shared_ptr<SamplerObj> &sampler, const std::shared_ptr<SamplerObj> &sampler,
const std::map<std::string, int32_t> &class_indexing, bool decode) {
auto ds = std::make_shared<ManifestNode>(dataset_file, usage, sampler, class_indexing, decode);
const std::map<std::string, int32_t> &class_indexing, bool decode,
const std::shared_ptr<DatasetCache> &cache) {
auto ds = std::make_shared<ManifestNode>(dataset_file, usage, sampler, class_indexing, decode, cache);


// Call derived class validation method. // Call derived class validation method.
return ds->ValidateParams() ? ds : nullptr; return ds->ValidateParams() ? ds : nullptr;
@@ -244,8 +255,9 @@ std::shared_ptr<MindDataNode> MindData(const std::vector<std::string> &dataset_f


// Function to create a MnistNode. // Function to create a MnistNode.
std::shared_ptr<MnistNode> Mnist(const std::string &dataset_dir, const std::string &usage, std::shared_ptr<MnistNode> Mnist(const std::string &dataset_dir, const std::string &usage,
const std::shared_ptr<SamplerObj> &sampler) {
auto ds = std::make_shared<MnistNode>(dataset_dir, usage, sampler);
const std::shared_ptr<SamplerObj> &sampler,
const std::shared_ptr<DatasetCache> &cache) {
auto ds = std::make_shared<MnistNode>(dataset_dir, usage, sampler, cache);


// Call derived class validation method. // Call derived class validation method.
return ds->ValidateParams() ? ds : nullptr; return ds->ValidateParams() ? ds : nullptr;
@@ -262,8 +274,9 @@ std::shared_ptr<ConcatNode> operator+(const std::shared_ptr<Dataset> &datasets1,


// Function to create a TextFileNode. // Function to create a TextFileNode.
std::shared_ptr<TextFileNode> TextFile(const std::vector<std::string> &dataset_files, int64_t num_samples, std::shared_ptr<TextFileNode> TextFile(const std::vector<std::string> &dataset_files, int64_t num_samples,
ShuffleMode shuffle, int32_t num_shards, int32_t shard_id) {
auto ds = std::make_shared<TextFileNode>(dataset_files, num_samples, shuffle, num_shards, shard_id);
ShuffleMode shuffle, int32_t num_shards, int32_t shard_id,
const std::shared_ptr<DatasetCache> &cache) {
auto ds = std::make_shared<TextFileNode>(dataset_files, num_samples, shuffle, num_shards, shard_id, cache);


// Call derived class validation method. // Call derived class validation method.
return ds->ValidateParams() ? ds : nullptr; return ds->ValidateParams() ? ds : nullptr;
@@ -273,8 +286,8 @@ std::shared_ptr<TextFileNode> TextFile(const std::vector<std::string> &dataset_f
// Function to create a VOCNode. // Function to create a VOCNode.
std::shared_ptr<VOCNode> VOC(const std::string &dataset_dir, const std::string &task, const std::string &usage, std::shared_ptr<VOCNode> VOC(const std::string &dataset_dir, const std::string &task, const std::string &usage,
const std::map<std::string, int32_t> &class_indexing, bool decode, const std::map<std::string, int32_t> &class_indexing, bool decode,
const std::shared_ptr<SamplerObj> &sampler) {
auto ds = std::make_shared<VOCNode>(dataset_dir, task, usage, class_indexing, decode, sampler);
const std::shared_ptr<SamplerObj> &sampler, const std::shared_ptr<DatasetCache> &cache) {
auto ds = std::make_shared<VOCNode>(dataset_dir, task, usage, class_indexing, decode, sampler, cache);


// Call derived class validation method. // Call derived class validation method.
return ds->ValidateParams() ? ds : nullptr; return ds->ValidateParams() ? ds : nullptr;
@@ -365,8 +378,10 @@ std::shared_ptr<ConcatNode> Dataset::Concat(const std::vector<std::shared_ptr<Da
// Function to create a Map dataset. // Function to create a Map dataset.
std::shared_ptr<MapNode> Dataset::Map(std::vector<std::shared_ptr<TensorOperation>> operations, std::shared_ptr<MapNode> Dataset::Map(std::vector<std::shared_ptr<TensorOperation>> operations,
std::vector<std::string> input_columns, std::vector<std::string> output_columns, std::vector<std::string> input_columns, std::vector<std::string> output_columns,
const std::vector<std::string> &project_columns) {
auto ds = std::make_shared<MapNode>(shared_from_this(), operations, input_columns, output_columns, project_columns);
const std::vector<std::string> &project_columns,
const std::shared_ptr<DatasetCache> &cache) {
auto ds =
std::make_shared<MapNode>(shared_from_this(), operations, input_columns, output_columns, project_columns, cache);


if (!ds->ValidateParams()) { if (!ds->ValidateParams()) {
return nullptr; return nullptr;
@@ -464,6 +479,14 @@ std::shared_ptr<ZipNode> Dataset::Zip(const std::vector<std::shared_ptr<Dataset>


return ds->ValidateParams() ? ds : nullptr; return ds->ValidateParams() ? ds : nullptr;
} }
Status Dataset::AddCacheOp(std::vector<std::shared_ptr<DatasetOp>> *node_ops) {
if (cache_ != nullptr) {
std::shared_ptr<DatasetOp> cache_op;
RETURN_IF_NOT_OK(cache_->CreateCacheOp(num_workers_, &cache_op));
node_ops->push_back(cache_op);
}
return Status::OK();
}


SchemaObj::SchemaObj(const std::string &schema_file) : schema_file_(schema_file), num_rows_(0), dataset_type_("") {} SchemaObj::SchemaObj(const std::string &schema_file) : schema_file_(schema_file), num_rows_(0), dataset_type_("") {}


@@ -831,8 +854,13 @@ std::vector<std::shared_ptr<DatasetOp>> AlbumNode::Build() {
// Constructor for CelebANode // Constructor for CelebANode
CelebANode::CelebANode(const std::string &dataset_dir, const std::string &usage, CelebANode::CelebANode(const std::string &dataset_dir, const std::string &usage,
const std::shared_ptr<SamplerObj> &sampler, const bool &decode, const std::shared_ptr<SamplerObj> &sampler, const bool &decode,
const std::set<std::string> &extensions)
: dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler), decode_(decode), extensions_(extensions) {}
const std::set<std::string> &extensions, const std::shared_ptr<DatasetCache> &cache)
: Dataset(cache),
dataset_dir_(dataset_dir),
usage_(usage),
sampler_(sampler),
decode_(decode),
extensions_(extensions) {}


Status CelebANode::ValidateParams() { Status CelebANode::ValidateParams() {
RETURN_IF_NOT_OK(ValidateDatasetDirParam("CelebANode", dataset_dir_)); RETURN_IF_NOT_OK(ValidateDatasetDirParam("CelebANode", dataset_dir_));
@@ -855,15 +883,18 @@ std::vector<std::shared_ptr<DatasetOp>> CelebANode::Build() {
// label is like this:0 1 0 0 1...... // label is like this:0 1 0 0 1......
RETURN_EMPTY_IF_ERROR( RETURN_EMPTY_IF_ERROR(
schema->AddColumn(ColDescriptor("attr", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1))); schema->AddColumn(ColDescriptor("attr", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1)));

node_ops.push_back(std::make_shared<CelebAOp>(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_, node_ops.push_back(std::make_shared<CelebAOp>(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_,
decode_, usage_, extensions_, std::move(schema), decode_, usage_, extensions_, std::move(schema),
std::move(sampler_->Build()))); std::move(sampler_->Build())));

return node_ops; return node_ops;
} }


// Constructor for Cifar10Node // Constructor for Cifar10Node
Cifar10Node::Cifar10Node(const std::string &dataset_dir, const std::string &usage, std::shared_ptr<SamplerObj> sampler)
: dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {}
Cifar10Node::Cifar10Node(const std::string &dataset_dir, const std::string &usage, std::shared_ptr<SamplerObj> sampler,
std::shared_ptr<DatasetCache> cache)
: Dataset(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {}


Status Cifar10Node::ValidateParams() { Status Cifar10Node::ValidateParams() {
RETURN_IF_NOT_OK(ValidateDatasetDirParam("Cifar10Node", dataset_dir_)); RETURN_IF_NOT_OK(ValidateDatasetDirParam("Cifar10Node", dataset_dir_));
@@ -887,16 +918,19 @@ std::vector<std::shared_ptr<DatasetOp>> Cifar10Node::Build() {
RETURN_EMPTY_IF_ERROR( RETURN_EMPTY_IF_ERROR(
schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar))); schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));


RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops));

node_ops.push_back(std::make_shared<CifarOp>(CifarOp::CifarType::kCifar10, usage_, num_workers_, rows_per_buffer_, node_ops.push_back(std::make_shared<CifarOp>(CifarOp::CifarType::kCifar10, usage_, num_workers_, rows_per_buffer_,
dataset_dir_, connector_que_size_, std::move(schema), dataset_dir_, connector_que_size_, std::move(schema),
std::move(sampler_->Build()))); std::move(sampler_->Build())));

return node_ops; return node_ops;
} }


// Constructor for Cifar100Node // Constructor for Cifar100Node
Cifar100Node::Cifar100Node(const std::string &dataset_dir, const std::string &usage, Cifar100Node::Cifar100Node(const std::string &dataset_dir, const std::string &usage,
std::shared_ptr<SamplerObj> sampler)
: dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {}
std::shared_ptr<SamplerObj> sampler, std::shared_ptr<DatasetCache> cache)
: Dataset(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {}


Status Cifar100Node::ValidateParams() { Status Cifar100Node::ValidateParams() {
RETURN_IF_NOT_OK(ValidateDatasetDirParam("Cifar100Node", dataset_dir_)); RETURN_IF_NOT_OK(ValidateDatasetDirParam("Cifar100Node", dataset_dir_));
@@ -922,16 +956,20 @@ std::vector<std::shared_ptr<DatasetOp>> Cifar100Node::Build() {
RETURN_EMPTY_IF_ERROR( RETURN_EMPTY_IF_ERROR(
schema->AddColumn(ColDescriptor("fine_label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar))); schema->AddColumn(ColDescriptor("fine_label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));


RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops));

node_ops.push_back(std::make_shared<CifarOp>(CifarOp::CifarType::kCifar100, usage_, num_workers_, rows_per_buffer_, node_ops.push_back(std::make_shared<CifarOp>(CifarOp::CifarType::kCifar100, usage_, num_workers_, rows_per_buffer_,
dataset_dir_, connector_que_size_, std::move(schema), dataset_dir_, connector_que_size_, std::move(schema),
std::move(sampler_->Build()))); std::move(sampler_->Build())));

return node_ops; return node_ops;
} }


// Constructor for CLUENode // Constructor for CLUENode
CLUENode::CLUENode(const std::vector<std::string> clue_files, std::string task, std::string usage, int64_t num_samples, CLUENode::CLUENode(const std::vector<std::string> clue_files, std::string task, std::string usage, int64_t num_samples,
ShuffleMode shuffle, int32_t num_shards, int32_t shard_id)
: dataset_files_(clue_files),
ShuffleMode shuffle, int32_t num_shards, int32_t shard_id, std::shared_ptr<DatasetCache> cache)
: Dataset(std::move(cache)),
dataset_files_(clue_files),
task_(task), task_(task),
usage_(usage), usage_(usage),
num_samples_(num_samples), num_samples_(num_samples),
@@ -973,6 +1011,7 @@ std::vector<std::string> CLUENode::split(const std::string &s, char delim) {
std::vector<std::shared_ptr<DatasetOp>> CLUENode::Build() { std::vector<std::shared_ptr<DatasetOp>> CLUENode::Build() {
// A vector containing shared pointer to the Dataset Ops that this object will create // A vector containing shared pointer to the Dataset Ops that this object will create
std::vector<std::shared_ptr<DatasetOp>> node_ops; std::vector<std::shared_ptr<DatasetOp>> node_ops;

std::map<std::string, std::string> key_map; std::map<std::string, std::string> key_map;
if (task_ == "AFQMC") { if (task_ == "AFQMC") {
if (usage_ == "train") { if (usage_ == "train") {
@@ -1102,15 +1141,22 @@ std::vector<std::shared_ptr<DatasetOp>> CLUENode::Build() {
rows_per_buffer_, &shuffle_op)); rows_per_buffer_, &shuffle_op));
node_ops.push_back(shuffle_op); node_ops.push_back(shuffle_op);
} }
RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops));


node_ops.push_back(clue_op); node_ops.push_back(clue_op);

return node_ops; return node_ops;
} }


// Constructor for CocoNode // Constructor for CocoNode
CocoNode::CocoNode(const std::string &dataset_dir, const std::string &annotation_file, const std::string &task, CocoNode::CocoNode(const std::string &dataset_dir, const std::string &annotation_file, const std::string &task,
const bool &decode, const std::shared_ptr<SamplerObj> &sampler)
: dataset_dir_(dataset_dir), annotation_file_(annotation_file), task_(task), decode_(decode), sampler_(sampler) {}
const bool &decode, const std::shared_ptr<SamplerObj> &sampler, std::shared_ptr<DatasetCache> cache)
: Dataset(std::move(cache)),
dataset_dir_(dataset_dir),
annotation_file_(annotation_file),
task_(task),
decode_(decode),
sampler_(sampler) {}


Status CocoNode::ValidateParams() { Status CocoNode::ValidateParams() {
RETURN_IF_NOT_OK(ValidateDatasetDirParam("CocoNode", dataset_dir_)); RETURN_IF_NOT_OK(ValidateDatasetDirParam("CocoNode", dataset_dir_));
@@ -1186,7 +1232,10 @@ std::vector<std::shared_ptr<DatasetOp>> CocoNode::Build() {
std::shared_ptr<CocoOp> op = std::shared_ptr<CocoOp> op =
std::make_shared<CocoOp>(task_type, dataset_dir_, annotation_file_, num_workers_, rows_per_buffer_, std::make_shared<CocoOp>(task_type, dataset_dir_, annotation_file_, num_workers_, rows_per_buffer_,
connector_que_size_, decode_, std::move(schema), std::move(sampler_->Build())); connector_que_size_, decode_, std::move(schema), std::move(sampler_->Build()));
RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops));

node_ops.push_back(op); node_ops.push_back(op);

return node_ops; return node_ops;
} }


@@ -1194,8 +1243,9 @@ std::vector<std::shared_ptr<DatasetOp>> CocoNode::Build() {
CSVNode::CSVNode(const std::vector<std::string> &csv_files, char field_delim, CSVNode::CSVNode(const std::vector<std::string> &csv_files, char field_delim,
const std::vector<std::shared_ptr<CsvBase>> &column_defaults, const std::vector<std::shared_ptr<CsvBase>> &column_defaults,
const std::vector<std::string> &column_names, int64_t num_samples, ShuffleMode shuffle, const std::vector<std::string> &column_names, int64_t num_samples, ShuffleMode shuffle,
int32_t num_shards, int32_t shard_id)
: dataset_files_(csv_files),
int32_t num_shards, int32_t shard_id, std::shared_ptr<DatasetCache> cache)
: Dataset(std::move(cache)),
dataset_files_(csv_files),
field_delim_(field_delim), field_delim_(field_delim),
column_defaults_(column_defaults), column_defaults_(column_defaults),
column_names_(column_names), column_names_(column_names),
@@ -1274,17 +1324,26 @@ std::vector<std::shared_ptr<DatasetOp>> CSVNode::Build() {
// Add the shuffle op after this op // Add the shuffle op after this op
RETURN_EMPTY_IF_ERROR(AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_, RETURN_EMPTY_IF_ERROR(AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_,
rows_per_buffer_, &shuffle_op)); rows_per_buffer_, &shuffle_op));

node_ops.push_back(shuffle_op); node_ops.push_back(shuffle_op);
} }
RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops));


node_ops.push_back(csv_op); node_ops.push_back(csv_op);

return node_ops; return node_ops;
} }
#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
ManifestNode::ManifestNode(const std::string &dataset_file, const std::string &usage, ManifestNode::ManifestNode(const std::string &dataset_file, const std::string &usage,
const std::shared_ptr<SamplerObj> &sampler, const std::shared_ptr<SamplerObj> &sampler,
const std::map<std::string, int32_t> &class_indexing, bool decode)
: dataset_file_(dataset_file), usage_(usage), decode_(decode), class_index_(class_indexing), sampler_(sampler) {}
const std::map<std::string, int32_t> &class_indexing, bool decode,
std::shared_ptr<DatasetCache> cache)
: Dataset(std::move(cache)),
dataset_file_(dataset_file),
usage_(usage),
decode_(decode),
class_index_(class_indexing),
sampler_(sampler) {}


Status ManifestNode::ValidateParams() { Status ManifestNode::ValidateParams() {
std::vector<char> forbidden_symbols = {':', '*', '?', '"', '<', '>', '|', '`', '&', '\'', ';'}; std::vector<char> forbidden_symbols = {':', '*', '?', '"', '<', '>', '|', '`', '&', '\'', ';'};
@@ -1326,8 +1385,10 @@ std::vector<std::shared_ptr<DatasetOp>> ManifestNode::Build() {
manifest_op = manifest_op =
std::make_shared<ManifestOp>(num_workers_, rows_per_buffer_, dataset_file_, connector_que_size_, decode_, std::make_shared<ManifestOp>(num_workers_, rows_per_buffer_, dataset_file_, connector_que_size_, decode_,
class_index_, std::move(schema), std::move(sampler_->Build()), usage_); class_index_, std::move(schema), std::move(sampler_->Build()), usage_);
RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops));


node_ops.push_back(manifest_op); node_ops.push_back(manifest_op);

return node_ops; return node_ops;
} }
#endif #endif
@@ -1466,8 +1527,9 @@ std::vector<std::shared_ptr<DatasetOp>> MindDataNode::Build() {
} }
#endif #endif


MnistNode::MnistNode(std::string dataset_dir, std::string usage, std::shared_ptr<SamplerObj> sampler)
: dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {}
MnistNode::MnistNode(std::string dataset_dir, std::string usage, std::shared_ptr<SamplerObj> sampler,
std::shared_ptr<DatasetCache> cache)
: Dataset(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {}


Status MnistNode::ValidateParams() { Status MnistNode::ValidateParams() {
RETURN_IF_NOT_OK(ValidateDatasetDirParam("MnistNode", dataset_dir_)); RETURN_IF_NOT_OK(ValidateDatasetDirParam("MnistNode", dataset_dir_));
@@ -1489,9 +1551,11 @@ std::vector<std::shared_ptr<DatasetOp>> MnistNode::Build() {
TensorShape scalar = TensorShape::CreateScalar(); TensorShape scalar = TensorShape::CreateScalar();
RETURN_EMPTY_IF_ERROR( RETURN_EMPTY_IF_ERROR(
schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar))); schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops));


node_ops.push_back(std::make_shared<MnistOp>(usage_, num_workers_, rows_per_buffer_, dataset_dir_, node_ops.push_back(std::make_shared<MnistOp>(usage_, num_workers_, rows_per_buffer_, dataset_dir_,
connector_que_size_, std::move(schema), std::move(sampler_->Build()))); connector_que_size_, std::move(schema), std::move(sampler_->Build())));

return node_ops; return node_ops;
} }


@@ -1560,14 +1624,18 @@ std::vector<std::shared_ptr<DatasetOp>> RandomNode::Build() {
std::shared_ptr<RandomDataOp> op; std::shared_ptr<RandomDataOp> op;
op = std::make_shared<RandomDataOp>(num_workers_, connector_que_size_, rows_per_buffer_, total_rows_, op = std::make_shared<RandomDataOp>(num_workers_, connector_que_size_, rows_per_buffer_, total_rows_,
std::move(data_schema), std::move(sampler_->Build())); std::move(data_schema), std::move(sampler_->Build()));
RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops));

node_ops.push_back(op); node_ops.push_back(op);

return node_ops; return node_ops;
} }


// Constructor for TextFileNode // Constructor for TextFileNode
TextFileNode::TextFileNode(std::vector<std::string> dataset_files, int32_t num_samples, ShuffleMode shuffle, TextFileNode::TextFileNode(std::vector<std::string> dataset_files, int32_t num_samples, ShuffleMode shuffle,
int32_t num_shards, int32_t shard_id)
: dataset_files_(dataset_files),
int32_t num_shards, int32_t shard_id, std::shared_ptr<DatasetCache> cache)
: Dataset(std::move(cache)),
dataset_files_(dataset_files),
num_samples_(num_samples), num_samples_(num_samples),
shuffle_(shuffle), shuffle_(shuffle),
num_shards_(num_shards), num_shards_(num_shards),
@@ -1622,9 +1690,11 @@ std::vector<std::shared_ptr<DatasetOp>> TextFileNode::Build() {
rows_per_buffer_, &shuffle_op)); rows_per_buffer_, &shuffle_op));
node_ops.push_back(shuffle_op); node_ops.push_back(shuffle_op);
} }
RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops));


// Add TextFileOp // Add TextFileOp
node_ops.push_back(text_file_op); node_ops.push_back(text_file_op);

return node_ops; return node_ops;
} }


@@ -1673,6 +1743,7 @@ std::vector<std::shared_ptr<DatasetOp>> TFRecordNode::Build() {
rows_per_buffer_, &shuffle_op)); rows_per_buffer_, &shuffle_op));
node_ops.push_back(shuffle_op); node_ops.push_back(shuffle_op);
} }
RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops));


// Add TFReaderOp // Add TFReaderOp
node_ops.push_back(tf_reader_op); node_ops.push_back(tf_reader_op);
@@ -1681,8 +1752,10 @@ std::vector<std::shared_ptr<DatasetOp>> TFRecordNode::Build() {


// Constructor for VOCNode // Constructor for VOCNode
VOCNode::VOCNode(const std::string &dataset_dir, const std::string &task, const std::string &usage, VOCNode::VOCNode(const std::string &dataset_dir, const std::string &task, const std::string &usage,
const std::map<std::string, int32_t> &class_indexing, bool decode, std::shared_ptr<SamplerObj> sampler)
: dataset_dir_(dataset_dir),
const std::map<std::string, int32_t> &class_indexing, bool decode, std::shared_ptr<SamplerObj> sampler,
std::shared_ptr<DatasetCache> cache)
: Dataset(std::move(cache)),
dataset_dir_(dataset_dir),
task_(task), task_(task),
usage_(usage), usage_(usage),
class_index_(class_indexing), class_index_(class_indexing),
@@ -1755,9 +1828,18 @@ std::vector<std::shared_ptr<DatasetOp>> VOCNode::Build() {
std::shared_ptr<VOCOp> voc_op; std::shared_ptr<VOCOp> voc_op;
voc_op = std::make_shared<VOCOp>(task_type_, usage_, dataset_dir_, class_index_, num_workers_, rows_per_buffer_, voc_op = std::make_shared<VOCOp>(task_type_, usage_, dataset_dir_, class_index_, num_workers_, rows_per_buffer_,
connector_que_size_, decode_, std::move(schema), std::move(sampler_->Build())); connector_que_size_, decode_, std::move(schema), std::move(sampler_->Build()));
RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops));

node_ops.push_back(voc_op); node_ops.push_back(voc_op);
return node_ops; return node_ops;
} }
std::shared_ptr<DatasetCache> CreateDatasetCache(session_id_type id, uint64_t mem_sz, bool spill,
std::optional<std::string> hostname, std::optional<int32_t> port,
std::optional<int32_t> num_connections,
std::optional<int32_t> prefetch_sz) {
auto cache = std::make_shared<DatasetCacheImpl>(id, mem_sz, spill, hostname, port, num_connections, prefetch_sz);
return cache->ValidateParams() ? cache : nullptr;
}
#endif #endif


#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
@@ -1766,11 +1848,12 @@ std::vector<std::shared_ptr<DatasetOp>> VOCNode::Build() {


MapNode::MapNode(std::shared_ptr<Dataset> child, std::vector<std::shared_ptr<TensorOperation>> operations, MapNode::MapNode(std::shared_ptr<Dataset> child, std::vector<std::shared_ptr<TensorOperation>> operations,
std::vector<std::string> input_columns, std::vector<std::string> output_columns, std::vector<std::string> input_columns, std::vector<std::string> output_columns,
const std::vector<std::string> &project_columns)
const std::vector<std::string> &project_columns, std::shared_ptr<DatasetCache> cache)
: operations_(operations), : operations_(operations),
input_columns_(input_columns), input_columns_(input_columns),
output_columns_(output_columns), output_columns_(output_columns),
project_columns_(project_columns) {
project_columns_(project_columns),
Dataset(std::move(cache)) {
this->children.push_back(child); this->children.push_back(child);
} }


@@ -1793,6 +1876,7 @@ std::vector<std::shared_ptr<DatasetOp>> MapNode::Build() {
auto project_op = std::make_shared<ProjectOp>(project_columns_); auto project_op = std::make_shared<ProjectOp>(project_columns_);
node_ops.push_back(project_op); node_ops.push_back(project_op);
} }
RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops));


node_ops.push_back(map_op); node_ops.push_back(map_op);
return node_ops; return node_ops;


+ 2
- 1
mindspore/ccsrc/minddata/dataset/engine/ir/CMakeLists.txt View File

@@ -1,3 +1,4 @@
file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
add_subdirectory(datasetops)
add_subdirectory(datasetops)
add_subdirectory(cache)

+ 4
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/cache/CMakeLists.txt View File

@@ -0,0 +1,4 @@
file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
add_library(engine-ir-cache OBJECT
dataset_cache_impl.cc)

+ 34
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/cache/dataset_cache.h View File

@@ -0,0 +1,34 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_CACHE_DATASET_CACHE_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_CACHE_DATASET_CACHE_H_

#include <memory>

#include "minddata/dataset/util/status.h"
#include "minddata/dataset/engine/datasetops/dataset_op.h"

namespace mindspore::dataset::api {

class DatasetCache {
public:
virtual Status Build() = 0;
virtual Status ValidateParams() = 0;
virtual Status CreateCacheOp(int num_workers, std::shared_ptr<DatasetOp> *ds_op) = 0;
};
} // namespace mindspore::dataset::api

#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_CACHE_DATASET_CACHE_H_

+ 44
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/cache/dataset_cache_impl.cc View File

@@ -0,0 +1,44 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <memory>

#include "minddata/dataset/engine/ir/cache/dataset_cache_impl.h"
#include "minddata/dataset/engine/datasetops/cache_op.h"

namespace mindspore::dataset::api {

/// Method to initialize the DatasetCache by creating an instance of a CacheClient
/// \return Status Error code
Status DatasetCacheImpl::Build() {
CacheClient::Builder builder;
builder.SetSessionId(session_id_).SetCacheMemSz(cache_mem_sz_).SetSpill(spill_);
if (hostname_) builder.SetHostname(hostname_.value());
if (port_) builder.SetPort(port_.value());
if (num_connections_) builder.SetNumConnections(num_connections_.value());
if (prefetch_sz_) builder.SetPrefetchSize(prefetch_sz_.value());
return builder.Build(&cache_client_);
}

Status DatasetCacheImpl::CreateCacheOp(int32_t num_workers, std::shared_ptr<DatasetOp> *ds) {
CHECK_FAIL_RETURN_UNEXPECTED(cache_client_ != nullptr, "Cache client has not been created yet.");
std::shared_ptr<CacheOp> cache_op = nullptr;
RETURN_IF_NOT_OK(CacheOp::Builder().SetNumWorkers(num_workers).SetClient(cache_client_).Build(&cache_op));
*ds = cache_op;

return Status::OK();
}

} // namespace mindspore::dataset::api

+ 72
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/cache/dataset_cache_impl.h View File

@@ -0,0 +1,72 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_CACHE_DATASET_CACHE_IMPL_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_CACHE_DATASET_CACHE_IMPL_H_

#include <memory>
#include <string>
#include <optional>
#include <utility>
#include "minddata/dataset/engine/cache/cache_client.h"
#include "minddata/dataset/engine/datasetops/cache_op.h"
#include "minddata/dataset/engine/ir/cache/dataset_cache.h"

namespace mindspore::dataset::api {

/// DatasetCache is the IR of CacheClient
class DatasetCacheImpl : public DatasetCache {
public:
///
/// \brief Constructor
/// \param id A user assigned session id for the current pipeline
/// \param mem_sz Size of the memory set aside for the row caching. 0 for unlimited
/// \param spill Spill to disk if out of memory
/// \param hostname optional host name
/// \param port optional port
/// \param num_connections optional number of connections
/// \param prefetch_sz optional prefetch size
DatasetCacheImpl(session_id_type id, uint64_t mem_sz, bool spill, std::optional<std::string> hostname,
std::optional<int32_t> port, std::optional<int32_t> num_connections,
std::optional<int32_t> prefetch_sz)
: session_id_(id),
cache_mem_sz_(mem_sz),
spill_(spill),
hostname_(std::move(hostname)),
port_(std::move(port)),
num_connections_(std::move(num_connections)),
prefetch_sz_(std::move(prefetch_sz)) {}

/// Method to initialize the DatasetCache by creating an instance of a CacheClient
/// \return Status Error code
Status Build() override;

Status CreateCacheOp(int32_t num_workers, std::shared_ptr<DatasetOp> *ds) override;

Status ValidateParams() override { return Status::OK(); }

private:
std::shared_ptr<CacheClient> cache_client_;
session_id_type session_id_;
uint64_t cache_mem_sz_;
bool spill_;
std::optional<std::string> hostname_;
std::optional<int32_t> port_;
std::optional<int32_t> num_connections_;
std::optional<int32_t> prefetch_sz_;
};
} // namespace mindspore::dataset::api

#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_CACHE_DATASET_CACHE_IMPL_H_

+ 7
- 2
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/image_folder_node.cc View File

@@ -32,13 +32,15 @@ namespace api {


ImageFolderNode::ImageFolderNode(std::string dataset_dir, bool decode, std::shared_ptr<SamplerObj> sampler, ImageFolderNode::ImageFolderNode(std::string dataset_dir, bool decode, std::shared_ptr<SamplerObj> sampler,
bool recursive, std::set<std::string> extensions, bool recursive, std::set<std::string> extensions,
std::map<std::string, int32_t> class_indexing)
std::map<std::string, int32_t> class_indexing,
std::shared_ptr<DatasetCache> cache = nullptr)
: dataset_dir_(dataset_dir), : dataset_dir_(dataset_dir),
decode_(decode), decode_(decode),
sampler_(sampler), sampler_(sampler),
recursive_(recursive), recursive_(recursive),
class_indexing_(class_indexing), class_indexing_(class_indexing),
exts_(extensions) {}
exts_(extensions),
Dataset(std::move(cache)) {}


Status ImageFolderNode::ValidateParams() { Status ImageFolderNode::ValidateParams() {
RETURN_IF_NOT_OK(ValidateDatasetDirParam("ImageFolderNode", dataset_dir_)); RETURN_IF_NOT_OK(ValidateDatasetDirParam("ImageFolderNode", dataset_dir_));
@@ -60,6 +62,9 @@ std::vector<std::shared_ptr<DatasetOp>> ImageFolderNode::Build() {
schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1))); schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1)));
RETURN_EMPTY_IF_ERROR( RETURN_EMPTY_IF_ERROR(
schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_INT32), TensorImpl::kFlexible, 0, &scalar))); schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_INT32), TensorImpl::kFlexible, 0, &scalar)));

RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops));

node_ops.push_back(std::make_shared<ImageFolderOp>(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_, node_ops.push_back(std::make_shared<ImageFolderOp>(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_,
recursive_, decode_, exts_, class_indexing_, std::move(schema), recursive_, decode_, exts_, class_indexing_, std::move(schema),
std::move(sampler_->Build()))); std::move(sampler_->Build())));


+ 3
- 1
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/image_folder_node.h View File

@@ -23,6 +23,7 @@
#include <string> #include <string>
#include <vector> #include <vector>


#include "mindspore/ccsrc/minddata/dataset/engine/ir/cache/dataset_cache.h"
#include "minddata/dataset/include/datasets.h" #include "minddata/dataset/include/datasets.h"


namespace mindspore { namespace mindspore {
@@ -36,7 +37,8 @@ class ImageFolderNode : public Dataset {
public: public:
/// \brief Constructor /// \brief Constructor
ImageFolderNode(std::string dataset_dir, bool decode, std::shared_ptr<SamplerObj> sampler, bool recursive, ImageFolderNode(std::string dataset_dir, bool decode, std::shared_ptr<SamplerObj> sampler, bool recursive,
std::set<std::string> extensions, std::map<std::string, int32_t> class_indexing);
std::set<std::string> extensions, std::map<std::string, int32_t> class_indexing,
std::shared_ptr<DatasetCache> cache);


/// \brief Destructor /// \brief Destructor
~ImageFolderNode() = default; ~ImageFolderNode() = default;


+ 115
- 38
mindspore/ccsrc/minddata/dataset/include/datasets.h View File

@@ -25,7 +25,9 @@
#include <unordered_set> #include <unordered_set>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "mindspore/ccsrc/minddata/dataset/engine/ir/cache/dataset_cache.h"
#include "minddata/dataset/core/constants.h" #include "minddata/dataset/core/constants.h"

#include "minddata/dataset/engine/data_schema.h" #include "minddata/dataset/engine/data_schema.h"
#include "minddata/dataset/include/iterator.h" #include "minddata/dataset/include/iterator.h"
#include "minddata/dataset/include/samplers.h" #include "minddata/dataset/include/samplers.h"
@@ -147,10 +149,13 @@ std::shared_ptr<AlbumNode> Album(const std::string &dataset_dir, const std::stri
/// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler()) /// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler())
/// \param[in] decode Decode the images after reading (default=false). /// \param[in] decode Decode the images after reading (default=false).
/// \param[in] extensions Set of file extensions to be included in the dataset (default={}). /// \param[in] extensions Set of file extensions to be included in the dataset (default={}).
/// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used).
/// The cache feature is under development and is not recommended.
/// \return Shared pointer to the current Dataset /// \return Shared pointer to the current Dataset
std::shared_ptr<CelebANode> CelebA(const std::string &dataset_dir, const std::string &usage = "all", std::shared_ptr<CelebANode> CelebA(const std::string &dataset_dir, const std::string &usage = "all",
const std::shared_ptr<SamplerObj> &sampler = RandomSampler(), bool decode = false, const std::shared_ptr<SamplerObj> &sampler = RandomSampler(), bool decode = false,
const std::set<std::string> &extensions = {});
const std::set<std::string> &extensions = {},
const std::shared_ptr<DatasetCache> &cache = nullptr);


/// \brief Function to create a Cifar10 Dataset /// \brief Function to create a Cifar10 Dataset
/// \notes The generated dataset has two columns ["image", "label"] /// \notes The generated dataset has two columns ["image", "label"]
@@ -158,9 +163,12 @@ std::shared_ptr<CelebANode> CelebA(const std::string &dataset_dir, const std::st
/// \param[in] usage of CIFAR10, can be "train", "test" or "all" (default = "all"). /// \param[in] usage of CIFAR10, can be "train", "test" or "all" (default = "all").
/// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given, /// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given,
/// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler()) /// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler())
/// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used).
/// The cache feature is under development and is not recommended.
/// \return Shared pointer to the current Dataset /// \return Shared pointer to the current Dataset
std::shared_ptr<Cifar10Node> Cifar10(const std::string &dataset_dir, const std::string &usage = "all", std::shared_ptr<Cifar10Node> Cifar10(const std::string &dataset_dir, const std::string &usage = "all",
const std::shared_ptr<SamplerObj> &sampler = RandomSampler());
const std::shared_ptr<SamplerObj> &sampler = RandomSampler(),
const std::shared_ptr<DatasetCache> &cache = nullptr);


/// \brief Function to create a Cifar100 Dataset /// \brief Function to create a Cifar100 Dataset
/// \notes The generated dataset has three columns ["image", "coarse_label", "fine_label"] /// \notes The generated dataset has three columns ["image", "coarse_label", "fine_label"]
@@ -168,9 +176,12 @@ std::shared_ptr<Cifar10Node> Cifar10(const std::string &dataset_dir, const std::
/// \param[in] usage of CIFAR100, can be "train", "test" or "all" (default = "all"). /// \param[in] usage of CIFAR100, can be "train", "test" or "all" (default = "all").
/// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given, /// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given,
/// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler()) /// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler())
/// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used).
/// The cache feature is under development and is not recommended.
/// \return Shared pointer to the current Dataset /// \return Shared pointer to the current Dataset
std::shared_ptr<Cifar100Node> Cifar100(const std::string &dataset_dir, const std::string &usage = "all", std::shared_ptr<Cifar100Node> Cifar100(const std::string &dataset_dir, const std::string &usage = "all",
const std::shared_ptr<SamplerObj> &sampler = RandomSampler());
const std::shared_ptr<SamplerObj> &sampler = RandomSampler(),
const std::shared_ptr<DatasetCache> &cache = nullptr);


/// \brief Function to create a CLUENode /// \brief Function to create a CLUENode
/// \notes The generated dataset has a variable number of columns depending on the task and usage /// \notes The generated dataset has a variable number of columns depending on the task and usage
@@ -188,11 +199,13 @@ std::shared_ptr<Cifar100Node> Cifar100(const std::string &dataset_dir, const std
/// \param[in] num_shards Number of shards that the dataset should be divided into. (Default = 1) /// \param[in] num_shards Number of shards that the dataset should be divided into. (Default = 1)
/// \param[in] shard_id The shard ID within num_shards. This argument should be /// \param[in] shard_id The shard ID within num_shards. This argument should be
/// specified only when num_shards is also specified. (Default = 0) /// specified only when num_shards is also specified. (Default = 0)
/// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used).
/// The cache feature is under development and is not recommended.
/// \return Shared pointer to the current CLUENode /// \return Shared pointer to the current CLUENode
std::shared_ptr<CLUENode> CLUE(const std::vector<std::string> &dataset_files, const std::string &task = "AFQMC", std::shared_ptr<CLUENode> CLUE(const std::vector<std::string> &dataset_files, const std::string &task = "AFQMC",
const std::string &usage = "train", int64_t num_samples = 0, const std::string &usage = "train", int64_t num_samples = 0,
ShuffleMode shuffle = ShuffleMode::kGlobal, int32_t num_shards = 1,
int32_t shard_id = 0);
ShuffleMode shuffle = ShuffleMode::kGlobal, int32_t num_shards = 1, int32_t shard_id = 0,
const std::shared_ptr<DatasetCache> &cache = nullptr);


/// \brief Function to create a CocoNode /// \brief Function to create a CocoNode
/// \notes The generated dataset has multi-columns : /// \notes The generated dataset has multi-columns :
@@ -209,10 +222,13 @@ std::shared_ptr<CLUENode> CLUE(const std::vector<std::string> &dataset_files, co
/// \param[in] decode Decode the images after reading /// \param[in] decode Decode the images after reading
/// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given, /// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given,
/// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler()) /// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler())
/// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used).
/// The cache feature is under development and is not recommended.
/// \return Shared pointer to the current Dataset /// \return Shared pointer to the current Dataset
std::shared_ptr<CocoNode> Coco(const std::string &dataset_dir, const std::string &annotation_file, std::shared_ptr<CocoNode> Coco(const std::string &dataset_dir, const std::string &annotation_file,
const std::string &task = "Detection", const bool &decode = false, const std::string &task = "Detection", const bool &decode = false,
const std::shared_ptr<SamplerObj> &sampler = RandomSampler());
const std::shared_ptr<SamplerObj> &sampler = RandomSampler(),
const std::shared_ptr<DatasetCache> &cache = nullptr);


/// \brief Function to create a CSVNode /// \brief Function to create a CSVNode
/// \notes The generated dataset has a variable number of columns /// \notes The generated dataset has a variable number of columns
@@ -233,11 +249,14 @@ std::shared_ptr<CocoNode> Coco(const std::string &dataset_dir, const std::string
/// \param[in] num_shards Number of shards that the dataset should be divided into. (Default = 1) /// \param[in] num_shards Number of shards that the dataset should be divided into. (Default = 1)
/// \param[in] shard_id The shard ID within num_shards. This argument should be /// \param[in] shard_id The shard ID within num_shards. This argument should be
/// specified only when num_shards is also specified. (Default = 0) /// specified only when num_shards is also specified. (Default = 0)
/// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used).
/// The cache feature is under development and is not recommended.
/// \return Shared pointer to the current Dataset /// \return Shared pointer to the current Dataset
std::shared_ptr<CSVNode> CSV(const std::vector<std::string> &dataset_files, char field_delim = ',', std::shared_ptr<CSVNode> CSV(const std::vector<std::string> &dataset_files, char field_delim = ',',
const std::vector<std::shared_ptr<CsvBase>> &column_defaults = {}, const std::vector<std::shared_ptr<CsvBase>> &column_defaults = {},
const std::vector<std::string> &column_names = {}, int64_t num_samples = 0, const std::vector<std::string> &column_names = {}, int64_t num_samples = 0,
ShuffleMode shuffle = ShuffleMode::kGlobal, int32_t num_shards = 1, int32_t shard_id = 0);
ShuffleMode shuffle = ShuffleMode::kGlobal, int32_t num_shards = 1, int32_t shard_id = 0,
const std::shared_ptr<DatasetCache> &cache = nullptr);


/// \brief Function to create an ImageFolderNode /// \brief Function to create an ImageFolderNode
/// \notes A source dataset that reads images from a tree of directories /// \notes A source dataset that reads images from a tree of directories
@@ -249,11 +268,14 @@ std::shared_ptr<CSVNode> CSV(const std::vector<std::string> &dataset_files, char
/// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler()) /// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler())
/// \param[in] extensions File extensions to be read /// \param[in] extensions File extensions to be read
/// \param[in] class_indexing a class name to label map /// \param[in] class_indexing a class name to label map
/// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used).
/// The cache feature is under development and is not recommended.
/// \return Shared pointer to the current ImageFolderNode /// \return Shared pointer to the current ImageFolderNode
std::shared_ptr<ImageFolderNode> ImageFolder(const std::string &dataset_dir, bool decode = false, std::shared_ptr<ImageFolderNode> ImageFolder(const std::string &dataset_dir, bool decode = false,
const std::shared_ptr<SamplerObj> &sampler = RandomSampler(), const std::shared_ptr<SamplerObj> &sampler = RandomSampler(),
const std::set<std::string> &extensions = {}, const std::set<std::string> &extensions = {},
const std::map<std::string, int32_t> &class_indexing = {});
const std::map<std::string, int32_t> &class_indexing = {},
const std::shared_ptr<DatasetCache> &cache = nullptr);


#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
/// \brief Function to create a ManifestNode /// \brief Function to create a ManifestNode
@@ -265,10 +287,13 @@ std::shared_ptr<ImageFolderNode> ImageFolder(const std::string &dataset_dir, boo
/// \param[in] class_indexing A str-to-int mapping from label name to index (default={}, the folder /// \param[in] class_indexing A str-to-int mapping from label name to index (default={}, the folder
/// names will be sorted alphabetically and each class will be given a unique index starting from 0). /// names will be sorted alphabetically and each class will be given a unique index starting from 0).
/// \param[in] decode Decode the images after reading (default=false). /// \param[in] decode Decode the images after reading (default=false).
/// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used).
/// The cache feature is under development and is not recommended.
/// \return Shared pointer to the current ManifestNode /// \return Shared pointer to the current ManifestNode
std::shared_ptr<ManifestNode> Manifest(const std::string &dataset_file, const std::string &usage = "train", std::shared_ptr<ManifestNode> Manifest(const std::string &dataset_file, const std::string &usage = "train",
const std::shared_ptr<SamplerObj> &sampler = RandomSampler(), const std::shared_ptr<SamplerObj> &sampler = RandomSampler(),
const std::map<std::string, int32_t> &class_indexing = {}, bool decode = false);
const std::map<std::string, int32_t> &class_indexing = {}, bool decode = false,
const std::shared_ptr<DatasetCache> &cache = nullptr);
#endif #endif


#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
@@ -308,9 +333,12 @@ std::shared_ptr<MindDataNode> MindData(const std::vector<std::string> &dataset_f
/// \param[in] usage of MNIST, can be "train", "test" or "all" (default = "all"). /// \param[in] usage of MNIST, can be "train", "test" or "all" (default = "all").
/// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given, /// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given,
/// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler()) /// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler())
/// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used).
/// The cache feature is under development and is not recommended.
/// \return Shared pointer to the current MnistNode /// \return Shared pointer to the current MnistNode
std::shared_ptr<MnistNode> Mnist(const std::string &dataset_dir, const std::string &usage = "all", std::shared_ptr<MnistNode> Mnist(const std::string &dataset_dir, const std::string &usage = "all",
const std::shared_ptr<SamplerObj> &sampler = RandomSampler());
const std::shared_ptr<SamplerObj> &sampler = RandomSampler(),
const std::shared_ptr<DatasetCache> &cache = nullptr);


/// \brief Function to create a ConcatNode /// \brief Function to create a ConcatNode
/// \notes Reload "+" operator to concat two datasets /// \notes Reload "+" operator to concat two datasets
@@ -326,11 +354,14 @@ std::shared_ptr<ConcatNode> operator+(const std::shared_ptr<Dataset> &datasets1,
/// \param[in] columns_list List of columns to be read (default={}, read all columns) /// \param[in] columns_list List of columns to be read (default={}, read all columns)
/// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given, /// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given,
/// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler()) /// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler())
/// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used).
/// The cache feature is under development and is not recommended.
/// \return Shared pointer to the current Dataset /// \return Shared pointer to the current Dataset
template <typename T = std::shared_ptr<SchemaObj>> template <typename T = std::shared_ptr<SchemaObj>>
std::shared_ptr<RandomNode> RandomData(const int32_t &total_rows = 0, const T &schema = nullptr, std::shared_ptr<RandomNode> RandomData(const int32_t &total_rows = 0, const T &schema = nullptr,
const std::vector<std::string> &columns_list = {}, const std::vector<std::string> &columns_list = {},
const std::shared_ptr<SamplerObj> &sampler = RandomSampler()) {
const std::shared_ptr<SamplerObj> &sampler = RandomSampler(),
const std::shared_ptr<DatasetCache> &cache = nullptr) {
if (total_rows < 0) { if (total_rows < 0) {
MS_LOG(ERROR) << "RandomNode: total_rows must be greater than or equal 0, now get " << total_rows; MS_LOG(ERROR) << "RandomNode: total_rows must be greater than or equal 0, now get " << total_rows;
return nullptr; return nullptr;
@@ -356,9 +387,11 @@ std::shared_ptr<RandomNode> RandomData(const int32_t &total_rows = 0, const T &s
std::shared_ptr<RandomNode> ds; std::shared_ptr<RandomNode> ds;
if constexpr (std::is_same<T, std::nullptr_t>::value || std::is_same<T, std::shared_ptr<SchemaObj>>::value) { if constexpr (std::is_same<T, std::nullptr_t>::value || std::is_same<T, std::shared_ptr<SchemaObj>>::value) {
std::shared_ptr<SchemaObj> schema_obj = schema; std::shared_ptr<SchemaObj> schema_obj = schema;
ds = std::make_shared<RandomNode>(total_rows, std::move(schema_obj), std::move(columns_list), std::move(sampler));
ds = std::make_shared<RandomNode>(total_rows, std::move(schema_obj), std::move(columns_list), std::move(sampler),
cache);
} else { } else {
ds = std::make_shared<RandomNode>(total_rows, std::move(schema), std::move(columns_list), std::move(sampler));
ds =
std::make_shared<RandomNode>(total_rows, std::move(schema), std::move(columns_list), std::move(sampler), cache);
} }
return ds; return ds;
} }
@@ -377,10 +410,12 @@ std::shared_ptr<RandomNode> RandomData(const int32_t &total_rows = 0, const T &s
/// \param[in] num_shards Number of shards that the dataset should be divided into. (Default = 1) /// \param[in] num_shards Number of shards that the dataset should be divided into. (Default = 1)
/// \param[in] shard_id The shard ID within num_shards. This argument should be /// \param[in] shard_id The shard ID within num_shards. This argument should be
/// specified only when num_shards is also specified. (Default = 0) /// specified only when num_shards is also specified. (Default = 0)
/// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used).
/// The cache feature is under development and is not recommended.
/// \return Shared pointer to the current TextFileNode /// \return Shared pointer to the current TextFileNode
std::shared_ptr<TextFileNode> TextFile(const std::vector<std::string> &dataset_files, int64_t num_samples = 0, std::shared_ptr<TextFileNode> TextFile(const std::vector<std::string> &dataset_files, int64_t num_samples = 0,
ShuffleMode shuffle = ShuffleMode::kGlobal, int32_t num_shards = 1, ShuffleMode shuffle = ShuffleMode::kGlobal, int32_t num_shards = 1,
int32_t shard_id = 0);
int32_t shard_id = 0, const std::shared_ptr<DatasetCache> &cache = nullptr);


#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
/// \brief Function to create a TFRecordNode /// \brief Function to create a TFRecordNode
@@ -404,12 +439,15 @@ std::shared_ptr<TextFileNode> TextFile(const std::vector<std::string> &dataset_f
/// when num_shards is also specified. (Default = 0) /// when num_shards is also specified. (Default = 0)
/// \param[in] shard_equal_rows Get equal rows for all shards. (Default = False, number of rows of /// \param[in] shard_equal_rows Get equal rows for all shards. (Default = False, number of rows of
/// each shard may be not equal) /// each shard may be not equal)
/// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used).
/// The cache feature is under development and is not recommended.
/// \return Shared pointer to the current TFRecordNode /// \return Shared pointer to the current TFRecordNode
template <typename T = std::shared_ptr<SchemaObj>> template <typename T = std::shared_ptr<SchemaObj>>
std::shared_ptr<TFRecordNode> TFRecord(const std::vector<std::string> &dataset_files, const T &schema = nullptr, std::shared_ptr<TFRecordNode> TFRecord(const std::vector<std::string> &dataset_files, const T &schema = nullptr,
const std::vector<std::string> &columns_list = {}, int64_t num_samples = 0, const std::vector<std::string> &columns_list = {}, int64_t num_samples = 0,
ShuffleMode shuffle = ShuffleMode::kGlobal, int32_t num_shards = 1, ShuffleMode shuffle = ShuffleMode::kGlobal, int32_t num_shards = 1,
int32_t shard_id = 0, bool shard_equal_rows = false) {
int32_t shard_id = 0, bool shard_equal_rows = false,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
if (dataset_files.empty()) { if (dataset_files.empty()) {
MS_LOG(ERROR) << "TFRecordNode: dataset_files is not specified."; MS_LOG(ERROR) << "TFRecordNode: dataset_files is not specified.";
return nullptr; return nullptr;
@@ -441,7 +479,7 @@ std::shared_ptr<TFRecordNode> TFRecord(const std::vector<std::string> &dataset_f
if constexpr (std::is_same<T, std::nullptr_t>::value || std::is_same<T, std::shared_ptr<SchemaObj>>::value) { if constexpr (std::is_same<T, std::nullptr_t>::value || std::is_same<T, std::shared_ptr<SchemaObj>>::value) {
std::shared_ptr<SchemaObj> schema_obj = schema; std::shared_ptr<SchemaObj> schema_obj = schema;
ds = std::make_shared<TFRecordNode>(dataset_files, schema_obj, columns_list, num_samples, shuffle, num_shards, ds = std::make_shared<TFRecordNode>(dataset_files, schema_obj, columns_list, num_samples, shuffle, num_shards,
shard_id, shard_equal_rows);
shard_id, shard_equal_rows, cache);
} else { } else {
std::string schema_path = schema; std::string schema_path = schema;
if (!schema_path.empty()) { if (!schema_path.empty()) {
@@ -452,7 +490,7 @@ std::shared_ptr<TFRecordNode> TFRecord(const std::vector<std::string> &dataset_f
} }
} }
ds = std::make_shared<TFRecordNode>(dataset_files, schema_path, columns_list, num_samples, shuffle, num_shards, ds = std::make_shared<TFRecordNode>(dataset_files, schema_path, columns_list, num_samples, shuffle, num_shards,
shard_id, shard_equal_rows);
shard_id, shard_equal_rows, cache);
} }
return ds; return ds;
} }
@@ -469,11 +507,28 @@ std::shared_ptr<TFRecordNode> TFRecord(const std::vector<std::string> &dataset_f
/// \param[in] decode Decode the images after reading /// \param[in] decode Decode the images after reading
/// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given, /// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given,
/// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler()) /// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler())
/// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used).
/// The cache feature is under development and is not recommended.
/// \return Shared pointer to the current Dataset /// \return Shared pointer to the current Dataset
std::shared_ptr<VOCNode> VOC(const std::string &dataset_dir, const std::string &task = "Segmentation", std::shared_ptr<VOCNode> VOC(const std::string &dataset_dir, const std::string &task = "Segmentation",
const std::string &usage = "train", const std::string &usage = "train",
const std::map<std::string, int32_t> &class_indexing = {}, bool decode = false, const std::map<std::string, int32_t> &class_indexing = {}, bool decode = false,
const std::shared_ptr<SamplerObj> &sampler = RandomSampler());
const std::shared_ptr<SamplerObj> &sampler = RandomSampler(),
const std::shared_ptr<DatasetCache> &cache = nullptr);

/// \brief Function the create a cache to be attached to a dataset
/// \param id A user assigned session id for the current pipeline
/// \param mem_sz Size of the memory set aside for the row caching. 0 for unlimited
/// \param spill Spill to disk if out of memory
/// \param hostname optional host name
/// \param port optional port
/// \param num_connections optional number of connections
/// \param prefetch_sz optional prefetch size
/// \return Shared pointer to DatasetCache. If error, nullptr is returned.
std::shared_ptr<DatasetCache> CreateDatasetCache(session_id_type id, uint64_t mem_sz, bool spill,
std::optional<std::string> hostname, std::optional<int32_t> port,
std::optional<int32_t> num_connections,
std::optional<int32_t> prefetch_sz);
#endif #endif


/// \brief Function to create a ZipNode /// \brief Function to create a ZipNode
@@ -493,6 +548,10 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
/// \brief Constructor /// \brief Constructor
Dataset(); Dataset();


/// \brief Constructor that initializes the cache
/// \param dataset_cache DatasetCache
explicit Dataset(const std::shared_ptr<DatasetCache> &dataset_cache);

/// \brief Destructor /// \brief Destructor
~Dataset() = default; ~Dataset() = default;


@@ -610,11 +669,14 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
/// last operation. The default output_columns will have the same /// last operation. The default output_columns will have the same
/// name as the input columns, i.e., the columns will be replaced /// name as the input columns, i.e., the columns will be replaced
/// \param[in] project_columns A list of column names to project /// \param[in] project_columns A list of column names to project
/// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used).
/// The cache feature is under development and is not recommended.
/// \return Shared pointer to the current MapNode /// \return Shared pointer to the current MapNode
std::shared_ptr<MapNode> Map(std::vector<std::shared_ptr<TensorOperation>> operations, std::shared_ptr<MapNode> Map(std::vector<std::shared_ptr<TensorOperation>> operations,
std::vector<std::string> input_columns = {}, std::vector<std::string> input_columns = {},
std::vector<std::string> output_columns = {}, std::vector<std::string> output_columns = {},
const std::vector<std::string> &project_columns = {});
const std::vector<std::string> &project_columns = {},
const std::shared_ptr<DatasetCache> &cache = nullptr);


/// \brief Function to create a Project Dataset /// \brief Function to create a Project Dataset
/// \notes Applies project to the dataset /// \notes Applies project to the dataset
@@ -670,6 +732,9 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
int32_t rows_per_buffer_; int32_t rows_per_buffer_;
int32_t connector_que_size_; int32_t connector_que_size_;
int32_t worker_connector_size_; int32_t worker_connector_size_;

std::shared_ptr<DatasetCache> cache_;
Status AddCacheOp(std::vector<std::shared_ptr<DatasetOp>> *node_ops);
}; };


class SchemaObj { class SchemaObj {
@@ -766,7 +831,7 @@ class CelebANode : public Dataset {
public: public:
/// \brief Constructor /// \brief Constructor
CelebANode(const std::string &dataset_dir, const std::string &usage, const std::shared_ptr<SamplerObj> &sampler, CelebANode(const std::string &dataset_dir, const std::string &usage, const std::shared_ptr<SamplerObj> &sampler,
const bool &decode, const std::set<std::string> &extensions);
const bool &decode, const std::set<std::string> &extensions, const std::shared_ptr<DatasetCache> &cache);


/// \brief Destructor /// \brief Destructor
~CelebANode() = default; ~CelebANode() = default;
@@ -792,7 +857,8 @@ class CelebANode : public Dataset {
class Cifar10Node : public Dataset { class Cifar10Node : public Dataset {
public: public:
/// \brief Constructor /// \brief Constructor
Cifar10Node(const std::string &dataset_dir, const std::string &usage, std::shared_ptr<SamplerObj> sampler);
Cifar10Node(const std::string &dataset_dir, const std::string &usage, std::shared_ptr<SamplerObj> sampler,
std::shared_ptr<DatasetCache> cache);


/// \brief Destructor /// \brief Destructor
~Cifar10Node() = default; ~Cifar10Node() = default;
@@ -814,7 +880,8 @@ class Cifar10Node : public Dataset {
class Cifar100Node : public Dataset { class Cifar100Node : public Dataset {
public: public:
/// \brief Constructor /// \brief Constructor
Cifar100Node(const std::string &dataset_dir, const std::string &usage, std::shared_ptr<SamplerObj> sampler);
Cifar100Node(const std::string &dataset_dir, const std::string &usage, std::shared_ptr<SamplerObj> sampler,
std::shared_ptr<DatasetCache> cache);


/// \brief Destructor /// \brief Destructor
~Cifar100Node() = default; ~Cifar100Node() = default;
@@ -839,7 +906,7 @@ class CLUENode : public Dataset {
public: public:
/// \brief Constructor /// \brief Constructor
CLUENode(const std::vector<std::string> dataset_files, std::string task, std::string usage, int64_t num_samples, CLUENode(const std::vector<std::string> dataset_files, std::string task, std::string usage, int64_t num_samples,
ShuffleMode shuffle, int32_t num_shards, int32_t shard_id);
ShuffleMode shuffle, int32_t num_shards, int32_t shard_id, std::shared_ptr<DatasetCache> cache);


/// \brief Destructor /// \brief Destructor
~CLUENode() = default; ~CLUENode() = default;
@@ -870,7 +937,7 @@ class CocoNode : public Dataset {
public: public:
/// \brief Constructor /// \brief Constructor
CocoNode(const std::string &dataset_dir, const std::string &annotation_file, const std::string &task, CocoNode(const std::string &dataset_dir, const std::string &annotation_file, const std::string &task,
const bool &decode, const std::shared_ptr<SamplerObj> &sampler);
const bool &decode, const std::shared_ptr<SamplerObj> &sampler, std::shared_ptr<DatasetCache> cache);


/// \brief Destructor /// \brief Destructor
~CocoNode() = default; ~CocoNode() = default;
@@ -918,7 +985,8 @@ class CSVNode : public Dataset {
/// \brief Constructor /// \brief Constructor
CSVNode(const std::vector<std::string> &dataset_files, char field_delim, CSVNode(const std::vector<std::string> &dataset_files, char field_delim,
const std::vector<std::shared_ptr<CsvBase>> &column_defaults, const std::vector<std::string> &column_names, const std::vector<std::shared_ptr<CsvBase>> &column_defaults, const std::vector<std::string> &column_names,
int64_t num_samples, ShuffleMode shuffle, int32_t num_shards, int32_t shard_id);
int64_t num_samples, ShuffleMode shuffle, int32_t num_shards, int32_t shard_id,
std::shared_ptr<DatasetCache> cache);


/// \brief Destructor /// \brief Destructor
~CSVNode() = default; ~CSVNode() = default;
@@ -947,7 +1015,7 @@ class ManifestNode : public Dataset {
public: public:
/// \brief Constructor /// \brief Constructor
ManifestNode(const std::string &dataset_file, const std::string &usage, const std::shared_ptr<SamplerObj> &sampler, ManifestNode(const std::string &dataset_file, const std::string &usage, const std::shared_ptr<SamplerObj> &sampler,
const std::map<std::string, int32_t> &class_indexing, bool decode);
const std::map<std::string, int32_t> &class_indexing, bool decode, std::shared_ptr<DatasetCache> cache);


/// \brief Destructor /// \brief Destructor
~ManifestNode() = default; ~ManifestNode() = default;
@@ -1016,7 +1084,8 @@ class MindDataNode : public Dataset {
class MnistNode : public Dataset { class MnistNode : public Dataset {
public: public:
/// \brief Constructor /// \brief Constructor
MnistNode(std::string dataset_dir, std::string usage, std::shared_ptr<SamplerObj> sampler);
MnistNode(std::string dataset_dir, std::string usage, std::shared_ptr<SamplerObj> sampler,
std::shared_ptr<DatasetCache> cache);


/// \brief Destructor /// \brief Destructor
~MnistNode() = default; ~MnistNode() = default;
@@ -1044,8 +1113,9 @@ class RandomNode : public Dataset {


/// \brief Constructor /// \brief Constructor
RandomNode(const int32_t &total_rows, std::shared_ptr<SchemaObj> schema, const std::vector<std::string> &columns_list, RandomNode(const int32_t &total_rows, std::shared_ptr<SchemaObj> schema, const std::vector<std::string> &columns_list,
const std::shared_ptr<SamplerObj> &sampler)
: total_rows_(total_rows),
const std::shared_ptr<SamplerObj> &sampler, std::shared_ptr<DatasetCache> cache)
: Dataset(std::move(cache)),
total_rows_(total_rows),
schema_path_(""), schema_path_(""),
schema_(std::move(schema)), schema_(std::move(schema)),
columns_list_(columns_list), columns_list_(columns_list),
@@ -1053,8 +1123,12 @@ class RandomNode : public Dataset {


/// \brief Constructor /// \brief Constructor
RandomNode(const int32_t &total_rows, std::string schema_path, const std::vector<std::string> &columns_list, RandomNode(const int32_t &total_rows, std::string schema_path, const std::vector<std::string> &columns_list,
const std::shared_ptr<SamplerObj> &sampler)
: total_rows_(total_rows), schema_path_(schema_path), columns_list_(columns_list), sampler_(std::move(sampler)) {}
const std::shared_ptr<SamplerObj> &sampler, std::shared_ptr<DatasetCache> cache)
: Dataset(std::move(cache)),
total_rows_(total_rows),
schema_path_(schema_path),
columns_list_(columns_list),
sampler_(std::move(sampler)) {}


/// \brief Destructor /// \brief Destructor
~RandomNode() = default; ~RandomNode() = default;
@@ -1088,7 +1162,7 @@ class TextFileNode : public Dataset {
public: public:
/// \brief Constructor /// \brief Constructor
TextFileNode(std::vector<std::string> dataset_files, int32_t num_samples, ShuffleMode shuffle, int32_t num_shards, TextFileNode(std::vector<std::string> dataset_files, int32_t num_samples, ShuffleMode shuffle, int32_t num_shards,
int32_t shard_id);
int32_t shard_id, std::shared_ptr<DatasetCache> cache);


/// \brief Destructor /// \brief Destructor
~TextFileNode() = default; ~TextFileNode() = default;
@@ -1117,8 +1191,9 @@ class TFRecordNode : public Dataset {
/// \note Parameter 'schema' is the path to the schema file /// \note Parameter 'schema' is the path to the schema file
TFRecordNode(const std::vector<std::string> &dataset_files, std::string schema, TFRecordNode(const std::vector<std::string> &dataset_files, std::string schema,
const std::vector<std::string> &columns_list, int64_t num_samples, ShuffleMode shuffle, const std::vector<std::string> &columns_list, int64_t num_samples, ShuffleMode shuffle,
int32_t num_shards, int32_t shard_id, bool shard_equal_rows)
: dataset_files_(dataset_files),
int32_t num_shards, int32_t shard_id, bool shard_equal_rows, std::shared_ptr<DatasetCache> cache)
: Dataset(std::move(cache)),
dataset_files_(dataset_files),
schema_path_(schema), schema_path_(schema),
columns_list_(columns_list), columns_list_(columns_list),
num_samples_(num_samples), num_samples_(num_samples),
@@ -1131,8 +1206,9 @@ class TFRecordNode : public Dataset {
/// \note Parameter 'schema' is shared pointer to Schema object /// \note Parameter 'schema' is shared pointer to Schema object
TFRecordNode(const std::vector<std::string> &dataset_files, std::shared_ptr<SchemaObj> schema, TFRecordNode(const std::vector<std::string> &dataset_files, std::shared_ptr<SchemaObj> schema,
const std::vector<std::string> &columns_list, int64_t num_samples, ShuffleMode shuffle, const std::vector<std::string> &columns_list, int64_t num_samples, ShuffleMode shuffle,
int32_t num_shards, int32_t shard_id, bool shard_equal_rows)
: dataset_files_(dataset_files),
int32_t num_shards, int32_t shard_id, bool shard_equal_rows, std::shared_ptr<DatasetCache> cache)
: Dataset(std::move(cache)),
dataset_files_(dataset_files),
schema_obj_(schema), schema_obj_(schema),
columns_list_(columns_list), columns_list_(columns_list),
num_samples_(num_samples), num_samples_(num_samples),
@@ -1169,7 +1245,8 @@ class VOCNode : public Dataset {
public: public:
/// \brief Constructor /// \brief Constructor
VOCNode(const std::string &dataset_dir, const std::string &task, const std::string &usage, VOCNode(const std::string &dataset_dir, const std::string &task, const std::string &usage,
const std::map<std::string, int32_t> &class_indexing, bool decode, std::shared_ptr<SamplerObj> sampler);
const std::map<std::string, int32_t> &class_indexing, bool decode, std::shared_ptr<SamplerObj> sampler,
std::shared_ptr<DatasetCache> cache);


/// \brief Destructor /// \brief Destructor
~VOCNode() = default; ~VOCNode() = default;
@@ -1206,7 +1283,7 @@ class MapNode : public Dataset {
/// \brief Constructor /// \brief Constructor
MapNode(std::shared_ptr<Dataset> child, std::vector<std::shared_ptr<TensorOperation>> operations, MapNode(std::shared_ptr<Dataset> child, std::vector<std::shared_ptr<TensorOperation>> operations,
std::vector<std::string> input_columns = {}, std::vector<std::string> output_columns = {}, std::vector<std::string> input_columns = {}, std::vector<std::string> output_columns = {},
const std::vector<std::string> &columns = {});
const std::vector<std::string> &columns = {}, std::shared_ptr<DatasetCache> cache = nullptr);


/// \brief Destructor /// \brief Destructor
~MapNode() = default; ~MapNode() = default;


Loading…
Cancel
Save