You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

parser_primitive_test.cc 4.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <iostream>
  17. #include <string>
  18. #include "common/common_test.h"
  19. #include "common/py_func_graph_fetcher.h"
  20. #include "utils/log_adapter.h"
  21. #include "pipeline/parse/parse.h"
  22. #include "debug/draw.h"
  23. namespace mindspore {
  24. namespace parse {
  25. class TestParserPrimitive : public UT::Common {
  26. public:
  27. TestParserPrimitive() {}
  28. virtual void SetUp();
  29. virtual void TearDown();
  30. };
  31. void TestParserPrimitive::SetUp() { UT::InitPythonPath(); }
  32. void TestParserPrimitive::TearDown() {}
  33. TEST_F(TestParserPrimitive, TestParserOpsMethod1) {
  34. py::function fn_ = python_adapter::GetPyFn("gtest_input.pipeline.parse.parse_primitive", "test_ops_f1");
  35. FuncGraphPtr func_graph = ParsePythonCode(fn_);
  36. ASSERT_TRUE(nullptr != func_graph);
  37. // save the func_graph to manager
  38. std::shared_ptr<FuncGraphManager> manager = Manage(func_graph);
  39. // call resolve
  40. bool ret_ = ResolveAll(manager);
  41. ASSERT_TRUE(ret_);
  42. // draw graph
  43. int i = 0;
  44. for (auto tmp : manager->func_graphs()) {
  45. std::string name = "ut_parser_ops_1_" + std::to_string(i) + ".dot";
  46. draw::Draw(name, tmp);
  47. i++;
  48. }
  49. }
  50. TEST_F(TestParserPrimitive, TestParserOpsMethod2) {
  51. py::function fn_ = python_adapter::GetPyFn("gtest_input.pipeline.parse.parse_primitive", "test_ops_f2");
  52. FuncGraphPtr func_graph = ParsePythonCode(fn_);
  53. ASSERT_TRUE(nullptr != func_graph);
  54. // save the func_graph to manager
  55. std::shared_ptr<FuncGraphManager> manager = Manage(func_graph);
  56. // call resolve
  57. bool ret_ = ResolveAll(manager);
  58. ASSERT_TRUE(ret_);
  59. // draw graph
  60. int i = 0;
  61. for (auto tmp : manager->func_graphs()) {
  62. std::string name = "ut_parser_ops_2_" + std::to_string(i) + ".dot";
  63. draw::Draw(name, tmp);
  64. i++;
  65. }
  66. }
  67. // Test primitive class obj
  68. TEST_F(TestParserPrimitive, TestParsePrimitive) {
  69. #if 0 // Segmentation fault
  70. py::object obj_ = python_adapter::CallPyFn("gtest_input.pipeline.parse.parse_primitive", "test_primitive_obj");
  71. Parser::InitParserEnvironment(obj_);
  72. FuncGraphPtr func_graph = ParsePythonCode(obj_);
  73. ASSERT_TRUE(nullptr != func_graph);
  74. draw::Draw("ut_parser_primitive_x.dot", func_graph);
  75. // save the func_graph to manager
  76. std::shared_ptr<FuncGraphManager> manager = Manage(func_graph);
  77. // call resolve
  78. bool ret_ = ResolveAll(manager);
  79. ASSERT_TRUE(ret_);
  80. // draw graph
  81. int i = 0;
  82. for (auto tmp : manager->func_graphs()) {
  83. std::string name = "ut_parser_ops_3_" + std::to_string(i) + ".dot";
  84. draw::Draw(name, tmp);
  85. i++;
  86. }
  87. #endif
  88. }
  89. /* skip ut test case temporarily
  90. TEST_F(TestParserPrimitive, TestParsePrimitiveParmeter) {
  91. py::object obj_ =
  92. python_adapter::CallPyFn("gtest_input.pipeline.parse.parse_primitive", "test_primitive_obj_parameter");
  93. Parser::InitParserEnvironment(obj_);
  94. FuncGraphPtr func_graph = ParsePythonCode(obj_);
  95. ASSERT_TRUE(nullptr != func_graph);
  96. draw::Draw("ut_parser_primitive_x.dot", func_graph);
  97. // save the func_graph to manager
  98. std::shared_ptr<FuncGraphManager> manager = Manage(func_graph);
  99. // call resolve
  100. bool ret_ = ResolveAll(manager);
  101. ASSERT_TRUE(ret_);
  102. // draw graph
  103. int i = 0;
  104. for (auto tmp : manager->func_graphs()) {
  105. std::string name = "ut_parser_ops_4_" + std::to_string(i) + ".dot";
  106. draw::Draw(name, tmp);
  107. i++;
  108. }
  109. }
  110. TEST_F(TestParserPrimitive, TestParsePrimitiveParmeter2) {
  111. py::object obj_ = python_adapter::CallPyFn("gtest_input.pipeline.parse.parse_primitive", "test_primitive_functional");
  112. Parser::InitParserEnvironment(obj_);
  113. FuncGraphPtr func_graph = ParsePythonCode(obj_);
  114. ASSERT_TRUE(nullptr != func_graph);
  115. draw::Draw("ut_parser_primitive_x.dot", func_graph);
  116. // save the func_graph to manager
  117. std::shared_ptr<FuncGraphManager> manager = Manage(func_graph);
  118. // call resolve
  119. bool ret_ = ResolveAll(manager);
  120. ASSERT_TRUE(ret_);
  121. // draw graph
  122. int i = 0;
  123. for (auto tmp : manager->func_graphs()) {
  124. std::string name = "ut_parser_ops_5_" + std::to_string(i) + ".dot";
  125. draw::Draw(name, tmp);
  126. i++;
  127. }
  128. }
  129. */
  130. } // namespace parse
  131. } // namespace mindspore