You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

base.py 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """base process"""
  16. import copy
  17. import mindspore.nn as nn
  18. from mindspore.nn.optim import LARS
  19. from mindspore import log as logger
  20. from mindspore.common import Parameter
  21. from .less_batch_normalization import CommonHeadLastFN
  22. __all__ = ["OptimizerProcess", "ParameterProcess"]
  23. class OptimizerProcess:
  24. r"""
  25. Process optimizer for Boost. Currently, this class supports adding GC(grad centralization) tags
  26. and creating new optimizers.
  27. Args:
  28. opt (Cell): Optimizer used.
  29. Examples:
  30. >>> from mindspore import Tensor, Parameter, nn
  31. >>> import mindspore.ops import ops
  32. >>> from mindspore.boost import OptimizerProcess
  33. >>>
  34. >>> class Net(nn.Cell):
  35. ... def __init__(self, in_features, out_features):
  36. ... super(Net, self).__init__()
  37. ... self.weight = Parameter(Tensor(np.ones([in_features, out_features]).astype(np.float32)),
  38. ... name='weight')
  39. ... self.matmul = ops.MatMul()
  40. ...
  41. ... def construct(self, x):
  42. ... output = self.matmul(x, self.weight)
  43. ... return output
  44. ...
  45. >>> size, in_features, out_features = 16, 16, 10
  46. >>> network = Net(in_features, out_features)
  47. >>> optimizer = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
  48. >>> optimizer_process = OptimizerProcess(optimizer)
  49. >>> optimizer_process.add_grad_centralization(network)
  50. >>> optimizer = optimizer_process.generate_new_optimizer()
  51. """
  52. def __init__(self, opt):
  53. if isinstance(opt, LARS):
  54. self.is_lars = True
  55. self.opt_class = type(opt.opt)
  56. self.opt_init_args = opt.opt.init_args
  57. self.lars_init_args = opt.init_args
  58. else:
  59. self.is_lars = False
  60. self.opt_class = type(opt)
  61. self.opt_init_args = opt.init_args
  62. self.origin_params = opt.init_params["params"]
  63. def build_params_dict(self, network):
  64. r"""
  65. Build the parameter's dict of the network.
  66. Args:
  67. network (Cell): The training network.
  68. """
  69. cells = network.cells_and_names()
  70. params_dict = {}
  71. for _, cell in cells:
  72. for par in cell.get_parameters(expand=False):
  73. params_dict[id(par)] = cell
  74. return params_dict
  75. def build_gc_params_group(self, params_dict, parameters):
  76. r"""
  77. Build the parameter's group with grad centralization.
  78. Args:
  79. params_dict (dict): The network's parameter dict.
  80. parameters (list): The network's parameter list.
  81. """
  82. group_params = []
  83. for group_param in parameters:
  84. if 'order_params' in group_param.keys():
  85. group_params.append(group_param)
  86. continue
  87. params_gc_value = []
  88. params_value = []
  89. for param in group_param['params']:
  90. if 'beta' not in param.name and 'gamma' not in param.name and 'bias' not in param.name:
  91. param_cell = params_dict[id(param)]
  92. if (isinstance(param_cell, nn.Conv2d) and param_cell.group > 1) or \
  93. isinstance(param_cell, CommonHeadLastFN):
  94. params_value.append(param)
  95. else:
  96. params_gc_value.append(param)
  97. else:
  98. params_value.append(param)
  99. if params_gc_value:
  100. new_group_param = copy.deepcopy(group_param)
  101. new_group_param['params'] = params_gc_value
  102. new_group_param['grad_centralization'] = True
  103. group_params.append(new_group_param)
  104. if params_value:
  105. new_group_param = copy.deepcopy(group_param)
  106. new_group_param['params'] = params_value
  107. group_params.append(new_group_param)
  108. return group_params
  109. def add_grad_centralization(self, network):
  110. r"""
  111. Add gradient centralization.
  112. Args:
  113. network (Cell): The training network.
  114. """
  115. params_dict = self.build_params_dict(network)
  116. parameters = self.origin_params
  117. if parameters is not None and not isinstance(parameters, list):
  118. parameters = list(parameters)
  119. if not parameters:
  120. raise ValueError("Optimizer got an empty parameter list.")
  121. if not isinstance(parameters[0], (dict, Parameter)):
  122. raise TypeError("Only a list of Parameter or dict can be supported.")
  123. if isinstance(parameters[0], Parameter):
  124. logger.warning("Only group parameters support gradient centralization.")
  125. return
  126. self.origin_params = self.build_gc_params_group(params_dict, parameters)
  127. def generate_new_optimizer(self):
  128. """Generate new optimizer."""
  129. if not self.is_lars:
  130. opt = self.opt_class(params=self.origin_params, **self.opt_init_args)
  131. else:
  132. opt = LARS(self.opt_class(params=self.origin_params, **self.opt_init_args), **self.lars_init_args)
  133. return opt
  134. class ParameterProcess:
  135. r"""
  136. Process parameter for Boost. Currently, this class supports creating group parameters
  137. and automatically setting gradient segmentation point.
  138. Examples:
  139. >>> from mindspore import Tensor, Parameter, nn
  140. >>> import mindspore.ops as ops
  141. >>> from mindspore.boost import OptimizerProcess
  142. >>>
  143. >>> class Net(nn.Cell):
  144. ... def __init__(self, in_features, out_features):
  145. ... super(Net, self).__init__()
  146. ... self.weight = Parameter(Tensor(np.ones([in_features, out_features]).astype(np.float32)),
  147. ... name='weight')
  148. ... self.weight2 = Parameter(Tensor(np.ones([in_features, out_features]).astype(np.float32)),
  149. ... name='weight2')
  150. ... self.matmul = ops.MatMul()
  151. ... self.matmul2 = ops.MatMul()
  152. ...
  153. ... def construct(self, x):
  154. ... output = self.matmul(x, self.weight)
  155. ... output2 = self.matmul2(x, self.weight2)
  156. ... return output + output2
  157. ...
  158. >>> size, in_features, out_features = 16, 16, 10
  159. >>> network = Net(in_features, out_features)
  160. >>> new_parameter = net.trainable_params()[:1]
  161. >>> parameter_process = ParameterProcess()
  162. >>> group_params = parameter_process.generate_group_params(new_parameter, net.trainable_params())
  163. """
  164. def __init__(self):
  165. self._parameter_indices = 1
  166. def assign_parameter_group(self, parameters, split_point=None):
  167. r"""
  168. Assign parameter group.
  169. Args:
  170. parameters (list): The network's parameter list.
  171. split_point (list): The gradient split point of this network. default: None.
  172. """
  173. if not isinstance(parameters, (list, tuple)) or not parameters:
  174. return parameters
  175. parameter_len = len(parameters)
  176. if split_point:
  177. split_parameter_index = split_point
  178. else:
  179. split_parameter_index = [parameter_len // 2]
  180. for i in range(parameter_len):
  181. if i in split_parameter_index:
  182. self._parameter_indices += 1
  183. parameters[i].comm_fusion = self._parameter_indices
  184. return parameters
  185. def generate_group_params(self, parameters, origin_params):
  186. r"""
  187. Generate group parameters.
  188. Args:
  189. parameters (list): The network's parameter list.
  190. origin_params (list): The network's origin parameter list.
  191. """
  192. origin_params_copy = origin_params
  193. if origin_params_copy is not None:
  194. if not isinstance(origin_params_copy, list):
  195. origin_params_copy = list(origin_params_copy)
  196. if not origin_params_copy:
  197. raise ValueError("Optimizer got an empty parameter list.")
  198. if not isinstance(origin_params_copy[0], (dict, Parameter)):
  199. raise TypeError("Only a list of Parameter or dict can be supported.")
  200. if isinstance(origin_params_copy[0], Parameter):
  201. group_params = [{"params": parameters}]
  202. return group_params
  203. group_params = []
  204. params_name = [param.name for param in parameters]
  205. new_params_count = copy.deepcopy(params_name)
  206. new_params_clone = {}
  207. max_key_number = 0
  208. for group_param in origin_params_copy:
  209. if 'order_params' in group_param.keys():
  210. new_group_param = copy.deepcopy(group_param)
  211. new_group_param['order_params'] = parameters
  212. group_params.append(new_group_param)
  213. continue
  214. params_value = []
  215. for param in group_param['params']:
  216. if param.name in params_name:
  217. index = params_name.index(param.name)
  218. params_value.append(parameters[index])
  219. new_params_count.remove(param.name)
  220. new_group_param = copy.deepcopy(group_param)
  221. new_group_param['params'] = params_value
  222. group_params.append(new_group_param)
  223. if len(group_param.keys()) > max_key_number:
  224. max_key_number = len(group_param.keys())
  225. new_params_clone = copy.deepcopy(group_param)
  226. if new_params_count:
  227. params_value = []
  228. for param in new_params_count:
  229. index = params_name.index(param)
  230. params_value.append(parameters[index])
  231. if new_params_clone:
  232. new_params_clone['params'] = params_value
  233. group_params.append(new_params_clone)
  234. else:
  235. group_params.append({"params": params_value})
  236. return group_params