You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

resolve_test.cc 2.1 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <iostream>
  17. #include <string>
  18. #include "common/common_test.h"
  19. #include "common/py_func_graph_fetcher.h"
  20. #include "utils/log_adapter.h"
  21. #include "pipeline/jit/parse/parse.h"
  22. #include "debug/draw.h"
  23. namespace mindspore {
  24. namespace parse {
  25. class TestResolve : public UT::Common {
  26. public:
  27. TestResolve() {}
  28. virtual void SetUp();
  29. virtual void TearDown();
  30. };
  31. void TestResolve::SetUp() { UT::InitPythonPath(); }
  32. void TestResolve::TearDown() {}
  33. TEST_F(TestResolve, TestResolveApi) {
  34. py::function fn_ = python_adapter::GetPyFn("gtest_input.pipeline.parse.parser_test", "get_resolve_fn");
  35. // parse graph
  36. FuncGraphPtr func_graph = ParsePythonCode(fn_);
  37. ASSERT_FALSE(nullptr == func_graph);
  38. // save the func_graph to manager
  39. std::shared_ptr<FuncGraphManager> manager = Manage(func_graph);
  40. // call resolve
  41. bool ret_ = ResolveAll(manager);
  42. ASSERT_TRUE(ret_);
  43. ASSERT_EQ(manager->func_graphs().size(), (size_t)2);
  44. }
  45. TEST_F(TestResolve, TestParseGraphTestClosureResolve) {
  46. py::function test_fn =
  47. python_adapter::CallPyFn("gtest_input.pipeline.parse.parser_test", "test_reslove_closure", 123);
  48. FuncGraphPtr func_graph = ParsePythonCode(test_fn);
  49. ASSERT_TRUE(func_graph != nullptr);
  50. // save the func_graph to manager
  51. std::shared_ptr<FuncGraphManager> manager = Manage(func_graph);
  52. // call resolve
  53. bool ret_ = ResolveAll(manager);
  54. ASSERT_TRUE(ret_);
  55. ASSERT_EQ(manager->func_graphs().size(), (size_t)2);
  56. }
  57. } // namespace parse
  58. } // namespace mindspore