You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

cconv_test.cc 4.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <iostream>
  17. #include <string>
  18. #include "common/common_test.h"
  19. #include "common/py_func_graph_fetcher.h"
  20. #include "ir/func_graph_cloner.h"
  21. #include "utils/log_adapter.h"
  22. #include "pipeline/jit/parse/parse.h"
  23. #include "debug/draw.h"
  24. namespace mindspore {
  25. void CheckNoFreeVariables(FuncGraphPtr root) {
  26. auto mng = Manage(root);
  27. for (auto &iter : mng->func_graphs()) {
  28. auto g = iter;
  29. if (g == nullptr) {
  30. continue;
  31. }
  32. ASSERT_TRUE(g->parent() == nullptr);
  33. auto nodes = g->nodes();
  34. for (auto &node : nodes) {
  35. ASSERT_EQ(node->func_graph(), g);
  36. auto cnode = node->cast<CNodePtr>();
  37. if (cnode != nullptr) {
  38. for (auto &inp : cnode->inputs()) {
  39. ASSERT_TRUE(inp->func_graph() == nullptr || inp->func_graph() == g);
  40. }
  41. }
  42. }
  43. }
  44. }
  45. void CheckCconv(FuncGraphPtr g) {
  46. auto mng = Manage(g);
  47. auto new_g = LiftingClone(g);
  48. CheckNoFreeVariables(new_g);
  49. }
  50. class TestCconv : public UT::Common {
  51. public:
  52. TestCconv() : getPyFun("gtest_input.optimizer.cconv_test") {}
  53. virtual void SetUp();
  54. virtual void TearDown();
  55. public:
  56. UT::PyFuncGraphFetcher getPyFun;
  57. };
  58. void TestCconv::SetUp() {}
  59. void TestCconv::TearDown() {}
  60. TEST_F(TestCconv, TestStraight) {
  61. FuncGraphPtr func_graph = getPyFun.CallAndParseRet("get_test_cconv_fn", "test_straight");
  62. ASSERT_TRUE(nullptr != func_graph);
  63. CheckCconv(func_graph);
  64. }
  65. TEST_F(TestCconv, TestSimpleClosure) {
  66. FuncGraphPtr func_graph = getPyFun.CallAndParseRet("get_test_cconv_fn", "test_simple_closure");
  67. ASSERT_TRUE(nullptr != func_graph);
  68. CheckCconv(func_graph);
  69. }
  70. TEST_F(TestCconv, TestMax) {
  71. FuncGraphPtr func_graph = getPyFun.CallAndParseRet("get_test_cconv_fn", "test_max");
  72. ASSERT_TRUE(nullptr != func_graph);
  73. CheckCconv(func_graph);
  74. }
  75. TEST_F(TestCconv, TestDeepNesting) {
  76. FuncGraphPtr func_graph = getPyFun.CallAndParseRet("get_test_cconv_fn", "test_deep_nesting");
  77. ASSERT_TRUE(nullptr != func_graph);
  78. CheckCconv(func_graph);
  79. }
  80. TEST_F(TestCconv, TestReturnInDoubleWhile) {
  81. FuncGraphPtr func_graph = getPyFun.CallAndParseRet("get_test_cconv_fn", "test_return_in_double_while");
  82. ASSERT_TRUE(nullptr != func_graph);
  83. CheckCconv(func_graph);
  84. }
  85. TEST_F(TestCconv, TestPow10) {
  86. FuncGraphPtr func_graph = getPyFun.CallAndParseRet("get_test_cconv_fn", "test_pow10");
  87. ASSERT_TRUE(nullptr != func_graph);
  88. CheckCconv(func_graph);
  89. }
  90. TEST_F(TestCconv, TestClosureAsSimpleFv) {
  91. FuncGraphPtr func_graph = getPyFun.CallAndParseRet("get_test_cconv_fn", "test_closure_as_simple_fv");
  92. ASSERT_TRUE(nullptr != func_graph);
  93. CheckCconv(func_graph);
  94. }
  95. TEST_F(TestCconv, TestClosureAsFv) {
  96. FuncGraphPtr func_graph = getPyFun.CallAndParseRet("get_test_cconv_fn", "test_closure_as_fv");
  97. ASSERT_TRUE(nullptr != func_graph);
  98. CheckCconv(func_graph);
  99. }
  100. TEST_F(TestCconv, TestClosureAsDoubleFv) {
  101. FuncGraphPtr func_graph = getPyFun.CallAndParseRet("get_test_cconv_fn", "test_closure_as_double_fv");
  102. ASSERT_TRUE(nullptr != func_graph);
  103. CheckCconv(func_graph);
  104. }
  105. TEST_F(TestCconv, TestClosureLiftSameParam) {
  106. FuncGraphPtr func_graph = getPyFun.CallAndParseRet("get_test_cconv_fn", "test_closure_lift_same_param");
  107. ASSERT_TRUE(nullptr != func_graph);
  108. CheckCconv(func_graph);
  109. }
  110. TEST_F(TestCconv, TestClosureAsLoop) {
  111. FuncGraphPtr func_graph = getPyFun.CallAndParseRet("get_test_cconv_fn", "test_closure_as_loop");
  112. ASSERT_TRUE(nullptr != func_graph);
  113. CheckCconv(func_graph);
  114. }
  115. TEST_F(TestCconv, TestClosureLiftCNode) {
  116. FuncGraphPtr func_graph = getPyFun.CallAndParseRet("get_test_cconv_fn", "test_closure_lift_cnode");
  117. ASSERT_TRUE(nullptr != func_graph);
  118. CheckCconv(func_graph);
  119. }
  120. } // namespace mindspore