Browse Source

!7593 C++ api add DeviceQueueOp

Merge pull request !7593 from xiaotianci/device_op
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
95fe324798
39 changed files with 443 additions and 18 deletions
  1. +53
    -0
      mindspore/ccsrc/minddata/dataset/api/datasets.cc
  2. +22
    -4
      mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.cc
  3. +17
    -13
      mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.h
  4. +1
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/CMakeLists.txt
  5. +7
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/album_node.cc
  6. +4
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/album_node.h
  7. +7
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/celeba_node.cc
  8. +4
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/celeba_node.h
  9. +7
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar100_node.cc
  10. +4
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar100_node.h
  11. +7
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar10_node.cc
  12. +4
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar10_node.h
  13. +7
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/clue_node.cc
  14. +4
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/clue_node.h
  15. +8
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/coco_node.cc
  16. +4
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/coco_node.h
  17. +8
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/csv_node.cc
  18. +4
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/csv_node.h
  19. +8
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/image_folder_node.cc
  20. +4
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/image_folder_node.h
  21. +8
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/manifest_node.cc
  22. +4
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/manifest_node.h
  23. +7
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/minddata_node.cc
  24. +4
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/minddata_node.h
  25. +7
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/mnist_node.cc
  26. +4
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/mnist_node.h
  27. +7
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/random_node.cc
  28. +4
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/random_node.h
  29. +7
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/text_file_node.cc
  30. +4
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/text_file_node.h
  31. +7
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/tf_record_node.cc
  32. +4
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/tf_record_node.h
  33. +7
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/voc_node.cc
  34. +4
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/voc_node.h
  35. +90
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/transfer_node.cc
  36. +62
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/transfer_node.h
  37. +4
    -0
      mindspore/ccsrc/minddata/dataset/engine/tree_adapter.h
  38. +15
    -0
      mindspore/ccsrc/minddata/dataset/include/datasets.h
  39. +10
    -1
      mindspore/ccsrc/minddata/dataset/include/samplers.h

+ 53
- 0
mindspore/ccsrc/minddata/dataset/api/datasets.cc View File

@@ -62,6 +62,7 @@
#include "minddata/dataset/engine/ir/datasetops/shuffle_node.h"
#include "minddata/dataset/engine/ir/datasetops/skip_node.h"
#include "minddata/dataset/engine/ir/datasetops/take_node.h"
#include "minddata/dataset/engine/ir/datasetops/transfer_node.h"
#include "minddata/dataset/engine/ir/datasetops/zip_node.h"

#ifndef ENABLE_ANDROID
@@ -72,6 +73,7 @@
#include "minddata/dataset/core/config_manager.h"
#include "minddata/dataset/util/path.h"
#include "minddata/dataset/util/random.h"
#include "minddata/dataset/util/services.h"

// IR leaf nodes
#include "minddata/dataset/engine/ir/datasetops/source/album_node.h"
@@ -125,6 +127,56 @@ std::shared_ptr<Iterator> Dataset::CreateIterator(std::vector<std::string> colum
return iter;
}

// Function to return a transferred Node that transfers data through a device.
bool Dataset::DeviceQueue(bool send_epoch_end) {
Status rc;

// Build and launch tree
std::unique_ptr<RuntimeContext> runtime_context = std::make_unique<RuntimeContext>();
rc = runtime_context->Init();
if (rc.IsError()) {
MS_LOG(ERROR) << "Failed to init runtime context. Error status: " << rc;
return false;
}

// Get a uuid for queue name
std::string queue_name = Services::GetUniqueID();

// TODO(CRC):
// Get device type from ms context
std::string device_type = "CPU";

// Get device ID from children
int32_t device_id = 0;
rc = TransferNode::get_distribution(shared_from_this(), &device_id);
if (rc.IsError()) {
MS_LOG(ERROR) << "Failed to get shard id. Error status: " << rc;
return false;
}

// Add TransferNode IR on top of dataset d
auto ds = std::make_shared<TransferNode>(shared_from_this(), queue_name, device_id, device_type, send_epoch_end);

// Get ToDevice consumer
auto consumer = std::make_unique<ToDevice>(device_type, send_epoch_end, -1);
ToDevice *consumer_ = consumer.get();
rc = consumer->Init(ds);
if (rc.IsError()) {
MS_LOG(ERROR) << "ToDevice: Failed to init. Error status: " << rc;
return false;
}
runtime_context->AssignConsumer(std::move(consumer));

// Send data to device
rc = consumer_->Send();
if (rc.IsError()) {
MS_LOG(ERROR) << "ToDevice: Failed to send data to device. Error status: " << rc;
return false;
}

return true;
}

#ifndef ENABLE_ANDROID
// Function to create the saver, which will build and launch the execution tree and save data
bool Dataset::Save(std::string dataset_path, int32_t num_files, std::string dataset_type) {
@@ -931,6 +983,7 @@ std::shared_ptr<DatasetCache> CreateDatasetCache(session_id_type id, uint64_t me
auto cache = std::make_shared<DatasetCacheImpl>(id, mem_sz, spill, hostname, port, num_connections, prefetch_sz);
return cache->ValidateParams() ? cache : nullptr;
}

#endif

} // namespace api


+ 22
- 4
mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.cc View File

@@ -74,13 +74,31 @@ Status IteratorConsumer::GetNextAsMap(std::unordered_map<std::string, TensorPtr>

// ToDevice
Status ToDevice::Init(std::shared_ptr<api::Dataset> d) {
// TODO(CRC):
// Get device ID from children look at get_distribution in python
// Add DeviceQue IR on top of dataset d

return tree_adapter_->BuildAndPrepare(std::move(d), num_epochs_);
}

Status ToDevice::Send() {
std::unique_ptr<DataBuffer> db;
RETURN_IF_NOT_OK(tree_adapter_->Launch());
RETURN_IF_NOT_OK(tree_adapter_->root()->GetNextBuffer(&db));
return Status::OK();
}

Status ToDevice::Continue() {
// tree_.root() must be DeviceQueueOp
DeviceQueueOp *op = dynamic_cast<DeviceQueueOp *>(tree_adapter_->root().get());
CHECK_FAIL_RETURN_UNEXPECTED(op != nullptr, "ContinueSend only supported by DeviceQueueOp");
op->ContinueSend();
return Status::OK();
}

Status ToDevice::Stop() {
DeviceQueueOp *op = dynamic_cast<DeviceQueueOp *>(tree_adapter_->root().get());
CHECK_FAIL_RETURN_UNEXPECTED(op != nullptr, "StopSend only supported by DeviceQueueOp");
op->StopSend();
return Status::OK();
}

#ifndef ENABLE_ANDROID
// SaveToDisk
Status SaveToDisk::ValidateParams() {


+ 17
- 13
mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.h View File

@@ -126,23 +126,27 @@ class SaveToDisk : public TreeConsumer {
/// Consumer that iterates over the dataset and send it to a device
class ToDevice : public TreeConsumer {
public:
ToDevice(std::string device_type, bool send_epoch_end, int32_t num_epochs)
ToDevice(std::string device_type, bool send_epoch_end, int32_t num_epochs = -1)
: TreeConsumer(), device_type_(device_type), send_epoch_end_(send_epoch_end), num_epochs_(num_epochs) {}

Status Init(std::shared_ptr<api::Dataset> d) override;

Status Send() {
// TODO(CRC): launch the tree
return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet.");
}
Status Stop() {
// TODO(CRC): Get root + call StopSend
return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet.");
}
Status Continue() {
// TODO(CRC): Get root + call StopSend
return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet.");
}
/// Send the data to device
/// \return Status error code
Status Send();

/// Stop to send data to device
/// \return Status error code
Status Stop();

/// Continue to send data to device
/// \return Status error code
Status Continue();

protected:
/// Method to return the name of the consumer
/// \return string
std::string Name() override { return "ToDevice"; }

private:
std::string device_type_;


+ 1
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/CMakeLists.txt View File

@@ -15,6 +15,7 @@ set(DATASET_ENGINE_IR_DATASETOPS_SRC_FILES
skip_node.cc
sync_wait_node.cc
take_node.cc
transfer_node.cc
zip_node.cc
)



+ 7
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/album_node.cc View File

@@ -68,6 +68,13 @@ std::vector<std::shared_ptr<DatasetOp>> AlbumNode::Build() {
return node_ops;
}

// Get the shard id of node
Status AlbumNode::GetShardId(int32_t *shard_id) {
*shard_id = sampler_->ShardId();

return Status::OK();
}

} // namespace api
} // namespace dataset
} // namespace mindspore

+ 4
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/album_node.h View File

@@ -44,6 +44,10 @@ class AlbumNode : public Dataset {
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;

/// \brief Get the shard id of node
/// \return Status Status::OK() if get shard id successfully
Status GetShardId(int32_t *shard_id) override;

private:
std::string dataset_dir_;
std::string schema_path_;


+ 7
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/celeba_node.cc View File

@@ -67,6 +67,13 @@ std::vector<std::shared_ptr<DatasetOp>> CelebANode::Build() {
return node_ops;
}

// Get the shard id of node
Status CelebANode::GetShardId(int32_t *shard_id) {
*shard_id = sampler_->ShardId();

return Status::OK();
}

} // namespace api
} // namespace dataset
} // namespace mindspore

+ 4
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/celeba_node.h View File

@@ -46,6 +46,10 @@ class CelebANode : public Dataset {
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;

/// \brief Get the shard id of node
/// \return Status Status::OK() if get shard id successfully
Status GetShardId(int32_t *shard_id) override;

private:
std::string dataset_dir_;
std::string usage_;


+ 7
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar100_node.cc View File

@@ -66,6 +66,13 @@ std::vector<std::shared_ptr<DatasetOp>> Cifar100Node::Build() {
return node_ops;
}

// Get the shard id of node
Status Cifar100Node::GetShardId(int32_t *shard_id) {
*shard_id = sampler_->ShardId();

return Status::OK();
}

} // namespace api
} // namespace dataset
} // namespace mindspore

+ 4
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar100_node.h View File

@@ -44,6 +44,10 @@ class Cifar100Node : public Dataset {
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;

/// \brief Get the shard id of node
/// \return Status Status::OK() if get shard id successfully
Status GetShardId(int32_t *shard_id) override;

private:
std::string dataset_dir_;
std::string usage_;


+ 7
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar10_node.cc View File

@@ -64,6 +64,13 @@ std::vector<std::shared_ptr<DatasetOp>> Cifar10Node::Build() {
return node_ops;
}

// Get the shard id of node
Status Cifar10Node::GetShardId(int32_t *shard_id) {
*shard_id = sampler_->ShardId();

return Status::OK();
}

} // namespace api
} // namespace dataset
} // namespace mindspore

+ 4
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar10_node.h View File

@@ -44,6 +44,10 @@ class Cifar10Node : public Dataset {
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;

/// \brief Get the shard id of node
/// \return Status Status::OK() if get shard id successfully
Status GetShardId(int32_t *shard_id) override;

private:
std::string dataset_dir_;
std::string usage_;


+ 7
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/clue_node.cc View File

@@ -213,6 +213,13 @@ std::vector<std::shared_ptr<DatasetOp>> CLUENode::Build() {
return node_ops;
}

// Get the shard id of node
Status CLUENode::GetShardId(int32_t *shard_id) {
*shard_id = shard_id_;

return Status::OK();
}

} // namespace api
} // namespace dataset
} // namespace mindspore

+ 4
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/clue_node.h View File

@@ -45,6 +45,10 @@ class CLUENode : public Dataset {
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;

/// \brief Get the shard id of node
/// \return Status Status::OK() if get shard id successfully
Status GetShardId(int32_t *shard_id) override;

private:
/// \brief Split string based on a character delimiter
/// \return A string vector


+ 8
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/coco_node.cc View File

@@ -117,6 +117,14 @@ std::vector<std::shared_ptr<DatasetOp>> CocoNode::Build() {

return node_ops;
}

// Get the shard id of node
Status CocoNode::GetShardId(int32_t *shard_id) {
*shard_id = sampler_->ShardId();

return Status::OK();
}

} // namespace api
} // namespace dataset
} // namespace mindspore

+ 4
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/coco_node.h View File

@@ -43,6 +43,10 @@ class CocoNode : public Dataset {
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;

/// \brief Get the shard id of node
/// \return Status Status::OK() if get shard id successfully
Status GetShardId(int32_t *shard_id) override;

private:
std::string dataset_dir_;
std::string annotation_file_;


+ 8
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/csv_node.cc View File

@@ -122,6 +122,14 @@ std::vector<std::shared_ptr<DatasetOp>> CSVNode::Build() {

return node_ops;
}

// Get the shard id of node
Status CSVNode::GetShardId(int32_t *shard_id) {
*shard_id = shard_id_;

return Status::OK();
}

} // namespace api
} // namespace dataset
} // namespace mindspore

+ 4
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/csv_node.h View File

@@ -66,6 +66,10 @@ class CSVNode : public Dataset {
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;

/// \brief Get the shard id of node
/// \return Status Status::OK() if get shard id successfully
Status GetShardId(int32_t *shard_id) override;

private:
std::vector<std::string> dataset_files_;
char field_delim_;


+ 8
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/image_folder_node.cc View File

@@ -70,6 +70,14 @@ std::vector<std::shared_ptr<DatasetOp>> ImageFolderNode::Build() {
std::move(sampler_->Build())));
return node_ops;
}

// Get the shard id of node
Status ImageFolderNode::GetShardId(int32_t *shard_id) {
*shard_id = sampler_->ShardId();

return Status::OK();
}

} // namespace api
} // namespace dataset
} // namespace mindspore

+ 4
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/image_folder_node.h View File

@@ -51,6 +51,10 @@ class ImageFolderNode : public Dataset {
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;

/// \brief Get the shard id of node
/// \return Status Status::OK() if get shard id successfully
Status GetShardId(int32_t *shard_id) override;

private:
std::string dataset_dir_;
bool decode_;


+ 8
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/manifest_node.cc View File

@@ -85,6 +85,14 @@ std::vector<std::shared_ptr<DatasetOp>> ManifestNode::Build() {

return node_ops;
}

// Get the shard id of node
Status ManifestNode::GetShardId(int32_t *shard_id) {
*shard_id = sampler_->ShardId();

return Status::OK();
}

} // namespace api
} // namespace dataset
} // namespace mindspore

+ 4
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/manifest_node.h View File

@@ -44,6 +44,10 @@ class ManifestNode : public Dataset {
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;

/// \brief Get the shard id of node
/// \return Status Status::OK() if get shard id successfully
Status GetShardId(int32_t *shard_id) override;

private:
std::string dataset_file_;
std::string usage_;


+ 7
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/minddata_node.cc View File

@@ -160,6 +160,13 @@ std::vector<std::shared_ptr<DatasetOp>> MindDataNode::Build() {
return node_ops;
}

// Get the shard id of node
Status MindDataNode::GetShardId(int32_t *shard_id) {
*shard_id = sampler_->ShardId();

return Status::OK();
}

} // namespace api
} // namespace dataset
} // namespace mindspore

+ 4
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/minddata_node.h View File

@@ -48,6 +48,10 @@ class MindDataNode : public Dataset {
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;

/// \brief Get the shard id of node
/// \return Status Status::OK() if get shard id successfully
Status GetShardId(int32_t *shard_id) override;

/// \brief Build sampler chain for minddata dataset
/// \return Status Status::OK() if input sampler is valid
Status BuildMindDatasetSamplerChain(const std::shared_ptr<SamplerObj> &sampler,


+ 7
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/mnist_node.cc View File

@@ -60,6 +60,13 @@ std::vector<std::shared_ptr<DatasetOp>> MnistNode::Build() {
return node_ops;
}

// Get the shard id of node
Status MnistNode::GetShardId(int32_t *shard_id) {
*shard_id = sampler_->ShardId();

return Status::OK();
}

} // namespace api
} // namespace dataset
} // namespace mindspore

+ 4
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/mnist_node.h View File

@@ -44,6 +44,10 @@ class MnistNode : public Dataset {
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;

/// \brief Get the shard id of node
/// \return Status Status::OK() if get shard id successfully
Status GetShardId(int32_t *shard_id) override;

private:
std::string dataset_dir_;
std::string usage_;


+ 7
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/random_node.cc View File

@@ -99,6 +99,13 @@ std::vector<std::shared_ptr<DatasetOp>> RandomNode::Build() {
return node_ops;
}

// Get the shard id of node
Status RandomNode::GetShardId(int32_t *shard_id) {
*shard_id = sampler_->ShardId();

return Status::OK();
}

} // namespace api
} // namespace dataset
} // namespace mindspore

+ 4
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/random_node.h View File

@@ -65,6 +65,10 @@ class RandomNode : public Dataset {
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;

/// \brief Get the shard id of node
/// \return Status Status::OK() if get shard id successfully
Status GetShardId(int32_t *shard_id) override;

private:
/// \brief A quick inline for producing a random number between (and including) min/max
/// \param[in] min minimum number that can be generated.


+ 7
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/text_file_node.cc View File

@@ -95,6 +95,13 @@ std::vector<std::shared_ptr<DatasetOp>> TextFileNode::Build() {
return node_ops;
}

// Get the shard id of node
Status TextFileNode::GetShardId(int32_t *shard_id) {
*shard_id = shard_id_;

return Status::OK();
}

} // namespace api
} // namespace dataset
} // namespace mindspore

+ 4
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/text_file_node.h View File

@@ -45,6 +45,10 @@ class TextFileNode : public Dataset {
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;

/// \brief Get the shard id of node
/// \return Status Status::OK() if get shard id successfully
Status GetShardId(int32_t *shard_id) override;

private:
std::vector<std::string> dataset_files_;
int32_t num_samples_;


+ 7
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/tf_record_node.cc View File

@@ -80,6 +80,13 @@ std::vector<std::shared_ptr<DatasetOp>> TFRecordNode::Build() {
return node_ops;
}

// Get the shard id of node
Status TFRecordNode::GetShardId(int32_t *shard_id) {
*shard_id = shard_id_;

return Status::OK();
}

} // namespace api
} // namespace dataset
} // namespace mindspore

+ 4
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/tf_record_node.h View File

@@ -72,6 +72,10 @@ class TFRecordNode : public Dataset {
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;

/// \brief Get the shard id of node
/// \return Status Status::OK() if get shard id successfully
Status GetShardId(int32_t *shard_id) override;

private:
std::vector<std::string> dataset_files_;
std::string schema_path_; // schema_path_ path to schema file. It is set when type of schema parameter is string


+ 7
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/voc_node.cc View File

@@ -112,6 +112,13 @@ std::vector<std::shared_ptr<DatasetOp>> VOCNode::Build() {
return node_ops;
}

// Get the shard id of node
Status VOCNode::GetShardId(int32_t *shard_id) {
*shard_id = sampler_->ShardId();

return Status::OK();
}

} // namespace api
} // namespace dataset
} // namespace mindspore

+ 4
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/voc_node.h View File

@@ -45,6 +45,10 @@ class VOCNode : public Dataset {
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;

/// \brief Get the shard id of node
/// \return Status Status::OK() if get shard id successfully
Status GetShardId(int32_t *shard_id) override;

private:
const std::string kColumnImage = "image";
const std::string kColumnTarget = "target";


+ 90
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/transfer_node.cc View File

@@ -0,0 +1,90 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "minddata/dataset/engine/ir/datasetops/transfer_node.h"

#include <memory>
#include <string>
#include <vector>

#include "minddata/dataset/engine/datasetops/device_queue_op.h"
#include "minddata/dataset/util/status.h"

namespace mindspore {
namespace dataset {
namespace api {

// Constructor for TransferNode
TransferNode::TransferNode(std::shared_ptr<Dataset> child, const std::string &queue_name, int32_t device_id,
const std::string &device_type, bool send_epoch_end)
: queue_name_(queue_name),
device_id_(device_id),
device_type_(device_type),
prefetch_size_(16),
send_epoch_end_(send_epoch_end),
total_batch_(0) {
this->children.push_back(child);
}

// Validator for TransferNode
Status TransferNode::ValidateParams() {
// Check if device_type_ is in {"CPU", "GPU", "Ascend"}
RETURN_IF_NOT_OK(ValidateStringValue("TransferNode", device_type_, {"CPU", "GPU", "Ascend"}));
return Status::OK();
}

// Function to build TransferNode
std::vector<std::shared_ptr<DatasetOp>> TransferNode::Build() {
// A vector containing shared pointer to the Dataset Ops that this object will create
std::vector<std::shared_ptr<DatasetOp>> node_ops;

// Convert device_type_ from string to DeviceType
DeviceQueueOp::DeviceType type;
if (device_type_ == "CPU") {
type = DeviceQueueOp::DeviceType::CPU;
} else if (device_type_ == "GPU") {
type = DeviceQueueOp::DeviceType::GPU;
} else if (device_type_ == "Ascend") {
type = DeviceQueueOp::DeviceType::Ascend;
}

node_ops.push_back(
std::make_shared<DeviceQueueOp>(queue_name_, type, device_id_, prefetch_size_, send_epoch_end_, total_batch_));
return node_ops;
}

// Function to get the device_id
Status TransferNode::get_distribution(std::shared_ptr<Dataset> ds, int32_t *device_id) {
// Get device id according to the type of dataset
Status rc = ds->GetShardId(device_id);
if (rc != Status::OK()) {
// Get device id from the child node
if (ds->children.size()) {
ds = ds->children[0];
return TransferNode::get_distribution(ds, device_id);
} else {
std::string err_msg = "Unknown dataset type.";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
}

return Status::OK();
}

} // namespace api
} // namespace dataset
} // namespace mindspore

+ 62
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/transfer_node.h View File

@@ -0,0 +1,62 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_TRANSFER_NODE_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_TRANSFER_NODE_H_

#include <memory>
#include <string>
#include <vector>

#include "minddata/dataset/include/datasets.h"

namespace mindspore {
namespace dataset {

namespace api {

class TransferNode : public Dataset {
public:
/// \brief Constructor
TransferNode(std::shared_ptr<Dataset> child, const std::string &queue_name, int32_t device_id,
const std::string &device_type, bool send_epoch_end);

/// \brief Destructor
~TransferNode() = default;

/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return shared pointer to the list of newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override;

/// \brief Parameters validation
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;

static Status get_distribution(std::shared_ptr<Dataset> ds, int32_t *device_id);

private:
std::string queue_name_;
int32_t device_id_;
std::string device_type_;
int32_t prefetch_size_;
bool send_epoch_end_;
int32_t total_batch_;
};

} // namespace api
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_TRANSFER_NODE_H_

+ 4
- 0
mindspore/ccsrc/minddata/dataset/engine/tree_adapter.h View File

@@ -57,6 +57,10 @@ class TreeAdapter {
// to be able to launch a thread. BuildAndPrepare needs to be called before this function
TaskGroup *AllTasks() const { return tree_ != nullptr ? tree_->AllTasks() : nullptr; }

std::shared_ptr<DatasetOp> root() { return tree_->root(); }

Status Launch() const { return tree_->Launch(); }

private:
// This RECURSIVE function converts IR nodes into DatasetOp in ExecutionTree. IR could build a vector of ops. In
// such case, the first node is returned. Op is added as child when the current function returns.


+ 15
- 0
mindspore/ccsrc/minddata/dataset/include/datasets.h View File

@@ -96,6 +96,7 @@ class RepeatNode;
class ShuffleNode;
class SkipNode;
class TakeNode;
class TransferNode;
class ZipNode;

#define RETURN_EMPTY_IF_ERROR(_s) \
@@ -559,6 +560,7 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
public:
// need friend class so they can access the children_ field
friend class Iterator;
friend class TransferNode;
friend class mindspore::dataset::TreeAdapter;

/// \brief Constructor
@@ -579,6 +581,12 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
/// \return Status Status::OK() if all the parameters are valid
virtual Status ValidateParams() = 0;

/// \brief Pure virtual function for derived class to get the shard id of specific node
/// \return Status Status::OK() if get shard id successfully
virtual Status GetShardId(int32_t *shard_id) {
return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet.");
}

/// \brief Gets the dataset size
/// \return status code
int64_t GetDatasetSize();
@@ -617,6 +625,13 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
/// \return Shared pointer to the Iterator
std::shared_ptr<Iterator> CreateIterator(std::vector<std::string> columns = {});

/// \brief Function to transfer data through a device.
/// \notes If device is Ascend, features of data will be transferred one by one. The limitation
/// of data transmission per time is 256M.
/// \param[in] send_epoch_end Whether to send end of sequence to device or not (default=True).
/// \return Returns true if no error encountered else false.
bool DeviceQueue(bool send_epoch_end = true);

#ifndef ENABLE_ANDROID
/// \brief Function to create a Saver to save the dynamic data processed by the dataset pipeline
/// \note Usage restrictions:


+ 10
- 1
mindspore/ccsrc/minddata/dataset/include/samplers.h View File

@@ -17,8 +17,9 @@
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_SAMPLERS_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_SAMPLERS_H_

#include <vector>
#include <memory>
#include <string>
#include <vector>

#ifndef ENABLE_ANDROID
#include "minddata/dataset/engine/datasetops/source/mindrecord_op.h"
@@ -48,6 +49,10 @@ class SamplerObj : public std::enable_shared_from_this<SamplerObj> {
/// \return Shared pointers to the newly created Sampler
virtual std::shared_ptr<Sampler> Build() = 0;

/// \brief Function for derived class to get the shard id of sampler
/// \return The shard id of the derived sampler
virtual int64_t ShardId() { return 0; }

#ifndef ENABLE_ANDROID
/// \brief Virtual function to convert a SamplerObj class into a runtime mindrecord sampler object,
/// only override by SubsetRandomSampler, PkSampler, RandomSampler, SequentialSampler, DistributedSampler
@@ -134,6 +139,10 @@ class DistributedSamplerObj : public SamplerObj {

bool ValidateParams() override;

/// \brief Function to get the shard id of sampler
/// \return The shard id of sampler
int64_t ShardId() override { return shard_id_; }

private:
int64_t num_shards_;
int64_t shard_id_;


Loading…
Cancel
Save