add gc. remove gc in less_bn [Less BN]Add GC optimizer. update format. fix bug.pull/14809/head
| @@ -20,6 +20,7 @@ namespace mindspore { | |||||
| namespace opt { | namespace opt { | ||||
| namespace irpass { | namespace irpass { | ||||
| namespace { | namespace { | ||||
| enum RemoveNodeType { kOtherNode = 0, kOptimizerNode }; | |||||
| const char kLessBatchNormalizationPassName[] = "less_bn"; | const char kLessBatchNormalizationPassName[] = "less_bn"; | ||||
| constexpr auto kValidResidualStructureIndex = 1; | constexpr auto kValidResidualStructureIndex = 1; | ||||
| constexpr auto kBNParametersStartIndex = 2; | constexpr auto kBNParametersStartIndex = 2; | ||||
| @@ -63,6 +64,11 @@ const std::vector<kStructureTuple> ResidualStructureFirstStepPattern{ | |||||
| {kSecondBranchPattern3, {prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D}, {SIZE_MAX, SIZE_MAX}}}; | {kSecondBranchPattern3, {prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D}, {SIZE_MAX, SIZE_MAX}}}; | ||||
| static const std::vector<std::vector<kStructureTuple>> kNeedMatchPattern = { | static const std::vector<std::vector<kStructureTuple>> kNeedMatchPattern = { | ||||
| ResidualStructureBasePattern, ResidualStructureShortCutPattern, ResidualStructureFirstStepPattern}; | ResidualStructureBasePattern, ResidualStructureShortCutPattern, ResidualStructureFirstStepPattern}; | ||||
| const std::set<PrimitivePtr> kNeedRemoveNodeSet{ | |||||
| prim::kPrimLoad, prim::kPrimRefToEmbed, prim::kPrimApplyMomentum, prim::kPrimMomentum, | |||||
| prim::kPrimApplyFtrl, prim::kPrimSGD, prim::kPrimApplyRMSProp, prim::kPrimAdam}; | |||||
| static std::unordered_map<RemoveNodeType, std::unordered_set<size_t>> kRemoveIndex{ | |||||
| {RemoveNodeType::kOtherNode, {2}}, {RemoveNodeType::kOptimizerNode, {3, 5, 6}}}; | |||||
| bool NeedRemove(const ParameterPtr &a, const std::vector<AnfNodePtr> ¶meter_list) { | bool NeedRemove(const ParameterPtr &a, const std::vector<AnfNodePtr> ¶meter_list) { | ||||
| if (a == nullptr) { | if (a == nullptr) { | ||||
| @@ -73,13 +79,56 @@ bool NeedRemove(const ParameterPtr &a, const std::vector<AnfNodePtr> ¶meter_ | |||||
| }); | }); | ||||
| } | } | ||||
| bool IsNotRealUseNode(const AnfNodePtr &node) { | |||||
| for (const auto &prim : kNeedRemoveNodeSet) { | |||||
| if (IsPrimitiveCNode(node, prim)) { | |||||
| return true; | |||||
| } | |||||
| } | |||||
| return false; | |||||
| } | |||||
| CNodePtr ConvertRemoveNodeToVirtualNode(const CNodePtr &cnode) { | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| std::vector<AnfNodePtr> args; | |||||
| size_t index = 0; | |||||
| const auto &inputs = cnode->inputs(); | |||||
| auto remove_index = kRemoveIndex[RemoveNodeType::kOptimizerNode]; | |||||
| if (IsPrimitiveCNode(cnode, prim::kPrimLoad) || IsPrimitiveCNode(cnode, prim::kPrimRefToEmbed)) { | |||||
| remove_index = kRemoveIndex[RemoveNodeType::kOtherNode]; | |||||
| } | |||||
| (void)std::copy_if( | |||||
| inputs.begin(), inputs.end(), std::back_inserter(args), | |||||
| [&remove_index, &index](const AnfNodePtr &) { return remove_index.find(index++) != remove_index.end(); }); | |||||
| (void)args.insert(args.begin(), NewValueNode(prim::kPrimMakeTuple)); | |||||
| const auto &fg = cnode->func_graph(); | |||||
| MS_EXCEPTION_IF_NULL(fg); | |||||
| auto new_make_tuple = fg->NewCNode(args); | |||||
| return new_make_tuple; | |||||
| } | |||||
| bool IsRealRemoveParameterNode(const FuncGraphManagerPtr &manager, const AnfNodePtr ¶meter) { | bool IsRealRemoveParameterNode(const FuncGraphManagerPtr &manager, const AnfNodePtr ¶meter) { | ||||
| auto param_output = manager->node_users().find(parameter); | auto param_output = manager->node_users().find(parameter); | ||||
| if (param_output == manager->node_users().end()) { | if (param_output == manager->node_users().end()) { | ||||
| return true; | return true; | ||||
| } | } | ||||
| return false; | |||||
| bool need_remove = true; | |||||
| auto output_info_list = param_output->second; | |||||
| for (const auto &output_info : output_info_list) { | |||||
| const auto &node = output_info.first; | |||||
| if (IsNotRealUseNode(node)) { | |||||
| const auto &cnode = node->cast<CNodePtr>(); | |||||
| const auto &new_cnode = ConvertRemoveNodeToVirtualNode(cnode); | |||||
| manager->Replace(cnode, new_cnode); | |||||
| continue; | |||||
| } | |||||
| need_remove = false; | |||||
| } | |||||
| return need_remove; | |||||
| } | } | ||||
| void RemoveBatchNormalizetionNotUseParameters(const FuncGraphManagerPtr &manager, | void RemoveBatchNormalizetionNotUseParameters(const FuncGraphManagerPtr &manager, | ||||
| @@ -98,12 +147,20 @@ void RemoveBatchNormalizetionNotUseParameters(const FuncGraphManagerPtr &manager | |||||
| [&manager](const AnfNodePtr ¶m) { return IsRealRemoveParameterNode(manager, param); }); | [&manager](const AnfNodePtr ¶m) { return IsRealRemoveParameterNode(manager, param); }); | ||||
| auto root_parameters = root_graph->parameters(); | auto root_parameters = root_graph->parameters(); | ||||
| size_t origin_param_count = root_parameters.size(); | |||||
| root_parameters.erase(std::remove_if(root_parameters.begin(), root_parameters.end(), | root_parameters.erase(std::remove_if(root_parameters.begin(), root_parameters.end(), | ||||
| [&real_remove_parameter_list](const AnfNodePtr &node) { | [&real_remove_parameter_list](const AnfNodePtr &node) { | ||||
| return NeedRemove(node->cast<ParameterPtr>(), real_remove_parameter_list); | return NeedRemove(node->cast<ParameterPtr>(), real_remove_parameter_list); | ||||
| }), | }), | ||||
| root_parameters.end()); | root_parameters.end()); | ||||
| size_t remove_param_count = origin_param_count - root_parameters.size(); | |||||
| size_t hyper_param_count = root_graph->hyper_param_count(); | |||||
| if (remove_param_count > hyper_param_count) { | |||||
| MS_LOG(ERROR) << "The number of deleted parameters cannot exceed the number of original parameters."; | |||||
| return; | |||||
| } | |||||
| hyper_param_count = hyper_param_count - remove_param_count; | |||||
| root_graph->set_hyper_param_count(hyper_param_count); | |||||
| manager->SetParameters(root_graph, root_parameters); | manager->SetParameters(root_graph, root_parameters); | ||||
| } | } | ||||
| } // namespace | } // namespace | ||||
| @@ -231,6 +231,9 @@ FuncGraphPtr FuncGraph::GenerateGraph(const AbstractBasePtrList &args_spec_list) | |||||
| std::vector<abstract::AbstractKeywordArgPtr> kwarg_list; | std::vector<abstract::AbstractKeywordArgPtr> kwarg_list; | ||||
| std::vector<size_t> pos_arg_indexes; | std::vector<size_t> pos_arg_indexes; | ||||
| size_t arguments_count = args_spec_list.size(); | size_t arguments_count = args_spec_list.size(); | ||||
| if (hyper_param_count_ > arguments_count) { | |||||
| MS_LOG(EXCEPTION) << "The number of parameters in funcgraph cannot exceed the number of arguments."; | |||||
| } | |||||
| for (size_t i = 0; i < arguments_count - hyper_param_count_; i++) { | for (size_t i = 0; i < arguments_count - hyper_param_count_; i++) { | ||||
| MS_EXCEPTION_IF_NULL(args_spec_list[i]); | MS_EXCEPTION_IF_NULL(args_spec_list[i]); | ||||
| if (args_spec_list[i]->isa<abstract::AbstractKeywordArg>()) { | if (args_spec_list[i]->isa<abstract::AbstractKeywordArg>()) { | ||||
| @@ -13,8 +13,74 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================ | # ============================================================================ | ||||
| """less batch normalization""" | """less batch normalization""" | ||||
| import numpy as np | |||||
| from mindspore import nn | |||||
| from mindspore.ops import operations as P | |||||
| from mindspore import Tensor, Parameter | |||||
| from mindspore import dtype as mstype | |||||
| from mindspore.common.initializer import initializer | |||||
| from ..cell import Cell | from ..cell import Cell | ||||
| __all__ = ["LessBN"] | |||||
| class CommonHeadLastFN(Cell): | |||||
| r""" | |||||
| The last full normalization layer. | |||||
| This layer implements the operation as: | |||||
| .. math:: | |||||
| \text{inputs} = \text{norm}(\text{inputs}) | |||||
| \text{kernel} = \text{norm}(\text{kernel}) | |||||
| \text{outputs} = \text{multiplier} * (\text{inputs} * \text{kernel} + \text{bias}), | |||||
| Args: | |||||
| in_channels (int): The number of channels in the input space. | |||||
| out_channels (int): The number of channels in the output space. | |||||
| weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype | |||||
| is same as input x. The values of str refer to the function `initializer`. Default: 'normal'. | |||||
| bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is | |||||
| same as input x. The values of str refer to the function `initializer`. Default: 'zeros'. | |||||
| has_bias (bool): Specifies whether the layer uses a bias vector. Default: True. | |||||
| Supported Platforms: | |||||
| ``Ascend`` ``GPU`` ``CPU`` | |||||
| Examples: | |||||
| >>> input = Tensor(np.array([[180, 234, 154], [244, 48, 247]]), mindspore.float32) | |||||
| >>> net = CommonHeadLastFN(3, 4) | |||||
| >>> output = net(input) | |||||
| """ | |||||
| def __init__(self, | |||||
| in_channels, | |||||
| out_channels, | |||||
| weight_init='normal', | |||||
| bias_init='zeros', | |||||
| has_bias=True): | |||||
| super(CommonHeadLastFN, self).__init__() | |||||
| weight_shape = [out_channels, in_channels] | |||||
| self.weight = Parameter(initializer(weight_init, weight_shape), requires_grad=True, name='weight') | |||||
| self.x_norm = P.L2Normalize(axis=1) | |||||
| self.w_norm = P.L2Normalize(axis=1) | |||||
| self.fc = P.MatMul(transpose_a=False, transpose_b=True) | |||||
| self.multiplier = Parameter(Tensor(np.ones([1]), mstype.float32), requires_grad=True, name='multiplier') | |||||
| self.has_bias = has_bias | |||||
| if self.has_bias: | |||||
| bias_shape = [out_channels] | |||||
| self.bias_add = P.BiasAdd() | |||||
| self.bias = Parameter(initializer(bias_init, bias_shape), requires_grad=True, name='bias') | |||||
| def construct(self, x): | |||||
| x = self.x_norm(x) | |||||
| w = self.w_norm(self.weight) | |||||
| x = self.fc(x, w) | |||||
| if self.has_bias: | |||||
| x = self.bias_add(x, self.bias) | |||||
| x = self.multiplier * x | |||||
| return x | |||||
| class LessBN(Cell): | class LessBN(Cell): | ||||
| """ | """ | ||||
| Reduce the number of BN automatically to improve the network performance | Reduce the number of BN automatically to improve the network performance | ||||
| @@ -31,6 +97,44 @@ class LessBN(Cell): | |||||
| super(LessBN, self).__init__() | super(LessBN, self).__init__() | ||||
| self.network = network | self.network = network | ||||
| self.network.set_acc("less_bn") | self.network.set_acc("less_bn") | ||||
| self.network.update_cell_prefix() | |||||
| self._convert_to_less_bn_net(self.network) | |||||
| self.network.add_flags(defer_inline=True) | |||||
| def _convert_dense(self, subcell): | |||||
| """ | |||||
| convert dense cell to FN cell | |||||
| """ | |||||
| prefix = subcell.param_prefix | |||||
| new_subcell = CommonHeadLastFN(subcell.in_channels, | |||||
| subcell.out_channels, | |||||
| subcell.weight, | |||||
| subcell.bias, | |||||
| subcell.has_bias) | |||||
| new_subcell.update_parameters_name(prefix + '.') | |||||
| return new_subcell | |||||
| def _convert_to_less_bn_net(self, net): | |||||
| """ | |||||
| convert network to less_bn network | |||||
| """ | |||||
| cells = net.name_cells() | |||||
| dense_name = [] | |||||
| dense_list = [] | |||||
| for name in cells: | |||||
| subcell = cells[name] | |||||
| if subcell == net: | |||||
| continue | |||||
| elif isinstance(subcell, (nn.Dense)): | |||||
| dense_name.append(name) | |||||
| dense_list.append(subcell) | |||||
| else: | |||||
| self._convert_to_less_bn_net(subcell) | |||||
| if dense_list: | |||||
| new_subcell = self._convert_dense(dense_list[-1]) | |||||
| net.insert_child_to_cell(dense_name[-1], new_subcell) | |||||
| def construct(self, *inputs): | def construct(self, *inputs): | ||||
| return self.network(*inputs) | return self.network(*inputs) | ||||
| @@ -1048,7 +1048,7 @@ class Cell(Cell_): | |||||
| Some acceleration algorithms may affect the accuracy of the network, please choose carefully. | Some acceleration algorithms may affect the accuracy of the network, please choose carefully. | ||||
| Args: | Args: | ||||
| acc_type (:str:`less_bn`): accelerate algorithm. | |||||
| acc_type (str): accelerate algorithm. | |||||
| Raises: | Raises: | ||||
| ValueError: If acc_type is not in the algorithm library. | ValueError: If acc_type is not in the algorithm library. | ||||
| @@ -153,8 +153,8 @@ class Adagrad(Optimizer): | |||||
| params = self.parameters | params = self.parameters | ||||
| accum = self.accum | accum = self.accum | ||||
| grads = self.decay_weight(grads) | grads = self.decay_weight(grads) | ||||
| grads = self.scale_grad(grads) | |||||
| grads = self.gradients_centralization(grads) | grads = self.gradients_centralization(grads) | ||||
| grads = self.scale_grad(grads) | |||||
| lr = self.get_lr() | lr = self.get_lr() | ||||
| if self.is_group_lr: | if self.is_group_lr: | ||||
| success = self.map_(F.partial(_ada_grad_opt, self.opt), lr, params, accum, | success = self.map_(F.partial(_ada_grad_opt, self.opt), lr, params, accum, | ||||
| @@ -338,9 +338,9 @@ class Adam(Optimizer): | |||||
| moment1 = self.moment1 | moment1 = self.moment1 | ||||
| moment2 = self.moment2 | moment2 = self.moment2 | ||||
| gradients = self.decay_weight(gradients) | gradients = self.decay_weight(gradients) | ||||
| gradients = self.gradients_centralization(gradients) | |||||
| gradients = self.scale_grad(gradients) | gradients = self.scale_grad(gradients) | ||||
| gradients = self._grad_sparse_indices_deduplicate(gradients) | gradients = self._grad_sparse_indices_deduplicate(gradients) | ||||
| gradients = self.gradients_centralization(gradients) | |||||
| lr = self.get_lr() | lr = self.get_lr() | ||||
| beta1_power = self.beta1_power * self.beta1 | beta1_power = self.beta1_power * self.beta1 | ||||
| @@ -218,9 +218,9 @@ class FTRL(Optimizer): | |||||
| moments = self.moments | moments = self.moments | ||||
| linear = self.linear | linear = self.linear | ||||
| grads = self.decay_weight(grads) | grads = self.decay_weight(grads) | ||||
| grads = self.gradients_centralization(grads) | |||||
| grads = self.scale_grad(grads) | grads = self.scale_grad(grads) | ||||
| grads = self._grad_sparse_indices_deduplicate(grads) | grads = self._grad_sparse_indices_deduplicate(grads) | ||||
| grads = self.gradients_centralization(grads) | |||||
| lr = self.get_lr() | lr = self.get_lr() | ||||
| success = self.map_(F.partial(_ftrl_opt, self.opt, self.sparse_opt, self._ps_push, self._ps_pull, | success = self.map_(F.partial(_ftrl_opt, self.opt, self.sparse_opt, self._ps_push, self._ps_pull, | ||||
| @@ -255,9 +255,9 @@ class LazyAdam(Optimizer): | |||||
| def construct(self, gradients): | def construct(self, gradients): | ||||
| gradients = self.decay_weight(gradients) | gradients = self.decay_weight(gradients) | ||||
| gradients = self.gradients_centralization(gradients) | |||||
| gradients = self.scale_grad(gradients) | gradients = self.scale_grad(gradients) | ||||
| gradients = self._grad_sparse_indices_deduplicate(gradients) | gradients = self._grad_sparse_indices_deduplicate(gradients) | ||||
| gradients = self.gradients_centralization(gradients) | |||||
| lr = self.get_lr() | lr = self.get_lr() | ||||
| self.beta1_power = self.beta1_power * self.beta1 | self.beta1_power = self.beta1_power * self.beta1 | ||||
| @@ -159,8 +159,8 @@ class Momentum(Optimizer): | |||||
| params = self.params | params = self.params | ||||
| moments = self.moments | moments = self.moments | ||||
| gradients = self.decay_weight(gradients) | gradients = self.decay_weight(gradients) | ||||
| gradients = self.scale_grad(gradients) | |||||
| gradients = self.gradients_centralization(gradients) | gradients = self.gradients_centralization(gradients) | ||||
| gradients = self.scale_grad(gradients) | |||||
| lr = self.get_lr() | lr = self.get_lr() | ||||
| if self.is_group_lr: | if self.is_group_lr: | ||||
| success = self.hyper_map(F.partial(_momentum_opt, self.opt, self.momentum), lr, gradients, params, moments, | success = self.hyper_map(F.partial(_momentum_opt, self.opt, self.momentum), lr, gradients, params, moments, | ||||
| @@ -439,10 +439,6 @@ class Optimizer(Cell): | |||||
| self.grad_centralization = self._preprocess_grad_centralization(group_param['grad_centralization']) | self.grad_centralization = self._preprocess_grad_centralization(group_param['grad_centralization']) | ||||
| for param in group_param['params']: | for param in group_param['params']: | ||||
| validator.check_value_type("parameter", param, [Parameter], self.cls_name) | validator.check_value_type("parameter", param, [Parameter], self.cls_name) | ||||
| if "conv" not in param.name and self.grad_centralization is True: | |||||
| raise ValueError("Grad centralization can be perform only on the conv layer. If the parameter" | |||||
| "is not a convolution layer, this parameter cannot be set to True.") | |||||
| grad_centralization_ = self.grad_centralization | grad_centralization_ = self.grad_centralization | ||||
| else: | else: | ||||
| grad_centralization_ = grad_centralization | grad_centralization_ = grad_centralization | ||||
| @@ -630,9 +626,16 @@ def _tensor_apply_grad_centralization_with_sparse(if_apply, gradient): | |||||
| """Get grad with grad_centralization.""" | """Get grad with grad_centralization.""" | ||||
| if if_apply: | if if_apply: | ||||
| indices = gradient.indices | indices = gradient.indices | ||||
| values = op_gc(gradient.values, -1) | |||||
| shape = gradient.dense_shape | shape = gradient.dense_shape | ||||
| return RowTensor(indices, values, shape) | |||||
| grad_shape = F.shape(gradient) | |||||
| axis = [] | |||||
| for i in range(1, len(grad_shape)): | |||||
| axis.append(i) | |||||
| if len(axis) >= 1: | |||||
| if grad_shape[1] % 16 != 0: | |||||
| return gradient | |||||
| values = op_gc(gradient.values, axis) | |||||
| return RowTensor(indices, values, shape) | |||||
| return gradient | return gradient | ||||
| @@ -640,7 +643,14 @@ def _tensor_apply_grad_centralization_with_sparse(if_apply, gradient): | |||||
| def _tensor_apply_grad_centralization(if_apply, gradient): | def _tensor_apply_grad_centralization(if_apply, gradient): | ||||
| """Get grad with grad_centralization.""" | """Get grad with grad_centralization.""" | ||||
| if if_apply: | if if_apply: | ||||
| return op_gc(gradient, -1) | |||||
| axis = [] | |||||
| grad_shape = F.shape(gradient) | |||||
| for i in range(1, len(grad_shape)): | |||||
| axis.append(i) | |||||
| if len(axis) >= 1: | |||||
| if grad_shape[1] % 16 != 0: | |||||
| return gradient | |||||
| return op_gc(gradient, axis) | |||||
| return gradient | return gradient | ||||
| @@ -165,9 +165,9 @@ class ProximalAdagrad(Optimizer): | |||||
| params = self.parameters | params = self.parameters | ||||
| accum = self.accum | accum = self.accum | ||||
| grads = self.decay_weight(grads) | grads = self.decay_weight(grads) | ||||
| grads = self.gradients_centralization(grads) | |||||
| grads = self.scale_grad(grads) | grads = self.scale_grad(grads) | ||||
| grads = self._grad_sparse_indices_deduplicate(grads) | grads = self._grad_sparse_indices_deduplicate(grads) | ||||
| grads = self.gradients_centralization(grads) | |||||
| lr = self.get_lr() | lr = self.get_lr() | ||||
| if self.is_group_lr: | if self.is_group_lr: | ||||
| success = self.map_(F.partial(_proximal_ada_grad_opt, self.opt, self.sparse_opt, self.l1, self.l2), lr, | success = self.map_(F.partial(_proximal_ada_grad_opt, self.opt, self.sparse_opt, self.l1, self.l2), lr, | ||||
| @@ -200,8 +200,8 @@ class RMSProp(Optimizer): | |||||
| def construct(self, gradients): | def construct(self, gradients): | ||||
| params = self.parameters | params = self.parameters | ||||
| gradients = self.decay_weight(gradients) | gradients = self.decay_weight(gradients) | ||||
| gradients = self.scale_grad(gradients) | |||||
| gradients = self.gradients_centralization(gradients) | gradients = self.gradients_centralization(gradients) | ||||
| gradients = self.scale_grad(gradients) | |||||
| lr = self.get_lr() | lr = self.get_lr() | ||||
| if self.centered: | if self.centered: | ||||
| if self.is_group_lr: | if self.is_group_lr: | ||||
| @@ -173,8 +173,8 @@ class SGD(Optimizer): | |||||
| params = self.parameters | params = self.parameters | ||||
| accum = self.accum | accum = self.accum | ||||
| stat = self.stat | stat = self.stat | ||||
| gradients = self.scale_grad(gradients) | |||||
| gradients = self.gradients_centralization(gradients) | gradients = self.gradients_centralization(gradients) | ||||
| gradients = self.scale_grad(gradients) | |||||
| lr = self.get_lr() | lr = self.get_lr() | ||||
| if self.is_group_lr: | if self.is_group_lr: | ||||
| success = self.hyper_map(F.partial(_sgd_opt, self.opt, self.momentum), lr, gradients, params, accum, stat) | success = self.hyper_map(F.partial(_sgd_opt, self.opt, self.momentum), lr, gradients, params, accum, stat) | ||||
| @@ -26,9 +26,10 @@ centralization_op_info = TBERegOp("Centralization") \ | |||||
| .attr("axis", "required", "listInt", "all") \ | .attr("axis", "required", "listInt", "all") \ | ||||
| .input(0, "x", False, "required", "all") \ | .input(0, "x", False, "required", "all") \ | ||||
| .output(0, "y", False, "required", "all") \ | .output(0, "y", False, "required", "all") \ | ||||
| .op_pattern("reduce") \ | |||||
| .dtype_format(DataType.F16_Default, DataType.F16_Default) \ | .dtype_format(DataType.F16_Default, DataType.F16_Default) \ | ||||
| .dtype_format(DataType.F32_Default, DataType.F32_Default) \ | .dtype_format(DataType.F32_Default, DataType.F32_Default) \ | ||||
| .dtype_format(DataType.F16_FracZ, DataType.F16_FracZ) \ | |||||
| .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ) \ | |||||
| .get_op_info() | .get_op_info() | ||||