You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

iterator.cc 2.3 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "minddata/dataset/include/iterator.h"
  17. #include "minddata/dataset/core/client.h"
  18. #include "minddata/dataset/engine/consumers/tree_consumer.h"
  19. #include "minddata/dataset/engine/runtime_context.h"
  20. #include "minddata/dataset/include/datasets.h"
  21. namespace mindspore {
  22. namespace dataset {
  23. Iterator::Iterator() : consumer_(nullptr) {}
  24. Iterator::~Iterator() { Stop(); }
  25. // Get the next row from the data pipeline.
  26. bool Iterator::GetNextRow(TensorMap *row) {
  27. Status rc = consumer_->GetNextAsMap(row);
  28. if (rc.IsError()) {
  29. MS_LOG(ERROR) << "GetNextRow: Failed to get next row. Error status: " << rc;
  30. row->clear();
  31. return false;
  32. }
  33. return true;
  34. }
  35. // Get the next row from the data pipeline.
  36. bool Iterator::GetNextRow(TensorVec *row) {
  37. Status rc = consumer_->GetNextAsVector(row);
  38. if (rc.IsError()) {
  39. MS_LOG(ERROR) << "GetNextRow: Failed to get next row. Error status: " << rc;
  40. row->clear();
  41. return false;
  42. }
  43. return true;
  44. }
  45. // Shut down the data pipeline.
  46. void Iterator::Stop() {
  47. Status rc = runtime_context_->Terminate();
  48. if (rc.IsError()) {
  49. MS_LOG(ERROR) << rc.ToString();
  50. }
  51. }
  52. // Function to build and launch the execution tree.
  53. Status Iterator::BuildAndLaunchTree(std::shared_ptr<Dataset> ds, int32_t num_epochs) {
  54. runtime_context_ = std::make_unique<NativeRuntimeContext>();
  55. RETURN_IF_NOT_OK(runtime_context_->Init());
  56. auto consumer = std::make_unique<IteratorConsumer>(num_epochs);
  57. consumer_ = consumer.get();
  58. RETURN_IF_NOT_OK(consumer->Init(ds->IRNode()));
  59. runtime_context_->AssignConsumer(std::move(consumer));
  60. return Status::OK();
  61. }
  62. } // namespace dataset
  63. } // namespace mindspore