You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

py_func_op.cc 3.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "dataset/kernels/py_func_op.h"
  17. #include <memory>
  18. #include <vector>
  19. #include "dataset/core/tensor.h"
  20. #include "dataset/kernels/tensor_op.h"
  21. #include "dataset/util/status.h"
  22. namespace mindspore {
  23. namespace dataset {
  24. Status PyFuncOp::Compute(const std::vector<std::shared_ptr<Tensor>> &input,
  25. std::vector<std::shared_ptr<Tensor>> *output) {
  26. IO_CHECK_VECTOR(input, output);
  27. Status ret = Status(StatusCode::kOK, "PyFunc Call Succeed");
  28. {
  29. // Acquire Python GIL
  30. py::gil_scoped_acquire gil_acquire;
  31. if (Py_IsInitialized() == 0) {
  32. ret = Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized");
  33. goto ComputeReturn;
  34. }
  35. try {
  36. // Transform input tensor vector into numpy array vector
  37. py::tuple input_args(input.size());
  38. for (size_t i = 0; i < input.size(); i++) {
  39. py::array new_data;
  40. RETURN_IF_NOT_OK(input.at(i)->GetDataAsNumpy(&new_data));
  41. // possible memcpy here
  42. input_args[i] = new_data;
  43. }
  44. // Invoke python function
  45. py::object ret_py_obj = this->py_func_ptr_(*input_args);
  46. // Process the return value
  47. if (py::isinstance<py::array>(ret_py_obj)) {
  48. // In case of a n-1 mapping, the return value will be a numpy array
  49. std::shared_ptr<Tensor> out;
  50. RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, ret_py_obj.cast<py::array>()));
  51. output->push_back(out);
  52. } else if (py::isinstance<py::tuple>(ret_py_obj)) {
  53. // In case of a n-m mapping, the return value will be a tuple of numpy arrays
  54. py::tuple ret_py_tuple = ret_py_obj.cast<py::tuple>();
  55. // Iterate over two containers simultaneously for memory copy
  56. for (size_t i = 0; i < ret_py_tuple.size(); i++) {
  57. py::object ret_py_ele = ret_py_tuple[i];
  58. if (!py::isinstance<py::array>(ret_py_ele)) {
  59. goto ShapeMisMatch;
  60. }
  61. std::shared_ptr<Tensor> out;
  62. RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, ret_py_ele.cast<py::array>()));
  63. output->push_back(out);
  64. }
  65. } else {
  66. goto ShapeMisMatch;
  67. }
  68. } catch (const py::error_already_set &e) {
  69. ret = Status(StatusCode::kPyFuncException, e.what());
  70. }
  71. }
  72. ComputeReturn:
  73. return ret;
  74. ShapeMisMatch:
  75. ret = Status(StatusCode::kShapeMisMatch, "PyFunc should return a numpy array or a numpy array tuple");
  76. goto ComputeReturn;
  77. }
  78. } // namespace dataset
  79. } // namespace mindspore