Browse Source

!25858 [Boost] Add boost config dict.

Merge pull request !25858 from linqingke/boost
tags/v1.6.0
i-robot Gitee 4 years ago
parent
commit
2fa352a115
6 changed files with 220 additions and 100 deletions
  1. +13
    -13
      mindspore/boost/base.py
  2. +135
    -39
      mindspore/boost/boost.py
  3. +19
    -6
      mindspore/boost/boost_cell_wrapper.py
  4. +22
    -22
      mindspore/boost/grad_freeze.py
  5. +24
    -17
      mindspore/train/amp.py
  6. +7
    -3
      mindspore/train/model.py

+ 13
- 13
mindspore/boost/base.py View File

@@ -71,8 +71,8 @@ class OptimizerProcess:
r"""
Build the parameter's dict of the network.

Inputs:
- **network** (Cell) - The training network.
Args:
network (Cell) - The training network.
"""
cells = network.cells_and_names()
params_dict = {}
@@ -85,9 +85,9 @@ class OptimizerProcess:
r"""
Build the parameter's group with grad centralization.

Inputs:
- **params_dict** (dict) - The network's parameter dict.
- **parameters** (list) - The network's parameter list.
Args:
params_dict (dict) - The network's parameter dict.
parameters (list) - The network's parameter list.
"""
group_params = []
for group_param in parameters:
@@ -121,8 +121,8 @@ class OptimizerProcess:
r"""
Add gradient centralization.

Inputs:
- **network** (Cell) - The training network.
Args:
network (Cell) - The training network.
"""
params_dict = self.build_params_dict(network)

@@ -190,9 +190,9 @@ class ParameterProcess:
r"""
Assign parameter group.

Inputs:
- **parameters** (list) - The network's parameter list.
- **split_point** (list) - The gradient split point of this network. default: None.
Args:
parameters (list) - The network's parameter list.
split_point (list) - The gradient split point of this network. default: None.
"""
if not isinstance(parameters, (list, tuple)) or not parameters:
return parameters
@@ -212,9 +212,9 @@ class ParameterProcess:
r"""
Generate group parameters.

Inputs:
- **parameters** (list) - The network's parameter list.
- **origin_params** (list) - The network's origin parameter list.
Args:
parameters (list) - The network's parameter list.
origin_params (list) - The network's origin parameter list.
"""
origin_params_copy = origin_params
if origin_params_copy is not None:


+ 135
- 39
mindspore/boost/boost.py View File

@@ -13,6 +13,7 @@
# limitations under the License.
# ============================================================================
"""boost"""
import threading
from .less_batch_normalization import LessBN
from .grad_freeze import GradientFreeze
from .base import OptimizerProcess, ParameterProcess
@@ -20,20 +21,23 @@ from .base import OptimizerProcess, ParameterProcess
__all__ = ["AutoBoost"]
_boost_config_mode = ["auto", "manual", "enable_all", "disable_all"]
_boost_config_level = {
"O0": {
"less_bn": False,
"grad_freeze": False,
"adasum": False},
"adasum": False,
"grad_accumulation": False},
"O1": {
"less_bn": True,
"grad_freeze": True,
"adasum": False},
"adasum": False,
"grad_accumulation": False},
"O2": {
"less_bn": True,
"grad_freeze": True,
"adasum": True}}
"adasum": True,
"grad_accumulation": False}}
class AutoBoost:
@@ -41,57 +45,115 @@ class AutoBoost:
Provide auto accelerating for network.
Args:
level (str): boost config level.
kwargs (any): Additional configuration parameters related to boost.
level (str): Boost config level. Default: "O0".
boost_config_dict (dict): User config hyperparameter dict, recommended config format:
{
"boost": {
"//": "suggest mode: ["auto", "manual", "enable_all", "disable_all"]",
"mode": "auto",
"less_bn": false,
"grad_freeze": false,
"adasum": false,
"grad_accumulation": false
},
"common": {
"gradient_split_groups": [50, 100]
},
"less_bn": {
"fn_flag": true,
"gc_flag": true
},
"grad_freeze": {
"param_groups": 10,
"freeze_type": 1,
"freeze_p": 0.7,
"total_steps": 65536
},
"adasum": {
"device_number": 8
},
"grad_accumulation": {
"grad_accumulation_step": 1
}
}
User can load the config through the JSON file or use the dictionary directly.
The unconfigured parameters will adopt the default values. Default: "".
Raises:
ValueError: The boost mode not in ["auto", "manual", "enable_all", "disable_all"].
Supported Platforms:
``Ascend``
Examples:
>>> from mindspore.boost import AutoBoost
>>> #1) when configuring the dict directly:
>>> boost_config_dict = {"boost": {"mode": "auto"}}
>>> boost = AutoBoost("O1", boost_config_dict)
>>>
>>> #2) when loading the dict from a json file:
>>> import json
>>> boost_json = "/path/boost_config.json"
>>> with open(boost_json, 'r') as fp:
>>> boost_config_dict = json.load(fp)
>>> boost = AutoBoost("O1", boost_config_dict)
"""
def __init__(self, level, kwargs):
_instance_lock = threading.Lock()
_instance = None
def __init__(self, level="O0", boost_config_dict=""):
if level not in _boost_config_level.keys():
level = 'O0'
self.level = level
boost_config = _boost_config_level[level]
self._boost_config = boost_config
self._fn_flag = True
self._gc_flag = True
self._param_groups = 10
self._freeze_type = 1
self._freeze_p = 0.7
self._total_steps = 65536
self._gradient_groups = None
self._get_configuration(kwargs)
self._param_processer = ParameterProcess()
def _get_configuration(self, kwargs):
"""Get configuration."""
for key, val in kwargs.items():
if key not in self._boost_config_func_map.keys():
continue
self._boost_config_func_map[key](self, val)
level = "O0"
if self._instance.level is None:
self.level = level
self.boost_config_dict = boost_config_dict
self._fn_flag = True
self._gc_flag = True
self._param_groups = 10
self._freeze_type = 1
self._freeze_p = 0.7
self._total_steps = 65536
self.gradient_groups = None
self.device_number = 8
self.grad_accumulation_step = 1
self.boost_config = self._get_configuration(level, self.boost_config_dict)
self._param_processer = ParameterProcess()
# pylint: disable=unused-argument
def __new__(cls, *args, **kwargs):
if AutoBoost._instance is None:
with AutoBoost._instance_lock:
if AutoBoost._instance is None:
AutoBoost._instance = object.__new__(cls)
AutoBoost._instance.level = None
AutoBoost._instance.boost_config_dict = None
return AutoBoost._instance
def network_auto_process_train(self, network, optimizer):
r"""
Boost network train.
Inputs:
- **network** (Cell) - The training network.
- **optimizer** (Cell) - Optimizer for updating the weights.
Args:
network (Cell) - The training network.
optimizer (Cell) - Optimizer for updating the weights.
"""
if self._boost_config["less_bn"]:
if self.boost_config["less_bn"]:
network = LessBN(network, fn_flag=self._fn_flag)
optimizer_process = OptimizerProcess(optimizer)
group_params = self._param_processer.assign_parameter_group(network.trainable_params(),
self._gradient_groups)
self.gradient_groups)
optimizer_process.origin_params = \
self._param_processer.generate_group_params(group_params, optimizer_process.origin_params)
if self._gc_flag:
optimizer_process.add_grad_centralization(network)
optimizer = optimizer_process.generate_new_optimizer()
if self._boost_config["grad_freeze"]:
if self.boost_config["grad_freeze"]:
freeze_processer = GradientFreeze(self._param_groups, self._freeze_type,
self._freeze_p, self._total_steps)
network, optimizer = freeze_processer.freeze_generate(network, optimizer)
if self._boost_config["adasum"]:
if self.boost_config["adasum"]:
setattr(optimizer, "adasum", True)
return network, optimizer
@@ -100,9 +162,9 @@ class AutoBoost:
Boost network eval.
Args:
- **network** (Cell) - The inference network.
network (Cell) - The inference network.
"""
if self._boost_config["less_bn"]:
if self.boost_config["less_bn"]:
network = LessBN(network)
return network
@@ -125,12 +187,44 @@ class AutoBoost:
def set_total_steps(self, total_steps):
self._total_steps = total_steps
def set_gradient_groups(self, gradient_groups):
def set_device_number(self, device_number):
self.device_number = device_number
def set_grad_accumulation_step(self, grad_accumulation_step):
self.grad_accumulation_step = grad_accumulation_step
def set_gradient_split_groups(self, gradient_groups):
if not isinstance(gradient_groups, (list, int)):
raise ValueError(f"gradient_groups `{gradient_groups}` is not in (list, int)")
if isinstance(gradient_groups, int):
gradient_groups = list(gradient_groups)
self._gradient_groups = gradient_groups
self.gradient_groups = gradient_groups
def _get_configuration(self, level, boost_config_dict):
"""Get configuration."""
level_config = _boost_config_level[level]
if not boost_config_dict:
return level_config
mode = "auto"
if 'boost' in boost_config_dict and 'mode' in boost_config_dict['boost']:
mode = boost_config_dict['boost']['mode']
if mode not in _boost_config_mode:
raise ValueError("The boost mode must be in {}, but got {}".format(_boost_config_mode, mode))
if mode == "manual":
for key, value in boost_config_dict["boost"].items():
if key in level_config:
level_config[key] = value
elif mode == "enable_all":
level_config = {key: True for key in level_config}
elif mode == "disable_all":
level_config = {key: False for key in level_config}
for key, boost_each_mode_config in boost_config_dict.items():
if key in level_config.keys() and level_config[key] or key == "common":
for key_s in boost_each_mode_config.keys():
if key_s in self._boost_config_func_map:
self._boost_config_func_map[key_s](self, boost_each_mode_config[key_s])
return level_config
_boost_config_func_map = {
"fn_flag": set_fn_flag,
@@ -139,5 +233,7 @@ class AutoBoost:
"freeze_type": set_freeze_type,
"freeze_p": set_freeze_p,
"total_steps": set_total_steps,
"gradient_groups": set_gradient_groups
"device_number": set_device_number,
"gradient_split_groups": set_gradient_split_groups,
"grad_accumulation_step": set_grad_accumulation_step
}

+ 19
- 6
mindspore/boost/boost_cell_wrapper.py View File

@@ -15,7 +15,7 @@
"""Boost Mode Cell Wrapper."""
from mindspore.nn.wrap import TrainOneStepCell
import mindspore.context as context
from mindspore.context import ParallelMode, get_auto_parallel_context
from mindspore.context import ParallelMode
from mindspore.parallel._utils import _get_global_rank, _get_device_num, _get_gradients_mean
from mindspore.communication.management import get_group_size, create_group
from mindspore.nn.cell import Cell
@@ -26,6 +26,7 @@ from mindspore.ops import functional as F
from mindspore.ops import composite as C
from mindspore.ops import operations as P
from mindspore.common import dtype as mstype
from .boost import AutoBoost
from .grad_freeze import FreezeOpt, freeze_cell
from .adasum import AdaSum
from .grad_accumulation import gradient_accumulation_op, gradient_clear_op
@@ -142,9 +143,12 @@ class BoostTrainOneStepCell(TrainOneStepCell):
self.weights = self.optimizer.parameters
self.train_strategy = getattr(self.optimizer, 'train_strategy', None)

auto_boost = AutoBoost()
self.use_grad_accumulation = self.parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.STAND_ALONE)
self.use_grad_accumulation = self.use_grad_accumulation & auto_boost.boost_config["grad_accumulation"]
self.max_accumulation_step = 1
if self.use_grad_accumulation:
self.max_accumulation_step = get_auto_parallel_context("grad_accumulation_step")
self.max_accumulation_step = auto_boost.grad_accumulation_step
if self.max_accumulation_step <= 1:
self.max_accumulation_step = 1
self.use_grad_accumulation = False
@@ -170,7 +174,7 @@ class BoostTrainOneStepCell(TrainOneStepCell):
if self.enable_adasum:
_rank = _get_global_rank()
_rank_size = get_group_size()
_device_number = 8
_device_number = auto_boost.device_number
self.device_number = _device_number
group_number = _rank_size // _device_number

@@ -214,6 +218,9 @@ class BoostTrainOneStepCell(TrainOneStepCell):

Inputs:
- **(*inputs)** (Tuple(Tensor)) - Tuple of input tensors with shape :math:`(N, \ldots)`.

Outputs:
- **loss** (Tensor) - Tensor with shape :math:`()`.
"""
if self.train_strategy is None:
step = self.step
@@ -235,6 +242,9 @@ class BoostTrainOneStepCell(TrainOneStepCell):
Inputs:
- **loss** (Tensor) - Tensor with shape :math:`()`.
- **grads** (Tuple(Tensor)) - Tuple of gradient tensors.

Outputs:
- **loss** (Tensor) - Tensor with shape :math:`()`.
"""
loss = F.depend(loss, self.hyper_map(F.partial(gradient_accumulation_op, self.max_accumulation_step),
self.grad_accumulation, grads))
@@ -259,6 +269,9 @@ class BoostTrainOneStepCell(TrainOneStepCell):
Inputs:
- **loss** (Tensor) - Tensor with shape :math:`()`.
- **grads** (Tuple(Tensor)) - Tuple of gradient tensors.

Outputs:
- **loss** (Tensor) - Tensor with shape :math:`()`.
"""
loss = F.depend(loss, self.optimizer(grads))
rank_weights = self.weights[self.start[self.server_rank]: self.end[self.server_rank]]
@@ -281,9 +294,9 @@ class BoostTrainOneStepCell(TrainOneStepCell):
r"""
Check adasum enable.

Inputs:
- **optimizer** (Union[Cell]) - Optimizer for updating the weights.
- **reducer_flag** (bool) - Reducer flag.
Args:
optimizer (Union[Cell]) - Optimizer for updating the weights.
reducer_flag (bool) - Reducer flag.
"""
if not getattr(optimizer, "adasum", None) or not reducer_flag:
return False


+ 22
- 22
mindspore/boost/grad_freeze.py View File

@@ -168,7 +168,7 @@ class GradientFreeze:
total_steps (numbers.Number): Steps of the whole training.

Examples:
>>> gradient_freeze_class = acc.GradientFreeze(10, 1, 0.5, 2000)
>>> gradient_freeze_class = boost.GradientFreeze(10, 1, 0.5, 2000)
>>> network, optimizer = gradient_freeze_class.freeze_generate(network, optimizer)
"""
def __init__(self, param_groups, freeze_type, freeze_p, total_steps):
@@ -183,9 +183,9 @@ class GradientFreeze:
r"""
Split parameter groups for gradients freezing training.

Inputs:
- **net** (Cell) - The training network.
- **freeze_para_groups_number** (int) - The number of gradient freeze groups.
Args:
net (Cell) - The training network.
freeze_para_groups_number (int) - The number of gradient freeze groups.
"""
grouped_params = []
tmp = []
@@ -210,11 +210,11 @@ class GradientFreeze:
r"""
Generate index sequence for gradient freezing training.

Inputs:
- **parameter_groups_number** (int) - The number of parameter groups.
- **freeze_strategy** (int) - Gradient freeze grouping strategy, select from [0, 1].
- **freeze_p** (float) - Gradient freezing probability.
- **total_steps** (int) - Total training steps.
Args:
parameter_groups_number (int) - The number of parameter groups.
freeze_strategy (int) - Gradient freeze grouping strategy, select from [0, 1].
freeze_p (float) - Gradient freezing probability.
total_steps (int) - Total training steps.
"""
total_step = int(total_steps * 1.01)
if parameter_groups_number <= 1:
@@ -252,9 +252,9 @@ class GradientFreeze:
r"""
Generate freeze network and optimizer.

Inputs:
- **network** (Cell) - The training network.
- **optimizer** (Cell) - Optimizer for updating the weights.
Args:
network (Cell) - The training network.
optimizer (Cell) - Optimizer for updating the weights.
"""
train_para_groups = self.split_parameters_groups(
network, self._param_groups)
@@ -273,16 +273,16 @@ def freeze_cell(reducer_flag, network, optimizer, sens, grad, use_grad_accumulat
r"""
Generate freeze network and optimizer.

Inputs:
- **reducer_flag** (bool) - Reducer flag.
- **network** (Cell) - The training network.
- **optimizer** (Cell) - Optimizer for updating the weights.
- **sens** (Tensor) - Tensor with shape :math:`()`
- **grad** (Tuple(Tensor)) - Tuple of gradient tensors.
- **use_grad_accumulation** (bool) - Use gradient accumulation flag.
- **mean** (bool) - Gradients mean flag. default: None.
- **degree** (int) - Device number. default: None.
- **max_accumulation_step** (int) - Max accumulation steps. default: 1.
Args:
reducer_flag (bool) - Reducer flag.
network (Cell) - The training network.
optimizer (Cell) - Optimizer for updating the weights.
sens (Tensor) - Tensor with shape :math:`()`
grad (Tuple(Tensor)) - Tuple of gradient tensors.
use_grad_accumulation (bool) - Use gradient accumulation flag.
mean (bool) - Gradients mean flag. default: None.
degree (int) - Device number. default: None.
max_accumulation_step (int) - Max accumulation steps. default: 1.

Examples:
>>> import numpy as np


+ 24
- 17
mindspore/train/amp.py View File

@@ -87,6 +87,29 @@ def _check_kwargs(key_words):
validator.check_value_type('loss_scale_manager', loss_scale_manager, LossScaleManager)


def _check_level(level, boost_level):
"""Check level."""
if not isinstance(level, str):
raise TypeError("The argument `level` must be a string in ['O0', 'O2', 'O3', 'auto'], \
but got type {}.".format(type(level)))
validator.check('level', level, "", ['O0', 'O2', 'O3', 'auto'], Rel.IN)
validator.check('boost_level', boost_level, "", ['O0', 'O1', 'O2'], Rel.IN)

if level == "auto":
device_target = context.get_context('device_target')
if device_target == "GPU":
level = "O2"
elif device_target == "Ascend":
level = "O3"
else:
raise ValueError("Level `auto` only support when `device_target` is GPU or Ascend.")

enable_boost = False
if boost_level in ["O1", "O2"]:
enable_boost = True

return level, enable_boost

def _add_loss_network(network, loss_fn, cast_model_type):
"""Add loss network."""

@@ -159,20 +182,8 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', boost_leve
"""
validator.check_value_type('network', network, nn.Cell)
validator.check_value_type('optimizer', optimizer, (nn.Optimizer, boost.FreezeOpt))
if not isinstance(level, str):
raise TypeError(f"The argument `level` must be a string in ['O0', 'O2', 'O3', 'auto'], "
f"but got type {str(type(level))}.")
validator.check('level', level, "", ['O0', 'O2', 'O3', 'auto'], Rel.IN)
validator.check('boost_level', boost_level, "", ['O0', 'O1', 'O2'], Rel.IN)

if level == "auto":
device_target = context.get_context('device_target')
if device_target == "GPU":
level = "O2"
elif device_target == "Ascend":
level = "O3"
else:
raise ValueError("Level `auto` only support when `device_target` is GPU or Ascend.")
level, enable_boost = _check_level(level, boost_level)

_check_kwargs(kwargs)
config = dict(_config_level[level], **kwargs)
@@ -189,10 +200,6 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', boost_leve
if _get_parallel_mode() in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
network = _VirtualDatasetCell(network)

enable_boost = False
if boost_level in ["O1", "O2"]:
enable_boost = True

loss_scale = 1.0
if config["loss_scale_manager"] is not None:
loss_scale_manager = config["loss_scale_manager"]


+ 7
- 3
mindspore/train/model.py View File

@@ -97,6 +97,7 @@ class Model:
the accuracy is the same as the original accuracy.
- O2: Enable the boost mode, the performance is improved by about 30%, and
the accuracy is reduced by less than 3%.
If you want to config boost mode by yourself, you can set boost_config_dict as `boost.py`.
Examples:
>>> from mindspore import Model, nn
>>>
@@ -187,9 +188,9 @@ class Model:

def _check_kwargs(self, kwargs):
for arg in kwargs:
if arg not in ['loss_scale_manager', 'keep_batchnorm_fp32']:
if arg not in ['loss_scale_manager', 'keep_batchnorm_fp32', 'boost_config_dict']:
raise ValueError(f"The argument in 'kwargs' should be 'loss_scale_manager' or "
f"'keep_batchnorm_fp32', but got '{arg}'.")
f"'keep_batchnorm_fp32' or 'boost_config_dict', but got '{arg}'.")

def _check_reuse_dataset(self, dataset):
if not hasattr(dataset, '__model_hash__'):
@@ -199,7 +200,10 @@ class Model:

def _build_boost_network(self, kwargs):
"""Build the boost network."""
processor = AutoBoost(self._boost_level, kwargs)
boost_config_dict = ""
if 'boost_config_dict' in kwargs:
boost_config_dict = kwargs['boost_config_dict']
processor = AutoBoost(self._boost_level, boost_config_dict)
if processor.level not in ["O1", "O2"]:
return
if self._optimizer is None:


Loading…
Cancel
Save