# Copyright 2019 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Runner for compile and execute a configs of an operator on device""" import json import time import multiprocessing import logging import os import subprocess import time from typing import NamedTuple import numpy as np from akg import composite from akg.utils import custom_tiling as ct_util from akg.utils import kernel_exec as utils from akg.auto_tune.kernel_compiler import compile_kernel from akg.auto_tune.data_generators import gen_data from akg.auto_tune.kernel_compiler import get_matmul_cube_attrs logger = logging.getLogger('fuzz.tune.autotuning.runner') error_time_list = [ 9999999999.0, 9999999998.0, 9999999997.0, 9999999996.0, ] error_time_string = { error_time_list[0]: 'run_failed', error_time_list[1]: 'precision_error', error_time_list[2]: 'compile_failed', error_time_list[3]: 'timeout' } run_failed_time = error_time_list[0] precision_error_time = error_time_list[1] compile_fail_time = error_time_list[2] timeout_time = error_time_list[3] def get_attr_from_config(config, index_table): tiling = [] attrs = {} tuning_dict = config._asdict() for key, value in tuning_dict.items(): if key.startswith('tiling'): item = [value, 1] tiling.append(item) else: attrs[key] = value if len(tiling): tiling_param = [] for i, element in enumerate(tiling): tiling_param.append(index_table[i] + element) dim_info = ct_util.set_dims(tuple(tiling_param)) attrs['dim'] = dim_info else: print("No tiling info. Use auto tiling.") return attrs class KernelRunner: """kernel runner This runner will compile and execute configs of an operator, and return their running times. Parameters ---------- op_type: str The name of operator op_desc: NamedTuple The definition parameters of operator timeout: int Timeout for running one config repeat_times: Run one config repeat_times """ def __init__(self, op_type: str, op_desc: NamedTuple, json_desc: str, index_table: list, timeout: int = 600, repeat_times: int = 2, input_data=None, expect=None, mod_output_param=None): self.op_type = op_type self.op_desc = op_desc self.json_desc = json_desc self._index_table = index_table self.run_kernel_time = 0.0 self.timeout = timeout self.repeat_times = repeat_times self.mod_output_param = mod_output_param if input_data is None: self.input, self.expect = gen_data(op_type, op_desc) if isinstance(self.input, dict): self.input, self.mod_output_param = self.input['args'], self.input['outputs'] else: self.input, self.expect = input_data, expect self.input_shape = [x.shape for x in self.input] def info(self): print('run kernel time:', self.run_kernel_time) def run_one_kernel(self, run_times, idx, config, best_time=np.inf, is_auto=False): """Compile and execute a config of the operator on device""" time_one_kernel_start = time.time() logger.debug('compile %dth kernel', idx) # get available device if utils.get_available_devices_num() == 1: device_id = utils.get_device_id() else: device_id = idx + utils.get_device_id() os.environ['DEVICE_ID'] = str(device_id) logger.debug('run %dth kernel', idx) logger.debug('++++++++++++++++++++++=device_id') logger.debug(device_id) logger.debug('++++++++++++++++++++++=device_id') try: time_start_build = time.time() logger.debug(config) if self.op_type in ["json", "extra_tune"]: if is_auto: attrs = {} attrs["kernel_name"] = "tuning_json_" + str(idx) mod = composite.build(self.op_desc, attrs) if self.op_type == "extra_tune": del os.environ['MS_GRAPH_KERNEL_TILING'] else: attrs = get_attr_from_config(config.input, self._index_table) attrs["kernel_name"] = "tuning_json_" + str(idx) if os.environ['RUNTIME_MODE'] == "gpu": attrs['target'] = "cuda" mod = composite.build(self.op_desc, attrs, use_repo=False) elif self.op_type == "matmul_json": attrs = get_matmul_cube_attrs(self.op_desc, config.input) print(attrs) mod = composite.build(self.json_desc, attrs, use_repo=False) else: mod = compile_kernel(self.op_type, self.op_desc, self.input_shape, self._index_table, None if is_auto else config.input, idx) time_end_build = time.time() logger.debug("build module time: %f", time_end_build - time_start_build) logger.debug('finished compile %dth kernel', idx) except BaseException as e: logger.warning("Compile Failed: [%s] : %s", "origin" if is_auto else str(config.input), str(e)) run_times[idx] = compile_fail_time return run_times[idx] = run_failed_time try: for _ in range(self.repeat_times): stat_info = {} try: time_start_launch = time.time() if self.mod_output_param is not None and len(self.mod_output_param) > 1: output, stat_info = utils.mod_launch(mod, list(self.input), self.mod_output_param, tuning=True, device_id=device_id) if stat_info['run_time'] < best_time: if not all(map(lambda x, y: np.allclose(x, y, rtol=5e-03, atol=5e-03, equal_nan=True), output, self.expect)): stat_info['run_time'] = precision_error_time logger.warning("Precision Error: [%s]", "origin" if config is None else str(config.input)) else: if self.op_type in ["json", "extra_tune", "matmul_json"]: output, stat_info = utils.mod_launch(mod, self.input, self.mod_output_param, tuning=True, device_id=device_id) else: output, stat_info = utils.mod_launch(mod, self.input, tuning=True, device_id=device_id) if stat_info['run_time'] < best_time: if not np.allclose(output, self.expect, rtol=5e-03, atol=5e-03, equal_nan=True): stat_info['run_time'] = precision_error_time logger.warning("Precision Error: [%s]", "origin" if config is None else str(config.input)) time_end_launch = time.time() logger.debug("mod launch time: %f", time_end_launch - time_start_launch) except BaseException as e: logger.warning("Run Failed: [%s] : %s", str(config.input), str(e)) stat_info['run_time'] = run_failed_time run_times[idx] = np.minimum(run_times[idx], stat_info['run_time']) finally: logger.debug('end of %dth kernel', idx) time_one_kernel_end = time.time() logger.debug('run one kernel time: %f', time_one_kernel_end - time_one_kernel_start) return def run(self, configs, best_time=np.inf, is_auto_set_dim=False, all_space=False): """Compile and execute a batch config of the operator on device""" start = time.time() logger.debug("gen cce kernels batch: %d kernels", len(configs)) subprocess.run("rm -rf ./jobs/JOB*", shell=True) process_jobs = [] run_times = multiprocessing.Manager().list(np.full((len(configs),), compile_fail_time)) for idx, config in enumerate(configs): p = multiprocessing.Process(target=self.run_one_kernel, args=(run_times, idx, config, best_time, is_auto_set_dim)) process_jobs.append(p) p.start() timeout_error = False for idx, p in enumerate(process_jobs): if not timeout_error: p.join(timeout=self.timeout) if p.is_alive(): timeout_error = True logger.debug("Timeout Error: [%s]", str(configs[idx].input)) run_times[idx] = timeout_time p.terminate() process_end = time.time() logger.debug("process time: %f", process_end - start) # clean the profiling directory tune_device = int(os.environ['DEVICE_ID']) tune_num = int(os.environ['DEVICE_TOTAL_NUM']) if os.environ['RUNTIME_MODE'] == "gpu": subprocess.run("rm -rf cuda_meta*", shell=True) else: profiling_dir = str(os.environ['PROFILING_DIR']) if len(profiling_dir) == 0 or profiling_dir.isspace(): logger.error("The value about PROFILING_DIR shoud be setted correctly.") subprocess.run("rm -rf %s/JOB*" % profiling_dir, shell=True) end = time.time() logger.debug("run kernels time: %f", end - start) self.run_kernel_time += end - start for idx, config in enumerate(configs): if run_times[idx] not in error_time_list: logger.debug("KernelRunTime : [%s] : %s", str(configs[idx].input), str(run_times[idx])) else: logger.debug("KernelRunTime : [%s] : %s", str(configs[idx].input), str(error_time_string[run_times[idx]])) return run_times