You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

py_func_op.cc 3.0 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "dataset/kernels/py_func_op.h"
  17. #include <memory>
  18. #include <vector>
  19. #include "dataset/core/tensor.h"
  20. #include "dataset/kernels/tensor_op.h"
  21. #include "dataset/util/status.h"
  22. namespace mindspore {
  23. namespace dataset {
  24. Status PyFuncOp::Compute(const TensorRow &input, TensorRow *output) {
  25. IO_CHECK_VECTOR(input, output);
  26. Status ret = Status(StatusCode::kOK, "PyFunc Call Succeed");
  27. {
  28. // Acquire Python GIL
  29. py::gil_scoped_acquire gil_acquire;
  30. if (Py_IsInitialized() == 0) {
  31. ret = Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized");
  32. goto ComputeReturn;
  33. }
  34. try {
  35. // Transform input tensor vector into numpy array vector
  36. py::tuple input_args(input.size());
  37. for (size_t i = 0; i < input.size(); i++) {
  38. py::array new_data;
  39. RETURN_IF_NOT_OK(input.at(i)->GetDataAsNumpy(&new_data));
  40. // possible memcpy here
  41. input_args[i] = new_data;
  42. }
  43. // Invoke python function
  44. py::object ret_py_obj = this->py_func_ptr_(*input_args);
  45. // Process the return value
  46. if (py::isinstance<py::array>(ret_py_obj)) {
  47. // In case of a n-1 mapping, the return value will be a numpy array
  48. std::shared_ptr<Tensor> out;
  49. RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, ret_py_obj.cast<py::array>()));
  50. output->push_back(out);
  51. } else if (py::isinstance<py::tuple>(ret_py_obj)) {
  52. // In case of a n-m mapping, the return value will be a tuple of numpy arrays
  53. py::tuple ret_py_tuple = ret_py_obj.cast<py::tuple>();
  54. // Iterate over two containers simultaneously for memory copy
  55. for (size_t i = 0; i < ret_py_tuple.size(); i++) {
  56. py::object ret_py_ele = ret_py_tuple[i];
  57. if (!py::isinstance<py::array>(ret_py_ele)) {
  58. goto ShapeMisMatch;
  59. }
  60. std::shared_ptr<Tensor> out;
  61. RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, ret_py_ele.cast<py::array>()));
  62. output->push_back(out);
  63. }
  64. } else {
  65. goto ShapeMisMatch;
  66. }
  67. } catch (const py::error_already_set &e) {
  68. ret = Status(StatusCode::kPyFuncException, e.what());
  69. }
  70. }
  71. ComputeReturn:
  72. return ret;
  73. ShapeMisMatch:
  74. ret = Status(StatusCode::kShapeMisMatch, "PyFunc should return a numpy array or a numpy array tuple");
  75. goto ComputeReturn;
  76. }
  77. } // namespace dataset
  78. } // namespace mindspore