Merge pull request !7555 from h.farahat/consumerstags/v1.1.0
| @@ -15,6 +15,7 @@ | |||
| */ | |||
| #include "minddata/dataset/include/iterator.h" | |||
| #include "minddata/dataset/core/client.h" | |||
| #include "minddata/dataset/engine/consumers/tree_consumer.h" | |||
| #include "minddata/dataset/include/datasets.h" | |||
| namespace mindspore { | |||
| @@ -23,7 +24,7 @@ namespace api { | |||
| // Get the next row from the data pipeline. | |||
| bool Iterator::GetNextRow(TensorMap *row) { | |||
| Status rc = iterator_->GetNextAsMap(row); | |||
| Status rc = consumer_->GetNextAsMap(row); | |||
| if (rc.IsError()) { | |||
| MS_LOG(ERROR) << "GetNextRow: Failed to get next row. Error status: " << rc; | |||
| row->clear(); | |||
| @@ -34,100 +35,27 @@ bool Iterator::GetNextRow(TensorMap *row) { | |||
| // Get the next row from the data pipeline. | |||
| bool Iterator::GetNextRow(TensorVec *row) { | |||
| TensorRow tensor_row; | |||
| Status rc = iterator_->FetchNextTensorRow(&tensor_row); | |||
| Status rc = consumer_->GetNextAsVector(row); | |||
| if (rc.IsError()) { | |||
| MS_LOG(ERROR) << "GetNextRow: Failed to get next row. Error status: " << rc; | |||
| row->clear(); | |||
| return false; | |||
| } | |||
| // Generate a vector as return | |||
| row->clear(); | |||
| std::copy(tensor_row.begin(), tensor_row.end(), std::back_inserter(*row)); | |||
| return true; | |||
| } | |||
| // Shut down the data pipeline. | |||
| void Iterator::Stop() { | |||
| // Releasing the iterator_ unique_ptre. This should trigger the destructor of iterator_. | |||
| iterator_.reset(); | |||
| // Release ownership of tree_ shared pointer. This will decrement the ref count. | |||
| tree_.reset(); | |||
| } | |||
| // Function to build and launch the execution tree. | |||
| void Iterator::Stop() { runtime_context->Terminate(); } | |||
| // | |||
| //// Function to build and launch the execution tree. | |||
| Status Iterator::BuildAndLaunchTree(std::shared_ptr<Dataset> ds) { | |||
| // One time init | |||
| Status rc; | |||
| rc = GlobalInit(); | |||
| RETURN_IF_NOT_OK(rc); | |||
| // Instantiate the execution tree | |||
| tree_ = std::make_shared<ExecutionTree>(); | |||
| // Iterative BFS converting Dataset tree into runtime Execution tree. | |||
| std::queue<std::pair<std::shared_ptr<Dataset>, std::shared_ptr<DatasetOp>>> q; | |||
| if (ds == nullptr) { | |||
| RETURN_STATUS_UNEXPECTED("Input is null pointer"); | |||
| } else { | |||
| // Convert the current root node. | |||
| auto root_ops = ds->Build(); | |||
| if (root_ops.empty()) { | |||
| RETURN_STATUS_UNEXPECTED("Node operation returned nothing"); | |||
| } | |||
| // Iterate through all the DatasetOps returned by Dataset's Build(), associate them | |||
| // with the execution tree and add the child and parent relationship between the nodes | |||
| // Note that some Dataset objects might return more than one DatasetOps | |||
| // e.g. MapDataset will return [ProjectOp, MapOp] if project_columns is set for MapDataset | |||
| std::shared_ptr<DatasetOp> prev_op = nullptr; | |||
| for (auto op : root_ops) { | |||
| RETURN_IF_NOT_OK(tree_->AssociateNode(op)); | |||
| if (prev_op != nullptr) { | |||
| RETURN_IF_NOT_OK(prev_op->AddChild(op)); | |||
| } | |||
| prev_op = op; | |||
| } | |||
| // Add the last DatasetOp to the queue to be BFS. | |||
| q.push(std::make_pair(ds, root_ops.back())); | |||
| // Traverse down to the children and convert them to the corresponding DatasetOps (i.e. execution tree nodes) | |||
| while (!q.empty()) { | |||
| auto node_pair = q.front(); | |||
| q.pop(); | |||
| // Iterate through all the direct children of the first element in our BFS queue | |||
| for (auto child : node_pair.first->children) { | |||
| auto child_ops = child->Build(); | |||
| if (child_ops.empty()) { | |||
| RETURN_STATUS_UNEXPECTED("Node operation returned nothing"); | |||
| } | |||
| auto node_op = node_pair.second; | |||
| // Iterate through all the DatasetOps returned by calling Build on the last Dataset object, associate them | |||
| // with the execution tree and add the child and parent relationship between the nodes | |||
| // Note that some Dataset objects might return more than one DatasetOps | |||
| // e.g. MapDataset will return MapOp and ProjectOp if project_columns is set for MapDataset | |||
| for (auto child_op : child_ops) { | |||
| RETURN_IF_NOT_OK(tree_->AssociateNode(child_op)); | |||
| RETURN_IF_NOT_OK(node_op->AddChild(child_op)); | |||
| node_op = child_op; | |||
| } | |||
| // Add the child and the last element of the returned DatasetOps (which is now the leaf node in our current | |||
| // execution tree) to the BFS queue | |||
| q.push(std::make_pair(child, child_ops.back())); | |||
| } | |||
| } | |||
| RETURN_IF_NOT_OK(tree_->AssignRoot(root_ops.front())); | |||
| } | |||
| // Launch the execution tree. | |||
| RETURN_IF_NOT_OK(tree_->Prepare()); | |||
| tree_->Launch(); | |||
| iterator_ = std::make_unique<DatasetIterator>(tree_); | |||
| RETURN_UNEXPECTED_IF_NULL(iterator_); | |||
| return rc; | |||
| runtime_context = std::make_unique<RuntimeContext>(); | |||
| RETURN_IF_NOT_OK(runtime_context->Init()); | |||
| auto consumer = std::make_unique<IteratorConsumer>(); | |||
| consumer_ = consumer.get(); | |||
| RETURN_IF_NOT_OK(consumer->Init(ds)); | |||
| runtime_context->AssignConsumer(std::move(consumer)); | |||
| return Status::OK(); | |||
| } | |||
| } // namespace api | |||
| @@ -12,12 +12,14 @@ endif () | |||
| file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") | |||
| set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) | |||
| add_library(engine OBJECT | |||
| execution_tree.cc | |||
| data_buffer.cc | |||
| data_schema.cc | |||
| dataset_iterator.cc | |||
| tree_adapter.cc | |||
| ) | |||
| execution_tree.cc | |||
| data_buffer.cc | |||
| data_schema.cc | |||
| dataset_iterator.cc | |||
| tree_adapter.cc | |||
| runtime_context.cc | |||
| consumers/tree_consumer.cc | |||
| ) | |||
| if (ENABLE_PYTHON) | |||
| target_include_directories(engine PRIVATE ${pybind11_INCLUDE_DIRS}) | |||
| @@ -0,0 +1,49 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CONSUMER_PYTHON_TREE_CONSUMER_H_ | |||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CONSUMER_PYTHON_TREE_CONSUMER_H_ | |||
| #include <memory> | |||
| #include <string> | |||
| #include <unordered_map> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "minddata/dataset/engine/consumers/tree_consumer.h" | |||
| namespace mindspore::dataset { | |||
| /// Consumer that iterates over the dataset and returns the rows one by one as a python list or a dict | |||
| class PythonIterator : public IteratorConsumer { | |||
| /// Constructor | |||
| /// \param num_epochs number of epochs. Default to -1 (infinite epochs). | |||
| explicit PythonIterator(int32_t num_epochs = -1) : IteratorConsumer(num_epochs) {} | |||
| /// Get the next row as a python dict | |||
| /// \param[out] output python dict | |||
| /// \return Status error code | |||
| Status GetNextAsMap(py::dict *output) { | |||
| return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet."); | |||
| } | |||
| /// Get the next row as a python dict | |||
| /// \param[out] output python dict | |||
| /// \return Status error code | |||
| Status GetNextAsList(py::list *output) { | |||
| return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet."); | |||
| } | |||
| }; | |||
| } // namespace mindspore::dataset | |||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CONSUMER_PYTHON_TREE_CONSUMER_H_ | |||
| @@ -0,0 +1,72 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <algorithm> | |||
| #include <memory> | |||
| #include <string> | |||
| #include <unordered_map> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "minddata/dataset/engine/consumers/tree_consumer.h" | |||
| #include "minddata/dataset/engine/tree_adapter.h" | |||
| namespace mindspore::dataset { | |||
| Status IteratorConsumer::GetNextAsVector(std::vector<TensorPtr> *out) { | |||
| RETURN_UNEXPECTED_IF_NULL(out); | |||
| out->clear(); | |||
| TensorRow res; | |||
| RETURN_IF_NOT_OK(tree_adapter_->GetNext(&res)); | |||
| // Return empty vector if there's no data | |||
| RETURN_OK_IF_TRUE(res.empty()); | |||
| std::copy(res.begin(), res.end(), std::back_inserter(*out)); | |||
| return Status::OK(); | |||
| } | |||
| Status IteratorConsumer::GetNextAsMap(std::unordered_map<std::string, TensorPtr> *out_map) { | |||
| RETURN_UNEXPECTED_IF_NULL(out_map); | |||
| out_map->clear(); | |||
| TensorRow res; | |||
| RETURN_IF_NOT_OK(tree_adapter_->GetNext(&res)); | |||
| // Return empty map if there's no data | |||
| RETURN_OK_IF_TRUE(res.empty()); | |||
| // Populate the out map from the row and return it | |||
| for (const auto &colMap : tree_adapter_->GetColumnNameMap()) { | |||
| (*out_map)[colMap.first] = std::move(res[colMap.second]); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| TreeConsumer::TreeConsumer() { tree_adapter_ = std::make_unique<TreeAdapter>(); } | |||
| Status IteratorConsumer::Init(std::shared_ptr<api::Dataset> d) { | |||
| return tree_adapter_->BuildAndPrepare(std::move(d), num_epochs_); | |||
| } | |||
| Status TreeConsumer::Init(std::shared_ptr<api::Dataset> d) { return tree_adapter_->BuildAndPrepare(std::move(d)); } | |||
| Status ToDevice::Init(std::shared_ptr<api::Dataset> d) { | |||
| // TODO(CRC): | |||
| // Get device ID from children look at get_distribution in python | |||
| // Add DeviceQue IR on top of dataset d | |||
| return tree_adapter_->BuildAndPrepare(std::move(d), num_epochs_); | |||
| } | |||
| } // namespace mindspore::dataset | |||
| @@ -0,0 +1,154 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CONSUMER_TREE_CONSUMER_H_ | |||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CONSUMER_TREE_CONSUMER_H_ | |||
| #include <memory> | |||
| #include <string> | |||
| #include <unordered_map> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "minddata/dataset/engine/tree_adapter.h" | |||
| namespace mindspore::dataset { | |||
| // Forward declare | |||
| class TreeAdapter; | |||
| namespace api { | |||
| class Dataset; | |||
| } | |||
| /// A base class for tree consumers which would fetch rows from the tree pipeline | |||
| class TreeConsumer { | |||
| public: | |||
| /// Constructor that prepares an empty tree_adapter | |||
| TreeConsumer(); | |||
| /// Initializes the consumer, this involves constructing and preparing the tree. | |||
| /// \param d The dataset node that represent the root of the IR tree. | |||
| /// \return Status error code. | |||
| virtual Status Init(std::shared_ptr<api::Dataset> d); | |||
| protected: | |||
| /// The class owns the tree_adapter that handles execution tree operations. | |||
| std::unique_ptr<TreeAdapter> tree_adapter_; | |||
| /// Method to return the name of the consumer | |||
| /// \return string | |||
| virtual std::string Name() = 0; | |||
| }; | |||
| /// Consumer that iterates over the dataset and returns the rows one by one as a vector or a map | |||
| class IteratorConsumer : public TreeConsumer { | |||
| public: | |||
| /// Constructor which will call the base class default constructor. | |||
| /// \param num_epochs number of epochs. Default to -1 (infinite epochs). | |||
| explicit IteratorConsumer(int32_t num_epochs = -1) : TreeConsumer(), num_epochs_(num_epochs) {} | |||
| Status Init(std::shared_ptr<api::Dataset> d) override; | |||
| /// Returns the next row in a vector format | |||
| /// \param[out] out std::vector of Tensors | |||
| /// \return Status error code | |||
| Status GetNextAsVector(std::vector<TensorPtr> *out); | |||
| /// Returns the next row in as a map | |||
| /// \param[out] out std::map of string to Tensor | |||
| /// \return Status error code | |||
| Status GetNextAsMap(std::unordered_map<std::string, TensorPtr> *out); | |||
| protected: | |||
| /// Method to return the name of the consumer | |||
| /// \return string | |||
| std::string Name() override { return "IteratorConsumer"; } | |||
| private: | |||
| int32_t num_epochs_; | |||
| }; | |||
| /// Consumer that iterates over the dataset and writes it to desk | |||
| class SaveToDesk : public TreeConsumer { | |||
| public: | |||
| /// Constructor which will call the base class default constructor. | |||
| /// \param dataset_path path the the dataset | |||
| /// \param num_files number of files. Default to 1 | |||
| /// \param dataset_type The format of the dataset. Default to "mindrecod". | |||
| explicit SaveToDesk(std::string dataset_path, int32_t num_files = 1, std::string dataset_type = "mindrecord") | |||
| : TreeConsumer(), dataset_path_(dataset_path), num_files_(num_files), dataset_type_(dataset_type) {} | |||
| /// Save the given dataset to MindRecord format on desk. This is a blocking method (i.e., after returning, all rows | |||
| /// would be written to desk) | |||
| /// \return Status error code | |||
| Status Save() { return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet."); } | |||
| private: | |||
| std::string dataset_path_; | |||
| int32_t num_files_; | |||
| std::string dataset_type_; | |||
| }; | |||
| /// Consumer that iterates over the dataset and send it to a device | |||
| class ToDevice : public TreeConsumer { | |||
| public: | |||
| ToDevice(std::string device_type, bool send_epoch_end, int32_t num_epochs) | |||
| : TreeConsumer(), device_type_(device_type), send_epoch_end_(send_epoch_end), num_epochs_(num_epochs) {} | |||
| Status Init(std::shared_ptr<api::Dataset> d) override; | |||
| Status Send() { | |||
| // TODO(CRC): launch the tree | |||
| return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet."); | |||
| } | |||
| Status Stop() { | |||
| // TODO(CRC): Get root + call StopSend | |||
| return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet."); | |||
| } | |||
| Status Continue() { | |||
| // TODO(CRC): Get root + call StopSend | |||
| return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet."); | |||
| } | |||
| private: | |||
| std::string device_type_; | |||
| bool send_epoch_end_; | |||
| int32_t num_epochs_; | |||
| }; | |||
| /// Consumer that is used to get some pipeline information | |||
| class TreeGetters : public TreeConsumer { | |||
| Status GetDatasetSize(int32_t *size) { | |||
| return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet."); | |||
| } | |||
| Status GetBatchSize(int32_t *batch_size) { | |||
| return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet."); | |||
| } | |||
| Status GetRepeatCount(int32_t *repeat_count) { | |||
| return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet."); | |||
| } | |||
| Status GetNumClasses(int32_t *num_classes) { | |||
| return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet."); | |||
| } | |||
| Status GetOutputShapes(std::vector<TensorShape> *shapes) { | |||
| return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet."); | |||
| } | |||
| Status GetOutputTypes(std::vector<DataType> *types) { | |||
| return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet."); | |||
| } | |||
| Status GetOutputNames(std::vector<std::string> *names) { | |||
| return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet."); | |||
| } | |||
| }; | |||
| } // namespace mindspore::dataset | |||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CONSUMER_TREE_CONSUMER_H_ | |||
| @@ -0,0 +1,25 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "minddata/dataset/engine/runtime_context.h" | |||
| #include <memory> | |||
| #include <utility> | |||
| namespace mindspore::dataset { | |||
| void RuntimeContext::AssignConsumer(std::unique_ptr<TreeConsumer> tree_consumer) { | |||
| tree_consumer_ = std::move(tree_consumer); | |||
| } | |||
| } // namespace mindspore::dataset | |||
| @@ -0,0 +1,54 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_RUNTIME_CONTEXT_H_ | |||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_RUNTIME_CONTEXT_H_ | |||
| #include <memory> | |||
| #include <utility> | |||
| #include "minddata/dataset/core/client.h" | |||
| #include "minddata/dataset/engine/consumers/tree_consumer.h" | |||
| namespace mindspore::dataset { | |||
| class TreeConsumer; | |||
| /// Class the represents single runtime instance which can consume data from a data pipeline | |||
| class RuntimeContext { | |||
| public: | |||
| /// Default constructor | |||
| RuntimeContext() = default; | |||
| /// Initialize the runtime, for now we just call the global init | |||
| /// \return Status error code | |||
| Status Init() { return GlobalInit(); } | |||
| /// Method to terminate the runtime, this will not release the resources | |||
| /// \return Status error code | |||
| virtual Status Terminate() { return Status::OK(); } | |||
| /// Set the tree consumer | |||
| /// \param tree_consumer to be assigned | |||
| void AssignConsumer(std::unique_ptr<TreeConsumer> tree_consumer); | |||
| /// Get the tree consumer | |||
| /// \return Raw pointer to the tree consumer. | |||
| TreeConsumer *GetConsumer() { return tree_consumer_.get(); } | |||
| private: | |||
| std::unique_ptr<TreeConsumer> tree_consumer_; | |||
| }; | |||
| } // namespace mindspore::dataset | |||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_RUNTIME_CONTEXT_H_ | |||
| @@ -26,8 +26,6 @@ Status TreeAdapter::BuildAndPrepare(std::shared_ptr<api::Dataset> root_ir, int32 | |||
| // Check whether this function has been called before. If so, return fail | |||
| CHECK_FAIL_RETURN_UNEXPECTED(tree_ == nullptr, "ExecutionTree is already built."); | |||
| RETURN_UNEXPECTED_IF_NULL(root_ir); | |||
| // GlobalInit, might need to be moved to the proper place once RuntimeConext is complete | |||
| RETURN_IF_NOT_OK(GlobalInit()); | |||
| // this will evolve in the long run | |||
| tree_ = std::make_unique<ExecutionTree>(); | |||
| @@ -28,6 +28,9 @@ | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| namespace api { | |||
| class Dataset; | |||
| } | |||
| class TreeAdapter { | |||
| public: | |||
| TreeAdapter() = default; | |||
| @@ -52,6 +52,8 @@ class Vocab; | |||
| #endif | |||
| namespace api { | |||
| class Dataset; | |||
| class Iterator; | |||
| class TensorOperation; | |||
| class SchemaObj; | |||
| @@ -17,10 +17,11 @@ | |||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_ITERATOR_H_ | |||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_ITERATOR_H_ | |||
| #include <unordered_map> | |||
| #include <memory> | |||
| #include <vector> | |||
| #include <string> | |||
| #include <unordered_map> | |||
| #include <vector> | |||
| #include "minddata/dataset/engine/runtime_context.h" | |||
| #include "minddata/dataset/include/status.h" | |||
| namespace mindspore { | |||
| @@ -32,6 +33,8 @@ class DatasetIterator; | |||
| class DatasetOp; | |||
| class Tensor; | |||
| class RuntimeContext; | |||
| class IteratorConsumer; | |||
| namespace api { | |||
| class Dataset; | |||
| @@ -43,7 +46,7 @@ using TensorVec = std::vector<std::shared_ptr<Tensor>>; | |||
| class Iterator { | |||
| public: | |||
| /// \brief Constructor | |||
| Iterator() = default; | |||
| Iterator() : consumer_(nullptr) {} | |||
| /// \brief Destructor | |||
| ~Iterator() = default; | |||
| @@ -111,12 +114,8 @@ class Iterator { | |||
| _Iterator end() { return _Iterator(nullptr); } | |||
| private: | |||
| // Runtime tree. | |||
| // Use shared_ptr instead of unique_ptr because the DatasetIterator constructor takes in a shared_ptr type. | |||
| std::shared_ptr<ExecutionTree> tree_; | |||
| // Runtime iterator | |||
| std::unique_ptr<DatasetIterator> iterator_; | |||
| std::unique_ptr<RuntimeContext> runtime_context; | |||
| IteratorConsumer *consumer_; | |||
| }; | |||
| } // namespace api | |||
| } // namespace dataset | |||
| @@ -89,6 +89,7 @@ enum class StatusCode : char { | |||
| kTimeOut = 14, | |||
| kBuddySpaceFull = 15, | |||
| kNetWorkError = 16, | |||
| kNotImplementedYet = 17, | |||
| // Make this error code the last one. Add new error code above it. | |||
| kUnexpectedError = 127 | |||
| }; | |||