You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

profiling_base.py 2.6 kB

5 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. from multiprocessing import Process
  16. from akg.utils.kernel_exec import PERFORMANCE_TEST_FILE
  17. from tests.common.base import TestBase, PERFORMANCE_TEST
  18. TIMEOUT = 600
  19. class ProfilingTestBase(TestBase):
  20. def __init__(self, casename, testcases):
  21. """
  22. testcase preparcondition
  23. :return:
  24. """
  25. casepath = os.getcwd()
  26. super(ProfilingTestBase, self).__init__(casename, casepath)
  27. self.testcases = testcases
  28. def setup(self):
  29. self.caseresult = True
  30. self._log.info("============= {0} Setup case============".format(self.casename))
  31. self.result_file = os.path.join(self.caselog_path, self.casename + ".csv")
  32. os.environ[PERFORMANCE_TEST] = "True"
  33. os.environ[PERFORMANCE_TEST_FILE] = self.result_file
  34. return
  35. def _get_test_case_perf(self, test_case):
  36. _, func, args, _ = self.ana_args(test_case)
  37. func_name = func if isinstance(func, str) else func.__name__
  38. operator_name = func_name.split("_run")[0]
  39. p_file = open(self.result_file, 'a+')
  40. p_file.write("%s; %s; " % (operator_name, args))
  41. p_file.close()
  42. is_conv = True if "conv" in operator_name else False
  43. self.common_run([test_case], is_conv=is_conv)
  44. def test_run_perf(self):
  45. """
  46. run case.
  47. :return:
  48. """
  49. for test_case in self.testcases:
  50. # For the profiling tool, each test case must run with a new process
  51. p = Process(target=self._get_test_case_perf, args=(test_case,))
  52. p.start()
  53. p.join(timeout=TIMEOUT)
  54. if p.is_alive():
  55. p.terminate()
  56. raise RuntimeError("process for {0} timeout!".format(test_case))
  57. def teardown(self):
  58. """
  59. clean environment
  60. :return:
  61. """
  62. os.environ.pop(PERFORMANCE_TEST_FILE)
  63. os.environ.pop(PERFORMANCE_TEST)
  64. self._log.info("============= {0} Teardown============".format(self.casename))
  65. return