You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

_cost_model_context.py 25 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Context of cost_model in auto_parallel"""
  16. import threading
  17. from mindspore._c_expression import CostModelContext
  18. from mindspore._checkparam import args_type_check
  19. class _CostModelContext:
  20. """
  21. _CostModelContext is the environment in which operations are executed
  22. Note:
  23. Creating a context through instantiating Context object is not recommended.
  24. Use cost_model_context() to get the context since Context is singleton.
  25. """
  26. _instance = None
  27. _instance_lock = threading.Lock()
  28. def __init__(self):
  29. self._context_handle = CostModelContext.get_instance()
  30. def __new__(cls):
  31. if cls._instance is None:
  32. cls._instance_lock.acquire()
  33. cls._instance = object.__new__(cls)
  34. cls._instance_lock.release()
  35. return cls._instance
  36. def set_device_memory_capacity(self, dev_mem_cap):
  37. """
  38. Set device memory capacity.
  39. Args:
  40. dev_mem_cap (float): The memory capacity for each device.
  41. Raises:
  42. ValueError: If context handle is none.
  43. """
  44. if self._context_handle is None:
  45. raise ValueError("Context handle is none in context!!!")
  46. self._context_handle.set_device_memory_capacity(dev_mem_cap)
  47. def get_device_memory_capacity(self):
  48. """
  49. Get device memory capacity.
  50. Raises:
  51. ValueError: If context handle is none.
  52. """
  53. if self._context_handle is None:
  54. raise ValueError("Context handle is none in context!!!")
  55. return self._context_handle.get_device_memory_capacity()
  56. def set_costmodel_alpha(self, alpha):
  57. """
  58. Set costmodel alpha.
  59. Args:
  60. alpha (float): The parameter costmodel_alpha used in strategy-searching algorithm.
  61. Raises:
  62. ValueError: If context handle is none.
  63. """
  64. if self._context_handle is None:
  65. raise ValueError("Context handle is none in context!!!")
  66. self._context_handle.set_costmodel_alpha(alpha)
  67. def get_costmodel_alpha(self):
  68. """
  69. Get costmodel alpha.
  70. Raises:
  71. ValueError: If context handle is none.
  72. """
  73. if self._context_handle is None:
  74. raise ValueError("Context handle is none in context!!!")
  75. return self._context_handle.get_costmodel_alpha()
  76. def set_costmodel_beta(self, beta):
  77. """
  78. Set costmodel beta.
  79. Args:
  80. beta (float): The parameter costmodel_beta used in strategy-searching algorithm.
  81. Raises:
  82. ValueError: If context handle is none.
  83. """
  84. if self._context_handle is None:
  85. raise ValueError("Context handle is none in context!!!")
  86. self._context_handle.set_costmodel_beta(beta)
  87. def get_costmodel_beta(self):
  88. """
  89. Get costmodel beta.
  90. Raises:
  91. ValueError: If context handle is none.
  92. """
  93. if self._context_handle is None:
  94. raise ValueError("Context handle is none in context!!!")
  95. return self._context_handle.get_costmodel_beta()
  96. def set_costmodel_gamma(self, gamma):
  97. """
  98. Set costmodel gamma.
  99. Args:
  100. gamma (float): The parameter costmodel_gamma used in strategy-searching algorithm.
  101. Raises:
  102. ValueError: If context handle is none.
  103. """
  104. if self._context_handle is None:
  105. raise ValueError("Context handle is none in context!!!")
  106. self._context_handle.set_costmodel_gamma(gamma)
  107. def get_costmodel_gamma(self):
  108. """
  109. Get costmodel gamma.
  110. Raises:
  111. ValueError: If context handle is none.
  112. """
  113. if self._context_handle is None:
  114. raise ValueError("Context handle is none in context!!!")
  115. return self._context_handle.get_costmodel_gamma()
  116. def set_costmodel_communi_threshold(self, threshold):
  117. """
  118. Set costmodel communication threshold.
  119. Args:
  120. threshold (float): A parameter used in adjusting communication calculation for practice.
  121. Raises:
  122. ValueError: If context handle is none.
  123. """
  124. if self._context_handle is None:
  125. raise ValueError("Context handle is none in context!!!")
  126. self._context_handle.set_costmodel_communi_threshold(threshold)
  127. def get_costmodel_communi_threshold(self):
  128. """
  129. Get costmodel communication threshold.
  130. Raises:
  131. ValueError: If context handle is none.
  132. """
  133. if self._context_handle is None:
  134. raise ValueError("Context handle is none in context!!!")
  135. return self._context_handle.get_costmodel_communi_threshold()
  136. def set_costmodel_communi_const(self, communi_const):
  137. """
  138. Set costmodel communication const.
  139. Args:
  140. const (float): A parameter used in adjusting communication calculation for practice.
  141. Raises:
  142. ValueError: If context handle is none.
  143. """
  144. if self._context_handle is None:
  145. raise ValueError("Context handle is none in context!!!")
  146. self._context_handle.set_costmodel_communi_const(communi_const)
  147. def get_costmodel_communi_const(self):
  148. """
  149. Get costmodel communication const.
  150. Raises:
  151. ValueError: If context handle is none.
  152. """
  153. if self._context_handle is None:
  154. raise ValueError("Context handle is none in context!!!")
  155. return self._context_handle.get_costmodel_communi_const()
  156. def set_costmodel_communi_bias(self, communi_bias):
  157. """
  158. Set costmodel communication bias.
  159. Args:
  160. communi_bias (float): A parameter used in adjusting communication calculation for practice.
  161. Raises:
  162. ValueError: If context handle is none.
  163. """
  164. if self._context_handle is None:
  165. raise ValueError("Context handle is none in context!!!")
  166. self._context_handle.set_costmodel_communi_bias(communi_bias)
  167. def get_costmodel_communi_bias(self):
  168. """
  169. Get costmodel communication bias.
  170. Raises:
  171. ValueError: If context handle is none.
  172. """
  173. if self._context_handle is None:
  174. raise ValueError("Context handle is none in context!!!")
  175. return self._context_handle.get_costmodel_communi_bias()
  176. def set_multi_subgraphs(self, multi_subgraph):
  177. """
  178. Set the flag of ANF graph containing multiple subgraphs.
  179. Args:
  180. multi_subgraph (bool): A parameter used in marking the multi-subgraphs flag.
  181. Raises:
  182. ValueError: If context handle is none.
  183. """
  184. if self._context_handle is None:
  185. raise ValueError("Context handle is none in context!!!")
  186. self._context_handle.set_multi_subgraphs(multi_subgraph)
  187. def get_multi_subgraphs(self):
  188. """
  189. Get the flag of ANF graph containing multiple subgraphs.
  190. Raises:
  191. ValueError: If context handle is none.
  192. """
  193. if self._context_handle is None:
  194. raise ValueError("Context handle is none in context!!!")
  195. return self._context_handle.get_multi_subgraphs()
  196. def set_run_phase(self, phase):
  197. """
  198. Set the flag of running phase: training (0) or inference (1)
  199. Args:
  200. phase (int): A parameter indicating which phase is running.
  201. Raises:
  202. ValueError: If context handle is none, or phase is not in {0, 1}.
  203. """
  204. if not isinstance(phase, int) or isinstance(phase, bool):
  205. raise TypeError(f"The type of communi_const must be int, but got {type(phase)}.")
  206. if self._context_handle is None:
  207. raise ValueError("Context handle is none in context!!!")
  208. if phase not in (0, 1):
  209. raise ValueError("The argument of set_run_phase() must be '0' or '1', but got {}".format(phase))
  210. self._context_handle.set_run_phase(phase)
  211. def get_run_phase(self):
  212. """
  213. Get the flag of running phase.
  214. Raises:
  215. ValueError: If context handle is none.
  216. """
  217. if self._context_handle is None:
  218. raise ValueError("Context handle is none in context!!!")
  219. return self._context_handle.get_run_phase()
  220. def set_dp_algo_single_loop(self, single_loop):
  221. """
  222. Set the flag of generating a single suite of OperatorInfos in for-loop.
  223. Args:
  224. single_loop (bool): The parameter for the single loop flag.
  225. Raises:
  226. ValueError: If context handle is none.
  227. """
  228. if not isinstance(single_loop, bool):
  229. raise TypeError(f"The type of single_loop must be bool, but got {type(single_loop)}.")
  230. if self._context_handle is None:
  231. raise ValueError("Context handle is none in context!!!")
  232. self._context_handle.set_dp_algo_single_loop(single_loop)
  233. def get_dp_algo_single_loop(self):
  234. """
  235. Get the flag of whether or not generating a single suite of OperatorInfos in for-loop.
  236. Raises:
  237. ValueError: If context handle is none.
  238. """
  239. if self._context_handle is None:
  240. raise ValueError("Context handle is none in context!!!")
  241. return self._context_handle.get_dp_algo_single_loop()
  242. def set_costmodel_allreduce_fusion_algorithm(self, algorithm):
  243. """
  244. Set costmodel allreduce fusion algorithm.
  245. Args:
  246. algorithm (int): The AllReduce fusion algorithm of parameter gradients.
  247. Raises:
  248. ValueError: If context handle is none.
  249. """
  250. if self._context_handle is None:
  251. raise ValueError("Context handle is none in context!!!")
  252. self._context_handle.set_costmodel_allreduce_fusion_algorithm(algorithm)
  253. def get_costmodel_allreduce_fusion_algorithm(self):
  254. """
  255. Get costmodel allreduce fusion algorithm.
  256. Raises:
  257. ValueError: If context handle is none.
  258. """
  259. if self._context_handle is None:
  260. raise ValueError("Context handle is none in context!!!")
  261. return self._context_handle.get_costmodel_allreduce_fusion_algorithm()
  262. def set_costmodel_allreduce_fusion_times(self, allreduce_fusion_times):
  263. """
  264. Set costmodel allreduce fusion times.
  265. Args:
  266. allreduce_fusion_times (int): The AllReduce fusion times of parameter gradients.
  267. Raises:
  268. ValueError: If context handle is none.
  269. """
  270. if self._context_handle is None:
  271. raise ValueError("Context handle is none in context!!!")
  272. self._context_handle.set_costmodel_allreduce_fusion_times(allreduce_fusion_times)
  273. def get_costmodel_allreduce_fusion_times(self):
  274. """
  275. Get costmodel allreduce fusion times.
  276. Raises:
  277. ValueError: If context handle is none.
  278. """
  279. if self._context_handle is None:
  280. raise ValueError("Context handle is none in context!!!")
  281. return self._context_handle.get_costmodel_allreduce_fusion_times()
  282. def set_costmodel_allreduce_fusion_tail_percent(self, tail_percent):
  283. """
  284. Set costmodel allreduce fusion tail percent.
  285. Args:
  286. tail_percent (int): The percentage of backward computing time corresponding to the last parameter gradients
  287. AllReduce in the whole backward computing time.
  288. Raises:
  289. ValueError: If context handle is none.
  290. """
  291. if self._context_handle is None:
  292. raise ValueError("Context handle is none in context!!!")
  293. self._context_handle.set_costmodel_allreduce_fusion_tail_percent(tail_percent)
  294. def get_costmodel_allreduce_fusion_tail_percent(self):
  295. """
  296. Get costmodel allreduce fusion tail percent.
  297. Raises:
  298. ValueError: If context handle is none.
  299. """
  300. if self._context_handle is None:
  301. raise ValueError("Context handle is none in context!!!")
  302. return self._context_handle.get_costmodel_allreduce_fusion_tail_percent()
  303. def set_costmodel_allreduce_fusion_tail_time(self, tail_time):
  304. """
  305. Set costmodel allreduce fusion tail time.
  306. Args:
  307. tail_time (int): The tail time of the last parameter gradients AllReduce after the end of backward
  308. computation.
  309. Raises:
  310. ValueError: If context handle is none.
  311. """
  312. if self._context_handle is None:
  313. raise ValueError("Context handle is none in context!!!")
  314. self._context_handle.set_costmodel_allreduce_fusion_tail_time(tail_time)
  315. def get_costmodel_allreduce_fusion_tail_time(self):
  316. """
  317. Get costmodel allreduce fusion tail time.
  318. Raises:
  319. ValueError: If context handle is none.
  320. """
  321. if self._context_handle is None:
  322. raise ValueError("Context handle is none in context!!!")
  323. return self._context_handle.get_costmodel_allreduce_fusion_tail_time()
  324. def set_costmodel_allreduce_fusion_allreduce_inherent_time(self, allreduce_inherent_time):
  325. """
  326. Set costmodel allreduce fusion allreduce inherent time.
  327. Args:
  328. allreduce_inherent_time (int): The inherent cost time of AllReduce.
  329. Raises:
  330. ValueError: If context handle is none.
  331. """
  332. if self._context_handle is None:
  333. raise ValueError("Context handle is none in context!!!")
  334. self._context_handle.set_costmodel_allreduce_fusion_allreduce_inherent_time(allreduce_inherent_time)
  335. def get_costmodel_allreduce_fusion_allreduce_inherent_time(self):
  336. """
  337. Get costmodel allreduce fusion allreduce inherent time.
  338. Raises:
  339. ValueError: If context handle is none.
  340. """
  341. if self._context_handle is None:
  342. raise ValueError("Context handle is none in context!!!")
  343. return self._context_handle.get_costmodel_allreduce_fusion_allreduce_inherent_time()
  344. def set_costmodel_allreduce_fusion_allreduce_bandwidth(self, allreduce_bandwidth):
  345. """
  346. Set costmodel allreduce fusion allreduce bandwidth.
  347. Args:
  348. allreduce_bandwidth (int): The bandwidth of AllReduce.
  349. Raises:
  350. ValueError: If context handle is none.
  351. """
  352. if self._context_handle is None:
  353. raise ValueError("Context handle is none in context!!!")
  354. self._context_handle.set_costmodel_allreduce_fusion_allreduce_bandwidth(allreduce_bandwidth)
  355. def get_costmodel_allreduce_fusion_allreduce_bandwidth(self):
  356. """
  357. Get costmodel allreduce fusion allreduce bandwidth.
  358. Raises:
  359. ValueError: If context handle is none.
  360. """
  361. if self._context_handle is None:
  362. raise ValueError("Context handle is none in context!!!")
  363. return self._context_handle.get_costmodel_allreduce_fusion_allreduce_bandwidth()
  364. def set_costmodel_allreduce_fusion_computation_time_parameter(self, computation_time_parameter):
  365. """
  366. Set costmodel allreduce fusion computation time parameter.
  367. Args:
  368. computation_time_parameter (int): The parameter used to compute backward computation time.
  369. Raises:
  370. ValueError: If context handle is none.
  371. """
  372. if self._context_handle is None:
  373. raise ValueError("Context handle is none in context!!!")
  374. self._context_handle.set_costmodel_allreduce_fusion_computation_time_parameter(computation_time_parameter)
  375. def get_costmodel_allreduce_fusion_computation_time_parameter(self):
  376. """
  377. Get costmodel allreduce fusion computation time parameter.
  378. Raises:
  379. ValueError: If context handle is none.
  380. """
  381. if self._context_handle is None:
  382. raise ValueError("Context handle is none in context!!!")
  383. return self._context_handle.get_costmodel_allreduce_fusion_computation_time_parameter()
  384. def reset_cost_model(self):
  385. """
  386. Reset cost model settings.
  387. Raises:
  388. ValueError: If context handle is none.
  389. """
  390. if self._context_handle is None:
  391. raise ValueError("Context handle is none in context!!!")
  392. self._context_handle.reset_cost_model()
  393. _cost_model_context = None
  394. def cost_model_context():
  395. """
  396. Get the global _cost_model_context. If it is not created, create a new one.
  397. Returns:
  398. The global cost_model context.
  399. """
  400. global _cost_model_context
  401. if _cost_model_context is None:
  402. _cost_model_context = _CostModelContext()
  403. return _cost_model_context
  404. set_cost_model_context_func_map = {
  405. "device_memory_capacity": cost_model_context().set_device_memory_capacity,
  406. "costmodel_alpha": cost_model_context().set_costmodel_alpha,
  407. "costmodel_beta": cost_model_context().set_costmodel_beta,
  408. "costmodel_gamma": cost_model_context().set_costmodel_gamma,
  409. "costmodel_communi_threshold": cost_model_context().set_costmodel_communi_threshold,
  410. "costmodel_communi_const": cost_model_context().set_costmodel_communi_const,
  411. "costmodel_communi_bias": cost_model_context().set_costmodel_communi_bias,
  412. "run_phase": cost_model_context().set_run_phase,
  413. "costmodel_allreduce_fusion_algorithm": cost_model_context().set_costmodel_allreduce_fusion_algorithm,
  414. "costmodel_allreduce_fusion_times": cost_model_context().set_costmodel_allreduce_fusion_times,
  415. "costmodel_allreduce_fusion_tail_percent": cost_model_context().set_costmodel_allreduce_fusion_tail_percent,
  416. "costmodel_allreduce_fusion_tail_time": cost_model_context().set_costmodel_allreduce_fusion_tail_time,
  417. "costmodel_allreduce_fusion_allreduce_inherent_time":
  418. cost_model_context().set_costmodel_allreduce_fusion_allreduce_inherent_time,
  419. "costmodel_allreduce_fusion_allreduce_bandwidth":
  420. cost_model_context().set_costmodel_allreduce_fusion_allreduce_bandwidth,
  421. "costmodel_allreduce_fusion_computation_time_parameter":
  422. cost_model_context().set_costmodel_allreduce_fusion_computation_time_parameter}
  423. get_cost_model_context_func_map = {
  424. "device_memory_capacity": cost_model_context().get_device_memory_capacity,
  425. "costmodel_alpha": cost_model_context().get_costmodel_alpha,
  426. "costmodel_beta": cost_model_context().get_costmodel_beta,
  427. "costmodel_gamma": cost_model_context().get_costmodel_gamma,
  428. "costmodel_communi_threshold": cost_model_context().get_costmodel_communi_threshold,
  429. "costmodel_communi_const": cost_model_context().get_costmodel_communi_const,
  430. "costmodel_communi_bias": cost_model_context().get_costmodel_communi_bias,
  431. "run_phase": cost_model_context().get_run_phase,
  432. "costmodel_allreduce_fusion_algorithm": cost_model_context().get_costmodel_allreduce_fusion_algorithm,
  433. "costmodel_allreduce_fusion_times": cost_model_context().get_costmodel_allreduce_fusion_times,
  434. "costmodel_allreduce_fusion_tail_percent": cost_model_context().get_costmodel_allreduce_fusion_tail_percent,
  435. "costmodel_allreduce_fusion_tail_time": cost_model_context().get_costmodel_allreduce_fusion_tail_time,
  436. "costmodel_allreduce_fusion_allreduce_inherent_time":
  437. cost_model_context().get_costmodel_allreduce_fusion_allreduce_inherent_time,
  438. "costmodel_allreduce_fusion_allreduce_bandwidth":
  439. cost_model_context().get_costmodel_allreduce_fusion_allreduce_bandwidth,
  440. "costmodel_allreduce_fusion_computation_time_parameter":
  441. cost_model_context().get_costmodel_allreduce_fusion_computation_time_parameter}
  442. @args_type_check(device_memory_capacity=float, costmodel_alpha=float, costmodel_beta=float, costmodel_gamma=float,
  443. costmodel_communi_threshold=float, costmodel_communi_const=float, costmodel_communi_bias=float,
  444. multi_subgraphs=bool, run_phase=int,
  445. costmodel_allreduce_fusion_algorithm=int, costmodel_allreduce_fusion_times=int,
  446. costmodel_allreduce_fusion_tail_percent=float, costmodel_allreduce_fusion_tail_time=float,
  447. costmodel_allreduce_fusion_allreduce_inherent_time=float,
  448. costmodel_allreduce_fusion_allreduce_bandwidth=float,
  449. costmodel_allreduce_fusion_computation_time_parameter=float)
  450. def set_cost_model_context(**kwargs):
  451. """
  452. Set cost model context.
  453. Note:
  454. Attribute name is needed.
  455. Args:
  456. device_memory_capacity (float): The memory capacity for each device.
  457. costmodel_alpha (float): The parameter costmodel_alpha used in strategy-searching algorithm.
  458. costmodel_beta (float): The parameter costmodel_beta used in strategy-searching algorithm.
  459. costmodel_gamma (float): The parameter costmodel_gamma used in strategy-searching algorithm.
  460. costmodel_communi_threshold (float): A parameter used in adjusting communication calculation for practice.
  461. costmodel_communi_const (float): A parameter used in adjusting communication calculation for practice.
  462. costmodel_communi_bias (float): A parameter used in adjusting communication calculation for practice.
  463. run_phase (int): A parameter indicating which phase is running: training (0) or inference (1). Default: 0.
  464. costmodel_allreduce_fusion_algorithm (int): The allreduce fusion algorithm.
  465. 0: bypass allreduce fusion;
  466. 1: only use backward computation time to group allreduce;
  467. 2: use backward computation time and parameter gradient allreduce time to group allreduce.
  468. costmodel_allreduce_fusion_times (int): The AllReduce fusion times of parameter gradients.
  469. costmodel_allreduce_fusion_tail_percent (float): A parameter used in allreduce fusion algorithm. The percentage
  470. of backward computing time corresponding to the last parameter gradients AllReduce in the whole backward
  471. computing time.
  472. costmodel_allreduce_fusion_tail_time (float): A parameter used in allreduce fusion algorithm. The tail time of
  473. the last parameter gradients AllReduce after the end of backward computation.
  474. costmodel_allreduce_fusion_allreduce_inherent_time (float): A parameter used in allreduce fusion algorithm. The
  475. inherent cost time of AllReduce.
  476. costmodel_allreduce_fusion_allreduce_bandwidth (float): A parameter used in allreduce fusion algorithm. The
  477. bandwidth of AllReduce.
  478. costmodel_allreduce_fusion_computation_time_parameter (float): A parameter used in allreduce fusion algorithm.
  479. The parameter used to compute backward computation time.
  480. Raises:
  481. ValueError: If context keyword is not recognized.
  482. """
  483. for key, value in kwargs.items():
  484. if key not in set_cost_model_context_func_map:
  485. raise ValueError("Set context keyword %s is not recognized!" % key)
  486. set_func = set_cost_model_context_func_map[key]
  487. set_func(value)
  488. def get_cost_model_context(attr_key):
  489. """
  490. Get cost model context attributes.
  491. Note:
  492. Return value according to the attribute value.
  493. Args:
  494. attr_key (str): The key of the attribute.
  495. Raises:
  496. ValueError: If context keyword is not recognized.
  497. """
  498. if attr_key not in get_cost_model_context_func_map:
  499. raise ValueError("Get context keyword %s is not recognized!" % attr_key)
  500. get_func = get_cost_model_context_func_map[attr_key]
  501. return get_func()
  502. def reset_cost_model_context():
  503. """Reset cost model context attributes."""
  504. cost_model_context().reset_cost_model()
  505. def _set_multi_subgraphs(multi_subgraph=True):
  506. """
  507. Set the flag of ANF graph containing multiple subgraphs.
  508. Args:
  509. multi_subgraph (bool): A parameter used in marking the multi-subgraphs flag.
  510. """
  511. cost_model_context().set_multi_subgraphs(multi_subgraph)
  512. def _get_multi_subgraphs():
  513. """
  514. Get the flag of ANF graph containing multiple subgraphs.
  515. """
  516. return cost_model_context().get_multi_subgraphs()
  517. def _set_algo_single_loop(single_loop=True):
  518. """
  519. Set the flag of generating a single suite of OperatorInfos in for-loop.
  520. Args:
  521. single_loop (bool): The parameter for the single loop flag.
  522. """
  523. cost_model_context().set_dp_algo_single_loop(single_loop)
  524. def _get_algo_single_loop():
  525. """
  526. Get the flag of whether or not generating a single suite of OperatorInfos in for-loop.
  527. """
  528. return cost_model_context().get_dp_algo_single_loop()