Browse Source

Update RuntimeContext

tags/v1.1.0
hesham 5 years ago
parent
commit
982efd0e25
7 changed files with 75 additions and 34 deletions
  1. +10
    -10
      mindspore/ccsrc/minddata/dataset/api/datasets.cc
  2. +1
    -1
      mindspore/ccsrc/minddata/dataset/api/iterator.cc
  3. +16
    -1
      mindspore/ccsrc/minddata/dataset/engine/python_runtime_context.cc
  4. +10
    -11
      mindspore/ccsrc/minddata/dataset/engine/python_runtime_context.h
  5. +13
    -0
      mindspore/ccsrc/minddata/dataset/engine/runtime_context.cc
  6. +23
    -9
      mindspore/ccsrc/minddata/dataset/engine/runtime_context.h
  7. +2
    -2
      mindspore/ccsrc/minddata/dataset/include/iterator.h

+ 10
- 10
mindspore/ccsrc/minddata/dataset/api/datasets.cc View File

@@ -111,7 +111,7 @@ bool Dataset::DeviceQueue(bool send_epoch_end) {
Status rc; Status rc;


// Build and launch tree // Build and launch tree
std::unique_ptr<RuntimeContext> runtime_context = std::make_unique<RuntimeContext>();
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
rc = runtime_context->Init(); rc = runtime_context->Init();
if (rc.IsError()) { if (rc.IsError()) {
MS_LOG(ERROR) << "Failed to init runtime context. Error status: " << rc; MS_LOG(ERROR) << "Failed to init runtime context. Error status: " << rc;
@@ -147,7 +147,7 @@ bool Dataset::Save(std::string dataset_path, int32_t num_files, std::string data
Status rc; Status rc;
// Build and launch tree // Build and launch tree
auto ds = shared_from_this(); auto ds = shared_from_this();
std::unique_ptr<RuntimeContext> runtime_context = std::make_unique<RuntimeContext>();
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
rc = runtime_context->Init(); rc = runtime_context->Init();
if (rc.IsError()) { if (rc.IsError()) {
MS_LOG(ERROR) << "CreateSaver failed." << rc; MS_LOG(ERROR) << "CreateSaver failed." << rc;
@@ -193,7 +193,7 @@ Dataset::Dataset() { tree_getters_ = std::make_shared<TreeGetters>(); }
int64_t Dataset::GetDatasetSize() { int64_t Dataset::GetDatasetSize() {
int64_t dataset_size; int64_t dataset_size;
Status rc; Status rc;
std::unique_ptr<RuntimeContext> runtime_context = std::make_unique<RuntimeContext>();
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
rc = runtime_context->Init(); rc = runtime_context->Init();
if (rc.IsError()) { if (rc.IsError()) {
MS_LOG(ERROR) << "GetDatasetSize: Initializing RuntimeContext failed."; MS_LOG(ERROR) << "GetDatasetSize: Initializing RuntimeContext failed.";
@@ -213,7 +213,7 @@ int64_t Dataset::GetDatasetSize() {
std::vector<DataType> Dataset::GetOutputTypes() { std::vector<DataType> Dataset::GetOutputTypes() {
std::vector<DataType> types; std::vector<DataType> types;
Status rc; Status rc;
std::unique_ptr<RuntimeContext> runtime_context = std::make_unique<RuntimeContext>();
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
rc = runtime_context->Init(); rc = runtime_context->Init();
if (rc.IsError()) { if (rc.IsError()) {
MS_LOG(ERROR) << "GetOutputTypes: Initializing RuntimeContext failed."; MS_LOG(ERROR) << "GetOutputTypes: Initializing RuntimeContext failed.";
@@ -240,7 +240,7 @@ std::vector<DataType> Dataset::GetOutputTypes() {
std::vector<TensorShape> Dataset::GetOutputShapes() { std::vector<TensorShape> Dataset::GetOutputShapes() {
std::vector<TensorShape> shapes; std::vector<TensorShape> shapes;
Status rc; Status rc;
std::unique_ptr<RuntimeContext> runtime_context = std::make_unique<RuntimeContext>();
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
rc = runtime_context->Init(); rc = runtime_context->Init();
if (rc.IsError()) { if (rc.IsError()) {
MS_LOG(ERROR) << "GetOutputShapes: Initializing RuntimeContext failed."; MS_LOG(ERROR) << "GetOutputShapes: Initializing RuntimeContext failed.";
@@ -268,7 +268,7 @@ int64_t Dataset::GetNumClasses() {
int64_t num_classes; int64_t num_classes;
auto ds = shared_from_this(); auto ds = shared_from_this();
Status rc; Status rc;
std::unique_ptr<RuntimeContext> runtime_context = std::make_unique<RuntimeContext>();
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
rc = runtime_context->Init(); rc = runtime_context->Init();
if (rc.IsError()) { if (rc.IsError()) {
MS_LOG(ERROR) << "GetNumClasses: Initializing RuntimeContext failed."; MS_LOG(ERROR) << "GetNumClasses: Initializing RuntimeContext failed.";
@@ -562,7 +562,7 @@ int64_t Dataset::GetBatchSize() {
int64_t batch_size; int64_t batch_size;
auto ds = shared_from_this(); auto ds = shared_from_this();
Status rc; Status rc;
std::unique_ptr<RuntimeContext> runtime_context = std::make_unique<RuntimeContext>();
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
rc = runtime_context->Init(); rc = runtime_context->Init();
if (rc.IsError()) { if (rc.IsError()) {
MS_LOG(ERROR) << "GetBatchSize: Initializing RuntimeContext failed."; MS_LOG(ERROR) << "GetBatchSize: Initializing RuntimeContext failed.";
@@ -583,7 +583,7 @@ int64_t Dataset::GetRepeatCount() {
int64_t repeat_count; int64_t repeat_count;
auto ds = shared_from_this(); auto ds = shared_from_this();
Status rc; Status rc;
std::unique_ptr<RuntimeContext> runtime_context = std::make_unique<RuntimeContext>();
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
rc = runtime_context->Init(); rc = runtime_context->Init();
if (rc.IsError()) { if (rc.IsError()) {
MS_LOG(ERROR) << "GetRepeatCount: Initializing RuntimeContext failed."; MS_LOG(ERROR) << "GetRepeatCount: Initializing RuntimeContext failed.";
@@ -613,7 +613,7 @@ std::shared_ptr<SentencePieceVocab> Dataset::BuildSentencePieceVocab(
auto ds = std::make_shared<BuildSentenceVocabNode>(IRNode(), vocab, col_names, vocab_size, character_coverage, auto ds = std::make_shared<BuildSentenceVocabNode>(IRNode(), vocab, col_names, vocab_size, character_coverage,
model_type, params); model_type, params);


std::unique_ptr<RuntimeContext> runtime_context = std::make_unique<RuntimeContext>();
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
Status rc = runtime_context->Init(); Status rc = runtime_context->Init();
if (rc.IsError()) { if (rc.IsError()) {
MS_LOG(ERROR) << "BuildSentencePieceVocab: Failed to init runtime context. Error status: " << rc; MS_LOG(ERROR) << "BuildSentencePieceVocab: Failed to init runtime context. Error status: " << rc;
@@ -645,7 +645,7 @@ std::shared_ptr<Vocab> Dataset::BuildVocab(const std::vector<std::string> &colum
auto ds = auto ds =
std::make_shared<BuildVocabNode>(IRNode(), vocab, columns, freq_range, top_k, special_tokens, special_first); std::make_shared<BuildVocabNode>(IRNode(), vocab, columns, freq_range, top_k, special_tokens, special_first);


std::unique_ptr<RuntimeContext> runtime_context = std::make_unique<RuntimeContext>();
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
Status rc = runtime_context->Init(); Status rc = runtime_context->Init();
if (rc.IsError()) { if (rc.IsError()) {
MS_LOG(ERROR) << "BuildVocab: Failed to init runtime context. Error status: " << rc; MS_LOG(ERROR) << "BuildVocab: Failed to init runtime context. Error status: " << rc;


+ 1
- 1
mindspore/ccsrc/minddata/dataset/api/iterator.cc View File

@@ -48,7 +48,7 @@ void Iterator::Stop() { runtime_context_->Terminate(); }


// Function to build and launch the execution tree. // Function to build and launch the execution tree.
Status Iterator::BuildAndLaunchTree(std::shared_ptr<Dataset> ds) { Status Iterator::BuildAndLaunchTree(std::shared_ptr<Dataset> ds) {
runtime_context_ = std::make_unique<RuntimeContext>();
runtime_context_ = std::make_unique<NativeRuntimeContext>();
RETURN_IF_NOT_OK(runtime_context_->Init()); RETURN_IF_NOT_OK(runtime_context_->Init());
auto consumer = std::make_unique<IteratorConsumer>(); auto consumer = std::make_unique<IteratorConsumer>();
consumer_ = consumer.get(); consumer_ = consumer.get();


+ 16
- 1
mindspore/ccsrc/minddata/dataset/engine/python_runtime_context.cc View File

@@ -19,9 +19,24 @@


namespace mindspore::dataset { namespace mindspore::dataset {


Status PythonRuntimeContext::Terminate() {
Status PythonRuntimeContext::Terminate() { return TerminateImpl(); }

Status PythonRuntimeContext::TerminateImpl() {
CHECK_FAIL_RETURN_UNEXPECTED(tree_consumer_ != nullptr, " Tree Consumer is not initialized");
// Release GIL before joining all threads // Release GIL before joining all threads
py::gil_scoped_release gil_release; py::gil_scoped_release gil_release;
return tree_consumer_->Terminate(); return tree_consumer_->Terminate();
} }

PythonRuntimeContext::~PythonRuntimeContext() {
TerminateImpl();
{
py::gil_scoped_acquire gil_acquire;
tree_consumer_.reset();
}
}

PythonIteratorConsumer *PythonRuntimeContext::GetPythonConsumer() {
return dynamic_cast<PythonIteratorConsumer *>(tree_consumer_.get());
}
} // namespace mindspore::dataset } // namespace mindspore::dataset

+ 10
- 11
mindspore/ccsrc/minddata/dataset/engine/python_runtime_context.h View File

@@ -24,25 +24,24 @@
#include "minddata/dataset/engine/runtime_context.h" #include "minddata/dataset/engine/runtime_context.h"


namespace mindspore::dataset { namespace mindspore::dataset {
class RuntimeContext;
class NativeRuntimeContext;


/// Class that represents single runtime instance which can consume data from a data pipeline
/// Class that represents Python single runtime instance which can consume data from a data pipeline
class PythonRuntimeContext : public RuntimeContext { class PythonRuntimeContext : public RuntimeContext {
public: public:
/// Method to terminate the runtime, this will not release the resources /// Method to terminate the runtime, this will not release the resources
/// \return Status error code /// \return Status error code
Status Terminate() override; Status Terminate() override;


// Safe destructing the tree that includes python objects
~PythonRuntimeContext() {
Terminate();
{
py::gil_scoped_acquire gil_acquire;
tree_consumer_.reset();
}
}
/// Safe destructing the tree that includes python objects
~PythonRuntimeContext() override;


PythonIteratorConsumer *GetPythonConsumer() { return dynamic_cast<PythonIteratorConsumer *>(tree_consumer_.get()); }
PythonIteratorConsumer *GetPythonConsumer();

private:
/// Internal function to perform the termination
/// \return Status error code
Status TerminateImpl();
}; };


} // namespace mindspore::dataset } // namespace mindspore::dataset


+ 13
- 0
mindspore/ccsrc/minddata/dataset/engine/runtime_context.cc View File

@@ -22,4 +22,17 @@ namespace mindspore::dataset {
void RuntimeContext::AssignConsumer(std::shared_ptr<TreeConsumer> tree_consumer) { void RuntimeContext::AssignConsumer(std::shared_ptr<TreeConsumer> tree_consumer) {
tree_consumer_ = std::move(tree_consumer); tree_consumer_ = std::move(tree_consumer);
} }
Status NativeRuntimeContext::Terminate() { return TerminateImpl(); }

Status NativeRuntimeContext::TerminateImpl() {
CHECK_FAIL_RETURN_UNEXPECTED(tree_consumer_ != nullptr, " Tree Consumer is not initialized");
return tree_consumer_->Terminate();
}

NativeRuntimeContext::~NativeRuntimeContext() { TerminateImpl(); }

TreeConsumer *RuntimeContext::GetConsumer() { return tree_consumer_.get(); }

Status RuntimeContext::Init() { return GlobalInit(); }

} // namespace mindspore::dataset } // namespace mindspore::dataset

+ 23
- 9
mindspore/ccsrc/minddata/dataset/engine/runtime_context.h View File

@@ -23,8 +23,7 @@


namespace mindspore::dataset { namespace mindspore::dataset {
class TreeConsumer; class TreeConsumer;

/// Class the represents single runtime instance which can consume data from a data pipeline
/// Class that represents single runtime instance which can consume data from a data pipeline
class RuntimeContext { class RuntimeContext {
public: public:
/// Default constructor /// Default constructor
@@ -32,11 +31,7 @@ class RuntimeContext {


/// Initialize the runtime, for now we just call the global init /// Initialize the runtime, for now we just call the global init
/// \return Status error code /// \return Status error code
Status Init() { return GlobalInit(); }

/// Method to terminate the runtime, this will not release the resources
/// \return Status error code
virtual Status Terminate() { return Status::OK(); }
Status Init();


/// Set the tree consumer /// Set the tree consumer
/// \param tree_consumer to be assigned /// \param tree_consumer to be assigned
@@ -44,13 +39,32 @@ class RuntimeContext {


/// Get the tree consumer /// Get the tree consumer
/// \return Raw pointer to the tree consumer. /// \return Raw pointer to the tree consumer.
TreeConsumer *GetConsumer() { return tree_consumer_.get(); }
TreeConsumer *GetConsumer();


~RuntimeContext() { Terminate(); }
/// Method to terminate the runtime, this will not release the resources
/// \return Status error code
virtual Status Terminate() = 0;

virtual ~RuntimeContext() = default;


protected: protected:
std::shared_ptr<TreeConsumer> tree_consumer_; std::shared_ptr<TreeConsumer> tree_consumer_;
}; };


/// Class that represents C++ single runtime instance which can consume data from a data pipeline
class NativeRuntimeContext : public RuntimeContext {
public:
/// Method to terminate the runtime, this will not release the resources
/// \return Status error code
Status Terminate() override;

~NativeRuntimeContext() override;

private:
/// Internal function to perform the termination
/// \return Status error code
Status TerminateImpl();
};

} // namespace mindspore::dataset } // namespace mindspore::dataset
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_RUNTIME_CONTEXT_H_ #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_RUNTIME_CONTEXT_H_

+ 2
- 2
mindspore/ccsrc/minddata/dataset/include/iterator.h View File

@@ -33,7 +33,7 @@ class DatasetIterator;
class DatasetOp; class DatasetOp;
class Tensor; class Tensor;


class RuntimeContext;
class NativeRuntimeContext;
class IteratorConsumer; class IteratorConsumer;


class Dataset; class Dataset;
@@ -113,7 +113,7 @@ class Iterator {
_Iterator end() { return _Iterator(nullptr); } _Iterator end() { return _Iterator(nullptr); }


private: private:
std::unique_ptr<RuntimeContext> runtime_context_;
std::unique_ptr<NativeRuntimeContext> runtime_context_;
IteratorConsumer *consumer_; IteratorConsumer *consumer_;
}; };
} // namespace dataset } // namespace dataset


Loading…
Cancel
Save