|
- /**
- * Copyright 2019 Huawei Technologies Co., Ltd
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
- #include "dataset/kernels/py_func_op.h"
-
- #include <memory>
- #include <vector>
-
- #include "dataset/core/tensor.h"
- #include "dataset/kernels/tensor_op.h"
- #include "dataset/util/status.h"
-
- namespace mindspore {
- namespace dataset {
- Status PyFuncOp::Compute(const TensorRow &input, TensorRow *output) {
- IO_CHECK_VECTOR(input, output);
- Status ret = Status(StatusCode::kOK, "PyFunc Call Succeed");
- {
- // Acquire Python GIL
- py::gil_scoped_acquire gil_acquire;
- if (Py_IsInitialized() == 0) {
- ret = Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized");
- goto ComputeReturn;
- }
- try {
- // Transform input tensor vector into numpy array vector
- py::tuple input_args(input.size());
- for (size_t i = 0; i < input.size(); i++) {
- py::array new_data;
- RETURN_IF_NOT_OK(input.at(i)->GetDataAsNumpy(&new_data));
- // possible memcpy here
- input_args[i] = new_data;
- }
- // Invoke python function
- py::object ret_py_obj = this->py_func_ptr_(*input_args);
- // Process the return value
- if (py::isinstance<py::array>(ret_py_obj)) {
- // In case of a n-1 mapping, the return value will be a numpy array
- std::shared_ptr<Tensor> out;
- RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, ret_py_obj.cast<py::array>()));
- output->push_back(out);
- } else if (py::isinstance<py::tuple>(ret_py_obj)) {
- // In case of a n-m mapping, the return value will be a tuple of numpy arrays
- py::tuple ret_py_tuple = ret_py_obj.cast<py::tuple>();
- // Iterate over two containers simultaneously for memory copy
- for (size_t i = 0; i < ret_py_tuple.size(); i++) {
- py::object ret_py_ele = ret_py_tuple[i];
- if (!py::isinstance<py::array>(ret_py_ele)) {
- goto ShapeMisMatch;
- }
- std::shared_ptr<Tensor> out;
- RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, ret_py_ele.cast<py::array>()));
- output->push_back(out);
- }
- } else {
- goto ShapeMisMatch;
- }
- } catch (const py::error_already_set &e) {
- ret = Status(StatusCode::kPyFuncException, e.what());
- }
- }
-
- ComputeReturn:
- return ret;
-
- ShapeMisMatch:
- ret = Status(StatusCode::kShapeMisMatch, "PyFunc should return a numpy array or a numpy array tuple");
- goto ComputeReturn;
- }
- } // namespace dataset
- } // namespace mindspore
|