From 2bda8c21e9ac926a13f0dba45e5f06800e7bc3c3 Mon Sep 17 00:00:00 2001 From: buxue Date: Thu, 8 Apr 2021 22:16:38 +0800 Subject: [PATCH] support grad freeze --- .../ccsrc/backend/session/ascend_session.cc | 6 +- .../ccsrc/debug/data_dump/cpu_e2e_dump.cc | 1 - mindspore/core/abstract/prim_statement.cc | 2 +- mindspore/nn/acc/__init__.py | 6 +- mindspore/nn/acc/grad_freeze.py | 139 ++++++++++++++++++ mindspore/nn/acc/less_batch_normalization.py | 10 +- mindspore/nn/optim/momentum.py | 2 + mindspore/nn/optim/optimizer.py | 14 +- mindspore/nn/wrap/cell_wrapper.py | 92 +++++++++--- mindspore/train/amp.py | 3 +- 10 files changed, 240 insertions(+), 35 deletions(-) create mode 100644 mindspore/nn/acc/grad_freeze.py diff --git a/mindspore/ccsrc/backend/session/ascend_session.cc b/mindspore/ccsrc/backend/session/ascend_session.cc index 4b3d9a87d2..3d75912fdb 100644 --- a/mindspore/ccsrc/backend/session/ascend_session.cc +++ b/mindspore/ccsrc/backend/session/ascend_session.cc @@ -478,10 +478,8 @@ GraphId AscendSession::CompileGraphImpl(NotNull func_graph) { auto root_graph = ConstructKernelGraph(func_graph, &all_graphs); // Update Graph Dynamic Shape Attr UpdateAllGraphDynamicShapeAttr(all_graphs); - for (const auto &graph : all_graphs) { - UnifyMindIR(graph); - } - BackendOptimization(all_graphs); + UnifyMindIR(root_graph); + opt::BackendCommonOptimization(root_graph); // empty graph dont entry to backend if (root_graph->execution_order().empty()) { MS_LOG(INFO) << root_graph->ToString() << " is empty graph."; diff --git a/mindspore/ccsrc/debug/data_dump/cpu_e2e_dump.cc b/mindspore/ccsrc/debug/data_dump/cpu_e2e_dump.cc index 50deb98056..a4522223c0 100644 --- a/mindspore/ccsrc/debug/data_dump/cpu_e2e_dump.cc +++ b/mindspore/ccsrc/debug/data_dump/cpu_e2e_dump.cc @@ -19,7 +19,6 @@ #include "backend/session/anf_runtime_algorithm.h" namespace mindspore { - void CPUE2eDump::DumpCNodeData(const CNodePtr &node) { MS_EXCEPTION_IF_NULL(node); auto &dump_json_parser = DumpJsonParser::GetInstance(); diff --git a/mindspore/core/abstract/prim_statement.cc b/mindspore/core/abstract/prim_statement.cc index 8d27af803f..46778b6b41 100644 --- a/mindspore/core/abstract/prim_statement.cc +++ b/mindspore/core/abstract/prim_statement.cc @@ -92,7 +92,7 @@ AbstractBasePtr InferImplSwitchLayer(const AnalysisEnginePtr &, const PrimitiveP for (size_t i = 0; i < branches.size(); i++) { MS_EXCEPTION_IF_NULL(branches[i]); - if (!branches[i]->isa()) { + if (!branches[i]->isa() && !branches[i]->isa()) { MS_EXCEPTION(ValueError) << op_name << " requires that the 2th arg be tuple of functions, but got " << branches[i]->ToString() << " as the " << i << "th element."; } diff --git a/mindspore/nn/acc/__init__.py b/mindspore/nn/acc/__init__.py index 8e07749c83..372343d7ce 100644 --- a/mindspore/nn/acc/__init__.py +++ b/mindspore/nn/acc/__init__.py @@ -17,6 +17,8 @@ Accelerating. Provide auto accelerating for network, such as Less BN. """ -from .less_batch_normalization import LessBN +from .less_batch_normalization import * +from .grad_freeze import * -__all__ = ['LessBN'] +__all__ = ['LessBN', 'FreezeOpt', 'CONTINUOUS_STRATEGY', 'INTERVAL_STRATEGY', + 'split_parameters_groups', 'generate_freeze_index_sequence'] diff --git a/mindspore/nn/acc/grad_freeze.py b/mindspore/nn/acc/grad_freeze.py new file mode 100644 index 0000000000..e69688dd89 --- /dev/null +++ b/mindspore/nn/acc/grad_freeze.py @@ -0,0 +1,139 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""grad freeze""" + +import numpy as np + +from mindspore.nn.cell import Cell +from mindspore.nn.optim import Optimizer +from mindspore.common import Tensor, Parameter +from mindspore.common import dtype as mstype + +__all__ = ['CONTINUOUS_STRATEGY', 'INTERVAL_STRATEGY', + 'split_parameters_groups', 'generate_freeze_index_sequence', + 'FreezeOpt'] + +CONTINUOUS_STRATEGY = 0 +INTERVAL_STRATEGY = 1 + + +def split_parameters_groups(net, freeze_para_groups_number): + """Split parameter groups for gradients freezing training.""" + grouped_params = [] + tmp = [] + for para in net.trainable_params(): + name = para.name + # ensure 'bn' after 'conv' is not split + if 'bn' in name or 'bias' in name: + tmp.append(para) + elif len(tmp) >= 3: + grouped_params.append(tmp) + tmp = [para] + else: + tmp.append(para) + if tmp: + grouped_params.append(tmp) + stride = len(grouped_params) // freeze_para_groups_number + freeze_grouped_params = [sum(grouped_params[i * stride:], []) for i in range(freeze_para_groups_number)] + return freeze_grouped_params + + +def generate_freeze_index_sequence(parameter_groups_number, freeze_strategy, freeze_p, steps_per_epoch, max_epoch): + """Generate index sequence for gradient freezing training.""" + total_step = steps_per_epoch * max_epoch * 1.01 + # local continuous freezing training strategy, as '00001234' + if freeze_strategy == CONTINUOUS_STRATEGY: + zero_cnt = int(freeze_p * (parameter_groups_number - 1) / (1 - freeze_p) + 0.5) + sub_idx = [0] * zero_cnt + list(range(1, parameter_groups_number)) + freeze_idxes = [] + while len(freeze_idxes) < total_step: + freeze_idxes += sub_idx + return freeze_idxes + # interval freezing training strategy, as '01020304' + if freeze_strategy == INTERVAL_STRATEGY: + index_all = list(range(1, parameter_groups_number)) + prob = [x / sum(index_all) for x in index_all] + freeze_idxes = [0] + zero_cnt = 1 + freeze_cnt = 0 + while len(freeze_idxes) < total_step: + freeze_p_cur = 1.0 * freeze_cnt / (zero_cnt + freeze_cnt) + if freeze_p_cur < 1 - freeze_p: + freeze_idxes.append(int(np.random.choice(index_all[::-1], p=prob))) + freeze_cnt += 1 + else: + freeze_idxes.append(0) + zero_cnt += 1 + return freeze_idxes + raise ValueError(f"Unsupported freezing training strategy '{freeze_strategy}'") + + +class FreezeOpt(Cell): + """ + Optimizer that supports gradients freezing training. + + Args: + opt (Optimizer): non-freezing optimizer instance, such as 'Momentum', 'SGD'. + train_parameter_groups (Union[Tuple, List]): Groups of parameters for gradients freezing training. + train_strategy (Union[tuple(int), list(int), Tensor]): Strategy for gradients freezing training. + + Supported Platforms: + ``Ascend`` + """ + def __init__(self, opt, train_parameter_groups=None, train_strategy=None): + super(FreezeOpt, self).__init__() + if not isinstance(opt, Optimizer): + raise TypeError(f"The first arg 'opt' must be an Optimizer instance, but got {type(opt)}") + if train_strategy is not None and train_parameter_groups is None: + raise ValueError("When the 'train_strategy' is specified, the value of 'train_parameter_groups' " + "must also be specified") + opt_class = type(opt) + opt_init_args = opt.init_args + self.opts = [] + + if train_parameter_groups is None: + groups_num = 10 + step = 6 + parameters = opt.parameters + para_groups = (parameters[(i * step):] for i in range(groups_num)) + self.opts = [opt_class(params=params, **opt_init_args) for params in para_groups] + else: + if not isinstance(train_parameter_groups, (tuple, list)): + raise TypeError("The specified 'train_parameter_groups' should be tuple or list") + for params in train_parameter_groups: + if not isinstance(params, (tuple, list)): + raise TypeError("The each element of 'train_parameter_groups' should be tuple or list " + "to store the Parameter") + for para in params: + if not isinstance(para, Parameter): + raise TypeError("The element of each group should be the Parameter") + + # generate one-to-one opt corresponding to the parameter group + self.opts.append(opt_class(params=params, **opt_init_args)) + + if isinstance(train_strategy, (tuple, list)): + for ele in train_strategy: + if not isinstance(ele, int): + raise ValueError("The element in train_strategy should be int number") + self.train_strategy = Tensor(train_strategy, mstype.int32) + elif isinstance(train_strategy, Tensor): + if train_strategy.ndim != 1 or train_strategy.dtype != mstype.int32: + raise ValueError("When train_strategy is a Tensor, the dimension should be 1 and " + "the dtype should be int32") + self.train_strategy = train_strategy + elif train_strategy is None: + self.train_strategy = None + else: + raise TypeError("The specified 'train_strategy' should be None, tuple, list or Tensor") diff --git a/mindspore/nn/acc/less_batch_normalization.py b/mindspore/nn/acc/less_batch_normalization.py index ee6ddecafb..44fef9a8e2 100644 --- a/mindspore/nn/acc/less_batch_normalization.py +++ b/mindspore/nn/acc/less_batch_normalization.py @@ -14,12 +14,12 @@ # ============================================================================ """less batch normalization""" import numpy as np -from mindspore import nn +from mindspore.nn.cell import Cell +from mindspore.nn.layer import Dense from mindspore.ops import operations as P -from mindspore import Tensor, Parameter -from mindspore import dtype as mstype +from mindspore.common import Tensor, Parameter +from mindspore.common import dtype as mstype from mindspore.common.initializer import initializer -from ..cell import Cell __all__ = ["LessBN"] @@ -126,7 +126,7 @@ class LessBN(Cell): subcell = cells[name] if subcell == net: continue - elif isinstance(subcell, (nn.Dense)): + elif isinstance(subcell, (Dense)): dense_name.append(name) dense_list.append(subcell) else: diff --git a/mindspore/nn/optim/momentum.py b/mindspore/nn/optim/momentum.py index 4ef1019ecd..a3be233830 100755 --- a/mindspore/nn/optim/momentum.py +++ b/mindspore/nn/optim/momentum.py @@ -20,6 +20,7 @@ from mindspore.common.tensor import Tensor import mindspore.common.dtype as mstype from mindspore._checkparam import Validator from .optimizer import Optimizer +from .optimizer import opt_init_args_register _momentum_opt = C.MultitypeFuncGraph("momentum_opt") @@ -147,6 +148,7 @@ class Momentum(Optimizer): >>> loss = nn.SoftmaxCrossEntropyWithLogits() >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None) """ + @opt_init_args_register def __init__(self, params, learning_rate, momentum, weight_decay=0.0, loss_scale=1.0, use_nesterov=False): super(Momentum, self).__init__(learning_rate, params, weight_decay, loss_scale) Validator.check_value_type("momentum", momentum, [float], self.cls_name) diff --git a/mindspore/nn/optim/optimizer.py b/mindspore/nn/optim/optimizer.py index 81218a9292..90fdca3be3 100755 --- a/mindspore/nn/optim/optimizer.py +++ b/mindspore/nn/optim/optimizer.py @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================ """optimizer""" +import inspect from typing import Iterable import numpy as np @@ -33,7 +34,18 @@ from mindspore.context import ParallelMode from mindspore import context from mindspore.nn.learning_rate_schedule import LearningRateSchedule -__all__ = ['Optimizer'] +__all__ = ['Optimizer', 'opt_init_args_register'] + +def opt_init_args_register(fn): + def deco(self, *args, **kwargs): + bound_args = inspect.signature(fn).bind(self, *args, **kwargs) + bound_args.apply_defaults() + arguments = bound_args.arguments + arguments.pop('self') + arguments.pop('params') + setattr(self, 'init_args', arguments) + fn(self, *args, **kwargs) + return deco class Optimizer(Cell): diff --git a/mindspore/nn/wrap/cell_wrapper.py b/mindspore/nn/wrap/cell_wrapper.py index bf2cccf3fa..843f2728f3 100644 --- a/mindspore/nn/wrap/cell_wrapper.py +++ b/mindspore/nn/wrap/cell_wrapper.py @@ -27,6 +27,7 @@ from ...ops import functional as F from ...ops import operations as P from ...ops.operations.comm_ops import _VirtualDataset from ..cell import Cell +from ...nn import acc from .grad_reducer import DistributedGradReducer _get_datatype = C.MultitypeFuncGraph("_get_datatype") @@ -292,6 +293,32 @@ class ForwardValueAndGrad(Cell): return loss, grads +class _TrainFreezeCell(Cell): + """Gradient freezing training network.""" + def __init__(self, net, sens, grad, grad_reducer, use_grad_accumulation, max_accumulation_step, optimizer): + super(_TrainFreezeCell, self).__init__(auto_prefix=False) + self.net = net + self.grad = grad + self.grad_reducer = grad_reducer + self.opt = optimizer + self.parameters = optimizer.parameters + self.sens = sens + self.use_grad_accumulation = use_grad_accumulation + if use_grad_accumulation: + self.grad_accumulation = GradientAccumulation(max_accumulation_step, optimizer) + + def construct(self, *inputs): + loss = self.net(*inputs) + sens = F.fill(loss.dtype, loss.shape, self.sens) + grads = self.grad(self.net, self.parameters)(*inputs, sens) + grads = self.grad_reducer(grads) + if self.use_grad_accumulation: + loss = self.grad_accumulation(loss, grads) + else: + loss = F.depend(loss, self.opt(grads)) + return loss + + class TrainOneStepCell(Cell): r""" Network training package class. @@ -302,7 +329,7 @@ class TrainOneStepCell(Cell): Args: network (Cell): The training network. The network only supports single output. - optimizer (Cell): Optimizer for updating the weights. + optimizer (Union[Cell]): Optimizer for updating the weights. sens (Number): The scaling number to be filled as the input of backpropagation. Default value is 1.0. Inputs: @@ -348,41 +375,66 @@ class TrainOneStepCell(Cell): super(TrainOneStepCell, self).__init__(auto_prefix=False) self.network = network self.network.set_grad() - self.weights = optimizer.parameters + self.freeze = isinstance(optimizer, acc.FreezeOpt) self.optimizer = optimizer + if not self.freeze: + self.weights = self.optimizer.parameters + self.train_strategy = getattr(self.optimizer, 'train_strategy', None) self.grad = C.GradOperation(get_by_list=True, sens_param=True) self.sens = sens self.reducer_flag = False self.grad_reducer = F.identity self.parallel_mode = _get_parallel_mode() - if self.parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL): - self.reducer_flag = True - if self.reducer_flag: - self.mean = _get_gradients_mean() - self.degree = _get_device_num() - self.grad_reducer = DistributedGradReducer(self.weights, self.mean, self.degree) - self.use_grad_accumulation = False - if self.parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.STAND_ALONE): - self.use_grad_accumulation = True + self.reducer_flag = self.parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL) + self.use_grad_accumulation = self.parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.STAND_ALONE) if self.use_grad_accumulation: self.max_accumulation_step = get_auto_parallel_context("grad_accumulation_step") if self.max_accumulation_step <= 1: self.max_accumulation_step = 1 self.use_grad_accumulation = False + self.grad_accumulation = None if self.use_grad_accumulation: self.grad_accumulation = GradientAccumulation(self.max_accumulation_step, self.optimizer) + if self.reducer_flag: + self.mean = _get_gradients_mean() + self.degree = _get_device_num() + if self.freeze: + self.grad_reducers = (DistributedGradReducer(opt.parameters, self.mean, self.degree) + for opt in self.optimizer.opts) + self.freeze_nets = tuple(_TrainFreezeCell(self.network, self.sens, self.grad, reducer, + self.use_grad_accumulation, self.max_accumulation_step, opt) + for reducer, opt in zip(self.grad_reducers, self.optimizer)) + else: + self.grad_reducer = DistributedGradReducer(self.optimizer.parameters, self.mean, self.degree) + else: + if self.freeze: + self.freeze_nets = tuple(_TrainFreezeCell(self.network, self.sens, self.grad, self.grad_reducer, + self.use_grad_accumulation, self.max_accumulation_step, opt) + for opt in self.optimizer.opts) + self.step = Parameter(Tensor(0, dtype=mstype.int32)) def construct(self, *inputs): - weights = self.weights - loss = self.network(*inputs) - sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) - grads = self.grad(self.network, weights)(*inputs, sens) - grads = self.grad_reducer(grads) - - if self.use_grad_accumulation: - loss = self.grad_accumulation(loss, grads) + if self.freeze: + if self.train_strategy is None: + step = self.step + max_index = len(self.freeze_nets) + else: + step = self.train_strategy[self.step] + max_index = len(self.train_strategy) + loss = self.freeze_nets[step](*inputs) + if self.step + 1 >= max_index: + self.step = 0 + else: + self.step += 1 else: - loss = F.depend(loss, self.optimizer(grads)) + loss = self.network(*inputs) + sens = F.fill(loss.dtype, loss.shape, self.sens) + grads = self.grad(self.network, self.weights)(*inputs, sens) + grads = self.grad_reducer(grads) + if self.use_grad_accumulation: + loss = self.grad_accumulation(loss, grads) + else: + loss = F.depend(loss, self.optimizer(grads)) return loss diff --git a/mindspore/train/amp.py b/mindspore/train/amp.py index 163daf31e8..d57717ba6d 100644 --- a/mindspore/train/amp.py +++ b/mindspore/train/amp.py @@ -19,6 +19,7 @@ from .. import nn from .._checkparam import Validator as validator from .._checkparam import Rel from ..common import dtype as mstype +from ..nn import acc from ..nn.wrap.cell_wrapper import _VirtualDatasetCell from ..ops import functional as F from ..parallel._utils import _get_parallel_mode @@ -139,7 +140,7 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', **kwargs): scale the loss by `LossScaleManager`. If set, overwrite the level setting. """ validator.check_value_type('network', network, nn.Cell) - validator.check_value_type('optimizer', optimizer, nn.Optimizer) + validator.check_value_type('optimizer', optimizer, (nn.Optimizer, acc.FreezeOpt)) validator.check('level', level, "", ['O0', 'O2', 'O3', "auto"], Rel.IN) if level == "auto":