| @@ -111,7 +111,7 @@ bool Dataset::DeviceQueue(bool send_epoch_end) { | |||
| Status rc; | |||
| // Build and launch tree | |||
| std::unique_ptr<RuntimeContext> runtime_context = std::make_unique<RuntimeContext>(); | |||
| std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>(); | |||
| rc = runtime_context->Init(); | |||
| if (rc.IsError()) { | |||
| MS_LOG(ERROR) << "Failed to init runtime context. Error status: " << rc; | |||
| @@ -147,7 +147,7 @@ bool Dataset::Save(std::string dataset_path, int32_t num_files, std::string data | |||
| Status rc; | |||
| // Build and launch tree | |||
| auto ds = shared_from_this(); | |||
| std::unique_ptr<RuntimeContext> runtime_context = std::make_unique<RuntimeContext>(); | |||
| std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>(); | |||
| rc = runtime_context->Init(); | |||
| if (rc.IsError()) { | |||
| MS_LOG(ERROR) << "CreateSaver failed." << rc; | |||
| @@ -193,7 +193,7 @@ Dataset::Dataset() { tree_getters_ = std::make_shared<TreeGetters>(); } | |||
| int64_t Dataset::GetDatasetSize() { | |||
| int64_t dataset_size; | |||
| Status rc; | |||
| std::unique_ptr<RuntimeContext> runtime_context = std::make_unique<RuntimeContext>(); | |||
| std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>(); | |||
| rc = runtime_context->Init(); | |||
| if (rc.IsError()) { | |||
| MS_LOG(ERROR) << "GetDatasetSize: Initializing RuntimeContext failed."; | |||
| @@ -213,7 +213,7 @@ int64_t Dataset::GetDatasetSize() { | |||
| std::vector<DataType> Dataset::GetOutputTypes() { | |||
| std::vector<DataType> types; | |||
| Status rc; | |||
| std::unique_ptr<RuntimeContext> runtime_context = std::make_unique<RuntimeContext>(); | |||
| std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>(); | |||
| rc = runtime_context->Init(); | |||
| if (rc.IsError()) { | |||
| MS_LOG(ERROR) << "GetOutputTypes: Initializing RuntimeContext failed."; | |||
| @@ -240,7 +240,7 @@ std::vector<DataType> Dataset::GetOutputTypes() { | |||
| std::vector<TensorShape> Dataset::GetOutputShapes() { | |||
| std::vector<TensorShape> shapes; | |||
| Status rc; | |||
| std::unique_ptr<RuntimeContext> runtime_context = std::make_unique<RuntimeContext>(); | |||
| std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>(); | |||
| rc = runtime_context->Init(); | |||
| if (rc.IsError()) { | |||
| MS_LOG(ERROR) << "GetOutputShapes: Initializing RuntimeContext failed."; | |||
| @@ -268,7 +268,7 @@ int64_t Dataset::GetNumClasses() { | |||
| int64_t num_classes; | |||
| auto ds = shared_from_this(); | |||
| Status rc; | |||
| std::unique_ptr<RuntimeContext> runtime_context = std::make_unique<RuntimeContext>(); | |||
| std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>(); | |||
| rc = runtime_context->Init(); | |||
| if (rc.IsError()) { | |||
| MS_LOG(ERROR) << "GetNumClasses: Initializing RuntimeContext failed."; | |||
| @@ -562,7 +562,7 @@ int64_t Dataset::GetBatchSize() { | |||
| int64_t batch_size; | |||
| auto ds = shared_from_this(); | |||
| Status rc; | |||
| std::unique_ptr<RuntimeContext> runtime_context = std::make_unique<RuntimeContext>(); | |||
| std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>(); | |||
| rc = runtime_context->Init(); | |||
| if (rc.IsError()) { | |||
| MS_LOG(ERROR) << "GetBatchSize: Initializing RuntimeContext failed."; | |||
| @@ -583,7 +583,7 @@ int64_t Dataset::GetRepeatCount() { | |||
| int64_t repeat_count; | |||
| auto ds = shared_from_this(); | |||
| Status rc; | |||
| std::unique_ptr<RuntimeContext> runtime_context = std::make_unique<RuntimeContext>(); | |||
| std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>(); | |||
| rc = runtime_context->Init(); | |||
| if (rc.IsError()) { | |||
| MS_LOG(ERROR) << "GetRepeatCount: Initializing RuntimeContext failed."; | |||
| @@ -613,7 +613,7 @@ std::shared_ptr<SentencePieceVocab> Dataset::BuildSentencePieceVocab( | |||
| auto ds = std::make_shared<BuildSentenceVocabNode>(IRNode(), vocab, col_names, vocab_size, character_coverage, | |||
| model_type, params); | |||
| std::unique_ptr<RuntimeContext> runtime_context = std::make_unique<RuntimeContext>(); | |||
| std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>(); | |||
| Status rc = runtime_context->Init(); | |||
| if (rc.IsError()) { | |||
| MS_LOG(ERROR) << "BuildSentencePieceVocab: Failed to init runtime context. Error status: " << rc; | |||
| @@ -645,7 +645,7 @@ std::shared_ptr<Vocab> Dataset::BuildVocab(const std::vector<std::string> &colum | |||
| auto ds = | |||
| std::make_shared<BuildVocabNode>(IRNode(), vocab, columns, freq_range, top_k, special_tokens, special_first); | |||
| std::unique_ptr<RuntimeContext> runtime_context = std::make_unique<RuntimeContext>(); | |||
| std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>(); | |||
| Status rc = runtime_context->Init(); | |||
| if (rc.IsError()) { | |||
| MS_LOG(ERROR) << "BuildVocab: Failed to init runtime context. Error status: " << rc; | |||
| @@ -48,7 +48,7 @@ void Iterator::Stop() { runtime_context_->Terminate(); } | |||
| // Function to build and launch the execution tree. | |||
| Status Iterator::BuildAndLaunchTree(std::shared_ptr<Dataset> ds) { | |||
| runtime_context_ = std::make_unique<RuntimeContext>(); | |||
| runtime_context_ = std::make_unique<NativeRuntimeContext>(); | |||
| RETURN_IF_NOT_OK(runtime_context_->Init()); | |||
| auto consumer = std::make_unique<IteratorConsumer>(); | |||
| consumer_ = consumer.get(); | |||
| @@ -19,9 +19,24 @@ | |||
| namespace mindspore::dataset { | |||
| Status PythonRuntimeContext::Terminate() { | |||
| Status PythonRuntimeContext::Terminate() { return TerminateImpl(); } | |||
| Status PythonRuntimeContext::TerminateImpl() { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(tree_consumer_ != nullptr, " Tree Consumer is not initialized"); | |||
| // Release GIL before joining all threads | |||
| py::gil_scoped_release gil_release; | |||
| return tree_consumer_->Terminate(); | |||
| } | |||
| PythonRuntimeContext::~PythonRuntimeContext() { | |||
| TerminateImpl(); | |||
| { | |||
| py::gil_scoped_acquire gil_acquire; | |||
| tree_consumer_.reset(); | |||
| } | |||
| } | |||
| PythonIteratorConsumer *PythonRuntimeContext::GetPythonConsumer() { | |||
| return dynamic_cast<PythonIteratorConsumer *>(tree_consumer_.get()); | |||
| } | |||
| } // namespace mindspore::dataset | |||
| @@ -24,25 +24,24 @@ | |||
| #include "minddata/dataset/engine/runtime_context.h" | |||
| namespace mindspore::dataset { | |||
| class RuntimeContext; | |||
| class NativeRuntimeContext; | |||
| /// Class that represents single runtime instance which can consume data from a data pipeline | |||
| /// Class that represents Python single runtime instance which can consume data from a data pipeline | |||
| class PythonRuntimeContext : public RuntimeContext { | |||
| public: | |||
| /// Method to terminate the runtime, this will not release the resources | |||
| /// \return Status error code | |||
| Status Terminate() override; | |||
| // Safe destructing the tree that includes python objects | |||
| ~PythonRuntimeContext() { | |||
| Terminate(); | |||
| { | |||
| py::gil_scoped_acquire gil_acquire; | |||
| tree_consumer_.reset(); | |||
| } | |||
| } | |||
| /// Safe destructing the tree that includes python objects | |||
| ~PythonRuntimeContext() override; | |||
| PythonIteratorConsumer *GetPythonConsumer() { return dynamic_cast<PythonIteratorConsumer *>(tree_consumer_.get()); } | |||
| PythonIteratorConsumer *GetPythonConsumer(); | |||
| private: | |||
| /// Internal function to perform the termination | |||
| /// \return Status error code | |||
| Status TerminateImpl(); | |||
| }; | |||
| } // namespace mindspore::dataset | |||
| @@ -22,4 +22,17 @@ namespace mindspore::dataset { | |||
| void RuntimeContext::AssignConsumer(std::shared_ptr<TreeConsumer> tree_consumer) { | |||
| tree_consumer_ = std::move(tree_consumer); | |||
| } | |||
| Status NativeRuntimeContext::Terminate() { return TerminateImpl(); } | |||
| Status NativeRuntimeContext::TerminateImpl() { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(tree_consumer_ != nullptr, " Tree Consumer is not initialized"); | |||
| return tree_consumer_->Terminate(); | |||
| } | |||
| NativeRuntimeContext::~NativeRuntimeContext() { TerminateImpl(); } | |||
| TreeConsumer *RuntimeContext::GetConsumer() { return tree_consumer_.get(); } | |||
| Status RuntimeContext::Init() { return GlobalInit(); } | |||
| } // namespace mindspore::dataset | |||
| @@ -23,8 +23,7 @@ | |||
| namespace mindspore::dataset { | |||
| class TreeConsumer; | |||
| /// Class the represents single runtime instance which can consume data from a data pipeline | |||
| /// Class that represents single runtime instance which can consume data from a data pipeline | |||
| class RuntimeContext { | |||
| public: | |||
| /// Default constructor | |||
| @@ -32,11 +31,7 @@ class RuntimeContext { | |||
| /// Initialize the runtime, for now we just call the global init | |||
| /// \return Status error code | |||
| Status Init() { return GlobalInit(); } | |||
| /// Method to terminate the runtime, this will not release the resources | |||
| /// \return Status error code | |||
| virtual Status Terminate() { return Status::OK(); } | |||
| Status Init(); | |||
| /// Set the tree consumer | |||
| /// \param tree_consumer to be assigned | |||
| @@ -44,13 +39,32 @@ class RuntimeContext { | |||
| /// Get the tree consumer | |||
| /// \return Raw pointer to the tree consumer. | |||
| TreeConsumer *GetConsumer() { return tree_consumer_.get(); } | |||
| TreeConsumer *GetConsumer(); | |||
| ~RuntimeContext() { Terminate(); } | |||
| /// Method to terminate the runtime, this will not release the resources | |||
| /// \return Status error code | |||
| virtual Status Terminate() = 0; | |||
| virtual ~RuntimeContext() = default; | |||
| protected: | |||
| std::shared_ptr<TreeConsumer> tree_consumer_; | |||
| }; | |||
| /// Class that represents C++ single runtime instance which can consume data from a data pipeline | |||
| class NativeRuntimeContext : public RuntimeContext { | |||
| public: | |||
| /// Method to terminate the runtime, this will not release the resources | |||
| /// \return Status error code | |||
| Status Terminate() override; | |||
| ~NativeRuntimeContext() override; | |||
| private: | |||
| /// Internal function to perform the termination | |||
| /// \return Status error code | |||
| Status TerminateImpl(); | |||
| }; | |||
| } // namespace mindspore::dataset | |||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_RUNTIME_CONTEXT_H_ | |||
| @@ -33,7 +33,7 @@ class DatasetIterator; | |||
| class DatasetOp; | |||
| class Tensor; | |||
| class RuntimeContext; | |||
| class NativeRuntimeContext; | |||
| class IteratorConsumer; | |||
| class Dataset; | |||
| @@ -113,7 +113,7 @@ class Iterator { | |||
| _Iterator end() { return _Iterator(nullptr); } | |||
| private: | |||
| std::unique_ptr<RuntimeContext> runtime_context_; | |||
| std::unique_ptr<NativeRuntimeContext> runtime_context_; | |||
| IteratorConsumer *consumer_; | |||
| }; | |||
| } // namespace dataset | |||