You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tree_adapter.h 3.0 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_TREE_ADAPTER_H_
  17. #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_TREE_ADAPTER_H_
  18. #include <memory>
  19. #include <string>
  20. #include <unordered_map>
  21. #include <utility>
  22. #include <vector>
  23. #include "minddata/dataset/engine/execution_tree.h"
  24. #include "minddata/dataset/include/datasets.h"
  25. namespace mindspore {
  26. namespace dataset {
  27. namespace api {
  28. class Dataset;
  29. }
  30. class TreeAdapter {
  31. public:
  32. TreeAdapter() = default;
  33. ~TreeAdapter() = default;
  34. // This will construct an ExeTree from a Dataset root and Prepare() the ExeTree
  35. // This function is only meant to be called once and needs to be called before GetNext
  36. // ExeTree will be launched when the first GetNext is called
  37. Status BuildAndPrepare(std::shared_ptr<api::Dataset> root, int32_t num_epoch = -1);
  38. // This is the main method TreeConsumer uses to interact with TreeAdapter
  39. // 1. GetNext will Launch() the ExeTree on its first call by iterator (tree is already prepared)
  40. // 2. GetNext will return empty row when eoe/eof is obtained
  41. Status GetNext(TensorRow *);
  42. // This function will return the root of the execution tree.
  43. std::weak_ptr<DatasetOp> GetRoot() { return tree_ != nullptr ? tree_->root() : nullptr; }
  44. // This function will return the column_name_map once BuildAndPrepare() is called
  45. std::unordered_map<std::string, int32_t> GetColumnNameMap() const { return column_name_map_; }
  46. // This function returns the TaskGroup associated with ExeTree. This is needed by DeviceQueueConsumer
  47. // to be able to launch a thread. BuildAndPrepare needs to be called before this function
  48. TaskGroup *AllTasks() const { return tree_ != nullptr ? tree_->AllTasks() : nullptr; }
  49. std::shared_ptr<DatasetOp> root() { return tree_->root(); }
  50. Status Launch() const { return tree_->Launch(); }
  51. private:
  52. // This RECURSIVE function converts IR nodes into DatasetOp in ExecutionTree. IR could build a vector of ops. In
  53. // such case, the first node is returned. Op is added as child when the current function returns.
  54. Status DFSBuildTree(std::shared_ptr<api::Dataset> ir, std::shared_ptr<DatasetOp> *op);
  55. std::unique_ptr<DataBuffer> cur_db_;
  56. std::unordered_map<std::string, int32_t> column_name_map_;
  57. std::unique_ptr<ExecutionTree> tree_;
  58. };
  59. } // namespace dataset
  60. } // namespace mindspore
  61. #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_TREE_ADAPTER_H_