diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.h index 1d7f2b97f3..ff451b0929 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.h @@ -134,6 +134,8 @@ class GeneratorOp : public PipelineOp { // @return Name of the current Op std::string Name() const override { return "GeneratorOp"; } + Status Init(); + private: py::function generator_function_; std::vector column_names_; @@ -146,8 +148,6 @@ class GeneratorOp : public PipelineOp { WaitPost wp_; - Status Init(); - void Dealloc() noexcept; Status PyRowToTensorRow(py::object py_data, TensorRow *tensor_row); diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/CMakeLists.txt index 43f35638f2..244d6586d0 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/CMakeLists.txt +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/CMakeLists.txt @@ -1,4 +1,15 @@ file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) -add_library(engine-ir-datasetops-source OBJECT - image_folder_node.cc) \ No newline at end of file + +set(DATASET_ENGINE_IR_DATASETOPS_SOURCE_SRC_FILES + image_folder_node.cc + ) + +if (ENABLE_PYTHON) + set(DATASET_ENGINE_IR_DATASETOPS_SOURCE_SRC_FILES + ${DATASET_ENGINE_IR_DATASETOPS_SOURCE_SRC_FILES} + generator_node.cc + ) +endif () + +add_library(engine-ir-datasetops-source OBJECT ${DATASET_ENGINE_IR_DATASETOPS_SOURCE_SRC_FILES}) \ No newline at end of file diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/generator_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/generator_node.cc new file mode 100644 index 0000000000..c60eff1486 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/generator_node.cc @@ -0,0 +1,60 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/engine/ir/datasetops/source/generator_node.h" + +#include +#include +#include + +#include "minddata/dataset/engine/datasetops/source/generator_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +namespace api { +GeneratorNode::GeneratorNode(py::function generator_function, const std::vector &column_names, + const std::vector &column_types) + : generator_function_(generator_function), column_names_(column_names), column_types_(column_types) {} + +std::vector> GeneratorNode::Build() { + // A vector containing shared pointer to the Dataset Ops that this object will create + std::vector> node_ops; + // GeneratorOp's constructor takes in a prefetch_size, which isn't being set by user nor is it being used by + // GeneratorOp internally. Here it is given a zero which is the default in generator builder + std::shared_ptr op = std::make_shared(generator_function_, column_names_, column_types_, 0, + rows_per_buffer_, connector_que_size_); + + // Init() is called in builder when generator is built. Here, since we are getting away from the builder class, init + // needs to be called when the op is built. The caveat is that Init needs to be made public (before it is private). + // This method can be privatized once we move Init() to Generator's functor. However, that is a bigger change which + // best be delivered when the test cases for this api is ready. + Status rc = op->Init(); + + if (rc.IsOk()) { + node_ops.push_back(op); + } else { + MS_LOG(ERROR) << "Fail to Init GeneratorOp : " << rc.ToString(); + } + + return node_ops; +} + +// no validation is needed for generator op. +Status GeneratorNode::ValidateParams() { return Status::OK(); } +} // namespace api +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/generator_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/generator_node.h new file mode 100644 index 0000000000..ae52020289 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/generator_node.h @@ -0,0 +1,59 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_GENERATOR_NODE_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_GENERATOR_NODE_H_ + +#include +#include +#include + +#include "minddata/dataset/include/datasets.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { + +namespace api { + +/// \class GeneratorNode +/// \brief A Dataset derived class to represent GeneratorNode dataset +class GeneratorNode : public Dataset { + public: + /// \brief Constructor + GeneratorNode(py::function generator_function, const std::vector &column_names, + const std::vector &column_types); + + /// \brief Destructor + ~GeneratorNode() = default; + + /// \brief a base class override function to create the required runtime dataset op objects for this class + /// \return The list of shared pointers to the newly created DatasetOps + std::vector> Build() override; + + /// \brief Parameters validation + /// \return Status Status::OK() if all the parameters are valid + Status ValidateParams() override; + + private: + py::function generator_function_; + std::vector column_names_; + std::vector column_types_; +}; +} // namespace api +} // namespace dataset +} // namespace mindspore +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_GENERATOR_NODE_H_