Browse Source

support grad freeze

pull/12898/head
buxue linqingke 4 years ago
parent
commit
2bda8c21e9
10 changed files with 240 additions and 35 deletions
  1. +2
    -4
      mindspore/ccsrc/backend/session/ascend_session.cc
  2. +0
    -1
      mindspore/ccsrc/debug/data_dump/cpu_e2e_dump.cc
  3. +1
    -1
      mindspore/core/abstract/prim_statement.cc
  4. +4
    -2
      mindspore/nn/acc/__init__.py
  5. +139
    -0
      mindspore/nn/acc/grad_freeze.py
  6. +5
    -5
      mindspore/nn/acc/less_batch_normalization.py
  7. +2
    -0
      mindspore/nn/optim/momentum.py
  8. +13
    -1
      mindspore/nn/optim/optimizer.py
  9. +72
    -20
      mindspore/nn/wrap/cell_wrapper.py
  10. +2
    -1
      mindspore/train/amp.py

+ 2
- 4
mindspore/ccsrc/backend/session/ascend_session.cc View File

@@ -478,10 +478,8 @@ GraphId AscendSession::CompileGraphImpl(NotNull<FuncGraphPtr> func_graph) {
auto root_graph = ConstructKernelGraph(func_graph, &all_graphs); auto root_graph = ConstructKernelGraph(func_graph, &all_graphs);
// Update Graph Dynamic Shape Attr // Update Graph Dynamic Shape Attr
UpdateAllGraphDynamicShapeAttr(all_graphs); UpdateAllGraphDynamicShapeAttr(all_graphs);
for (const auto &graph : all_graphs) {
UnifyMindIR(graph);
}
BackendOptimization(all_graphs);
UnifyMindIR(root_graph);
opt::BackendCommonOptimization(root_graph);
// empty graph dont entry to backend // empty graph dont entry to backend
if (root_graph->execution_order().empty()) { if (root_graph->execution_order().empty()) {
MS_LOG(INFO) << root_graph->ToString() << " is empty graph."; MS_LOG(INFO) << root_graph->ToString() << " is empty graph.";


+ 0
- 1
mindspore/ccsrc/debug/data_dump/cpu_e2e_dump.cc View File

@@ -19,7 +19,6 @@
#include "backend/session/anf_runtime_algorithm.h" #include "backend/session/anf_runtime_algorithm.h"


namespace mindspore { namespace mindspore {

void CPUE2eDump::DumpCNodeData(const CNodePtr &node) { void CPUE2eDump::DumpCNodeData(const CNodePtr &node) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
auto &dump_json_parser = DumpJsonParser::GetInstance(); auto &dump_json_parser = DumpJsonParser::GetInstance();


+ 1
- 1
mindspore/core/abstract/prim_statement.cc View File

@@ -92,7 +92,7 @@ AbstractBasePtr InferImplSwitchLayer(const AnalysisEnginePtr &, const PrimitiveP


for (size_t i = 0; i < branches.size(); i++) { for (size_t i = 0; i < branches.size(); i++) {
MS_EXCEPTION_IF_NULL(branches[i]); MS_EXCEPTION_IF_NULL(branches[i]);
if (!branches[i]->isa<FuncGraphAbstractClosure>()) {
if (!branches[i]->isa<FuncGraphAbstractClosure>() && !branches[i]->isa<PartialAbstractClosure>()) {
MS_EXCEPTION(ValueError) << op_name << " requires that the 2th arg be tuple of functions, but got " MS_EXCEPTION(ValueError) << op_name << " requires that the 2th arg be tuple of functions, but got "
<< branches[i]->ToString() << " as the " << i << "th element."; << branches[i]->ToString() << " as the " << i << "th element.";
} }


+ 4
- 2
mindspore/nn/acc/__init__.py View File

@@ -17,6 +17,8 @@ Accelerating.


Provide auto accelerating for network, such as Less BN. Provide auto accelerating for network, such as Less BN.
""" """
from .less_batch_normalization import LessBN
from .less_batch_normalization import *
from .grad_freeze import *


__all__ = ['LessBN']
__all__ = ['LessBN', 'FreezeOpt', 'CONTINUOUS_STRATEGY', 'INTERVAL_STRATEGY',
'split_parameters_groups', 'generate_freeze_index_sequence']

+ 139
- 0
mindspore/nn/acc/grad_freeze.py View File

@@ -0,0 +1,139 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""grad freeze"""

import numpy as np

from mindspore.nn.cell import Cell
from mindspore.nn.optim import Optimizer
from mindspore.common import Tensor, Parameter
from mindspore.common import dtype as mstype

__all__ = ['CONTINUOUS_STRATEGY', 'INTERVAL_STRATEGY',
'split_parameters_groups', 'generate_freeze_index_sequence',
'FreezeOpt']

CONTINUOUS_STRATEGY = 0
INTERVAL_STRATEGY = 1


def split_parameters_groups(net, freeze_para_groups_number):
"""Split parameter groups for gradients freezing training."""
grouped_params = []
tmp = []
for para in net.trainable_params():
name = para.name
# ensure 'bn' after 'conv' is not split
if 'bn' in name or 'bias' in name:
tmp.append(para)
elif len(tmp) >= 3:
grouped_params.append(tmp)
tmp = [para]
else:
tmp.append(para)
if tmp:
grouped_params.append(tmp)
stride = len(grouped_params) // freeze_para_groups_number
freeze_grouped_params = [sum(grouped_params[i * stride:], []) for i in range(freeze_para_groups_number)]
return freeze_grouped_params


def generate_freeze_index_sequence(parameter_groups_number, freeze_strategy, freeze_p, steps_per_epoch, max_epoch):
"""Generate index sequence for gradient freezing training."""
total_step = steps_per_epoch * max_epoch * 1.01
# local continuous freezing training strategy, as '00001234'
if freeze_strategy == CONTINUOUS_STRATEGY:
zero_cnt = int(freeze_p * (parameter_groups_number - 1) / (1 - freeze_p) + 0.5)
sub_idx = [0] * zero_cnt + list(range(1, parameter_groups_number))
freeze_idxes = []
while len(freeze_idxes) < total_step:
freeze_idxes += sub_idx
return freeze_idxes
# interval freezing training strategy, as '01020304'
if freeze_strategy == INTERVAL_STRATEGY:
index_all = list(range(1, parameter_groups_number))
prob = [x / sum(index_all) for x in index_all]
freeze_idxes = [0]
zero_cnt = 1
freeze_cnt = 0
while len(freeze_idxes) < total_step:
freeze_p_cur = 1.0 * freeze_cnt / (zero_cnt + freeze_cnt)
if freeze_p_cur < 1 - freeze_p:
freeze_idxes.append(int(np.random.choice(index_all[::-1], p=prob)))
freeze_cnt += 1
else:
freeze_idxes.append(0)
zero_cnt += 1
return freeze_idxes
raise ValueError(f"Unsupported freezing training strategy '{freeze_strategy}'")


class FreezeOpt(Cell):
"""
Optimizer that supports gradients freezing training.

Args:
opt (Optimizer): non-freezing optimizer instance, such as 'Momentum', 'SGD'.
train_parameter_groups (Union[Tuple, List]): Groups of parameters for gradients freezing training.
train_strategy (Union[tuple(int), list(int), Tensor]): Strategy for gradients freezing training.

Supported Platforms:
``Ascend``
"""
def __init__(self, opt, train_parameter_groups=None, train_strategy=None):
super(FreezeOpt, self).__init__()
if not isinstance(opt, Optimizer):
raise TypeError(f"The first arg 'opt' must be an Optimizer instance, but got {type(opt)}")
if train_strategy is not None and train_parameter_groups is None:
raise ValueError("When the 'train_strategy' is specified, the value of 'train_parameter_groups' "
"must also be specified")
opt_class = type(opt)
opt_init_args = opt.init_args
self.opts = []

if train_parameter_groups is None:
groups_num = 10
step = 6
parameters = opt.parameters
para_groups = (parameters[(i * step):] for i in range(groups_num))
self.opts = [opt_class(params=params, **opt_init_args) for params in para_groups]
else:
if not isinstance(train_parameter_groups, (tuple, list)):
raise TypeError("The specified 'train_parameter_groups' should be tuple or list")
for params in train_parameter_groups:
if not isinstance(params, (tuple, list)):
raise TypeError("The each element of 'train_parameter_groups' should be tuple or list "
"to store the Parameter")
for para in params:
if not isinstance(para, Parameter):
raise TypeError("The element of each group should be the Parameter")

# generate one-to-one opt corresponding to the parameter group
self.opts.append(opt_class(params=params, **opt_init_args))

if isinstance(train_strategy, (tuple, list)):
for ele in train_strategy:
if not isinstance(ele, int):
raise ValueError("The element in train_strategy should be int number")
self.train_strategy = Tensor(train_strategy, mstype.int32)
elif isinstance(train_strategy, Tensor):
if train_strategy.ndim != 1 or train_strategy.dtype != mstype.int32:
raise ValueError("When train_strategy is a Tensor, the dimension should be 1 and "
"the dtype should be int32")
self.train_strategy = train_strategy
elif train_strategy is None:
self.train_strategy = None
else:
raise TypeError("The specified 'train_strategy' should be None, tuple, list or Tensor")

+ 5
- 5
mindspore/nn/acc/less_batch_normalization.py View File

@@ -14,12 +14,12 @@
# ============================================================================ # ============================================================================
"""less batch normalization""" """less batch normalization"""
import numpy as np import numpy as np
from mindspore import nn
from mindspore.nn.cell import Cell
from mindspore.nn.layer import Dense
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore import Tensor, Parameter
from mindspore import dtype as mstype
from mindspore.common import Tensor, Parameter
from mindspore.common import dtype as mstype
from mindspore.common.initializer import initializer from mindspore.common.initializer import initializer
from ..cell import Cell




__all__ = ["LessBN"] __all__ = ["LessBN"]
@@ -126,7 +126,7 @@ class LessBN(Cell):
subcell = cells[name] subcell = cells[name]
if subcell == net: if subcell == net:
continue continue
elif isinstance(subcell, (nn.Dense)):
elif isinstance(subcell, (Dense)):
dense_name.append(name) dense_name.append(name)
dense_list.append(subcell) dense_list.append(subcell)
else: else:


+ 2
- 0
mindspore/nn/optim/momentum.py View File

@@ -20,6 +20,7 @@ from mindspore.common.tensor import Tensor
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
from mindspore._checkparam import Validator from mindspore._checkparam import Validator
from .optimizer import Optimizer from .optimizer import Optimizer
from .optimizer import opt_init_args_register


_momentum_opt = C.MultitypeFuncGraph("momentum_opt") _momentum_opt = C.MultitypeFuncGraph("momentum_opt")


@@ -147,6 +148,7 @@ class Momentum(Optimizer):
>>> loss = nn.SoftmaxCrossEntropyWithLogits() >>> loss = nn.SoftmaxCrossEntropyWithLogits()
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None) >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None)
""" """
@opt_init_args_register
def __init__(self, params, learning_rate, momentum, weight_decay=0.0, loss_scale=1.0, use_nesterov=False): def __init__(self, params, learning_rate, momentum, weight_decay=0.0, loss_scale=1.0, use_nesterov=False):
super(Momentum, self).__init__(learning_rate, params, weight_decay, loss_scale) super(Momentum, self).__init__(learning_rate, params, weight_decay, loss_scale)
Validator.check_value_type("momentum", momentum, [float], self.cls_name) Validator.check_value_type("momentum", momentum, [float], self.cls_name)


+ 13
- 1
mindspore/nn/optim/optimizer.py View File

@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""optimizer""" """optimizer"""
import inspect
from typing import Iterable from typing import Iterable


import numpy as np import numpy as np
@@ -33,7 +34,18 @@ from mindspore.context import ParallelMode
from mindspore import context from mindspore import context
from mindspore.nn.learning_rate_schedule import LearningRateSchedule from mindspore.nn.learning_rate_schedule import LearningRateSchedule


__all__ = ['Optimizer']
__all__ = ['Optimizer', 'opt_init_args_register']

def opt_init_args_register(fn):
def deco(self, *args, **kwargs):
bound_args = inspect.signature(fn).bind(self, *args, **kwargs)
bound_args.apply_defaults()
arguments = bound_args.arguments
arguments.pop('self')
arguments.pop('params')
setattr(self, 'init_args', arguments)
fn(self, *args, **kwargs)
return deco




class Optimizer(Cell): class Optimizer(Cell):


+ 72
- 20
mindspore/nn/wrap/cell_wrapper.py View File

@@ -27,6 +27,7 @@ from ...ops import functional as F
from ...ops import operations as P from ...ops import operations as P
from ...ops.operations.comm_ops import _VirtualDataset from ...ops.operations.comm_ops import _VirtualDataset
from ..cell import Cell from ..cell import Cell
from ...nn import acc
from .grad_reducer import DistributedGradReducer from .grad_reducer import DistributedGradReducer


_get_datatype = C.MultitypeFuncGraph("_get_datatype") _get_datatype = C.MultitypeFuncGraph("_get_datatype")
@@ -292,6 +293,32 @@ class ForwardValueAndGrad(Cell):
return loss, grads return loss, grads




class _TrainFreezeCell(Cell):
"""Gradient freezing training network."""
def __init__(self, net, sens, grad, grad_reducer, use_grad_accumulation, max_accumulation_step, optimizer):
super(_TrainFreezeCell, self).__init__(auto_prefix=False)
self.net = net
self.grad = grad
self.grad_reducer = grad_reducer
self.opt = optimizer
self.parameters = optimizer.parameters
self.sens = sens
self.use_grad_accumulation = use_grad_accumulation
if use_grad_accumulation:
self.grad_accumulation = GradientAccumulation(max_accumulation_step, optimizer)

def construct(self, *inputs):
loss = self.net(*inputs)
sens = F.fill(loss.dtype, loss.shape, self.sens)
grads = self.grad(self.net, self.parameters)(*inputs, sens)
grads = self.grad_reducer(grads)
if self.use_grad_accumulation:
loss = self.grad_accumulation(loss, grads)
else:
loss = F.depend(loss, self.opt(grads))
return loss


class TrainOneStepCell(Cell): class TrainOneStepCell(Cell):
r""" r"""
Network training package class. Network training package class.
@@ -302,7 +329,7 @@ class TrainOneStepCell(Cell):


Args: Args:
network (Cell): The training network. The network only supports single output. network (Cell): The training network. The network only supports single output.
optimizer (Cell): Optimizer for updating the weights.
optimizer (Union[Cell]): Optimizer for updating the weights.
sens (Number): The scaling number to be filled as the input of backpropagation. Default value is 1.0. sens (Number): The scaling number to be filled as the input of backpropagation. Default value is 1.0.


Inputs: Inputs:
@@ -348,41 +375,66 @@ class TrainOneStepCell(Cell):
super(TrainOneStepCell, self).__init__(auto_prefix=False) super(TrainOneStepCell, self).__init__(auto_prefix=False)
self.network = network self.network = network
self.network.set_grad() self.network.set_grad()
self.weights = optimizer.parameters
self.freeze = isinstance(optimizer, acc.FreezeOpt)
self.optimizer = optimizer self.optimizer = optimizer
if not self.freeze:
self.weights = self.optimizer.parameters
self.train_strategy = getattr(self.optimizer, 'train_strategy', None)
self.grad = C.GradOperation(get_by_list=True, sens_param=True) self.grad = C.GradOperation(get_by_list=True, sens_param=True)
self.sens = sens self.sens = sens
self.reducer_flag = False self.reducer_flag = False
self.grad_reducer = F.identity self.grad_reducer = F.identity
self.parallel_mode = _get_parallel_mode() self.parallel_mode = _get_parallel_mode()
if self.parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL):
self.reducer_flag = True
if self.reducer_flag:
self.mean = _get_gradients_mean()
self.degree = _get_device_num()
self.grad_reducer = DistributedGradReducer(self.weights, self.mean, self.degree)
self.use_grad_accumulation = False
if self.parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.STAND_ALONE):
self.use_grad_accumulation = True
self.reducer_flag = self.parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL)
self.use_grad_accumulation = self.parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.STAND_ALONE)
if self.use_grad_accumulation: if self.use_grad_accumulation:
self.max_accumulation_step = get_auto_parallel_context("grad_accumulation_step") self.max_accumulation_step = get_auto_parallel_context("grad_accumulation_step")
if self.max_accumulation_step <= 1: if self.max_accumulation_step <= 1:
self.max_accumulation_step = 1 self.max_accumulation_step = 1
self.use_grad_accumulation = False self.use_grad_accumulation = False
self.grad_accumulation = None
if self.use_grad_accumulation: if self.use_grad_accumulation:
self.grad_accumulation = GradientAccumulation(self.max_accumulation_step, self.optimizer) self.grad_accumulation = GradientAccumulation(self.max_accumulation_step, self.optimizer)
if self.reducer_flag:
self.mean = _get_gradients_mean()
self.degree = _get_device_num()
if self.freeze:
self.grad_reducers = (DistributedGradReducer(opt.parameters, self.mean, self.degree)
for opt in self.optimizer.opts)
self.freeze_nets = tuple(_TrainFreezeCell(self.network, self.sens, self.grad, reducer,
self.use_grad_accumulation, self.max_accumulation_step, opt)
for reducer, opt in zip(self.grad_reducers, self.optimizer))
else:
self.grad_reducer = DistributedGradReducer(self.optimizer.parameters, self.mean, self.degree)
else:
if self.freeze:
self.freeze_nets = tuple(_TrainFreezeCell(self.network, self.sens, self.grad, self.grad_reducer,
self.use_grad_accumulation, self.max_accumulation_step, opt)
for opt in self.optimizer.opts)
self.step = Parameter(Tensor(0, dtype=mstype.int32))


def construct(self, *inputs): def construct(self, *inputs):
weights = self.weights
loss = self.network(*inputs)
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
grads = self.grad(self.network, weights)(*inputs, sens)
grads = self.grad_reducer(grads)

if self.use_grad_accumulation:
loss = self.grad_accumulation(loss, grads)
if self.freeze:
if self.train_strategy is None:
step = self.step
max_index = len(self.freeze_nets)
else:
step = self.train_strategy[self.step]
max_index = len(self.train_strategy)
loss = self.freeze_nets[step](*inputs)
if self.step + 1 >= max_index:
self.step = 0
else:
self.step += 1
else: else:
loss = F.depend(loss, self.optimizer(grads))
loss = self.network(*inputs)
sens = F.fill(loss.dtype, loss.shape, self.sens)
grads = self.grad(self.network, self.weights)(*inputs, sens)
grads = self.grad_reducer(grads)
if self.use_grad_accumulation:
loss = self.grad_accumulation(loss, grads)
else:
loss = F.depend(loss, self.optimizer(grads))
return loss return loss






+ 2
- 1
mindspore/train/amp.py View File

@@ -19,6 +19,7 @@ from .. import nn
from .._checkparam import Validator as validator from .._checkparam import Validator as validator
from .._checkparam import Rel from .._checkparam import Rel
from ..common import dtype as mstype from ..common import dtype as mstype
from ..nn import acc
from ..nn.wrap.cell_wrapper import _VirtualDatasetCell from ..nn.wrap.cell_wrapper import _VirtualDatasetCell
from ..ops import functional as F from ..ops import functional as F
from ..parallel._utils import _get_parallel_mode from ..parallel._utils import _get_parallel_mode
@@ -139,7 +140,7 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', **kwargs):
scale the loss by `LossScaleManager`. If set, overwrite the level setting. scale the loss by `LossScaleManager`. If set, overwrite the level setting.
""" """
validator.check_value_type('network', network, nn.Cell) validator.check_value_type('network', network, nn.Cell)
validator.check_value_type('optimizer', optimizer, nn.Optimizer)
validator.check_value_type('optimizer', optimizer, (nn.Optimizer, acc.FreezeOpt))
validator.check('level', level, "", ['O0', 'O2', 'O3', "auto"], Rel.IN) validator.check('level', level, "", ['O0', 'O2', 'O3', "auto"], Rel.IN)


if level == "auto": if level == "auto":


Loading…
Cancel
Save