/** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "pybind11/pybind11.h" #include "pybind11/stl.h" #include "pybind11/stl_bind.h" #include "minddata/dataset/api/python/pybind_register.h" #include "minddata/dataset/api/python/de_pipeline.h" namespace mindspore { namespace dataset { PYBIND_REGISTER( DEPipeline, 0, ([](const py::module *m) { (void)py::class_(*m, "DEPipeline") .def(py::init<>()) .def( "AddNodeToTree", [](DEPipeline &de, const OpName &op_name, const py::dict &args) { py::dict out; THROW_IF_ERROR(de.AddNodeToTree(op_name, args, &out)); return out; }, py::return_value_policy::reference) .def_static("AddChildToParentNode", [](const DsOpPtr &child_op, const DsOpPtr &parent_op) { THROW_IF_ERROR(DEPipeline::AddChildToParentNode(child_op, parent_op)); }) .def("AssignRootNode", [](DEPipeline &de, const DsOpPtr &dataset_op) { THROW_IF_ERROR(de.AssignRootNode(dataset_op)); }) .def("SetBatchParameters", [](DEPipeline &de, const py::dict &args) { THROW_IF_ERROR(de.SetBatchParameters(args)); }) .def("PrepareTree", [](DEPipeline &de, int32_t num_epochs) { THROW_IF_ERROR(de.PrepareTree(num_epochs)); }) .def("LaunchTreeExec", [](DEPipeline &de) { THROW_IF_ERROR(de.LaunchTreeExec()); }) .def("GetColumnNames", [](DEPipeline &de) { py::list out; THROW_IF_ERROR(de.GetColumnNames(&out)); return out; }) .def("GetNextAsMap", [](DEPipeline &de) { py::dict out; THROW_IF_ERROR(de.GetNextAsMap(&out)); return out; }) .def("GetNextAsList", [](DEPipeline &de) { py::list out; THROW_IF_ERROR(de.GetNextAsList(&out)); return out; }) .def("GetOutputShapes", [](DEPipeline &de) { py::list out; THROW_IF_ERROR(de.GetOutputShapes(&out)); return out; }) .def("GetOutputTypes", [](DEPipeline &de) { py::list out; THROW_IF_ERROR(de.GetOutputTypes(&out)); return out; }) .def("GetDatasetSize", &DEPipeline::GetDatasetSize) .def("GetBatchSize", &DEPipeline::GetBatchSize) .def("GetNumClasses", &DEPipeline::GetNumClasses) .def("GetRepeatCount", &DEPipeline::GetRepeatCount) .def("StopSend", [](DEPipeline &de) { THROW_IF_ERROR(de.StopSend()); }) .def("ContinueSend", [](DEPipeline &de) { THROW_IF_ERROR(de.ContinueSend()); }) .def("SaveDataset", [](DEPipeline &de, const std::vector &file_names, const std::string &file_type) { THROW_IF_ERROR(de.SaveDataset(file_names, file_type)); return true; }); })); PYBIND_REGISTER(OpName, 0, ([](const py::module *m) { (void)py::enum_(*m, "OpName", py::arithmetic()) .value("SHUFFLE", OpName::kShuffle) .value("BATCH", OpName::kBatch) .value("BUCKETBATCH", OpName::kBucketBatch) .value("BARRIER", OpName::kBarrier) .value("MINDRECORD", OpName::kMindrecord) .value("CACHE", OpName::kCache) .value("REPEAT", OpName::kRepeat) .value("SKIP", OpName::kSkip) .value("TAKE", OpName::kTake) .value("ZIP", OpName::kZip) .value("CONCAT", OpName::kConcat) .value("MAP", OpName::kMap) .value("FILTER", OpName::kFilter) .value("DEVICEQUEUE", OpName::kDeviceQueue) .value("GENERATOR", OpName::kGenerator) .export_values() .value("RENAME", OpName::kRename) .value("TFREADER", OpName::kTfReader) .value("PROJECT", OpName::kProject) .value("IMAGEFOLDER", OpName::kImageFolder) .value("MNIST", OpName::kMnist) .value("MANIFEST", OpName::kManifest) .value("VOC", OpName::kVoc) .value("COCO", OpName::kCoco) .value("CIFAR10", OpName::kCifar10) .value("CIFAR100", OpName::kCifar100) .value("RANDOMDATA", OpName::kRandomData) .value("BUILDVOCAB", OpName::kBuildVocab) .value("SENTENCEPIECEVOCAB", OpName::kSentencePieceVocab) .value("CELEBA", OpName::kCelebA) .value("TEXTFILE", OpName::kTextFile) .value("EPOCHCTRL", OpName::kEpochCtrl) .value("CSV", OpName::kCsv) .value("CLUE", OpName::kClue); })); } // namespace dataset } // namespace mindspore