You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tuner.py 13 kB

5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """Tuner for finding best config for operators"""
  15. import logging
  16. import time
  17. import json
  18. import os
  19. import numpy as np
  20. from multiprocessing import Process
  21. from tvm.autotvm.tuner.xgboost_cost_model import XgbCostModel
  22. from tvm.autotvm.tuner.sa_model_optimizer import SimulatedAnnealingOptimizer
  23. from akg.auto_tune.space import ConfigSpace
  24. from akg.auto_tune.runner import KernelRunner
  25. logger = logging.getLogger('fuzz.tune.autotuning.tuner')
  26. class Tuner:
  27. """Basic tuner class
  28. Parameters
  29. ----------
  30. runner: KernelRunner
  31. This is for run kernels in physical device
  32. config_space: ConfigSpace
  33. The space of configs
  34. n_parallel: int
  35. How many kernels are processed in a turn
  36. """
  37. def __init__(self, runner: KernelRunner, index_table: list, config_space: ConfigSpace, n_parallel: int = 1):
  38. self._runner = runner
  39. self._index_table = index_table
  40. self._space = config_space
  41. self._n_parallel = n_parallel
  42. # trial plan
  43. self._trials = []
  44. self._trial_pt = 0
  45. self._visited = set()
  46. # observed samples
  47. self._xs = []
  48. self._ys = []
  49. # keep the current best
  50. self._best_config = None # type: ConfigEntity
  51. self._index_table = list() # used to parse best config into attrs
  52. self._best_time = np.inf
  53. self._best_iter = 0
  54. self._tuning_time = 0.0
  55. self._original_time = np.inf
  56. @property
  57. def best_config(self):
  58. return self._best_config
  59. @property
  60. def best_time(self):
  61. return self._best_time
  62. @property
  63. def best_iter(self):
  64. return self._best_iter
  65. @property
  66. def tuning_time(self):
  67. return self._tuning_time
  68. @property
  69. def original_time(self):
  70. return self._original_time
  71. @property
  72. def xs(self):
  73. return self._xs
  74. @property
  75. def ys(self):
  76. return self._ys
  77. def info(self):
  78. print('space size:', self._space.length)
  79. print('best config:', self._best_config)
  80. print('best time:', self._best_time)
  81. print('best_iter:', self._best_iter)
  82. print('tuning time:', self._tuning_time, 'secs')
  83. def next_batch(self, batch_size: int, is_add_visited=True):
  84. """extract next batch with xgboost model"""
  85. ret = []
  86. counter = 0
  87. if not is_add_visited:
  88. return [self._space.get(index) for index in range(min(batch_size, self._space.length))]
  89. while counter < batch_size and self._space.has_next():
  90. index = 0
  91. while self._trial_pt < len(self._trials):
  92. index = self._trials[self._trial_pt]
  93. if index not in self._visited:
  94. break
  95. self._trial_pt += 1
  96. if self._trial_pt >= len(self._trials):
  97. # if the trial list is empty choose randomly
  98. index = self._space.fetch_index()
  99. ret.append(self._space.get(index))
  100. self._visited.add(index)
  101. counter += 1
  102. return ret
  103. def next_config(self, batch_size: int):
  104. """extract next config orderly"""
  105. ret = []
  106. counter = 0
  107. while counter < batch_size and self._space.has_next():
  108. index = self._space.fetch_next_index()
  109. ret.append(self._space.get(index))
  110. self._visited.add(index)
  111. counter += 1
  112. return ret
  113. def export_configs(self, configs: list, output_file: str, append: bool = True, desc=""):
  114. """export configs"""
  115. mode = "a" if append else "w"
  116. with open(output_file, mode) as f:
  117. for x, y in configs:
  118. f.write("{} | {} | {}\n".format(desc, json.dumps(x._asdict()), y))
  119. def export_dim_configs(self, configs, output_file: str, append: bool = True, key=""):
  120. """export dim configs"""
  121. mode = "a" if append else "w"
  122. data = {}
  123. try:
  124. if os.path.isfile(output_file):
  125. with open(output_file, 'r') as f:
  126. data = json.load(f)
  127. except IOError as e:
  128. logger.debug("get dim info from [%s] failed: %s", output_file, str(e))
  129. with open(output_file, mode) as f:
  130. import re
  131. data[key] = configs
  132. s = json.dumps(data, sort_keys=True)
  133. s = re.sub(r',\s*"', ',\n"', s)
  134. s = '{\n' + s[1:-1] + '\n}'
  135. f.write(s)
  136. def export_dim_configs_for_keys(self, configs, output_file: str, append: bool = True, keys=[]):
  137. """export dim configs"""
  138. mode = "a" if append else "w"
  139. data = {}
  140. try:
  141. if os.path.isfile(output_file):
  142. with open(output_file, 'r') as f:
  143. data = json.load(f)
  144. except IOError as e:
  145. logger.debug("get dim info from [%s] failed: %s", output_file, str(e))
  146. with open(output_file, mode) as f:
  147. import copy
  148. data_tmp = copy.deepcopy(data)
  149. res_key = []
  150. for key in keys:
  151. if key in data_tmp:
  152. data_tmp = data_tmp[key]
  153. res_key.append(key)
  154. tmp = copy.deepcopy(configs)
  155. info = {}
  156. for key in reversed(keys):
  157. if not key in res_key:
  158. info = {key: tmp}
  159. tmp = copy.deepcopy(info)
  160. data_change = data
  161. for key in res_key:
  162. data_change = data_change[key]
  163. data_change.update(**info)
  164. s = json.dumps(data, sort_keys=True, indent=4)
  165. f.write(s)
  166. def load_configs(self, input_file: str):
  167. """load configs"""
  168. configs = []
  169. file_path = os.path.realpath(input_file)
  170. if os.path.isfile(file_path):
  171. with open(file_path, "r") as f:
  172. for line in f:
  173. x, y, _ = line.split('|')
  174. configs.append((self._space.input_type(**json.loads(x)), np.float64(y)))
  175. return configs
  176. def tune(self, least_try_times: int, output_file: str = None):
  177. """grid search all configs"""
  178. i = 0
  179. while i < least_try_times:
  180. if not self._space.has_next():
  181. break
  182. configs = self.next_config(min(self._n_parallel, least_try_times - i))
  183. run_times = self._runner.run(configs, self._best_time)
  184. results = []
  185. for idx, conf in enumerate(configs):
  186. results.append((conf.input_id, run_times[idx]))
  187. # keep best config
  188. if self.best_time > run_times[idx]:
  189. self._best_time = run_times[idx]
  190. self._best_iter = i + idx
  191. self._best_config = conf
  192. i += len(results)
  193. # update
  194. for res in results:
  195. self._xs.append(res[0])
  196. self._ys.append(res[1])
  197. if output_file:
  198. configs = [(self._space.get(res[0]).input, res[1]) for res in results]
  199. self.export_configs(configs, output_file)
  200. return run_times
  201. class ModelBasedTuner(Tuner):
  202. """Model based tuner
  203. This tuner will fit a cost model and use an optimizer to find the maximums of the cost model as next trials
  204. Parameters
  205. ----------
  206. plan_size: int
  207. Tuner will re-fit model per `plan_size` new measure samples
  208. pre_model: CostModel
  209. The cost model that predicts the speed of a config (IR)
  210. """
  211. def __init__(self, runner, index_table, config_space, n_parallel=1, plan_size=32, pre_model=None):
  212. super(ModelBasedTuner, self).__init__(runner, index_table, config_space, n_parallel)
  213. self.__plan_size = plan_size
  214. if pre_model is not None:
  215. self.__cost_model = pre_model
  216. self.__cost_model.reset_space(self._space)
  217. else:
  218. self.__cost_model = XgbCostModel(self._space)
  219. self.__model_optimizer = SimulatedAnnealingOptimizer(self._space)
  220. self.__train_ct = 0
  221. self.__is_auto_set_dim = True
  222. # time to leave
  223. self.__ttl = None
  224. self.__least_try_times = None
  225. self.__early_stopping = None
  226. self.__model_run_time = 0.0
  227. def info(self):
  228. super(ModelBasedTuner, self).info()
  229. print('model run time:', self.__model_run_time, 'secs')
  230. def model_res(self):
  231. self.__cost_model.fit(self._xs, self._ys, self.__plan_size)
  232. best_configs = self.__model_optimizer.find_best(
  233. self.__cost_model, self.__plan_size, self._visited)
  234. self._trials = best_configs
  235. def tune(self, least_try_times: int, output_file: str = None):
  236. early_stopping = least_try_times
  237. self.__least_try_times = least_try_times
  238. self.__early_stopping = early_stopping
  239. old_level = logger.level
  240. i = 0
  241. error_ct = 0
  242. tuning_start = time.time()
  243. while (i < self._space.length and (i < least_try_times
  244. or (self._best_time > self._original_time - 0.9
  245. and i < least_try_times * 3))):
  246. if not self._space.has_next():
  247. break
  248. iter_start = time.time()
  249. if not self.__is_auto_set_dim:
  250. configs = self.next_batch(min(self._n_parallel, self._space.length - i))
  251. else:
  252. configs = self.next_batch(min(self._n_parallel, self._space.length - i), False)
  253. logger.debug('--indexes: %s', str([x.input_id for x in configs]))
  254. run_times = self._runner.run(configs, self._best_time, self.__is_auto_set_dim)
  255. if self.__is_auto_set_dim:
  256. # profiling start fail occasionally
  257. run_fail = 9999999999.0
  258. run_times = [x for x in run_times if x != run_fail]
  259. if len(run_times) == 0:
  260. self._original_time = run_fail
  261. else:
  262. from operator import add
  263. from functools import reduce
  264. self._original_time = reduce(add, run_times) / len(run_times)
  265. self._best_time = self._original_time
  266. self._best_iter = -1
  267. self._best_config = None
  268. run_times = None
  269. self.__is_auto_set_dim = False
  270. continue
  271. results = []
  272. for idx, conf in enumerate(configs):
  273. results.append((conf.input_id, run_times[idx]))
  274. # keep best config
  275. if self._best_time - 600 > run_times[idx]:
  276. self._best_time = run_times[idx]
  277. self._best_iter = i + idx
  278. self._best_config = conf
  279. i += len(results)
  280. self.__ttl = min(early_stopping + self.best_iter, self._space.length) - i
  281. start = time.time()
  282. # update
  283. for res in results:
  284. self._xs.append(res[0])
  285. self._ys.append(res[1])
  286. if output_file:
  287. configs = [(self._space.get(res[0]).input, res[1]) for res in results]
  288. desc = str(self._runner.op_desc)
  289. self.export_configs(configs, output_file, desc=desc)
  290. # if we have enough new training samples
  291. if len(self._xs) >= self.__plan_size * (self.__train_ct + 1):
  292. p = Process(target=self.model_res)
  293. p.start()
  294. p.join()
  295. self._trial_pt = 0
  296. self.__train_ct += 1
  297. end = time.time()
  298. logger.debug('model running time: %f seconds', end - start)
  299. self.__model_run_time += end - start
  300. iter_end = time.time()
  301. logger.debug('iter time: %f seconds', iter_end - iter_start)
  302. if self._best_iter > 0 and i >= self.best_iter + early_stopping:
  303. logger.warning('Early stopped. Best iter: %d', self._best_iter)
  304. return
  305. if self._best_time < 1000:
  306. logger.warning('Early stopped for this is a small shape. Best iter: %d', self._best_iter)
  307. return
  308. logger.debug("tuning time already, %f", time.time() - tuning_start)
  309. if time.time() - tuning_start > 7200:
  310. logger.warning('Early stopped because of too long time. Best iter: %d', self._best_iter)
  311. return
  312. if error_ct > 150:
  313. logging.warning('Too many errors happen in the tuning. Now is in debug mode')
  314. logger.setLevel(logging.DEBUG)
  315. else:
  316. logger.setLevel(old_level)
  317. self._tuning_time += time.time() - tuning_start