| @@ -478,10 +478,8 @@ GraphId AscendSession::CompileGraphImpl(NotNull<FuncGraphPtr> func_graph) { | |||
| auto root_graph = ConstructKernelGraph(func_graph, &all_graphs); | |||
| // Update Graph Dynamic Shape Attr | |||
| UpdateAllGraphDynamicShapeAttr(all_graphs); | |||
| for (const auto &graph : all_graphs) { | |||
| UnifyMindIR(graph); | |||
| } | |||
| BackendOptimization(all_graphs); | |||
| UnifyMindIR(root_graph); | |||
| opt::BackendCommonOptimization(root_graph); | |||
| // empty graph dont entry to backend | |||
| if (root_graph->execution_order().empty()) { | |||
| MS_LOG(INFO) << root_graph->ToString() << " is empty graph."; | |||
| @@ -19,7 +19,6 @@ | |||
| #include "backend/session/anf_runtime_algorithm.h" | |||
| namespace mindspore { | |||
| void CPUE2eDump::DumpCNodeData(const CNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto &dump_json_parser = DumpJsonParser::GetInstance(); | |||
| @@ -92,7 +92,7 @@ AbstractBasePtr InferImplSwitchLayer(const AnalysisEnginePtr &, const PrimitiveP | |||
| for (size_t i = 0; i < branches.size(); i++) { | |||
| MS_EXCEPTION_IF_NULL(branches[i]); | |||
| if (!branches[i]->isa<FuncGraphAbstractClosure>()) { | |||
| if (!branches[i]->isa<FuncGraphAbstractClosure>() && !branches[i]->isa<PartialAbstractClosure>()) { | |||
| MS_EXCEPTION(ValueError) << op_name << " requires that the 2th arg be tuple of functions, but got " | |||
| << branches[i]->ToString() << " as the " << i << "th element."; | |||
| } | |||
| @@ -17,6 +17,8 @@ Accelerating. | |||
| Provide auto accelerating for network, such as Less BN. | |||
| """ | |||
| from .less_batch_normalization import LessBN | |||
| from .less_batch_normalization import * | |||
| from .grad_freeze import * | |||
| __all__ = ['LessBN'] | |||
| __all__ = ['LessBN', 'FreezeOpt', 'CONTINUOUS_STRATEGY', 'INTERVAL_STRATEGY', | |||
| 'split_parameters_groups', 'generate_freeze_index_sequence'] | |||
| @@ -0,0 +1,139 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """grad freeze""" | |||
| import numpy as np | |||
| from mindspore.nn.cell import Cell | |||
| from mindspore.nn.optim import Optimizer | |||
| from mindspore.common import Tensor, Parameter | |||
| from mindspore.common import dtype as mstype | |||
| __all__ = ['CONTINUOUS_STRATEGY', 'INTERVAL_STRATEGY', | |||
| 'split_parameters_groups', 'generate_freeze_index_sequence', | |||
| 'FreezeOpt'] | |||
| CONTINUOUS_STRATEGY = 0 | |||
| INTERVAL_STRATEGY = 1 | |||
| def split_parameters_groups(net, freeze_para_groups_number): | |||
| """Split parameter groups for gradients freezing training.""" | |||
| grouped_params = [] | |||
| tmp = [] | |||
| for para in net.trainable_params(): | |||
| name = para.name | |||
| # ensure 'bn' after 'conv' is not split | |||
| if 'bn' in name or 'bias' in name: | |||
| tmp.append(para) | |||
| elif len(tmp) >= 3: | |||
| grouped_params.append(tmp) | |||
| tmp = [para] | |||
| else: | |||
| tmp.append(para) | |||
| if tmp: | |||
| grouped_params.append(tmp) | |||
| stride = len(grouped_params) // freeze_para_groups_number | |||
| freeze_grouped_params = [sum(grouped_params[i * stride:], []) for i in range(freeze_para_groups_number)] | |||
| return freeze_grouped_params | |||
| def generate_freeze_index_sequence(parameter_groups_number, freeze_strategy, freeze_p, steps_per_epoch, max_epoch): | |||
| """Generate index sequence for gradient freezing training.""" | |||
| total_step = steps_per_epoch * max_epoch * 1.01 | |||
| # local continuous freezing training strategy, as '00001234' | |||
| if freeze_strategy == CONTINUOUS_STRATEGY: | |||
| zero_cnt = int(freeze_p * (parameter_groups_number - 1) / (1 - freeze_p) + 0.5) | |||
| sub_idx = [0] * zero_cnt + list(range(1, parameter_groups_number)) | |||
| freeze_idxes = [] | |||
| while len(freeze_idxes) < total_step: | |||
| freeze_idxes += sub_idx | |||
| return freeze_idxes | |||
| # interval freezing training strategy, as '01020304' | |||
| if freeze_strategy == INTERVAL_STRATEGY: | |||
| index_all = list(range(1, parameter_groups_number)) | |||
| prob = [x / sum(index_all) for x in index_all] | |||
| freeze_idxes = [0] | |||
| zero_cnt = 1 | |||
| freeze_cnt = 0 | |||
| while len(freeze_idxes) < total_step: | |||
| freeze_p_cur = 1.0 * freeze_cnt / (zero_cnt + freeze_cnt) | |||
| if freeze_p_cur < 1 - freeze_p: | |||
| freeze_idxes.append(int(np.random.choice(index_all[::-1], p=prob))) | |||
| freeze_cnt += 1 | |||
| else: | |||
| freeze_idxes.append(0) | |||
| zero_cnt += 1 | |||
| return freeze_idxes | |||
| raise ValueError(f"Unsupported freezing training strategy '{freeze_strategy}'") | |||
| class FreezeOpt(Cell): | |||
| """ | |||
| Optimizer that supports gradients freezing training. | |||
| Args: | |||
| opt (Optimizer): non-freezing optimizer instance, such as 'Momentum', 'SGD'. | |||
| train_parameter_groups (Union[Tuple, List]): Groups of parameters for gradients freezing training. | |||
| train_strategy (Union[tuple(int), list(int), Tensor]): Strategy for gradients freezing training. | |||
| Supported Platforms: | |||
| ``Ascend`` | |||
| """ | |||
| def __init__(self, opt, train_parameter_groups=None, train_strategy=None): | |||
| super(FreezeOpt, self).__init__() | |||
| if not isinstance(opt, Optimizer): | |||
| raise TypeError(f"The first arg 'opt' must be an Optimizer instance, but got {type(opt)}") | |||
| if train_strategy is not None and train_parameter_groups is None: | |||
| raise ValueError("When the 'train_strategy' is specified, the value of 'train_parameter_groups' " | |||
| "must also be specified") | |||
| opt_class = type(opt) | |||
| opt_init_args = opt.init_args | |||
| self.opts = [] | |||
| if train_parameter_groups is None: | |||
| groups_num = 10 | |||
| step = 6 | |||
| parameters = opt.parameters | |||
| para_groups = (parameters[(i * step):] for i in range(groups_num)) | |||
| self.opts = [opt_class(params=params, **opt_init_args) for params in para_groups] | |||
| else: | |||
| if not isinstance(train_parameter_groups, (tuple, list)): | |||
| raise TypeError("The specified 'train_parameter_groups' should be tuple or list") | |||
| for params in train_parameter_groups: | |||
| if not isinstance(params, (tuple, list)): | |||
| raise TypeError("The each element of 'train_parameter_groups' should be tuple or list " | |||
| "to store the Parameter") | |||
| for para in params: | |||
| if not isinstance(para, Parameter): | |||
| raise TypeError("The element of each group should be the Parameter") | |||
| # generate one-to-one opt corresponding to the parameter group | |||
| self.opts.append(opt_class(params=params, **opt_init_args)) | |||
| if isinstance(train_strategy, (tuple, list)): | |||
| for ele in train_strategy: | |||
| if not isinstance(ele, int): | |||
| raise ValueError("The element in train_strategy should be int number") | |||
| self.train_strategy = Tensor(train_strategy, mstype.int32) | |||
| elif isinstance(train_strategy, Tensor): | |||
| if train_strategy.ndim != 1 or train_strategy.dtype != mstype.int32: | |||
| raise ValueError("When train_strategy is a Tensor, the dimension should be 1 and " | |||
| "the dtype should be int32") | |||
| self.train_strategy = train_strategy | |||
| elif train_strategy is None: | |||
| self.train_strategy = None | |||
| else: | |||
| raise TypeError("The specified 'train_strategy' should be None, tuple, list or Tensor") | |||
| @@ -14,12 +14,12 @@ | |||
| # ============================================================================ | |||
| """less batch normalization""" | |||
| import numpy as np | |||
| from mindspore import nn | |||
| from mindspore.nn.cell import Cell | |||
| from mindspore.nn.layer import Dense | |||
| from mindspore.ops import operations as P | |||
| from mindspore import Tensor, Parameter | |||
| from mindspore import dtype as mstype | |||
| from mindspore.common import Tensor, Parameter | |||
| from mindspore.common import dtype as mstype | |||
| from mindspore.common.initializer import initializer | |||
| from ..cell import Cell | |||
| __all__ = ["LessBN"] | |||
| @@ -126,7 +126,7 @@ class LessBN(Cell): | |||
| subcell = cells[name] | |||
| if subcell == net: | |||
| continue | |||
| elif isinstance(subcell, (nn.Dense)): | |||
| elif isinstance(subcell, (Dense)): | |||
| dense_name.append(name) | |||
| dense_list.append(subcell) | |||
| else: | |||
| @@ -20,6 +20,7 @@ from mindspore.common.tensor import Tensor | |||
| import mindspore.common.dtype as mstype | |||
| from mindspore._checkparam import Validator | |||
| from .optimizer import Optimizer | |||
| from .optimizer import opt_init_args_register | |||
| _momentum_opt = C.MultitypeFuncGraph("momentum_opt") | |||
| @@ -147,6 +148,7 @@ class Momentum(Optimizer): | |||
| >>> loss = nn.SoftmaxCrossEntropyWithLogits() | |||
| >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None) | |||
| """ | |||
| @opt_init_args_register | |||
| def __init__(self, params, learning_rate, momentum, weight_decay=0.0, loss_scale=1.0, use_nesterov=False): | |||
| super(Momentum, self).__init__(learning_rate, params, weight_decay, loss_scale) | |||
| Validator.check_value_type("momentum", momentum, [float], self.cls_name) | |||
| @@ -13,6 +13,7 @@ | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """optimizer""" | |||
| import inspect | |||
| from typing import Iterable | |||
| import numpy as np | |||
| @@ -33,7 +34,18 @@ from mindspore.context import ParallelMode | |||
| from mindspore import context | |||
| from mindspore.nn.learning_rate_schedule import LearningRateSchedule | |||
| __all__ = ['Optimizer'] | |||
| __all__ = ['Optimizer', 'opt_init_args_register'] | |||
| def opt_init_args_register(fn): | |||
| def deco(self, *args, **kwargs): | |||
| bound_args = inspect.signature(fn).bind(self, *args, **kwargs) | |||
| bound_args.apply_defaults() | |||
| arguments = bound_args.arguments | |||
| arguments.pop('self') | |||
| arguments.pop('params') | |||
| setattr(self, 'init_args', arguments) | |||
| fn(self, *args, **kwargs) | |||
| return deco | |||
| class Optimizer(Cell): | |||
| @@ -27,6 +27,7 @@ from ...ops import functional as F | |||
| from ...ops import operations as P | |||
| from ...ops.operations.comm_ops import _VirtualDataset | |||
| from ..cell import Cell | |||
| from ...nn import acc | |||
| from .grad_reducer import DistributedGradReducer | |||
| _get_datatype = C.MultitypeFuncGraph("_get_datatype") | |||
| @@ -292,6 +293,32 @@ class ForwardValueAndGrad(Cell): | |||
| return loss, grads | |||
| class _TrainFreezeCell(Cell): | |||
| """Gradient freezing training network.""" | |||
| def __init__(self, net, sens, grad, grad_reducer, use_grad_accumulation, max_accumulation_step, optimizer): | |||
| super(_TrainFreezeCell, self).__init__(auto_prefix=False) | |||
| self.net = net | |||
| self.grad = grad | |||
| self.grad_reducer = grad_reducer | |||
| self.opt = optimizer | |||
| self.parameters = optimizer.parameters | |||
| self.sens = sens | |||
| self.use_grad_accumulation = use_grad_accumulation | |||
| if use_grad_accumulation: | |||
| self.grad_accumulation = GradientAccumulation(max_accumulation_step, optimizer) | |||
| def construct(self, *inputs): | |||
| loss = self.net(*inputs) | |||
| sens = F.fill(loss.dtype, loss.shape, self.sens) | |||
| grads = self.grad(self.net, self.parameters)(*inputs, sens) | |||
| grads = self.grad_reducer(grads) | |||
| if self.use_grad_accumulation: | |||
| loss = self.grad_accumulation(loss, grads) | |||
| else: | |||
| loss = F.depend(loss, self.opt(grads)) | |||
| return loss | |||
| class TrainOneStepCell(Cell): | |||
| r""" | |||
| Network training package class. | |||
| @@ -302,7 +329,7 @@ class TrainOneStepCell(Cell): | |||
| Args: | |||
| network (Cell): The training network. The network only supports single output. | |||
| optimizer (Cell): Optimizer for updating the weights. | |||
| optimizer (Union[Cell]): Optimizer for updating the weights. | |||
| sens (Number): The scaling number to be filled as the input of backpropagation. Default value is 1.0. | |||
| Inputs: | |||
| @@ -348,41 +375,66 @@ class TrainOneStepCell(Cell): | |||
| super(TrainOneStepCell, self).__init__(auto_prefix=False) | |||
| self.network = network | |||
| self.network.set_grad() | |||
| self.weights = optimizer.parameters | |||
| self.freeze = isinstance(optimizer, acc.FreezeOpt) | |||
| self.optimizer = optimizer | |||
| if not self.freeze: | |||
| self.weights = self.optimizer.parameters | |||
| self.train_strategy = getattr(self.optimizer, 'train_strategy', None) | |||
| self.grad = C.GradOperation(get_by_list=True, sens_param=True) | |||
| self.sens = sens | |||
| self.reducer_flag = False | |||
| self.grad_reducer = F.identity | |||
| self.parallel_mode = _get_parallel_mode() | |||
| if self.parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL): | |||
| self.reducer_flag = True | |||
| if self.reducer_flag: | |||
| self.mean = _get_gradients_mean() | |||
| self.degree = _get_device_num() | |||
| self.grad_reducer = DistributedGradReducer(self.weights, self.mean, self.degree) | |||
| self.use_grad_accumulation = False | |||
| if self.parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.STAND_ALONE): | |||
| self.use_grad_accumulation = True | |||
| self.reducer_flag = self.parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL) | |||
| self.use_grad_accumulation = self.parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.STAND_ALONE) | |||
| if self.use_grad_accumulation: | |||
| self.max_accumulation_step = get_auto_parallel_context("grad_accumulation_step") | |||
| if self.max_accumulation_step <= 1: | |||
| self.max_accumulation_step = 1 | |||
| self.use_grad_accumulation = False | |||
| self.grad_accumulation = None | |||
| if self.use_grad_accumulation: | |||
| self.grad_accumulation = GradientAccumulation(self.max_accumulation_step, self.optimizer) | |||
| if self.reducer_flag: | |||
| self.mean = _get_gradients_mean() | |||
| self.degree = _get_device_num() | |||
| if self.freeze: | |||
| self.grad_reducers = (DistributedGradReducer(opt.parameters, self.mean, self.degree) | |||
| for opt in self.optimizer.opts) | |||
| self.freeze_nets = tuple(_TrainFreezeCell(self.network, self.sens, self.grad, reducer, | |||
| self.use_grad_accumulation, self.max_accumulation_step, opt) | |||
| for reducer, opt in zip(self.grad_reducers, self.optimizer)) | |||
| else: | |||
| self.grad_reducer = DistributedGradReducer(self.optimizer.parameters, self.mean, self.degree) | |||
| else: | |||
| if self.freeze: | |||
| self.freeze_nets = tuple(_TrainFreezeCell(self.network, self.sens, self.grad, self.grad_reducer, | |||
| self.use_grad_accumulation, self.max_accumulation_step, opt) | |||
| for opt in self.optimizer.opts) | |||
| self.step = Parameter(Tensor(0, dtype=mstype.int32)) | |||
| def construct(self, *inputs): | |||
| weights = self.weights | |||
| loss = self.network(*inputs) | |||
| sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) | |||
| grads = self.grad(self.network, weights)(*inputs, sens) | |||
| grads = self.grad_reducer(grads) | |||
| if self.use_grad_accumulation: | |||
| loss = self.grad_accumulation(loss, grads) | |||
| if self.freeze: | |||
| if self.train_strategy is None: | |||
| step = self.step | |||
| max_index = len(self.freeze_nets) | |||
| else: | |||
| step = self.train_strategy[self.step] | |||
| max_index = len(self.train_strategy) | |||
| loss = self.freeze_nets[step](*inputs) | |||
| if self.step + 1 >= max_index: | |||
| self.step = 0 | |||
| else: | |||
| self.step += 1 | |||
| else: | |||
| loss = F.depend(loss, self.optimizer(grads)) | |||
| loss = self.network(*inputs) | |||
| sens = F.fill(loss.dtype, loss.shape, self.sens) | |||
| grads = self.grad(self.network, self.weights)(*inputs, sens) | |||
| grads = self.grad_reducer(grads) | |||
| if self.use_grad_accumulation: | |||
| loss = self.grad_accumulation(loss, grads) | |||
| else: | |||
| loss = F.depend(loss, self.optimizer(grads)) | |||
| return loss | |||
| @@ -19,6 +19,7 @@ from .. import nn | |||
| from .._checkparam import Validator as validator | |||
| from .._checkparam import Rel | |||
| from ..common import dtype as mstype | |||
| from ..nn import acc | |||
| from ..nn.wrap.cell_wrapper import _VirtualDatasetCell | |||
| from ..ops import functional as F | |||
| from ..parallel._utils import _get_parallel_mode | |||
| @@ -139,7 +140,7 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', **kwargs): | |||
| scale the loss by `LossScaleManager`. If set, overwrite the level setting. | |||
| """ | |||
| validator.check_value_type('network', network, nn.Cell) | |||
| validator.check_value_type('optimizer', optimizer, nn.Optimizer) | |||
| validator.check_value_type('optimizer', optimizer, (nn.Optimizer, acc.FreezeOpt)) | |||
| validator.check('level', level, "", ['O0', 'O2', 'O3', "auto"], Rel.IN) | |||
| if level == "auto": | |||