You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

_cost_model_context.py 25 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Context of cost_model in auto_parallel"""
  16. import threading
  17. from mindspore._c_expression import CostModelContext
  18. from mindspore._checkparam import args_type_check
  19. class _CostModelContext:
  20. """
  21. _CostModelContext is the environment in which operations are executed
  22. Note:
  23. Creating a context through instantiating Context object is not recommended.
  24. Use cost_model_context() to get the context since Context is singleton.
  25. """
  26. _instance = None
  27. _instance_lock = threading.Lock()
  28. def __init__(self):
  29. self._context_handle = CostModelContext.get_instance()
  30. def __new__(cls):
  31. if cls._instance is None:
  32. cls._instance_lock.acquire()
  33. cls._instance = object.__new__(cls)
  34. cls._instance_lock.release()
  35. return cls._instance
  36. def set_device_memory_capacity(self, dev_mem_cap):
  37. """
  38. Set device memory capacity.
  39. Args:
  40. dev_mem_cap (float): The memory capacity for each device.
  41. Raises:
  42. ValueError: If context handle is none.
  43. """
  44. if self._context_handle is None:
  45. raise ValueError("Context handle is none in context!!!")
  46. self._context_handle.set_device_memory_capacity(dev_mem_cap)
  47. def get_device_memory_capacity(self):
  48. """
  49. Get device memory capacity.
  50. Raises:
  51. ValueError: If context handle is none.
  52. """
  53. if self._context_handle is None:
  54. raise ValueError("Context handle is none in context!!!")
  55. return self._context_handle.get_device_memory_capacity()
  56. def set_costmodel_alpha(self, alpha):
  57. """
  58. Set costmodel alpha.
  59. Args:
  60. alpha (float): The parameter costmodel_alpha used in strategy-searching algorithm.
  61. Raises:
  62. ValueError: If context handle is none.
  63. """
  64. if self._context_handle is None:
  65. raise ValueError("Context handle is none in context!!!")
  66. self._context_handle.set_costmodel_alpha(alpha)
  67. def get_costmodel_alpha(self):
  68. """
  69. Get costmodel alpha.
  70. Raises:
  71. ValueError: If context handle is none.
  72. """
  73. if self._context_handle is None:
  74. raise ValueError("Context handle is none in context!!!")
  75. return self._context_handle.get_costmodel_alpha()
  76. def set_costmodel_beta(self, beta):
  77. """
  78. Set costmodel beta.
  79. Args:
  80. beta (float): The parameter costmodel_beta used in strategy-searching algorithm.
  81. Raises:
  82. ValueError: If context handle is none.
  83. """
  84. if self._context_handle is None:
  85. raise ValueError("Context handle is none in context!!!")
  86. self._context_handle.set_costmodel_beta(beta)
  87. def get_costmodel_beta(self):
  88. """
  89. Get costmodel beta.
  90. Raises:
  91. ValueError: If context handle is none.
  92. """
  93. if self._context_handle is None:
  94. raise ValueError("Context handle is none in context!!!")
  95. return self._context_handle.get_costmodel_beta()
  96. def set_costmodel_gamma(self, gamma):
  97. """
  98. Set costmodel gamma.
  99. Args:
  100. gamma (float): The parameter costmodel_gamma used in strategy-searching algorithm.
  101. Raises:
  102. ValueError: If context handle is none.
  103. """
  104. if self._context_handle is None:
  105. raise ValueError("Context handle is none in context!!!")
  106. self._context_handle.set_costmodel_gamma(gamma)
  107. def get_costmodel_gamma(self):
  108. """
  109. Get costmodel gamma.
  110. Raises:
  111. ValueError: If context handle is none.
  112. """
  113. if self._context_handle is None:
  114. raise ValueError("Context handle is none in context!!!")
  115. return self._context_handle.get_costmodel_gamma()
  116. def set_costmodel_communi_threshold(self, threshold):
  117. """
  118. Set costmodel communication threshold.
  119. Args:
  120. threshold (float): A parameter used in adjusting communication calculation for practice.
  121. Raises:
  122. ValueError: If context handle is none.
  123. """
  124. if self._context_handle is None:
  125. raise ValueError("Context handle is none in context!!!")
  126. self._context_handle.set_costmodel_communi_threshold(threshold)
  127. def get_costmodel_communi_threshold(self):
  128. """
  129. Get costmodel communication threshold.
  130. Raises:
  131. ValueError: If context handle is none.
  132. """
  133. if self._context_handle is None:
  134. raise ValueError("Context handle is none in context!!!")
  135. return self._context_handle.get_costmodel_communi_threshold()
  136. def set_costmodel_communi_const(self, communi_const):
  137. """
  138. Set costmodel communication const.
  139. Args:
  140. const (float): A parameter used in adjusting communication calculation for practice.
  141. Raises:
  142. ValueError: If context handle is none.
  143. """
  144. if self._context_handle is None:
  145. raise ValueError("Context handle is none in context!!!")
  146. self._context_handle.set_costmodel_communi_const(communi_const)
  147. def get_costmodel_communi_const(self):
  148. """
  149. Get costmodel communication const.
  150. Raises:
  151. ValueError: If context handle is none.
  152. """
  153. if self._context_handle is None:
  154. raise ValueError("Context handle is none in context!!!")
  155. return self._context_handle.get_costmodel_communi_const()
  156. def set_costmodel_communi_bias(self, communi_bias):
  157. """
  158. Set costmodel communication bias.
  159. Args:
  160. bias (float): A parameter used in adjusting communication calculation for practice.
  161. Raises:
  162. ValueError: If context handle is none.
  163. """
  164. if self._context_handle is None:
  165. raise ValueError("Context handle is none in context!!!")
  166. self._context_handle.set_costmodel_communi_bias(communi_bias)
  167. def get_costmodel_communi_bias(self):
  168. """
  169. Get costmodel communication bias.
  170. Raises:
  171. ValueError: If context handle is none.
  172. """
  173. if self._context_handle is None:
  174. raise ValueError("Context handle is none in context!!!")
  175. return self._context_handle.get_costmodel_communi_bias()
  176. def set_multi_subgraphs(self, multi_subgraph):
  177. """
  178. Set the flag of ANF graph containing multiple subgraphs.
  179. Args:
  180. multi_subgraph (bool): A parameter used in marking the multi-subgraphs flag.
  181. Raises:
  182. ValueError: If context handle is none.
  183. """
  184. if self._context_handle is None:
  185. raise ValueError("Context handle is none in context!!!")
  186. self._context_handle.set_multi_subgraphs(multi_subgraph)
  187. def get_multi_subgraphs(self):
  188. """
  189. Get the flag of ANF graph containing multiple subgraphs.
  190. Raises:
  191. ValueError: If context handle is none.
  192. """
  193. if self._context_handle is None:
  194. raise ValueError("Context handle is none in context!!!")
  195. return self._context_handle.get_multi_subgraphs()
  196. def set_run_phase(self, phase):
  197. """
  198. Set the flag of running phase: training (0) or inference (1)
  199. Args:
  200. phase (int): A parameter indicating which phase is running.
  201. Raises:
  202. ValueError: If context handle is none, or phase is not in {0, 1}.
  203. """
  204. if self._context_handle is None:
  205. raise ValueError("Context handle is none in context!!!")
  206. if phase not in (0, 1):
  207. raise ValueError("The argument of set_run_phase() must be '0' or '1', but got {}".format(phase))
  208. self._context_handle.set_run_phase(phase)
  209. def get_run_phase(self):
  210. """
  211. Get the flag of running phase.
  212. Raises:
  213. ValueError: If context handle is none.
  214. """
  215. if self._context_handle is None:
  216. raise ValueError("Context handle is none in context!!!")
  217. return self._context_handle.get_run_phase()
  218. def set_dp_algo_single_loop(self, single_loop):
  219. """
  220. Set the flag of generating a single suite of OperatorInfos in for-loop.
  221. Args:
  222. single_loop (bool): The parameter for the single loop flag.
  223. Raises:
  224. ValueError: If context handle is none.
  225. """
  226. if self._context_handle is None:
  227. raise ValueError("Context handle is none in context!!!")
  228. self._context_handle.set_dp_algo_single_loop(single_loop)
  229. def get_dp_algo_single_loop(self):
  230. """
  231. Get the flag of whether or not generating a single suite of OperatorInfos in for-loop.
  232. Raises:
  233. ValueError: If context handle is none.
  234. """
  235. if self._context_handle is None:
  236. raise ValueError("Context handle is none in context!!!")
  237. return self._context_handle.get_dp_algo_single_loop()
  238. def set_costmodel_allreduce_fusion_algorithm(self, algorithm):
  239. """
  240. Set costmodel allreduce fusion algorithm.
  241. Args:
  242. algorithm (int): The AllReduce fusion algorithm of parameter gradients.
  243. Raises:
  244. ValueError: If context handle is none.
  245. """
  246. if self._context_handle is None:
  247. raise ValueError("Context handle is none in context!!!")
  248. self._context_handle.set_costmodel_allreduce_fusion_algorithm(algorithm)
  249. def get_costmodel_allreduce_fusion_algorithm(self):
  250. """
  251. Get costmodel allreduce fusion algorithm.
  252. Raises:
  253. ValueError: If context handle is none.
  254. """
  255. if self._context_handle is None:
  256. raise ValueError("Context handle is none in context!!!")
  257. return self._context_handle.get_costmodel_allreduce_fusion_algorithm()
  258. def set_costmodel_allreduce_fusion_times(self, allreduce_fusion_times):
  259. """
  260. Set costmodel allreduce fusion times.
  261. Args:
  262. allreduce_fusion_times (int): The AllReduce fusion times of parameter gradients.
  263. Raises:
  264. ValueError: If context handle is none.
  265. """
  266. if self._context_handle is None:
  267. raise ValueError("Context handle is none in context!!!")
  268. self._context_handle.set_costmodel_allreduce_fusion_times(allreduce_fusion_times)
  269. def get_costmodel_allreduce_fusion_times(self):
  270. """
  271. Get costmodel allreduce fusion times.
  272. Raises:
  273. ValueError: If context handle is none.
  274. """
  275. if self._context_handle is None:
  276. raise ValueError("Context handle is none in context!!!")
  277. return self._context_handle.get_costmodel_allreduce_fusion_times()
  278. def set_costmodel_allreduce_fusion_tail_percent(self, tail_percent):
  279. """
  280. Set costmodel allreduce fusion tail percent.
  281. Args:
  282. tail_percent (int): The percentage of backward computing time corresponding to the last parameter gradients
  283. AllReduce in the whole backward computing time.
  284. Raises:
  285. ValueError: If context handle is none.
  286. """
  287. if self._context_handle is None:
  288. raise ValueError("Context handle is none in context!!!")
  289. self._context_handle.set_costmodel_allreduce_fusion_tail_percent(tail_percent)
  290. def get_costmodel_allreduce_fusion_tail_percent(self):
  291. """
  292. Get costmodel allreduce fusion tail percent.
  293. Raises:
  294. ValueError: If context handle is none.
  295. """
  296. if self._context_handle is None:
  297. raise ValueError("Context handle is none in context!!!")
  298. return self._context_handle.get_costmodel_allreduce_fusion_tail_percent()
  299. def set_costmodel_allreduce_fusion_tail_time(self, tail_time):
  300. """
  301. Set costmodel allreduce fusion tail time.
  302. Args:
  303. tail_time (int): The tail time of the last parameter gradients AllReduce after the end of backward
  304. computation.
  305. Raises:
  306. ValueError: If context handle is none.
  307. """
  308. if self._context_handle is None:
  309. raise ValueError("Context handle is none in context!!!")
  310. self._context_handle.set_costmodel_allreduce_fusion_tail_time(tail_time)
  311. def get_costmodel_allreduce_fusion_tail_time(self):
  312. """
  313. Get costmodel allreduce fusion tail time.
  314. Raises:
  315. ValueError: If context handle is none.
  316. """
  317. if self._context_handle is None:
  318. raise ValueError("Context handle is none in context!!!")
  319. return self._context_handle.get_costmodel_allreduce_fusion_tail_time()
  320. def set_costmodel_allreduce_fusion_allreduce_inherent_time(self, allreduce_inherent_time):
  321. """
  322. Set costmodel allreduce fusion allreduce inherent time.
  323. Args:
  324. allreduce_inherent_time (int): The inherent cost time of AllReduce.
  325. Raises:
  326. ValueError: If context handle is none.
  327. """
  328. if self._context_handle is None:
  329. raise ValueError("Context handle is none in context!!!")
  330. self._context_handle.set_costmodel_allreduce_fusion_allreduce_inherent_time(allreduce_inherent_time)
  331. def get_costmodel_allreduce_fusion_allreduce_inherent_time(self):
  332. """
  333. Get costmodel allreduce fusion allreduce inherent time.
  334. Raises:
  335. ValueError: If context handle is none.
  336. """
  337. if self._context_handle is None:
  338. raise ValueError("Context handle is none in context!!!")
  339. return self._context_handle.get_costmodel_allreduce_fusion_allreduce_inherent_time()
  340. def set_costmodel_allreduce_fusion_allreduce_bandwidth(self, allreduce_bandwidth):
  341. """
  342. Set costmodel allreduce fusion allreduce bandwidth.
  343. Args:
  344. allreduce_bandwidth (int): The bandwidth of AllReduce.
  345. Raises:
  346. ValueError: If context handle is none.
  347. """
  348. if self._context_handle is None:
  349. raise ValueError("Context handle is none in context!!!")
  350. self._context_handle.set_costmodel_allreduce_fusion_allreduce_bandwidth(allreduce_bandwidth)
  351. def get_costmodel_allreduce_fusion_allreduce_bandwidth(self):
  352. """
  353. Get costmodel allreduce fusion allreduce bandwidth.
  354. Raises:
  355. ValueError: If context handle is none.
  356. """
  357. if self._context_handle is None:
  358. raise ValueError("Context handle is none in context!!!")
  359. return self._context_handle.get_costmodel_allreduce_fusion_allreduce_bandwidth()
  360. def set_costmodel_allreduce_fusion_computation_time_parameter(self, computation_time_parameter):
  361. """
  362. Set costmodel allreduce fusion computation time parameter.
  363. Args:
  364. computation_time_parameter (int): The parameter used to compute backward computation time.
  365. Raises:
  366. ValueError: If context handle is none.
  367. """
  368. if self._context_handle is None:
  369. raise ValueError("Context handle is none in context!!!")
  370. self._context_handle.set_costmodel_allreduce_fusion_computation_time_parameter(computation_time_parameter)
  371. def get_costmodel_allreduce_fusion_computation_time_parameter(self):
  372. """
  373. Get costmodel allreduce fusion computation time parameter.
  374. Raises:
  375. ValueError: If context handle is none.
  376. """
  377. if self._context_handle is None:
  378. raise ValueError("Context handle is none in context!!!")
  379. return self._context_handle.get_costmodel_allreduce_fusion_computation_time_parameter()
  380. def reset_cost_model(self):
  381. """
  382. Reset cost model settings.
  383. Raises:
  384. ValueError: If context handle is none.
  385. """
  386. if self._context_handle is None:
  387. raise ValueError("Context handle is none in context!!!")
  388. self._context_handle.reset_cost_model()
  389. _cost_model_context = None
  390. def cost_model_context():
  391. """
  392. Get the global _cost_model_context. If it is not created, create a new one.
  393. Returns:
  394. The global cost_model context.
  395. """
  396. global _cost_model_context
  397. if _cost_model_context is None:
  398. _cost_model_context = _CostModelContext()
  399. return _cost_model_context
  400. set_cost_model_context_func_map = {
  401. "device_memory_capacity": cost_model_context().set_device_memory_capacity,
  402. "costmodel_alpha": cost_model_context().set_costmodel_alpha,
  403. "costmodel_beta": cost_model_context().set_costmodel_beta,
  404. "costmodel_gamma": cost_model_context().set_costmodel_gamma,
  405. "costmodel_communi_threshold": cost_model_context().set_costmodel_communi_threshold,
  406. "costmodel_communi_const": cost_model_context().set_costmodel_communi_const,
  407. "costmodel_communi_bias": cost_model_context().set_costmodel_communi_bias,
  408. "run_phase": cost_model_context().set_run_phase,
  409. "costmodel_allreduce_fusion_algorithm": cost_model_context().set_costmodel_allreduce_fusion_algorithm,
  410. "costmodel_allreduce_fusion_times": cost_model_context().set_costmodel_allreduce_fusion_times,
  411. "costmodel_allreduce_fusion_tail_percent": cost_model_context().set_costmodel_allreduce_fusion_tail_percent,
  412. "costmodel_allreduce_fusion_tail_time": cost_model_context().set_costmodel_allreduce_fusion_tail_time,
  413. "costmodel_allreduce_fusion_allreduce_inherent_time":
  414. cost_model_context().set_costmodel_allreduce_fusion_allreduce_inherent_time,
  415. "costmodel_allreduce_fusion_allreduce_bandwidth":
  416. cost_model_context().set_costmodel_allreduce_fusion_allreduce_bandwidth,
  417. "costmodel_allreduce_fusion_computation_time_parameter":
  418. cost_model_context().set_costmodel_allreduce_fusion_computation_time_parameter}
  419. get_cost_model_context_func_map = {
  420. "device_memory_capacity": cost_model_context().get_device_memory_capacity,
  421. "costmodel_alpha": cost_model_context().get_costmodel_alpha,
  422. "costmodel_beta": cost_model_context().get_costmodel_beta,
  423. "costmodel_gamma": cost_model_context().get_costmodel_gamma,
  424. "costmodel_communi_threshold": cost_model_context().get_costmodel_communi_threshold,
  425. "costmodel_communi_const": cost_model_context().get_costmodel_communi_const,
  426. "costmodel_communi_bias": cost_model_context().get_costmodel_communi_bias,
  427. "run_phase": cost_model_context().get_run_phase,
  428. "costmodel_allreduce_fusion_algorithm": cost_model_context().get_costmodel_allreduce_fusion_algorithm,
  429. "costmodel_allreduce_fusion_times": cost_model_context().get_costmodel_allreduce_fusion_times,
  430. "costmodel_allreduce_fusion_tail_percent": cost_model_context().get_costmodel_allreduce_fusion_tail_percent,
  431. "costmodel_allreduce_fusion_tail_time": cost_model_context().get_costmodel_allreduce_fusion_tail_time,
  432. "costmodel_allreduce_fusion_allreduce_inherent_time":
  433. cost_model_context().get_costmodel_allreduce_fusion_allreduce_inherent_time,
  434. "costmodel_allreduce_fusion_allreduce_bandwidth":
  435. cost_model_context().get_costmodel_allreduce_fusion_allreduce_bandwidth,
  436. "costmodel_allreduce_fusion_computation_time_parameter":
  437. cost_model_context().get_costmodel_allreduce_fusion_computation_time_parameter}
  438. @args_type_check(device_memory_capacity=float, costmodel_alpha=float, costmodel_beta=float, costmodel_gamma=float,
  439. costmodel_communi_threshold=float, costmodel_communi_const=float, costmodel_communi_bias=float,
  440. multi_subgraphs=bool, run_phase=int,
  441. costmodel_allreduce_fusion_algorithm=int, costmodel_allreduce_fusion_times=int,
  442. costmodel_allreduce_fusion_tail_percent=float, costmodel_allreduce_fusion_tail_time=float,
  443. costmodel_allreduce_fusion_allreduce_inherent_time=float,
  444. costmodel_allreduce_fusion_allreduce_bandwidth=float,
  445. costmodel_allreduce_fusion_computation_time_parameter=float)
  446. def set_cost_model_context(**kwargs):
  447. """
  448. Set cost model context.
  449. Note:
  450. Attribute name is needed.
  451. Args:
  452. device_memory_capacity (float): The memory capacity for each device.
  453. costmodel_alpha (float): The parameter costmodel_alpha used in strategy-searching algorithm.
  454. costmodel_beta (float): The parameter costmodel_beta used in strategy-searching algorithm.
  455. costmodel_gamma (float): The parameter costmodel_gamma used in strategy-searching algorithm.
  456. costmodel_communi_threshold (float): A parameter used in adjusting communication calculation for practice.
  457. costmodel_communi_const (float): A parameter used in adjusting communication calculation for practice.
  458. costmodel_communi_bias (float): A parameter used in adjusting communication calculation for practice.
  459. run_phase (int): A parameter indicating which phase is running: training (0) or inference (1). Default: 0.
  460. costmodel_allreduce_fusion_algorithm (int): The allreduce fusion algorithm.
  461. 0: bypass allreduce fusion;
  462. 1: only use backward computation time to group allreduce;
  463. 2: use backward computation time and parameter gradient allreduce time to group allreduce.
  464. costmodel_allreduce_fusion_times (int): The AllReduce fusion times of parameter gradients.
  465. costmodel_allreduce_fusion_tail_percent (float): A parameter used in allreduce fusion algorithm. The percentage
  466. of backward computing time corresponding to the last parameter gradients AllReduce in the whole backward
  467. computing time.
  468. costmodel_allreduce_fusion_tail_time (float): A parameter used in allreduce fusion algorithm. The tail time of
  469. the last parameter gradients AllReduce after the end of backward computation.
  470. costmodel_allreduce_fusion_allreduce_inherent_time (float): A parameter used in allreduce fusion algorithm. The
  471. inherent cost time of AllReduce.
  472. costmodel_allreduce_fusion_allreduce_bandwidth (float): A parameter used in allreduce fusion algorithm. The
  473. bandwidth of AllReduce.
  474. costmodel_allreduce_fusion_computation_time_parameter (float): A parameter used in allreduce fusion algorithm.
  475. The parameter used to compute backward computation time.
  476. Raises:
  477. ValueError: If context keyword is not recognized.
  478. """
  479. for key, value in kwargs.items():
  480. if key not in set_cost_model_context_func_map:
  481. raise ValueError("Set context keyword %s is not recognized!" % key)
  482. set_func = set_cost_model_context_func_map[key]
  483. set_func(value)
  484. def get_cost_model_context(attr_key):
  485. """
  486. Get cost model context attributes.
  487. Note:
  488. Return value according to the attribute value.
  489. Args:
  490. attr_key (str): The key of the attribute.
  491. Raises:
  492. ValueError: If context keyword is not recognized.
  493. """
  494. if attr_key not in get_cost_model_context_func_map:
  495. raise ValueError("Get context keyword %s is not recognized!" % attr_key)
  496. get_func = get_cost_model_context_func_map[attr_key]
  497. return get_func()
  498. def reset_cost_model_context():
  499. """Reset cost model context attributes."""
  500. cost_model_context().reset_cost_model()
  501. def _set_multi_subgraphs(multi_subgraph=True):
  502. """
  503. Set the flag of ANF graph containing multiple subgraphs.
  504. Args:
  505. multi_subgraph (bool): A parameter used in marking the multi-subgraphs flag.
  506. """
  507. cost_model_context().set_multi_subgraphs(multi_subgraph)
  508. def _get_multi_subgraphs():
  509. """
  510. Get the flag of ANF graph containing multiple subgraphs.
  511. """
  512. return cost_model_context().get_multi_subgraphs()
  513. def _set_algo_single_loop(single_loop=True):
  514. """
  515. Set the flag of generating a single suite of OperatorInfos in for-loop.
  516. Args:
  517. single_loop (bool): The parameter for the single loop flag.
  518. """
  519. cost_model_context().set_dp_algo_single_loop(single_loop)
  520. def _get_algo_single_loop():
  521. """
  522. Get the flag of whether or not generating a single suite of OperatorInfos in for-loop.
  523. """
  524. return cost_model_context().get_dp_algo_single_loop()