add gc. remove gc in less_bn [Less BN]Add GC optimizer. update format. fix bug.pull/14809/head
| @@ -20,6 +20,7 @@ namespace mindspore { | |||
| namespace opt { | |||
| namespace irpass { | |||
| namespace { | |||
| enum RemoveNodeType { kOtherNode = 0, kOptimizerNode }; | |||
| const char kLessBatchNormalizationPassName[] = "less_bn"; | |||
| constexpr auto kValidResidualStructureIndex = 1; | |||
| constexpr auto kBNParametersStartIndex = 2; | |||
| @@ -63,6 +64,11 @@ const std::vector<kStructureTuple> ResidualStructureFirstStepPattern{ | |||
| {kSecondBranchPattern3, {prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D}, {SIZE_MAX, SIZE_MAX}}}; | |||
| static const std::vector<std::vector<kStructureTuple>> kNeedMatchPattern = { | |||
| ResidualStructureBasePattern, ResidualStructureShortCutPattern, ResidualStructureFirstStepPattern}; | |||
| const std::set<PrimitivePtr> kNeedRemoveNodeSet{ | |||
| prim::kPrimLoad, prim::kPrimRefToEmbed, prim::kPrimApplyMomentum, prim::kPrimMomentum, | |||
| prim::kPrimApplyFtrl, prim::kPrimSGD, prim::kPrimApplyRMSProp, prim::kPrimAdam}; | |||
| static std::unordered_map<RemoveNodeType, std::unordered_set<size_t>> kRemoveIndex{ | |||
| {RemoveNodeType::kOtherNode, {2}}, {RemoveNodeType::kOptimizerNode, {3, 5, 6}}}; | |||
| bool NeedRemove(const ParameterPtr &a, const std::vector<AnfNodePtr> ¶meter_list) { | |||
| if (a == nullptr) { | |||
| @@ -73,13 +79,56 @@ bool NeedRemove(const ParameterPtr &a, const std::vector<AnfNodePtr> ¶meter_ | |||
| }); | |||
| } | |||
| bool IsNotRealUseNode(const AnfNodePtr &node) { | |||
| for (const auto &prim : kNeedRemoveNodeSet) { | |||
| if (IsPrimitiveCNode(node, prim)) { | |||
| return true; | |||
| } | |||
| } | |||
| return false; | |||
| } | |||
| CNodePtr ConvertRemoveNodeToVirtualNode(const CNodePtr &cnode) { | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| std::vector<AnfNodePtr> args; | |||
| size_t index = 0; | |||
| const auto &inputs = cnode->inputs(); | |||
| auto remove_index = kRemoveIndex[RemoveNodeType::kOptimizerNode]; | |||
| if (IsPrimitiveCNode(cnode, prim::kPrimLoad) || IsPrimitiveCNode(cnode, prim::kPrimRefToEmbed)) { | |||
| remove_index = kRemoveIndex[RemoveNodeType::kOtherNode]; | |||
| } | |||
| (void)std::copy_if( | |||
| inputs.begin(), inputs.end(), std::back_inserter(args), | |||
| [&remove_index, &index](const AnfNodePtr &) { return remove_index.find(index++) != remove_index.end(); }); | |||
| (void)args.insert(args.begin(), NewValueNode(prim::kPrimMakeTuple)); | |||
| const auto &fg = cnode->func_graph(); | |||
| MS_EXCEPTION_IF_NULL(fg); | |||
| auto new_make_tuple = fg->NewCNode(args); | |||
| return new_make_tuple; | |||
| } | |||
| bool IsRealRemoveParameterNode(const FuncGraphManagerPtr &manager, const AnfNodePtr ¶meter) { | |||
| auto param_output = manager->node_users().find(parameter); | |||
| if (param_output == manager->node_users().end()) { | |||
| return true; | |||
| } | |||
| return false; | |||
| bool need_remove = true; | |||
| auto output_info_list = param_output->second; | |||
| for (const auto &output_info : output_info_list) { | |||
| const auto &node = output_info.first; | |||
| if (IsNotRealUseNode(node)) { | |||
| const auto &cnode = node->cast<CNodePtr>(); | |||
| const auto &new_cnode = ConvertRemoveNodeToVirtualNode(cnode); | |||
| manager->Replace(cnode, new_cnode); | |||
| continue; | |||
| } | |||
| need_remove = false; | |||
| } | |||
| return need_remove; | |||
| } | |||
| void RemoveBatchNormalizetionNotUseParameters(const FuncGraphManagerPtr &manager, | |||
| @@ -98,12 +147,20 @@ void RemoveBatchNormalizetionNotUseParameters(const FuncGraphManagerPtr &manager | |||
| [&manager](const AnfNodePtr ¶m) { return IsRealRemoveParameterNode(manager, param); }); | |||
| auto root_parameters = root_graph->parameters(); | |||
| size_t origin_param_count = root_parameters.size(); | |||
| root_parameters.erase(std::remove_if(root_parameters.begin(), root_parameters.end(), | |||
| [&real_remove_parameter_list](const AnfNodePtr &node) { | |||
| return NeedRemove(node->cast<ParameterPtr>(), real_remove_parameter_list); | |||
| }), | |||
| root_parameters.end()); | |||
| size_t remove_param_count = origin_param_count - root_parameters.size(); | |||
| size_t hyper_param_count = root_graph->hyper_param_count(); | |||
| if (remove_param_count > hyper_param_count) { | |||
| MS_LOG(ERROR) << "The number of deleted parameters cannot exceed the number of original parameters."; | |||
| return; | |||
| } | |||
| hyper_param_count = hyper_param_count - remove_param_count; | |||
| root_graph->set_hyper_param_count(hyper_param_count); | |||
| manager->SetParameters(root_graph, root_parameters); | |||
| } | |||
| } // namespace | |||
| @@ -231,6 +231,9 @@ FuncGraphPtr FuncGraph::GenerateGraph(const AbstractBasePtrList &args_spec_list) | |||
| std::vector<abstract::AbstractKeywordArgPtr> kwarg_list; | |||
| std::vector<size_t> pos_arg_indexes; | |||
| size_t arguments_count = args_spec_list.size(); | |||
| if (hyper_param_count_ > arguments_count) { | |||
| MS_LOG(EXCEPTION) << "The number of parameters in funcgraph cannot exceed the number of arguments."; | |||
| } | |||
| for (size_t i = 0; i < arguments_count - hyper_param_count_; i++) { | |||
| MS_EXCEPTION_IF_NULL(args_spec_list[i]); | |||
| if (args_spec_list[i]->isa<abstract::AbstractKeywordArg>()) { | |||
| @@ -13,8 +13,74 @@ | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """less batch normalization""" | |||
| import numpy as np | |||
| from mindspore import nn | |||
| from mindspore.ops import operations as P | |||
| from mindspore import Tensor, Parameter | |||
| from mindspore import dtype as mstype | |||
| from mindspore.common.initializer import initializer | |||
| from ..cell import Cell | |||
| __all__ = ["LessBN"] | |||
| class CommonHeadLastFN(Cell): | |||
| r""" | |||
| The last full normalization layer. | |||
| This layer implements the operation as: | |||
| .. math:: | |||
| \text{inputs} = \text{norm}(\text{inputs}) | |||
| \text{kernel} = \text{norm}(\text{kernel}) | |||
| \text{outputs} = \text{multiplier} * (\text{inputs} * \text{kernel} + \text{bias}), | |||
| Args: | |||
| in_channels (int): The number of channels in the input space. | |||
| out_channels (int): The number of channels in the output space. | |||
| weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype | |||
| is same as input x. The values of str refer to the function `initializer`. Default: 'normal'. | |||
| bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is | |||
| same as input x. The values of str refer to the function `initializer`. Default: 'zeros'. | |||
| has_bias (bool): Specifies whether the layer uses a bias vector. Default: True. | |||
| Supported Platforms: | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| Examples: | |||
| >>> input = Tensor(np.array([[180, 234, 154], [244, 48, 247]]), mindspore.float32) | |||
| >>> net = CommonHeadLastFN(3, 4) | |||
| >>> output = net(input) | |||
| """ | |||
| def __init__(self, | |||
| in_channels, | |||
| out_channels, | |||
| weight_init='normal', | |||
| bias_init='zeros', | |||
| has_bias=True): | |||
| super(CommonHeadLastFN, self).__init__() | |||
| weight_shape = [out_channels, in_channels] | |||
| self.weight = Parameter(initializer(weight_init, weight_shape), requires_grad=True, name='weight') | |||
| self.x_norm = P.L2Normalize(axis=1) | |||
| self.w_norm = P.L2Normalize(axis=1) | |||
| self.fc = P.MatMul(transpose_a=False, transpose_b=True) | |||
| self.multiplier = Parameter(Tensor(np.ones([1]), mstype.float32), requires_grad=True, name='multiplier') | |||
| self.has_bias = has_bias | |||
| if self.has_bias: | |||
| bias_shape = [out_channels] | |||
| self.bias_add = P.BiasAdd() | |||
| self.bias = Parameter(initializer(bias_init, bias_shape), requires_grad=True, name='bias') | |||
| def construct(self, x): | |||
| x = self.x_norm(x) | |||
| w = self.w_norm(self.weight) | |||
| x = self.fc(x, w) | |||
| if self.has_bias: | |||
| x = self.bias_add(x, self.bias) | |||
| x = self.multiplier * x | |||
| return x | |||
| class LessBN(Cell): | |||
| """ | |||
| Reduce the number of BN automatically to improve the network performance | |||
| @@ -31,6 +97,44 @@ class LessBN(Cell): | |||
| super(LessBN, self).__init__() | |||
| self.network = network | |||
| self.network.set_acc("less_bn") | |||
| self.network.update_cell_prefix() | |||
| self._convert_to_less_bn_net(self.network) | |||
| self.network.add_flags(defer_inline=True) | |||
| def _convert_dense(self, subcell): | |||
| """ | |||
| convert dense cell to FN cell | |||
| """ | |||
| prefix = subcell.param_prefix | |||
| new_subcell = CommonHeadLastFN(subcell.in_channels, | |||
| subcell.out_channels, | |||
| subcell.weight, | |||
| subcell.bias, | |||
| subcell.has_bias) | |||
| new_subcell.update_parameters_name(prefix + '.') | |||
| return new_subcell | |||
| def _convert_to_less_bn_net(self, net): | |||
| """ | |||
| convert network to less_bn network | |||
| """ | |||
| cells = net.name_cells() | |||
| dense_name = [] | |||
| dense_list = [] | |||
| for name in cells: | |||
| subcell = cells[name] | |||
| if subcell == net: | |||
| continue | |||
| elif isinstance(subcell, (nn.Dense)): | |||
| dense_name.append(name) | |||
| dense_list.append(subcell) | |||
| else: | |||
| self._convert_to_less_bn_net(subcell) | |||
| if dense_list: | |||
| new_subcell = self._convert_dense(dense_list[-1]) | |||
| net.insert_child_to_cell(dense_name[-1], new_subcell) | |||
| def construct(self, *inputs): | |||
| return self.network(*inputs) | |||
| @@ -1048,7 +1048,7 @@ class Cell(Cell_): | |||
| Some acceleration algorithms may affect the accuracy of the network, please choose carefully. | |||
| Args: | |||
| acc_type (:str:`less_bn`): accelerate algorithm. | |||
| acc_type (str): accelerate algorithm. | |||
| Raises: | |||
| ValueError: If acc_type is not in the algorithm library. | |||
| @@ -153,8 +153,8 @@ class Adagrad(Optimizer): | |||
| params = self.parameters | |||
| accum = self.accum | |||
| grads = self.decay_weight(grads) | |||
| grads = self.scale_grad(grads) | |||
| grads = self.gradients_centralization(grads) | |||
| grads = self.scale_grad(grads) | |||
| lr = self.get_lr() | |||
| if self.is_group_lr: | |||
| success = self.map_(F.partial(_ada_grad_opt, self.opt), lr, params, accum, | |||
| @@ -338,9 +338,9 @@ class Adam(Optimizer): | |||
| moment1 = self.moment1 | |||
| moment2 = self.moment2 | |||
| gradients = self.decay_weight(gradients) | |||
| gradients = self.gradients_centralization(gradients) | |||
| gradients = self.scale_grad(gradients) | |||
| gradients = self._grad_sparse_indices_deduplicate(gradients) | |||
| gradients = self.gradients_centralization(gradients) | |||
| lr = self.get_lr() | |||
| beta1_power = self.beta1_power * self.beta1 | |||
| @@ -218,9 +218,9 @@ class FTRL(Optimizer): | |||
| moments = self.moments | |||
| linear = self.linear | |||
| grads = self.decay_weight(grads) | |||
| grads = self.gradients_centralization(grads) | |||
| grads = self.scale_grad(grads) | |||
| grads = self._grad_sparse_indices_deduplicate(grads) | |||
| grads = self.gradients_centralization(grads) | |||
| lr = self.get_lr() | |||
| success = self.map_(F.partial(_ftrl_opt, self.opt, self.sparse_opt, self._ps_push, self._ps_pull, | |||
| @@ -255,9 +255,9 @@ class LazyAdam(Optimizer): | |||
| def construct(self, gradients): | |||
| gradients = self.decay_weight(gradients) | |||
| gradients = self.gradients_centralization(gradients) | |||
| gradients = self.scale_grad(gradients) | |||
| gradients = self._grad_sparse_indices_deduplicate(gradients) | |||
| gradients = self.gradients_centralization(gradients) | |||
| lr = self.get_lr() | |||
| self.beta1_power = self.beta1_power * self.beta1 | |||
| @@ -159,8 +159,8 @@ class Momentum(Optimizer): | |||
| params = self.params | |||
| moments = self.moments | |||
| gradients = self.decay_weight(gradients) | |||
| gradients = self.scale_grad(gradients) | |||
| gradients = self.gradients_centralization(gradients) | |||
| gradients = self.scale_grad(gradients) | |||
| lr = self.get_lr() | |||
| if self.is_group_lr: | |||
| success = self.hyper_map(F.partial(_momentum_opt, self.opt, self.momentum), lr, gradients, params, moments, | |||
| @@ -439,10 +439,6 @@ class Optimizer(Cell): | |||
| self.grad_centralization = self._preprocess_grad_centralization(group_param['grad_centralization']) | |||
| for param in group_param['params']: | |||
| validator.check_value_type("parameter", param, [Parameter], self.cls_name) | |||
| if "conv" not in param.name and self.grad_centralization is True: | |||
| raise ValueError("Grad centralization can be perform only on the conv layer. If the parameter" | |||
| "is not a convolution layer, this parameter cannot be set to True.") | |||
| grad_centralization_ = self.grad_centralization | |||
| else: | |||
| grad_centralization_ = grad_centralization | |||
| @@ -630,9 +626,16 @@ def _tensor_apply_grad_centralization_with_sparse(if_apply, gradient): | |||
| """Get grad with grad_centralization.""" | |||
| if if_apply: | |||
| indices = gradient.indices | |||
| values = op_gc(gradient.values, -1) | |||
| shape = gradient.dense_shape | |||
| return RowTensor(indices, values, shape) | |||
| grad_shape = F.shape(gradient) | |||
| axis = [] | |||
| for i in range(1, len(grad_shape)): | |||
| axis.append(i) | |||
| if len(axis) >= 1: | |||
| if grad_shape[1] % 16 != 0: | |||
| return gradient | |||
| values = op_gc(gradient.values, axis) | |||
| return RowTensor(indices, values, shape) | |||
| return gradient | |||
| @@ -640,7 +643,14 @@ def _tensor_apply_grad_centralization_with_sparse(if_apply, gradient): | |||
| def _tensor_apply_grad_centralization(if_apply, gradient): | |||
| """Get grad with grad_centralization.""" | |||
| if if_apply: | |||
| return op_gc(gradient, -1) | |||
| axis = [] | |||
| grad_shape = F.shape(gradient) | |||
| for i in range(1, len(grad_shape)): | |||
| axis.append(i) | |||
| if len(axis) >= 1: | |||
| if grad_shape[1] % 16 != 0: | |||
| return gradient | |||
| return op_gc(gradient, axis) | |||
| return gradient | |||
| @@ -165,9 +165,9 @@ class ProximalAdagrad(Optimizer): | |||
| params = self.parameters | |||
| accum = self.accum | |||
| grads = self.decay_weight(grads) | |||
| grads = self.gradients_centralization(grads) | |||
| grads = self.scale_grad(grads) | |||
| grads = self._grad_sparse_indices_deduplicate(grads) | |||
| grads = self.gradients_centralization(grads) | |||
| lr = self.get_lr() | |||
| if self.is_group_lr: | |||
| success = self.map_(F.partial(_proximal_ada_grad_opt, self.opt, self.sparse_opt, self.l1, self.l2), lr, | |||
| @@ -200,8 +200,8 @@ class RMSProp(Optimizer): | |||
| def construct(self, gradients): | |||
| params = self.parameters | |||
| gradients = self.decay_weight(gradients) | |||
| gradients = self.scale_grad(gradients) | |||
| gradients = self.gradients_centralization(gradients) | |||
| gradients = self.scale_grad(gradients) | |||
| lr = self.get_lr() | |||
| if self.centered: | |||
| if self.is_group_lr: | |||
| @@ -173,8 +173,8 @@ class SGD(Optimizer): | |||
| params = self.parameters | |||
| accum = self.accum | |||
| stat = self.stat | |||
| gradients = self.scale_grad(gradients) | |||
| gradients = self.gradients_centralization(gradients) | |||
| gradients = self.scale_grad(gradients) | |||
| lr = self.get_lr() | |||
| if self.is_group_lr: | |||
| success = self.hyper_map(F.partial(_sgd_opt, self.opt, self.momentum), lr, gradients, params, accum, stat) | |||
| @@ -26,9 +26,10 @@ centralization_op_info = TBERegOp("Centralization") \ | |||
| .attr("axis", "required", "listInt", "all") \ | |||
| .input(0, "x", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .op_pattern("reduce") \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.F16_FracZ, DataType.F16_FracZ) \ | |||
| .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ) \ | |||
| .get_op_info() | |||