You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

base.py 22 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """Test Base class"""
  15. import os
  16. import sys
  17. import time
  18. import tarfile
  19. import datetime
  20. import collections
  21. import numpy as np
  22. from akg import dim
  23. from akg.utils.result_analysis import count_unequal_element
  24. from tests.common import tensorio
  25. from tests.common.ftp_handel import ftpHandle
  26. from tests.common.log import Log
  27. PERFORMANCE_TEST = "PERFORMANCE_TEST"
  28. class TestBase(object):
  29. pandora_logger_ = None
  30. def params_init(self, case_name, case_path, max_retry=3):
  31. self.casename = case_name
  32. self.caselog_path = case_path
  33. self.max_retry = max_retry
  34. # Define the log storage location, which is stored in case_log by default.
  35. self.case_result = True
  36. if TestBase.pandora_logger_ is None:
  37. TestBase.pandora_logger_ = Log(case_name, case_path)
  38. self._log = TestBase.pandora_logger_.log
  39. self.test_args = []
  40. self.caseresult = True
  41. self._exception = None
  42. def setup(self):
  43. self._log.info("TestBase:{0} Setup case".format(self.casename))
  44. return True
  45. def teardown(self):
  46. self._log.info("TestBase:{0} Teardown".format(self.casename))
  47. return
  48. def run_test_arg_func(self, test_args=[], attr=None):
  49. if not attr:
  50. self._log.info("attr is None")
  51. return False
  52. run_mode = self.get_env_var("RUNTIME_MODE")
  53. if run_mode in ["compile_cloud", "compile_mini"]:
  54. mode = "compile"
  55. else:
  56. mode = "execute"
  57. for arg in test_args:
  58. self._log.info(arg)
  59. if attr in arg[-1]:
  60. case_result, exception = self.common_run([arg[0:-1]], mode=mode)
  61. if not case_result:
  62. self._log.info("{0} run failed".format(arg))
  63. return False
  64. return True
  65. def print_args(self):
  66. for index, arg in enumerate(self.test_args):
  67. print("{0} {1}".format(index, arg[0]))
  68. def ana_args(self, arg, is_conv=False):
  69. caseflag, func, args = arg[0:3]
  70. kwargs = {}
  71. attrs = self.get_dim_info(arg, is_conv)
  72. if self.get_env_var(PERFORMANCE_TEST):
  73. attrs["record_core"] = True
  74. if attrs is not None:
  75. if len(arg) == 5 and not arg[-1]:
  76. args = list(args)
  77. args.append(attrs)
  78. args.append(arg[-1])
  79. kwargs = {}
  80. else:
  81. args = list(args)
  82. kwargs = {"attrs": attrs}
  83. return caseflag, func, args, kwargs
  84. def get_dim_info(self, arg, is_conv=False):
  85. info = dim.Dim()
  86. tile_dims = []
  87. dims = None
  88. enable_multicore = None
  89. dynamic = False
  90. partial_dynamic = False
  91. bypass_l1 = False
  92. if "dynamic" in arg:
  93. dynamic = True
  94. if isinstance(arg, tuple):
  95. arg = list(arg)
  96. arg.remove("dynamic")
  97. arg = tuple(arg)
  98. else:
  99. arg.remove("dynamic")
  100. if "partial_dynamic" in arg:
  101. partial_dynamic = True
  102. arg.remove("partial_dynamic")
  103. if "bypassL1" in arg:
  104. bypass_l1 = True
  105. arg.remove("bypassL1")
  106. if is_conv:
  107. dy = dynamic or partial_dynamic
  108. if len(arg) == 4:
  109. conv_tile = arg[3]
  110. if len(conv_tile) > 0:
  111. if not dy:
  112. return {
  113. "dim": str(info),
  114. "conv_tile": conv_tile,
  115. "enable_multicore": True,
  116. "bypass": 1 if bypass_l1 else 0,
  117. }
  118. else:
  119. return {
  120. "dim": str(info),
  121. "conv_tile": conv_tile,
  122. "dynamic": dynamic,
  123. "partial_dynamic": partial_dynamic,
  124. "bypass": 1 if bypass_l1 else 0,
  125. }
  126. elif dy and len(arg) == 3:
  127. return {
  128. "dynamic": dynamic,
  129. "partial_dynamic": partial_dynamic,
  130. "bypass": 1 if bypass_l1 else 0,
  131. }
  132. if len(arg) == 5 and not arg[-1]:
  133. dims = arg[3]
  134. for d in range(len(dims)):
  135. tile_dims.append(dims[d][0])
  136. elif (len(arg) == 5 and arg[-1]) or len(arg) == 4:
  137. if isinstance(arg[3], (bool, int)): # only multicore info
  138. enable_multicore = arg[3]
  139. elif isinstance(arg[3][-1], (bool, int)): # dim info and multicore info
  140. enable_multicore = arg[3][-1]
  141. dims = arg[3][0]
  142. else: # only dim info
  143. dims = arg[3]
  144. if dims is not None:
  145. for i in range(len(dims)):
  146. if (isinstance(dims[i][0], int)):
  147. # only one index, ((l1,l0),(l1,l0),...)
  148. i_dims = dims
  149. else:
  150. # multiple indices, (((l1,l0),(l1,l0),...), ((l1,l0),(l1,l0),...))
  151. i_dims = dims[i]
  152. for d in range(len(i_dims)):
  153. info.setdim(index=i,
  154. axis=d,
  155. tilel1=i_dims[d][0],
  156. tilel0=i_dims[d][1])
  157. if len(arg) == 5 and not arg[-1]:
  158. return {"tile": tile_dims}
  159. else:
  160. res = {"dim": str(info), "dynamic": dynamic}
  161. if enable_multicore:
  162. res["enable_multicore"] = enable_multicore
  163. return res
  164. def get_env_var(self, env_key=None):
  165. env_dic = os.environ
  166. env_var = env_dic.get(env_key)
  167. if env_var:
  168. return env_var
  169. return None
  170. def translate_func_name(self, arg):
  171. args_list = []
  172. args_list.append(arg[0])
  173. func = arg[1]
  174. if isinstance(func, str):
  175. args_list.append(func)
  176. else:
  177. args_list.append(func.__name__)
  178. for i in range(2, len(arg)):
  179. args_list.append(arg[i])
  180. return tuple(args_list)
  181. def import_get_func(self, func, mode):
  182. """
  183. from test_run.tile_run import tile_compile
  184. :param func: function name
  185. :param mode: case mode
  186. :return:
  187. """
  188. func_fromlist = "tests.common.test_run." + func
  189. try:
  190. new_func = func
  191. func_py = __import__(func_fromlist, fromlist=func)
  192. run_func = getattr(func_py, new_func)
  193. except (ImportError, AttributeError) as e:
  194. new_func = func.split("_run")[0] + "_" + mode
  195. func_py = __import__(func_fromlist, fromlist=new_func)
  196. run_func = getattr(func_py, new_func)
  197. return run_func
  198. def common_run(self, args, dtype_list=None, mode="execute", is_conv=False, raise_exception=True):
  199. """
  200. :param dtype_list:operator program data type
  201. :param mode: operator run mode: such as rpc_cloud/aicmodel
  202. :param raise_exception: By default, when an exception occurs in the compilation,
  203. the assert is used to interrupt the program.
  204. :return:
  205. """
  206. for arg in args:
  207. starttime = datetime.datetime.now()
  208. caseflag, func, args, kwargs = self.ana_args(arg, is_conv)
  209. if dtype_list:
  210. if not self.set_args_dtype(args, func, dtype_list):
  211. self._log.error("common_run failed for set_args_dtype")
  212. return False
  213. if isinstance(func, str):
  214. self._log.info("common_run :: run {funcname} with args:{args}".format(funcname=func, args=args))
  215. func = self.import_get_func(func, mode)
  216. else:
  217. self._log.info("common_run :: run {funcname} with args:{args}".format(funcname=func.__name__, args=args))
  218. mod = None
  219. if mode == "compile":
  220. try:
  221. mod = func(*args, **kwargs)
  222. except Exception as e:
  223. TestBase.pandora_logger_.traceback()
  224. self._exception = e
  225. finally:
  226. if (not mod) or self._exception:
  227. self._log.error("common_run :: circle {0} fail !".format(self.translate_func_name(arg)))
  228. self._log.error("common_run :: compile failed !")
  229. self.case_result = False
  230. elif mode == "execute":
  231. input, output, expect, runres = func(*args, **kwargs)
  232. rtol = atol = 0
  233. compare_res = []
  234. if isinstance(runres, list):
  235. if isinstance(runres[-1], (list, tuple)):
  236. rtol = runres[-1][0]
  237. atol = runres[-1][1]
  238. runres = list(runres[:-1])
  239. compare_res = runres
  240. runres = all(runres)
  241. elif isinstance(runres, collections.Iterable):
  242. compare_res = list(runres)
  243. else:
  244. compare_res = [runres]
  245. kernel_name = self.get_kernel_name(args, func)
  246. cce_file_name = self.collect_cce(kernel_name)
  247. ir_file_name = self.collect_ir(kernel_name)
  248. if not runres:
  249. runtime_mode = os.environ.get("RUNTIME_MODE")
  250. if runtime_mode in ["rpc", "rpc_cloud", "air", "air_cloud"]:
  251. for retry in range(self.max_retry):
  252. self._log.error("Case result is incorrect, but RPC server occasionally produce incorrect "
  253. "output. Retry it before reporting failure. Retry count: " + str(retry + 1))
  254. input, output, expect, runres = func(*args, **kwargs)
  255. if isinstance(runres, list):
  256. if isinstance(runres[-1], (list, tuple)):
  257. rtol = runres[-1][0]
  258. atol = runres[-1][1]
  259. runres = list(runres[:-1])
  260. compare_res = runres
  261. runres = all(runres)
  262. elif isinstance(runres, collections.Iterable):
  263. compare_res = list(runres)
  264. else:
  265. compare_res = [runres]
  266. if runres:
  267. break
  268. if not runres:
  269. self._log.error("common_run :: circle {0} fail !".format(self.translate_func_name(arg)))
  270. self._log.error("common_run :: CompareResult: %s", str(compare_res))
  271. if rtol == 0:
  272. self._log.error("@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@")
  273. self._log.error("Caution: the 'rtol' and 'atol' is default $$$$$1e-4$$$$$")
  274. self._log.error("@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@")
  275. rtol = atol = 1e-4
  276. if isinstance(expect, (tuple, list)):
  277. for i, tmp in enumerate(expect):
  278. count_unequal_element(tmp, output[i], rtol, atol)
  279. else:
  280. if not isinstance(expect, np.ndarray):
  281. expect = np.atleast_1d(expect)
  282. count_unequal_element(expect, output, rtol, atol)
  283. if not self.collect_data(input, output, cce_file_name, ir_file_name, arg, kernel_name):
  284. self._log.error("common_run :: collect data failed")
  285. self.case_result = False
  286. else:
  287. self._log.info("common_run :: circle {0} pass !".format(self.translate_func_name(arg)))
  288. if cce_file_name and os.path.exists(cce_file_name):
  289. os.remove(cce_file_name)
  290. if ir_file_name and os.path.exists(ir_file_name):
  291. os.remove(ir_file_name)
  292. self.case_result &= True
  293. endtime = datetime.datetime.now()
  294. self._log.info("{0} testcase use ***Running Time*** is: {1}s. "
  295. .format(caseflag, (endtime - starttime).seconds))
  296. self._log.info(self.case_result)
  297. '''
  298. use assert in the common_run function:
  299. Because the common_run function in the use cases does not verify the return value, the result cannot be
  300. printed normally after the program ends, so the execution result needs to be judged in the common_run function.
  301. '''
  302. if (not self.case_result) and raise_exception:
  303. assert self.case_result
  304. return self.case_result, self._exception
  305. def get_args_dtype(self, input_args_names):
  306. """
  307. Get the dtype of the function input parameter, return its index
  308. :param input_args_names: Test operator method
  309. :return: kernel_name
  310. """
  311. return tuple([index for index, name in enumerate(input_args_names) if str(name).__contains__("dtype")])
  312. def get_kernel_name(self, args, func):
  313. func_input_args_names = func.__code__.co_varnames
  314. kernel_name = func.__name__.split('_run')[0].split('_execute')[0]
  315. for index, name in enumerate(func_input_args_names):
  316. if str(name).__contains__("kernel_name"):
  317. kernel_name = func.__name__.split('_run')[0].split('_execute')[0]
  318. break
  319. return kernel_name
  320. def replace_args_dtype(self, args, input_args_names, dtype_list):
  321. """
  322. replace the dtype field of args
  323. """
  324. dtype_index_list = self.get_args_dtype(input_args_names)
  325. if not dtype_index_list or len(dtype_index_list) > len(dtype_list):
  326. self._log.error("replace_args_dtype :: dtype_index_list failed, dtype_index_list:{0},dtype_list:{1}".format(
  327. dtype_index_list, dtype_list))
  328. return False
  329. input_dtype_index = 0
  330. for index in dtype_index_list:
  331. args[index] = dtype_list[input_dtype_index]
  332. input_dtype_index += 1
  333. return True
  334. def set_args_dtype(self, args, func, dtype_list):
  335. """
  336. Set the dtype field of the use case parameter list
  337. """
  338. if not args or not dtype_list:
  339. self._log.error("set_args_dtype failed for test_arg_list:{0},dtype_list:{1}".format(args, dtype_list))
  340. return True
  341. func_input_args_names = func.__code__.co_varnames
  342. if not func_input_args_names:
  343. self._log.error("function : {0} args list is None".format(func))
  344. return True
  345. return self.replace_args_dtype(args, func_input_args_names, dtype_list)
  346. def upload_file_ftp(self, upload_type, local_file_path):
  347. if upload_type not in ("csvs", "cce", "ir", "dump_shape", "logs",):
  348. self._log.error("upload_file_ftp failed :: not support for upload_type:{0}".format(upload_type))
  349. return None
  350. today = str(datetime.date.today())
  351. ftp = ftpHandle(self._log)
  352. if not ftp.ftp_login():
  353. self._log.error("upload_file_ftp failed for ftp_login")
  354. return None
  355. remote_path = os.path.join("/auto_tensor", upload_type)
  356. if not ftp.ftp_mkdir(remote_path, today):
  357. self._log.error("upload_file_ftp failed for ftp_mkdir,remote_path:{0},today:{1}".format(remote_path, today))
  358. ftp.ftp_close()
  359. return None
  360. remote_path = os.path.join(remote_path, today)
  361. remote_file_name = str(local_file_path).split("/")[-1]
  362. if not ftp.ftp_upload_file(remote_path, remote_file_name, local_file_path):
  363. self._log.error(
  364. "upload_file_ftp failed for ftp_upload_file,remote_path:{0},today:{1},local_file_path:{2}".format(
  365. remote_path, today, local_file_path))
  366. ftp.ftp_close()
  367. return None
  368. ftp_url = "ftp://{host}/{path}".format(host=ftp.host, path=os.path.join(remote_path, remote_file_name))
  369. ftp.ftp_close()
  370. return ftp_url
  371. def collect_ir(self, kernel_name):
  372. if not os.path.exists(kernel_name):
  373. self._log.warning("not exist ir directory for :{kernel_name}".format(kernel_name=kernel_name))
  374. return None
  375. file_name = kernel_name + ".tar.gz"
  376. with tarfile.open(file_name, "w:gz") as tar:
  377. tar.add(kernel_name, arcname=os.path.basename(kernel_name))
  378. return file_name
  379. def collect_cce(self, kernel_name):
  380. file_name = kernel_name + ".cce"
  381. if not os.path.exists(file_name):
  382. self._log.warning("not exist cce file for :{file_name}".format(file_name=file_name))
  383. return None
  384. return file_name
  385. def collect_data(self, input, output, cce_file_name, ir_file_name, arg, kernel_name):
  386. ret_val = True
  387. # dump input and output
  388. dump_file_list = self.data_dump(input, output, arg)
  389. self._log.warning("dump input and output as follow:")
  390. if os.environ.get("FTP_HOST"):
  391. for dump_file in dump_file_list:
  392. ftp_url = self.upload_file_ftp("dump_shape", dump_file)
  393. if not ftp_url:
  394. self._log.error("upload_file_ftp failed for dump_file : {0}".format(dump_file))
  395. ret_val = False
  396. else:
  397. self._log.warning("dump_file ftp_url : {0}".format(ftp_url))
  398. # dump ir
  399. if not ir_file_name:
  400. self._log.error("collect_ir failed")
  401. ret_val = False
  402. else:
  403. ftp_url = self.upload_file_ftp("ir", ir_file_name)
  404. if not ftp_url:
  405. self._log.error("upload_file_ftp failed for ir_file_name : {0}".format(ir_file_name))
  406. ret_val = False
  407. else:
  408. self._log.warning("ir ftp_url : {0}".format(ftp_url))
  409. # dump cce
  410. if not cce_file_name:
  411. self._log.error("collect_cce failed")
  412. ret_val = False
  413. else:
  414. ftp_url = self.upload_file_ftp("cce", cce_file_name)
  415. if not ftp_url:
  416. self._log.error("upload_file_ftp failed for cce_file_name : {0}".format(cce_file_name))
  417. ret_val = False
  418. else:
  419. self._log.warning("cce ftp_url : {0}".format(ftp_url))
  420. else:
  421. case_failed_save_path = '/' + '/'.join(os.path.abspath(self.casename).split('/')[1:-1])
  422. self._log.warning("The input output data of failed use case log have been saved to the path : {0}/data/{1}"
  423. .format(case_failed_save_path, kernel_name))
  424. self._log.warning("The ir data of failed use case log have been saved to the path : {0}/{1}"
  425. .format(case_failed_save_path, ir_file_name))
  426. self._log.warning("The cce data of failed use case log have been saved to the path : {0}/{1}"
  427. .format(case_failed_save_path, cce_file_name))
  428. return ret_val
  429. def data_dump(self, input, output, arg):
  430. dump_file_list = []
  431. operator_name = str(arg[1]).split("_run")[0].split()[-1]
  432. data_dir = "./data/{0}/".format(operator_name)
  433. os.popen("mkdir -p %s" % data_dir)
  434. time.sleep(1)
  435. if not isinstance(input, list) and not isinstance(input, tuple):
  436. input = [input]
  437. if not isinstance(output, list) and not isinstance(output, tuple):
  438. output = [output]
  439. data_dict = {"input": input, "output": output}
  440. for kays in data_dict.keys():
  441. for index, i in enumerate(data_dict[kays]):
  442. seq = [operator_name, kays, str(index + 1)] + list(map(str, arg[2])) + [".t"]
  443. dump_file_name = "_".join(seq).replace("[", "").replace("]", "").replace(",", "-") \
  444. .replace(" ", "").replace("(", "").replace(")", "").replace("_.", ".")
  445. dump_file_name += str(time.time())
  446. dump_file = os.path.join(data_dir, dump_file_name)
  447. dump_file_list.append(dump_file)
  448. tensorio.dump_tensor(i, dump_file)
  449. return dump_file_list
  450. def get_rtol_atol(op_name, dtype, rtol=5e-03, atol=5e-03):
  451. run_mode = os.environ.get('RUNTIME_MODE')
  452. if run_mode in ("rpc_cloud", "air_cloud"):
  453. if dtype == "float16":
  454. rtol = atol = 1e-03
  455. else:
  456. rtol = atol = 1e-04
  457. return rtol, atol
  458. def get_splitted_cases(cases, split_nums, split_idx):
  459. if not isinstance(cases, (list, tuple)):
  460. raise TypeError("Argument cases must be of type list or tuple.")
  461. if not isinstance(split_nums, int) or not isinstance(split_idx, int):
  462. raise TypeError("Arguments split_nums and split_idx must be of type int.")
  463. if split_nums <= 0 or split_idx < 0 or split_idx >= split_nums:
  464. raise ValueError("Argument split_nums must > 0, split_idx must be in range [0, split_nums)")
  465. cases = list(cases)
  466. all_cases = len(cases)
  467. fragment = (all_cases + split_nums - 1) // split_nums
  468. start_idx = split_idx * fragment
  469. if start_idx >= all_cases:
  470. return []
  471. end_idx = start_idx + fragment
  472. if end_idx > all_cases:
  473. end_idx = all_cases
  474. return cases[start_idx:end_idx]

AKG(Auto Kernel Generator)对深度神经网络中的算子进行优化,并提供特定模式下的算子自动融合功能。AKG与MindSpore的图算融合功能协同工作,可提升在不同硬件后端上运行网络的性能。