Merge pull request !8096 from h.farahat/datasetNodetags/v1.1.0
| @@ -86,6 +86,7 @@ | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/csv_node.h" | #include "minddata/dataset/engine/ir/datasetops/source/csv_node.h" | ||||
| #include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h" | #include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h" | ||||
| #include "minddata/dataset/engine/ir/datasetops/source/mnist_node.h" | #include "minddata/dataset/engine/ir/datasetops/source/mnist_node.h" | ||||
| #include "minddata/dataset/engine/ir/datasetops/source/random_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/text_file_node.h" | #include "minddata/dataset/engine/ir/datasetops/source/text_file_node.h" | ||||
| // IR leaf nodes disabled for android | // IR leaf nodes disabled for android | ||||
| @@ -140,26 +141,11 @@ bool Dataset::DeviceQueue(bool send_epoch_end) { | |||||
| return false; | return false; | ||||
| } | } | ||||
| // Get a uuid for queue name | |||||
| std::string queue_name = Services::GetUniqueID(); | |||||
| // TODO(CRC): | |||||
| // Get device type from ms context | |||||
| std::string device_type = "CPU"; | |||||
| // Get device ID from children | |||||
| int32_t device_id = 0; | |||||
| rc = TransferNode::get_distribution(shared_from_this(), &device_id); | |||||
| if (rc.IsError()) { | |||||
| MS_LOG(ERROR) << "Failed to get shard id. Error status: " << rc; | |||||
| return false; | |||||
| } | |||||
| // Add TransferNode IR on top of dataset d | // Add TransferNode IR on top of dataset d | ||||
| auto ds = std::make_shared<TransferNode>(shared_from_this(), queue_name, device_id, device_type, send_epoch_end); | |||||
| auto ds = std::make_shared<TransferNode>(shared_from_this()->IRNode(), send_epoch_end); | |||||
| // Get ToDevice consumer | // Get ToDevice consumer | ||||
| auto consumer = std::make_unique<ToDevice>(device_type, send_epoch_end, -1); | |||||
| auto consumer = std::make_unique<ToDevice>(send_epoch_end, -1); | |||||
| ToDevice *consumer_ = consumer.get(); | ToDevice *consumer_ = consumer.get(); | ||||
| rc = consumer->Init(ds); | rc = consumer->Init(ds); | ||||
| if (rc.IsError()) { | if (rc.IsError()) { | ||||
| @@ -199,7 +185,7 @@ bool Dataset::Save(std::string dataset_path, int32_t num_files, std::string data | |||||
| return false; | return false; | ||||
| } | } | ||||
| SaveToDisk *consumer_ = consumer.get(); | SaveToDisk *consumer_ = consumer.get(); | ||||
| rc = consumer->Init(ds); | |||||
| rc = consumer->Init(ds->IRNode()); | |||||
| if (rc.IsError()) { | if (rc.IsError()) { | ||||
| MS_LOG(ERROR) << "CreateSaver failed." << rc; | MS_LOG(ERROR) << "CreateSaver failed." << rc; | ||||
| return false; | return false; | ||||
| @@ -225,19 +211,10 @@ bool Dataset::Save(std::string dataset_path, int32_t num_files, std::string data | |||||
| #endif | #endif | ||||
| // Constructor | // Constructor | ||||
| Dataset::Dataset() { | |||||
| // Fetch some default value from config manager | |||||
| std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager(); | |||||
| num_workers_ = cfg->num_parallel_workers(); | |||||
| rows_per_buffer_ = cfg->rows_per_buffer(); | |||||
| connector_que_size_ = cfg->op_connector_size(); | |||||
| worker_connector_size_ = cfg->worker_connector_size(); | |||||
| tree_getters_ = std::make_shared<TreeGetters>(); | |||||
| } | |||||
| Dataset::Dataset() { tree_getters_ = std::make_shared<TreeGetters>(); } | |||||
| int64_t Dataset::GetDatasetSize() { | int64_t Dataset::GetDatasetSize() { | ||||
| int64_t dataset_size; | int64_t dataset_size; | ||||
| auto ds = shared_from_this(); | |||||
| Status rc; | Status rc; | ||||
| std::unique_ptr<RuntimeContext> runtime_context = std::make_unique<RuntimeContext>(); | std::unique_ptr<RuntimeContext> runtime_context = std::make_unique<RuntimeContext>(); | ||||
| rc = runtime_context->Init(); | rc = runtime_context->Init(); | ||||
| @@ -246,7 +223,7 @@ int64_t Dataset::GetDatasetSize() { | |||||
| return -1; | return -1; | ||||
| } | } | ||||
| if (!tree_getters_->isInitialized()) { | if (!tree_getters_->isInitialized()) { | ||||
| rc = tree_getters_->Init(ds); | |||||
| rc = tree_getters_->Init(this->IRNode()); | |||||
| if (rc.IsError()) { | if (rc.IsError()) { | ||||
| MS_LOG(ERROR) << "GetDatasetSize: Initializing TreeGetters failed."; | MS_LOG(ERROR) << "GetDatasetSize: Initializing TreeGetters failed."; | ||||
| return -1; | return -1; | ||||
| @@ -267,7 +244,7 @@ std::vector<DataType> Dataset::GetOutputTypes() { | |||||
| return types; | return types; | ||||
| } | } | ||||
| if (!tree_getters_->isInitialized()) { | if (!tree_getters_->isInitialized()) { | ||||
| rc = tree_getters_->Init(shared_from_this()); | |||||
| rc = tree_getters_->Init(this->IRNode()); | |||||
| if (rc.IsError()) { | if (rc.IsError()) { | ||||
| MS_LOG(ERROR) << "GetOutputTypes: Initializing TreeGetters failed."; | MS_LOG(ERROR) << "GetOutputTypes: Initializing TreeGetters failed."; | ||||
| types.clear(); | types.clear(); | ||||
| @@ -294,7 +271,7 @@ std::vector<TensorShape> Dataset::GetOutputShapes() { | |||||
| return shapes; | return shapes; | ||||
| } | } | ||||
| if (!tree_getters_->isInitialized()) { | if (!tree_getters_->isInitialized()) { | ||||
| rc = tree_getters_->Init(shared_from_this()); | |||||
| rc = tree_getters_->Init(this->IRNode()); | |||||
| if (rc.IsError()) { | if (rc.IsError()) { | ||||
| MS_LOG(ERROR) << "GetOutputShapes: Initializing TreeGetters failed."; | MS_LOG(ERROR) << "GetOutputShapes: Initializing TreeGetters failed."; | ||||
| shapes.clear(); | shapes.clear(); | ||||
| @@ -321,7 +298,7 @@ int64_t Dataset::GetNumClasses() { | |||||
| return -1; | return -1; | ||||
| } | } | ||||
| if (!tree_getters_->isInitialized()) { | if (!tree_getters_->isInitialized()) { | ||||
| rc = tree_getters_->Init(ds); | |||||
| rc = tree_getters_->Init(ds->IRNode()); | |||||
| if (rc.IsError()) { | if (rc.IsError()) { | ||||
| MS_LOG(ERROR) << "GetNumClasses: Initializing TreeGetters failed."; | MS_LOG(ERROR) << "GetNumClasses: Initializing TreeGetters failed."; | ||||
| return -1; | return -1; | ||||
| @@ -331,9 +308,6 @@ int64_t Dataset::GetNumClasses() { | |||||
| return rc.IsError() ? -1 : num_classes; | return rc.IsError() ? -1 : num_classes; | ||||
| } | } | ||||
| // Constructor to initialize the cache | |||||
| Dataset::Dataset(const std::shared_ptr<DatasetCache> &dataset_cache) : Dataset() { cache_ = dataset_cache; } | |||||
| /// \brief Function to create a SchemaObj | /// \brief Function to create a SchemaObj | ||||
| /// \param[in] schema_file Path of schema file | /// \param[in] schema_file Path of schema file | ||||
| /// \return Shared pointer to the current schema | /// \return Shared pointer to the current schema | ||||
| @@ -346,161 +320,155 @@ std::shared_ptr<SchemaObj> Schema(const std::string &schema_file) { | |||||
| // FUNCTIONS TO CREATE DATASETS FOR LEAF-NODE DATASETS | // FUNCTIONS TO CREATE DATASETS FOR LEAF-NODE DATASETS | ||||
| // (In alphabetical order) | // (In alphabetical order) | ||||
| // Function to create a AlbumNode. | |||||
| std::shared_ptr<AlbumNode> Album(const std::string &dataset_dir, const std::string &data_schema, | |||||
| const std::vector<std::string> &column_names, bool decode, | |||||
| const std::shared_ptr<SamplerObj> &sampler, | |||||
| const std::shared_ptr<DatasetCache> &cache) { | |||||
| auto ds = std::make_shared<AlbumNode>(dataset_dir, data_schema, column_names, decode, sampler, cache); | |||||
| // Function to create a AlbumDataset. | |||||
| std::shared_ptr<AlbumDataset> Album(const std::string &dataset_dir, const std::string &data_schema, | |||||
| const std::vector<std::string> &column_names, bool decode, | |||||
| const std::shared_ptr<SamplerObj> &sampler, | |||||
| const std::shared_ptr<DatasetCache> &cache) { | |||||
| auto ds = std::make_shared<AlbumDataset>(dataset_dir, data_schema, column_names, decode, sampler, cache); | |||||
| return ds; | return ds; | ||||
| } | } | ||||
| // Function to create a CelebANode. | |||||
| std::shared_ptr<CelebANode> CelebA(const std::string &dataset_dir, const std::string &usage, | |||||
| const std::shared_ptr<SamplerObj> &sampler, bool decode, | |||||
| const std::set<std::string> &extensions, | |||||
| const std::shared_ptr<DatasetCache> &cache) { | |||||
| auto ds = std::make_shared<CelebANode>(dataset_dir, usage, sampler, decode, extensions, cache); | |||||
| // Function to create a CelebADataset. | |||||
| std::shared_ptr<CelebADataset> CelebA(const std::string &dataset_dir, const std::string &usage, | |||||
| const std::shared_ptr<SamplerObj> &sampler, bool decode, | |||||
| const std::set<std::string> &extensions, | |||||
| const std::shared_ptr<DatasetCache> &cache) { | |||||
| auto ds = std::make_shared<CelebADataset>(dataset_dir, usage, sampler, decode, extensions, cache); | |||||
| return ds; | return ds; | ||||
| } | } | ||||
| // Function to create a Cifar10Node. | |||||
| std::shared_ptr<Cifar10Node> Cifar10(const std::string &dataset_dir, const std::string &usage, | |||||
| const std::shared_ptr<SamplerObj> &sampler, | |||||
| const std::shared_ptr<DatasetCache> &cache) { | |||||
| auto ds = std::make_shared<Cifar10Node>(dataset_dir, usage, sampler, cache); | |||||
| // Function to create a Cifar10Dataset. | |||||
| std::shared_ptr<Cifar10Dataset> Cifar10(const std::string &dataset_dir, const std::string &usage, | |||||
| const std::shared_ptr<SamplerObj> &sampler, | |||||
| const std::shared_ptr<DatasetCache> &cache) { | |||||
| auto ds = std::make_shared<Cifar10Dataset>(dataset_dir, usage, sampler, cache); | |||||
| return ds; | return ds; | ||||
| } | } | ||||
| // Function to create a Cifar100Node. | |||||
| std::shared_ptr<Cifar100Node> Cifar100(const std::string &dataset_dir, const std::string &usage, | |||||
| const std::shared_ptr<SamplerObj> &sampler, | |||||
| const std::shared_ptr<DatasetCache> &cache) { | |||||
| auto ds = std::make_shared<Cifar100Node>(dataset_dir, usage, sampler, cache); | |||||
| // Function to create a Cifar100Dataset. | |||||
| std::shared_ptr<Cifar100Dataset> Cifar100(const std::string &dataset_dir, const std::string &usage, | |||||
| const std::shared_ptr<SamplerObj> &sampler, | |||||
| const std::shared_ptr<DatasetCache> &cache) { | |||||
| auto ds = std::make_shared<Cifar100Dataset>(dataset_dir, usage, sampler, cache); | |||||
| return ds; | return ds; | ||||
| } | } | ||||
| // Function to create a CLUENode. | |||||
| std::shared_ptr<CLUENode> CLUE(const std::vector<std::string> &clue_files, const std::string &task, | |||||
| const std::string &usage, int64_t num_samples, ShuffleMode shuffle, int32_t num_shards, | |||||
| int32_t shard_id, const std::shared_ptr<DatasetCache> &cache) { | |||||
| auto ds = std::make_shared<CLUENode>(clue_files, task, usage, num_samples, shuffle, num_shards, shard_id, cache); | |||||
| // Function to create a CLUEDataset. | |||||
| std::shared_ptr<CLUEDataset> CLUE(const std::vector<std::string> &clue_files, const std::string &task, | |||||
| const std::string &usage, int64_t num_samples, ShuffleMode shuffle, | |||||
| int32_t num_shards, int32_t shard_id, const std::shared_ptr<DatasetCache> &cache) { | |||||
| auto ds = std::make_shared<CLUEDataset>(clue_files, task, usage, num_samples, shuffle, num_shards, shard_id, cache); | |||||
| return ds; | return ds; | ||||
| } | } | ||||
| // Function to create a CocoNode. | |||||
| std::shared_ptr<CocoNode> Coco(const std::string &dataset_dir, const std::string &annotation_file, | |||||
| const std::string &task, const bool &decode, const std::shared_ptr<SamplerObj> &sampler, | |||||
| const std::shared_ptr<DatasetCache> &cache) { | |||||
| auto ds = std::make_shared<CocoNode>(dataset_dir, annotation_file, task, decode, sampler, cache); | |||||
| // Function to create a CocoDataset. | |||||
| std::shared_ptr<CocoDataset> Coco(const std::string &dataset_dir, const std::string &annotation_file, | |||||
| const std::string &task, const bool &decode, | |||||
| const std::shared_ptr<SamplerObj> &sampler, | |||||
| const std::shared_ptr<DatasetCache> &cache) { | |||||
| auto ds = std::make_shared<CocoDataset>(dataset_dir, annotation_file, task, decode, sampler, cache); | |||||
| return ds; | return ds; | ||||
| } | } | ||||
| // Function to create a CSVNode. | |||||
| std::shared_ptr<CSVNode> CSV(const std::vector<std::string> &dataset_files, char field_delim, | |||||
| const std::vector<std::shared_ptr<CsvBase>> &column_defaults, | |||||
| const std::vector<std::string> &column_names, int64_t num_samples, ShuffleMode shuffle, | |||||
| int32_t num_shards, int32_t shard_id, const std::shared_ptr<DatasetCache> &cache) { | |||||
| auto ds = std::make_shared<CSVNode>(dataset_files, field_delim, column_defaults, column_names, num_samples, shuffle, | |||||
| num_shards, shard_id, cache); | |||||
| // Function to create a CSVDataset. | |||||
| std::shared_ptr<CSVDataset> CSV(const std::vector<std::string> &dataset_files, char field_delim, | |||||
| const std::vector<std::shared_ptr<CsvBase>> &column_defaults, | |||||
| const std::vector<std::string> &column_names, int64_t num_samples, ShuffleMode shuffle, | |||||
| int32_t num_shards, int32_t shard_id, const std::shared_ptr<DatasetCache> &cache) { | |||||
| auto ds = std::make_shared<CSVDataset>(dataset_files, field_delim, column_defaults, column_names, num_samples, | |||||
| shuffle, num_shards, shard_id, cache); | |||||
| return ds; | return ds; | ||||
| } | } | ||||
| // Function to create a ImageFolderNode. | |||||
| std::shared_ptr<ImageFolderNode> ImageFolder(const std::string &dataset_dir, bool decode, | |||||
| const std::shared_ptr<SamplerObj> &sampler, | |||||
| const std::set<std::string> &extensions, | |||||
| const std::map<std::string, int32_t> &class_indexing, | |||||
| const std::shared_ptr<DatasetCache> &cache) { | |||||
| // This arg exists in ImageFolderOp, but not externalized (in Python API). The default value is false. | |||||
| bool recursive = false; | |||||
| // Create logical representation of ImageFolderNode. | |||||
| auto ds = | |||||
| std::make_shared<ImageFolderNode>(dataset_dir, decode, sampler, recursive, extensions, class_indexing, cache); | |||||
| // Function to create a ImageFolderDataset. | |||||
| std::shared_ptr<ImageFolderDataset> ImageFolder(const std::string &dataset_dir, bool decode, | |||||
| const std::shared_ptr<SamplerObj> &sampler, | |||||
| const std::set<std::string> &extensions, | |||||
| const std::map<std::string, int32_t> &class_indexing, | |||||
| const std::shared_ptr<DatasetCache> &cache) { | |||||
| auto ds = std::make_shared<ImageFolderDataset>(dataset_dir, decode, sampler, extensions, class_indexing, cache); | |||||
| return ds; | return ds; | ||||
| } | } | ||||
| #ifndef ENABLE_ANDROID | #ifndef ENABLE_ANDROID | ||||
| // Function to create a ManifestNode. | |||||
| std::shared_ptr<ManifestNode> Manifest(const std::string &dataset_file, const std::string &usage, | |||||
| const std::shared_ptr<SamplerObj> &sampler, | |||||
| const std::map<std::string, int32_t> &class_indexing, bool decode, | |||||
| const std::shared_ptr<DatasetCache> &cache) { | |||||
| auto ds = std::make_shared<ManifestNode>(dataset_file, usage, sampler, class_indexing, decode, cache); | |||||
| // Function to create a ManifestDataset. | |||||
| std::shared_ptr<ManifestDataset> Manifest(const std::string &dataset_file, const std::string &usage, | |||||
| const std::shared_ptr<SamplerObj> &sampler, | |||||
| const std::map<std::string, int32_t> &class_indexing, bool decode, | |||||
| const std::shared_ptr<DatasetCache> &cache) { | |||||
| auto ds = std::make_shared<ManifestDataset>(dataset_file, usage, sampler, class_indexing, decode, cache); | |||||
| return ds; | return ds; | ||||
| } | } | ||||
| // Function to create a MindDataNode. | |||||
| std::shared_ptr<MindDataNode> MindData(const std::string &dataset_file, const std::vector<std::string> &columns_list, | |||||
| const std::shared_ptr<SamplerObj> &sampler, nlohmann::json padded_sample, | |||||
| int64_t num_padded) { | |||||
| auto ds = std::make_shared<MindDataNode>(dataset_file, columns_list, sampler, padded_sample, num_padded); | |||||
| // Function to create a MindDataDataset. | |||||
| std::shared_ptr<MindDataDataset> MindData(const std::string &dataset_file, const std::vector<std::string> &columns_list, | |||||
| const std::shared_ptr<SamplerObj> &sampler, nlohmann::json padded_sample, | |||||
| int64_t num_padded) { | |||||
| auto ds = std::make_shared<MindDataDataset>(dataset_file, columns_list, sampler, padded_sample, num_padded); | |||||
| return ds; | return ds; | ||||
| } | } | ||||
| // Function to create a MindDataNode. | |||||
| std::shared_ptr<MindDataNode> MindData(const std::vector<std::string> &dataset_files, | |||||
| const std::vector<std::string> &columns_list, | |||||
| const std::shared_ptr<SamplerObj> &sampler, nlohmann::json padded_sample, | |||||
| int64_t num_padded) { | |||||
| auto ds = std::make_shared<MindDataNode>(dataset_files, columns_list, sampler, padded_sample, num_padded); | |||||
| // Function to create a MindDataDataset. | |||||
| std::shared_ptr<MindDataDataset> MindData(const std::vector<std::string> &dataset_files, | |||||
| const std::vector<std::string> &columns_list, | |||||
| const std::shared_ptr<SamplerObj> &sampler, nlohmann::json padded_sample, | |||||
| int64_t num_padded) { | |||||
| auto ds = std::make_shared<MindDataDataset>(dataset_files, columns_list, sampler, padded_sample, num_padded); | |||||
| return ds; | return ds; | ||||
| } | } | ||||
| #endif | #endif | ||||
| // Function to create a MnistNode. | |||||
| std::shared_ptr<MnistNode> Mnist(const std::string &dataset_dir, const std::string &usage, | |||||
| const std::shared_ptr<SamplerObj> &sampler, | |||||
| const std::shared_ptr<DatasetCache> &cache) { | |||||
| auto ds = std::make_shared<MnistNode>(dataset_dir, usage, sampler, cache); | |||||
| // Function to create a MnistDataset. | |||||
| std::shared_ptr<MnistDataset> Mnist(const std::string &dataset_dir, const std::string &usage, | |||||
| const std::shared_ptr<SamplerObj> &sampler, | |||||
| const std::shared_ptr<DatasetCache> &cache) { | |||||
| auto ds = std::make_shared<MnistDataset>(dataset_dir, usage, sampler, cache); | |||||
| return ds; | return ds; | ||||
| } | } | ||||
| // Function to overload "+" operator to concat two datasets | // Function to overload "+" operator to concat two datasets | ||||
| std::shared_ptr<ConcatNode> operator+(const std::shared_ptr<Dataset> &datasets1, | |||||
| const std::shared_ptr<Dataset> &datasets2) { | |||||
| std::shared_ptr<ConcatNode> ds = std::make_shared<ConcatNode>(std::vector({datasets2, datasets1})); | |||||
| return ds; | |||||
| std::shared_ptr<ConcatDataset> operator+(const std::shared_ptr<Dataset> &datasets1, | |||||
| const std::shared_ptr<Dataset> &datasets2) { | |||||
| return std::make_shared<ConcatDataset>(std::vector({datasets2, datasets1})); | |||||
| } | } | ||||
| // Function to create a TextFileNode. | |||||
| std::shared_ptr<TextFileNode> TextFile(const std::vector<std::string> &dataset_files, int64_t num_samples, | |||||
| ShuffleMode shuffle, int32_t num_shards, int32_t shard_id, | |||||
| const std::shared_ptr<DatasetCache> &cache) { | |||||
| auto ds = std::make_shared<TextFileNode>(dataset_files, num_samples, shuffle, num_shards, shard_id, cache); | |||||
| // Function to create a TextFileDataset. | |||||
| std::shared_ptr<TextFileDataset> TextFile(const std::vector<std::string> &dataset_files, int64_t num_samples, | |||||
| ShuffleMode shuffle, int32_t num_shards, int32_t shard_id, | |||||
| const std::shared_ptr<DatasetCache> &cache) { | |||||
| auto ds = std::make_shared<TextFileDataset>(dataset_files, num_samples, shuffle, num_shards, shard_id, cache); | |||||
| return ds; | return ds; | ||||
| } | } | ||||
| #ifndef ENABLE_ANDROID | #ifndef ENABLE_ANDROID | ||||
| // Function to create a VOCNode. | |||||
| std::shared_ptr<VOCNode> VOC(const std::string &dataset_dir, const std::string &task, const std::string &usage, | |||||
| const std::map<std::string, int32_t> &class_indexing, bool decode, | |||||
| const std::shared_ptr<SamplerObj> &sampler, const std::shared_ptr<DatasetCache> &cache) { | |||||
| auto ds = std::make_shared<VOCNode>(dataset_dir, task, usage, class_indexing, decode, sampler, cache); | |||||
| // Function to create a VOCDataset. | |||||
| std::shared_ptr<VOCDataset> VOC(const std::string &dataset_dir, const std::string &task, const std::string &usage, | |||||
| const std::map<std::string, int32_t> &class_indexing, bool decode, | |||||
| const std::shared_ptr<SamplerObj> &sampler, | |||||
| const std::shared_ptr<DatasetCache> &cache) { | |||||
| auto ds = std::make_shared<VOCDataset>(dataset_dir, task, usage, class_indexing, decode, sampler, cache); | |||||
| return ds; | return ds; | ||||
| } | } | ||||
| #endif | #endif | ||||
| // Function to create a ZipNode. | // Function to create a ZipNode. | ||||
| std::shared_ptr<ZipNode> Zip(const std::vector<std::shared_ptr<Dataset>> &datasets) { | |||||
| auto ds = std::make_shared<ZipNode>(datasets); | |||||
| std::shared_ptr<ZipDataset> Zip(const std::vector<std::shared_ptr<Dataset>> &datasets) { | |||||
| auto ds = std::make_shared<ZipDataset>(datasets); | |||||
| return ds; | return ds; | ||||
| } | } | ||||
| @@ -508,170 +476,112 @@ std::shared_ptr<ZipNode> Zip(const std::vector<std::shared_ptr<Dataset>> &datase | |||||
| // (In alphabetical order) | // (In alphabetical order) | ||||
| // Function to create a Batch dataset | // Function to create a Batch dataset | ||||
| std::shared_ptr<BatchNode> Dataset::Batch(int32_t batch_size, bool drop_remainder) { | |||||
| BatchDataset::BatchDataset(std::shared_ptr<Dataset> input, int32_t batch_size, bool drop_remainder) { | |||||
| // Default values | // Default values | ||||
| std::vector<std::string> cols_to_map = {}; | std::vector<std::string> cols_to_map = {}; | ||||
| std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> pad_map; | std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> pad_map; | ||||
| bool pad = false; | bool pad = false; | ||||
| auto ds = std::make_shared<BatchNode>(shared_from_this(), batch_size, drop_remainder, pad, cols_to_map, pad_map); | |||||
| return ds; | |||||
| auto ds = std::make_shared<BatchNode>(input->IRNode(), batch_size, drop_remainder, pad, cols_to_map, pad_map); | |||||
| ir_node_ = std::static_pointer_cast<DatasetNode>(ds); | |||||
| } | } | ||||
| #ifndef ENABLE_ANDROID | #ifndef ENABLE_ANDROID | ||||
| // Function to create a BucketBatchByLength dataset | // Function to create a BucketBatchByLength dataset | ||||
| std::shared_ptr<BucketBatchByLengthNode> Dataset::BucketBatchByLength( | |||||
| const std::vector<std::string> &column_names, const std::vector<int32_t> &bucket_boundaries, | |||||
| const std::vector<int32_t> &bucket_batch_sizes, std::function<TensorRow(TensorRow)> element_length_function, | |||||
| BucketBatchByLengthDataset::BucketBatchByLengthDataset( | |||||
| std::shared_ptr<Dataset> input, const std::vector<std::string> &column_names, | |||||
| const std::vector<int32_t> &bucket_boundaries, const std::vector<int32_t> &bucket_batch_sizes, | |||||
| std::function<TensorRow(TensorRow)> element_length_function, | |||||
| const std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> &pad_info, bool pad_to_bucket_boundary, | const std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> &pad_info, bool pad_to_bucket_boundary, | ||||
| bool drop_remainder) { | bool drop_remainder) { | ||||
| auto ds = std::make_shared<BucketBatchByLengthNode>(shared_from_this(), column_names, bucket_boundaries, | |||||
| auto ds = std::make_shared<BucketBatchByLengthNode>(input->IRNode(), column_names, bucket_boundaries, | |||||
| bucket_batch_sizes, element_length_function, pad_info, | bucket_batch_sizes, element_length_function, pad_info, | ||||
| pad_to_bucket_boundary, drop_remainder); | pad_to_bucket_boundary, drop_remainder); | ||||
| return ds; | |||||
| } | |||||
| // Function to create a SentencePieceVocab from dataset | |||||
| std::shared_ptr<SentencePieceVocab> Dataset::BuildSentencePieceVocab( | |||||
| const std::vector<std::string> &col_names, uint32_t vocab_size, float character_coverage, | |||||
| SentencePieceModel model_type, const std::unordered_map<std::string, std::string> ¶ms) { | |||||
| auto vocab = std::make_shared<SentencePieceVocab>(); | |||||
| auto ds = std::make_shared<BuildSentenceVocabNode>(shared_from_this(), vocab, col_names, vocab_size, | |||||
| character_coverage, model_type, params); | |||||
| // Run tree here to start building vocab | |||||
| std::shared_ptr<Iterator> iter = ds->CreateIterator(); | |||||
| if (iter == nullptr) { | |||||
| MS_LOG(ERROR) << "Fail to run iterator in BuildSentencePieceVocab."; | |||||
| return nullptr; | |||||
| } | |||||
| // Finish building vocab by triggering GetNextRow | |||||
| std::unordered_map<std::string, std::shared_ptr<Tensor>> row; | |||||
| if (!iter->GetNextRow(&row)) { | |||||
| return nullptr; | |||||
| } | |||||
| return vocab; | |||||
| ir_node_ = std::static_pointer_cast<DatasetNode>(ds); | |||||
| } | } | ||||
| // Function to create a Vocab from dataset | |||||
| std::shared_ptr<Vocab> Dataset::BuildVocab(const std::vector<std::string> &columns, | |||||
| const std::pair<int64_t, int64_t> &freq_range, int64_t top_k, | |||||
| const std::vector<std::string> &special_tokens, bool special_first) { | |||||
| auto vocab = std::make_shared<Vocab>(); | |||||
| auto ds = std::make_shared<BuildVocabNode>(shared_from_this(), vocab, columns, freq_range, top_k, special_tokens, | |||||
| special_first); | |||||
| // Run tree here to starting building vocab | |||||
| std::shared_ptr<Iterator> iter = ds->CreateIterator(); | |||||
| if (iter == nullptr) { | |||||
| MS_LOG(ERROR) << "Fail to run iterator in BuildVocab."; | |||||
| return nullptr; | |||||
| } | |||||
| // Finish building vocab by triggering GetNextRow | |||||
| std::unordered_map<std::string, std::shared_ptr<Tensor>> row; | |||||
| if (!iter->GetNextRow(&row)) { | |||||
| return nullptr; | |||||
| } | |||||
| return vocab; | |||||
| } | |||||
| #endif | #endif | ||||
| // Function to create a Concat dataset | |||||
| std::shared_ptr<ConcatNode> Dataset::Concat(const std::vector<std::shared_ptr<Dataset>> &datasets) { | |||||
| auto ds = std::make_shared<ConcatNode>(datasets); | |||||
| ds->children.push_back(shared_from_this()); | |||||
| ConcatDataset::ConcatDataset(const std::vector<std::shared_ptr<Dataset>> &datasets) { | |||||
| std::vector<std::shared_ptr<DatasetNode>> all_datasets; | |||||
| (void)std::transform( | |||||
| datasets.begin(), datasets.end(), std::back_inserter(all_datasets), | |||||
| [](std::shared_ptr<Dataset> dataset) -> std::shared_ptr<DatasetNode> { return dataset->IRNode(); }); | |||||
| return ds; | |||||
| auto ds = std::make_shared<ConcatNode>(all_datasets); | |||||
| ir_node_ = std::static_pointer_cast<DatasetNode>(ds); | |||||
| } | } | ||||
| // Function to create a Map dataset. | |||||
| std::shared_ptr<MapNode> Dataset::Map(std::vector<std::shared_ptr<TensorOperation>> operations, | |||||
| std::vector<std::string> input_columns, std::vector<std::string> output_columns, | |||||
| const std::vector<std::string> &project_columns, | |||||
| const std::shared_ptr<DatasetCache> &cache) { | |||||
| MapDataset::MapDataset(std::shared_ptr<Dataset> input, std::vector<std::shared_ptr<TensorOperation>> operations, | |||||
| std::vector<std::string> input_columns, std::vector<std::string> output_columns, | |||||
| const std::vector<std::string> &project_columns, const std::shared_ptr<DatasetCache> &cache) { | |||||
| auto ds = | auto ds = | ||||
| std::make_shared<MapNode>(shared_from_this(), operations, input_columns, output_columns, project_columns, cache); | |||||
| std::make_shared<MapNode>(input->IRNode(), operations, input_columns, output_columns, project_columns, cache); | |||||
| return ds; | |||||
| ir_node_ = std::static_pointer_cast<DatasetNode>(ds); | |||||
| } | } | ||||
| // Function to create a ProjectNode. | |||||
| std::shared_ptr<ProjectNode> Dataset::Project(const std::vector<std::string> &columns) { | |||||
| auto ds = std::make_shared<ProjectNode>(shared_from_this(), columns); | |||||
| ProjectDataset::ProjectDataset(std::shared_ptr<Dataset> input, const std::vector<std::string> &columns) { | |||||
| auto ds = std::make_shared<ProjectNode>(input->IRNode(), columns); | |||||
| return ds; | |||||
| ir_node_ = std::static_pointer_cast<DatasetNode>(ds); | |||||
| } | } | ||||
| // Function to create a RenameNode. | |||||
| std::shared_ptr<RenameNode> Dataset::Rename(const std::vector<std::string> &input_columns, | |||||
| const std::vector<std::string> &output_columns) { | |||||
| auto ds = std::make_shared<RenameNode>(shared_from_this(), input_columns, output_columns); | |||||
| RenameDataset::RenameDataset(std::shared_ptr<Dataset> input, const std::vector<std::string> &input_columns, | |||||
| const std::vector<std::string> &output_columns) { | |||||
| auto ds = std::make_shared<RenameNode>(input->IRNode(), input_columns, output_columns); | |||||
| return ds; | |||||
| ir_node_ = std::static_pointer_cast<DatasetNode>(ds); | |||||
| } | } | ||||
| // Function to create Repeat dataset. | |||||
| std::shared_ptr<Dataset> Dataset::Repeat(int32_t count) { | |||||
| RepeatDataset::RepeatDataset(std::shared_ptr<Dataset> input, int32_t count) { | |||||
| // Workaround for repeat == 1, do not inject repeat. | // Workaround for repeat == 1, do not inject repeat. | ||||
| if (count == 1) { | if (count == 1) { | ||||
| return shared_from_this(); | |||||
| ir_node_ = input->IRNode(); | |||||
| return; | |||||
| } | } | ||||
| auto ds = std::make_shared<RepeatNode>(shared_from_this(), count); | |||||
| auto ds = std::make_shared<RepeatNode>(input->IRNode(), count); | |||||
| return ds; | |||||
| ir_node_ = std::static_pointer_cast<DatasetNode>(ds); | |||||
| } | } | ||||
| // Function to create a ShuffleOp | |||||
| std::shared_ptr<ShuffleNode> Dataset::Shuffle(int32_t buffer_size) { | |||||
| ShuffleDataset::ShuffleDataset(std::shared_ptr<Dataset> input, int32_t buffer_size) { | |||||
| // Pass in reshuffle_each_epoch with true | // Pass in reshuffle_each_epoch with true | ||||
| auto ds = std::make_shared<ShuffleNode>(shared_from_this(), buffer_size, true); | |||||
| auto ds = std::make_shared<ShuffleNode>(input->IRNode(), buffer_size, true); | |||||
| return ds; | |||||
| ir_node_ = std::static_pointer_cast<DatasetNode>(ds); | |||||
| } | } | ||||
| // Function to create a SkipNode. | |||||
| std::shared_ptr<SkipNode> Dataset::Skip(int32_t count) { | |||||
| auto ds = std::make_shared<SkipNode>(shared_from_this(), count); | |||||
| SkipDataset::SkipDataset(std::shared_ptr<Dataset> input, int32_t count) { | |||||
| auto ds = std::make_shared<SkipNode>(input->IRNode(), count); | |||||
| return ds; | |||||
| ir_node_ = std::static_pointer_cast<DatasetNode>(ds); | |||||
| } | } | ||||
| // Function to create a TakeNode. | |||||
| std::shared_ptr<Dataset> Dataset::Take(int32_t count) { | |||||
| TakeDataset::TakeDataset(std::shared_ptr<Dataset> input, int32_t count) { | |||||
| // If count is greater than the number of element in dataset or equal to -1, | // If count is greater than the number of element in dataset or equal to -1, | ||||
| // all the element in dataset will be taken | // all the element in dataset will be taken | ||||
| if (count == -1) { | if (count == -1) { | ||||
| return shared_from_this(); | |||||
| ir_node_ = input->IRNode(); | |||||
| return; | |||||
| } | } | ||||
| auto ds = std::make_shared<TakeNode>(shared_from_this(), count); | |||||
| auto ds = std::make_shared<TakeNode>(input->IRNode(), count); | |||||
| return ds; | |||||
| ir_node_ = std::static_pointer_cast<DatasetNode>(ds); | |||||
| } | } | ||||
| // Function to create a Zip dataset | |||||
| std::shared_ptr<ZipNode> Dataset::Zip(const std::vector<std::shared_ptr<Dataset>> &datasets) { | |||||
| // Default values | |||||
| auto ds = std::make_shared<ZipNode>(datasets); | |||||
| ds->children.push_back(shared_from_this()); | |||||
| ZipDataset::ZipDataset(const std::vector<std::shared_ptr<Dataset>> &datasets) { | |||||
| std::vector<std::shared_ptr<DatasetNode>> all_datasets; | |||||
| (void)std::transform( | |||||
| datasets.begin(), datasets.end(), std::back_inserter(all_datasets), | |||||
| [](std::shared_ptr<Dataset> dataset) -> std::shared_ptr<DatasetNode> { return dataset->IRNode(); }); | |||||
| return ds; | |||||
| } | |||||
| auto ds = std::make_shared<ZipNode>(all_datasets); | |||||
| Status Dataset::AddCacheOp(std::vector<std::shared_ptr<DatasetOp>> *node_ops) { | |||||
| if (cache_ != nullptr) { | |||||
| RETURN_IF_NOT_OK(cache_->Build()); | |||||
| std::shared_ptr<DatasetOp> cache_op; | |||||
| RETURN_IF_NOT_OK(cache_->CreateCacheOp(num_workers_, &cache_op)); | |||||
| node_ops->push_back(cache_op); | |||||
| } | |||||
| return Status::OK(); | |||||
| ir_node_ = std::static_pointer_cast<DatasetNode>(ds); | |||||
| } | } | ||||
| int64_t Dataset::GetBatchSize() { | int64_t Dataset::GetBatchSize() { | ||||
| @@ -685,7 +595,7 @@ int64_t Dataset::GetBatchSize() { | |||||
| return -1; | return -1; | ||||
| } | } | ||||
| if (!tree_getters_->isInitialized()) { | if (!tree_getters_->isInitialized()) { | ||||
| rc = tree_getters_->Init(ds); | |||||
| rc = tree_getters_->Init(ds->IRNode()); | |||||
| if (rc.IsError()) { | if (rc.IsError()) { | ||||
| MS_LOG(ERROR) << "GetBatchSize: Initializing TreeGetters failed."; | MS_LOG(ERROR) << "GetBatchSize: Initializing TreeGetters failed."; | ||||
| return -1; | return -1; | ||||
| @@ -706,7 +616,7 @@ int64_t Dataset::GetRepeatCount() { | |||||
| return -1; | return -1; | ||||
| } | } | ||||
| if (!tree_getters_->isInitialized()) { | if (!tree_getters_->isInitialized()) { | ||||
| rc = tree_getters_->Init(ds); | |||||
| rc = tree_getters_->Init(ds->IRNode()); | |||||
| if (rc.IsError()) { | if (rc.IsError()) { | ||||
| MS_LOG(ERROR) << "GetRepeatCount: Initializing TreeGetters failed."; | MS_LOG(ERROR) << "GetRepeatCount: Initializing TreeGetters failed."; | ||||
| return -1; | return -1; | ||||
| @@ -715,7 +625,77 @@ int64_t Dataset::GetRepeatCount() { | |||||
| rc = tree_getters_->GetRepeatCount(&repeat_count); | rc = tree_getters_->GetRepeatCount(&repeat_count); | ||||
| return rc.IsError() ? 0 : repeat_count; | return rc.IsError() ? 0 : repeat_count; | ||||
| } | } | ||||
| std::shared_ptr<Dataset> Dataset::SetNumWorkers(int32_t num_workers) { | |||||
| if (ir_node_ == nullptr || ir_node_->SetNumWorkers(num_workers) == nullptr) { | |||||
| return nullptr; | |||||
| } | |||||
| return shared_from_this(); | |||||
| } | |||||
| #ifndef ENABLE_ANDROID | |||||
| std::shared_ptr<SentencePieceVocab> Dataset::BuildSentencePieceVocab( | |||||
| const std::vector<std::string> &col_names, uint32_t vocab_size, float character_coverage, | |||||
| SentencePieceModel model_type, const std::unordered_map<std::string, std::string> ¶ms) { | |||||
| auto vocab = std::make_shared<SentencePieceVocab>(); | |||||
| auto ds = std::make_shared<BuildSentenceVocabNode>(IRNode(), vocab, col_names, vocab_size, character_coverage, | |||||
| model_type, params); | |||||
| std::unique_ptr<RuntimeContext> runtime_context = std::make_unique<RuntimeContext>(); | |||||
| Status rc = runtime_context->Init(); | |||||
| if (rc.IsError()) { | |||||
| MS_LOG(ERROR) << "Failed to init runtime context. Error status: " << rc; | |||||
| return nullptr; | |||||
| } | |||||
| auto consumer = std::make_unique<BuildVocabConsumer>(); | |||||
| BuildVocabConsumer *bv_consumer = consumer.get(); | |||||
| rc = consumer->Init(ds); | |||||
| if (rc.IsError()) { | |||||
| MS_LOG(ERROR) << "BuildVocab: Failed to init. Error status: " << rc; | |||||
| return nullptr; | |||||
| } | |||||
| runtime_context->AssignConsumer(std::move(consumer)); | |||||
| // Run tree here to starting building vocab | |||||
| rc = bv_consumer->Start(); | |||||
| if (rc.IsError()) { | |||||
| MS_LOG(ERROR) << "BuildVocab: Failed to start. Error status: " << rc; | |||||
| return nullptr; | |||||
| } | |||||
| return vocab; | |||||
| } | |||||
| std::shared_ptr<Vocab> Dataset::BuildVocab(const std::vector<std::string> &columns, | |||||
| const std::pair<int64_t, int64_t> &freq_range, int64_t top_k, | |||||
| const std::vector<std::string> &special_tokens, bool special_first) { | |||||
| auto vocab = std::make_shared<Vocab>(); | |||||
| auto ds = | |||||
| std::make_shared<BuildVocabNode>(IRNode(), vocab, columns, freq_range, top_k, special_tokens, special_first); | |||||
| std::unique_ptr<RuntimeContext> runtime_context = std::make_unique<RuntimeContext>(); | |||||
| Status rc = runtime_context->Init(); | |||||
| if (rc.IsError()) { | |||||
| MS_LOG(ERROR) << "Failed to init runtime context. Error status: " << rc; | |||||
| return nullptr; | |||||
| } | |||||
| auto consumer = std::make_unique<BuildVocabConsumer>(); | |||||
| BuildVocabConsumer *bv_consumer = consumer.get(); | |||||
| rc = consumer->Init(ds); | |||||
| if (rc.IsError()) { | |||||
| MS_LOG(ERROR) << "BuildVocab: Failed to init. Error status: " << rc; | |||||
| return nullptr; | |||||
| } | |||||
| runtime_context->AssignConsumer(std::move(consumer)); | |||||
| // Run tree here to starting building vocab | |||||
| rc = bv_consumer->Start(); | |||||
| if (rc.IsError()) { | |||||
| MS_LOG(ERROR) << "BuildVocab: Failed to start. Error status: " << rc; | |||||
| return nullptr; | |||||
| } | |||||
| return vocab; | |||||
| } | |||||
| #endif | |||||
| SchemaObj::SchemaObj(const std::string &schema_file) : schema_file_(schema_file), num_rows_(0), dataset_type_("") {} | SchemaObj::SchemaObj(const std::string &schema_file) : schema_file_(schema_file), num_rows_(0), dataset_type_("") {} | ||||
| // SchemaObj init function | // SchemaObj init function | ||||
| @@ -1046,6 +1026,136 @@ std::shared_ptr<DatasetCache> CreateDatasetCache(session_id_type id, uint64_t me | |||||
| } | } | ||||
| #endif | #endif | ||||
| AlbumDataset::AlbumDataset(const std::string &dataset_dir, const std::string &data_schema, | |||||
| const std::vector<std::string> &column_names, bool decode, | |||||
| const std::shared_ptr<SamplerObj> &sampler, const std::shared_ptr<DatasetCache> &cache) { | |||||
| auto ds = std::make_shared<AlbumNode>(dataset_dir, data_schema, column_names, decode, sampler, cache); | |||||
| ir_node_ = std::static_pointer_cast<DatasetNode>(ds); | |||||
| } | |||||
| CelebADataset::CelebADataset(const std::string &dataset_dir, const std::string &usage, | |||||
| const std::shared_ptr<SamplerObj> &sampler, bool decode, | |||||
| const std::set<std::string> &extensions, const std::shared_ptr<DatasetCache> &cache) { | |||||
| auto ds = std::make_shared<CelebANode>(dataset_dir, usage, sampler, decode, extensions, cache); | |||||
| ir_node_ = std::static_pointer_cast<DatasetNode>(ds); | |||||
| } | |||||
| Cifar10Dataset::Cifar10Dataset(const std::string &dataset_dir, const std::string &usage, | |||||
| const std::shared_ptr<SamplerObj> &sampler, const std::shared_ptr<DatasetCache> &cache) { | |||||
| auto ds = std::make_shared<Cifar10Node>(dataset_dir, usage, sampler, cache); | |||||
| ir_node_ = std::static_pointer_cast<DatasetNode>(ds); | |||||
| } | |||||
| Cifar100Dataset::Cifar100Dataset(const std::string &dataset_dir, const std::string &usage, | |||||
| const std::shared_ptr<SamplerObj> &sampler, | |||||
| const std::shared_ptr<DatasetCache> &cache) { | |||||
| auto ds = std::make_shared<Cifar100Node>(dataset_dir, usage, sampler, cache); | |||||
| ir_node_ = std::static_pointer_cast<DatasetNode>(ds); | |||||
| } | |||||
| CLUEDataset::CLUEDataset(const std::vector<std::string> &dataset_files, const std::string &task, | |||||
| const std::string &usage, int64_t num_samples, ShuffleMode shuffle, int32_t num_shards, | |||||
| int32_t shard_id, const std::shared_ptr<DatasetCache> &cache) { | |||||
| auto ds = std::make_shared<CLUENode>(dataset_files, task, usage, num_samples, shuffle, num_shards, shard_id, cache); | |||||
| ir_node_ = std::static_pointer_cast<DatasetNode>(ds); | |||||
| } | |||||
| CocoDataset::CocoDataset(const std::string &dataset_dir, const std::string &annotation_file, const std::string &task, | |||||
| const bool &decode, const std::shared_ptr<SamplerObj> &sampler, | |||||
| const std::shared_ptr<DatasetCache> &cache) { | |||||
| auto ds = std::make_shared<CocoNode>(dataset_dir, annotation_file, task, decode, sampler, cache); | |||||
| ir_node_ = std::static_pointer_cast<DatasetNode>(ds); | |||||
| } | |||||
| CSVDataset::CSVDataset(const std::vector<std::string> &dataset_files, char field_delim, | |||||
| const std::vector<std::shared_ptr<CsvBase>> &column_defaults, | |||||
| const std::vector<std::string> &column_names, int64_t num_samples, ShuffleMode shuffle, | |||||
| int32_t num_shards, int32_t shard_id, const std::shared_ptr<DatasetCache> &cache) { | |||||
| auto ds = std::make_shared<CSVNode>(dataset_files, field_delim, column_defaults, column_names, num_samples, shuffle, | |||||
| num_shards, shard_id, cache); | |||||
| ir_node_ = std::static_pointer_cast<DatasetNode>(ds); | |||||
| } | |||||
| ImageFolderDataset::ImageFolderDataset(const std::string &dataset_dir, bool decode, | |||||
| const std::shared_ptr<SamplerObj> &sampler, | |||||
| const std::set<std::string> &extensions, | |||||
| const std::map<std::string, int32_t> &class_indexing, | |||||
| const std::shared_ptr<DatasetCache> &cache) { | |||||
| // This arg exists in ImageFolderOp, but not externalized (in Python API). The default value is false. | |||||
| bool recursive = false; | |||||
| // Create logical representation of ImageFolderDataset. | |||||
| auto ds = | |||||
| std::make_shared<ImageFolderNode>(dataset_dir, decode, sampler, recursive, extensions, class_indexing, cache); | |||||
| ir_node_ = std::static_pointer_cast<DatasetNode>(ds); | |||||
| } | |||||
| #ifndef ENABLE_ANDROID | |||||
| ManifestDataset::ManifestDataset(const std::string &dataset_file, const std::string &usage, | |||||
| const std::shared_ptr<SamplerObj> &sampler, | |||||
| const std::map<std::string, int32_t> &class_indexing, bool decode, | |||||
| const std::shared_ptr<DatasetCache> &cache) { | |||||
| auto ds = std::make_shared<ManifestNode>(dataset_file, usage, sampler, class_indexing, decode, cache); | |||||
| ir_node_ = std::static_pointer_cast<DatasetNode>(ds); | |||||
| } | |||||
| MindDataDataset::MindDataDataset(const std::string &dataset_file, const std::vector<std::string> &columns_list, | |||||
| const std::shared_ptr<SamplerObj> &sampler, nlohmann::json padded_sample, | |||||
| int64_t num_padded) { | |||||
| auto ds = std::make_shared<MindDataNode>(dataset_file, columns_list, sampler, padded_sample, num_padded); | |||||
| ir_node_ = std::static_pointer_cast<DatasetNode>(ds); | |||||
| } | |||||
| MindDataDataset::MindDataDataset(const std::vector<std::string> &dataset_files, | |||||
| const std::vector<std::string> &columns_list, | |||||
| const std::shared_ptr<SamplerObj> &sampler, nlohmann::json padded_sample, | |||||
| int64_t num_padded) { | |||||
| auto ds = std::make_shared<MindDataNode>(dataset_files, columns_list, sampler, padded_sample, num_padded); | |||||
| ir_node_ = std::static_pointer_cast<DatasetNode>(ds); | |||||
| } | |||||
| #endif | |||||
| MnistDataset::MnistDataset(const std::string &dataset_dir, const std::string &usage, | |||||
| const std::shared_ptr<SamplerObj> &sampler, const std::shared_ptr<DatasetCache> &cache) { | |||||
| auto ds = std::make_shared<MnistNode>(dataset_dir, usage, sampler, cache); | |||||
| ir_node_ = std::static_pointer_cast<DatasetNode>(ds); | |||||
| } | |||||
| TextFileDataset::TextFileDataset(const std::vector<std::string> &dataset_files, int64_t num_samples, | |||||
| ShuffleMode shuffle, int32_t num_shards, int32_t shard_id, | |||||
| const std::shared_ptr<DatasetCache> &cache) { | |||||
| auto ds = std::make_shared<TextFileNode>(dataset_files, num_samples, shuffle, num_shards, shard_id, cache); | |||||
| ir_node_ = std::static_pointer_cast<DatasetNode>(ds); | |||||
| } | |||||
| #ifndef ENABLE_ANDROID | |||||
| VOCDataset::VOCDataset(const std::string &dataset_dir, const std::string &task, const std::string &usage, | |||||
| const std::map<std::string, int32_t> &class_indexing, bool decode, | |||||
| const std::shared_ptr<SamplerObj> &sampler, const std::shared_ptr<DatasetCache> &cache) { | |||||
| auto ds = std::make_shared<VOCNode>(dataset_dir, task, usage, class_indexing, decode, sampler, cache); | |||||
| ir_node_ = std::static_pointer_cast<DatasetNode>(ds); | |||||
| } | |||||
| #endif | |||||
| RandomDataDataset::RandomDataDataset(const int32_t &total_rows, std::shared_ptr<SchemaObj> schema, | |||||
| const std::vector<std::string> &columns_list, | |||||
| const std::shared_ptr<SamplerObj> &sampler, std::shared_ptr<DatasetCache> cache) { | |||||
| auto ds = | |||||
| std::make_shared<RandomNode>(total_rows, std::move(schema), std::move(columns_list), std::move(sampler), cache); | |||||
| ir_node_ = std::static_pointer_cast<DatasetNode>(ds); | |||||
| } | |||||
| RandomDataDataset::RandomDataDataset(const int32_t &total_rows, std::string schema_path, | |||||
| const std::vector<std::string> &columns_list, | |||||
| const std::shared_ptr<SamplerObj> &sampler, std::shared_ptr<DatasetCache> cache) { | |||||
| auto ds = std::make_shared<RandomNode>(total_rows, std::move(schema_path), std::move(columns_list), | |||||
| std::move(sampler), cache); | |||||
| ir_node_ = std::static_pointer_cast<DatasetNode>(ds); | |||||
| } | |||||
| #ifndef ENABLE_ANDROID | |||||
| TFRecordDataset::TFRecordDataset(const std::vector<std::string> &dataset_files, std::string schema, | |||||
| const std::vector<std::string> &columns_list, int64_t num_samples, ShuffleMode shuffle, | |||||
| int32_t num_shards, int32_t shard_id, bool shard_equal_rows, | |||||
| std::shared_ptr<DatasetCache> cache) { | |||||
| auto ds = std::make_shared<TFRecordNode>(dataset_files, schema, columns_list, num_samples, shuffle, num_shards, | |||||
| shard_id, shard_equal_rows, cache); | |||||
| ir_node_ = std::static_pointer_cast<DatasetNode>(ds); | |||||
| } | |||||
| TFRecordDataset::TFRecordDataset(const std::vector<std::string> &dataset_files, std::shared_ptr<SchemaObj> schema, | |||||
| const std::vector<std::string> &columns_list, int64_t num_samples, ShuffleMode shuffle, | |||||
| int32_t num_shards, int32_t shard_id, bool shard_equal_rows, | |||||
| std::shared_ptr<DatasetCache> cache) { | |||||
| auto ds = std::make_shared<TFRecordNode>(dataset_files, schema, columns_list, num_samples, shuffle, num_shards, | |||||
| shard_id, shard_equal_rows, cache); | |||||
| ir_node_ = std::static_pointer_cast<DatasetNode>(ds); | |||||
| } | |||||
| #endif | |||||
| std::shared_ptr<SamplerObj> SelectSampler(int64_t num_samples, bool shuffle, int32_t num_shards, int32_t shard_id) { | std::shared_ptr<SamplerObj> SelectSampler(int64_t num_samples, bool shuffle, int32_t num_shards, int32_t shard_id) { | ||||
| if (shuffle) { | if (shuffle) { | ||||
| if (num_shards > 1) { | if (num_shards > 1) { | ||||
| @@ -1062,7 +1172,6 @@ std::shared_ptr<SamplerObj> SelectSampler(int64_t num_samples, bool shuffle, int | |||||
| // If shuffle disabled, sharding disabled, use sequential sampler | // If shuffle disabled, sharding disabled, use sequential sampler | ||||
| return SequentialSampler(0, num_samples); | return SequentialSampler(0, num_samples); | ||||
| } | } | ||||
| } // namespace api | } // namespace api | ||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -53,7 +53,7 @@ Status Iterator::BuildAndLaunchTree(std::shared_ptr<Dataset> ds) { | |||||
| RETURN_IF_NOT_OK(runtime_context->Init()); | RETURN_IF_NOT_OK(runtime_context->Init()); | ||||
| auto consumer = std::make_unique<IteratorConsumer>(); | auto consumer = std::make_unique<IteratorConsumer>(); | ||||
| consumer_ = consumer.get(); | consumer_ = consumer.get(); | ||||
| RETURN_IF_NOT_OK(consumer->Init(ds)); | |||||
| RETURN_IF_NOT_OK(consumer->Init(ds->IRNode())); | |||||
| runtime_context->AssignConsumer(std::move(consumer)); | runtime_context->AssignConsumer(std::move(consumer)); | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -11,7 +11,7 @@ endif () | |||||
| file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") | file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") | ||||
| set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) | set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) | ||||
| add_library(engine OBJECT | |||||
| set(SRC_FILES_LIST | |||||
| execution_tree.cc | execution_tree.cc | ||||
| data_buffer.cc | data_buffer.cc | ||||
| data_schema.cc | data_schema.cc | ||||
| @@ -20,10 +20,19 @@ add_library(engine OBJECT | |||||
| runtime_context.cc | runtime_context.cc | ||||
| consumers/tree_consumer.cc | consumers/tree_consumer.cc | ||||
| ) | ) | ||||
| if (ENABLE_PYTHON) | |||||
| set(SRC_FILES_LIST | |||||
| ${SRC_FILES_LIST} | |||||
| python_runtime_context.cc | |||||
| consumers/python_tree_consumer.cc | |||||
| ) | |||||
| endif () | |||||
| add_library(engine OBJECT ${SRC_FILES_LIST}) | |||||
| if (ENABLE_PYTHON) | if (ENABLE_PYTHON) | ||||
| target_include_directories(engine PRIVATE ${pybind11_INCLUDE_DIRS}) | |||||
| endif() | |||||
| target_include_directories(engine PRIVATE ${pybind11_INCLUDE_DIRS}) | |||||
| endif () | |||||
| add_dependencies(engine engine-datasetops engine-datasetops-source engine-opt engine-gnn engine-perf engine-cache-client engine-datasetops-mapop) | add_dependencies(engine engine-datasetops engine-datasetops-source engine-opt engine-gnn engine-perf engine-cache-client engine-datasetops-mapop) | ||||
| @@ -0,0 +1,46 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <string> | |||||
| #include <unordered_map> | |||||
| #include <vector> | |||||
| #include "minddata/dataset/engine/consumers/python_tree_consumer.h" | |||||
| namespace mindspore::dataset { | |||||
| Status PythonIteratorConsumer::GetNextAsList(py::list *out) { | |||||
| std::vector<TensorPtr> row; | |||||
| { | |||||
| py::gil_scoped_release gil_release; | |||||
| RETURN_IF_NOT_OK(GetNextAsVector(&row)); | |||||
| } | |||||
| for (auto el : row) { | |||||
| (*out).append(el); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| Status PythonIteratorConsumer::GetNextAsDict(py::dict *out) { | |||||
| std::unordered_map<std::string, TensorPtr> row; | |||||
| { | |||||
| py::gil_scoped_release gil_release; | |||||
| RETURN_IF_NOT_OK(GetNextAsMap(&row)); | |||||
| } | |||||
| for (auto el : row) { | |||||
| (*out)[common::SafeCStr(el.first)] = el.second; | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| } // namespace mindspore::dataset | |||||
| @@ -26,24 +26,21 @@ | |||||
| namespace mindspore::dataset { | namespace mindspore::dataset { | ||||
| /// Consumer that iterates over the dataset and returns the rows one by one as a python list or a dict | /// Consumer that iterates over the dataset and returns the rows one by one as a python list or a dict | ||||
| class PythonIterator : public IteratorConsumer { | |||||
| /// Constructor | |||||
| class PythonIteratorConsumer : public IteratorConsumer { | |||||
| public: | |||||
| /// Constructor which will call the base class default constructor. | |||||
| /// \param num_epochs number of epochs. Default to -1 (infinite epochs). | /// \param num_epochs number of epochs. Default to -1 (infinite epochs). | ||||
| explicit PythonIterator(int32_t num_epochs = -1) : IteratorConsumer(num_epochs) {} | |||||
| explicit PythonIteratorConsumer(int32_t num_epochs = -1) : IteratorConsumer(num_epochs) {} | |||||
| /// Returns the next row in a vector format | |||||
| /// \param[out] out std::vector of Tensors | |||||
| /// \return Status error code | |||||
| Status GetNextAsList(py::list *out); | |||||
| /// Get the next row as a python dict | |||||
| /// \param[out] output python dict | |||||
| /// \return Status error code | |||||
| Status GetNextAsMap(py::dict *output) { | |||||
| return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet."); | |||||
| } | |||||
| /// Get the next row as a python dict | |||||
| /// \param[out] output python dict | |||||
| /// \return Status error code | |||||
| Status GetNextAsList(py::list *output) { | |||||
| return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet."); | |||||
| } | |||||
| /// Returns the next row in as a map | |||||
| /// \param[out] out std::map of string to Tensor | |||||
| /// \return Status error code | |||||
| Status GetNextAsDict(py::dict *out); | |||||
| }; | }; | ||||
| } // namespace mindspore::dataset | } // namespace mindspore::dataset | ||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CONSUMERS_PYTHON_TREE_CONSUMER_H_ | #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CONSUMERS_PYTHON_TREE_CONSUMER_H_ | ||||
| @@ -34,10 +34,11 @@ namespace mindspore::dataset { | |||||
| // TreeConsumer | // TreeConsumer | ||||
| TreeConsumer::TreeConsumer() { tree_adapter_ = std::make_unique<TreeAdapter>(); } | TreeConsumer::TreeConsumer() { tree_adapter_ = std::make_unique<TreeAdapter>(); } | ||||
| Status TreeConsumer::Init(std::shared_ptr<api::Dataset> d) { return tree_adapter_->BuildAndPrepare(std::move(d)); } | |||||
| Status TreeConsumer::Init(std::shared_ptr<api::DatasetNode> d) { return tree_adapter_->BuildAndPrepare(std::move(d)); } | |||||
| Status TreeConsumer::Terminate() { return tree_adapter_->AllTasks()->DoServiceStop(); } | |||||
| // IteratorConsumer | // IteratorConsumer | ||||
| Status IteratorConsumer::Init(std::shared_ptr<api::Dataset> d) { | |||||
| Status IteratorConsumer::Init(std::shared_ptr<api::DatasetNode> d) { | |||||
| return tree_adapter_->BuildAndPrepare(std::move(d), num_epochs_); | return tree_adapter_->BuildAndPrepare(std::move(d), num_epochs_); | ||||
| } | } | ||||
| @@ -73,7 +74,7 @@ Status IteratorConsumer::GetNextAsMap(std::unordered_map<std::string, TensorPtr> | |||||
| } | } | ||||
| // ToDevice | // ToDevice | ||||
| Status ToDevice::Init(std::shared_ptr<api::Dataset> d) { | |||||
| Status ToDevice::Init(std::shared_ptr<api::DatasetNode> d) { | |||||
| return tree_adapter_->BuildAndPrepare(std::move(d), num_epochs_); | return tree_adapter_->BuildAndPrepare(std::move(d), num_epochs_); | ||||
| } | } | ||||
| @@ -384,7 +385,7 @@ TreeGetters::TreeGetters() : dataset_size_(-1), init_flag_(false), row_flag_(fal | |||||
| tree_adapter_ = std::make_unique<TreeAdapter>(); | tree_adapter_ = std::make_unique<TreeAdapter>(); | ||||
| } | } | ||||
| Status TreeGetters::Init(std::shared_ptr<api::Dataset> d) { | |||||
| Status TreeGetters::Init(std::shared_ptr<api::DatasetNode> d) { | |||||
| Status s = tree_adapter_->BuildAndPrepare(std::move(d)); | Status s = tree_adapter_->BuildAndPrepare(std::move(d)); | ||||
| if (!s.IsError()) { | if (!s.IsError()) { | ||||
| init_flag_ = true; | init_flag_ = true; | ||||
| @@ -463,4 +464,15 @@ Status TreeGetters::GetNumClasses(int64_t *num_classes) { | |||||
| RETURN_IF_NOT_OK(root->GetNumClasses(num_classes)); | RETURN_IF_NOT_OK(root->GetNumClasses(num_classes)); | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status BuildVocabConsumer::Init(std::shared_ptr<api::DatasetNode> d) { | |||||
| return tree_adapter_->BuildAndPrepare(std::move(d), 1); | |||||
| } | |||||
| Status BuildVocabConsumer::Start() { | |||||
| // Getting one row would trigger building the vocab | |||||
| TensorRow row; | |||||
| RETURN_IF_NOT_OK(tree_adapter_->GetNext(&row)); | |||||
| // The returned row would EOE which is an empty row | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(row.empty(), "The fetched row from BuildVocab should be an EOE."); | |||||
| return Status::OK(); | |||||
| } | |||||
| } // namespace mindspore::dataset | } // namespace mindspore::dataset | ||||
| @@ -22,14 +22,16 @@ | |||||
| #include <unordered_map> | #include <unordered_map> | ||||
| #include <utility> | #include <utility> | ||||
| #include <vector> | #include <vector> | ||||
| #include "minddata/dataset/engine/tree_adapter.h" | #include "minddata/dataset/engine/tree_adapter.h" | ||||
| #include "minddata/dataset/text/vocab.h" | |||||
| namespace mindspore::dataset { | namespace mindspore::dataset { | ||||
| // Forward declare | // Forward declare | ||||
| class TreeAdapter; | class TreeAdapter; | ||||
| namespace api { | namespace api { | ||||
| class Dataset; | |||||
| class DatasetNode; | |||||
| } | } | ||||
| /// A base class for tree consumers which would fetch rows from the tree pipeline | /// A base class for tree consumers which would fetch rows from the tree pipeline | ||||
| @@ -40,7 +42,9 @@ class TreeConsumer { | |||||
| /// Initializes the consumer, this involves constructing and preparing the tree. | /// Initializes the consumer, this involves constructing and preparing the tree. | ||||
| /// \param d The dataset node that represent the root of the IR tree. | /// \param d The dataset node that represent the root of the IR tree. | ||||
| /// \return Status error code. | /// \return Status error code. | ||||
| virtual Status Init(std::shared_ptr<api::Dataset> d); | |||||
| virtual Status Init(std::shared_ptr<api::DatasetNode> d); | |||||
| Status Terminate(); | |||||
| protected: | protected: | ||||
| /// The class owns the tree_adapter that handles execution tree operations. | /// The class owns the tree_adapter that handles execution tree operations. | ||||
| @@ -57,7 +61,7 @@ class IteratorConsumer : public TreeConsumer { | |||||
| /// \param num_epochs number of epochs. Default to -1 (infinite epochs). | /// \param num_epochs number of epochs. Default to -1 (infinite epochs). | ||||
| explicit IteratorConsumer(int32_t num_epochs = -1) : TreeConsumer(), num_epochs_(num_epochs) {} | explicit IteratorConsumer(int32_t num_epochs = -1) : TreeConsumer(), num_epochs_(num_epochs) {} | ||||
| Status Init(std::shared_ptr<api::Dataset> d) override; | |||||
| Status Init(std::shared_ptr<api::DatasetNode> d) override; | |||||
| /// Returns the next row in a vector format | /// Returns the next row in a vector format | ||||
| /// \param[out] out std::vector of Tensors | /// \param[out] out std::vector of Tensors | ||||
| @@ -126,10 +130,10 @@ class SaveToDisk : public TreeConsumer { | |||||
| /// Consumer that iterates over the dataset and send it to a device | /// Consumer that iterates over the dataset and send it to a device | ||||
| class ToDevice : public TreeConsumer { | class ToDevice : public TreeConsumer { | ||||
| public: | public: | ||||
| ToDevice(std::string device_type, bool send_epoch_end, int32_t num_epochs = -1) | |||||
| : TreeConsumer(), device_type_(device_type), send_epoch_end_(send_epoch_end), num_epochs_(num_epochs) {} | |||||
| explicit ToDevice(bool send_epoch_end, int32_t num_epochs = -1) | |||||
| : TreeConsumer(), send_epoch_end_(send_epoch_end), num_epochs_(num_epochs) {} | |||||
| Status Init(std::shared_ptr<api::Dataset> d) override; | |||||
| Status Init(std::shared_ptr<api::DatasetNode> d) override; | |||||
| /// Send the data to device | /// Send the data to device | ||||
| /// \return Status error code | /// \return Status error code | ||||
| @@ -158,7 +162,7 @@ class ToDevice : public TreeConsumer { | |||||
| class TreeGetters : public TreeConsumer { | class TreeGetters : public TreeConsumer { | ||||
| public: | public: | ||||
| TreeGetters(); | TreeGetters(); | ||||
| Status Init(std::shared_ptr<api::Dataset> d) override; | |||||
| Status Init(std::shared_ptr<api::DatasetNode> d) override; | |||||
| Status GetDatasetSize(int64_t *size); | Status GetDatasetSize(int64_t *size); | ||||
| Status GetOutputTypes(std::vector<DataType> *types); | Status GetOutputTypes(std::vector<DataType> *types); | ||||
| Status GetOutputShapes(std::vector<TensorShape> *shapes); | Status GetOutputShapes(std::vector<TensorShape> *shapes); | ||||
| @@ -176,5 +180,23 @@ class TreeGetters : public TreeConsumer { | |||||
| bool row_flag_; // indicate whether the first row has been stored in row_ | bool row_flag_; // indicate whether the first row has been stored in row_ | ||||
| }; | }; | ||||
| class BuildVocabConsumer : public TreeConsumer { | |||||
| public: | |||||
| /// BuildVocabConsumer Constructor which will call the base class default constructor. | |||||
| BuildVocabConsumer() = default; | |||||
| Status Init(std::shared_ptr<api::DatasetNode> d) override; | |||||
| /// Save the given dataset to MindRecord format on disk. This is a blocking method (i.e., after returning, all rows | |||||
| /// would be written to disk) | |||||
| /// \return Status error code | |||||
| Status Start(); | |||||
| protected: | |||||
| /// Method to return the name of the consumer | |||||
| /// \return string | |||||
| std::string Name() override { return "BuildVocab"; } | |||||
| }; | |||||
| } // namespace mindspore::dataset | } // namespace mindspore::dataset | ||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CONSUMERS_TREE_CONSUMER_H_ | #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CONSUMERS_TREE_CONSUMER_H_ | ||||
| @@ -3,6 +3,7 @@ set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE | |||||
| add_subdirectory(source) | add_subdirectory(source) | ||||
| set(DATASET_ENGINE_IR_DATASETOPS_SRC_FILES | set(DATASET_ENGINE_IR_DATASETOPS_SRC_FILES | ||||
| dataset_node.cc | |||||
| batch_node.cc | batch_node.cc | ||||
| bucket_batch_by_length_node.cc | bucket_batch_by_length_node.cc | ||||
| build_sentence_piece_vocab_node.cc | build_sentence_piece_vocab_node.cc | ||||
| @@ -28,7 +28,7 @@ namespace mindspore { | |||||
| namespace dataset { | namespace dataset { | ||||
| namespace api { | namespace api { | ||||
| BatchNode::BatchNode(std::shared_ptr<Dataset> child, int32_t batch_size, bool drop_remainder, bool pad, | |||||
| BatchNode::BatchNode(std::shared_ptr<DatasetNode> child, int32_t batch_size, bool drop_remainder, bool pad, | |||||
| std::vector<std::string> cols_to_map, | std::vector<std::string> cols_to_map, | ||||
| std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> pad_map) | std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> pad_map) | ||||
| : batch_size_(batch_size), | : batch_size_(batch_size), | ||||
| @@ -23,16 +23,16 @@ | |||||
| #include <utility> | #include <utility> | ||||
| #include <vector> | #include <vector> | ||||
| #include "minddata/dataset/include/datasets.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/dataset_node.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| namespace api { | namespace api { | ||||
| class BatchNode : public Dataset { | |||||
| class BatchNode : public DatasetNode { | |||||
| public: | public: | ||||
| /// \brief Constructor | /// \brief Constructor | ||||
| BatchNode(std::shared_ptr<Dataset> child, int32_t batch_size, bool drop_remainder, bool pad, | |||||
| BatchNode(std::shared_ptr<DatasetNode> child, int32_t batch_size, bool drop_remainder, bool pad, | |||||
| std::vector<std::string> cols_to_map, | std::vector<std::string> cols_to_map, | ||||
| std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> pad_map); | std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> pad_map); | ||||
| @@ -29,7 +29,7 @@ namespace mindspore { | |||||
| namespace dataset { | namespace dataset { | ||||
| namespace api { | namespace api { | ||||
| BucketBatchByLengthNode::BucketBatchByLengthNode( | BucketBatchByLengthNode::BucketBatchByLengthNode( | ||||
| std::shared_ptr<Dataset> child, const std::vector<std::string> &column_names, | |||||
| std::shared_ptr<DatasetNode> child, const std::vector<std::string> &column_names, | |||||
| const std::vector<int32_t> &bucket_boundaries, const std::vector<int32_t> &bucket_batch_sizes, | const std::vector<int32_t> &bucket_boundaries, const std::vector<int32_t> &bucket_batch_sizes, | ||||
| std::function<TensorRow(TensorRow)> element_length_function, | std::function<TensorRow(TensorRow)> element_length_function, | ||||
| const std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> &pad_info, bool pad_to_bucket_boundary, | const std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> &pad_info, bool pad_to_bucket_boundary, | ||||
| @@ -23,15 +23,15 @@ | |||||
| #include <utility> | #include <utility> | ||||
| #include <vector> | #include <vector> | ||||
| #include "minddata/dataset/include/datasets.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/dataset_node.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| namespace api { | namespace api { | ||||
| class BucketBatchByLengthNode : public Dataset { | |||||
| class BucketBatchByLengthNode : public DatasetNode { | |||||
| public: | public: | ||||
| /// \brief Constructor | /// \brief Constructor | ||||
| BucketBatchByLengthNode(std::shared_ptr<Dataset> child, const std::vector<std::string> &column_names, | |||||
| BucketBatchByLengthNode(std::shared_ptr<DatasetNode> child, const std::vector<std::string> &column_names, | |||||
| const std::vector<int32_t> &bucket_boundaries, const std::vector<int32_t> &bucket_batch_sizes, | const std::vector<int32_t> &bucket_boundaries, const std::vector<int32_t> &bucket_batch_sizes, | ||||
| std::function<TensorRow(TensorRow)> element_length_function = nullptr, | std::function<TensorRow(TensorRow)> element_length_function = nullptr, | ||||
| const std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> &pad_info = {}, | const std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> &pad_info = {}, | ||||
| @@ -28,7 +28,7 @@ namespace mindspore { | |||||
| namespace dataset { | namespace dataset { | ||||
| namespace api { | namespace api { | ||||
| BuildSentenceVocabNode::BuildSentenceVocabNode(std::shared_ptr<Dataset> child, | |||||
| BuildSentenceVocabNode::BuildSentenceVocabNode(std::shared_ptr<DatasetNode> child, | |||||
| std::shared_ptr<SentencePieceVocab> vocab, | std::shared_ptr<SentencePieceVocab> vocab, | ||||
| const std::vector<std::string> &col_names, uint32_t vocab_size, | const std::vector<std::string> &col_names, uint32_t vocab_size, | ||||
| float character_coverage, SentencePieceModel model_type, | float character_coverage, SentencePieceModel model_type, | ||||
| @@ -29,10 +29,10 @@ namespace mindspore { | |||||
| namespace dataset { | namespace dataset { | ||||
| namespace api { | namespace api { | ||||
| class BuildSentenceVocabNode : public Dataset { | |||||
| class BuildSentenceVocabNode : public DatasetNode { | |||||
| public: | public: | ||||
| /// \brief Constructor | /// \brief Constructor | ||||
| BuildSentenceVocabNode(std::shared_ptr<Dataset> child, std::shared_ptr<SentencePieceVocab> vocab, | |||||
| BuildSentenceVocabNode(std::shared_ptr<DatasetNode> child, std::shared_ptr<SentencePieceVocab> vocab, | |||||
| const std::vector<std::string> &col_names, uint32_t vocab_size, float character_coverage, | const std::vector<std::string> &col_names, uint32_t vocab_size, float character_coverage, | ||||
| SentencePieceModel model_type, const std::unordered_map<std::string, std::string> ¶ms); | SentencePieceModel model_type, const std::unordered_map<std::string, std::string> ¶ms); | ||||
| @@ -28,7 +28,7 @@ namespace mindspore { | |||||
| namespace dataset { | namespace dataset { | ||||
| namespace api { | namespace api { | ||||
| BuildVocabNode::BuildVocabNode(std::shared_ptr<Dataset> child, std::shared_ptr<Vocab> vocab, | |||||
| BuildVocabNode::BuildVocabNode(std::shared_ptr<DatasetNode> child, std::shared_ptr<Vocab> vocab, | |||||
| const std::vector<std::string> &columns, const std::pair<int64_t, int64_t> &freq_range, | const std::vector<std::string> &columns, const std::pair<int64_t, int64_t> &freq_range, | ||||
| int64_t top_k, const std::vector<std::string> &special_tokens, bool special_first) | int64_t top_k, const std::vector<std::string> &special_tokens, bool special_first) | ||||
| : vocab_(vocab), | : vocab_(vocab), | ||||
| @@ -22,17 +22,17 @@ | |||||
| #include <utility> | #include <utility> | ||||
| #include <vector> | #include <vector> | ||||
| #include "minddata/dataset/include/datasets.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/dataset_node.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| namespace api { | namespace api { | ||||
| class BuildVocabNode : public Dataset { | |||||
| class BuildVocabNode : public DatasetNode { | |||||
| public: | public: | ||||
| /// \brief Constructor | /// \brief Constructor | ||||
| BuildVocabNode(std::shared_ptr<Dataset> child, std::shared_ptr<Vocab> vocab, const std::vector<std::string> &columns, | |||||
| const std::pair<int64_t, int64_t> &freq_range, int64_t top_k, | |||||
| BuildVocabNode(std::shared_ptr<DatasetNode> child, std::shared_ptr<Vocab> vocab, | |||||
| const std::vector<std::string> &columns, const std::pair<int64_t, int64_t> &freq_range, int64_t top_k, | |||||
| const std::vector<std::string> &special_tokens, bool special_first); | const std::vector<std::string> &special_tokens, bool special_first); | ||||
| /// \brief Destructor | /// \brief Destructor | ||||
| @@ -27,18 +27,16 @@ namespace mindspore { | |||||
| namespace dataset { | namespace dataset { | ||||
| namespace api { | namespace api { | ||||
| // Function to build ConcatOp | // Function to build ConcatOp | ||||
| ConcatNode::ConcatNode(const std::vector<std::shared_ptr<Dataset>> &datasets) : datasets_(datasets) { | |||||
| this->children = datasets_; | |||||
| } | |||||
| ConcatNode::ConcatNode(const std::vector<std::shared_ptr<DatasetNode>> &datasets) { this->children = datasets; } | |||||
| Status ConcatNode::ValidateParams() { | Status ConcatNode::ValidateParams() { | ||||
| if (datasets_.empty()) { | |||||
| if (children.size() < 2) { | |||||
| std::string err_msg = "ConcatNode: concatenated datasets are not specified."; | std::string err_msg = "ConcatNode: concatenated datasets are not specified."; | ||||
| MS_LOG(ERROR) << err_msg; | MS_LOG(ERROR) << err_msg; | ||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | RETURN_STATUS_SYNTAX_ERROR(err_msg); | ||||
| } | } | ||||
| if (find(datasets_.begin(), datasets_.end(), nullptr) != datasets_.end()) { | |||||
| if (find(children.begin(), children.end(), nullptr) != children.end()) { | |||||
| std::string err_msg = "ConcatNode: concatenated datasets should not be null."; | std::string err_msg = "ConcatNode: concatenated datasets should not be null."; | ||||
| MS_LOG(ERROR) << err_msg; | MS_LOG(ERROR) << err_msg; | ||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | RETURN_STATUS_SYNTAX_ERROR(err_msg); | ||||
| @@ -21,16 +21,16 @@ | |||||
| #include <string> | #include <string> | ||||
| #include <vector> | #include <vector> | ||||
| #include "minddata/dataset/include/datasets.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/dataset_node.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| namespace api { | namespace api { | ||||
| class ConcatNode : public Dataset { | |||||
| class ConcatNode : public DatasetNode { | |||||
| public: | public: | ||||
| /// \brief Constructor | /// \brief Constructor | ||||
| explicit ConcatNode(const std::vector<std::shared_ptr<Dataset>> &datasets); | |||||
| explicit ConcatNode(const std::vector<std::shared_ptr<DatasetNode>> &datasets); | |||||
| /// \brief Destructor | /// \brief Destructor | ||||
| ~ConcatNode() = default; | ~ConcatNode() = default; | ||||
| @@ -42,9 +42,6 @@ class ConcatNode : public Dataset { | |||||
| /// \brief Parameters validation | /// \brief Parameters validation | ||||
| /// \return Status Status::OK() if all the parameters are valid | /// \return Status Status::OK() if all the parameters are valid | ||||
| Status ValidateParams() override; | Status ValidateParams() override; | ||||
| private: | |||||
| std::vector<std::shared_ptr<Dataset>> datasets_; | |||||
| }; | }; | ||||
| } // namespace api | } // namespace api | ||||
| @@ -0,0 +1,65 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "minddata/dataset/engine/ir/datasetops/dataset_node.h" | |||||
| #include <memory> | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| namespace api { | |||||
| Status DatasetNode::AddCacheOp(std::vector<std::shared_ptr<DatasetOp>> *node_ops) { | |||||
| if (cache_ != nullptr) { | |||||
| RETURN_IF_NOT_OK(cache_->Build()); | |||||
| std::shared_ptr<DatasetOp> cache_op; | |||||
| RETURN_IF_NOT_OK(cache_->CreateCacheOp(num_workers_, &cache_op)); | |||||
| node_ops->push_back(cache_op); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| // Constructor to initialize the cache | |||||
| DatasetNode::DatasetNode(const std::shared_ptr<DatasetCache> &dataset_cache) : DatasetNode() { cache_ = dataset_cache; } | |||||
| std::shared_ptr<DatasetNode> DatasetNode::SetNumWorkers(int32_t num_workers) { | |||||
| #if !defined(_WIN32) && !defined(_WIN64) | |||||
| #ifndef ENABLE_ANDROID | |||||
| int32_t cpu_count = sysconf(_SC_NPROCESSORS_CONF); | |||||
| if (cpu_count < 0 || cpu_count > INT32_MAX) { | |||||
| MS_LOG(ERROR) << "Error determining current CPU: " << cpu_count; | |||||
| return nullptr; | |||||
| } | |||||
| if (num_workers < 1 || num_workers > cpu_count) { | |||||
| MS_LOG(ERROR) << "num_workers exceeds the boundary between 1 and " << cpu_count; | |||||
| return nullptr; | |||||
| } | |||||
| #endif | |||||
| #endif | |||||
| num_workers_ = num_workers; | |||||
| return shared_from_this(); | |||||
| } | |||||
| DatasetNode::DatasetNode() { | |||||
| // Fetch some default value from config manager | |||||
| std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager(); | |||||
| num_workers_ = cfg->num_parallel_workers(); | |||||
| rows_per_buffer_ = cfg->rows_per_buffer(); | |||||
| connector_que_size_ = cfg->op_connector_size(); | |||||
| worker_connector_size_ = cfg->worker_connector_size(); | |||||
| } | |||||
| } // namespace api | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,126 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_DATASET_NODE_H_ | |||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_DATASET_NODE_H_ | |||||
| #include <map> | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <unordered_set> | |||||
| #include <utility> | |||||
| #include <vector> | |||||
| #include "minddata/dataset/include/datasets.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| namespace api { | |||||
| class Dataset; | |||||
| class SamplerObj; | |||||
| #define RETURN_EMPTY_IF_ERROR(_s) \ | |||||
| do { \ | |||||
| Status __rc = (_s); \ | |||||
| if (__rc.IsError()) { \ | |||||
| MS_LOG(ERROR) << __rc; \ | |||||
| return {}; \ | |||||
| } \ | |||||
| } while (false) | |||||
| Status AddShuffleOp(int64_t num_files, int64_t num_devices, int64_t num_rows, int64_t total_rows, | |||||
| int32_t connector_que_size, int32_t rows_per_buffer, std::shared_ptr<DatasetOp> *shuffle_op); | |||||
| // Helper function to validate dataset files parameter | |||||
| Status ValidateDatasetFilesParam(const std::string &dataset_name, const std::vector<std::string> &dataset_files); | |||||
| // Helper function to validate dataset num_shards and shard_id parameters | |||||
| Status ValidateDatasetShardParams(const std::string &dataset_name, int32_t num_shards, int32_t shard_id); | |||||
| // Helper function to validate dataset sampler parameter | |||||
| Status ValidateDatasetSampler(const std::string &dataset_name, const std::shared_ptr<SamplerObj> &sampler); | |||||
| Status ValidateStringValue(const std::string &dataset_name, const std::string &str, | |||||
| const std::unordered_set<std::string> &valid_strings); | |||||
| // Helper function to validate dataset input/output column parameterCD - | |||||
| Status ValidateDatasetColumnParam(const std::string &dataset_name, const std::string &column_param, | |||||
| const std::vector<std::string> &columns); | |||||
| // Helper function to validate dataset directory parameter | |||||
| Status ValidateDatasetDirParam(const std::string &dataset_name, std::string dataset_dir); | |||||
| /// \brief Function to create a sampler for non-mappable dataset (to be used by cache op later). | |||||
| /// \notes Non-mappable dataset does not directly support a sampler. It has provided sampling arguments (shuffle, | |||||
| /// num_samples, num_shards, shard_id) and it DOES support sampling if somewhere above it in the pipeline contains | |||||
| /// a cache. If there is no cache above it, then the sampler is not used. | |||||
| /// \param[in] num_samples The number of samples to be included in the dataset. | |||||
| /// \param[in] shuffle If true, the indices are shuffled. | |||||
| /// \param[in] num_shards Number of shards to divide the dataset into. | |||||
| /// \param[in] shard_id Shard ID of the current shard within num_shards. | |||||
| /// \return Shared pointer to the current Sampler. | |||||
| std::shared_ptr<SamplerObj> SelectSampler(int64_t num_samples, bool shuffle, int32_t num_shards, int32_t shard_id); | |||||
| class DatasetNode : public std::enable_shared_from_this<DatasetNode> { | |||||
| public: | |||||
| /// \brief Constructor | |||||
| DatasetNode(); | |||||
| /// \brief Constructor that initializes the cache | |||||
| /// \param dataset_cache DatasetCache | |||||
| explicit DatasetNode(const std::shared_ptr<DatasetCache> &dataset_cache); | |||||
| /// \brief Destructor | |||||
| ~DatasetNode() = default; | |||||
| /// \brief Pure virtual function to convert a DatasetNode class into a runtime dataset object | |||||
| /// \return The list of shared pointers to the newly created DatasetOps | |||||
| virtual std::vector<std::shared_ptr<DatasetOp>> Build() = 0; | |||||
| /// \brief Pure virtual function for derived class to implement parameters validation | |||||
| /// \return Status Status::OK() if all the parameters are valid | |||||
| virtual Status ValidateParams() = 0; | |||||
| const std::vector<std::shared_ptr<DatasetNode>> Children() const { return children; } | |||||
| /// \brief Pure virtual function for derived class to get the shard id of specific node | |||||
| /// \return Status Status::OK() if get shard id successfully | |||||
| virtual Status GetShardId(int32_t *shard_id) { | |||||
| return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet."); | |||||
| } | |||||
| /// \brief Setter function for runtime number of workers | |||||
| /// \param[in] num_workers The number of threads in this operator | |||||
| /// \return Shared pointer to the original object | |||||
| std::shared_ptr<DatasetNode> SetNumWorkers(int32_t num_workers); | |||||
| protected: | |||||
| std::vector<std::shared_ptr<DatasetNode>> children; | |||||
| std::shared_ptr<DatasetNode> parent; | |||||
| std::shared_ptr<DatasetCache> cache_; | |||||
| Status AddCacheOp(std::vector<std::shared_ptr<DatasetOp>> *node_ops); | |||||
| int32_t num_workers_; | |||||
| int32_t rows_per_buffer_; | |||||
| int32_t connector_que_size_; | |||||
| int32_t worker_connector_size_; | |||||
| }; | |||||
| } // namespace api | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_DATASET_NODE_H_ | |||||
| @@ -28,14 +28,14 @@ namespace mindspore { | |||||
| namespace dataset { | namespace dataset { | ||||
| namespace api { | namespace api { | ||||
| MapNode::MapNode(std::shared_ptr<Dataset> child, std::vector<std::shared_ptr<TensorOperation>> operations, | |||||
| MapNode::MapNode(std::shared_ptr<DatasetNode> child, std::vector<std::shared_ptr<TensorOperation>> operations, | |||||
| std::vector<std::string> input_columns, std::vector<std::string> output_columns, | std::vector<std::string> input_columns, std::vector<std::string> output_columns, | ||||
| const std::vector<std::string> &project_columns, std::shared_ptr<DatasetCache> cache) | const std::vector<std::string> &project_columns, std::shared_ptr<DatasetCache> cache) | ||||
| : operations_(operations), | : operations_(operations), | ||||
| input_columns_(input_columns), | input_columns_(input_columns), | ||||
| output_columns_(output_columns), | output_columns_(output_columns), | ||||
| project_columns_(project_columns), | project_columns_(project_columns), | ||||
| Dataset(std::move(cache)) { | |||||
| DatasetNode(std::move(cache)) { | |||||
| this->children.push_back(child); | this->children.push_back(child); | ||||
| } | } | ||||
| @@ -21,15 +21,15 @@ | |||||
| #include <string> | #include <string> | ||||
| #include <vector> | #include <vector> | ||||
| #include "minddata/dataset/include/datasets.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/dataset_node.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| namespace api { | namespace api { | ||||
| class MapNode : public Dataset { | |||||
| class MapNode : public DatasetNode { | |||||
| public: | public: | ||||
| /// \brief Constructor | /// \brief Constructor | ||||
| MapNode(std::shared_ptr<Dataset> child, std::vector<std::shared_ptr<TensorOperation>> operations, | |||||
| MapNode(std::shared_ptr<DatasetNode> child, std::vector<std::shared_ptr<TensorOperation>> operations, | |||||
| std::vector<std::string> input_columns = {}, std::vector<std::string> output_columns = {}, | std::vector<std::string> input_columns = {}, std::vector<std::string> output_columns = {}, | ||||
| const std::vector<std::string> &columns = {}, std::shared_ptr<DatasetCache> cache = nullptr); | const std::vector<std::string> &columns = {}, std::shared_ptr<DatasetCache> cache = nullptr); | ||||
| @@ -28,7 +28,8 @@ namespace dataset { | |||||
| namespace api { | namespace api { | ||||
| // Function to build ProjectOp | // Function to build ProjectOp | ||||
| ProjectNode::ProjectNode(std::shared_ptr<Dataset> child, const std::vector<std::string> &columns) : columns_(columns) { | |||||
| ProjectNode::ProjectNode(std::shared_ptr<DatasetNode> child, const std::vector<std::string> &columns) | |||||
| : columns_(columns) { | |||||
| this->children.push_back(child); | this->children.push_back(child); | ||||
| } | } | ||||
| @@ -21,17 +21,17 @@ | |||||
| #include <string> | #include <string> | ||||
| #include <vector> | #include <vector> | ||||
| #include "minddata/dataset/include/datasets.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/dataset_node.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| namespace api { | namespace api { | ||||
| class ProjectNode : public Dataset { | |||||
| class ProjectNode : public DatasetNode { | |||||
| public: | public: | ||||
| /// \brief Constructor | /// \brief Constructor | ||||
| explicit ProjectNode(std::shared_ptr<Dataset> child, const std::vector<std::string> &columns); | |||||
| explicit ProjectNode(std::shared_ptr<DatasetNode> child, const std::vector<std::string> &columns); | |||||
| /// \brief Destructor | /// \brief Destructor | ||||
| ~ProjectNode() = default; | ~ProjectNode() = default; | ||||
| @@ -27,7 +27,7 @@ namespace mindspore { | |||||
| namespace dataset { | namespace dataset { | ||||
| namespace api { | namespace api { | ||||
| // Function to build RenameOp | // Function to build RenameOp | ||||
| RenameNode::RenameNode(std::shared_ptr<Dataset> child, const std::vector<std::string> &input_columns, | |||||
| RenameNode::RenameNode(std::shared_ptr<DatasetNode> child, const std::vector<std::string> &input_columns, | |||||
| const std::vector<std::string> &output_columns) | const std::vector<std::string> &output_columns) | ||||
| : input_columns_(input_columns), output_columns_(output_columns) { | : input_columns_(input_columns), output_columns_(output_columns) { | ||||
| this->children.push_back(child); | this->children.push_back(child); | ||||
| @@ -21,17 +21,17 @@ | |||||
| #include <string> | #include <string> | ||||
| #include <vector> | #include <vector> | ||||
| #include "minddata/dataset/include/datasets.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/dataset_node.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| namespace api { | namespace api { | ||||
| class RenameNode : public Dataset { | |||||
| class RenameNode : public DatasetNode { | |||||
| public: | public: | ||||
| /// \brief Constructor | /// \brief Constructor | ||||
| explicit RenameNode(std::shared_ptr<Dataset> child, const std::vector<std::string> &input_columns, | |||||
| explicit RenameNode(std::shared_ptr<DatasetNode> child, const std::vector<std::string> &input_columns, | |||||
| const std::vector<std::string> &output_columns); | const std::vector<std::string> &output_columns); | ||||
| /// \brief Destructor | /// \brief Destructor | ||||
| @@ -27,7 +27,7 @@ namespace mindspore { | |||||
| namespace dataset { | namespace dataset { | ||||
| namespace api { | namespace api { | ||||
| RepeatNode::RepeatNode(std::shared_ptr<Dataset> child, int32_t count) : repeat_count_(count) { | |||||
| RepeatNode::RepeatNode(std::shared_ptr<DatasetNode> child, int32_t count) : repeat_count_(count) { | |||||
| this->children.push_back(child); | this->children.push_back(child); | ||||
| } | } | ||||
| @@ -23,17 +23,17 @@ | |||||
| #include <string> | #include <string> | ||||
| #include <vector> | #include <vector> | ||||
| #include "minddata/dataset/include/datasets.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/dataset_node.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| namespace api { | namespace api { | ||||
| class RepeatNode : public Dataset { | |||||
| class RepeatNode : public DatasetNode { | |||||
| public: | public: | ||||
| /// \brief Constructor | /// \brief Constructor | ||||
| explicit RepeatNode(std::shared_ptr<Dataset> child, int32_t count); | |||||
| explicit RepeatNode(std::shared_ptr<DatasetNode> child, int32_t count); | |||||
| /// \brief Destructor | /// \brief Destructor | ||||
| ~RepeatNode() = default; | ~RepeatNode() = default; | ||||
| @@ -28,7 +28,7 @@ namespace dataset { | |||||
| namespace api { | namespace api { | ||||
| // Constructor for ShuffleNode | // Constructor for ShuffleNode | ||||
| ShuffleNode::ShuffleNode(std::shared_ptr<Dataset> child, int32_t shuffle_size, bool reset_every_epoch) | |||||
| ShuffleNode::ShuffleNode(std::shared_ptr<DatasetNode> child, int32_t shuffle_size, bool reset_every_epoch) | |||||
| : shuffle_size_(shuffle_size), shuffle_seed_(GetSeed()), reset_every_epoch_(reset_every_epoch) { | : shuffle_size_(shuffle_size), shuffle_seed_(GetSeed()), reset_every_epoch_(reset_every_epoch) { | ||||
| this->children.push_back(child); | this->children.push_back(child); | ||||
| } | } | ||||
| @@ -23,16 +23,16 @@ | |||||
| #include <string> | #include <string> | ||||
| #include <vector> | #include <vector> | ||||
| #include "minddata/dataset/include/datasets.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/dataset_node.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| namespace api { | namespace api { | ||||
| class ShuffleNode : public Dataset { | |||||
| class ShuffleNode : public DatasetNode { | |||||
| public: | public: | ||||
| ShuffleNode(std::shared_ptr<Dataset> child, int32_t shuffle_size, bool reset_every_epoch); | |||||
| ShuffleNode(std::shared_ptr<DatasetNode> child, int32_t shuffle_size, bool reset_every_epoch); | |||||
| ~ShuffleNode() = default; | ~ShuffleNode() = default; | ||||
| @@ -28,7 +28,7 @@ namespace dataset { | |||||
| namespace api { | namespace api { | ||||
| // Constructor for SkipNode | // Constructor for SkipNode | ||||
| SkipNode::SkipNode(std::shared_ptr<Dataset> child, int32_t count) : skip_count_(count) { | |||||
| SkipNode::SkipNode(std::shared_ptr<DatasetNode> child, int32_t count) : skip_count_(count) { | |||||
| this->children.push_back(child); | this->children.push_back(child); | ||||
| } | } | ||||
| @@ -21,16 +21,16 @@ | |||||
| #include <string> | #include <string> | ||||
| #include <vector> | #include <vector> | ||||
| #include "minddata/dataset/include/datasets.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/dataset_node.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| namespace api { | namespace api { | ||||
| class SkipNode : public Dataset { | |||||
| class SkipNode : public DatasetNode { | |||||
| public: | public: | ||||
| /// \brief Constructor | /// \brief Constructor | ||||
| explicit SkipNode(std::shared_ptr<Dataset> child, int32_t count); | |||||
| explicit SkipNode(std::shared_ptr<DatasetNode> child, int32_t count); | |||||
| /// \brief Destructor | /// \brief Destructor | ||||
| ~SkipNode() = default; | ~SkipNode() = default; | ||||
| @@ -32,7 +32,7 @@ namespace api { | |||||
| AlbumNode::AlbumNode(const std::string &dataset_dir, const std::string &data_schema, | AlbumNode::AlbumNode(const std::string &dataset_dir, const std::string &data_schema, | ||||
| const std::vector<std::string> &column_names, bool decode, | const std::vector<std::string> &column_names, bool decode, | ||||
| const std::shared_ptr<SamplerObj> &sampler, const std::shared_ptr<DatasetCache> &cache) | const std::shared_ptr<SamplerObj> &sampler, const std::shared_ptr<DatasetCache> &cache) | ||||
| : Dataset(std::move(cache)), | |||||
| : DatasetNode(std::move(cache)), | |||||
| dataset_dir_(dataset_dir), | dataset_dir_(dataset_dir), | ||||
| schema_path_(data_schema), | schema_path_(data_schema), | ||||
| column_names_(column_names), | column_names_(column_names), | ||||
| @@ -21,13 +21,13 @@ | |||||
| #include <string> | #include <string> | ||||
| #include <vector> | #include <vector> | ||||
| #include "minddata/dataset/include/datasets.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/dataset_node.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| namespace api { | namespace api { | ||||
| class AlbumNode : public Dataset { | |||||
| class AlbumNode : public DatasetNode { | |||||
| public: | public: | ||||
| /// \brief Constructor | /// \brief Constructor | ||||
| AlbumNode(const std::string &dataset_dir, const std::string &data_schema, | AlbumNode(const std::string &dataset_dir, const std::string &data_schema, | ||||
| @@ -31,7 +31,7 @@ namespace api { | |||||
| CelebANode::CelebANode(const std::string &dataset_dir, const std::string &usage, | CelebANode::CelebANode(const std::string &dataset_dir, const std::string &usage, | ||||
| const std::shared_ptr<SamplerObj> &sampler, const bool &decode, | const std::shared_ptr<SamplerObj> &sampler, const bool &decode, | ||||
| const std::set<std::string> &extensions, const std::shared_ptr<DatasetCache> &cache) | const std::set<std::string> &extensions, const std::shared_ptr<DatasetCache> &cache) | ||||
| : Dataset(std::move(cache)), | |||||
| : DatasetNode(std::move(cache)), | |||||
| dataset_dir_(dataset_dir), | dataset_dir_(dataset_dir), | ||||
| usage_(usage), | usage_(usage), | ||||
| sampler_(sampler), | sampler_(sampler), | ||||
| @@ -23,13 +23,13 @@ | |||||
| #include <utility> | #include <utility> | ||||
| #include <vector> | #include <vector> | ||||
| #include "minddata/dataset/include/datasets.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/dataset_node.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| namespace api { | namespace api { | ||||
| class CelebANode : public Dataset { | |||||
| class CelebANode : public DatasetNode { | |||||
| public: | public: | ||||
| /// \brief Constructor | /// \brief Constructor | ||||
| CelebANode(const std::string &dataset_dir, const std::string &usage, const std::shared_ptr<SamplerObj> &sampler, | CelebANode(const std::string &dataset_dir, const std::string &usage, const std::shared_ptr<SamplerObj> &sampler, | ||||
| @@ -31,7 +31,7 @@ namespace api { | |||||
| // Constructor for Cifar100Node | // Constructor for Cifar100Node | ||||
| Cifar100Node::Cifar100Node(const std::string &dataset_dir, const std::string &usage, | Cifar100Node::Cifar100Node(const std::string &dataset_dir, const std::string &usage, | ||||
| std::shared_ptr<SamplerObj> sampler, std::shared_ptr<DatasetCache> cache) | std::shared_ptr<SamplerObj> sampler, std::shared_ptr<DatasetCache> cache) | ||||
| : Dataset(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {} | |||||
| : DatasetNode(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {} | |||||
| Status Cifar100Node::ValidateParams() { | Status Cifar100Node::ValidateParams() { | ||||
| RETURN_IF_NOT_OK(ValidateDatasetDirParam("Cifar100Node", dataset_dir_)); | RETURN_IF_NOT_OK(ValidateDatasetDirParam("Cifar100Node", dataset_dir_)); | ||||
| @@ -21,13 +21,13 @@ | |||||
| #include <string> | #include <string> | ||||
| #include <vector> | #include <vector> | ||||
| #include "minddata/dataset/include/datasets.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/dataset_node.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| namespace api { | namespace api { | ||||
| class Cifar100Node : public Dataset { | |||||
| class Cifar100Node : public DatasetNode { | |||||
| public: | public: | ||||
| /// \brief Constructor | /// \brief Constructor | ||||
| Cifar100Node(const std::string &dataset_dir, const std::string &usage, std::shared_ptr<SamplerObj> sampler, | Cifar100Node(const std::string &dataset_dir, const std::string &usage, std::shared_ptr<SamplerObj> sampler, | ||||
| @@ -31,7 +31,7 @@ namespace api { | |||||
| // Constructor for Cifar10Node | // Constructor for Cifar10Node | ||||
| Cifar10Node::Cifar10Node(const std::string &dataset_dir, const std::string &usage, std::shared_ptr<SamplerObj> sampler, | Cifar10Node::Cifar10Node(const std::string &dataset_dir, const std::string &usage, std::shared_ptr<SamplerObj> sampler, | ||||
| std::shared_ptr<DatasetCache> cache) | std::shared_ptr<DatasetCache> cache) | ||||
| : Dataset(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {} | |||||
| : DatasetNode(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {} | |||||
| Status Cifar10Node::ValidateParams() { | Status Cifar10Node::ValidateParams() { | ||||
| RETURN_IF_NOT_OK(ValidateDatasetDirParam("Cifar10Node", dataset_dir_)); | RETURN_IF_NOT_OK(ValidateDatasetDirParam("Cifar10Node", dataset_dir_)); | ||||
| @@ -21,13 +21,13 @@ | |||||
| #include <string> | #include <string> | ||||
| #include <vector> | #include <vector> | ||||
| #include "minddata/dataset/include/datasets.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/dataset_node.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| namespace api { | namespace api { | ||||
| class Cifar10Node : public Dataset { | |||||
| class Cifar10Node : public DatasetNode { | |||||
| public: | public: | ||||
| /// \brief Constructor | /// \brief Constructor | ||||
| Cifar10Node(const std::string &dataset_dir, const std::string &usage, std::shared_ptr<SamplerObj> sampler, | Cifar10Node(const std::string &dataset_dir, const std::string &usage, std::shared_ptr<SamplerObj> sampler, | ||||
| @@ -33,7 +33,7 @@ namespace api { | |||||
| // Constructor for CLUENode | // Constructor for CLUENode | ||||
| CLUENode::CLUENode(const std::vector<std::string> clue_files, std::string task, std::string usage, int64_t num_samples, | CLUENode::CLUENode(const std::vector<std::string> clue_files, std::string task, std::string usage, int64_t num_samples, | ||||
| ShuffleMode shuffle, int32_t num_shards, int32_t shard_id, std::shared_ptr<DatasetCache> cache) | ShuffleMode shuffle, int32_t num_shards, int32_t shard_id, std::shared_ptr<DatasetCache> cache) | ||||
| : Dataset(std::move(cache)), | |||||
| : DatasetNode(std::move(cache)), | |||||
| dataset_files_(clue_files), | dataset_files_(clue_files), | ||||
| task_(task), | task_(task), | ||||
| usage_(usage), | usage_(usage), | ||||
| @@ -21,14 +21,14 @@ | |||||
| #include <string> | #include <string> | ||||
| #include <vector> | #include <vector> | ||||
| #include "minddata/dataset/include/datasets.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/dataset_node.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| namespace api { | namespace api { | ||||
| /// \class CLUENode | /// \class CLUENode | ||||
| /// \brief A Dataset derived class to represent CLUE dataset | /// \brief A Dataset derived class to represent CLUE dataset | ||||
| class CLUENode : public Dataset { | |||||
| class CLUENode : public DatasetNode { | |||||
| public: | public: | ||||
| /// \brief Constructor | /// \brief Constructor | ||||
| CLUENode(const std::vector<std::string> dataset_files, std::string task, std::string usage, int64_t num_samples, | CLUENode(const std::vector<std::string> dataset_files, std::string task, std::string usage, int64_t num_samples, | ||||
| @@ -30,7 +30,7 @@ namespace api { | |||||
| // Constructor for CocoNode | // Constructor for CocoNode | ||||
| CocoNode::CocoNode(const std::string &dataset_dir, const std::string &annotation_file, const std::string &task, | CocoNode::CocoNode(const std::string &dataset_dir, const std::string &annotation_file, const std::string &task, | ||||
| const bool &decode, const std::shared_ptr<SamplerObj> &sampler, std::shared_ptr<DatasetCache> cache) | const bool &decode, const std::shared_ptr<SamplerObj> &sampler, std::shared_ptr<DatasetCache> cache) | ||||
| : Dataset(std::move(cache)), | |||||
| : DatasetNode(std::move(cache)), | |||||
| dataset_dir_(dataset_dir), | dataset_dir_(dataset_dir), | ||||
| annotation_file_(annotation_file), | annotation_file_(annotation_file), | ||||
| task_(task), | task_(task), | ||||
| @@ -21,12 +21,12 @@ | |||||
| #include <string> | #include <string> | ||||
| #include <vector> | #include <vector> | ||||
| #include "minddata/dataset/include/datasets.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/dataset_node.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| namespace api { | namespace api { | ||||
| class CocoNode : public Dataset { | |||||
| class CocoNode : public DatasetNode { | |||||
| public: | public: | ||||
| /// \brief Constructor | /// \brief Constructor | ||||
| CocoNode(const std::string &dataset_dir, const std::string &annotation_file, const std::string &task, | CocoNode(const std::string &dataset_dir, const std::string &annotation_file, const std::string &task, | ||||
| @@ -33,7 +33,7 @@ CSVNode::CSVNode(const std::vector<std::string> &csv_files, char field_delim, | |||||
| const std::vector<std::shared_ptr<CsvBase>> &column_defaults, | const std::vector<std::shared_ptr<CsvBase>> &column_defaults, | ||||
| const std::vector<std::string> &column_names, int64_t num_samples, ShuffleMode shuffle, | const std::vector<std::string> &column_names, int64_t num_samples, ShuffleMode shuffle, | ||||
| int32_t num_shards, int32_t shard_id, std::shared_ptr<DatasetCache> cache) | int32_t num_shards, int32_t shard_id, std::shared_ptr<DatasetCache> cache) | ||||
| : Dataset(std::move(cache)), | |||||
| : DatasetNode(std::move(cache)), | |||||
| dataset_files_(csv_files), | dataset_files_(csv_files), | ||||
| field_delim_(field_delim), | field_delim_(field_delim), | ||||
| column_defaults_(column_defaults), | column_defaults_(column_defaults), | ||||
| @@ -21,7 +21,7 @@ | |||||
| #include <string> | #include <string> | ||||
| #include <vector> | #include <vector> | ||||
| #include "minddata/dataset/include/datasets.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/dataset_node.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| @@ -47,7 +47,7 @@ class CsvRecord : public CsvBase { | |||||
| T value; | T value; | ||||
| }; | }; | ||||
| class CSVNode : public Dataset { | |||||
| class CSVNode : public DatasetNode { | |||||
| public: | public: | ||||
| /// \brief Constructor | /// \brief Constructor | ||||
| CSVNode(const std::vector<std::string> &dataset_files, char field_delim, | CSVNode(const std::vector<std::string> &dataset_files, char field_delim, | ||||
| @@ -21,7 +21,7 @@ | |||||
| #include <string> | #include <string> | ||||
| #include <vector> | #include <vector> | ||||
| #include "minddata/dataset/include/datasets.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/dataset_node.h" | |||||
| #include "minddata/dataset/util/status.h" | #include "minddata/dataset/util/status.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| @@ -31,7 +31,7 @@ namespace api { | |||||
| /// \class GeneratorNode | /// \class GeneratorNode | ||||
| /// \brief A Dataset derived class to represent GeneratorNode dataset | /// \brief A Dataset derived class to represent GeneratorNode dataset | ||||
| class GeneratorNode : public Dataset { | |||||
| class GeneratorNode : public DatasetNode { | |||||
| public: | public: | ||||
| /// \brief Constructor | /// \brief Constructor | ||||
| GeneratorNode(py::function generator_function, const std::vector<std::string> &column_names, | GeneratorNode(py::function generator_function, const std::vector<std::string> &column_names, | ||||
| @@ -40,7 +40,7 @@ ImageFolderNode::ImageFolderNode(std::string dataset_dir, bool decode, std::shar | |||||
| recursive_(recursive), | recursive_(recursive), | ||||
| class_indexing_(class_indexing), | class_indexing_(class_indexing), | ||||
| exts_(extensions), | exts_(extensions), | ||||
| Dataset(std::move(cache)) {} | |||||
| DatasetNode(std::move(cache)) {} | |||||
| Status ImageFolderNode::ValidateParams() { | Status ImageFolderNode::ValidateParams() { | ||||
| RETURN_IF_NOT_OK(ValidateDatasetDirParam("ImageFolderNode", dataset_dir_)); | RETURN_IF_NOT_OK(ValidateDatasetDirParam("ImageFolderNode", dataset_dir_)); | ||||
| @@ -24,7 +24,7 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include "mindspore/ccsrc/minddata/dataset/engine/ir/cache/dataset_cache.h" | #include "mindspore/ccsrc/minddata/dataset/engine/ir/cache/dataset_cache.h" | ||||
| #include "minddata/dataset/include/datasets.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/dataset_node.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| @@ -33,7 +33,7 @@ namespace api { | |||||
| /// \class ImageFolderNode | /// \class ImageFolderNode | ||||
| /// \brief A Dataset derived class to represent ImageFolder dataset | /// \brief A Dataset derived class to represent ImageFolder dataset | ||||
| class ImageFolderNode : public Dataset { | |||||
| class ImageFolderNode : public DatasetNode { | |||||
| public: | public: | ||||
| /// \brief Constructor | /// \brief Constructor | ||||
| ImageFolderNode(std::string dataset_dir, bool decode, std::shared_ptr<SamplerObj> sampler, bool recursive, | ImageFolderNode(std::string dataset_dir, bool decode, std::shared_ptr<SamplerObj> sampler, bool recursive, | ||||
| @@ -32,7 +32,7 @@ ManifestNode::ManifestNode(const std::string &dataset_file, const std::string &u | |||||
| const std::shared_ptr<SamplerObj> &sampler, | const std::shared_ptr<SamplerObj> &sampler, | ||||
| const std::map<std::string, int32_t> &class_indexing, bool decode, | const std::map<std::string, int32_t> &class_indexing, bool decode, | ||||
| std::shared_ptr<DatasetCache> cache) | std::shared_ptr<DatasetCache> cache) | ||||
| : Dataset(std::move(cache)), | |||||
| : DatasetNode(std::move(cache)), | |||||
| dataset_file_(dataset_file), | dataset_file_(dataset_file), | ||||
| usage_(usage), | usage_(usage), | ||||
| decode_(decode), | decode_(decode), | ||||
| @@ -22,12 +22,12 @@ | |||||
| #include <string> | #include <string> | ||||
| #include <vector> | #include <vector> | ||||
| #include "minddata/dataset/include/datasets.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/dataset_node.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| namespace api { | namespace api { | ||||
| class ManifestNode : public Dataset { | |||||
| class ManifestNode : public DatasetNode { | |||||
| public: | public: | ||||
| /// \brief Constructor | /// \brief Constructor | ||||
| ManifestNode(const std::string &dataset_file, const std::string &usage, const std::shared_ptr<SamplerObj> &sampler, | ManifestNode(const std::string &dataset_file, const std::string &usage, const std::shared_ptr<SamplerObj> &sampler, | ||||
| @@ -22,12 +22,12 @@ | |||||
| #include <string> | #include <string> | ||||
| #include <vector> | #include <vector> | ||||
| #include "minddata/dataset/include/datasets.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/dataset_node.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| namespace api { | namespace api { | ||||
| class MindDataNode : public Dataset { | |||||
| class MindDataNode : public DatasetNode { | |||||
| public: | public: | ||||
| /// \brief Constructor | /// \brief Constructor | ||||
| MindDataNode(const std::vector<std::string> &dataset_files, const std::vector<std::string> &columns_list, | MindDataNode(const std::vector<std::string> &dataset_files, const std::vector<std::string> &columns_list, | ||||
| @@ -30,7 +30,7 @@ namespace api { | |||||
| MnistNode::MnistNode(std::string dataset_dir, std::string usage, std::shared_ptr<SamplerObj> sampler, | MnistNode::MnistNode(std::string dataset_dir, std::string usage, std::shared_ptr<SamplerObj> sampler, | ||||
| std::shared_ptr<DatasetCache> cache) | std::shared_ptr<DatasetCache> cache) | ||||
| : Dataset(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {} | |||||
| : DatasetNode(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {} | |||||
| Status MnistNode::ValidateParams() { | Status MnistNode::ValidateParams() { | ||||
| RETURN_IF_NOT_OK(ValidateDatasetDirParam("MnistNode", dataset_dir_)); | RETURN_IF_NOT_OK(ValidateDatasetDirParam("MnistNode", dataset_dir_)); | ||||
| @@ -21,13 +21,13 @@ | |||||
| #include <string> | #include <string> | ||||
| #include <vector> | #include <vector> | ||||
| #include "minddata/dataset/include/datasets.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/dataset_node.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| namespace api { | namespace api { | ||||
| class MnistNode : public Dataset { | |||||
| class MnistNode : public DatasetNode { | |||||
| public: | public: | ||||
| /// \brief Constructor | /// \brief Constructor | ||||
| MnistNode(std::string dataset_dir, std::string usage, std::shared_ptr<SamplerObj> sampler, | MnistNode(std::string dataset_dir, std::string usage, std::shared_ptr<SamplerObj> sampler, | ||||
| @@ -22,13 +22,13 @@ | |||||
| #include <utility> | #include <utility> | ||||
| #include <vector> | #include <vector> | ||||
| #include "minddata/dataset/include/datasets.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/dataset_node.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| namespace api { | namespace api { | ||||
| class RandomNode : public Dataset { | |||||
| class RandomNode : public DatasetNode { | |||||
| public: | public: | ||||
| // Some constants to provide limits to random generation. | // Some constants to provide limits to random generation. | ||||
| static constexpr int32_t kMaxNumColumns = 4; | static constexpr int32_t kMaxNumColumns = 4; | ||||
| @@ -38,7 +38,7 @@ class RandomNode : public Dataset { | |||||
| /// \brief Constructor | /// \brief Constructor | ||||
| RandomNode(const int32_t &total_rows, std::shared_ptr<SchemaObj> schema, const std::vector<std::string> &columns_list, | RandomNode(const int32_t &total_rows, std::shared_ptr<SchemaObj> schema, const std::vector<std::string> &columns_list, | ||||
| const std::shared_ptr<SamplerObj> &sampler, std::shared_ptr<DatasetCache> cache) | const std::shared_ptr<SamplerObj> &sampler, std::shared_ptr<DatasetCache> cache) | ||||
| : Dataset(std::move(cache)), | |||||
| : DatasetNode(std::move(cache)), | |||||
| total_rows_(total_rows), | total_rows_(total_rows), | ||||
| schema_path_(""), | schema_path_(""), | ||||
| schema_(std::move(schema)), | schema_(std::move(schema)), | ||||
| @@ -48,7 +48,7 @@ class RandomNode : public Dataset { | |||||
| /// \brief Constructor | /// \brief Constructor | ||||
| RandomNode(const int32_t &total_rows, std::string schema_path, const std::vector<std::string> &columns_list, | RandomNode(const int32_t &total_rows, std::string schema_path, const std::vector<std::string> &columns_list, | ||||
| const std::shared_ptr<SamplerObj> &sampler, std::shared_ptr<DatasetCache> cache) | const std::shared_ptr<SamplerObj> &sampler, std::shared_ptr<DatasetCache> cache) | ||||
| : Dataset(std::move(cache)), | |||||
| : DatasetNode(std::move(cache)), | |||||
| total_rows_(total_rows), | total_rows_(total_rows), | ||||
| schema_path_(schema_path), | schema_path_(schema_path), | ||||
| columns_list_(columns_list), | columns_list_(columns_list), | ||||
| @@ -31,7 +31,7 @@ namespace api { | |||||
| // Constructor for TextFileNode | // Constructor for TextFileNode | ||||
| TextFileNode::TextFileNode(std::vector<std::string> dataset_files, int32_t num_samples, ShuffleMode shuffle, | TextFileNode::TextFileNode(std::vector<std::string> dataset_files, int32_t num_samples, ShuffleMode shuffle, | ||||
| int32_t num_shards, int32_t shard_id, std::shared_ptr<DatasetCache> cache) | int32_t num_shards, int32_t shard_id, std::shared_ptr<DatasetCache> cache) | ||||
| : Dataset(std::move(cache)), | |||||
| : DatasetNode(std::move(cache)), | |||||
| dataset_files_(dataset_files), | dataset_files_(dataset_files), | ||||
| num_samples_(num_samples), | num_samples_(num_samples), | ||||
| shuffle_(shuffle), | shuffle_(shuffle), | ||||
| @@ -21,14 +21,14 @@ | |||||
| #include <string> | #include <string> | ||||
| #include <vector> | #include <vector> | ||||
| #include "minddata/dataset/include/datasets.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/dataset_node.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| namespace api { | namespace api { | ||||
| /// \class TextFileNode | /// \class TextFileNode | ||||
| /// \brief A Dataset derived class to represent TextFile dataset | /// \brief A Dataset derived class to represent TextFile dataset | ||||
| class TextFileNode : public Dataset { | |||||
| class TextFileNode : public DatasetNode { | |||||
| public: | public: | ||||
| /// \brief Constructor | /// \brief Constructor | ||||
| TextFileNode(std::vector<std::string> dataset_files, int32_t num_samples, ShuffleMode shuffle, int32_t num_shards, | TextFileNode(std::vector<std::string> dataset_files, int32_t num_samples, ShuffleMode shuffle, int32_t num_shards, | ||||
| @@ -55,6 +55,53 @@ bool ValidateFirstRowCrc(const std::string &filename) { | |||||
| // Validator for TFRecordNode | // Validator for TFRecordNode | ||||
| Status TFRecordNode::ValidateParams() { | Status TFRecordNode::ValidateParams() { | ||||
| if (dataset_files_.empty()) { | |||||
| std::string err_msg = "TFRecordNode: dataset_files is not specified."; | |||||
| MS_LOG(ERROR) << err_msg; | |||||
| return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg); | |||||
| } | |||||
| for (const auto &f : dataset_files_) { | |||||
| Path dataset_file(f); | |||||
| if (!dataset_file.Exists()) { | |||||
| std::string err_msg = "TFRecordNode: dataset file: [" + f + "] is invalid or does not exist."; | |||||
| MS_LOG(ERROR) << err_msg; | |||||
| return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg); | |||||
| } | |||||
| } | |||||
| if (num_samples_ < 0) { | |||||
| std::string err_msg = "TFRecordNode: Invalid number of samples: " + std::to_string(num_samples_); | |||||
| MS_LOG(ERROR) << err_msg; | |||||
| return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg); | |||||
| } | |||||
| if (num_shards_ <= 0) { | |||||
| std::string err_msg = "TFRecordNode: Invalid num_shards: " + std::to_string(num_shards_); | |||||
| MS_LOG(ERROR) << err_msg; | |||||
| return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg); | |||||
| } | |||||
| if (shard_id_ < 0 || shard_id_ >= num_shards_) { | |||||
| std::string err_msg = "TFRecordNode: Invalid input, shard_id: " + std::to_string(shard_id_) + | |||||
| ", num_shards: " + std::to_string(num_shards_); | |||||
| MS_LOG(ERROR) << err_msg; | |||||
| return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg); | |||||
| } | |||||
| if (cache_ == nullptr && !shard_equal_rows_ && dataset_files_.size() < num_shards_) { | |||||
| // This check only makes sense in a non-cache path. We should make sure there is at least one file per | |||||
| // shard in file-based sharding | |||||
| std::string err_msg = | |||||
| "TFRecordNode: Invalid number of dataset files, should at least be " + std::to_string(num_shards_); | |||||
| MS_LOG(ERROR) << err_msg; | |||||
| return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg); | |||||
| } | |||||
| std::vector<std::string> invalid_files(dataset_files_.size()); | std::vector<std::string> invalid_files(dataset_files_.size()); | ||||
| auto it = std::copy_if(dataset_files_.begin(), dataset_files_.end(), invalid_files.begin(), | auto it = std::copy_if(dataset_files_.begin(), dataset_files_.end(), invalid_files.begin(), | ||||
| [](const std::string &filename) { return !ValidateFirstRowCrc(filename); }); | [](const std::string &filename) { return !ValidateFirstRowCrc(filename); }); | ||||
| @@ -22,21 +22,21 @@ | |||||
| #include <utility> | #include <utility> | ||||
| #include <vector> | #include <vector> | ||||
| #include "minddata/dataset/include/datasets.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/dataset_node.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| namespace api { | namespace api { | ||||
| /// \class TFRecordNode | /// \class TFRecordNode | ||||
| /// \brief A Dataset derived class to represent TFRecord dataset | /// \brief A Dataset derived class to represent TFRecord dataset | ||||
| class TFRecordNode : public Dataset { | |||||
| class TFRecordNode : public DatasetNode { | |||||
| public: | public: | ||||
| /// \brief Constructor | /// \brief Constructor | ||||
| /// \note Parameter 'schema' is the path to the schema file | /// \note Parameter 'schema' is the path to the schema file | ||||
| TFRecordNode(const std::vector<std::string> &dataset_files, std::string schema, | TFRecordNode(const std::vector<std::string> &dataset_files, std::string schema, | ||||
| const std::vector<std::string> &columns_list, int64_t num_samples, ShuffleMode shuffle, | const std::vector<std::string> &columns_list, int64_t num_samples, ShuffleMode shuffle, | ||||
| int32_t num_shards, int32_t shard_id, bool shard_equal_rows, std::shared_ptr<DatasetCache> cache) | int32_t num_shards, int32_t shard_id, bool shard_equal_rows, std::shared_ptr<DatasetCache> cache) | ||||
| : Dataset(std::move(cache)), | |||||
| : DatasetNode(std::move(cache)), | |||||
| dataset_files_(dataset_files), | dataset_files_(dataset_files), | ||||
| schema_path_(schema), | schema_path_(schema), | ||||
| columns_list_(columns_list), | columns_list_(columns_list), | ||||
| @@ -51,7 +51,7 @@ class TFRecordNode : public Dataset { | |||||
| TFRecordNode(const std::vector<std::string> &dataset_files, std::shared_ptr<SchemaObj> schema, | TFRecordNode(const std::vector<std::string> &dataset_files, std::shared_ptr<SchemaObj> schema, | ||||
| const std::vector<std::string> &columns_list, int64_t num_samples, ShuffleMode shuffle, | const std::vector<std::string> &columns_list, int64_t num_samples, ShuffleMode shuffle, | ||||
| int32_t num_shards, int32_t shard_id, bool shard_equal_rows, std::shared_ptr<DatasetCache> cache) | int32_t num_shards, int32_t shard_id, bool shard_equal_rows, std::shared_ptr<DatasetCache> cache) | ||||
| : Dataset(std::move(cache)), | |||||
| : DatasetNode(std::move(cache)), | |||||
| dataset_files_(dataset_files), | dataset_files_(dataset_files), | ||||
| schema_obj_(schema), | schema_obj_(schema), | ||||
| columns_list_(columns_list), | columns_list_(columns_list), | ||||
| @@ -32,7 +32,7 @@ namespace api { | |||||
| VOCNode::VOCNode(const std::string &dataset_dir, const std::string &task, const std::string &usage, | VOCNode::VOCNode(const std::string &dataset_dir, const std::string &task, const std::string &usage, | ||||
| const std::map<std::string, int32_t> &class_indexing, bool decode, std::shared_ptr<SamplerObj> sampler, | const std::map<std::string, int32_t> &class_indexing, bool decode, std::shared_ptr<SamplerObj> sampler, | ||||
| std::shared_ptr<DatasetCache> cache) | std::shared_ptr<DatasetCache> cache) | ||||
| : Dataset(std::move(cache)), | |||||
| : DatasetNode(std::move(cache)), | |||||
| dataset_dir_(dataset_dir), | dataset_dir_(dataset_dir), | ||||
| task_(task), | task_(task), | ||||
| usage_(usage), | usage_(usage), | ||||
| @@ -22,12 +22,12 @@ | |||||
| #include <string> | #include <string> | ||||
| #include <vector> | #include <vector> | ||||
| #include "minddata/dataset/include/datasets.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/dataset_node.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| namespace api { | namespace api { | ||||
| class VOCNode : public Dataset { | |||||
| class VOCNode : public DatasetNode { | |||||
| public: | public: | ||||
| /// \brief Constructor | /// \brief Constructor | ||||
| VOCNode(const std::string &dataset_dir, const std::string &task, const std::string &usage, | VOCNode(const std::string &dataset_dir, const std::string &task, const std::string &usage, | ||||
| @@ -27,7 +27,7 @@ namespace mindspore { | |||||
| namespace dataset { | namespace dataset { | ||||
| namespace api { | namespace api { | ||||
| // Constructor for SyncWaitNode | // Constructor for SyncWaitNode | ||||
| SyncWaitNode::SyncWaitNode(std::shared_ptr<Dataset> child, const std::string &condition_name, int32_t num_batch, | |||||
| SyncWaitNode::SyncWaitNode(std::shared_ptr<DatasetNode> child, const std::string &condition_name, int32_t num_batch, | |||||
| py::function callback) | py::function callback) | ||||
| : condition_name_(condition_name), num_batch_(num_batch), callback_(callback) { | : condition_name_(condition_name), num_batch_(num_batch), callback_(callback) { | ||||
| this->children.push_back(child); | this->children.push_back(child); | ||||
| @@ -21,7 +21,7 @@ | |||||
| #include <string> | #include <string> | ||||
| #include <vector> | #include <vector> | ||||
| #include "minddata/dataset/include/datasets.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/dataset_node.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| @@ -30,10 +30,10 @@ namespace api { | |||||
| /// \class SyncWaitNode | /// \class SyncWaitNode | ||||
| /// \brief A Dataset derived class to represent SyncWaitNode dataset | /// \brief A Dataset derived class to represent SyncWaitNode dataset | ||||
| class SyncWaitNode : public Dataset { | |||||
| class SyncWaitNode : public DatasetNode { | |||||
| public: | public: | ||||
| /// \brief Constructor | /// \brief Constructor | ||||
| explicit SyncWaitNode(std::shared_ptr<Dataset> child, const std::string &condition_name, int32_t num_batch, | |||||
| explicit SyncWaitNode(std::shared_ptr<DatasetNode> child, const std::string &condition_name, int32_t num_batch, | |||||
| py::function callback); | py::function callback); | ||||
| /// \brief Destructor | /// \brief Destructor | ||||
| @@ -27,7 +27,7 @@ namespace mindspore { | |||||
| namespace dataset { | namespace dataset { | ||||
| namespace api { | namespace api { | ||||
| // Constructor for TakeNode | // Constructor for TakeNode | ||||
| TakeNode::TakeNode(std::shared_ptr<Dataset> child, int32_t count) : take_count_(count) { | |||||
| TakeNode::TakeNode(std::shared_ptr<DatasetNode> child, int32_t count) : take_count_(count) { | |||||
| this->children.push_back(child); | this->children.push_back(child); | ||||
| } | } | ||||
| @@ -21,17 +21,17 @@ | |||||
| #include <string> | #include <string> | ||||
| #include <vector> | #include <vector> | ||||
| #include "minddata/dataset/include/datasets.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/dataset_node.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| namespace api { | namespace api { | ||||
| class TakeNode : public Dataset { | |||||
| class TakeNode : public DatasetNode { | |||||
| public: | public: | ||||
| /// \brief Constructor | /// \brief Constructor | ||||
| explicit TakeNode(std::shared_ptr<Dataset> child, int32_t count); | |||||
| explicit TakeNode(std::shared_ptr<DatasetNode> child, int32_t count); | |||||
| /// \brief Destructor | /// \brief Destructor | ||||
| ~TakeNode() = default; | ~TakeNode() = default; | ||||
| @@ -28,14 +28,8 @@ namespace dataset { | |||||
| namespace api { | namespace api { | ||||
| // Constructor for TransferNode | // Constructor for TransferNode | ||||
| TransferNode::TransferNode(std::shared_ptr<Dataset> child, const std::string &queue_name, int32_t device_id, | |||||
| const std::string &device_type, bool send_epoch_end) | |||||
| : queue_name_(queue_name), | |||||
| device_id_(device_id), | |||||
| device_type_(device_type), | |||||
| prefetch_size_(16), | |||||
| send_epoch_end_(send_epoch_end), | |||||
| total_batch_(0) { | |||||
| TransferNode::TransferNode(std::shared_ptr<DatasetNode> child, bool send_epoch_end) | |||||
| : prefetch_size_(16), send_epoch_end_(send_epoch_end), total_batch_(0) { | |||||
| this->children.push_back(child); | this->children.push_back(child); | ||||
| } | } | ||||
| @@ -48,6 +42,15 @@ Status TransferNode::ValidateParams() { | |||||
| // Function to build TransferNode | // Function to build TransferNode | ||||
| std::vector<std::shared_ptr<DatasetOp>> TransferNode::Build() { | std::vector<std::shared_ptr<DatasetOp>> TransferNode::Build() { | ||||
| // Get a uuid for queue name | |||||
| queue_name_ = Services::GetUniqueID(); | |||||
| // TODO(CRC): | |||||
| // Get device type from ms context | |||||
| device_type_ = "CPU"; | |||||
| // Get device ID from children | |||||
| device_id_ = 0; | |||||
| RETURN_EMPTY_IF_ERROR(TransferNode::get_distribution(shared_from_this(), &device_id_)); | |||||
| // A vector containing shared pointer to the Dataset Ops that this object will create | // A vector containing shared pointer to the Dataset Ops that this object will create | ||||
| std::vector<std::shared_ptr<DatasetOp>> node_ops; | std::vector<std::shared_ptr<DatasetOp>> node_ops; | ||||
| @@ -67,13 +70,13 @@ std::vector<std::shared_ptr<DatasetOp>> TransferNode::Build() { | |||||
| } | } | ||||
| // Function to get the device_id | // Function to get the device_id | ||||
| Status TransferNode::get_distribution(std::shared_ptr<Dataset> ds, int32_t *device_id) { | |||||
| Status TransferNode::get_distribution(std::shared_ptr<DatasetNode> ds, int32_t *device_id) { | |||||
| // Get device id according to the type of dataset | // Get device id according to the type of dataset | ||||
| Status rc = ds->GetShardId(device_id); | Status rc = ds->GetShardId(device_id); | ||||
| if (rc != Status::OK()) { | if (rc != Status::OK()) { | ||||
| // Get device id from the child node | // Get device id from the child node | ||||
| if (ds->children.size()) { | |||||
| ds = ds->children[0]; | |||||
| if (ds->Children().size()) { | |||||
| ds = ds->Children()[0]; | |||||
| return TransferNode::get_distribution(ds, device_id); | return TransferNode::get_distribution(ds, device_id); | ||||
| } else { | } else { | ||||
| std::string err_msg = "Unknown dataset type."; | std::string err_msg = "Unknown dataset type."; | ||||
| @@ -21,18 +21,17 @@ | |||||
| #include <string> | #include <string> | ||||
| #include <vector> | #include <vector> | ||||
| #include "minddata/dataset/include/datasets.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/dataset_node.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| namespace api { | namespace api { | ||||
| class TransferNode : public Dataset { | |||||
| class TransferNode : public DatasetNode { | |||||
| public: | public: | ||||
| /// \brief Constructor | /// \brief Constructor | ||||
| TransferNode(std::shared_ptr<Dataset> child, const std::string &queue_name, int32_t device_id, | |||||
| const std::string &device_type, bool send_epoch_end); | |||||
| TransferNode(std::shared_ptr<DatasetNode> child, bool send_epoch_end); | |||||
| /// \brief Destructor | /// \brief Destructor | ||||
| ~TransferNode() = default; | ~TransferNode() = default; | ||||
| @@ -45,7 +44,7 @@ class TransferNode : public Dataset { | |||||
| /// \return Status Status::OK() if all the parameters are valid | /// \return Status Status::OK() if all the parameters are valid | ||||
| Status ValidateParams() override; | Status ValidateParams() override; | ||||
| static Status get_distribution(std::shared_ptr<Dataset> ds, int32_t *device_id); | |||||
| static Status get_distribution(std::shared_ptr<DatasetNode> ds, int32_t *device_id); | |||||
| private: | private: | ||||
| std::string queue_name_; | std::string queue_name_; | ||||
| @@ -27,7 +27,7 @@ namespace mindspore { | |||||
| namespace dataset { | namespace dataset { | ||||
| namespace api { | namespace api { | ||||
| ZipNode::ZipNode(const std::vector<std::shared_ptr<Dataset>> &datasets) : datasets_(datasets) { | |||||
| ZipNode::ZipNode(const std::vector<std::shared_ptr<DatasetNode>> &datasets) : datasets_(datasets) { | |||||
| for (auto dataset : datasets_) { | for (auto dataset : datasets_) { | ||||
| this->children.push_back(dataset); | this->children.push_back(dataset); | ||||
| } | } | ||||
| @@ -21,16 +21,16 @@ | |||||
| #include <string> | #include <string> | ||||
| #include <vector> | #include <vector> | ||||
| #include "minddata/dataset/include/datasets.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/dataset_node.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| namespace api { | namespace api { | ||||
| class ZipNode : public Dataset { | |||||
| class ZipNode : public DatasetNode { | |||||
| public: | public: | ||||
| /// \brief Constructor | /// \brief Constructor | ||||
| explicit ZipNode(const std::vector<std::shared_ptr<Dataset>> &datasets); | |||||
| explicit ZipNode(const std::vector<std::shared_ptr<DatasetNode>> &datasets); | |||||
| /// \brief Destructor | /// \brief Destructor | ||||
| ~ZipNode() = default; | ~ZipNode() = default; | ||||
| @@ -44,7 +44,7 @@ class ZipNode : public Dataset { | |||||
| Status ValidateParams() override; | Status ValidateParams() override; | ||||
| private: | private: | ||||
| std::vector<std::shared_ptr<Dataset>> datasets_; | |||||
| std::vector<std::shared_ptr<DatasetNode>> datasets_; | |||||
| }; | }; | ||||
| } // namespace api | } // namespace api | ||||
| @@ -0,0 +1,27 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "minddata/dataset/engine/python_runtime_context.h" | |||||
| #include "pybind11/pybind11.h" | |||||
| namespace mindspore::dataset { | |||||
| Status PythonRuntimeContext::Terminate() { | |||||
| // Release GIL before joining all threads | |||||
| py::gil_scoped_release gil_release; | |||||
| return tree_consumer_->Terminate(); | |||||
| } | |||||
| } // namespace mindspore::dataset | |||||
| @@ -0,0 +1,48 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_PYTHON_RUNTIME_CONTEXT_H_ | |||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_PYTHON_RUNTIME_CONTEXT_H_ | |||||
| #include <memory> | |||||
| #include <utility> | |||||
| #include "minddata/dataset/core/client.h" | |||||
| #include "minddata/dataset/engine/consumers/tree_consumer.h" | |||||
| #include "minddata/dataset/engine/consumers/python_tree_consumer.h" | |||||
| #include "minddata/dataset/engine/runtime_context.h" | |||||
| namespace mindspore::dataset { | |||||
| class RuntimeContext; | |||||
| /// Class the represents single runtime instance which can consume data from a data pipeline | |||||
| class PythonRuntimeContext : public RuntimeContext { | |||||
| public: | |||||
| /// Method to terminate the runtime, this will not release the resources | |||||
| /// \return Status error code | |||||
| Status Terminate() override; | |||||
| ~PythonRuntimeContext() { | |||||
| Terminate(); | |||||
| { | |||||
| py::gil_scoped_acquire gil_acquire; | |||||
| tree_consumer_.reset(); | |||||
| } | |||||
| } | |||||
| PythonIteratorConsumer *GetPythonConsumer() { return dynamic_cast<PythonIteratorConsumer *>(tree_consumer_.get()); } | |||||
| }; | |||||
| } // namespace mindspore::dataset | |||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_PYTHON_RUNTIME_CONTEXT_H_ | |||||
| @@ -19,7 +19,7 @@ | |||||
| #include <utility> | #include <utility> | ||||
| namespace mindspore::dataset { | namespace mindspore::dataset { | ||||
| void RuntimeContext::AssignConsumer(std::unique_ptr<TreeConsumer> tree_consumer) { | |||||
| void RuntimeContext::AssignConsumer(std::shared_ptr<TreeConsumer> tree_consumer) { | |||||
| tree_consumer_ = std::move(tree_consumer); | tree_consumer_ = std::move(tree_consumer); | ||||
| } | } | ||||
| } // namespace mindspore::dataset | } // namespace mindspore::dataset | ||||
| @@ -40,14 +40,16 @@ class RuntimeContext { | |||||
| /// Set the tree consumer | /// Set the tree consumer | ||||
| /// \param tree_consumer to be assigned | /// \param tree_consumer to be assigned | ||||
| void AssignConsumer(std::unique_ptr<TreeConsumer> tree_consumer); | |||||
| void AssignConsumer(std::shared_ptr<TreeConsumer> tree_consumer); | |||||
| /// Get the tree consumer | /// Get the tree consumer | ||||
| /// \return Raw pointer to the tree consumer. | /// \return Raw pointer to the tree consumer. | ||||
| TreeConsumer *GetConsumer() { return tree_consumer_.get(); } | TreeConsumer *GetConsumer() { return tree_consumer_.get(); } | ||||
| private: | |||||
| std::unique_ptr<TreeConsumer> tree_consumer_; | |||||
| ~RuntimeContext() { Terminate(); } | |||||
| protected: | |||||
| std::shared_ptr<TreeConsumer> tree_consumer_; | |||||
| }; | }; | ||||
| } // namespace mindspore::dataset | } // namespace mindspore::dataset | ||||
| @@ -22,7 +22,7 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| Status TreeAdapter::BuildAndPrepare(std::shared_ptr<api::Dataset> root_ir, int32_t num_epoch) { | |||||
| Status TreeAdapter::BuildAndPrepare(std::shared_ptr<api::DatasetNode> root_ir, int32_t num_epoch) { | |||||
| // Check whether this function has been called before. If so, return failure | // Check whether this function has been called before. If so, return failure | ||||
| CHECK_FAIL_RETURN_UNEXPECTED(tree_ == nullptr, "ExecutionTree is already built."); | CHECK_FAIL_RETURN_UNEXPECTED(tree_ == nullptr, "ExecutionTree is already built."); | ||||
| RETURN_UNEXPECTED_IF_NULL(root_ir); | RETURN_UNEXPECTED_IF_NULL(root_ir); | ||||
| @@ -65,7 +65,7 @@ Status TreeAdapter::GetNext(TensorRow *row) { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status TreeAdapter::DFSBuildTree(std::shared_ptr<api::Dataset> ir, std::shared_ptr<DatasetOp> *op) { | |||||
| Status TreeAdapter::DFSBuildTree(std::shared_ptr<api::DatasetNode> ir, std::shared_ptr<DatasetOp> *op) { | |||||
| // validate the op can be built first before building the DatasetOp | // validate the op can be built first before building the DatasetOp | ||||
| RETURN_IF_NOT_OK(ir->ValidateParams()); | RETURN_IF_NOT_OK(ir->ValidateParams()); | ||||
| std::vector<std::shared_ptr<DatasetOp>> ops = ir->Build(); | std::vector<std::shared_ptr<DatasetOp>> ops = ir->Build(); | ||||
| @@ -80,7 +80,7 @@ Status TreeAdapter::DFSBuildTree(std::shared_ptr<api::Dataset> ir, std::shared_p | |||||
| } | } | ||||
| // Build the children of ir, once they return, add the return value to *op | // Build the children of ir, once they return, add the return value to *op | ||||
| for (std::shared_ptr<api::Dataset> child_ir : ir->children) { | |||||
| for (const auto &child_ir : ir->Children()) { | |||||
| std::shared_ptr<DatasetOp> child_op; | std::shared_ptr<DatasetOp> child_op; | ||||
| RETURN_IF_NOT_OK(DFSBuildTree(child_ir, &child_op)); | RETURN_IF_NOT_OK(DFSBuildTree(child_ir, &child_op)); | ||||
| RETURN_IF_NOT_OK(ops.back()->AddChild(child_op)); // append children to the last of ops | RETURN_IF_NOT_OK(ops.back()->AddChild(child_op)); // append children to the last of ops | ||||
| @@ -24,12 +24,12 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include "minddata/dataset/engine/execution_tree.h" | #include "minddata/dataset/engine/execution_tree.h" | ||||
| #include "minddata/dataset/include/datasets.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/dataset_node.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| namespace api { | namespace api { | ||||
| class Dataset; | |||||
| class DatasetNode; | |||||
| } | } | ||||
| class TreeAdapter { | class TreeAdapter { | ||||
| public: | public: | ||||
| @@ -40,7 +40,7 @@ class TreeAdapter { | |||||
| // This will construct an ExeTree from a Dataset root and Prepare() the ExeTree | // This will construct an ExeTree from a Dataset root and Prepare() the ExeTree | ||||
| // This function is only meant to be called once and needs to be called before GetNext | // This function is only meant to be called once and needs to be called before GetNext | ||||
| // ExeTree will be launched when the first GetNext is called | // ExeTree will be launched when the first GetNext is called | ||||
| Status BuildAndPrepare(std::shared_ptr<api::Dataset> root, int32_t num_epoch = -1); | |||||
| Status BuildAndPrepare(std::shared_ptr<api::DatasetNode> root, int32_t num_epoch = -1); | |||||
| // This is the main method TreeConsumer uses to interact with TreeAdapter | // This is the main method TreeConsumer uses to interact with TreeAdapter | ||||
| // 1. GetNext will Launch() the ExeTree on its first call by iterator (tree is already prepared) | // 1. GetNext will Launch() the ExeTree on its first call by iterator (tree is already prepared) | ||||
| @@ -62,7 +62,7 @@ class TreeAdapter { | |||||
| private: | private: | ||||
| // This RECURSIVE function converts IR nodes into DatasetOp in ExecutionTree. IR could build a vector of ops. In | // This RECURSIVE function converts IR nodes into DatasetOp in ExecutionTree. IR could build a vector of ops. In | ||||
| // such case, the first node is returned. Op is added as child when the current function returns. | // such case, the first node is returned. Op is added as child when the current function returns. | ||||
| Status DFSBuildTree(std::shared_ptr<api::Dataset> ir, std::shared_ptr<DatasetOp> *op); | |||||
| Status DFSBuildTree(std::shared_ptr<api::DatasetNode> ir, std::shared_ptr<DatasetOp> *op); | |||||
| std::unique_ptr<DataBuffer> cur_db_; | std::unique_ptr<DataBuffer> cur_db_; | ||||
| std::unordered_map<std::string, int32_t> column_name_map_; | std::unordered_map<std::string, int32_t> column_name_map_; | ||||
| @@ -49,7 +49,7 @@ class Iterator { | |||||
| Iterator() : consumer_(nullptr) {} | Iterator() : consumer_(nullptr) {} | ||||
| /// \brief Destructor | /// \brief Destructor | ||||
| ~Iterator() = default; | |||||
| ~Iterator() { Stop(); } | |||||
| /// \brief Method for building and launching the pipeline. | /// \brief Method for building and launching the pipeline. | ||||
| /// \param[in] ops - a vector of DatasetOp in the data pipeline. | /// \param[in] ops - a vector of DatasetOp in the data pipeline. | ||||
| @@ -82,30 +82,30 @@ AUX_SOURCE_DIRECTORY(${MINDDATA_DIR}/kernels/image/lite_cv MINDDATA_KERNELS_IMA | |||||
| if (BUILD_MINDDATA STREQUAL "full") | if (BUILD_MINDDATA STREQUAL "full") | ||||
| include_directories("${CMAKE_SOURCE_DIR}/../ccsrc/minddata/dataset/kernels/image") | include_directories("${CMAKE_SOURCE_DIR}/../ccsrc/minddata/dataset/kernels/image") | ||||
| list(REMOVE_ITEM MINDDATA_API_SRC_FILES | |||||
| "${MINDDATA_DIR}/api/text.cc" | |||||
| ) | |||||
| list(REMOVE_ITEM MINDDATA_API_SRC_FILES | |||||
| "${MINDDATA_DIR}/api/text.cc" | |||||
| ) | |||||
| list(REMOVE_ITEM MINDDATA_CALLBACK_SRC_FILES | |||||
| "${MINDDATA_DIR}/callback/py_ds_callback.cc" | |||||
| ) | |||||
| list(REMOVE_ITEM MINDDATA_CALLBACK_SRC_FILES | |||||
| "${MINDDATA_DIR}/callback/py_ds_callback.cc" | |||||
| ) | |||||
| list(REMOVE_ITEM MINDDATA_CORE_SRC_FILES | list(REMOVE_ITEM MINDDATA_CORE_SRC_FILES | ||||
| "${MINDDATA_DIR}/core/cv_tensor.cc" | |||||
| ) | |||||
| "${MINDDATA_DIR}/core/cv_tensor.cc" | |||||
| ) | |||||
| list(REMOVE_ITEM MINDDATA_KERNELS_SRC_FILES "${MINDDATA_DIR}/kernels/py_func_op.cc") | list(REMOVE_ITEM MINDDATA_KERNELS_SRC_FILES "${MINDDATA_DIR}/kernels/py_func_op.cc") | ||||
| list(REMOVE_ITEM MINDDATA_ENGINE_DATASETOPS_SRC_FILES | list(REMOVE_ITEM MINDDATA_ENGINE_DATASETOPS_SRC_FILES | ||||
| "${MINDDATA_DIR}/engine/datasetops/build_sentence_piece_vocab_op.cc" | |||||
| "${MINDDATA_DIR}/engine/datasetops/filter_op.cc" | |||||
| "${MINDDATA_DIR}/engine/datasetops/barrier_op.cc" | |||||
| "${MINDDATA_DIR}/engine/datasetops/bucket_batch_by_length_op.cc" | |||||
| "${MINDDATA_DIR}/engine/datasetops/build_vocab_op.cc" | |||||
| "${MINDDATA_DIR}/engine/datasetops/cache_merge_op.cc" | |||||
| "${MINDDATA_DIR}/engine/datasetops/cache_base_op.cc" | |||||
| "${MINDDATA_DIR}/engine/datasetops/cache_lookup_op.cc" | |||||
| "${MINDDATA_DIR}/engine/datasetops/cache_op.cc" | |||||
| ) | |||||
| "${MINDDATA_DIR}/engine/datasetops/build_sentence_piece_vocab_op.cc" | |||||
| "${MINDDATA_DIR}/engine/datasetops/filter_op.cc" | |||||
| "${MINDDATA_DIR}/engine/datasetops/barrier_op.cc" | |||||
| "${MINDDATA_DIR}/engine/datasetops/bucket_batch_by_length_op.cc" | |||||
| "${MINDDATA_DIR}/engine/datasetops/build_vocab_op.cc" | |||||
| "${MINDDATA_DIR}/engine/datasetops/cache_merge_op.cc" | |||||
| "${MINDDATA_DIR}/engine/datasetops/cache_base_op.cc" | |||||
| "${MINDDATA_DIR}/engine/datasetops/cache_lookup_op.cc" | |||||
| "${MINDDATA_DIR}/engine/datasetops/cache_op.cc" | |||||
| ) | |||||
| list(REMOVE_ITEM MINDDATA_ENGINE_DATASETOPS_SOURCE_SRC_FILES | list(REMOVE_ITEM MINDDATA_ENGINE_DATASETOPS_SOURCE_SRC_FILES | ||||
| "${MINDDATA_DIR}/engine/datasetops/source/generator_op.cc" | "${MINDDATA_DIR}/engine/datasetops/source/generator_op.cc" | ||||
| @@ -161,47 +161,55 @@ if (BUILD_MINDDATA STREQUAL "full") | |||||
| "${MINDDATA_DIR}/kernels/image/random_crop_and_resize_with_bbox_op.cc" | "${MINDDATA_DIR}/kernels/image/random_crop_and_resize_with_bbox_op.cc" | ||||
| "${MINDDATA_DIR}/kernels/image/random_crop_decode_resize_op.cc" | "${MINDDATA_DIR}/kernels/image/random_crop_decode_resize_op.cc" | ||||
| "${MINDDATA_DIR}/kernels/image/random_crop_and_resize_op.cc" | "${MINDDATA_DIR}/kernels/image/random_crop_and_resize_op.cc" | ||||
| "${MINDDATA_DIR}/kernels/image/random_crop_op.cc" | |||||
| "${MINDDATA_DIR}/kernels/image/random_crop_with_bbox_op.cc" | |||||
| "${MINDDATA_DIR}/kernels/image/random_horizontal_flip_op.cc" | |||||
| "${MINDDATA_DIR}/kernels/image/random_horizontal_flip_with_bbox_op.cc" | |||||
| "${MINDDATA_DIR}/kernels/image/random_posterize_op.cc" | |||||
| "${MINDDATA_DIR}/kernels/image/random_resize_op.cc" | |||||
| "${MINDDATA_DIR}/kernels/image/random_rotation_op.cc" | |||||
| "${MINDDATA_DIR}/kernels/image/random_select_subpolicy_op.cc" | |||||
| "${MINDDATA_DIR}/kernels/image/random_solarize_op.cc" | |||||
| "${MINDDATA_DIR}/kernels/image/random_vertical_flip_op.cc" | |||||
| "${MINDDATA_DIR}/kernels/image/random_vertical_flip_with_bbox_op.cc" | |||||
| "${MINDDATA_DIR}/kernels/image/random_sharpness_op.cc" | |||||
| "${MINDDATA_DIR}/kernels/image/rescale_op.cc" | |||||
| "${MINDDATA_DIR}/kernels/image/rgba_to_bgr_op.cc" | |||||
| "${MINDDATA_DIR}/kernels/image/rgba_to_rgb_op.cc" | |||||
| "${MINDDATA_DIR}/kernels/image/sharpness_op.cc" | |||||
| "${MINDDATA_DIR}/kernels/image/solarize_op.cc" | |||||
| "${MINDDATA_DIR}/kernels/image/swap_red_blue_op.cc" | |||||
| "${MINDDATA_DIR}/kernels/image/uniform_aug_op.cc" | |||||
| "${MINDDATA_DIR}/kernels/image/resize_with_bbox_op.cc" | |||||
| "${MINDDATA_DIR}/kernels/image/random_resize_with_bbox_op.cc" | |||||
| "${MINDDATA_DIR}/kernels/image/random_color_op.cc" | |||||
| ) | |||||
| "${MINDDATA_DIR}/kernels/image/random_crop_op.cc" | |||||
| "${MINDDATA_DIR}/kernels/image/random_crop_with_bbox_op.cc" | |||||
| "${MINDDATA_DIR}/kernels/image/random_horizontal_flip_op.cc" | |||||
| "${MINDDATA_DIR}/kernels/image/random_horizontal_flip_with_bbox_op.cc" | |||||
| "${MINDDATA_DIR}/kernels/image/random_posterize_op.cc" | |||||
| "${MINDDATA_DIR}/kernels/image/random_resize_op.cc" | |||||
| "${MINDDATA_DIR}/kernels/image/random_rotation_op.cc" | |||||
| "${MINDDATA_DIR}/kernels/image/random_select_subpolicy_op.cc" | |||||
| "${MINDDATA_DIR}/kernels/image/random_solarize_op.cc" | |||||
| "${MINDDATA_DIR}/kernels/image/random_vertical_flip_op.cc" | |||||
| "${MINDDATA_DIR}/kernels/image/random_vertical_flip_with_bbox_op.cc" | |||||
| "${MINDDATA_DIR}/kernels/image/random_sharpness_op.cc" | |||||
| "${MINDDATA_DIR}/kernels/image/rescale_op.cc" | |||||
| "${MINDDATA_DIR}/kernels/image/rgba_to_bgr_op.cc" | |||||
| "${MINDDATA_DIR}/kernels/image/rgba_to_rgb_op.cc" | |||||
| "${MINDDATA_DIR}/kernels/image/sharpness_op.cc" | |||||
| "${MINDDATA_DIR}/kernels/image/solarize_op.cc" | |||||
| "${MINDDATA_DIR}/kernels/image/swap_red_blue_op.cc" | |||||
| "${MINDDATA_DIR}/kernels/image/uniform_aug_op.cc" | |||||
| "${MINDDATA_DIR}/kernels/image/resize_with_bbox_op.cc" | |||||
| "${MINDDATA_DIR}/kernels/image/random_resize_with_bbox_op.cc" | |||||
| "${MINDDATA_DIR}/kernels/image/random_color_op.cc" | |||||
| ) | |||||
| list(REMOVE_ITEM MINDDATA_ENGINE_IR_DATASETOPS_SRC_FILES | list(REMOVE_ITEM MINDDATA_ENGINE_IR_DATASETOPS_SRC_FILES | ||||
| "${MINDDATA_DIR}/engine/ir/datasetops/bucket_batch_by_length_node.cc" | |||||
| "${MINDDATA_DIR}/engine/ir/datasetops/build_sentence_piece_vocab_node.cc" | |||||
| "${MINDDATA_DIR}/engine/ir/datasetops/build_vocab_node.cc" | |||||
| "${MINDDATA_DIR}/engine/ir/datasetops/sync_wait_node.cc" | |||||
| ) | |||||
| "${MINDDATA_DIR}/engine/ir/datasetops/bucket_batch_by_length_node.cc" | |||||
| "${MINDDATA_DIR}/engine/ir/datasetops/build_sentence_piece_vocab_node.cc" | |||||
| "${MINDDATA_DIR}/engine/ir/datasetops/build_vocab_node.cc" | |||||
| "${MINDDATA_DIR}/engine/ir/datasetops/sync_wait_node.cc" | |||||
| ) | |||||
| list(REMOVE_ITEM MINDDATA_ENGINE_CONSUMERS_SRC_FILES | |||||
| "${MINDDATA_DIR}/engine/consumers/python_tree_consumer.cc" | |||||
| ) | |||||
| list(REMOVE_ITEM MINDDATA_ENGINE_SRC_FILES | |||||
| "${MINDDATA_DIR}/engine/python_runtime_context.cc" | |||||
| ) | |||||
| list(REMOVE_ITEM MINDDATA_KERNELS_DATA_SRC_FILES | list(REMOVE_ITEM MINDDATA_KERNELS_DATA_SRC_FILES | ||||
| "${MINDDATA_DIR}/kernels/data/unique_op.cc" | |||||
| ) | |||||
| "${MINDDATA_DIR}/kernels/data/unique_op.cc" | |||||
| ) | |||||
| include_directories("${CMAKE_BINARY_DIR}/minddata/dataset/engine/cache") | include_directories("${CMAKE_BINARY_DIR}/minddata/dataset/engine/cache") | ||||
| if (BUILD_MINDDATA_EXAMPLE AND (PLATFORM_ARM32 OR PLATFORM_ARM64)) | if (BUILD_MINDDATA_EXAMPLE AND (PLATFORM_ARM32 OR PLATFORM_ARM64)) | ||||
| set(MINDDATA_EXAMPLE_SRC ${CMAKE_CURRENT_SOURCE_DIR}/example/jni-example.cc) | |||||
| endif() | |||||
| set(MINDDATA_EXAMPLE_SRC ${CMAKE_CURRENT_SOURCE_DIR}/example/jni-example.cc) | |||||
| endif () | |||||
| add_library(minddata-lite SHARED | add_library(minddata-lite SHARED | ||||
| ${MINDDATA_API_SRC_FILES} | |||||
| ${MINDDATA_API_SRC_FILES} | |||||
| ${MINDDATA_CALLBACK_SRC_FILES} | ${MINDDATA_CALLBACK_SRC_FILES} | ||||
| ${MINDDATA_CORE_SRC_FILES} | ${MINDDATA_CORE_SRC_FILES} | ||||
| ${MINDDATA_ENGINE_SRC_FILES} | ${MINDDATA_ENGINE_SRC_FILES} | ||||
| @@ -1093,7 +1093,7 @@ TEST_F(MindDataTestPipeline, TestTakeDatasetDefault) { | |||||
| std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, RandomSampler(false, 7)); | std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, RandomSampler(false, 7)); | ||||
| EXPECT_NE(ds, nullptr); | EXPECT_NE(ds, nullptr); | ||||
| // Create a Take operation on ds, dafault count = -1 | |||||
| // Create a Take operation on ds, default count = -1 | |||||
| ds = ds->Take(); | ds = ds->Take(); | ||||
| EXPECT_NE(ds, nullptr); | EXPECT_NE(ds, nullptr); | ||||
| @@ -429,7 +429,7 @@ TEST_F(MindDataTestPipeline, TestRandomDatasetWithNullSampler) { | |||||
| schema->add_column("label", mindspore::TypeId::kNumberTypeUInt8, {1}); | schema->add_column("label", mindspore::TypeId::kNumberTypeUInt8, {1}); | ||||
| std::shared_ptr<Dataset> ds = RandomData(50, schema, {}, nullptr); | std::shared_ptr<Dataset> ds = RandomData(50, schema, {}, nullptr); | ||||
| // Expect failure: sampler can not be nullptr | // Expect failure: sampler can not be nullptr | ||||
| EXPECT_EQ(ds, nullptr); | |||||
| EXPECT_EQ(ds->CreateIterator(), nullptr); | |||||
| } | } | ||||
| TEST_F(MindDataTestPipeline, TestRandomDatasetDuplicateColumnName) { | TEST_F(MindDataTestPipeline, TestRandomDatasetDuplicateColumnName) { | ||||
| @@ -441,5 +441,5 @@ TEST_F(MindDataTestPipeline, TestRandomDatasetDuplicateColumnName) { | |||||
| schema->add_column("label", mindspore::TypeId::kNumberTypeUInt8, {1}); | schema->add_column("label", mindspore::TypeId::kNumberTypeUInt8, {1}); | ||||
| std::shared_ptr<Dataset> ds = RandomData(50, schema, {"image", "image"}); | std::shared_ptr<Dataset> ds = RandomData(50, schema, {"image", "image"}); | ||||
| // Expect failure: duplicate column names | // Expect failure: duplicate column names | ||||
| EXPECT_EQ(ds, nullptr); | |||||
| EXPECT_EQ(ds->CreateIterator(), nullptr); | |||||
| } | } | ||||
| @@ -443,34 +443,34 @@ TEST_F(MindDataTestPipeline, TestTFRecordDatasetExeception) { | |||||
| // This case expected to fail because the list of dir_path cannot be empty. | // This case expected to fail because the list of dir_path cannot be empty. | ||||
| std::shared_ptr<Dataset> ds1 = TFRecord({}); | std::shared_ptr<Dataset> ds1 = TFRecord({}); | ||||
| EXPECT_EQ(ds1, nullptr); | |||||
| EXPECT_EQ(ds1->CreateIterator(), nullptr); | |||||
| // This case expected to fail because the file in dir_path is not exist. | // This case expected to fail because the file in dir_path is not exist. | ||||
| std::string file_path = datasets_root_path_ + "/testTFTestAllTypes/test.data"; | std::string file_path = datasets_root_path_ + "/testTFTestAllTypes/test.data"; | ||||
| std::shared_ptr<Dataset> ds2 = TFRecord({file_path, "noexist.data"}); | std::shared_ptr<Dataset> ds2 = TFRecord({file_path, "noexist.data"}); | ||||
| EXPECT_EQ(ds2, nullptr); | |||||
| EXPECT_EQ(ds2->CreateIterator(), nullptr); | |||||
| // This case expected to fail because the file of schema is not exist. | // This case expected to fail because the file of schema is not exist. | ||||
| std::shared_ptr<Dataset> ds4 = TFRecord({file_path, "notexist.json"}); | std::shared_ptr<Dataset> ds4 = TFRecord({file_path, "notexist.json"}); | ||||
| EXPECT_EQ(ds4, nullptr); | |||||
| EXPECT_EQ(ds4->CreateIterator(), nullptr); | |||||
| // This case expected to fail because num_samples is negative. | // This case expected to fail because num_samples is negative. | ||||
| std::shared_ptr<Dataset> ds5 = TFRecord({file_path}, "", {}, -1); | std::shared_ptr<Dataset> ds5 = TFRecord({file_path}, "", {}, -1); | ||||
| EXPECT_EQ(ds5, nullptr); | |||||
| EXPECT_EQ(ds5->CreateIterator(), nullptr); | |||||
| // This case expected to fail because num_shards is negative. | // This case expected to fail because num_shards is negative. | ||||
| std::shared_ptr<Dataset> ds6 = TFRecord({file_path}, "", {}, 10, ShuffleMode::kFalse, 0); | std::shared_ptr<Dataset> ds6 = TFRecord({file_path}, "", {}, 10, ShuffleMode::kFalse, 0); | ||||
| EXPECT_EQ(ds6, nullptr); | |||||
| EXPECT_EQ(ds6->CreateIterator(), nullptr); | |||||
| // This case expected to fail because shard_id is out_of_bound. | // This case expected to fail because shard_id is out_of_bound. | ||||
| std::shared_ptr<Dataset> ds7 = TFRecord({file_path}, "", {}, 10, ShuffleMode::kFalse, 3, 3); | std::shared_ptr<Dataset> ds7 = TFRecord({file_path}, "", {}, 10, ShuffleMode::kFalse, 3, 3); | ||||
| EXPECT_EQ(ds7, nullptr); | |||||
| EXPECT_EQ(ds7->CreateIterator(), nullptr); | |||||
| // This case expected to fail because the provided number of files < num_shards in file-based sharding. | // This case expected to fail because the provided number of files < num_shards in file-based sharding. | ||||
| std::string file_path1 = datasets_root_path_ + "/test_tf_file_3_images2/train-0000-of-0001.data"; | std::string file_path1 = datasets_root_path_ + "/test_tf_file_3_images2/train-0000-of-0001.data"; | ||||
| std::string file_path2 = datasets_root_path_ + "/test_tf_file_3_images2/train-0000-of-0002.data"; | std::string file_path2 = datasets_root_path_ + "/test_tf_file_3_images2/train-0000-of-0002.data"; | ||||
| std::shared_ptr<Dataset> ds8 = TFRecord({file_path1, file_path2}, "", {}, 0, ShuffleMode::kFalse, 3); | std::shared_ptr<Dataset> ds8 = TFRecord({file_path1, file_path2}, "", {}, 0, ShuffleMode::kFalse, 3); | ||||
| EXPECT_EQ(ds8, nullptr); | |||||
| EXPECT_EQ(ds8->CreateIterator(), nullptr); | |||||
| } | } | ||||
| TEST_F(MindDataTestPipeline, TestTFRecordDatasetExeception2) { | TEST_F(MindDataTestPipeline, TestTFRecordDatasetExeception2) { | ||||
| @@ -56,7 +56,7 @@ TEST_F(MindDataTestTreeAdapter, TestSimpleTreeAdapter) { | |||||
| mindspore::dataset::TreeAdapter tree_adapter; | mindspore::dataset::TreeAdapter tree_adapter; | ||||
| Status rc = tree_adapter.BuildAndPrepare(ds, 1); | |||||
| Status rc = tree_adapter.BuildAndPrepare(ds->IRNode(), 1); | |||||
| EXPECT_TRUE(rc.IsOk()); | EXPECT_TRUE(rc.IsOk()); | ||||
| @@ -91,7 +91,7 @@ TEST_F(MindDataTestTreeAdapter, TestTreeAdapterWithRepeat) { | |||||
| mindspore::dataset::TreeAdapter tree_adapter; | mindspore::dataset::TreeAdapter tree_adapter; | ||||
| Status rc = tree_adapter.BuildAndPrepare(ds, 2); | |||||
| Status rc = tree_adapter.BuildAndPrepare(ds->IRNode(), 2); | |||||
| EXPECT_TRUE(rc.IsOk()); | EXPECT_TRUE(rc.IsOk()); | ||||
| const std::unordered_map<std::string, int32_t> map = tree_adapter.GetColumnNameMap(); | const std::unordered_map<std::string, int32_t> map = tree_adapter.GetColumnNameMap(); | ||||
| @@ -128,7 +128,7 @@ TEST_F(MindDataTestTreeAdapter, TestProjectMapTreeAdapter) { | |||||
| mindspore::dataset::TreeAdapter tree_adapter; | mindspore::dataset::TreeAdapter tree_adapter; | ||||
| Status rc = tree_adapter.BuildAndPrepare(ds, 2); | |||||
| Status rc = tree_adapter.BuildAndPrepare(ds->IRNode(), 2); | |||||
| EXPECT_TRUE(rc.IsOk()); | EXPECT_TRUE(rc.IsOk()); | ||||