Browse Source

Update RuntimeContext

tags/v1.1.0
hesham 5 years ago
parent
commit
982efd0e25
7 changed files with 75 additions and 34 deletions
  1. +10
    -10
      mindspore/ccsrc/minddata/dataset/api/datasets.cc
  2. +1
    -1
      mindspore/ccsrc/minddata/dataset/api/iterator.cc
  3. +16
    -1
      mindspore/ccsrc/minddata/dataset/engine/python_runtime_context.cc
  4. +10
    -11
      mindspore/ccsrc/minddata/dataset/engine/python_runtime_context.h
  5. +13
    -0
      mindspore/ccsrc/minddata/dataset/engine/runtime_context.cc
  6. +23
    -9
      mindspore/ccsrc/minddata/dataset/engine/runtime_context.h
  7. +2
    -2
      mindspore/ccsrc/minddata/dataset/include/iterator.h

+ 10
- 10
mindspore/ccsrc/minddata/dataset/api/datasets.cc View File

@@ -111,7 +111,7 @@ bool Dataset::DeviceQueue(bool send_epoch_end) {
Status rc;

// Build and launch tree
std::unique_ptr<RuntimeContext> runtime_context = std::make_unique<RuntimeContext>();
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
rc = runtime_context->Init();
if (rc.IsError()) {
MS_LOG(ERROR) << "Failed to init runtime context. Error status: " << rc;
@@ -147,7 +147,7 @@ bool Dataset::Save(std::string dataset_path, int32_t num_files, std::string data
Status rc;
// Build and launch tree
auto ds = shared_from_this();
std::unique_ptr<RuntimeContext> runtime_context = std::make_unique<RuntimeContext>();
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
rc = runtime_context->Init();
if (rc.IsError()) {
MS_LOG(ERROR) << "CreateSaver failed." << rc;
@@ -193,7 +193,7 @@ Dataset::Dataset() { tree_getters_ = std::make_shared<TreeGetters>(); }
int64_t Dataset::GetDatasetSize() {
int64_t dataset_size;
Status rc;
std::unique_ptr<RuntimeContext> runtime_context = std::make_unique<RuntimeContext>();
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
rc = runtime_context->Init();
if (rc.IsError()) {
MS_LOG(ERROR) << "GetDatasetSize: Initializing RuntimeContext failed.";
@@ -213,7 +213,7 @@ int64_t Dataset::GetDatasetSize() {
std::vector<DataType> Dataset::GetOutputTypes() {
std::vector<DataType> types;
Status rc;
std::unique_ptr<RuntimeContext> runtime_context = std::make_unique<RuntimeContext>();
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
rc = runtime_context->Init();
if (rc.IsError()) {
MS_LOG(ERROR) << "GetOutputTypes: Initializing RuntimeContext failed.";
@@ -240,7 +240,7 @@ std::vector<DataType> Dataset::GetOutputTypes() {
std::vector<TensorShape> Dataset::GetOutputShapes() {
std::vector<TensorShape> shapes;
Status rc;
std::unique_ptr<RuntimeContext> runtime_context = std::make_unique<RuntimeContext>();
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
rc = runtime_context->Init();
if (rc.IsError()) {
MS_LOG(ERROR) << "GetOutputShapes: Initializing RuntimeContext failed.";
@@ -268,7 +268,7 @@ int64_t Dataset::GetNumClasses() {
int64_t num_classes;
auto ds = shared_from_this();
Status rc;
std::unique_ptr<RuntimeContext> runtime_context = std::make_unique<RuntimeContext>();
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
rc = runtime_context->Init();
if (rc.IsError()) {
MS_LOG(ERROR) << "GetNumClasses: Initializing RuntimeContext failed.";
@@ -562,7 +562,7 @@ int64_t Dataset::GetBatchSize() {
int64_t batch_size;
auto ds = shared_from_this();
Status rc;
std::unique_ptr<RuntimeContext> runtime_context = std::make_unique<RuntimeContext>();
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
rc = runtime_context->Init();
if (rc.IsError()) {
MS_LOG(ERROR) << "GetBatchSize: Initializing RuntimeContext failed.";
@@ -583,7 +583,7 @@ int64_t Dataset::GetRepeatCount() {
int64_t repeat_count;
auto ds = shared_from_this();
Status rc;
std::unique_ptr<RuntimeContext> runtime_context = std::make_unique<RuntimeContext>();
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
rc = runtime_context->Init();
if (rc.IsError()) {
MS_LOG(ERROR) << "GetRepeatCount: Initializing RuntimeContext failed.";
@@ -613,7 +613,7 @@ std::shared_ptr<SentencePieceVocab> Dataset::BuildSentencePieceVocab(
auto ds = std::make_shared<BuildSentenceVocabNode>(IRNode(), vocab, col_names, vocab_size, character_coverage,
model_type, params);

std::unique_ptr<RuntimeContext> runtime_context = std::make_unique<RuntimeContext>();
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
Status rc = runtime_context->Init();
if (rc.IsError()) {
MS_LOG(ERROR) << "BuildSentencePieceVocab: Failed to init runtime context. Error status: " << rc;
@@ -645,7 +645,7 @@ std::shared_ptr<Vocab> Dataset::BuildVocab(const std::vector<std::string> &colum
auto ds =
std::make_shared<BuildVocabNode>(IRNode(), vocab, columns, freq_range, top_k, special_tokens, special_first);

std::unique_ptr<RuntimeContext> runtime_context = std::make_unique<RuntimeContext>();
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
Status rc = runtime_context->Init();
if (rc.IsError()) {
MS_LOG(ERROR) << "BuildVocab: Failed to init runtime context. Error status: " << rc;


+ 1
- 1
mindspore/ccsrc/minddata/dataset/api/iterator.cc View File

@@ -48,7 +48,7 @@ void Iterator::Stop() { runtime_context_->Terminate(); }

// Function to build and launch the execution tree.
Status Iterator::BuildAndLaunchTree(std::shared_ptr<Dataset> ds) {
runtime_context_ = std::make_unique<RuntimeContext>();
runtime_context_ = std::make_unique<NativeRuntimeContext>();
RETURN_IF_NOT_OK(runtime_context_->Init());
auto consumer = std::make_unique<IteratorConsumer>();
consumer_ = consumer.get();


+ 16
- 1
mindspore/ccsrc/minddata/dataset/engine/python_runtime_context.cc View File

@@ -19,9 +19,24 @@

namespace mindspore::dataset {

Status PythonRuntimeContext::Terminate() {
Status PythonRuntimeContext::Terminate() { return TerminateImpl(); }

Status PythonRuntimeContext::TerminateImpl() {
CHECK_FAIL_RETURN_UNEXPECTED(tree_consumer_ != nullptr, " Tree Consumer is not initialized");
// Release GIL before joining all threads
py::gil_scoped_release gil_release;
return tree_consumer_->Terminate();
}

PythonRuntimeContext::~PythonRuntimeContext() {
TerminateImpl();
{
py::gil_scoped_acquire gil_acquire;
tree_consumer_.reset();
}
}

PythonIteratorConsumer *PythonRuntimeContext::GetPythonConsumer() {
return dynamic_cast<PythonIteratorConsumer *>(tree_consumer_.get());
}
} // namespace mindspore::dataset

+ 10
- 11
mindspore/ccsrc/minddata/dataset/engine/python_runtime_context.h View File

@@ -24,25 +24,24 @@
#include "minddata/dataset/engine/runtime_context.h"

namespace mindspore::dataset {
class RuntimeContext;
class NativeRuntimeContext;

/// Class that represents single runtime instance which can consume data from a data pipeline
/// Class that represents Python single runtime instance which can consume data from a data pipeline
class PythonRuntimeContext : public RuntimeContext {
public:
/// Method to terminate the runtime, this will not release the resources
/// \return Status error code
Status Terminate() override;

// Safe destructing the tree that includes python objects
~PythonRuntimeContext() {
Terminate();
{
py::gil_scoped_acquire gil_acquire;
tree_consumer_.reset();
}
}
/// Safe destructing the tree that includes python objects
~PythonRuntimeContext() override;

PythonIteratorConsumer *GetPythonConsumer() { return dynamic_cast<PythonIteratorConsumer *>(tree_consumer_.get()); }
PythonIteratorConsumer *GetPythonConsumer();

private:
/// Internal function to perform the termination
/// \return Status error code
Status TerminateImpl();
};

} // namespace mindspore::dataset


+ 13
- 0
mindspore/ccsrc/minddata/dataset/engine/runtime_context.cc View File

@@ -22,4 +22,17 @@ namespace mindspore::dataset {
void RuntimeContext::AssignConsumer(std::shared_ptr<TreeConsumer> tree_consumer) {
tree_consumer_ = std::move(tree_consumer);
}
Status NativeRuntimeContext::Terminate() { return TerminateImpl(); }

Status NativeRuntimeContext::TerminateImpl() {
CHECK_FAIL_RETURN_UNEXPECTED(tree_consumer_ != nullptr, " Tree Consumer is not initialized");
return tree_consumer_->Terminate();
}

NativeRuntimeContext::~NativeRuntimeContext() { TerminateImpl(); }

TreeConsumer *RuntimeContext::GetConsumer() { return tree_consumer_.get(); }

Status RuntimeContext::Init() { return GlobalInit(); }

} // namespace mindspore::dataset

+ 23
- 9
mindspore/ccsrc/minddata/dataset/engine/runtime_context.h View File

@@ -23,8 +23,7 @@

namespace mindspore::dataset {
class TreeConsumer;

/// Class the represents single runtime instance which can consume data from a data pipeline
/// Class that represents single runtime instance which can consume data from a data pipeline
class RuntimeContext {
public:
/// Default constructor
@@ -32,11 +31,7 @@ class RuntimeContext {

/// Initialize the runtime, for now we just call the global init
/// \return Status error code
Status Init() { return GlobalInit(); }

/// Method to terminate the runtime, this will not release the resources
/// \return Status error code
virtual Status Terminate() { return Status::OK(); }
Status Init();

/// Set the tree consumer
/// \param tree_consumer to be assigned
@@ -44,13 +39,32 @@ class RuntimeContext {

/// Get the tree consumer
/// \return Raw pointer to the tree consumer.
TreeConsumer *GetConsumer() { return tree_consumer_.get(); }
TreeConsumer *GetConsumer();

~RuntimeContext() { Terminate(); }
/// Method to terminate the runtime, this will not release the resources
/// \return Status error code
virtual Status Terminate() = 0;

virtual ~RuntimeContext() = default;

protected:
std::shared_ptr<TreeConsumer> tree_consumer_;
};

/// Class that represents C++ single runtime instance which can consume data from a data pipeline
class NativeRuntimeContext : public RuntimeContext {
public:
/// Method to terminate the runtime, this will not release the resources
/// \return Status error code
Status Terminate() override;

~NativeRuntimeContext() override;

private:
/// Internal function to perform the termination
/// \return Status error code
Status TerminateImpl();
};

} // namespace mindspore::dataset
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_RUNTIME_CONTEXT_H_

+ 2
- 2
mindspore/ccsrc/minddata/dataset/include/iterator.h View File

@@ -33,7 +33,7 @@ class DatasetIterator;
class DatasetOp;
class Tensor;

class RuntimeContext;
class NativeRuntimeContext;
class IteratorConsumer;

class Dataset;
@@ -113,7 +113,7 @@ class Iterator {
_Iterator end() { return _Iterator(nullptr); }

private:
std::unique_ptr<RuntimeContext> runtime_context_;
std::unique_ptr<NativeRuntimeContext> runtime_context_;
IteratorConsumer *consumer_;
};
} // namespace dataset


Loading…
Cancel
Save