You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

boost.py 9.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """boost"""
  16. import threading
  17. from .less_batch_normalization import LessBN
  18. from .grad_freeze import GradientFreeze
  19. from .base import OptimizerProcess, ParameterProcess
  20. __all__ = ["AutoBoost"]
  21. _boost_config_mode = ["auto", "manual", "enable_all", "disable_all"]
  22. _boost_config_level = {
  23. "O0": {
  24. "less_bn": False,
  25. "grad_freeze": False,
  26. "adasum": False,
  27. "grad_accumulation": False},
  28. "O1": {
  29. "less_bn": True,
  30. "grad_freeze": True,
  31. "adasum": False,
  32. "grad_accumulation": False},
  33. "O2": {
  34. "less_bn": True,
  35. "grad_freeze": True,
  36. "adasum": True,
  37. "grad_accumulation": False}}
  38. class AutoBoost:
  39. r"""
  40. Provide auto accelerating for network.
  41. Args:
  42. level (str): Boost config level. Default: "O0".
  43. boost_config_dict (dict): User config hyperparameter dict, recommended config format:
  44. .. code-block::
  45. {
  46. "boost": {
  47. "//": "suggest mode: ["auto", "manual", "enable_all", "disable_all"]",
  48. "mode": "auto",
  49. "less_bn": false,
  50. "grad_freeze": false,
  51. "adasum": false,
  52. "grad_accumulation": false
  53. },
  54. "common": {
  55. "gradient_split_groups": [50, 100]
  56. },
  57. "less_bn": {
  58. "fn_flag": true,
  59. "gc_flag": true
  60. },
  61. "grad_freeze": {
  62. "param_groups": 10,
  63. "freeze_type": 1,
  64. "freeze_p": 0.7,
  65. "total_steps": 65536
  66. },
  67. "adasum": {
  68. "device_number": 8
  69. },
  70. "grad_accumulation": {
  71. "grad_accumulation_step": 1
  72. }
  73. }
  74. User can load the config through the JSON file or use the dictionary directly.
  75. The unconfigured parameters will adopt the default values. Default: "".
  76. Raises:
  77. ValueError: The boost mode not in ["auto", "manual", "enable_all", "disable_all"].
  78. Supported Platforms:
  79. ``Ascend``
  80. Examples:
  81. >>> from mindspore.boost import AutoBoost
  82. >>> #1) when configuring the dict directly:
  83. >>> boost_config_dict = {"boost": {"mode": "auto"}}
  84. >>> boost = AutoBoost("O1", boost_config_dict)
  85. >>>
  86. >>> #2) when loading the dict from a json file:
  87. >>> import json
  88. >>> boost_json = "/path/boost_config.json"
  89. >>> with open(boost_json, 'r') as fp:
  90. >>> boost_config_dict = json.load(fp)
  91. >>> boost = AutoBoost("O1", boost_config_dict)
  92. """
  93. _instance_lock = threading.Lock()
  94. _instance = None
  95. def __init__(self, level="O0", boost_config_dict=""):
  96. if level not in _boost_config_level.keys():
  97. level = "O0"
  98. if self._instance.level is None:
  99. self.level = level
  100. self.boost_config_dict = boost_config_dict
  101. self._fn_flag = True
  102. self._gc_flag = True
  103. self._param_groups = 10
  104. self._freeze_type = 1
  105. self._freeze_p = 0.7
  106. self._total_steps = 65536
  107. self.gradient_groups = None
  108. self.device_number = 8
  109. self.grad_accumulation_step = 1
  110. self.boost_config = self._get_configuration(level, self.boost_config_dict)
  111. self._param_processer = ParameterProcess()
  112. # pylint: disable=unused-argument
  113. def __new__(cls, *args, **kwargs):
  114. if AutoBoost._instance is None:
  115. with AutoBoost._instance_lock:
  116. if AutoBoost._instance is None:
  117. AutoBoost._instance = object.__new__(cls)
  118. AutoBoost._instance.level = None
  119. AutoBoost._instance.boost_config_dict = None
  120. return AutoBoost._instance
  121. def network_auto_process_train(self, network, optimizer):
  122. r"""
  123. Boost network train.
  124. Args:
  125. network (Cell): The training network.
  126. optimizer (Cell): Optimizer for updating the weights.
  127. """
  128. if self.boost_config["less_bn"]:
  129. network = LessBN(network, fn_flag=self._fn_flag)
  130. optimizer_process = OptimizerProcess(optimizer)
  131. group_params = self._param_processer.assign_parameter_group(network.trainable_params(),
  132. self.gradient_groups)
  133. optimizer_process.origin_params = \
  134. self._param_processer.generate_group_params(group_params, optimizer_process.origin_params)
  135. if self._gc_flag:
  136. optimizer_process.add_grad_centralization(network)
  137. optimizer = optimizer_process.generate_new_optimizer()
  138. if self.boost_config["grad_freeze"]:
  139. freeze_processer = GradientFreeze(self._param_groups, self._freeze_type,
  140. self._freeze_p, self._total_steps)
  141. network, optimizer = freeze_processer.freeze_generate(network, optimizer)
  142. if self.boost_config["adasum"]:
  143. setattr(optimizer, "adasum", True)
  144. return network, optimizer
  145. def network_auto_process_eval(self, network):
  146. r"""
  147. Boost network eval.
  148. Args:
  149. network (Cell): The inference network.
  150. """
  151. if self.boost_config["less_bn"]:
  152. network = LessBN(network)
  153. return network
  154. def set_fn_flag(self, fn_flag):
  155. self._fn_flag = fn_flag
  156. def set_gc_flag(self, gc_flag):
  157. self._gc_flag = gc_flag
  158. def set_param_groups(self, param_groups):
  159. self._param_groups = param_groups
  160. def set_freeze_type(self, freeze_type):
  161. self._freeze_type = freeze_type
  162. def set_freeze_p(self, freeze_p):
  163. self._freeze_p = freeze_p
  164. def set_total_steps(self, total_steps):
  165. self._total_steps = total_steps
  166. def set_device_number(self, device_number):
  167. self.device_number = device_number
  168. def set_grad_accumulation_step(self, grad_accumulation_step):
  169. self.grad_accumulation_step = grad_accumulation_step
  170. def set_gradient_split_groups(self, gradient_groups):
  171. if not isinstance(gradient_groups, (list, int)):
  172. raise ValueError(f"gradient_groups `{gradient_groups}` is not in (list, int)")
  173. if isinstance(gradient_groups, int):
  174. gradient_groups = list(gradient_groups)
  175. self.gradient_groups = gradient_groups
  176. def _get_configuration(self, level, boost_config_dict):
  177. """Get configuration."""
  178. level_config = _boost_config_level[level]
  179. if not boost_config_dict:
  180. return level_config
  181. mode = "auto"
  182. if 'boost' in boost_config_dict and 'mode' in boost_config_dict['boost']:
  183. mode = boost_config_dict['boost']['mode']
  184. if mode not in _boost_config_mode:
  185. raise ValueError("The boost mode must be in {}, but got {}".format(_boost_config_mode, mode))
  186. if mode == "manual":
  187. for key, value in boost_config_dict["boost"].items():
  188. if key in level_config:
  189. level_config[key] = value
  190. elif mode == "enable_all":
  191. level_config = {key: True for key in level_config}
  192. elif mode == "disable_all":
  193. level_config = {key: False for key in level_config}
  194. for key, boost_each_mode_config in boost_config_dict.items():
  195. if key in level_config.keys() and level_config[key] or key == "common":
  196. for key_s in boost_each_mode_config.keys():
  197. if key_s in self._boost_config_func_map:
  198. self._boost_config_func_map[key_s](self, boost_each_mode_config[key_s])
  199. return level_config
  200. _boost_config_func_map = {
  201. "fn_flag": set_fn_flag,
  202. "gc_flag": set_gc_flag,
  203. "param_groups": set_param_groups,
  204. "freeze_type": set_freeze_type,
  205. "freeze_p": set_freeze_p,
  206. "total_steps": set_total_steps,
  207. "device_number": set_device_number,
  208. "gradient_split_groups": set_gradient_split_groups,
  209. "grad_accumulation_step": set_grad_accumulation_step
  210. }