You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

bindings.cc 5.6 kB

5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "pybind11/pybind11.h"
  17. #include "pybind11/stl.h"
  18. #include "pybind11/stl_bind.h"
  19. #include "minddata/dataset/api/python/pybind_register.h"
  20. #include "minddata/dataset/api/python/de_pipeline.h"
  21. namespace mindspore {
  22. namespace dataset {
  23. PYBIND_REGISTER(
  24. DEPipeline, 0, ([](const py::module *m) {
  25. (void)py::class_<DEPipeline>(*m, "DEPipeline")
  26. .def(py::init<>())
  27. .def(
  28. "AddNodeToTree",
  29. [](DEPipeline &de, const OpName &op_name, const py::dict &args) {
  30. py::dict out;
  31. THROW_IF_ERROR(de.AddNodeToTree(op_name, args, &out));
  32. return out;
  33. },
  34. py::return_value_policy::reference)
  35. .def_static("AddChildToParentNode",
  36. [](const DsOpPtr &child_op, const DsOpPtr &parent_op) {
  37. THROW_IF_ERROR(DEPipeline::AddChildToParentNode(child_op, parent_op));
  38. })
  39. .def("AssignRootNode",
  40. [](DEPipeline &de, const DsOpPtr &dataset_op) { THROW_IF_ERROR(de.AssignRootNode(dataset_op)); })
  41. .def("SetBatchParameters",
  42. [](DEPipeline &de, const py::dict &args) { THROW_IF_ERROR(de.SetBatchParameters(args)); })
  43. .def("PrepareTree", [](DEPipeline &de, int32_t num_epochs) { THROW_IF_ERROR(de.PrepareTree(num_epochs)); })
  44. .def("LaunchTreeExec", [](DEPipeline &de) { THROW_IF_ERROR(de.LaunchTreeExec()); })
  45. .def("GetColumnNames",
  46. [](DEPipeline &de) {
  47. py::list out;
  48. THROW_IF_ERROR(de.GetColumnNames(&out));
  49. return out;
  50. })
  51. .def("GetNextAsMap",
  52. [](DEPipeline &de) {
  53. py::dict out;
  54. THROW_IF_ERROR(de.GetNextAsMap(&out));
  55. return out;
  56. })
  57. .def("GetNextAsList",
  58. [](DEPipeline &de) {
  59. py::list out;
  60. THROW_IF_ERROR(de.GetNextAsList(&out));
  61. return out;
  62. })
  63. .def("GetOutputShapes",
  64. [](DEPipeline &de) {
  65. py::list out;
  66. THROW_IF_ERROR(de.GetOutputShapes(&out));
  67. return out;
  68. })
  69. .def("GetOutputTypes",
  70. [](DEPipeline &de) {
  71. py::list out;
  72. THROW_IF_ERROR(de.GetOutputTypes(&out));
  73. return out;
  74. })
  75. .def("GetDatasetSize", &DEPipeline::GetDatasetSize)
  76. .def("GetBatchSize", &DEPipeline::GetBatchSize)
  77. .def("GetNumClasses", &DEPipeline::GetNumClasses)
  78. .def("GetRepeatCount", &DEPipeline::GetRepeatCount)
  79. .def("StopSend", [](DEPipeline &de) { THROW_IF_ERROR(de.StopSend()); })
  80. .def("ContinueSend", [](DEPipeline &de) { THROW_IF_ERROR(de.ContinueSend()); })
  81. .def("SaveDataset", [](DEPipeline &de, const std::vector<std::string> &file_names, const std::string &file_type) {
  82. THROW_IF_ERROR(de.SaveDataset(file_names, file_type));
  83. return true;
  84. });
  85. }));
  86. PYBIND_REGISTER(OpName, 0, ([](const py::module *m) {
  87. (void)py::enum_<OpName>(*m, "OpName", py::arithmetic())
  88. .value("SHUFFLE", OpName::kShuffle)
  89. .value("BATCH", OpName::kBatch)
  90. .value("BUCKETBATCH", OpName::kBucketBatch)
  91. .value("BARRIER", OpName::kBarrier)
  92. .value("MINDRECORD", OpName::kMindrecord)
  93. .value("CACHE", OpName::kCache)
  94. .value("REPEAT", OpName::kRepeat)
  95. .value("SKIP", OpName::kSkip)
  96. .value("TAKE", OpName::kTake)
  97. .value("ZIP", OpName::kZip)
  98. .value("CONCAT", OpName::kConcat)
  99. .value("MAP", OpName::kMap)
  100. .value("FILTER", OpName::kFilter)
  101. .value("DEVICEQUEUE", OpName::kDeviceQueue)
  102. .value("GENERATOR", OpName::kGenerator)
  103. .export_values()
  104. .value("RENAME", OpName::kRename)
  105. .value("TFREADER", OpName::kTfReader)
  106. .value("PROJECT", OpName::kProject)
  107. .value("IMAGEFOLDER", OpName::kImageFolder)
  108. .value("MNIST", OpName::kMnist)
  109. .value("MANIFEST", OpName::kManifest)
  110. .value("VOC", OpName::kVoc)
  111. .value("COCO", OpName::kCoco)
  112. .value("CIFAR10", OpName::kCifar10)
  113. .value("CIFAR100", OpName::kCifar100)
  114. .value("RANDOMDATA", OpName::kRandomData)
  115. .value("BUILDVOCAB", OpName::kBuildVocab)
  116. .value("SENTENCEPIECEVOCAB", OpName::kSentencePieceVocab)
  117. .value("CELEBA", OpName::kCelebA)
  118. .value("TEXTFILE", OpName::kTextFile)
  119. .value("EPOCHCTRL", OpName::kEpochCtrl)
  120. .value("CSV", OpName::kCsv)
  121. .value("CLUE", OpName::kClue);
  122. }));
  123. } // namespace dataset
  124. } // namespace mindspore