You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

_cost_model_context.py 24 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Context of cost_model in auto_parallel"""
  16. import threading
  17. from mindspore._c_expression import CostModelContext
  18. from mindspore._checkparam import args_type_check
  19. class _CostModelContext:
  20. """
  21. _CostModelContext is the environment in which operations are executed
  22. Note:
  23. Creating a context through instantiating Context object is not recommended.
  24. Use cost_model_context() to get the context since Context is singleton.
  25. """
  26. _instance = None
  27. _instance_lock = threading.Lock()
  28. def __init__(self):
  29. self._context_handle = CostModelContext.get_instance()
  30. def __new__(cls):
  31. if cls._instance is None:
  32. cls._instance_lock.acquire()
  33. cls._instance = object.__new__(cls)
  34. cls._instance_lock.release()
  35. return cls._instance
  36. def set_device_memory_capacity(self, dev_mem_cap):
  37. """
  38. Set device memory capacity.
  39. Args:
  40. dev_mem_cap (float): The memory capacity for each device.
  41. Raises:
  42. ValueError: If context handle is none.
  43. """
  44. if self._context_handle is None:
  45. raise ValueError("Context handle is none in context!!!")
  46. self._context_handle.set_device_memory_capacity(dev_mem_cap)
  47. def get_device_memory_capacity(self):
  48. """
  49. Get device memory capacity.
  50. Raises:
  51. ValueError: If context handle is none.
  52. """
  53. if self._context_handle is None:
  54. raise ValueError("Context handle is none in context!!!")
  55. return self._context_handle.get_device_memory_capacity()
  56. def set_costmodel_alpha(self, alpha):
  57. """
  58. Set costmodel alpha.
  59. Args:
  60. alpha (float): The parameter costmodel_alpha used in strategy-searching algorithm.
  61. Raises:
  62. ValueError: If context handle is none.
  63. """
  64. if self._context_handle is None:
  65. raise ValueError("Context handle is none in context!!!")
  66. self._context_handle.set_costmodel_alpha(alpha)
  67. def get_costmodel_alpha(self):
  68. """
  69. Get costmodel alpha.
  70. Raises:
  71. ValueError: If context handle is none.
  72. """
  73. if self._context_handle is None:
  74. raise ValueError("Context handle is none in context!!!")
  75. return self._context_handle.get_costmodel_alpha()
  76. def set_costmodel_beta(self, beta):
  77. """
  78. Set costmodel beta.
  79. Args:
  80. beta (float): The parameter costmodel_beta used in strategy-searching algorithm.
  81. Raises:
  82. ValueError: If context handle is none.
  83. """
  84. if self._context_handle is None:
  85. raise ValueError("Context handle is none in context!!!")
  86. self._context_handle.set_costmodel_beta(beta)
  87. def get_costmodel_beta(self):
  88. """
  89. Get costmodel beta.
  90. Raises:
  91. ValueError: If context handle is none.
  92. """
  93. if self._context_handle is None:
  94. raise ValueError("Context handle is none in context!!!")
  95. return self._context_handle.get_costmodel_beta()
  96. def set_costmodel_gamma(self, gamma):
  97. """
  98. Set costmodel gamma.
  99. Args:
  100. gamma (float): The parameter costmodel_gamma used in strategy-searching algorithm.
  101. Raises:
  102. ValueError: If context handle is none.
  103. """
  104. if self._context_handle is None:
  105. raise ValueError("Context handle is none in context!!!")
  106. self._context_handle.set_costmodel_gamma(gamma)
  107. def get_costmodel_gamma(self):
  108. """
  109. Get costmodel gamma.
  110. Raises:
  111. ValueError: If context handle is none.
  112. """
  113. if self._context_handle is None:
  114. raise ValueError("Context handle is none in context!!!")
  115. return self._context_handle.get_costmodel_gamma()
  116. def set_costmodel_communi_threshold(self, threshold):
  117. """
  118. Set costmodel communication threshold.
  119. Args:
  120. threshold (float): A parameter used in adjusting communication calculation for practice.
  121. Raises:
  122. ValueError: If context handle is none.
  123. """
  124. if self._context_handle is None:
  125. raise ValueError("Context handle is none in context!!!")
  126. self._context_handle.set_costmodel_communi_threshold(threshold)
  127. def get_costmodel_communi_threshold(self):
  128. """
  129. Get costmodel communication threshold.
  130. Raises:
  131. ValueError: If context handle is none.
  132. """
  133. if self._context_handle is None:
  134. raise ValueError("Context handle is none in context!!!")
  135. return self._context_handle.get_costmodel_communi_threshold()
  136. def set_costmodel_communi_const(self, communi_const):
  137. """
  138. Set costmodel communication const.
  139. Args:
  140. const (float): A parameter used in adjusting communication calculation for practice.
  141. Raises:
  142. ValueError: If context handle is none.
  143. """
  144. if self._context_handle is None:
  145. raise ValueError("Context handle is none in context!!!")
  146. self._context_handle.set_costmodel_communi_const(communi_const)
  147. def get_costmodel_communi_const(self):
  148. """
  149. Get costmodel communication const.
  150. Raises:
  151. ValueError: If context handle is none.
  152. """
  153. if self._context_handle is None:
  154. raise ValueError("Context handle is none in context!!!")
  155. return self._context_handle.get_costmodel_communi_const()
  156. def set_costmodel_communi_bias(self, communi_bias):
  157. """
  158. Set costmodel communication bias.
  159. Args:
  160. bias (float): A parameter used in adjusting communication calculation for practice.
  161. Raises:
  162. ValueError: If context handle is none.
  163. """
  164. if self._context_handle is None:
  165. raise ValueError("Context handle is none in context!!!")
  166. self._context_handle.set_costmodel_communi_bias(communi_bias)
  167. def get_costmodel_communi_bias(self):
  168. """
  169. Get costmodel communication bias.
  170. Raises:
  171. ValueError: If context handle is none.
  172. """
  173. if self._context_handle is None:
  174. raise ValueError("Context handle is none in context!!!")
  175. return self._context_handle.get_costmodel_communi_bias()
  176. def set_multi_subgraphs(self, multi_subgraph):
  177. """
  178. Set the flag of ANF graph containing multiple subgraphs.
  179. Args:
  180. multi_subgraph (bool): A parameter used in marking the multi-subgraphs flag.
  181. Raises:
  182. ValueError: If context handle is none.
  183. """
  184. if self._context_handle is None:
  185. raise ValueError("Context handle is none in context!!!")
  186. self._context_handle.set_multi_subgraphs(multi_subgraph)
  187. def get_multi_subgraphs(self):
  188. """
  189. Get the flag of ANF graph containing multiple subgraphs.
  190. Raises:
  191. ValueError: If context handle is none.
  192. """
  193. if self._context_handle is None:
  194. raise ValueError("Context handle is none in context!!!")
  195. return self._context_handle.get_multi_subgraphs()
  196. def set_run_phase(self, phase):
  197. """
  198. Set the flag of running phase: training (0) or inference (1)
  199. Args:
  200. phase (int): A parameter indicating which phase is running.
  201. Raises:
  202. ValueError: If context handle is none, or phase is not in {0, 1}.
  203. """
  204. if self._context_handle is None:
  205. raise ValueError("Context handle is none in context!!!")
  206. if phase not in (0, 1):
  207. raise ValueError("The argument of set_run_phase() must be '0' or '1', but got {}".format(phase))
  208. self._context_handle.set_run_phase(phase)
  209. def get_run_phase(self):
  210. """
  211. Get the flag of running phase.
  212. Raises:
  213. ValueError: If context handle is none.
  214. """
  215. if self._context_handle is None:
  216. raise ValueError("Context handle is none in context!!!")
  217. return self._context_handle.get_run_phase()
  218. def set_costmodel_allreduce_fusion_algorithm(self, algorithm):
  219. """
  220. Set costmodel allreduce fusion algorithm.
  221. Args:
  222. algorithm (int): The AllReduce fusion algorithm of parameter gradients.
  223. Raises:
  224. ValueError: If context handle is none.
  225. """
  226. if self._context_handle is None:
  227. raise ValueError("Context handle is none in context!!!")
  228. self._context_handle.set_costmodel_allreduce_fusion_algorithm(algorithm)
  229. def get_costmodel_allreduce_fusion_algorithm(self):
  230. """
  231. Get costmodel allreduce fusion algorithm.
  232. Raises:
  233. ValueError: If context handle is none.
  234. """
  235. if self._context_handle is None:
  236. raise ValueError("Context handle is none in context!!!")
  237. return self._context_handle.get_costmodel_allreduce_fusion_algorithm()
  238. def set_costmodel_allreduce_fusion_times(self, allreduce_fusion_times):
  239. """
  240. Set costmodel allreduce fusion times.
  241. Args:
  242. allreduce_fusion_times (int): The AllReduce fusion times of parameter gradients.
  243. Raises:
  244. ValueError: If context handle is none.
  245. """
  246. if self._context_handle is None:
  247. raise ValueError("Context handle is none in context!!!")
  248. self._context_handle.set_costmodel_allreduce_fusion_times(allreduce_fusion_times)
  249. def get_costmodel_allreduce_fusion_times(self):
  250. """
  251. Get costmodel allreduce fusion times.
  252. Raises:
  253. ValueError: If context handle is none.
  254. """
  255. if self._context_handle is None:
  256. raise ValueError("Context handle is none in context!!!")
  257. return self._context_handle.get_costmodel_allreduce_fusion_times()
  258. def set_costmodel_allreduce_fusion_tail_percent(self, tail_percent):
  259. """
  260. Set costmodel allreduce fusion tail percent.
  261. Args:
  262. tail_percent (int): The percentage of backward computing time corresponding to the last parameter gradients
  263. AllReduce in the whole backward computing time.
  264. Raises:
  265. ValueError: If context handle is none.
  266. """
  267. if self._context_handle is None:
  268. raise ValueError("Context handle is none in context!!!")
  269. self._context_handle.set_costmodel_allreduce_fusion_tail_percent(tail_percent)
  270. def get_costmodel_allreduce_fusion_tail_percent(self):
  271. """
  272. Get costmodel allreduce fusion tail percent.
  273. Raises:
  274. ValueError: If context handle is none.
  275. """
  276. if self._context_handle is None:
  277. raise ValueError("Context handle is none in context!!!")
  278. return self._context_handle.get_costmodel_allreduce_fusion_tail_percent()
  279. def set_costmodel_allreduce_fusion_tail_time(self, tail_time):
  280. """
  281. Set costmodel allreduce fusion tail time.
  282. Args:
  283. tail_time (int): The tail time of the last parameter gradients AllReduce after the end of backward
  284. computation.
  285. Raises:
  286. ValueError: If context handle is none.
  287. """
  288. if self._context_handle is None:
  289. raise ValueError("Context handle is none in context!!!")
  290. self._context_handle.set_costmodel_allreduce_fusion_tail_time(tail_time)
  291. def get_costmodel_allreduce_fusion_tail_time(self):
  292. """
  293. Get costmodel allreduce fusion tail time.
  294. Raises:
  295. ValueError: If context handle is none.
  296. """
  297. if self._context_handle is None:
  298. raise ValueError("Context handle is none in context!!!")
  299. return self._context_handle.get_costmodel_allreduce_fusion_tail_time()
  300. def set_costmodel_allreduce_fusion_allreduce_inherent_time(self, allreduce_inherent_time):
  301. """
  302. Set costmodel allreduce fusion allreduce inherent time.
  303. Args:
  304. allreduce_inherent_time (int): The inherent cost time of AllReduce.
  305. Raises:
  306. ValueError: If context handle is none.
  307. """
  308. if self._context_handle is None:
  309. raise ValueError("Context handle is none in context!!!")
  310. self._context_handle.set_costmodel_allreduce_fusion_allreduce_inherent_time(allreduce_inherent_time)
  311. def get_costmodel_allreduce_fusion_allreduce_inherent_time(self):
  312. """
  313. Get costmodel allreduce fusion allreduce inherent time.
  314. Raises:
  315. ValueError: If context handle is none.
  316. """
  317. if self._context_handle is None:
  318. raise ValueError("Context handle is none in context!!!")
  319. return self._context_handle.get_costmodel_allreduce_fusion_allreduce_inherent_time()
  320. def set_costmodel_allreduce_fusion_allreduce_bandwidth(self, allreduce_bandwidth):
  321. """
  322. Set costmodel allreduce fusion allreduce bandwidth.
  323. Args:
  324. allreduce_bandwidth (int): The bancwidth of AllReduce.
  325. Raises:
  326. ValueError: If context handle is none.
  327. """
  328. if self._context_handle is None:
  329. raise ValueError("Context handle is none in context!!!")
  330. self._context_handle.set_costmodel_allreduce_fusion_allreduce_bandwidth(allreduce_bandwidth)
  331. def get_costmodel_allreduce_fusion_allreduce_bandwidth(self):
  332. """
  333. Get costmodel allreduce fusion allreduce bandwidth.
  334. Raises:
  335. ValueError: If context handle is none.
  336. """
  337. if self._context_handle is None:
  338. raise ValueError("Context handle is none in context!!!")
  339. return self._context_handle.get_costmodel_allreduce_fusion_allreduce_bandwidth()
  340. def set_costmodel_allreduce_fusion_computation_time_parameter(self, computation_time_parameter):
  341. """
  342. Set costmodel allreduce fusion computation time parameter.
  343. Args:
  344. computation_time_parameter (int): The parameter used to compute backward computation time.
  345. Raises:
  346. ValueError: If context handle is none.
  347. """
  348. if self._context_handle is None:
  349. raise ValueError("Context handle is none in context!!!")
  350. self._context_handle.set_costmodel_allreduce_fusion_computation_time_parameter(computation_time_parameter)
  351. def get_costmodel_allreduce_fusion_computation_time_parameter(self):
  352. """
  353. Get costmodel allreduce fusion computation time parameter.
  354. Raises:
  355. ValueError: If context handle is none.
  356. """
  357. if self._context_handle is None:
  358. raise ValueError("Context handle is none in context!!!")
  359. return self._context_handle.get_costmodel_allreduce_fusion_computation_time_parameter()
  360. def reset_cost_model(self):
  361. """
  362. Reset cost model settings.
  363. Raises:
  364. ValueError: If context handle is none.
  365. """
  366. if self._context_handle is None:
  367. raise ValueError("Context handle is none in context!!!")
  368. self._context_handle.reset_cost_model()
  369. _cost_model_context = None
  370. def cost_model_context():
  371. """
  372. Get the global _cost_model_context. If it is not created, create a new one.
  373. Returns:
  374. The global cost_model context.
  375. """
  376. global _cost_model_context
  377. if _cost_model_context is None:
  378. _cost_model_context = _CostModelContext()
  379. return _cost_model_context
  380. set_cost_model_context_func_map = {
  381. "device_memory_capacity": cost_model_context().set_device_memory_capacity,
  382. "costmodel_alpha": cost_model_context().set_costmodel_alpha,
  383. "costmodel_beta": cost_model_context().set_costmodel_beta,
  384. "costmodel_gamma": cost_model_context().set_costmodel_gamma,
  385. "costmodel_communi_threshold": cost_model_context().set_costmodel_communi_threshold,
  386. "costmodel_communi_const": cost_model_context().set_costmodel_communi_const,
  387. "costmodel_communi_bias": cost_model_context().set_costmodel_communi_bias,
  388. "multi_subgraphs": cost_model_context().set_multi_subgraphs,
  389. "run_phase": cost_model_context().set_run_phase,
  390. "costmodel_allreduce_fusion_algorithm": cost_model_context().set_costmodel_allreduce_fusion_algorithm,
  391. "costmodel_allreduce_fusion_times": cost_model_context().set_costmodel_allreduce_fusion_times,
  392. "costmodel_allreduce_fusion_tail_percent": cost_model_context().set_costmodel_allreduce_fusion_tail_percent,
  393. "costmodel_allreduce_fusion_tail_time": cost_model_context().set_costmodel_allreduce_fusion_tail_time,
  394. "costmodel_allreduce_fusion_allreduce_inherent_time":
  395. cost_model_context().set_costmodel_allreduce_fusion_allreduce_inherent_time,
  396. "costmodel_allreduce_fusion_allreduce_bandwidth":
  397. cost_model_context().set_costmodel_allreduce_fusion_allreduce_bandwidth,
  398. "costmodel_allreduce_fusion_computation_time_parameter":
  399. cost_model_context().set_costmodel_allreduce_fusion_computation_time_parameter}
  400. get_cost_model_context_func_map = {
  401. "device_memory_capacity": cost_model_context().get_device_memory_capacity,
  402. "costmodel_alpha": cost_model_context().get_costmodel_alpha,
  403. "costmodel_beta": cost_model_context().get_costmodel_beta,
  404. "costmodel_gamma": cost_model_context().get_costmodel_gamma,
  405. "costmodel_communi_threshold": cost_model_context().get_costmodel_communi_threshold,
  406. "costmodel_communi_const": cost_model_context().get_costmodel_communi_const,
  407. "costmodel_communi_bias": cost_model_context().get_costmodel_communi_bias,
  408. "multi_subgraphs": cost_model_context().get_multi_subgraphs,
  409. "run_phase": cost_model_context().get_run_phase,
  410. "costmodel_allreduce_fusion_algorithm": cost_model_context().get_costmodel_allreduce_fusion_algorithm,
  411. "costmodel_allreduce_fusion_times": cost_model_context().get_costmodel_allreduce_fusion_times,
  412. "costmodel_allreduce_fusion_tail_percent": cost_model_context().get_costmodel_allreduce_fusion_tail_percent,
  413. "costmodel_allreduce_fusion_tail_time": cost_model_context().get_costmodel_allreduce_fusion_tail_time,
  414. "costmodel_allreduce_fusion_allreduce_inherent_time":
  415. cost_model_context().get_costmodel_allreduce_fusion_allreduce_inherent_time,
  416. "costmodel_allreduce_fusion_allreduce_bandwidth":
  417. cost_model_context().get_costmodel_allreduce_fusion_allreduce_bandwidth,
  418. "costmodel_allreduce_fusion_computation_time_parameter":
  419. cost_model_context().get_costmodel_allreduce_fusion_computation_time_parameter}
  420. @args_type_check(device_memory_capacity=float, costmodel_alpha=float, costmodel_beta=float, costmodel_gamma=float,
  421. costmodel_communi_threshold=float, costmodel_communi_const=float, costmodel_communi_bias=float,
  422. multi_subgraphs=bool, run_phase=int,
  423. costmodel_allreduce_fusion_algorithm=int, costmodel_allreduce_fusion_times=int,
  424. costmodel_allreduce_fusion_tail_percent=float, costmodel_allreduce_fusion_tail_time=float,
  425. costmodel_allreduce_fusion_allreduce_inherent_time=float,
  426. costmodel_allreduce_fusion_allreduce_bandwidth=float,
  427. costmodel_allreduce_fusion_computation_time_parameter=float)
  428. def set_cost_model_context(**kwargs):
  429. """
  430. Set cost model context.
  431. Note:
  432. Attribute name is needed.
  433. Args:
  434. device_memory_capacity (float): The memory capacity for each device.
  435. costmodel_alpha (float): The parameter costmodel_alpha used in strategy-searching algorithm.
  436. costmodel_beta (float): The parameter costmodel_beta used in strategy-searching algorithm.
  437. costmodel_gamma (float): The parameter costmodel_gamma used in strategy-searching algorithm.
  438. costmodel_communi_threshold (float): A parameter used in adjusting communication calculation for practice.
  439. costmodel_communi_const (float): A parameter used in adjusting communication calculation for practice.
  440. costmodel_communi_bias (float): A parameter used in adjusting communication calculation for practice.
  441. multi_subgraphs (bool): A parameter used in marking the flag of ANF graph containing multiple subgraphs.
  442. run_phase (int): A parameter indicating which phase is running: training (0) or inference (1). Default: 0.
  443. costmodel_allreduce_fusion_algorithm (int): The allreduce fusion algorithm.
  444. 0: bypass allreduce fusion;
  445. 1: only use backward computation time to group allreduce;
  446. 2: use backward computation time and parameter gradient allreduce time to group allreduce.
  447. costmodel_allreduce_fusion_times (int): The AllReduce fusion times of parameter gradients.
  448. costmodel_allreduce_fusion_tail_percent (float): A parameter used in allreduce fusion algorithm. The percentage
  449. of backward computing time corresponding to the last parameter gradients AllReduce in the whole backward
  450. computing time.
  451. costmodel_allreduce_fusion_tail_time (float): A parameter used in allreduce fusion algorithm. The tail time of
  452. the last parameter gradients AllReduce after the end of backward computation.
  453. costmodel_allreduce_fusion_allreduce_inherent_time (float): A parameter used in allreduce fusion algorithm. The
  454. inherent cost time of AllReduce.
  455. costmodel_allreduce_fusion_allreduce_bandwidth (float): A parameter used in allreduce fusion algorithm. The
  456. bandwidth of AllReduce.
  457. costmodel_allreduce_fusion_computation_time_parameter (float): A parameter used in allreduce fusion algorithm.
  458. The parameter used to compute backward computation time.
  459. Raises:
  460. ValueError: If context keyword is not recognized.
  461. """
  462. for key, value in kwargs.items():
  463. if key not in set_cost_model_context_func_map:
  464. raise ValueError("Set context keyword %s is not recognized!" % key)
  465. set_func = set_cost_model_context_func_map[key]
  466. set_func(value)
  467. def get_cost_model_context(attr_key):
  468. """
  469. Get cost model context attributes.
  470. Note:
  471. Return value according to the attribute value.
  472. Args:
  473. attr_key (str): The key of the attribute.
  474. Raises:
  475. ValueError: If context keyword is not recognized.
  476. """
  477. if attr_key not in get_cost_model_context_func_map:
  478. raise ValueError("Get context keyword %s is not recognized!" % attr_key)
  479. get_func = get_cost_model_context_func_map[attr_key]
  480. return get_func()
  481. def reset_cost_model_context():
  482. """Reset cost model context attributes."""
  483. cost_model_context().reset_cost_model()