|
|
|
@@ -82,10 +82,15 @@ namespace dataset { |
|
|
|
|
|
|
|
PYBIND_REGISTER(DatasetNode, 1, ([](const py::module *m) { |
|
|
|
(void)py::class_<DatasetNode, std::shared_ptr<DatasetNode>>(*m, "Dataset") |
|
|
|
.def("SetNumWorkers", |
|
|
|
.def("set_num_workers", |
|
|
|
[](std::shared_ptr<DatasetNode> self, std::optional<int32_t> num_workers) { |
|
|
|
return num_workers ? self->SetNumWorkers(*num_workers) : self; |
|
|
|
}) |
|
|
|
.def("set_cache_client", |
|
|
|
[](std::shared_ptr<DatasetNode> self) { |
|
|
|
std::shared_ptr<DatasetCache> dc = nullptr; |
|
|
|
return self->SetDatasetCache(dc); |
|
|
|
}) |
|
|
|
.def( |
|
|
|
"Zip", |
|
|
|
[](std::shared_ptr<DatasetNode> self, py::list datasets) { |
|
|
|
@@ -109,10 +114,9 @@ PYBIND_REGISTER(CelebANode, 2, ([](const py::module *m) { |
|
|
|
(void)py::class_<CelebANode, DatasetNode, std::shared_ptr<CelebANode>>(*m, "CelebANode", |
|
|
|
"to create a CelebANode") |
|
|
|
.def(py::init([](std::string dataset_dir, std::string usage, py::handle sampler, bool decode, |
|
|
|
py::list extensions, std::shared_ptr<CacheClient> cc) { |
|
|
|
auto celebA = |
|
|
|
std::make_shared<CelebANode>(dataset_dir, usage, toSamplerObj(sampler), decode, |
|
|
|
toStringSet(extensions), toDatasetCache(std::move(cc))); |
|
|
|
py::list extensions) { |
|
|
|
auto celebA = std::make_shared<CelebANode>(dataset_dir, usage, toSamplerObj(sampler), decode, |
|
|
|
toStringSet(extensions), nullptr); |
|
|
|
THROW_IF_ERROR(celebA->ValidateParams()); |
|
|
|
return celebA; |
|
|
|
})); |
|
|
|
@@ -121,10 +125,8 @@ PYBIND_REGISTER(CelebANode, 2, ([](const py::module *m) { |
|
|
|
PYBIND_REGISTER(Cifar10Node, 2, ([](const py::module *m) { |
|
|
|
(void)py::class_<Cifar10Node, DatasetNode, std::shared_ptr<Cifar10Node>>(*m, "Cifar10Node", |
|
|
|
"to create a Cifar10Node") |
|
|
|
.def(py::init([](std::string dataset_dir, std::string usage, py::handle sampler, |
|
|
|
std::shared_ptr<CacheClient> cc) { |
|
|
|
auto cifar10 = std::make_shared<Cifar10Node>(dataset_dir, usage, toSamplerObj(sampler), |
|
|
|
toDatasetCache(std::move(cc))); |
|
|
|
.def(py::init([](std::string dataset_dir, std::string usage, py::handle sampler) { |
|
|
|
auto cifar10 = std::make_shared<Cifar10Node>(dataset_dir, usage, toSamplerObj(sampler), nullptr); |
|
|
|
THROW_IF_ERROR(cifar10->ValidateParams()); |
|
|
|
return cifar10; |
|
|
|
})); |
|
|
|
@@ -133,36 +135,34 @@ PYBIND_REGISTER(Cifar10Node, 2, ([](const py::module *m) { |
|
|
|
PYBIND_REGISTER(Cifar100Node, 2, ([](const py::module *m) { |
|
|
|
(void)py::class_<Cifar100Node, DatasetNode, std::shared_ptr<Cifar100Node>>(*m, "Cifar100Node", |
|
|
|
"to create a Cifar100Node") |
|
|
|
.def(py::init([](std::string dataset_dir, std::string usage, py::handle sampler, |
|
|
|
std::shared_ptr<CacheClient> cc) { |
|
|
|
auto cifar100 = std::make_shared<Cifar100Node>(dataset_dir, usage, toSamplerObj(sampler), |
|
|
|
toDatasetCache(std::move(cc))); |
|
|
|
.def(py::init([](std::string dataset_dir, std::string usage, py::handle sampler) { |
|
|
|
auto cifar100 = |
|
|
|
std::make_shared<Cifar100Node>(dataset_dir, usage, toSamplerObj(sampler), nullptr); |
|
|
|
THROW_IF_ERROR(cifar100->ValidateParams()); |
|
|
|
return cifar100; |
|
|
|
})); |
|
|
|
})); |
|
|
|
|
|
|
|
PYBIND_REGISTER( |
|
|
|
CLUENode, 2, ([](const py::module *m) { |
|
|
|
(void)py::class_<CLUENode, DatasetNode, std::shared_ptr<CLUENode>>(*m, "CLUENode", "to create a CLUENode") |
|
|
|
.def(py::init([](py::list files, std::string task, std::string usage, int64_t num_samples, int32_t shuffle, |
|
|
|
int32_t num_shards, int32_t shard_id, std::shared_ptr<CacheClient> cc) { |
|
|
|
std::shared_ptr<CLUENode> clue_node = |
|
|
|
std::make_shared<dataset::CLUENode>(toStringVector(files), task, usage, num_samples, toShuffleMode(shuffle), |
|
|
|
num_shards, shard_id, toDatasetCache(std::move(cc))); |
|
|
|
THROW_IF_ERROR(clue_node->ValidateParams()); |
|
|
|
return clue_node; |
|
|
|
})); |
|
|
|
})); |
|
|
|
PYBIND_REGISTER(CLUENode, 2, ([](const py::module *m) { |
|
|
|
(void)py::class_<CLUENode, DatasetNode, std::shared_ptr<CLUENode>>(*m, "CLUENode", |
|
|
|
"to create a CLUENode") |
|
|
|
.def(py::init([](py::list files, std::string task, std::string usage, int64_t num_samples, |
|
|
|
int32_t shuffle, int32_t num_shards, int32_t shard_id) { |
|
|
|
std::shared_ptr<CLUENode> clue_node = |
|
|
|
std::make_shared<dataset::CLUENode>(toStringVector(files), task, usage, num_samples, |
|
|
|
toShuffleMode(shuffle), num_shards, shard_id, nullptr); |
|
|
|
THROW_IF_ERROR(clue_node->ValidateParams()); |
|
|
|
return clue_node; |
|
|
|
})); |
|
|
|
})); |
|
|
|
|
|
|
|
PYBIND_REGISTER(CocoNode, 2, ([](const py::module *m) { |
|
|
|
(void)py::class_<CocoNode, DatasetNode, std::shared_ptr<CocoNode>>(*m, "CocoNode", |
|
|
|
"to create a CocoNode") |
|
|
|
.def(py::init([](std::string dataset_dir, std::string annotation_file, std::string task, |
|
|
|
bool decode, py::handle sampler, std::shared_ptr<CacheClient> cc) { |
|
|
|
std::shared_ptr<CocoNode> coco = |
|
|
|
std::make_shared<CocoNode>(dataset_dir, annotation_file, task, decode, toSamplerObj(sampler), |
|
|
|
toDatasetCache(std::move(cc))); |
|
|
|
bool decode, py::handle sampler) { |
|
|
|
std::shared_ptr<CocoNode> coco = std::make_shared<CocoNode>( |
|
|
|
dataset_dir, annotation_file, task, decode, toSamplerObj(sampler), nullptr); |
|
|
|
THROW_IF_ERROR(coco->ValidateParams()); |
|
|
|
return coco; |
|
|
|
})); |
|
|
|
@@ -172,10 +172,10 @@ PYBIND_REGISTER(CSVNode, 2, ([](const py::module *m) { |
|
|
|
(void)py::class_<CSVNode, DatasetNode, std::shared_ptr<CSVNode>>(*m, "CSVNode", "to create a CSVNode") |
|
|
|
.def(py::init([](std::vector<std::string> csv_files, char field_delim, py::list column_defaults, |
|
|
|
std::vector<std::string> column_names, int64_t num_samples, int32_t shuffle, |
|
|
|
int32_t num_shards, int32_t shard_id, std::shared_ptr<CacheClient> cc) { |
|
|
|
auto csv = std::make_shared<CSVNode>(csv_files, field_delim, toCSVBase(column_defaults), |
|
|
|
column_names, num_samples, toShuffleMode(shuffle), |
|
|
|
num_shards, shard_id, toDatasetCache(std::move(cc))); |
|
|
|
int32_t num_shards, int32_t shard_id) { |
|
|
|
auto csv = |
|
|
|
std::make_shared<CSVNode>(csv_files, field_delim, toCSVBase(column_defaults), column_names, |
|
|
|
num_samples, toShuffleMode(shuffle), num_shards, shard_id, nullptr); |
|
|
|
THROW_IF_ERROR(csv->ValidateParams()); |
|
|
|
return csv; |
|
|
|
})); |
|
|
|
@@ -205,12 +205,12 @@ PYBIND_REGISTER(ImageFolderNode, 2, ([](const py::module *m) { |
|
|
|
(void)py::class_<ImageFolderNode, DatasetNode, std::shared_ptr<ImageFolderNode>>( |
|
|
|
*m, "ImageFolderNode", "to create an ImageFolderNode") |
|
|
|
.def(py::init([](std::string dataset_dir, bool decode, py::handle sampler, py::list extensions, |
|
|
|
py::dict class_indexing, std::shared_ptr<CacheClient> cc) { |
|
|
|
py::dict class_indexing) { |
|
|
|
// Don't update recursive to true |
|
|
|
bool recursive = false; // Will be removed in future PR |
|
|
|
auto imagefolder = std::make_shared<ImageFolderNode>( |
|
|
|
dataset_dir, decode, toSamplerObj(sampler), recursive, toStringSet(extensions), |
|
|
|
toStringMap(class_indexing), toDatasetCache(std::move(cc))); |
|
|
|
auto imagefolder = std::make_shared<ImageFolderNode>(dataset_dir, decode, toSamplerObj(sampler), |
|
|
|
recursive, toStringSet(extensions), |
|
|
|
toStringMap(class_indexing), nullptr); |
|
|
|
THROW_IF_ERROR(imagefolder->ValidateParams()); |
|
|
|
return imagefolder; |
|
|
|
})); |
|
|
|
@@ -220,10 +220,9 @@ PYBIND_REGISTER(ManifestNode, 2, ([](const py::module *m) { |
|
|
|
(void)py::class_<ManifestNode, DatasetNode, std::shared_ptr<ManifestNode>>(*m, "ManifestNode", |
|
|
|
"to create a ManifestNode") |
|
|
|
.def(py::init([](std::string dataset_file, std::string usage, py::handle sampler, |
|
|
|
py::dict class_indexing, bool decode, std::shared_ptr<CacheClient> cc) { |
|
|
|
py::dict class_indexing, bool decode) { |
|
|
|
auto manifest = std::make_shared<ManifestNode>(dataset_file, usage, toSamplerObj(sampler), |
|
|
|
toStringMap(class_indexing), decode, |
|
|
|
toDatasetCache(std::move(cc))); |
|
|
|
toStringMap(class_indexing), decode, nullptr); |
|
|
|
THROW_IF_ERROR(manifest->ValidateParams()); |
|
|
|
return manifest; |
|
|
|
})); |
|
|
|
@@ -261,41 +260,38 @@ PYBIND_REGISTER(MindDataNode, 2, ([](const py::module *m) { |
|
|
|
PYBIND_REGISTER(MnistNode, 2, ([](const py::module *m) { |
|
|
|
(void)py::class_<MnistNode, DatasetNode, std::shared_ptr<MnistNode>>(*m, "MnistNode", |
|
|
|
"to create an MnistNode") |
|
|
|
.def(py::init([](std::string dataset_dir, std::string usage, py::handle sampler, |
|
|
|
std::shared_ptr<CacheClient> cc) { |
|
|
|
auto mnist = std::make_shared<MnistNode>(dataset_dir, usage, toSamplerObj(sampler), |
|
|
|
toDatasetCache(std::move(cc))); |
|
|
|
.def(py::init([](std::string dataset_dir, std::string usage, py::handle sampler) { |
|
|
|
auto mnist = std::make_shared<MnistNode>(dataset_dir, usage, toSamplerObj(sampler), nullptr); |
|
|
|
THROW_IF_ERROR(mnist->ValidateParams()); |
|
|
|
return mnist; |
|
|
|
})); |
|
|
|
})); |
|
|
|
|
|
|
|
PYBIND_REGISTER( |
|
|
|
RandomNode, 2, ([](const py::module *m) { |
|
|
|
(void)py::class_<RandomNode, DatasetNode, std::shared_ptr<RandomNode>>(*m, "RandomNode", "to create a RandomNode") |
|
|
|
.def(py::init([](int32_t total_rows, std::shared_ptr<SchemaObj> schema, py::list columns_list, |
|
|
|
std::shared_ptr<CacheClient> cc) { |
|
|
|
auto random_node = |
|
|
|
std::make_shared<RandomNode>(total_rows, schema, toStringVector(columns_list), toDatasetCache(std::move(cc))); |
|
|
|
THROW_IF_ERROR(random_node->ValidateParams()); |
|
|
|
return random_node; |
|
|
|
})) |
|
|
|
.def(py::init([](int32_t total_rows, std::string schema, py::list columns_list, std::shared_ptr<CacheClient> cc) { |
|
|
|
auto random_node = |
|
|
|
std::make_shared<RandomNode>(total_rows, schema, toStringVector(columns_list), toDatasetCache(std::move(cc))); |
|
|
|
THROW_IF_ERROR(random_node->ValidateParams()); |
|
|
|
return random_node; |
|
|
|
})); |
|
|
|
})); |
|
|
|
PYBIND_REGISTER(RandomNode, 2, ([](const py::module *m) { |
|
|
|
(void)py::class_<RandomNode, DatasetNode, std::shared_ptr<RandomNode>>(*m, "RandomNode", |
|
|
|
"to create a RandomNode") |
|
|
|
.def(py::init([](int32_t total_rows, std::shared_ptr<SchemaObj> schema, py::list columns_list) { |
|
|
|
auto random_node = |
|
|
|
std::make_shared<RandomNode>(total_rows, schema, toStringVector(columns_list), nullptr); |
|
|
|
THROW_IF_ERROR(random_node->ValidateParams()); |
|
|
|
return random_node; |
|
|
|
})) |
|
|
|
.def(py::init([](int32_t total_rows, std::string schema, py::list columns_list) { |
|
|
|
auto random_node = |
|
|
|
std::make_shared<RandomNode>(total_rows, schema, toStringVector(columns_list), nullptr); |
|
|
|
THROW_IF_ERROR(random_node->ValidateParams()); |
|
|
|
return random_node; |
|
|
|
})); |
|
|
|
})); |
|
|
|
|
|
|
|
PYBIND_REGISTER(TextFileNode, 2, ([](const py::module *m) { |
|
|
|
(void)py::class_<TextFileNode, DatasetNode, std::shared_ptr<TextFileNode>>(*m, "TextFileNode", |
|
|
|
"to create a TextFileNode") |
|
|
|
.def(py::init([](py::list dataset_files, int32_t num_samples, int32_t shuffle, int32_t num_shards, |
|
|
|
int32_t shard_id, std::shared_ptr<CacheClient> cc) { |
|
|
|
std::shared_ptr<TextFileNode> textfile_node = std::make_shared<TextFileNode>( |
|
|
|
toStringVector(dataset_files), num_samples, toShuffleMode(shuffle), num_shards, shard_id, |
|
|
|
toDatasetCache(std::move(cc))); |
|
|
|
int32_t shard_id) { |
|
|
|
std::shared_ptr<TextFileNode> textfile_node = |
|
|
|
std::make_shared<TextFileNode>(toStringVector(dataset_files), num_samples, |
|
|
|
toShuffleMode(shuffle), num_shards, shard_id, nullptr); |
|
|
|
THROW_IF_ERROR(textfile_node->ValidateParams()); |
|
|
|
return textfile_node; |
|
|
|
})); |
|
|
|
@@ -306,19 +302,19 @@ PYBIND_REGISTER(TFRecordNode, 2, ([](const py::module *m) { |
|
|
|
"to create a TFRecordNode") |
|
|
|
.def(py::init([](py::list dataset_files, std::shared_ptr<SchemaObj> schema, py::list columns_list, |
|
|
|
int64_t num_samples, int32_t shuffle, int32_t num_shards, int32_t shard_id, |
|
|
|
bool shard_equal_rows, std::shared_ptr<CacheClient> cc) { |
|
|
|
bool shard_equal_rows) { |
|
|
|
std::shared_ptr<TFRecordNode> tfrecord = std::make_shared<TFRecordNode>( |
|
|
|
toStringVector(dataset_files), schema, toStringVector(columns_list), num_samples, |
|
|
|
toShuffleMode(shuffle), num_shards, shard_id, shard_equal_rows, toDatasetCache(std::move(cc))); |
|
|
|
toShuffleMode(shuffle), num_shards, shard_id, shard_equal_rows, nullptr); |
|
|
|
THROW_IF_ERROR(tfrecord->ValidateParams()); |
|
|
|
return tfrecord; |
|
|
|
})) |
|
|
|
.def(py::init([](py::list dataset_files, std::string schema, py::list columns_list, |
|
|
|
int64_t num_samples, int32_t shuffle, int32_t num_shards, int32_t shard_id, |
|
|
|
bool shard_equal_rows, std::shared_ptr<CacheClient> cc) { |
|
|
|
bool shard_equal_rows) { |
|
|
|
std::shared_ptr<TFRecordNode> tfrecord = std::make_shared<TFRecordNode>( |
|
|
|
toStringVector(dataset_files), schema, toStringVector(columns_list), num_samples, |
|
|
|
toShuffleMode(shuffle), num_shards, shard_id, shard_equal_rows, toDatasetCache(std::move(cc))); |
|
|
|
toShuffleMode(shuffle), num_shards, shard_id, shard_equal_rows, nullptr); |
|
|
|
THROW_IF_ERROR(tfrecord->ValidateParams()); |
|
|
|
return tfrecord; |
|
|
|
})); |
|
|
|
@@ -326,15 +322,13 @@ PYBIND_REGISTER(TFRecordNode, 2, ([](const py::module *m) { |
|
|
|
|
|
|
|
PYBIND_REGISTER(VOCNode, 2, ([](const py::module *m) { |
|
|
|
(void)py::class_<VOCNode, DatasetNode, std::shared_ptr<VOCNode>>(*m, "VOCNode", "to create a VOCNode") |
|
|
|
.def( |
|
|
|
py::init([](std::string dataset_dir, std::string task, std::string usage, py::dict class_indexing, |
|
|
|
bool decode, py::handle sampler, std::shared_ptr<CacheClient> cc) { |
|
|
|
std::shared_ptr<VOCNode> voc = |
|
|
|
std::make_shared<VOCNode>(dataset_dir, task, usage, toStringMap(class_indexing), decode, |
|
|
|
toSamplerObj(sampler), toDatasetCache(std::move(cc))); |
|
|
|
THROW_IF_ERROR(voc->ValidateParams()); |
|
|
|
return voc; |
|
|
|
})); |
|
|
|
.def(py::init([](std::string dataset_dir, std::string task, std::string usage, |
|
|
|
py::dict class_indexing, bool decode, py::handle sampler) { |
|
|
|
std::shared_ptr<VOCNode> voc = std::make_shared<VOCNode>( |
|
|
|
dataset_dir, task, usage, toStringMap(class_indexing), decode, toSamplerObj(sampler), nullptr); |
|
|
|
THROW_IF_ERROR(voc->ValidateParams()); |
|
|
|
return voc; |
|
|
|
})); |
|
|
|
})); |
|
|
|
|
|
|
|
// PYBIND FOR NON-LEAF NODES |
|
|
|
@@ -439,11 +433,11 @@ PYBIND_REGISTER(FilterNode, 2, ([](const py::module *m) { |
|
|
|
PYBIND_REGISTER(MapNode, 2, ([](const py::module *m) { |
|
|
|
(void)py::class_<MapNode, DatasetNode, std::shared_ptr<MapNode>>(*m, "MapNode", "to create a MapNode") |
|
|
|
.def(py::init([](std::shared_ptr<DatasetNode> self, py::list operations, py::list input_columns, |
|
|
|
py::list output_columns, py::list project_columns, std::shared_ptr<CacheClient> cc, |
|
|
|
py::list output_columns, py::list project_columns, |
|
|
|
std::vector<std::shared_ptr<PyDSCallback>> py_callbacks) { |
|
|
|
auto map = std::make_shared<MapNode>( |
|
|
|
self, std::move(toTensorOperations(operations)), toStringVector(input_columns), |
|
|
|
toStringVector(output_columns), toStringVector(project_columns), toDatasetCache(std::move(cc)), |
|
|
|
toStringVector(output_columns), toStringVector(project_columns), nullptr, |
|
|
|
std::vector<std::shared_ptr<DSCallback>>(py_callbacks.begin(), py_callbacks.end())); |
|
|
|
THROW_IF_ERROR(map->ValidateParams()); |
|
|
|
return map; |
|
|
|
|