You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

py_func_graph_fetcher.h 3.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef TESTS_UT_COMMON_PY_FUNC_GRAPH_FETCHER_H_
  17. #define TESTS_UT_COMMON_PY_FUNC_GRAPH_FETCHER_H_
  18. #include <string>
  19. #include <memory>
  20. #include "ir/anf.h"
  21. #include "ir/primitive.h"
  22. #include "ir/manager.h"
  23. #include "ir/func_graph.h"
  24. #include "pipeline/parse/parse_base.h"
  25. #include "pipeline/parse/parse.h"
  26. #include "./common.h"
  27. namespace UT {
  28. void InitPythonPath();
  29. class PyFuncGraphFetcher {
  30. public:
  31. explicit PyFuncGraphFetcher(std::string model_path, bool doResolve = false)
  32. : model_path_(model_path), doResolve_(doResolve) {
  33. InitPythonPath();
  34. }
  35. void SetDoResolve(bool doResolve = true) { doResolve_ = doResolve; }
  36. // The return of python function of "func_name" should be py::function.
  37. // step 1. Call the function user input
  38. // step 2. Parse the return "fn"
  39. template <class... T>
  40. mindspore::FuncGraphPtr CallAndParseRet(std::string func_name, T... args) {
  41. try {
  42. py::function fn = mindspore::parse::python_adapter::CallPyFn(model_path_.c_str(), func_name.c_str(), args...);
  43. mindspore::FuncGraphPtr func_graph = mindspore::parse::ParsePythonCode(fn);
  44. if (doResolve_) {
  45. std::shared_ptr<mindspore::FuncGraphManager> manager = mindspore::Manage(func_graph, false);
  46. mindspore::parse::python_adapter::set_use_signature_in_resolve(false);
  47. mindspore::parse::ResolveAll(manager);
  48. }
  49. return func_graph;
  50. } catch (py::error_already_set& e) {
  51. MS_LOG(ERROR) << "Call and parse fn failed!!! error:" << e.what();
  52. return nullptr;
  53. } catch (...) {
  54. MS_LOG(ERROR) << "Call fn failed!!!";
  55. return nullptr;
  56. }
  57. }
  58. // Fetch python function then parse to graph
  59. mindspore::FuncGraphPtr operator()(std::string func_name, std::string model_path = "") {
  60. try {
  61. std::string path = model_path_;
  62. if ("" != model_path) {
  63. path = model_path;
  64. }
  65. py::function fn = mindspore::parse::python_adapter::GetPyFn(path.c_str(), func_name.c_str());
  66. mindspore::FuncGraphPtr func_graph = mindspore::parse::ParsePythonCode(fn);
  67. if (doResolve_) {
  68. std::shared_ptr<mindspore::FuncGraphManager> manager = mindspore::Manage(func_graph, false);
  69. mindspore::parse::ResolveAll(manager);
  70. }
  71. return func_graph;
  72. } catch (py::error_already_set& e) {
  73. MS_LOG(ERROR) << "get fn failed!!! error:" << e.what();
  74. return nullptr;
  75. } catch (...) {
  76. MS_LOG(ERROR) << "get fn failed!!!";
  77. return nullptr;
  78. }
  79. }
  80. private:
  81. std::string model_path_;
  82. bool doResolve_;
  83. };
  84. } // namespace UT
  85. #endif // TESTS_UT_COMMON_PY_FUNC_GRAPH_FETCHER_H_