|
|
|
@@ -424,6 +424,7 @@ class SchemaObj { |
|
|
|
class BatchDataset : public Dataset { |
|
|
|
public: |
|
|
|
BatchDataset(std::shared_ptr<Dataset> input, int32_t batch_size, bool drop_remainder = false); |
|
|
|
~BatchDataset() = default; |
|
|
|
}; |
|
|
|
|
|
|
|
#ifndef ENABLE_ANDROID |
|
|
|
@@ -435,17 +436,20 @@ class BucketBatchByLengthDataset : public Dataset { |
|
|
|
std::function<TensorRow(TensorRow)> element_length_function = nullptr, |
|
|
|
const std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> &pad_info = {}, |
|
|
|
bool pad_to_bucket_boundary = false, bool drop_remainder = false); |
|
|
|
~BucketBatchByLengthDataset() = default; |
|
|
|
}; |
|
|
|
|
|
|
|
class ConcatDataset : public Dataset { |
|
|
|
public: |
|
|
|
explicit ConcatDataset(const std::vector<std::shared_ptr<Dataset>> &input); |
|
|
|
~ConcatDataset() = default; |
|
|
|
}; |
|
|
|
|
|
|
|
class FilterDataset : public Dataset { |
|
|
|
public: |
|
|
|
FilterDataset(std::shared_ptr<Dataset> input, std::function<TensorRow(TensorRow)> predicate, |
|
|
|
std::vector<std::string> input_columns); |
|
|
|
~FilterDataset() = default; |
|
|
|
}; |
|
|
|
#endif |
|
|
|
|
|
|
|
@@ -455,11 +459,13 @@ class MapDataset : public Dataset { |
|
|
|
std::vector<std::string> input_columns, std::vector<std::string> output_columns, |
|
|
|
const std::vector<std::string> &project_columns, const std::shared_ptr<DatasetCache> &cache, |
|
|
|
std::vector<std::shared_ptr<DSCallback>> callbacks); |
|
|
|
~MapDataset() = default; |
|
|
|
}; |
|
|
|
|
|
|
|
class ProjectDataset : public Dataset { |
|
|
|
public: |
|
|
|
ProjectDataset(std::shared_ptr<Dataset> input, const std::vector<std::string> &columns); |
|
|
|
~ProjectDataset() = default; |
|
|
|
}; |
|
|
|
|
|
|
|
#ifndef ENABLE_ANDROID |
|
|
|
@@ -473,27 +479,32 @@ class RenameDataset : public Dataset { |
|
|
|
class RepeatDataset : public Dataset { |
|
|
|
public: |
|
|
|
RepeatDataset(std::shared_ptr<Dataset> input, int32_t count); |
|
|
|
~RepeatDataset() = default; |
|
|
|
}; |
|
|
|
|
|
|
|
class ShuffleDataset : public Dataset { |
|
|
|
public: |
|
|
|
ShuffleDataset(std::shared_ptr<Dataset> input, int32_t buffer_size); |
|
|
|
~ShuffleDataset() = default; |
|
|
|
}; |
|
|
|
|
|
|
|
#ifndef ENABLE_ANDROID |
|
|
|
class SkipDataset : public Dataset { |
|
|
|
public: |
|
|
|
SkipDataset(std::shared_ptr<Dataset> input, int32_t count); |
|
|
|
~SkipDataset() = default; |
|
|
|
}; |
|
|
|
|
|
|
|
class TakeDataset : public Dataset { |
|
|
|
public: |
|
|
|
TakeDataset(std::shared_ptr<Dataset> input, int32_t count); |
|
|
|
~TakeDataset() = default; |
|
|
|
}; |
|
|
|
|
|
|
|
class ZipDataset : public Dataset { |
|
|
|
public: |
|
|
|
explicit ZipDataset(const std::vector<std::shared_ptr<Dataset>> &inputs); |
|
|
|
~ZipDataset() = default; |
|
|
|
}; |
|
|
|
#endif |
|
|
|
/// \brief Function to create a SchemaObj |
|
|
|
@@ -507,6 +518,7 @@ class AlbumDataset : public Dataset { |
|
|
|
const std::vector<std::string> &column_names = {}, bool decode = false, |
|
|
|
const std::shared_ptr<SamplerObj> &sampler = RandomSampler(), |
|
|
|
const std::shared_ptr<DatasetCache> &cache = nullptr); |
|
|
|
~AlbumDataset() = default; |
|
|
|
}; |
|
|
|
|
|
|
|
/// \brief Function to create an AlbumDataset |
|
|
|
@@ -533,6 +545,7 @@ class CelebADataset : public Dataset { |
|
|
|
const std::shared_ptr<SamplerObj> &sampler = RandomSampler(), bool decode = false, |
|
|
|
const std::set<std::string> &extensions = {}, |
|
|
|
const std::shared_ptr<DatasetCache> &cache = nullptr); |
|
|
|
~CelebADataset() = default; |
|
|
|
}; |
|
|
|
|
|
|
|
/// \brief Function to create a CelebADataset |
|
|
|
@@ -557,6 +570,7 @@ class Cifar10Dataset : public Dataset { |
|
|
|
explicit Cifar10Dataset(const std::string &dataset_dir, const std::string &usage = "all", |
|
|
|
const std::shared_ptr<SamplerObj> &sampler = RandomSampler(), |
|
|
|
const std::shared_ptr<DatasetCache> &cache = nullptr); |
|
|
|
~Cifar10Dataset() = default; |
|
|
|
}; |
|
|
|
|
|
|
|
/// \brief Function to create a Cifar10 Dataset |
|
|
|
@@ -577,6 +591,7 @@ class Cifar100Dataset : public Dataset { |
|
|
|
explicit Cifar100Dataset(const std::string &dataset_dir, const std::string &usage = "all", |
|
|
|
const std::shared_ptr<SamplerObj> &sampler = RandomSampler(), |
|
|
|
const std::shared_ptr<DatasetCache> &cache = nullptr); |
|
|
|
~Cifar100Dataset() = default; |
|
|
|
}; |
|
|
|
|
|
|
|
/// \brief Function to create a Cifar100 Dataset |
|
|
|
@@ -598,6 +613,7 @@ class CLUEDataset : public Dataset { |
|
|
|
const std::string &usage = "train", int64_t num_samples = 0, |
|
|
|
ShuffleMode shuffle = ShuffleMode::kGlobal, int32_t num_shards = 1, int32_t shard_id = 0, |
|
|
|
const std::shared_ptr<DatasetCache> &cache = nullptr); |
|
|
|
~CLUEDataset() = default; |
|
|
|
}; |
|
|
|
|
|
|
|
/// \brief Function to create a CLUEDataset |
|
|
|
@@ -629,6 +645,7 @@ class CocoDataset : public Dataset { |
|
|
|
CocoDataset(const std::string &dataset_dir, const std::string &annotation_file, const std::string &task = "Detection", |
|
|
|
const bool &decode = false, const std::shared_ptr<SamplerObj> &sampler = RandomSampler(), |
|
|
|
const std::shared_ptr<DatasetCache> &cache = nullptr); |
|
|
|
~CocoDataset() = default; |
|
|
|
}; |
|
|
|
|
|
|
|
/// \brief Function to create a CocoDataset |
|
|
|
@@ -661,6 +678,7 @@ class CSVDataset : public Dataset { |
|
|
|
const std::vector<std::string> &column_names = {}, int64_t num_samples = 0, |
|
|
|
ShuffleMode shuffle = ShuffleMode::kGlobal, int32_t num_shards = 1, int32_t shard_id = 0, |
|
|
|
const std::shared_ptr<DatasetCache> &cache = nullptr); |
|
|
|
~CSVDataset() = default; |
|
|
|
}; |
|
|
|
|
|
|
|
/// \brief Function to create a CSVDataset |
|
|
|
@@ -698,6 +716,7 @@ class ImageFolderDataset : public Dataset { |
|
|
|
const std::set<std::string> &extensions = {}, |
|
|
|
const std::map<std::string, int32_t> &class_indexing = {}, |
|
|
|
const std::shared_ptr<DatasetCache> &cache = nullptr); |
|
|
|
~ImageFolderDataset() = default; |
|
|
|
}; |
|
|
|
|
|
|
|
/// \brief Function to create an ImageFolderDataset |
|
|
|
@@ -725,6 +744,7 @@ class ManifestDataset : public Dataset { |
|
|
|
const std::shared_ptr<SamplerObj> &sampler = RandomSampler(), |
|
|
|
const std::map<std::string, int32_t> &class_indexing = {}, bool decode = false, |
|
|
|
const std::shared_ptr<DatasetCache> &cache = nullptr); |
|
|
|
~ManifestDataset() = default; |
|
|
|
}; |
|
|
|
|
|
|
|
/// \brief Function to create a ManifestDataset |
|
|
|
@@ -753,6 +773,7 @@ class MindDataDataset : public Dataset { |
|
|
|
const std::vector<std::string> &columns_list = {}, |
|
|
|
const std::shared_ptr<SamplerObj> &sampler = RandomSampler(), |
|
|
|
nlohmann::json padded_sample = nullptr, int64_t num_padded = 0); |
|
|
|
~MindDataDataset() = default; |
|
|
|
}; |
|
|
|
|
|
|
|
/// \brief Function to create a MindDataDataset |
|
|
|
@@ -789,6 +810,7 @@ class MnistDataset : public Dataset { |
|
|
|
explicit MnistDataset(const std::string &dataset_dir, const std::string &usage = "all", |
|
|
|
const std::shared_ptr<SamplerObj> &sampler = RandomSampler(), |
|
|
|
const std::shared_ptr<DatasetCache> &cache = nullptr); |
|
|
|
~MnistDataset() = default; |
|
|
|
}; |
|
|
|
|
|
|
|
/// \brief Function to create a MnistDataset |
|
|
|
@@ -819,6 +841,8 @@ class RandomDataDataset : public Dataset { |
|
|
|
|
|
|
|
RandomDataDataset(const int32_t &total_rows, std::string schema_path, const std::vector<std::string> &columns_list, |
|
|
|
std::shared_ptr<DatasetCache> cache); |
|
|
|
|
|
|
|
~RandomDataDataset() = default; |
|
|
|
}; |
|
|
|
|
|
|
|
/// \brief Function to create a RandomDataset |
|
|
|
@@ -849,6 +873,7 @@ class TextFileDataset : public Dataset { |
|
|
|
explicit TextFileDataset(const std::vector<std::string> &dataset_files, int64_t num_samples = 0, |
|
|
|
ShuffleMode shuffle = ShuffleMode::kGlobal, int32_t num_shards = 1, int32_t shard_id = 0, |
|
|
|
const std::shared_ptr<DatasetCache> &cache = nullptr); |
|
|
|
~TextFileDataset() = default; |
|
|
|
}; |
|
|
|
|
|
|
|
/// \brief Function to create a TextFileDataset |
|
|
|
@@ -883,6 +908,8 @@ class TFRecordDataset : public Dataset { |
|
|
|
TFRecordDataset(const std::vector<std::string> &dataset_files, std::shared_ptr<SchemaObj> schema, |
|
|
|
const std::vector<std::string> &columns_list, int64_t num_samples, ShuffleMode shuffle, |
|
|
|
int32_t num_shards, int32_t shard_id, bool shard_equal_rows, std::shared_ptr<DatasetCache> cache); |
|
|
|
|
|
|
|
~TFRecordDataset() = default; |
|
|
|
}; |
|
|
|
|
|
|
|
/// \brief Function to create a TFRecordDataset |
|
|
|
@@ -941,6 +968,7 @@ class VOCDataset : public Dataset { |
|
|
|
const std::string &usage = "train", const std::map<std::string, int32_t> &class_indexing = {}, |
|
|
|
bool decode = false, const std::shared_ptr<SamplerObj> &sampler = RandomSampler(), |
|
|
|
const std::shared_ptr<DatasetCache> &cache = nullptr); |
|
|
|
~VOCDataset() = default; |
|
|
|
}; |
|
|
|
|
|
|
|
/// \brief Function to create a VOCDataset |
|
|
|
|