Merge pull request !7555 from h.farahat/consumerstags/v1.1.0
| @@ -15,6 +15,7 @@ | |||||
| */ | */ | ||||
| #include "minddata/dataset/include/iterator.h" | #include "minddata/dataset/include/iterator.h" | ||||
| #include "minddata/dataset/core/client.h" | #include "minddata/dataset/core/client.h" | ||||
| #include "minddata/dataset/engine/consumers/tree_consumer.h" | |||||
| #include "minddata/dataset/include/datasets.h" | #include "minddata/dataset/include/datasets.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| @@ -23,7 +24,7 @@ namespace api { | |||||
| // Get the next row from the data pipeline. | // Get the next row from the data pipeline. | ||||
| bool Iterator::GetNextRow(TensorMap *row) { | bool Iterator::GetNextRow(TensorMap *row) { | ||||
| Status rc = iterator_->GetNextAsMap(row); | |||||
| Status rc = consumer_->GetNextAsMap(row); | |||||
| if (rc.IsError()) { | if (rc.IsError()) { | ||||
| MS_LOG(ERROR) << "GetNextRow: Failed to get next row. Error status: " << rc; | MS_LOG(ERROR) << "GetNextRow: Failed to get next row. Error status: " << rc; | ||||
| row->clear(); | row->clear(); | ||||
| @@ -34,100 +35,27 @@ bool Iterator::GetNextRow(TensorMap *row) { | |||||
| // Get the next row from the data pipeline. | // Get the next row from the data pipeline. | ||||
| bool Iterator::GetNextRow(TensorVec *row) { | bool Iterator::GetNextRow(TensorVec *row) { | ||||
| TensorRow tensor_row; | |||||
| Status rc = iterator_->FetchNextTensorRow(&tensor_row); | |||||
| Status rc = consumer_->GetNextAsVector(row); | |||||
| if (rc.IsError()) { | if (rc.IsError()) { | ||||
| MS_LOG(ERROR) << "GetNextRow: Failed to get next row. Error status: " << rc; | MS_LOG(ERROR) << "GetNextRow: Failed to get next row. Error status: " << rc; | ||||
| row->clear(); | row->clear(); | ||||
| return false; | return false; | ||||
| } | } | ||||
| // Generate a vector as return | |||||
| row->clear(); | |||||
| std::copy(tensor_row.begin(), tensor_row.end(), std::back_inserter(*row)); | |||||
| return true; | return true; | ||||
| } | } | ||||
| // Shut down the data pipeline. | // Shut down the data pipeline. | ||||
| void Iterator::Stop() { | |||||
| // Releasing the iterator_ unique_ptre. This should trigger the destructor of iterator_. | |||||
| iterator_.reset(); | |||||
| // Release ownership of tree_ shared pointer. This will decrement the ref count. | |||||
| tree_.reset(); | |||||
| } | |||||
| // Function to build and launch the execution tree. | |||||
| void Iterator::Stop() { runtime_context->Terminate(); } | |||||
| // | |||||
| //// Function to build and launch the execution tree. | |||||
| Status Iterator::BuildAndLaunchTree(std::shared_ptr<Dataset> ds) { | Status Iterator::BuildAndLaunchTree(std::shared_ptr<Dataset> ds) { | ||||
| // One time init | |||||
| Status rc; | |||||
| rc = GlobalInit(); | |||||
| RETURN_IF_NOT_OK(rc); | |||||
| // Instantiate the execution tree | |||||
| tree_ = std::make_shared<ExecutionTree>(); | |||||
| // Iterative BFS converting Dataset tree into runtime Execution tree. | |||||
| std::queue<std::pair<std::shared_ptr<Dataset>, std::shared_ptr<DatasetOp>>> q; | |||||
| if (ds == nullptr) { | |||||
| RETURN_STATUS_UNEXPECTED("Input is null pointer"); | |||||
| } else { | |||||
| // Convert the current root node. | |||||
| auto root_ops = ds->Build(); | |||||
| if (root_ops.empty()) { | |||||
| RETURN_STATUS_UNEXPECTED("Node operation returned nothing"); | |||||
| } | |||||
| // Iterate through all the DatasetOps returned by Dataset's Build(), associate them | |||||
| // with the execution tree and add the child and parent relationship between the nodes | |||||
| // Note that some Dataset objects might return more than one DatasetOps | |||||
| // e.g. MapDataset will return [ProjectOp, MapOp] if project_columns is set for MapDataset | |||||
| std::shared_ptr<DatasetOp> prev_op = nullptr; | |||||
| for (auto op : root_ops) { | |||||
| RETURN_IF_NOT_OK(tree_->AssociateNode(op)); | |||||
| if (prev_op != nullptr) { | |||||
| RETURN_IF_NOT_OK(prev_op->AddChild(op)); | |||||
| } | |||||
| prev_op = op; | |||||
| } | |||||
| // Add the last DatasetOp to the queue to be BFS. | |||||
| q.push(std::make_pair(ds, root_ops.back())); | |||||
| // Traverse down to the children and convert them to the corresponding DatasetOps (i.e. execution tree nodes) | |||||
| while (!q.empty()) { | |||||
| auto node_pair = q.front(); | |||||
| q.pop(); | |||||
| // Iterate through all the direct children of the first element in our BFS queue | |||||
| for (auto child : node_pair.first->children) { | |||||
| auto child_ops = child->Build(); | |||||
| if (child_ops.empty()) { | |||||
| RETURN_STATUS_UNEXPECTED("Node operation returned nothing"); | |||||
| } | |||||
| auto node_op = node_pair.second; | |||||
| // Iterate through all the DatasetOps returned by calling Build on the last Dataset object, associate them | |||||
| // with the execution tree and add the child and parent relationship between the nodes | |||||
| // Note that some Dataset objects might return more than one DatasetOps | |||||
| // e.g. MapDataset will return MapOp and ProjectOp if project_columns is set for MapDataset | |||||
| for (auto child_op : child_ops) { | |||||
| RETURN_IF_NOT_OK(tree_->AssociateNode(child_op)); | |||||
| RETURN_IF_NOT_OK(node_op->AddChild(child_op)); | |||||
| node_op = child_op; | |||||
| } | |||||
| // Add the child and the last element of the returned DatasetOps (which is now the leaf node in our current | |||||
| // execution tree) to the BFS queue | |||||
| q.push(std::make_pair(child, child_ops.back())); | |||||
| } | |||||
| } | |||||
| RETURN_IF_NOT_OK(tree_->AssignRoot(root_ops.front())); | |||||
| } | |||||
| // Launch the execution tree. | |||||
| RETURN_IF_NOT_OK(tree_->Prepare()); | |||||
| tree_->Launch(); | |||||
| iterator_ = std::make_unique<DatasetIterator>(tree_); | |||||
| RETURN_UNEXPECTED_IF_NULL(iterator_); | |||||
| return rc; | |||||
| runtime_context = std::make_unique<RuntimeContext>(); | |||||
| RETURN_IF_NOT_OK(runtime_context->Init()); | |||||
| auto consumer = std::make_unique<IteratorConsumer>(); | |||||
| consumer_ = consumer.get(); | |||||
| RETURN_IF_NOT_OK(consumer->Init(ds)); | |||||
| runtime_context->AssignConsumer(std::move(consumer)); | |||||
| return Status::OK(); | |||||
| } | } | ||||
| } // namespace api | } // namespace api | ||||
| @@ -12,12 +12,14 @@ endif () | |||||
| file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") | file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") | ||||
| set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) | set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) | ||||
| add_library(engine OBJECT | add_library(engine OBJECT | ||||
| execution_tree.cc | |||||
| data_buffer.cc | |||||
| data_schema.cc | |||||
| dataset_iterator.cc | |||||
| tree_adapter.cc | |||||
| ) | |||||
| execution_tree.cc | |||||
| data_buffer.cc | |||||
| data_schema.cc | |||||
| dataset_iterator.cc | |||||
| tree_adapter.cc | |||||
| runtime_context.cc | |||||
| consumers/tree_consumer.cc | |||||
| ) | |||||
| if (ENABLE_PYTHON) | if (ENABLE_PYTHON) | ||||
| target_include_directories(engine PRIVATE ${pybind11_INCLUDE_DIRS}) | target_include_directories(engine PRIVATE ${pybind11_INCLUDE_DIRS}) | ||||
| @@ -0,0 +1,49 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CONSUMER_PYTHON_TREE_CONSUMER_H_ | |||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CONSUMER_PYTHON_TREE_CONSUMER_H_ | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <unordered_map> | |||||
| #include <utility> | |||||
| #include <vector> | |||||
| #include "minddata/dataset/engine/consumers/tree_consumer.h" | |||||
| namespace mindspore::dataset { | |||||
| /// Consumer that iterates over the dataset and returns the rows one by one as a python list or a dict | |||||
| class PythonIterator : public IteratorConsumer { | |||||
| /// Constructor | |||||
| /// \param num_epochs number of epochs. Default to -1 (infinite epochs). | |||||
| explicit PythonIterator(int32_t num_epochs = -1) : IteratorConsumer(num_epochs) {} | |||||
| /// Get the next row as a python dict | |||||
| /// \param[out] output python dict | |||||
| /// \return Status error code | |||||
| Status GetNextAsMap(py::dict *output) { | |||||
| return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet."); | |||||
| } | |||||
| /// Get the next row as a python dict | |||||
| /// \param[out] output python dict | |||||
| /// \return Status error code | |||||
| Status GetNextAsList(py::list *output) { | |||||
| return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet."); | |||||
| } | |||||
| }; | |||||
| } // namespace mindspore::dataset | |||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CONSUMER_PYTHON_TREE_CONSUMER_H_ | |||||
| @@ -0,0 +1,72 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <algorithm> | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <unordered_map> | |||||
| #include <utility> | |||||
| #include <vector> | |||||
| #include "minddata/dataset/engine/consumers/tree_consumer.h" | |||||
| #include "minddata/dataset/engine/tree_adapter.h" | |||||
| namespace mindspore::dataset { | |||||
| Status IteratorConsumer::GetNextAsVector(std::vector<TensorPtr> *out) { | |||||
| RETURN_UNEXPECTED_IF_NULL(out); | |||||
| out->clear(); | |||||
| TensorRow res; | |||||
| RETURN_IF_NOT_OK(tree_adapter_->GetNext(&res)); | |||||
| // Return empty vector if there's no data | |||||
| RETURN_OK_IF_TRUE(res.empty()); | |||||
| std::copy(res.begin(), res.end(), std::back_inserter(*out)); | |||||
| return Status::OK(); | |||||
| } | |||||
| Status IteratorConsumer::GetNextAsMap(std::unordered_map<std::string, TensorPtr> *out_map) { | |||||
| RETURN_UNEXPECTED_IF_NULL(out_map); | |||||
| out_map->clear(); | |||||
| TensorRow res; | |||||
| RETURN_IF_NOT_OK(tree_adapter_->GetNext(&res)); | |||||
| // Return empty map if there's no data | |||||
| RETURN_OK_IF_TRUE(res.empty()); | |||||
| // Populate the out map from the row and return it | |||||
| for (const auto &colMap : tree_adapter_->GetColumnNameMap()) { | |||||
| (*out_map)[colMap.first] = std::move(res[colMap.second]); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| TreeConsumer::TreeConsumer() { tree_adapter_ = std::make_unique<TreeAdapter>(); } | |||||
| Status IteratorConsumer::Init(std::shared_ptr<api::Dataset> d) { | |||||
| return tree_adapter_->BuildAndPrepare(std::move(d), num_epochs_); | |||||
| } | |||||
| Status TreeConsumer::Init(std::shared_ptr<api::Dataset> d) { return tree_adapter_->BuildAndPrepare(std::move(d)); } | |||||
| Status ToDevice::Init(std::shared_ptr<api::Dataset> d) { | |||||
| // TODO(CRC): | |||||
| // Get device ID from children look at get_distribution in python | |||||
| // Add DeviceQue IR on top of dataset d | |||||
| return tree_adapter_->BuildAndPrepare(std::move(d), num_epochs_); | |||||
| } | |||||
| } // namespace mindspore::dataset | |||||
| @@ -0,0 +1,154 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CONSUMER_TREE_CONSUMER_H_ | |||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CONSUMER_TREE_CONSUMER_H_ | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <unordered_map> | |||||
| #include <utility> | |||||
| #include <vector> | |||||
| #include "minddata/dataset/engine/tree_adapter.h" | |||||
| namespace mindspore::dataset { | |||||
| // Forward declare | |||||
| class TreeAdapter; | |||||
| namespace api { | |||||
| class Dataset; | |||||
| } | |||||
| /// A base class for tree consumers which would fetch rows from the tree pipeline | |||||
| class TreeConsumer { | |||||
| public: | |||||
| /// Constructor that prepares an empty tree_adapter | |||||
| TreeConsumer(); | |||||
| /// Initializes the consumer, this involves constructing and preparing the tree. | |||||
| /// \param d The dataset node that represent the root of the IR tree. | |||||
| /// \return Status error code. | |||||
| virtual Status Init(std::shared_ptr<api::Dataset> d); | |||||
| protected: | |||||
| /// The class owns the tree_adapter that handles execution tree operations. | |||||
| std::unique_ptr<TreeAdapter> tree_adapter_; | |||||
| /// Method to return the name of the consumer | |||||
| /// \return string | |||||
| virtual std::string Name() = 0; | |||||
| }; | |||||
| /// Consumer that iterates over the dataset and returns the rows one by one as a vector or a map | |||||
| class IteratorConsumer : public TreeConsumer { | |||||
| public: | |||||
| /// Constructor which will call the base class default constructor. | |||||
| /// \param num_epochs number of epochs. Default to -1 (infinite epochs). | |||||
| explicit IteratorConsumer(int32_t num_epochs = -1) : TreeConsumer(), num_epochs_(num_epochs) {} | |||||
| Status Init(std::shared_ptr<api::Dataset> d) override; | |||||
| /// Returns the next row in a vector format | |||||
| /// \param[out] out std::vector of Tensors | |||||
| /// \return Status error code | |||||
| Status GetNextAsVector(std::vector<TensorPtr> *out); | |||||
| /// Returns the next row in as a map | |||||
| /// \param[out] out std::map of string to Tensor | |||||
| /// \return Status error code | |||||
| Status GetNextAsMap(std::unordered_map<std::string, TensorPtr> *out); | |||||
| protected: | |||||
| /// Method to return the name of the consumer | |||||
| /// \return string | |||||
| std::string Name() override { return "IteratorConsumer"; } | |||||
| private: | |||||
| int32_t num_epochs_; | |||||
| }; | |||||
| /// Consumer that iterates over the dataset and writes it to desk | |||||
| class SaveToDesk : public TreeConsumer { | |||||
| public: | |||||
| /// Constructor which will call the base class default constructor. | |||||
| /// \param dataset_path path the the dataset | |||||
| /// \param num_files number of files. Default to 1 | |||||
| /// \param dataset_type The format of the dataset. Default to "mindrecod". | |||||
| explicit SaveToDesk(std::string dataset_path, int32_t num_files = 1, std::string dataset_type = "mindrecord") | |||||
| : TreeConsumer(), dataset_path_(dataset_path), num_files_(num_files), dataset_type_(dataset_type) {} | |||||
| /// Save the given dataset to MindRecord format on desk. This is a blocking method (i.e., after returning, all rows | |||||
| /// would be written to desk) | |||||
| /// \return Status error code | |||||
| Status Save() { return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet."); } | |||||
| private: | |||||
| std::string dataset_path_; | |||||
| int32_t num_files_; | |||||
| std::string dataset_type_; | |||||
| }; | |||||
| /// Consumer that iterates over the dataset and send it to a device | |||||
| class ToDevice : public TreeConsumer { | |||||
| public: | |||||
| ToDevice(std::string device_type, bool send_epoch_end, int32_t num_epochs) | |||||
| : TreeConsumer(), device_type_(device_type), send_epoch_end_(send_epoch_end), num_epochs_(num_epochs) {} | |||||
| Status Init(std::shared_ptr<api::Dataset> d) override; | |||||
| Status Send() { | |||||
| // TODO(CRC): launch the tree | |||||
| return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet."); | |||||
| } | |||||
| Status Stop() { | |||||
| // TODO(CRC): Get root + call StopSend | |||||
| return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet."); | |||||
| } | |||||
| Status Continue() { | |||||
| // TODO(CRC): Get root + call StopSend | |||||
| return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet."); | |||||
| } | |||||
| private: | |||||
| std::string device_type_; | |||||
| bool send_epoch_end_; | |||||
| int32_t num_epochs_; | |||||
| }; | |||||
| /// Consumer that is used to get some pipeline information | |||||
| class TreeGetters : public TreeConsumer { | |||||
| Status GetDatasetSize(int32_t *size) { | |||||
| return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet."); | |||||
| } | |||||
| Status GetBatchSize(int32_t *batch_size) { | |||||
| return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet."); | |||||
| } | |||||
| Status GetRepeatCount(int32_t *repeat_count) { | |||||
| return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet."); | |||||
| } | |||||
| Status GetNumClasses(int32_t *num_classes) { | |||||
| return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet."); | |||||
| } | |||||
| Status GetOutputShapes(std::vector<TensorShape> *shapes) { | |||||
| return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet."); | |||||
| } | |||||
| Status GetOutputTypes(std::vector<DataType> *types) { | |||||
| return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet."); | |||||
| } | |||||
| Status GetOutputNames(std::vector<std::string> *names) { | |||||
| return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet."); | |||||
| } | |||||
| }; | |||||
| } // namespace mindspore::dataset | |||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CONSUMER_TREE_CONSUMER_H_ | |||||
| @@ -0,0 +1,25 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "minddata/dataset/engine/runtime_context.h" | |||||
| #include <memory> | |||||
| #include <utility> | |||||
| namespace mindspore::dataset { | |||||
| void RuntimeContext::AssignConsumer(std::unique_ptr<TreeConsumer> tree_consumer) { | |||||
| tree_consumer_ = std::move(tree_consumer); | |||||
| } | |||||
| } // namespace mindspore::dataset | |||||
| @@ -0,0 +1,54 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_RUNTIME_CONTEXT_H_ | |||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_RUNTIME_CONTEXT_H_ | |||||
| #include <memory> | |||||
| #include <utility> | |||||
| #include "minddata/dataset/core/client.h" | |||||
| #include "minddata/dataset/engine/consumers/tree_consumer.h" | |||||
| namespace mindspore::dataset { | |||||
| class TreeConsumer; | |||||
| /// Class the represents single runtime instance which can consume data from a data pipeline | |||||
| class RuntimeContext { | |||||
| public: | |||||
| /// Default constructor | |||||
| RuntimeContext() = default; | |||||
| /// Initialize the runtime, for now we just call the global init | |||||
| /// \return Status error code | |||||
| Status Init() { return GlobalInit(); } | |||||
| /// Method to terminate the runtime, this will not release the resources | |||||
| /// \return Status error code | |||||
| virtual Status Terminate() { return Status::OK(); } | |||||
| /// Set the tree consumer | |||||
| /// \param tree_consumer to be assigned | |||||
| void AssignConsumer(std::unique_ptr<TreeConsumer> tree_consumer); | |||||
| /// Get the tree consumer | |||||
| /// \return Raw pointer to the tree consumer. | |||||
| TreeConsumer *GetConsumer() { return tree_consumer_.get(); } | |||||
| private: | |||||
| std::unique_ptr<TreeConsumer> tree_consumer_; | |||||
| }; | |||||
| } // namespace mindspore::dataset | |||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_RUNTIME_CONTEXT_H_ | |||||
| @@ -26,8 +26,6 @@ Status TreeAdapter::BuildAndPrepare(std::shared_ptr<api::Dataset> root_ir, int32 | |||||
| // Check whether this function has been called before. If so, return fail | // Check whether this function has been called before. If so, return fail | ||||
| CHECK_FAIL_RETURN_UNEXPECTED(tree_ == nullptr, "ExecutionTree is already built."); | CHECK_FAIL_RETURN_UNEXPECTED(tree_ == nullptr, "ExecutionTree is already built."); | ||||
| RETURN_UNEXPECTED_IF_NULL(root_ir); | RETURN_UNEXPECTED_IF_NULL(root_ir); | ||||
| // GlobalInit, might need to be moved to the proper place once RuntimeConext is complete | |||||
| RETURN_IF_NOT_OK(GlobalInit()); | |||||
| // this will evolve in the long run | // this will evolve in the long run | ||||
| tree_ = std::make_unique<ExecutionTree>(); | tree_ = std::make_unique<ExecutionTree>(); | ||||
| @@ -28,6 +28,9 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| namespace api { | |||||
| class Dataset; | |||||
| } | |||||
| class TreeAdapter { | class TreeAdapter { | ||||
| public: | public: | ||||
| TreeAdapter() = default; | TreeAdapter() = default; | ||||
| @@ -52,6 +52,8 @@ class Vocab; | |||||
| #endif | #endif | ||||
| namespace api { | namespace api { | ||||
| class Dataset; | |||||
| class Iterator; | |||||
| class TensorOperation; | class TensorOperation; | ||||
| class SchemaObj; | class SchemaObj; | ||||
| @@ -17,10 +17,11 @@ | |||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_ITERATOR_H_ | #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_ITERATOR_H_ | ||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_ITERATOR_H_ | #define MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_ITERATOR_H_ | ||||
| #include <unordered_map> | |||||
| #include <memory> | #include <memory> | ||||
| #include <vector> | |||||
| #include <string> | #include <string> | ||||
| #include <unordered_map> | |||||
| #include <vector> | |||||
| #include "minddata/dataset/engine/runtime_context.h" | |||||
| #include "minddata/dataset/include/status.h" | #include "minddata/dataset/include/status.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| @@ -32,6 +33,8 @@ class DatasetIterator; | |||||
| class DatasetOp; | class DatasetOp; | ||||
| class Tensor; | class Tensor; | ||||
| class RuntimeContext; | |||||
| class IteratorConsumer; | |||||
| namespace api { | namespace api { | ||||
| class Dataset; | class Dataset; | ||||
| @@ -43,7 +46,7 @@ using TensorVec = std::vector<std::shared_ptr<Tensor>>; | |||||
| class Iterator { | class Iterator { | ||||
| public: | public: | ||||
| /// \brief Constructor | /// \brief Constructor | ||||
| Iterator() = default; | |||||
| Iterator() : consumer_(nullptr) {} | |||||
| /// \brief Destructor | /// \brief Destructor | ||||
| ~Iterator() = default; | ~Iterator() = default; | ||||
| @@ -111,12 +114,8 @@ class Iterator { | |||||
| _Iterator end() { return _Iterator(nullptr); } | _Iterator end() { return _Iterator(nullptr); } | ||||
| private: | private: | ||||
| // Runtime tree. | |||||
| // Use shared_ptr instead of unique_ptr because the DatasetIterator constructor takes in a shared_ptr type. | |||||
| std::shared_ptr<ExecutionTree> tree_; | |||||
| // Runtime iterator | |||||
| std::unique_ptr<DatasetIterator> iterator_; | |||||
| std::unique_ptr<RuntimeContext> runtime_context; | |||||
| IteratorConsumer *consumer_; | |||||
| }; | }; | ||||
| } // namespace api | } // namespace api | ||||
| } // namespace dataset | } // namespace dataset | ||||
| @@ -89,6 +89,7 @@ enum class StatusCode : char { | |||||
| kTimeOut = 14, | kTimeOut = 14, | ||||
| kBuddySpaceFull = 15, | kBuddySpaceFull = 15, | ||||
| kNetWorkError = 16, | kNetWorkError = 16, | ||||
| kNotImplementedYet = 17, | |||||
| // Make this error code the last one. Add new error code above it. | // Make this error code the last one. Add new error code above it. | ||||
| kUnexpectedError = 127 | kUnexpectedError = 127 | ||||
| }; | }; | ||||