You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

_cost_model_context.py 22 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Context of cost_model in auto_parallel"""
  16. import threading
  17. from mindspore._c_expression import CostModelContext
  18. from mindspore._checkparam import args_type_check
  19. class _CostModelContext:
  20. """
  21. _CostModelContext is the environment in which operations are executed
  22. Note:
  23. Creating a context through instantiating Context object is not recommended.
  24. Use cost_model_context() to get the context since Context is singleton.
  25. """
  26. _instance = None
  27. _instance_lock = threading.Lock()
  28. def __init__(self):
  29. self._context_handle = CostModelContext.get_instance()
  30. def __new__(cls):
  31. if cls._instance is None:
  32. cls._instance_lock.acquire()
  33. cls._instance = object.__new__(cls)
  34. cls._instance_lock.release()
  35. return cls._instance
  36. def set_device_memory_capacity(self, dev_mem_cap):
  37. """
  38. Set device memory capacity.
  39. Args:
  40. dev_mem_cap (float): The memory capacity for each device.
  41. Raises:
  42. ValueError: If context handle is none.
  43. """
  44. if self._context_handle is None:
  45. raise ValueError("Context handle is none in context!!!")
  46. self._context_handle.set_device_memory_capacity(dev_mem_cap)
  47. def get_device_memory_capacity(self):
  48. """
  49. Get device memory capacity.
  50. Raises:
  51. ValueError: If context handle is none.
  52. """
  53. if self._context_handle is None:
  54. raise ValueError("Context handle is none in context!!!")
  55. return self._context_handle.get_device_memory_capacity()
  56. def set_costmodel_alpha(self, alpha):
  57. """
  58. Set costmodel alpha.
  59. Args:
  60. alpha (float): The parameter costmodel_alpha used in strategy-searching algorithm.
  61. Raises:
  62. ValueError: If context handle is none.
  63. """
  64. if self._context_handle is None:
  65. raise ValueError("Context handle is none in context!!!")
  66. self._context_handle.set_costmodel_alpha(alpha)
  67. def get_costmodel_alpha(self):
  68. """
  69. Get costmodel alpha.
  70. Raises:
  71. ValueError: If context handle is none.
  72. """
  73. if self._context_handle is None:
  74. raise ValueError("Context handle is none in context!!!")
  75. return self._context_handle.get_costmodel_alpha()
  76. def set_costmodel_beta(self, beta):
  77. """
  78. Set costmodel beta.
  79. Args:
  80. beta (float): The parameter costmodel_beta used in strategy-searching algorithm.
  81. Raises:
  82. ValueError: If context handle is none.
  83. """
  84. if self._context_handle is None:
  85. raise ValueError("Context handle is none in context!!!")
  86. self._context_handle.set_costmodel_beta(beta)
  87. def get_costmodel_beta(self):
  88. """
  89. Get costmodel beta.
  90. Raises:
  91. ValueError: If context handle is none.
  92. """
  93. if self._context_handle is None:
  94. raise ValueError("Context handle is none in context!!!")
  95. return self._context_handle.get_costmodel_beta()
  96. def set_costmodel_gamma(self, gamma):
  97. """
  98. Set costmodel gamma.
  99. Args:
  100. gamma (float): The parameter costmodel_gamma used in strategy-searching algorithm.
  101. Raises:
  102. ValueError: If context handle is none.
  103. """
  104. if self._context_handle is None:
  105. raise ValueError("Context handle is none in context!!!")
  106. self._context_handle.set_costmodel_gamma(gamma)
  107. def get_costmodel_gamma(self):
  108. """
  109. Get costmodel gamma.
  110. Raises:
  111. ValueError: If context handle is none.
  112. """
  113. if self._context_handle is None:
  114. raise ValueError("Context handle is none in context!!!")
  115. return self._context_handle.get_costmodel_gamma()
  116. def set_costmodel_communi_threshold(self, threshold):
  117. """
  118. Set costmodel communication threshold.
  119. Args:
  120. threshold (float): A parameter used in adjusting communication calculation for practice.
  121. Raises:
  122. ValueError: If context handle is none.
  123. """
  124. if self._context_handle is None:
  125. raise ValueError("Context handle is none in context!!!")
  126. self._context_handle.set_costmodel_communi_threshold(threshold)
  127. def get_costmodel_communi_threshold(self):
  128. """
  129. Get costmodel communication threshold.
  130. Raises:
  131. ValueError: If context handle is none.
  132. """
  133. if self._context_handle is None:
  134. raise ValueError("Context handle is none in context!!!")
  135. return self._context_handle.get_costmodel_communi_threshold()
  136. def set_costmodel_communi_const(self, communi_const):
  137. """
  138. Set costmodel communication const.
  139. Args:
  140. const (float): A parameter used in adjusting communication calculation for practice.
  141. Raises:
  142. ValueError: If context handle is none.
  143. """
  144. if self._context_handle is None:
  145. raise ValueError("Context handle is none in context!!!")
  146. self._context_handle.set_costmodel_communi_const(communi_const)
  147. def get_costmodel_communi_const(self):
  148. """
  149. Get costmodel communication const.
  150. Raises:
  151. ValueError: If context handle is none.
  152. """
  153. if self._context_handle is None:
  154. raise ValueError("Context handle is none in context!!!")
  155. return self._context_handle.get_costmodel_communi_const()
  156. def set_costmodel_communi_bias(self, communi_bias):
  157. """
  158. Set costmodel communication bias.
  159. Args:
  160. bias (float): A parameter used in adjusting communication calculation for practice.
  161. Raises:
  162. ValueError: If context handle is none.
  163. """
  164. if self._context_handle is None:
  165. raise ValueError("Context handle is none in context!!!")
  166. self._context_handle.set_costmodel_communi_bias(communi_bias)
  167. def get_costmodel_communi_bias(self):
  168. """
  169. Get costmodel communication bias.
  170. Raises:
  171. ValueError: If context handle is none.
  172. """
  173. if self._context_handle is None:
  174. raise ValueError("Context handle is none in context!!!")
  175. return self._context_handle.get_costmodel_communi_bias()
  176. def set_multi_subgraphs(self, multi_subgraph):
  177. """
  178. Set the flag of ANF graph containing multiple subgraphs.
  179. Args:
  180. multi_subgraph (bool): A parameter used in marking the multi-subgraphs flag.
  181. Raises:
  182. ValueError: If context handle is none.
  183. """
  184. if self._context_handle is None:
  185. raise ValueError("Context handle is none in context!!!")
  186. self._context_handle.set_multi_subgraphs(multi_subgraph)
  187. def get_multi_subgraphs(self):
  188. """
  189. Get the flag of ANF graph containing multiple subgraphs.
  190. Raises:
  191. ValueError: If context handle is none.
  192. """
  193. if self._context_handle is None:
  194. raise ValueError("Context handle is none in context!!!")
  195. return self._context_handle.get_multi_subgraphs()
  196. def set_costmodel_allreduce_fusion_algorithm(self, algorithm):
  197. """
  198. Set costmodel allreduce fusion algorithm.
  199. Args:
  200. algorithm (int): The AllReduce fusion algorithm of parameter gradients.
  201. Raises:
  202. ValueError: If context handle is none.
  203. """
  204. if self._context_handle is None:
  205. raise ValueError("Context handle is none in context!!!")
  206. self._context_handle.set_costmodel_allreduce_fusion_algorithm(algorithm)
  207. def get_costmodel_allreduce_fusion_algorithm(self):
  208. """
  209. Get costmodel allreduce fusion algorithm.
  210. Raises:
  211. ValueError: If context handle is none.
  212. """
  213. if self._context_handle is None:
  214. raise ValueError("Context handle is none in context!!!")
  215. return self._context_handle.get_costmodel_allreduce_fusion_algorithm()
  216. def set_costmodel_allreduce_fusion_times(self, allreduce_fusion_times):
  217. """
  218. Set costmodel allreduce fusion times.
  219. Args:
  220. allreduce_fusion_times (int): The AllReduce fusion times of parameter gradients.
  221. Raises:
  222. ValueError: If context handle is none.
  223. """
  224. if self._context_handle is None:
  225. raise ValueError("Context handle is none in context!!!")
  226. self._context_handle.set_costmodel_allreduce_fusion_times(allreduce_fusion_times)
  227. def get_costmodel_allreduce_fusion_times(self):
  228. """
  229. Get costmodel allreduce fusion times.
  230. Raises:
  231. ValueError: If context handle is none.
  232. """
  233. if self._context_handle is None:
  234. raise ValueError("Context handle is none in context!!!")
  235. return self._context_handle.get_costmodel_allreduce_fusion_times()
  236. def set_costmodel_allreduce_fusion_tail_percent(self, tail_percent):
  237. """
  238. Set costmodel allreduce fusion tail percent.
  239. Args:
  240. tail_percent (int): The percentage of backward computing time corresponding to the last parameter gradients
  241. AllReduce in the whole backward computing time.
  242. Raises:
  243. ValueError: If context handle is none.
  244. """
  245. if self._context_handle is None:
  246. raise ValueError("Context handle is none in context!!!")
  247. self._context_handle.set_costmodel_allreduce_fusion_tail_percent(tail_percent)
  248. def get_costmodel_allreduce_fusion_tail_percent(self):
  249. """
  250. Get costmodel allreduce fusion tail percent.
  251. Raises:
  252. ValueError: If context handle is none.
  253. """
  254. if self._context_handle is None:
  255. raise ValueError("Context handle is none in context!!!")
  256. return self._context_handle.get_costmodel_allreduce_fusion_tail_percent()
  257. def set_costmodel_allreduce_fusion_tail_time(self, tail_time):
  258. """
  259. Set costmodel allreduce fusion tail time.
  260. Args:
  261. tail_time (int): The tail time of the last parameter gradients AllReduce after the end of backward
  262. computation.
  263. Raises:
  264. ValueError: If context handle is none.
  265. """
  266. if self._context_handle is None:
  267. raise ValueError("Context handle is none in context!!!")
  268. self._context_handle.set_costmodel_allreduce_fusion_tail_time(tail_time)
  269. def get_costmodel_allreduce_fusion_tail_time(self):
  270. """
  271. Get costmodel allreduce fusion tail time.
  272. Raises:
  273. ValueError: If context handle is none.
  274. """
  275. if self._context_handle is None:
  276. raise ValueError("Context handle is none in context!!!")
  277. return self._context_handle.get_costmodel_allreduce_fusion_tail_time()
  278. def set_costmodel_allreduce_fusion_allreduce_inherent_time(self, allreduce_inherent_time):
  279. """
  280. Set costmodel allreduce fusion allreduce inherent time.
  281. Args:
  282. allreduce_inherent_time (int): The inherent cost time of AllReduce.
  283. Raises:
  284. ValueError: If context handle is none.
  285. """
  286. if self._context_handle is None:
  287. raise ValueError("Context handle is none in context!!!")
  288. self._context_handle.set_costmodel_allreduce_fusion_allreduce_inherent_time(allreduce_inherent_time)
  289. def get_costmodel_allreduce_fusion_allreduce_inherent_time(self):
  290. """
  291. Get costmodel allreduce fusion allreduce inherent time.
  292. Raises:
  293. ValueError: If context handle is none.
  294. """
  295. if self._context_handle is None:
  296. raise ValueError("Context handle is none in context!!!")
  297. return self._context_handle.get_costmodel_allreduce_fusion_allreduce_inherent_time()
  298. def set_costmodel_allreduce_fusion_allreduce_bandwidth(self, allreduce_bandwidth):
  299. """
  300. Set costmodel allreduce fusion allreduce bandwidth.
  301. Args:
  302. allreduce_bandwidth (int): The bancwidth of AllReduce.
  303. Raises:
  304. ValueError: If context handle is none.
  305. """
  306. if self._context_handle is None:
  307. raise ValueError("Context handle is none in context!!!")
  308. self._context_handle.set_costmodel_allreduce_fusion_allreduce_bandwidth(allreduce_bandwidth)
  309. def get_costmodel_allreduce_fusion_allreduce_bandwidth(self):
  310. """
  311. Get costmodel allreduce fusion allreduce bandwidth.
  312. Raises:
  313. ValueError: If context handle is none.
  314. """
  315. if self._context_handle is None:
  316. raise ValueError("Context handle is none in context!!!")
  317. return self._context_handle.get_costmodel_allreduce_fusion_allreduce_bandwidth()
  318. def set_costmodel_allreduce_fusion_computation_time_parameter(self, computation_time_parameter):
  319. """
  320. Set costmodel allreduce fusion computation time parameter.
  321. Args:
  322. computation_time_parameter (int): The parameter used to compute backward computation time.
  323. Raises:
  324. ValueError: If context handle is none.
  325. """
  326. if self._context_handle is None:
  327. raise ValueError("Context handle is none in context!!!")
  328. self._context_handle.set_costmodel_allreduce_fusion_computation_time_parameter(computation_time_parameter)
  329. def get_costmodel_allreduce_fusion_computation_time_parameter(self):
  330. """
  331. Get costmodel allreduce fusion computation time parameter.
  332. Raises:
  333. ValueError: If context handle is none.
  334. """
  335. if self._context_handle is None:
  336. raise ValueError("Context handle is none in context!!!")
  337. return self._context_handle.get_costmodel_allreduce_fusion_computation_time_parameter()
  338. def reset_cost_model(self):
  339. """
  340. Reset cost model settings.
  341. Raises:
  342. ValueError: If context handle is none.
  343. """
  344. if self._context_handle is None:
  345. raise ValueError("Context handle is none in context!!!")
  346. self._context_handle.reset_cost_model()
  347. _cost_model_context = None
  348. def cost_model_context():
  349. """
  350. Get the global _cost_model_context. If it is not created, create a new one.
  351. Returns:
  352. The global cost_model context.
  353. """
  354. global _cost_model_context
  355. if _cost_model_context is None:
  356. _cost_model_context = _CostModelContext()
  357. return _cost_model_context
  358. set_cost_model_context_func_map = {
  359. "device_memory_capacity": cost_model_context().set_device_memory_capacity,
  360. "costmodel_alpha": cost_model_context().set_costmodel_alpha,
  361. "costmodel_beta": cost_model_context().set_costmodel_beta,
  362. "costmodel_gamma": cost_model_context().set_costmodel_gamma,
  363. "costmodel_communi_threshold": cost_model_context().set_costmodel_communi_threshold,
  364. "costmodel_communi_const": cost_model_context().set_costmodel_communi_const,
  365. "costmodel_communi_bias": cost_model_context().set_costmodel_communi_bias,
  366. "multi_subgraphs": cost_model_context().set_multi_subgraphs,
  367. "costmodel_allreduce_fusion_algorithm": cost_model_context().set_costmodel_allreduce_fusion_algorithm,
  368. "costmodel_allreduce_fusion_times": cost_model_context().set_costmodel_allreduce_fusion_times,
  369. "costmodel_allreduce_fusion_tail_percent": cost_model_context().set_costmodel_allreduce_fusion_tail_percent,
  370. "costmodel_allreduce_fusion_tail_time": cost_model_context().set_costmodel_allreduce_fusion_tail_time,
  371. "costmodel_allreduce_fusion_allreduce_inherent_time":
  372. cost_model_context().set_costmodel_allreduce_fusion_allreduce_inherent_time,
  373. "costmodel_allreduce_fusion_allreduce_bandwidth":
  374. cost_model_context().set_costmodel_allreduce_fusion_allreduce_bandwidth,
  375. "costmodel_allreduce_fusion_computation_time_parameter":
  376. cost_model_context().set_costmodel_allreduce_fusion_computation_time_parameter}
  377. get_cost_model_context_func_map = {
  378. "device_memory_capacity": cost_model_context().get_device_memory_capacity,
  379. "costmodel_alpha": cost_model_context().get_costmodel_alpha,
  380. "costmodel_beta": cost_model_context().get_costmodel_beta,
  381. "costmodel_gamma": cost_model_context().get_costmodel_gamma,
  382. "costmodel_communi_threshold": cost_model_context().get_costmodel_communi_threshold,
  383. "costmodel_communi_const": cost_model_context().get_costmodel_communi_const,
  384. "costmodel_communi_bias": cost_model_context().get_costmodel_communi_bias,
  385. "multi_subgraphs": cost_model_context().get_multi_subgraphs(),
  386. "costmodel_allreduce_fusion_algorithm": cost_model_context().get_costmodel_allreduce_fusion_algorithm,
  387. "costmodel_allreduce_fusion_times": cost_model_context().get_costmodel_allreduce_fusion_times,
  388. "costmodel_allreduce_fusion_tail_percent": cost_model_context().get_costmodel_allreduce_fusion_tail_percent,
  389. "costmodel_allreduce_fusion_tail_time": cost_model_context().get_costmodel_allreduce_fusion_tail_time,
  390. "costmodel_allreduce_fusion_allreduce_inherent_time":
  391. cost_model_context().get_costmodel_allreduce_fusion_allreduce_inherent_time,
  392. "costmodel_allreduce_fusion_allreduce_bandwidth":
  393. cost_model_context().get_costmodel_allreduce_fusion_allreduce_bandwidth,
  394. "costmodel_allreduce_fusion_computation_time_parameter":
  395. cost_model_context().get_costmodel_allreduce_fusion_computation_time_parameter}
  396. @args_type_check(device_memory_capacity=float, costmodel_alpha=float, costmodel_beta=float, costmodel_gamma=float,
  397. costmodel_communi_threshold=float, costmodel_communi_const=float, costmodel_communi_bias=float,
  398. multi_subgraphs=bool,
  399. costmodel_allreduce_fusion_algorithm=int, costmodel_allreduce_fusion_times=int,
  400. costmodel_allreduce_fusion_tail_percent=float, costmodel_allreduce_fusion_tail_time=float,
  401. costmodel_allreduce_fusion_allreduce_inherent_time=float,
  402. costmodel_allreduce_fusion_allreduce_bandwidth=float,
  403. costmodel_allreduce_fusion_computation_time_parameter=float)
  404. def set_cost_model_context(**kwargs):
  405. """
  406. Set cost model context.
  407. Note:
  408. Attribute name is needed.
  409. Args:
  410. device_memory_capacity (float): The memory capacity for each device.
  411. costmodel_alpha (float): The parameter costmodel_alpha used in strategy-searching algorithm.
  412. costmodel_beta (float): The parameter costmodel_beta used in strategy-searching algorithm.
  413. costmodel_gamma (float): The parameter costmodel_gamma used in strategy-searching algorithm.
  414. costmodel_communi_threshold (float): A parameter used in adjusting communication calculation for practice.
  415. costmodel_communi_const (float): A parameter used in adjusting communication calculation for practice.
  416. costmodel_communi_bias (float): A parameter used in adjusting communication calculation for practice.
  417. multi_subgraphs (bool): A parameter used in marking the flag of ANF graph containing multiple subgraphs.
  418. costmodel_allreduce_fusion_algorithm (int): The allreduce fusion algorithm.
  419. 0: bypass allreduce fusion;
  420. 1: only use backward computation time to group allreduce;
  421. 2: use backward computation time and parameter gradient allreduce time to group allreduce.
  422. costmodel_allreduce_fusion_times (int): The AllReduce fusion times of parameter gradients.
  423. costmodel_allreduce_fusion_tail_percent (float): A parameter used in allreduce fusion algorithm. The percentage
  424. of backward computing time corresponding to the last parameter gradients AllReduce in the whole backward
  425. computing time.
  426. costmodel_allreduce_fusion_tail_time (float): A parameter used in allreduce fusion algorithm. The tail time of
  427. the last parameter gradients AllReduce after the end of backward computation.
  428. costmodel_allreduce_fusion_allreduce_inherent_time (float): A parameter used in allreduce fusion algorithm. The
  429. inherent cost time of AllReduce.
  430. costmodel_allreduce_fusion_allreduce_bandwidth (float): A parameter used in allreduce fusion algorithm. The
  431. bandwidth of AllReduce.
  432. costmodel_allreduce_fusion_computation_time_parameter (float): A parameter used in allreduce fusion algorithm.
  433. The parameter used to compute backward computation time.
  434. Raises:
  435. ValueError: If context keyword is not recognized.
  436. """
  437. for key, value in kwargs.items():
  438. if key not in set_cost_model_context_func_map:
  439. raise ValueError("Set context keyword %s is not recognized!" % key)
  440. set_func = set_cost_model_context_func_map[key]
  441. set_func(value)
  442. def get_cost_model_context(attr_key):
  443. """
  444. Get cost model context attributes.
  445. Note:
  446. Return value according to the attribute value.
  447. Args:
  448. attr_key (str): The key of the attribute.
  449. Raises:
  450. ValueError: If context keyword is not recognized.
  451. """
  452. if attr_key not in get_cost_model_context_func_map:
  453. raise ValueError("Get context keyword %s is not recognized!" % attr_key)
  454. get_func = get_cost_model_context_func_map[attr_key]
  455. return get_func()
  456. def reset_cost_model_context():
  457. """Reset cost model context attributes."""
  458. cost_model_context().reset_cost_model()