You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

iterator.cc 4.9 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "minddata/dataset/include/iterator.h"
  17. #include "minddata/dataset/core/client.h"
  18. #include "minddata/dataset/include/datasets.h"
  19. namespace mindspore {
  20. namespace dataset {
  21. namespace api {
  22. // Get the next row from the data pipeline.
  23. bool Iterator::GetNextRow(TensorMap *row) {
  24. Status rc = iterator_->GetNextAsMap(row);
  25. if (rc.IsError()) {
  26. MS_LOG(ERROR) << "GetNextRow: Failed to get next row. Error status: " << rc;
  27. row->clear();
  28. return false;
  29. }
  30. return true;
  31. }
  32. // Get the next row from the data pipeline.
  33. bool Iterator::GetNextRow(TensorVec *row) {
  34. TensorRow tensor_row;
  35. Status rc = iterator_->FetchNextTensorRow(&tensor_row);
  36. if (rc.IsError()) {
  37. MS_LOG(ERROR) << "GetNextRow: Failed to get next row. Error status: " << rc;
  38. row->clear();
  39. return false;
  40. }
  41. // Generate a vector as return
  42. row->clear();
  43. std::copy(tensor_row.begin(), tensor_row.end(), std::back_inserter(*row));
  44. return true;
  45. }
  46. // Shut down the data pipeline.
  47. void Iterator::Stop() {
  48. // Releasing the iterator_ unique_ptre. This should trigger the destructor of iterator_.
  49. iterator_.reset();
  50. // Release ownership of tree_ shared pointer. This will decrement the ref count.
  51. tree_.reset();
  52. }
  53. // Function to build and launch the execution tree.
  54. Status Iterator::BuildAndLaunchTree(std::shared_ptr<Dataset> ds) {
  55. // One time init
  56. Status rc;
  57. rc = GlobalInit();
  58. RETURN_IF_NOT_OK(rc);
  59. // Instantiate the execution tree
  60. tree_ = std::make_shared<ExecutionTree>();
  61. // Iterative BFS converting Dataset tree into runtime Execution tree.
  62. std::queue<std::pair<std::shared_ptr<Dataset>, std::shared_ptr<DatasetOp>>> q;
  63. if (ds == nullptr) {
  64. RETURN_STATUS_UNEXPECTED("Input is null pointer");
  65. } else {
  66. // Convert the current root node.
  67. auto root_ops = ds->Build();
  68. if (root_ops.empty()) {
  69. RETURN_STATUS_UNEXPECTED("Node operation returned nothing");
  70. }
  71. // Iterate through all the DatasetOps returned by Dataset's Build(), associate them
  72. // with the execution tree and add the child and parent relationship between the nodes
  73. // Note that some Dataset objects might return more than one DatasetOps
  74. // e.g. MapDataset will return [ProjectOp, MapOp] if project_columns is set for MapDataset
  75. std::shared_ptr<DatasetOp> prev_op = nullptr;
  76. for (auto op : root_ops) {
  77. RETURN_IF_NOT_OK(tree_->AssociateNode(op));
  78. if (prev_op != nullptr) {
  79. RETURN_IF_NOT_OK(prev_op->AddChild(op));
  80. }
  81. prev_op = op;
  82. }
  83. // Add the last DatasetOp to the queue to be BFS.
  84. q.push(std::make_pair(ds, root_ops.back()));
  85. // Traverse down to the children and convert them to the corresponding DatasetOps (i.e. execution tree nodes)
  86. while (!q.empty()) {
  87. auto node_pair = q.front();
  88. q.pop();
  89. // Iterate through all the direct children of the first element in our BFS queue
  90. for (auto child : node_pair.first->children) {
  91. auto child_ops = child->Build();
  92. if (child_ops.empty()) {
  93. RETURN_STATUS_UNEXPECTED("Node operation returned nothing");
  94. }
  95. auto node_op = node_pair.second;
  96. // Iterate through all the DatasetOps returned by calling Build on the last Dataset object, associate them
  97. // with the execution tree and add the child and parent relationship between the nodes
  98. // Note that some Dataset objects might return more than one DatasetOps
  99. // e.g. MapDataset will return MapOp and ProjectOp if project_columns is set for MapDataset
  100. for (auto child_op : child_ops) {
  101. RETURN_IF_NOT_OK(tree_->AssociateNode(child_op));
  102. RETURN_IF_NOT_OK(node_op->AddChild(child_op));
  103. node_op = child_op;
  104. }
  105. // Add the child and the last element of the returned DatasetOps (which is now the leaf node in our current
  106. // execution tree) to the BFS queue
  107. q.push(std::make_pair(child, child_ops.back()));
  108. }
  109. }
  110. RETURN_IF_NOT_OK(tree_->AssignRoot(root_ops.front()));
  111. }
  112. // Launch the execution tree.
  113. RETURN_IF_NOT_OK(tree_->Prepare());
  114. tree_->Launch();
  115. iterator_ = std::make_unique<DatasetIterator>(tree_);
  116. RETURN_UNEXPECTED_IF_NULL(iterator_);
  117. return rc;
  118. }
  119. } // namespace api
  120. } // namespace dataset
  121. } // namespace mindspore