|
- # Copyright 2019 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
-
- """Test Base class"""
- import os
- import sys
- import time
- import tarfile
- import datetime
- import collections
- import numpy as np
- from akg import dim
- from akg.utils.result_analysis import count_unequal_element
- from tests.common import tensorio
- from tests.common.ftp_handel import ftpHandle
- from tests.common.log import Log
-
- PERFORMANCE_TEST = "PERFORMANCE_TEST"
-
-
- class TestBase(object):
- pandora_logger_ = None
-
- def params_init(self, case_name, case_path, max_retry=3):
- self.casename = case_name
- self.caselog_path = case_path
- self.max_retry = max_retry
- # Define the log storage location, which is stored in case_log by default.
- self.case_result = True
- if TestBase.pandora_logger_ is None:
- TestBase.pandora_logger_ = Log(case_name, case_path)
- self._log = TestBase.pandora_logger_.log
- self.test_args = []
- self.caseresult = True
- self._exception = None
-
- def setup(self):
- self._log.info("TestBase:{0} Setup case".format(self.casename))
- return True
-
- def teardown(self):
- self._log.info("TestBase:{0} Teardown".format(self.casename))
- return
-
- def run_test_arg_func(self, test_args=[], attr=None):
- if not attr:
- self._log.info("attr is None")
- return False
- run_mode = self.get_env_var("RUNTIME_MODE")
- if run_mode in ["compile_cloud", "compile_mini"]:
- mode = "compile"
- else:
- mode = "execute"
- for arg in test_args:
- self._log.info(arg)
- if attr in arg[-1]:
- case_result, exception = self.common_run([arg[0:-1]], mode=mode)
- if not case_result:
- self._log.info("{0} run failed".format(arg))
- return False
- return True
-
- def print_args(self):
- for index, arg in enumerate(self.test_args):
- print("{0} {1}".format(index, arg[0]))
-
- def ana_args(self, arg, is_conv=False):
- caseflag, func, args = arg[0:3]
- kwargs = {}
- attrs = self.get_dim_info(arg, is_conv)
- if self.get_env_var(PERFORMANCE_TEST):
- attrs["record_core"] = True
- if attrs is not None:
- if len(arg) == 5 and not arg[-1]:
- args = list(args)
- args.append(attrs)
- args.append(arg[-1])
- kwargs = {}
- else:
- args = list(args)
- kwargs = {"attrs": attrs}
- return caseflag, func, args, kwargs
-
- def get_dim_info(self, arg, is_conv=False):
- info = dim.Dim()
- tile_dims = []
- dims = None
- enable_multicore = None
- dynamic = False
- partial_dynamic = False
- bypass_l1 = False
- if "dynamic" in arg:
- dynamic = True
- if isinstance(arg, tuple):
- arg = list(arg)
- arg.remove("dynamic")
- arg = tuple(arg)
- else:
- arg.remove("dynamic")
- if "partial_dynamic" in arg:
- partial_dynamic = True
- arg.remove("partial_dynamic")
- if "bypassL1" in arg:
- bypass_l1 = True
- arg.remove("bypassL1")
- if is_conv:
- dy = dynamic or partial_dynamic
- if len(arg) == 4:
- conv_tile = arg[3]
- if len(conv_tile) > 0:
- if not dy:
- return {
- "dim": str(info),
- "conv_tile": conv_tile,
- "enable_multicore": True,
- "bypass": 1 if bypass_l1 else 0,
- }
- else:
- return {
- "dim": str(info),
- "conv_tile": conv_tile,
- "dynamic": dynamic,
- "partial_dynamic": partial_dynamic,
- "bypass": 1 if bypass_l1 else 0,
- }
- elif dy and len(arg) == 3:
- return {
- "dynamic": dynamic,
- "partial_dynamic": partial_dynamic,
- "bypass": 1 if bypass_l1 else 0,
- }
-
- if len(arg) == 5 and not arg[-1]:
- dims = arg[3]
- for d in range(len(dims)):
- tile_dims.append(dims[d][0])
- elif (len(arg) == 5 and arg[-1]) or len(arg) == 4:
- if isinstance(arg[3], (bool, int)): # only multicore info
- enable_multicore = arg[3]
- elif isinstance(arg[3][-1], (bool, int)): # dim info and multicore info
- enable_multicore = arg[3][-1]
- dims = arg[3][0]
- else: # only dim info
- dims = arg[3]
- if dims is not None:
- for i in range(len(dims)):
- if (isinstance(dims[i][0], int)):
- # only one index, ((l1,l0),(l1,l0),...)
- i_dims = dims
- else:
- # multiple indices, (((l1,l0),(l1,l0),...), ((l1,l0),(l1,l0),...))
- i_dims = dims[i]
-
- for d in range(len(i_dims)):
- info.setdim(index=i,
- axis=d,
- tilel1=i_dims[d][0],
- tilel0=i_dims[d][1])
-
- if len(arg) == 5 and not arg[-1]:
- return {"tile": tile_dims}
- else:
- res = {"dim": str(info), "dynamic": dynamic}
- if enable_multicore:
- res["enable_multicore"] = enable_multicore
- return res
-
- def get_env_var(self, env_key=None):
- env_dic = os.environ
- env_var = env_dic.get(env_key)
- if env_var:
- return env_var
- return None
-
- def translate_func_name(self, arg):
- args_list = []
- args_list.append(arg[0])
- func = arg[1]
- if isinstance(func, str):
- args_list.append(func)
- else:
- args_list.append(func.__name__)
- for i in range(2, len(arg)):
- args_list.append(arg[i])
- return tuple(args_list)
-
- def import_get_func(self, func, mode):
- """
- from test_run.tile_run import tile_compile
- :param func: function name
- :param mode: case mode
- :return:
- """
- func_fromlist = "tests.common.test_run." + func
- try:
- new_func = func
- func_py = __import__(func_fromlist, fromlist=func)
- run_func = getattr(func_py, new_func)
- except (ImportError, AttributeError) as e:
- new_func = func.split("_run")[0] + "_" + mode
- func_py = __import__(func_fromlist, fromlist=new_func)
- run_func = getattr(func_py, new_func)
- return run_func
-
- def common_run(self, args, dtype_list=None, mode="execute", is_conv=False, raise_exception=True):
- """
- :param dtype_list:operator program data type
- :param mode: operator run mode: such as rpc_cloud/aicmodel
- :param raise_exception: By default, when an exception occurs in the compilation,
- the assert is used to interrupt the program.
- :return:
- """
- for arg in args:
- starttime = datetime.datetime.now()
- caseflag, func, args, kwargs = self.ana_args(arg, is_conv)
-
- if dtype_list:
- if not self.set_args_dtype(args, func, dtype_list):
- self._log.error("common_run failed for set_args_dtype")
- return False
-
- if isinstance(func, str):
- self._log.info("common_run :: run {funcname} with args:{args}".format(funcname=func, args=args))
- func = self.import_get_func(func, mode)
- else:
- self._log.info("common_run :: run {funcname} with args:{args}".format(funcname=func.__name__, args=args))
-
- mod = None
- if mode == "compile":
- try:
- mod = func(*args, **kwargs)
- except Exception as e:
- TestBase.pandora_logger_.traceback()
- self._exception = e
- finally:
- if (not mod) or self._exception:
- self._log.error("common_run :: circle {0} fail !".format(self.translate_func_name(arg)))
- self._log.error("common_run :: compile failed !")
- self.case_result = False
-
- elif mode == "execute":
- input, output, expect, runres = func(*args, **kwargs)
- rtol = atol = 0
- compare_res = []
- if isinstance(runres, list):
- if isinstance(runres[-1], (list, tuple)):
- rtol = runres[-1][0]
- atol = runres[-1][1]
- runres = list(runres[:-1])
- compare_res = runres
- runres = all(runres)
- elif isinstance(runres, collections.Iterable):
- compare_res = list(runres)
- else:
- compare_res = [runres]
-
- kernel_name = self.get_kernel_name(args, func)
- cce_file_name = self.collect_cce(kernel_name)
- ir_file_name = self.collect_ir(kernel_name)
- if not runres:
- runtime_mode = os.environ.get("RUNTIME_MODE")
- if runtime_mode in ["rpc", "rpc_cloud", "air", "air_cloud"]:
- for retry in range(self.max_retry):
- self._log.error("Case result is incorrect, but RPC server occasionally produce incorrect "
- "output. Retry it before reporting failure. Retry count: " + str(retry + 1))
- input, output, expect, runres = func(*args, **kwargs)
-
- if isinstance(runres, list):
- if isinstance(runres[-1], (list, tuple)):
- rtol = runres[-1][0]
- atol = runres[-1][1]
- runres = list(runres[:-1])
- compare_res = runres
- runres = all(runres)
- elif isinstance(runres, collections.Iterable):
- compare_res = list(runres)
- else:
- compare_res = [runres]
-
- if runres:
- break
- if not runres:
- self._log.error("common_run :: circle {0} fail !".format(self.translate_func_name(arg)))
- self._log.error("common_run :: CompareResult: %s", str(compare_res))
-
- if rtol == 0:
- self._log.error("@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@")
- self._log.error("Caution: the 'rtol' and 'atol' is default $$$$$1e-4$$$$$")
- self._log.error("@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@")
- rtol = atol = 1e-4
-
- if isinstance(expect, (tuple, list)):
- for i, tmp in enumerate(expect):
- count_unequal_element(tmp, output[i], rtol, atol)
- else:
- if not isinstance(expect, np.ndarray):
- expect = np.atleast_1d(expect)
- count_unequal_element(expect, output, rtol, atol)
-
- if not self.collect_data(input, output, cce_file_name, ir_file_name, arg, kernel_name):
- self._log.error("common_run :: collect data failed")
- self.case_result = False
- else:
- self._log.info("common_run :: circle {0} pass !".format(self.translate_func_name(arg)))
- if cce_file_name and os.path.exists(cce_file_name):
- os.remove(cce_file_name)
- if ir_file_name and os.path.exists(ir_file_name):
- os.remove(ir_file_name)
- self.case_result &= True
-
- endtime = datetime.datetime.now()
- self._log.info("{0} testcase use ***Running Time*** is: {1}s. "
- .format(caseflag, (endtime - starttime).seconds))
- self._log.info(self.case_result)
- '''
- use assert in the common_run function:
- Because the common_run function in the use cases does not verify the return value, the result cannot be
- printed normally after the program ends, so the execution result needs to be judged in the common_run function.
- '''
- if (not self.case_result) and raise_exception:
- assert self.case_result
- return self.case_result, self._exception
-
- def get_args_dtype(self, input_args_names):
- """
- Get the dtype of the function input parameter, return its index
- :param input_args_names: Test operator method
- :return: kernel_name
- """
- return tuple([index for index, name in enumerate(input_args_names) if str(name).__contains__("dtype")])
-
- def get_kernel_name(self, args, func):
- func_input_args_names = func.__code__.co_varnames
- kernel_name = func.__name__.split('_run')[0].split('_execute')[0]
- for index, name in enumerate(func_input_args_names):
- if str(name).__contains__("kernel_name"):
- kernel_name = func.__name__.split('_run')[0].split('_execute')[0]
- break
- return kernel_name
-
- def replace_args_dtype(self, args, input_args_names, dtype_list):
- """
- replace the dtype field of args
- """
- dtype_index_list = self.get_args_dtype(input_args_names)
- if not dtype_index_list or len(dtype_index_list) > len(dtype_list):
- self._log.error("replace_args_dtype :: dtype_index_list failed, dtype_index_list:{0},dtype_list:{1}".format(
- dtype_index_list, dtype_list))
- return False
-
- input_dtype_index = 0
- for index in dtype_index_list:
- args[index] = dtype_list[input_dtype_index]
- input_dtype_index += 1
-
- return True
-
- def set_args_dtype(self, args, func, dtype_list):
- """
- Set the dtype field of the use case parameter list
- """
- if not args or not dtype_list:
- self._log.error("set_args_dtype failed for test_arg_list:{0},dtype_list:{1}".format(args, dtype_list))
- return True
-
- func_input_args_names = func.__code__.co_varnames
- if not func_input_args_names:
- self._log.error("function : {0} args list is None".format(func))
- return True
-
- return self.replace_args_dtype(args, func_input_args_names, dtype_list)
-
- def upload_file_ftp(self, upload_type, local_file_path):
- if upload_type not in ("csvs", "cce", "ir", "dump_shape", "logs",):
- self._log.error("upload_file_ftp failed :: not support for upload_type:{0}".format(upload_type))
- return None
-
- today = str(datetime.date.today())
- ftp = ftpHandle(self._log)
- if not ftp.ftp_login():
- self._log.error("upload_file_ftp failed for ftp_login")
- return None
-
- remote_path = os.path.join("/auto_tensor", upload_type)
- if not ftp.ftp_mkdir(remote_path, today):
- self._log.error("upload_file_ftp failed for ftp_mkdir,remote_path:{0},today:{1}".format(remote_path, today))
- ftp.ftp_close()
- return None
- remote_path = os.path.join(remote_path, today)
- remote_file_name = str(local_file_path).split("/")[-1]
-
- if not ftp.ftp_upload_file(remote_path, remote_file_name, local_file_path):
- self._log.error(
- "upload_file_ftp failed for ftp_upload_file,remote_path:{0},today:{1},local_file_path:{2}".format(
- remote_path, today, local_file_path))
- ftp.ftp_close()
- return None
- ftp_url = "ftp://{host}/{path}".format(host=ftp.host, path=os.path.join(remote_path, remote_file_name))
-
- ftp.ftp_close()
- return ftp_url
-
- def collect_ir(self, kernel_name):
- if not os.path.exists(kernel_name):
- self._log.warning("not exist ir directory for :{kernel_name}".format(kernel_name=kernel_name))
- return None
-
- file_name = kernel_name + ".tar.gz"
- with tarfile.open(file_name, "w:gz") as tar:
- tar.add(kernel_name, arcname=os.path.basename(kernel_name))
- return file_name
-
- def collect_cce(self, kernel_name):
- file_name = kernel_name + ".cce"
- if not os.path.exists(file_name):
- self._log.warning("not exist cce file for :{file_name}".format(file_name=file_name))
- return None
- return file_name
-
- def collect_data(self, input, output, cce_file_name, ir_file_name, arg, kernel_name):
- ret_val = True
- # dump input and output
- dump_file_list = self.data_dump(input, output, arg)
- self._log.warning("dump input and output as follow:")
- if os.environ.get("FTP_HOST"):
- for dump_file in dump_file_list:
- ftp_url = self.upload_file_ftp("dump_shape", dump_file)
- if not ftp_url:
- self._log.error("upload_file_ftp failed for dump_file : {0}".format(dump_file))
- ret_val = False
- else:
- self._log.warning("dump_file ftp_url : {0}".format(ftp_url))
- # dump ir
- if not ir_file_name:
- self._log.error("collect_ir failed")
- ret_val = False
- else:
- ftp_url = self.upload_file_ftp("ir", ir_file_name)
- if not ftp_url:
- self._log.error("upload_file_ftp failed for ir_file_name : {0}".format(ir_file_name))
- ret_val = False
- else:
- self._log.warning("ir ftp_url : {0}".format(ftp_url))
-
- # dump cce
- if not cce_file_name:
- self._log.error("collect_cce failed")
- ret_val = False
- else:
- ftp_url = self.upload_file_ftp("cce", cce_file_name)
- if not ftp_url:
- self._log.error("upload_file_ftp failed for cce_file_name : {0}".format(cce_file_name))
- ret_val = False
- else:
- self._log.warning("cce ftp_url : {0}".format(ftp_url))
- else:
- case_failed_save_path = '/' + '/'.join(os.path.abspath(self.casename).split('/')[1:-1])
- self._log.warning("The input output data of failed use case log have been saved to the path : {0}/data/{1}"
- .format(case_failed_save_path, kernel_name))
- self._log.warning("The ir data of failed use case log have been saved to the path : {0}/{1}"
- .format(case_failed_save_path, ir_file_name))
- self._log.warning("The cce data of failed use case log have been saved to the path : {0}/{1}"
- .format(case_failed_save_path, cce_file_name))
- return ret_val
-
- def data_dump(self, input, output, arg):
- dump_file_list = []
- operator_name = str(arg[1]).split("_run")[0].split()[-1]
- data_dir = "./data/{0}/".format(operator_name)
- os.popen("mkdir -p %s" % data_dir)
- time.sleep(1)
- if not isinstance(input, list) and not isinstance(input, tuple):
- input = [input]
- if not isinstance(output, list) and not isinstance(output, tuple):
- output = [output]
- data_dict = {"input": input, "output": output}
- for kays in data_dict.keys():
- for index, i in enumerate(data_dict[kays]):
- seq = [operator_name, kays, str(index + 1)] + list(map(str, arg[2])) + [".t"]
- dump_file_name = "_".join(seq).replace("[", "").replace("]", "").replace(",", "-") \
- .replace(" ", "").replace("(", "").replace(")", "").replace("_.", ".")
- dump_file_name += str(time.time())
- dump_file = os.path.join(data_dir, dump_file_name)
- dump_file_list.append(dump_file)
- tensorio.dump_tensor(i, dump_file)
- return dump_file_list
-
-
- def get_rtol_atol(op_name, dtype, rtol=5e-03, atol=5e-03):
- run_mode = os.environ.get('RUNTIME_MODE')
- if run_mode in ("rpc_cloud", "air_cloud"):
- if dtype == "float16":
- rtol = atol = 1e-03
- else:
- rtol = atol = 1e-04
- return rtol, atol
-
-
- def get_splitted_cases(cases, split_nums, split_idx):
- if not isinstance(cases, (list, tuple)):
- raise TypeError("Argument cases must be of type list or tuple.")
- if not isinstance(split_nums, int) or not isinstance(split_idx, int):
- raise TypeError("Arguments split_nums and split_idx must be of type int.")
- if split_nums <= 0 or split_idx < 0 or split_idx >= split_nums:
- raise ValueError("Argument split_nums must > 0, split_idx must be in range [0, split_nums)")
-
- cases = list(cases)
- all_cases = len(cases)
- fragment = (all_cases + split_nums - 1) // split_nums
-
- start_idx = split_idx * fragment
- if start_idx >= all_cases:
- return []
-
- end_idx = start_idx + fragment
- if end_idx > all_cases:
- end_idx = all_cases
-
- return cases[start_idx:end_idx]
|