Merge pull request !7593 from xiaotianci/device_optags/v1.1.0
| @@ -62,6 +62,7 @@ | |||||
| #include "minddata/dataset/engine/ir/datasetops/shuffle_node.h" | #include "minddata/dataset/engine/ir/datasetops/shuffle_node.h" | ||||
| #include "minddata/dataset/engine/ir/datasetops/skip_node.h" | #include "minddata/dataset/engine/ir/datasetops/skip_node.h" | ||||
| #include "minddata/dataset/engine/ir/datasetops/take_node.h" | #include "minddata/dataset/engine/ir/datasetops/take_node.h" | ||||
| #include "minddata/dataset/engine/ir/datasetops/transfer_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/zip_node.h" | #include "minddata/dataset/engine/ir/datasetops/zip_node.h" | ||||
| #ifndef ENABLE_ANDROID | #ifndef ENABLE_ANDROID | ||||
| @@ -72,6 +73,7 @@ | |||||
| #include "minddata/dataset/core/config_manager.h" | #include "minddata/dataset/core/config_manager.h" | ||||
| #include "minddata/dataset/util/path.h" | #include "minddata/dataset/util/path.h" | ||||
| #include "minddata/dataset/util/random.h" | #include "minddata/dataset/util/random.h" | ||||
| #include "minddata/dataset/util/services.h" | |||||
| // IR leaf nodes | // IR leaf nodes | ||||
| #include "minddata/dataset/engine/ir/datasetops/source/album_node.h" | #include "minddata/dataset/engine/ir/datasetops/source/album_node.h" | ||||
| @@ -125,6 +127,56 @@ std::shared_ptr<Iterator> Dataset::CreateIterator(std::vector<std::string> colum | |||||
| return iter; | return iter; | ||||
| } | } | ||||
| // Function to return a transferred Node that transfers data through a device. | |||||
| bool Dataset::DeviceQueue(bool send_epoch_end) { | |||||
| Status rc; | |||||
| // Build and launch tree | |||||
| std::unique_ptr<RuntimeContext> runtime_context = std::make_unique<RuntimeContext>(); | |||||
| rc = runtime_context->Init(); | |||||
| if (rc.IsError()) { | |||||
| MS_LOG(ERROR) << "Failed to init runtime context. Error status: " << rc; | |||||
| return false; | |||||
| } | |||||
| // Get a uuid for queue name | |||||
| std::string queue_name = Services::GetUniqueID(); | |||||
| // TODO(CRC): | |||||
| // Get device type from ms context | |||||
| std::string device_type = "CPU"; | |||||
| // Get device ID from children | |||||
| int32_t device_id = 0; | |||||
| rc = TransferNode::get_distribution(shared_from_this(), &device_id); | |||||
| if (rc.IsError()) { | |||||
| MS_LOG(ERROR) << "Failed to get shard id. Error status: " << rc; | |||||
| return false; | |||||
| } | |||||
| // Add TransferNode IR on top of dataset d | |||||
| auto ds = std::make_shared<TransferNode>(shared_from_this(), queue_name, device_id, device_type, send_epoch_end); | |||||
| // Get ToDevice consumer | |||||
| auto consumer = std::make_unique<ToDevice>(device_type, send_epoch_end, -1); | |||||
| ToDevice *consumer_ = consumer.get(); | |||||
| rc = consumer->Init(ds); | |||||
| if (rc.IsError()) { | |||||
| MS_LOG(ERROR) << "ToDevice: Failed to init. Error status: " << rc; | |||||
| return false; | |||||
| } | |||||
| runtime_context->AssignConsumer(std::move(consumer)); | |||||
| // Send data to device | |||||
| rc = consumer_->Send(); | |||||
| if (rc.IsError()) { | |||||
| MS_LOG(ERROR) << "ToDevice: Failed to send data to device. Error status: " << rc; | |||||
| return false; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| #ifndef ENABLE_ANDROID | #ifndef ENABLE_ANDROID | ||||
| // Function to create the saver, which will build and launch the execution tree and save data | // Function to create the saver, which will build and launch the execution tree and save data | ||||
| bool Dataset::Save(std::string dataset_path, int32_t num_files, std::string dataset_type) { | bool Dataset::Save(std::string dataset_path, int32_t num_files, std::string dataset_type) { | ||||
| @@ -931,6 +983,7 @@ std::shared_ptr<DatasetCache> CreateDatasetCache(session_id_type id, uint64_t me | |||||
| auto cache = std::make_shared<DatasetCacheImpl>(id, mem_sz, spill, hostname, port, num_connections, prefetch_sz); | auto cache = std::make_shared<DatasetCacheImpl>(id, mem_sz, spill, hostname, port, num_connections, prefetch_sz); | ||||
| return cache->ValidateParams() ? cache : nullptr; | return cache->ValidateParams() ? cache : nullptr; | ||||
| } | } | ||||
| #endif | #endif | ||||
| } // namespace api | } // namespace api | ||||
| @@ -74,13 +74,31 @@ Status IteratorConsumer::GetNextAsMap(std::unordered_map<std::string, TensorPtr> | |||||
| // ToDevice | // ToDevice | ||||
| Status ToDevice::Init(std::shared_ptr<api::Dataset> d) { | Status ToDevice::Init(std::shared_ptr<api::Dataset> d) { | ||||
| // TODO(CRC): | |||||
| // Get device ID from children look at get_distribution in python | |||||
| // Add DeviceQue IR on top of dataset d | |||||
| return tree_adapter_->BuildAndPrepare(std::move(d), num_epochs_); | return tree_adapter_->BuildAndPrepare(std::move(d), num_epochs_); | ||||
| } | } | ||||
| Status ToDevice::Send() { | |||||
| std::unique_ptr<DataBuffer> db; | |||||
| RETURN_IF_NOT_OK(tree_adapter_->Launch()); | |||||
| RETURN_IF_NOT_OK(tree_adapter_->root()->GetNextBuffer(&db)); | |||||
| return Status::OK(); | |||||
| } | |||||
| Status ToDevice::Continue() { | |||||
| // tree_.root() must be DeviceQueueOp | |||||
| DeviceQueueOp *op = dynamic_cast<DeviceQueueOp *>(tree_adapter_->root().get()); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(op != nullptr, "ContinueSend only supported by DeviceQueueOp"); | |||||
| op->ContinueSend(); | |||||
| return Status::OK(); | |||||
| } | |||||
| Status ToDevice::Stop() { | |||||
| DeviceQueueOp *op = dynamic_cast<DeviceQueueOp *>(tree_adapter_->root().get()); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(op != nullptr, "StopSend only supported by DeviceQueueOp"); | |||||
| op->StopSend(); | |||||
| return Status::OK(); | |||||
| } | |||||
| #ifndef ENABLE_ANDROID | #ifndef ENABLE_ANDROID | ||||
| // SaveToDisk | // SaveToDisk | ||||
| Status SaveToDisk::ValidateParams() { | Status SaveToDisk::ValidateParams() { | ||||
| @@ -126,23 +126,27 @@ class SaveToDisk : public TreeConsumer { | |||||
| /// Consumer that iterates over the dataset and send it to a device | /// Consumer that iterates over the dataset and send it to a device | ||||
| class ToDevice : public TreeConsumer { | class ToDevice : public TreeConsumer { | ||||
| public: | public: | ||||
| ToDevice(std::string device_type, bool send_epoch_end, int32_t num_epochs) | |||||
| ToDevice(std::string device_type, bool send_epoch_end, int32_t num_epochs = -1) | |||||
| : TreeConsumer(), device_type_(device_type), send_epoch_end_(send_epoch_end), num_epochs_(num_epochs) {} | : TreeConsumer(), device_type_(device_type), send_epoch_end_(send_epoch_end), num_epochs_(num_epochs) {} | ||||
| Status Init(std::shared_ptr<api::Dataset> d) override; | Status Init(std::shared_ptr<api::Dataset> d) override; | ||||
| Status Send() { | |||||
| // TODO(CRC): launch the tree | |||||
| return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet."); | |||||
| } | |||||
| Status Stop() { | |||||
| // TODO(CRC): Get root + call StopSend | |||||
| return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet."); | |||||
| } | |||||
| Status Continue() { | |||||
| // TODO(CRC): Get root + call StopSend | |||||
| return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet."); | |||||
| } | |||||
| /// Send the data to device | |||||
| /// \return Status error code | |||||
| Status Send(); | |||||
| /// Stop to send data to device | |||||
| /// \return Status error code | |||||
| Status Stop(); | |||||
| /// Continue to send data to device | |||||
| /// \return Status error code | |||||
| Status Continue(); | |||||
| protected: | |||||
| /// Method to return the name of the consumer | |||||
| /// \return string | |||||
| std::string Name() override { return "ToDevice"; } | |||||
| private: | private: | ||||
| std::string device_type_; | std::string device_type_; | ||||
| @@ -15,6 +15,7 @@ set(DATASET_ENGINE_IR_DATASETOPS_SRC_FILES | |||||
| skip_node.cc | skip_node.cc | ||||
| sync_wait_node.cc | sync_wait_node.cc | ||||
| take_node.cc | take_node.cc | ||||
| transfer_node.cc | |||||
| zip_node.cc | zip_node.cc | ||||
| ) | ) | ||||
| @@ -68,6 +68,13 @@ std::vector<std::shared_ptr<DatasetOp>> AlbumNode::Build() { | |||||
| return node_ops; | return node_ops; | ||||
| } | } | ||||
| // Get the shard id of node | |||||
| Status AlbumNode::GetShardId(int32_t *shard_id) { | |||||
| *shard_id = sampler_->ShardId(); | |||||
| return Status::OK(); | |||||
| } | |||||
| } // namespace api | } // namespace api | ||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -44,6 +44,10 @@ class AlbumNode : public Dataset { | |||||
| /// \return Status Status::OK() if all the parameters are valid | /// \return Status Status::OK() if all the parameters are valid | ||||
| Status ValidateParams() override; | Status ValidateParams() override; | ||||
| /// \brief Get the shard id of node | |||||
| /// \return Status Status::OK() if get shard id successfully | |||||
| Status GetShardId(int32_t *shard_id) override; | |||||
| private: | private: | ||||
| std::string dataset_dir_; | std::string dataset_dir_; | ||||
| std::string schema_path_; | std::string schema_path_; | ||||
| @@ -67,6 +67,13 @@ std::vector<std::shared_ptr<DatasetOp>> CelebANode::Build() { | |||||
| return node_ops; | return node_ops; | ||||
| } | } | ||||
| // Get the shard id of node | |||||
| Status CelebANode::GetShardId(int32_t *shard_id) { | |||||
| *shard_id = sampler_->ShardId(); | |||||
| return Status::OK(); | |||||
| } | |||||
| } // namespace api | } // namespace api | ||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -46,6 +46,10 @@ class CelebANode : public Dataset { | |||||
| /// \return Status Status::OK() if all the parameters are valid | /// \return Status Status::OK() if all the parameters are valid | ||||
| Status ValidateParams() override; | Status ValidateParams() override; | ||||
| /// \brief Get the shard id of node | |||||
| /// \return Status Status::OK() if get shard id successfully | |||||
| Status GetShardId(int32_t *shard_id) override; | |||||
| private: | private: | ||||
| std::string dataset_dir_; | std::string dataset_dir_; | ||||
| std::string usage_; | std::string usage_; | ||||
| @@ -66,6 +66,13 @@ std::vector<std::shared_ptr<DatasetOp>> Cifar100Node::Build() { | |||||
| return node_ops; | return node_ops; | ||||
| } | } | ||||
| // Get the shard id of node | |||||
| Status Cifar100Node::GetShardId(int32_t *shard_id) { | |||||
| *shard_id = sampler_->ShardId(); | |||||
| return Status::OK(); | |||||
| } | |||||
| } // namespace api | } // namespace api | ||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -44,6 +44,10 @@ class Cifar100Node : public Dataset { | |||||
| /// \return Status Status::OK() if all the parameters are valid | /// \return Status Status::OK() if all the parameters are valid | ||||
| Status ValidateParams() override; | Status ValidateParams() override; | ||||
| /// \brief Get the shard id of node | |||||
| /// \return Status Status::OK() if get shard id successfully | |||||
| Status GetShardId(int32_t *shard_id) override; | |||||
| private: | private: | ||||
| std::string dataset_dir_; | std::string dataset_dir_; | ||||
| std::string usage_; | std::string usage_; | ||||
| @@ -64,6 +64,13 @@ std::vector<std::shared_ptr<DatasetOp>> Cifar10Node::Build() { | |||||
| return node_ops; | return node_ops; | ||||
| } | } | ||||
| // Get the shard id of node | |||||
| Status Cifar10Node::GetShardId(int32_t *shard_id) { | |||||
| *shard_id = sampler_->ShardId(); | |||||
| return Status::OK(); | |||||
| } | |||||
| } // namespace api | } // namespace api | ||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -44,6 +44,10 @@ class Cifar10Node : public Dataset { | |||||
| /// \return Status Status::OK() if all the parameters are valid | /// \return Status Status::OK() if all the parameters are valid | ||||
| Status ValidateParams() override; | Status ValidateParams() override; | ||||
| /// \brief Get the shard id of node | |||||
| /// \return Status Status::OK() if get shard id successfully | |||||
| Status GetShardId(int32_t *shard_id) override; | |||||
| private: | private: | ||||
| std::string dataset_dir_; | std::string dataset_dir_; | ||||
| std::string usage_; | std::string usage_; | ||||
| @@ -213,6 +213,13 @@ std::vector<std::shared_ptr<DatasetOp>> CLUENode::Build() { | |||||
| return node_ops; | return node_ops; | ||||
| } | } | ||||
| // Get the shard id of node | |||||
| Status CLUENode::GetShardId(int32_t *shard_id) { | |||||
| *shard_id = shard_id_; | |||||
| return Status::OK(); | |||||
| } | |||||
| } // namespace api | } // namespace api | ||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -45,6 +45,10 @@ class CLUENode : public Dataset { | |||||
| /// \return Status Status::OK() if all the parameters are valid | /// \return Status Status::OK() if all the parameters are valid | ||||
| Status ValidateParams() override; | Status ValidateParams() override; | ||||
| /// \brief Get the shard id of node | |||||
| /// \return Status Status::OK() if get shard id successfully | |||||
| Status GetShardId(int32_t *shard_id) override; | |||||
| private: | private: | ||||
| /// \brief Split string based on a character delimiter | /// \brief Split string based on a character delimiter | ||||
| /// \return A string vector | /// \return A string vector | ||||
| @@ -117,6 +117,14 @@ std::vector<std::shared_ptr<DatasetOp>> CocoNode::Build() { | |||||
| return node_ops; | return node_ops; | ||||
| } | } | ||||
| // Get the shard id of node | |||||
| Status CocoNode::GetShardId(int32_t *shard_id) { | |||||
| *shard_id = sampler_->ShardId(); | |||||
| return Status::OK(); | |||||
| } | |||||
| } // namespace api | } // namespace api | ||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -43,6 +43,10 @@ class CocoNode : public Dataset { | |||||
| /// \return Status Status::OK() if all the parameters are valid | /// \return Status Status::OK() if all the parameters are valid | ||||
| Status ValidateParams() override; | Status ValidateParams() override; | ||||
| /// \brief Get the shard id of node | |||||
| /// \return Status Status::OK() if get shard id successfully | |||||
| Status GetShardId(int32_t *shard_id) override; | |||||
| private: | private: | ||||
| std::string dataset_dir_; | std::string dataset_dir_; | ||||
| std::string annotation_file_; | std::string annotation_file_; | ||||
| @@ -122,6 +122,14 @@ std::vector<std::shared_ptr<DatasetOp>> CSVNode::Build() { | |||||
| return node_ops; | return node_ops; | ||||
| } | } | ||||
| // Get the shard id of node | |||||
| Status CSVNode::GetShardId(int32_t *shard_id) { | |||||
| *shard_id = shard_id_; | |||||
| return Status::OK(); | |||||
| } | |||||
| } // namespace api | } // namespace api | ||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -66,6 +66,10 @@ class CSVNode : public Dataset { | |||||
| /// \return Status Status::OK() if all the parameters are valid | /// \return Status Status::OK() if all the parameters are valid | ||||
| Status ValidateParams() override; | Status ValidateParams() override; | ||||
| /// \brief Get the shard id of node | |||||
| /// \return Status Status::OK() if get shard id successfully | |||||
| Status GetShardId(int32_t *shard_id) override; | |||||
| private: | private: | ||||
| std::vector<std::string> dataset_files_; | std::vector<std::string> dataset_files_; | ||||
| char field_delim_; | char field_delim_; | ||||
| @@ -70,6 +70,14 @@ std::vector<std::shared_ptr<DatasetOp>> ImageFolderNode::Build() { | |||||
| std::move(sampler_->Build()))); | std::move(sampler_->Build()))); | ||||
| return node_ops; | return node_ops; | ||||
| } | } | ||||
| // Get the shard id of node | |||||
| Status ImageFolderNode::GetShardId(int32_t *shard_id) { | |||||
| *shard_id = sampler_->ShardId(); | |||||
| return Status::OK(); | |||||
| } | |||||
| } // namespace api | } // namespace api | ||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -51,6 +51,10 @@ class ImageFolderNode : public Dataset { | |||||
| /// \return Status Status::OK() if all the parameters are valid | /// \return Status Status::OK() if all the parameters are valid | ||||
| Status ValidateParams() override; | Status ValidateParams() override; | ||||
| /// \brief Get the shard id of node | |||||
| /// \return Status Status::OK() if get shard id successfully | |||||
| Status GetShardId(int32_t *shard_id) override; | |||||
| private: | private: | ||||
| std::string dataset_dir_; | std::string dataset_dir_; | ||||
| bool decode_; | bool decode_; | ||||
| @@ -85,6 +85,14 @@ std::vector<std::shared_ptr<DatasetOp>> ManifestNode::Build() { | |||||
| return node_ops; | return node_ops; | ||||
| } | } | ||||
| // Get the shard id of node | |||||
| Status ManifestNode::GetShardId(int32_t *shard_id) { | |||||
| *shard_id = sampler_->ShardId(); | |||||
| return Status::OK(); | |||||
| } | |||||
| } // namespace api | } // namespace api | ||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -44,6 +44,10 @@ class ManifestNode : public Dataset { | |||||
| /// \return Status Status::OK() if all the parameters are valid | /// \return Status Status::OK() if all the parameters are valid | ||||
| Status ValidateParams() override; | Status ValidateParams() override; | ||||
| /// \brief Get the shard id of node | |||||
| /// \return Status Status::OK() if get shard id successfully | |||||
| Status GetShardId(int32_t *shard_id) override; | |||||
| private: | private: | ||||
| std::string dataset_file_; | std::string dataset_file_; | ||||
| std::string usage_; | std::string usage_; | ||||
| @@ -160,6 +160,13 @@ std::vector<std::shared_ptr<DatasetOp>> MindDataNode::Build() { | |||||
| return node_ops; | return node_ops; | ||||
| } | } | ||||
| // Get the shard id of node | |||||
| Status MindDataNode::GetShardId(int32_t *shard_id) { | |||||
| *shard_id = sampler_->ShardId(); | |||||
| return Status::OK(); | |||||
| } | |||||
| } // namespace api | } // namespace api | ||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -48,6 +48,10 @@ class MindDataNode : public Dataset { | |||||
| /// \return Status Status::OK() if all the parameters are valid | /// \return Status Status::OK() if all the parameters are valid | ||||
| Status ValidateParams() override; | Status ValidateParams() override; | ||||
| /// \brief Get the shard id of node | |||||
| /// \return Status Status::OK() if get shard id successfully | |||||
| Status GetShardId(int32_t *shard_id) override; | |||||
| /// \brief Build sampler chain for minddata dataset | /// \brief Build sampler chain for minddata dataset | ||||
| /// \return Status Status::OK() if input sampler is valid | /// \return Status Status::OK() if input sampler is valid | ||||
| Status BuildMindDatasetSamplerChain(const std::shared_ptr<SamplerObj> &sampler, | Status BuildMindDatasetSamplerChain(const std::shared_ptr<SamplerObj> &sampler, | ||||
| @@ -60,6 +60,13 @@ std::vector<std::shared_ptr<DatasetOp>> MnistNode::Build() { | |||||
| return node_ops; | return node_ops; | ||||
| } | } | ||||
| // Get the shard id of node | |||||
| Status MnistNode::GetShardId(int32_t *shard_id) { | |||||
| *shard_id = sampler_->ShardId(); | |||||
| return Status::OK(); | |||||
| } | |||||
| } // namespace api | } // namespace api | ||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -44,6 +44,10 @@ class MnistNode : public Dataset { | |||||
| /// \return Status Status::OK() if all the parameters are valid | /// \return Status Status::OK() if all the parameters are valid | ||||
| Status ValidateParams() override; | Status ValidateParams() override; | ||||
| /// \brief Get the shard id of node | |||||
| /// \return Status Status::OK() if get shard id successfully | |||||
| Status GetShardId(int32_t *shard_id) override; | |||||
| private: | private: | ||||
| std::string dataset_dir_; | std::string dataset_dir_; | ||||
| std::string usage_; | std::string usage_; | ||||
| @@ -99,6 +99,13 @@ std::vector<std::shared_ptr<DatasetOp>> RandomNode::Build() { | |||||
| return node_ops; | return node_ops; | ||||
| } | } | ||||
| // Get the shard id of node | |||||
| Status RandomNode::GetShardId(int32_t *shard_id) { | |||||
| *shard_id = sampler_->ShardId(); | |||||
| return Status::OK(); | |||||
| } | |||||
| } // namespace api | } // namespace api | ||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -65,6 +65,10 @@ class RandomNode : public Dataset { | |||||
| /// \return Status Status::OK() if all the parameters are valid | /// \return Status Status::OK() if all the parameters are valid | ||||
| Status ValidateParams() override; | Status ValidateParams() override; | ||||
| /// \brief Get the shard id of node | |||||
| /// \return Status Status::OK() if get shard id successfully | |||||
| Status GetShardId(int32_t *shard_id) override; | |||||
| private: | private: | ||||
| /// \brief A quick inline for producing a random number between (and including) min/max | /// \brief A quick inline for producing a random number between (and including) min/max | ||||
| /// \param[in] min minimum number that can be generated. | /// \param[in] min minimum number that can be generated. | ||||
| @@ -95,6 +95,13 @@ std::vector<std::shared_ptr<DatasetOp>> TextFileNode::Build() { | |||||
| return node_ops; | return node_ops; | ||||
| } | } | ||||
| // Get the shard id of node | |||||
| Status TextFileNode::GetShardId(int32_t *shard_id) { | |||||
| *shard_id = shard_id_; | |||||
| return Status::OK(); | |||||
| } | |||||
| } // namespace api | } // namespace api | ||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -45,6 +45,10 @@ class TextFileNode : public Dataset { | |||||
| /// \return Status Status::OK() if all the parameters are valid | /// \return Status Status::OK() if all the parameters are valid | ||||
| Status ValidateParams() override; | Status ValidateParams() override; | ||||
| /// \brief Get the shard id of node | |||||
| /// \return Status Status::OK() if get shard id successfully | |||||
| Status GetShardId(int32_t *shard_id) override; | |||||
| private: | private: | ||||
| std::vector<std::string> dataset_files_; | std::vector<std::string> dataset_files_; | ||||
| int32_t num_samples_; | int32_t num_samples_; | ||||
| @@ -80,6 +80,13 @@ std::vector<std::shared_ptr<DatasetOp>> TFRecordNode::Build() { | |||||
| return node_ops; | return node_ops; | ||||
| } | } | ||||
| // Get the shard id of node | |||||
| Status TFRecordNode::GetShardId(int32_t *shard_id) { | |||||
| *shard_id = shard_id_; | |||||
| return Status::OK(); | |||||
| } | |||||
| } // namespace api | } // namespace api | ||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -72,6 +72,10 @@ class TFRecordNode : public Dataset { | |||||
| /// \return Status Status::OK() if all the parameters are valid | /// \return Status Status::OK() if all the parameters are valid | ||||
| Status ValidateParams() override; | Status ValidateParams() override; | ||||
| /// \brief Get the shard id of node | |||||
| /// \return Status Status::OK() if get shard id successfully | |||||
| Status GetShardId(int32_t *shard_id) override; | |||||
| private: | private: | ||||
| std::vector<std::string> dataset_files_; | std::vector<std::string> dataset_files_; | ||||
| std::string schema_path_; // schema_path_ path to schema file. It is set when type of schema parameter is string | std::string schema_path_; // schema_path_ path to schema file. It is set when type of schema parameter is string | ||||
| @@ -112,6 +112,13 @@ std::vector<std::shared_ptr<DatasetOp>> VOCNode::Build() { | |||||
| return node_ops; | return node_ops; | ||||
| } | } | ||||
| // Get the shard id of node | |||||
| Status VOCNode::GetShardId(int32_t *shard_id) { | |||||
| *shard_id = sampler_->ShardId(); | |||||
| return Status::OK(); | |||||
| } | |||||
| } // namespace api | } // namespace api | ||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -45,6 +45,10 @@ class VOCNode : public Dataset { | |||||
| /// \return Status Status::OK() if all the parameters are valid | /// \return Status Status::OK() if all the parameters are valid | ||||
| Status ValidateParams() override; | Status ValidateParams() override; | ||||
| /// \brief Get the shard id of node | |||||
| /// \return Status Status::OK() if get shard id successfully | |||||
| Status GetShardId(int32_t *shard_id) override; | |||||
| private: | private: | ||||
| const std::string kColumnImage = "image"; | const std::string kColumnImage = "image"; | ||||
| const std::string kColumnTarget = "target"; | const std::string kColumnTarget = "target"; | ||||
| @@ -0,0 +1,90 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "minddata/dataset/engine/ir/datasetops/transfer_node.h" | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "minddata/dataset/engine/datasetops/device_queue_op.h" | |||||
| #include "minddata/dataset/util/status.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| namespace api { | |||||
| // Constructor for TransferNode | |||||
| TransferNode::TransferNode(std::shared_ptr<Dataset> child, const std::string &queue_name, int32_t device_id, | |||||
| const std::string &device_type, bool send_epoch_end) | |||||
| : queue_name_(queue_name), | |||||
| device_id_(device_id), | |||||
| device_type_(device_type), | |||||
| prefetch_size_(16), | |||||
| send_epoch_end_(send_epoch_end), | |||||
| total_batch_(0) { | |||||
| this->children.push_back(child); | |||||
| } | |||||
| // Validator for TransferNode | |||||
| Status TransferNode::ValidateParams() { | |||||
| // Check if device_type_ is in {"CPU", "GPU", "Ascend"} | |||||
| RETURN_IF_NOT_OK(ValidateStringValue("TransferNode", device_type_, {"CPU", "GPU", "Ascend"})); | |||||
| return Status::OK(); | |||||
| } | |||||
| // Function to build TransferNode | |||||
| std::vector<std::shared_ptr<DatasetOp>> TransferNode::Build() { | |||||
| // A vector containing shared pointer to the Dataset Ops that this object will create | |||||
| std::vector<std::shared_ptr<DatasetOp>> node_ops; | |||||
| // Convert device_type_ from string to DeviceType | |||||
| DeviceQueueOp::DeviceType type; | |||||
| if (device_type_ == "CPU") { | |||||
| type = DeviceQueueOp::DeviceType::CPU; | |||||
| } else if (device_type_ == "GPU") { | |||||
| type = DeviceQueueOp::DeviceType::GPU; | |||||
| } else if (device_type_ == "Ascend") { | |||||
| type = DeviceQueueOp::DeviceType::Ascend; | |||||
| } | |||||
| node_ops.push_back( | |||||
| std::make_shared<DeviceQueueOp>(queue_name_, type, device_id_, prefetch_size_, send_epoch_end_, total_batch_)); | |||||
| return node_ops; | |||||
| } | |||||
| // Function to get the device_id | |||||
| Status TransferNode::get_distribution(std::shared_ptr<Dataset> ds, int32_t *device_id) { | |||||
| // Get device id according to the type of dataset | |||||
| Status rc = ds->GetShardId(device_id); | |||||
| if (rc != Status::OK()) { | |||||
| // Get device id from the child node | |||||
| if (ds->children.size()) { | |||||
| ds = ds->children[0]; | |||||
| return TransferNode::get_distribution(ds, device_id); | |||||
| } else { | |||||
| std::string err_msg = "Unknown dataset type."; | |||||
| MS_LOG(ERROR) << err_msg; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| } // namespace api | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,62 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_TRANSFER_NODE_H_ | |||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_TRANSFER_NODE_H_ | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "minddata/dataset/include/datasets.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| namespace api { | |||||
| class TransferNode : public Dataset { | |||||
| public: | |||||
| /// \brief Constructor | |||||
| TransferNode(std::shared_ptr<Dataset> child, const std::string &queue_name, int32_t device_id, | |||||
| const std::string &device_type, bool send_epoch_end); | |||||
| /// \brief Destructor | |||||
| ~TransferNode() = default; | |||||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | |||||
| /// \return shared pointer to the list of newly created DatasetOps | |||||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | |||||
| /// \brief Parameters validation | |||||
| /// \return Status Status::OK() if all the parameters are valid | |||||
| Status ValidateParams() override; | |||||
| static Status get_distribution(std::shared_ptr<Dataset> ds, int32_t *device_id); | |||||
| private: | |||||
| std::string queue_name_; | |||||
| int32_t device_id_; | |||||
| std::string device_type_; | |||||
| int32_t prefetch_size_; | |||||
| bool send_epoch_end_; | |||||
| int32_t total_batch_; | |||||
| }; | |||||
| } // namespace api | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_TRANSFER_NODE_H_ | |||||
| @@ -57,6 +57,10 @@ class TreeAdapter { | |||||
| // to be able to launch a thread. BuildAndPrepare needs to be called before this function | // to be able to launch a thread. BuildAndPrepare needs to be called before this function | ||||
| TaskGroup *AllTasks() const { return tree_ != nullptr ? tree_->AllTasks() : nullptr; } | TaskGroup *AllTasks() const { return tree_ != nullptr ? tree_->AllTasks() : nullptr; } | ||||
| std::shared_ptr<DatasetOp> root() { return tree_->root(); } | |||||
| Status Launch() const { return tree_->Launch(); } | |||||
| private: | private: | ||||
| // This RECURSIVE function converts IR nodes into DatasetOp in ExecutionTree. IR could build a vector of ops. In | // This RECURSIVE function converts IR nodes into DatasetOp in ExecutionTree. IR could build a vector of ops. In | ||||
| // such case, the first node is returned. Op is added as child when the current function returns. | // such case, the first node is returned. Op is added as child when the current function returns. | ||||
| @@ -96,6 +96,7 @@ class RepeatNode; | |||||
| class ShuffleNode; | class ShuffleNode; | ||||
| class SkipNode; | class SkipNode; | ||||
| class TakeNode; | class TakeNode; | ||||
| class TransferNode; | |||||
| class ZipNode; | class ZipNode; | ||||
| #define RETURN_EMPTY_IF_ERROR(_s) \ | #define RETURN_EMPTY_IF_ERROR(_s) \ | ||||
| @@ -559,6 +560,7 @@ class Dataset : public std::enable_shared_from_this<Dataset> { | |||||
| public: | public: | ||||
| // need friend class so they can access the children_ field | // need friend class so they can access the children_ field | ||||
| friend class Iterator; | friend class Iterator; | ||||
| friend class TransferNode; | |||||
| friend class mindspore::dataset::TreeAdapter; | friend class mindspore::dataset::TreeAdapter; | ||||
| /// \brief Constructor | /// \brief Constructor | ||||
| @@ -579,6 +581,12 @@ class Dataset : public std::enable_shared_from_this<Dataset> { | |||||
| /// \return Status Status::OK() if all the parameters are valid | /// \return Status Status::OK() if all the parameters are valid | ||||
| virtual Status ValidateParams() = 0; | virtual Status ValidateParams() = 0; | ||||
| /// \brief Pure virtual function for derived class to get the shard id of specific node | |||||
| /// \return Status Status::OK() if get shard id successfully | |||||
| virtual Status GetShardId(int32_t *shard_id) { | |||||
| return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet."); | |||||
| } | |||||
| /// \brief Gets the dataset size | /// \brief Gets the dataset size | ||||
| /// \return status code | /// \return status code | ||||
| int64_t GetDatasetSize(); | int64_t GetDatasetSize(); | ||||
| @@ -617,6 +625,13 @@ class Dataset : public std::enable_shared_from_this<Dataset> { | |||||
| /// \return Shared pointer to the Iterator | /// \return Shared pointer to the Iterator | ||||
| std::shared_ptr<Iterator> CreateIterator(std::vector<std::string> columns = {}); | std::shared_ptr<Iterator> CreateIterator(std::vector<std::string> columns = {}); | ||||
| /// \brief Function to transfer data through a device. | |||||
| /// \notes If device is Ascend, features of data will be transferred one by one. The limitation | |||||
| /// of data transmission per time is 256M. | |||||
| /// \param[in] send_epoch_end Whether to send end of sequence to device or not (default=True). | |||||
| /// \return Returns true if no error encountered else false. | |||||
| bool DeviceQueue(bool send_epoch_end = true); | |||||
| #ifndef ENABLE_ANDROID | #ifndef ENABLE_ANDROID | ||||
| /// \brief Function to create a Saver to save the dynamic data processed by the dataset pipeline | /// \brief Function to create a Saver to save the dynamic data processed by the dataset pipeline | ||||
| /// \note Usage restrictions: | /// \note Usage restrictions: | ||||
| @@ -17,8 +17,9 @@ | |||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_SAMPLERS_H_ | #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_SAMPLERS_H_ | ||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_SAMPLERS_H_ | #define MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_SAMPLERS_H_ | ||||
| #include <vector> | |||||
| #include <memory> | #include <memory> | ||||
| #include <string> | |||||
| #include <vector> | |||||
| #ifndef ENABLE_ANDROID | #ifndef ENABLE_ANDROID | ||||
| #include "minddata/dataset/engine/datasetops/source/mindrecord_op.h" | #include "minddata/dataset/engine/datasetops/source/mindrecord_op.h" | ||||
| @@ -48,6 +49,10 @@ class SamplerObj : public std::enable_shared_from_this<SamplerObj> { | |||||
| /// \return Shared pointers to the newly created Sampler | /// \return Shared pointers to the newly created Sampler | ||||
| virtual std::shared_ptr<Sampler> Build() = 0; | virtual std::shared_ptr<Sampler> Build() = 0; | ||||
| /// \brief Function for derived class to get the shard id of sampler | |||||
| /// \return The shard id of the derived sampler | |||||
| virtual int64_t ShardId() { return 0; } | |||||
| #ifndef ENABLE_ANDROID | #ifndef ENABLE_ANDROID | ||||
| /// \brief Virtual function to convert a SamplerObj class into a runtime mindrecord sampler object, | /// \brief Virtual function to convert a SamplerObj class into a runtime mindrecord sampler object, | ||||
| /// only override by SubsetRandomSampler, PkSampler, RandomSampler, SequentialSampler, DistributedSampler | /// only override by SubsetRandomSampler, PkSampler, RandomSampler, SequentialSampler, DistributedSampler | ||||
| @@ -134,6 +139,10 @@ class DistributedSamplerObj : public SamplerObj { | |||||
| bool ValidateParams() override; | bool ValidateParams() override; | ||||
| /// \brief Function to get the shard id of sampler | |||||
| /// \return The shard id of sampler | |||||
| int64_t ShardId() override { return shard_id_; } | |||||
| private: | private: | ||||
| int64_t num_shards_; | int64_t num_shards_; | ||||
| int64_t shard_id_; | int64_t shard_id_; | ||||