You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_graph_handler.py 6.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """
  16. Function:
  17. Test query debugger graph handler.
  18. Usage:
  19. pytest tests/st/func/debugger/test_graph_handler.py
  20. """
  21. import os
  22. import pytest
  23. from ....utils.tools import compare_result_with_file
  24. from .conftest import init_graph_handler
  25. class TestGraphHandler:
  26. """Test GraphHandler."""
  27. graph_results_dir = os.path.join(os.path.dirname(__file__), 'expect_results')
  28. graph_handler = init_graph_handler()
  29. @pytest.mark.level0
  30. @pytest.mark.env_single
  31. @pytest.mark.platform_x86_cpu
  32. @pytest.mark.platform_arm_ascend_training
  33. @pytest.mark.platform_x86_gpu_training
  34. @pytest.mark.platform_x86_ascend_training
  35. @pytest.mark.parametrize("filter_condition, result_file", [
  36. (None, "graph_handler_get_1_no_filter_condintion.json"),
  37. ({'name': 'Default'}, "graph_handler_get_2_list_nodes.json"),
  38. ({'name': 'Default/network-WithLossCell/_backbone-LeNet5/conv1-Conv2d/Cast-op190', 'single_node': True},
  39. "graph_handler_get_3_single_node.json")
  40. ])
  41. def test_get(self, filter_condition, result_file):
  42. """Test get."""
  43. result = self.graph_handler.get(filter_condition)
  44. file_path = os.path.join(self.graph_results_dir, result_file)
  45. compare_result_with_file(result, file_path)
  46. @pytest.mark.level0
  47. @pytest.mark.env_single
  48. @pytest.mark.platform_x86_cpu
  49. @pytest.mark.platform_arm_ascend_training
  50. @pytest.mark.platform_x86_gpu_training
  51. @pytest.mark.platform_x86_ascend_training
  52. @pytest.mark.parametrize("node_name, result_file", [
  53. ("Default/network-WithLossCell/_backbone-LeNet5/conv1-Conv2d/Cast-op190",
  54. "tenor_hist_0.json"),
  55. ("Default/optimizer-Momentum/ApplyMomentum[8]_1/ApplyMomentum-op22",
  56. "tensor_hist_1.json")
  57. ])
  58. def test_get_tensor_history(self, node_name, result_file):
  59. """Test get tensor history."""
  60. result = self.graph_handler.get_tensor_history(node_name)
  61. file_path = os.path.join(self.graph_results_dir, result_file)
  62. compare_result_with_file(result, file_path)
  63. @pytest.mark.level0
  64. @pytest.mark.env_single
  65. @pytest.mark.platform_x86_cpu
  66. @pytest.mark.platform_arm_ascend_training
  67. @pytest.mark.platform_x86_gpu_training
  68. @pytest.mark.platform_x86_ascend_training
  69. @pytest.mark.parametrize("pattern, result_file", [
  70. ("withlogits", "search_nodes_0.json"),
  71. ("cst", "search_node_1.json")
  72. ])
  73. def test_search_nodes(self, pattern, result_file):
  74. """Test search nodes."""
  75. result = self.graph_handler.search_nodes(pattern)
  76. file_path = os.path.join(self.graph_results_dir, result_file)
  77. compare_result_with_file(result, file_path)
  78. @pytest.mark.level0
  79. @pytest.mark.env_single
  80. @pytest.mark.platform_x86_cpu
  81. @pytest.mark.platform_arm_ascend_training
  82. @pytest.mark.platform_x86_gpu_training
  83. @pytest.mark.platform_x86_ascend_training
  84. @pytest.mark.parametrize("node_name, expect_type", [
  85. ("Default/network-WithLossCell/_loss_fn-SoftmaxCrossEntropyWithLogits/cst1", 'Const'),
  86. ("Default/TransData-op99", "TransData")
  87. ])
  88. def test_get_node_type(self, node_name, expect_type):
  89. """Test get node type."""
  90. node_type = self.graph_handler.get_node_type(node_name)
  91. assert node_type == expect_type
  92. @pytest.mark.level0
  93. @pytest.mark.env_single
  94. @pytest.mark.platform_x86_cpu
  95. @pytest.mark.platform_arm_ascend_training
  96. @pytest.mark.platform_x86_gpu_training
  97. @pytest.mark.platform_x86_ascend_training
  98. @pytest.mark.parametrize("node_name, expect_full_name", [
  99. (None, ""),
  100. ("Default/make_tuple[9]_3/make_tuple-op284", "Default/make_tuple-op284"),
  101. ("Default/args0", "Default/args0")
  102. ])
  103. def test_get_full_name(self, node_name, expect_full_name):
  104. """Test get full name."""
  105. full_name = self.graph_handler.get_full_name(node_name)
  106. assert full_name == expect_full_name
  107. @pytest.mark.level0
  108. @pytest.mark.env_single
  109. @pytest.mark.platform_x86_cpu
  110. @pytest.mark.platform_arm_ascend_training
  111. @pytest.mark.platform_x86_gpu_training
  112. @pytest.mark.platform_x86_ascend_training
  113. @pytest.mark.parametrize("full_name, expect_node_name", [
  114. (None, ""),
  115. ("Default/make_tuple-op284", "Default/make_tuple[9]_3/make_tuple-op284"),
  116. ("Default/args0", "Default/args0")
  117. ])
  118. def test_get_node_name_by_full_name(self, full_name, expect_node_name):
  119. """Test get node name by full name."""
  120. node_name = self.graph_handler.get_node_name_by_full_name(full_name)
  121. assert node_name == expect_node_name
  122. @pytest.mark.level0
  123. @pytest.mark.env_single
  124. @pytest.mark.platform_x86_cpu
  125. @pytest.mark.platform_arm_ascend_training
  126. @pytest.mark.platform_x86_gpu_training
  127. @pytest.mark.platform_x86_ascend_training
  128. @pytest.mark.parametrize("node_name, ascend, expect_next", [
  129. (None, True, "Default/network-WithLossCell/_loss_fn-SoftmaxCrossEntropyWithLogits/OneHot-op0"),
  130. (None, False, None),
  131. ("Default/tuple_getitem[10]_0/tuple_getitem-op206", True,
  132. "Default/network-WithLossCell/_backbone-LeNet5/relu-ReLU/ReLUV2-op89"),
  133. ("Default/tuple_getitem[10]_0/tuple_getitem-op206", False,
  134. "Default/network-WithLossCell/_backbone-LeNet5/max_pool2d-MaxPool2d/Cast-op205")
  135. ])
  136. def test_get_node_by_bfs_order(self, node_name, ascend, expect_next):
  137. """Test get node by BFS order."""
  138. next_node = self.graph_handler.get_node_by_bfs_order(node_name, ascend)
  139. assert next_node == expect_next