|
|
|
@@ -105,7 +105,7 @@ Dataset::Dataset() { |
|
|
|
|
|
|
|
// Function to create a CelebADataset. |
|
|
|
std::shared_ptr<CelebADataset> CelebA(const std::string &dataset_dir, const std::string &dataset_type, |
|
|
|
const std::shared_ptr<SamplerObj> &sampler, const bool &decode, |
|
|
|
const std::shared_ptr<SamplerObj> &sampler, bool decode, |
|
|
|
const std::set<std::string> &extensions) { |
|
|
|
auto ds = std::make_shared<CelebADataset>(dataset_dir, dataset_type, sampler, decode, extensions); |
|
|
|
|
|
|
|
@@ -114,7 +114,7 @@ std::shared_ptr<CelebADataset> CelebA(const std::string &dataset_dir, const std: |
|
|
|
} |
|
|
|
|
|
|
|
// Function to create a Cifar10Dataset. |
|
|
|
std::shared_ptr<Cifar10Dataset> Cifar10(const std::string &dataset_dir, std::shared_ptr<SamplerObj> sampler) { |
|
|
|
std::shared_ptr<Cifar10Dataset> Cifar10(const std::string &dataset_dir, const std::shared_ptr<SamplerObj> &sampler) { |
|
|
|
auto ds = std::make_shared<Cifar10Dataset>(dataset_dir, sampler); |
|
|
|
|
|
|
|
// Call derived class validation method. |
|
|
|
@@ -122,7 +122,7 @@ std::shared_ptr<Cifar10Dataset> Cifar10(const std::string &dataset_dir, std::sha |
|
|
|
} |
|
|
|
|
|
|
|
// Function to create a Cifar100Dataset. |
|
|
|
std::shared_ptr<Cifar100Dataset> Cifar100(const std::string &dataset_dir, std::shared_ptr<SamplerObj> sampler) { |
|
|
|
std::shared_ptr<Cifar100Dataset> Cifar100(const std::string &dataset_dir, const std::shared_ptr<SamplerObj> &sampler) { |
|
|
|
auto ds = std::make_shared<Cifar100Dataset>(dataset_dir, sampler); |
|
|
|
|
|
|
|
// Call derived class validation method. |
|
|
|
@@ -131,8 +131,8 @@ std::shared_ptr<Cifar100Dataset> Cifar100(const std::string &dataset_dir, std::s |
|
|
|
|
|
|
|
// Function to create a CLUEDataset. |
|
|
|
std::shared_ptr<CLUEDataset> CLUE(const std::vector<std::string> &clue_files, const std::string &task, |
|
|
|
const std::string &usage, int64_t num_samples, ShuffleMode shuffle, int num_shards, |
|
|
|
int shard_id) { |
|
|
|
const std::string &usage, int64_t num_samples, ShuffleMode shuffle, |
|
|
|
int32_t num_shards, int32_t shard_id) { |
|
|
|
auto ds = std::make_shared<CLUEDataset>(clue_files, task, usage, num_samples, shuffle, num_shards, shard_id); |
|
|
|
|
|
|
|
// Call derived class validation method. |
|
|
|
@@ -150,9 +150,10 @@ std::shared_ptr<CocoDataset> Coco(const std::string &dataset_dir, const std::str |
|
|
|
} |
|
|
|
|
|
|
|
// Function to create a ImageFolderDataset. |
|
|
|
std::shared_ptr<ImageFolderDataset> ImageFolder(std::string dataset_dir, bool decode, |
|
|
|
std::shared_ptr<SamplerObj> sampler, std::set<std::string> extensions, |
|
|
|
std::map<std::string, int32_t> class_indexing) { |
|
|
|
std::shared_ptr<ImageFolderDataset> ImageFolder(const std::string &dataset_dir, bool decode, |
|
|
|
const std::shared_ptr<SamplerObj> &sampler, |
|
|
|
const std::set<std::string> &extensions, |
|
|
|
const std::map<std::string, int32_t> &class_indexing) { |
|
|
|
// This arg exists in ImageFolderOp, but not externalized (in Python API). The default value is false. |
|
|
|
bool recursive = false; |
|
|
|
|
|
|
|
@@ -164,7 +165,7 @@ std::shared_ptr<ImageFolderDataset> ImageFolder(std::string dataset_dir, bool de |
|
|
|
} |
|
|
|
|
|
|
|
// Function to create a MnistDataset. |
|
|
|
std::shared_ptr<MnistDataset> Mnist(std::string dataset_dir, std::shared_ptr<SamplerObj> sampler) { |
|
|
|
std::shared_ptr<MnistDataset> Mnist(const std::string &dataset_dir, const std::shared_ptr<SamplerObj> &sampler) { |
|
|
|
auto ds = std::make_shared<MnistDataset>(dataset_dir, sampler); |
|
|
|
|
|
|
|
// Call derived class validation method. |
|
|
|
@@ -181,7 +182,7 @@ std::shared_ptr<ConcatDataset> operator+(const std::shared_ptr<Dataset> &dataset |
|
|
|
} |
|
|
|
|
|
|
|
// Function to create a TextFileDataset. |
|
|
|
std::shared_ptr<TextFileDataset> TextFile(std::vector<std::string> dataset_files, int32_t num_samples, |
|
|
|
std::shared_ptr<TextFileDataset> TextFile(const std::vector<std::string> &dataset_files, int32_t num_samples, |
|
|
|
ShuffleMode shuffle, int32_t num_shards, int32_t shard_id) { |
|
|
|
auto ds = std::make_shared<TextFileDataset>(dataset_files, num_samples, shuffle, num_shards, shard_id); |
|
|
|
|
|
|
|
@@ -191,9 +192,9 @@ std::shared_ptr<TextFileDataset> TextFile(std::vector<std::string> dataset_files |
|
|
|
|
|
|
|
// Function to create a VOCDataset. |
|
|
|
std::shared_ptr<VOCDataset> VOC(const std::string &dataset_dir, const std::string &task, const std::string &mode, |
|
|
|
const std::map<std::string, int32_t> &class_index, bool decode, |
|
|
|
std::shared_ptr<SamplerObj> sampler) { |
|
|
|
auto ds = std::make_shared<VOCDataset>(dataset_dir, task, mode, class_index, decode, sampler); |
|
|
|
const std::map<std::string, int32_t> &class_indexing, bool decode, |
|
|
|
const std::shared_ptr<SamplerObj> &sampler) { |
|
|
|
auto ds = std::make_shared<VOCDataset>(dataset_dir, task, mode, class_indexing, decode, sampler); |
|
|
|
|
|
|
|
// Call derived class validation method. |
|
|
|
return ds->ValidateParams() ? ds : nullptr; |
|
|
|
@@ -402,16 +403,57 @@ Status AddShuffleOp(int64_t num_files, int64_t num_devices, int64_t num_rows, in |
|
|
|
return Status::OK(); |
|
|
|
} |
|
|
|
|
|
|
|
// Helper function to validate dataset params |
|
|
|
bool ValidateCommonDatasetParams(std::string dataset_dir) { |
|
|
|
// Helper function to validate dataset directory parameter |
|
|
|
bool ValidateDatasetDirParam(const std::string &dataset_name, std::string dataset_dir) { |
|
|
|
if (dataset_dir.empty()) { |
|
|
|
MS_LOG(ERROR) << "No dataset path is specified"; |
|
|
|
MS_LOG(ERROR) << dataset_name << ": dataset_dir is not specified."; |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
Path dir(dataset_dir); |
|
|
|
if (!dir.IsDirectory()) { |
|
|
|
MS_LOG(ERROR) << dataset_name << ": dataset_dir: [" << dataset_dir << "] is an invalid directory path."; |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
if (access(dataset_dir.c_str(), R_OK) == -1) { |
|
|
|
MS_LOG(ERROR) << "No access to specified dataset path: " << dataset_dir; |
|
|
|
MS_LOG(ERROR) << dataset_name << ": No access to specified dataset path: " << dataset_dir; |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
// Helper function to validate dataset dataset files parameter |
|
|
|
bool ValidateDatasetFilesParam(const std::string &dataset_name, const std::vector<std::string> &dataset_files) { |
|
|
|
if (dataset_files.empty()) { |
|
|
|
MS_LOG(ERROR) << dataset_name << ": dataset_files is not specified."; |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
for (auto f : dataset_files) { |
|
|
|
Path dataset_file(f); |
|
|
|
if (!dataset_file.Exists()) { |
|
|
|
MS_LOG(ERROR) << dataset_name << ": dataset file: [" << f << "] is invalid or does not exist."; |
|
|
|
return false; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
// Helper function to validate dataset num_shards and shard_id parameters |
|
|
|
bool ValidateDatasetShardParams(const std::string &dataset_name, int32_t num_shards, int32_t shard_id) { |
|
|
|
if (num_shards <= 0) { |
|
|
|
MS_LOG(ERROR) << dataset_name << ": Invalid num_shards: " << num_shards; |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
if (shard_id < 0 || shard_id >= num_shards) { |
|
|
|
MS_LOG(ERROR) << dataset_name << ": Invalid input, shard_id: " << shard_id << ", num_shards: " << num_shards; |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
@@ -431,9 +473,7 @@ CelebADataset::CelebADataset(const std::string &dataset_dir, const std::string & |
|
|
|
extensions_(extensions) {} |
|
|
|
|
|
|
|
bool CelebADataset::ValidateParams() { |
|
|
|
Path dir(dataset_dir_); |
|
|
|
if (!dir.IsDirectory()) { |
|
|
|
MS_LOG(ERROR) << "Invalid dataset path or no dataset path is specified."; |
|
|
|
if (!ValidateDatasetDirParam("CelebADataset", dataset_dir_)) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
std::set<std::string> dataset_type_list = {"all", "train", "valid", "test"}; |
|
|
|
@@ -471,7 +511,7 @@ std::vector<std::shared_ptr<DatasetOp>> CelebADataset::Build() { |
|
|
|
Cifar10Dataset::Cifar10Dataset(const std::string &dataset_dir, std::shared_ptr<SamplerObj> sampler) |
|
|
|
: dataset_dir_(dataset_dir), sampler_(sampler) {} |
|
|
|
|
|
|
|
bool Cifar10Dataset::ValidateParams() { return ValidateCommonDatasetParams(dataset_dir_); } |
|
|
|
bool Cifar10Dataset::ValidateParams() { return ValidateDatasetDirParam("Cifar10Dataset", dataset_dir_); } |
|
|
|
|
|
|
|
// Function to build CifarOp for Cifar10 |
|
|
|
std::vector<std::shared_ptr<DatasetOp>> Cifar10Dataset::Build() { |
|
|
|
@@ -500,7 +540,7 @@ std::vector<std::shared_ptr<DatasetOp>> Cifar10Dataset::Build() { |
|
|
|
Cifar100Dataset::Cifar100Dataset(const std::string &dataset_dir, std::shared_ptr<SamplerObj> sampler) |
|
|
|
: dataset_dir_(dataset_dir), sampler_(sampler) {} |
|
|
|
|
|
|
|
bool Cifar100Dataset::ValidateParams() { return ValidateCommonDatasetParams(dataset_dir_); } |
|
|
|
bool Cifar100Dataset::ValidateParams() { return ValidateDatasetDirParam("Cifar100Dataset", dataset_dir_); } |
|
|
|
|
|
|
|
// Function to build CifarOp for Cifar100 |
|
|
|
std::vector<std::shared_ptr<DatasetOp>> Cifar100Dataset::Build() { |
|
|
|
@@ -529,7 +569,7 @@ std::vector<std::shared_ptr<DatasetOp>> Cifar100Dataset::Build() { |
|
|
|
|
|
|
|
// Constructor for CLUEDataset |
|
|
|
CLUEDataset::CLUEDataset(const std::vector<std::string> clue_files, std::string task, std::string usage, |
|
|
|
int64_t num_samples, ShuffleMode shuffle, int num_shards, int shard_id) |
|
|
|
int64_t num_samples, ShuffleMode shuffle, int32_t num_shards, int32_t shard_id) |
|
|
|
: dataset_files_(clue_files), |
|
|
|
task_(task), |
|
|
|
usage_(usage), |
|
|
|
@@ -539,19 +579,10 @@ CLUEDataset::CLUEDataset(const std::vector<std::string> clue_files, std::string |
|
|
|
shard_id_(shard_id) {} |
|
|
|
|
|
|
|
bool CLUEDataset::ValidateParams() { |
|
|
|
if (dataset_files_.empty()) { |
|
|
|
MS_LOG(ERROR) << "CLUEDataset: dataset_files is not specified."; |
|
|
|
if (!ValidateDatasetFilesParam("CLUEDataset", dataset_files_)) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
for (auto f : dataset_files_) { |
|
|
|
Path clue_file(f); |
|
|
|
if (!clue_file.Exists()) { |
|
|
|
MS_LOG(ERROR) << "dataset file: [" << f << "] is invalid or not exist"; |
|
|
|
return false; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<std::string> task_list = {"AFQMC", "TNEWS", "IFLYTEK", "CMNLI", "WSC", "CSL"}; |
|
|
|
std::vector<std::string> usage_list = {"train", "test", "eval"}; |
|
|
|
|
|
|
|
@@ -570,13 +601,7 @@ bool CLUEDataset::ValidateParams() { |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
if (num_shards_ <= 0) { |
|
|
|
MS_LOG(ERROR) << "CLUEDataset: Invalid num_shards: " << num_shards_; |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
if (shard_id_ < 0 || shard_id_ >= num_shards_) { |
|
|
|
MS_LOG(ERROR) << "CLUEDataset: Invalid input, shard_id: " << shard_id_ << ", num_shards: " << num_shards_; |
|
|
|
if (!ValidateDatasetShardParams("CLUEDataset", num_shards_, shard_id_)) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
@@ -734,9 +759,7 @@ CocoDataset::CocoDataset(const std::string &dataset_dir, const std::string &anno |
|
|
|
: dataset_dir_(dataset_dir), annotation_file_(annotation_file), task_(task), decode_(decode), sampler_(sampler) {} |
|
|
|
|
|
|
|
bool CocoDataset::ValidateParams() { |
|
|
|
Path dir(dataset_dir_); |
|
|
|
if (!dir.IsDirectory()) { |
|
|
|
MS_LOG(ERROR) << "Invalid dataset path or no dataset path is specified."; |
|
|
|
if (!ValidateDatasetDirParam("CocoDataset", dataset_dir_)) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
Path annotation_file(annotation_file_); |
|
|
|
@@ -829,7 +852,7 @@ ImageFolderDataset::ImageFolderDataset(std::string dataset_dir, bool decode, std |
|
|
|
class_indexing_(class_indexing), |
|
|
|
exts_(extensions) {} |
|
|
|
|
|
|
|
bool ImageFolderDataset::ValidateParams() { return ValidateCommonDatasetParams(dataset_dir_); } |
|
|
|
bool ImageFolderDataset::ValidateParams() { return ValidateDatasetDirParam("ImageFolderDataset", dataset_dir_); } |
|
|
|
|
|
|
|
std::vector<std::shared_ptr<DatasetOp>> ImageFolderDataset::Build() { |
|
|
|
// A vector containing shared pointer to the Dataset Ops that this object will create |
|
|
|
@@ -857,7 +880,7 @@ std::vector<std::shared_ptr<DatasetOp>> ImageFolderDataset::Build() { |
|
|
|
MnistDataset::MnistDataset(std::string dataset_dir, std::shared_ptr<SamplerObj> sampler) |
|
|
|
: dataset_dir_(dataset_dir), sampler_(sampler) {} |
|
|
|
|
|
|
|
bool MnistDataset::ValidateParams() { return ValidateCommonDatasetParams(dataset_dir_); } |
|
|
|
bool MnistDataset::ValidateParams() { return ValidateDatasetDirParam("MnistDataset", dataset_dir_); } |
|
|
|
|
|
|
|
std::vector<std::shared_ptr<DatasetOp>> MnistDataset::Build() { |
|
|
|
// A vector containing shared pointer to the Dataset Ops that this object will create |
|
|
|
@@ -890,31 +913,16 @@ TextFileDataset::TextFileDataset(std::vector<std::string> dataset_files, int32_t |
|
|
|
shard_id_(shard_id) {} |
|
|
|
|
|
|
|
bool TextFileDataset::ValidateParams() { |
|
|
|
if (dataset_files_.empty()) { |
|
|
|
MS_LOG(ERROR) << "TextFileDataset: dataset_files is not specified."; |
|
|
|
if (!ValidateDatasetFilesParam("TextFileDataset", dataset_files_)) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
for (auto file : dataset_files_) { |
|
|
|
std::ifstream handle(file); |
|
|
|
if (!handle.is_open()) { |
|
|
|
MS_LOG(ERROR) << "TextFileDataset: Failed to open file: " << file; |
|
|
|
return false; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
if (num_samples_ < 0) { |
|
|
|
MS_LOG(ERROR) << "TextFileDataset: Invalid number of samples: " << num_samples_; |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
if (num_shards_ <= 0) { |
|
|
|
MS_LOG(ERROR) << "TextFileDataset: Invalid num_shards: " << num_shards_; |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
if (shard_id_ < 0 || shard_id_ >= num_shards_) { |
|
|
|
MS_LOG(ERROR) << "TextFileDataset: Invalid input, shard_id: " << shard_id_ << ", num_shards: " << num_shards_; |
|
|
|
if (!ValidateDatasetShardParams("TextfileDataset", num_shards_, shard_id_)) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
@@ -960,12 +968,12 @@ std::vector<std::shared_ptr<DatasetOp>> TextFileDataset::Build() { |
|
|
|
|
|
|
|
// Constructor for VOCDataset |
|
|
|
VOCDataset::VOCDataset(const std::string &dataset_dir, const std::string &task, const std::string &mode, |
|
|
|
const std::map<std::string, int32_t> &class_index, bool decode, |
|
|
|
const std::map<std::string, int32_t> &class_indexing, bool decode, |
|
|
|
std::shared_ptr<SamplerObj> sampler) |
|
|
|
: dataset_dir_(dataset_dir), |
|
|
|
task_(task), |
|
|
|
mode_(mode), |
|
|
|
class_index_(class_index), |
|
|
|
class_index_(class_indexing), |
|
|
|
decode_(decode), |
|
|
|
sampler_(sampler) {} |
|
|
|
|
|
|
|
|