You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

py_func_graph_fetcher.h 3.2 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef TESTS_UT_COMMON_PY_FUNC_GRAPH_FETCHER_H_
  17. #define TESTS_UT_COMMON_PY_FUNC_GRAPH_FETCHER_H_
  18. #include <string>
  19. #include <memory>
  20. #include "ir/anf.h"
  21. #include "ir/manager.h"
  22. #include "pipeline/parse/parse_base.h"
  23. #include "pipeline/parse/parse.h"
  24. #include "./common.h"
  25. namespace UT {
  26. void InitPythonPath();
  27. class PyFuncGraphFetcher {
  28. public:
  29. explicit PyFuncGraphFetcher(std::string model_path, bool doResolve = false)
  30. : model_path_(model_path), doResolve_(doResolve) {
  31. InitPythonPath();
  32. }
  33. void SetDoResolve(bool doResolve = true) { doResolve_ = doResolve; }
  34. // The return of python function of "func_name" should be py::function.
  35. // step 1. Call the function user input
  36. // step 2. Parse the return "fn"
  37. template <class... T>
  38. mindspore::FuncGraphPtr CallAndParseRet(std::string func_name, T... args) {
  39. try {
  40. py::function fn = mindspore::parse::python_adapter::CallPyFn(model_path_.c_str(), func_name.c_str(), args...);
  41. mindspore::FuncGraphPtr func_graph = mindspore::parse::ParsePythonCode(fn);
  42. if (doResolve_) {
  43. std::shared_ptr<mindspore::FuncGraphManager> manager = mindspore::Manage(func_graph, false);
  44. mindspore::parse::python_adapter::set_use_signature_in_resolve(false);
  45. mindspore::parse::ResolveAll(manager);
  46. }
  47. return func_graph;
  48. } catch (py::error_already_set& e) {
  49. MS_LOG(ERROR) << "Call and parse fn failed!!! error:" << e.what();
  50. return nullptr;
  51. } catch (...) {
  52. MS_LOG(ERROR) << "Call fn failed!!!";
  53. return nullptr;
  54. }
  55. }
  56. // Fetch python function then parse to graph
  57. mindspore::FuncGraphPtr operator()(std::string func_name, std::string model_path = "") {
  58. try {
  59. std::string path = model_path_;
  60. if ("" != model_path) {
  61. path = model_path;
  62. }
  63. py::function fn = mindspore::parse::python_adapter::GetPyFn(path.c_str(), func_name.c_str());
  64. mindspore::FuncGraphPtr func_graph = mindspore::parse::ParsePythonCode(fn);
  65. if (doResolve_) {
  66. std::shared_ptr<mindspore::FuncGraphManager> manager = mindspore::Manage(func_graph, false);
  67. mindspore::parse::ResolveAll(manager);
  68. }
  69. return func_graph;
  70. } catch (py::error_already_set& e) {
  71. MS_LOG(ERROR) << "get fn failed!!! error:" << e.what();
  72. return nullptr;
  73. } catch (...) {
  74. MS_LOG(ERROR) << "get fn failed!!!";
  75. return nullptr;
  76. }
  77. }
  78. private:
  79. std::string model_path_;
  80. bool doResolve_;
  81. };
  82. } // namespace UT
  83. #endif // TESTS_UT_COMMON_PY_FUNC_GRAPH_FETCHER_H_