| @@ -276,6 +276,8 @@ void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGrap | |||
| auto ir_fusion_pm = std::make_shared<PassManager>("ir_fusion_pm"); | |||
| ir_fusion_pm->AddPass(std::make_shared<BnSplit>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<BnGradSplit>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<SyncBnSplit>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<SyncBnGradSplit>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<LayerNormGradSplit>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<InsertPadForNMSWithMask>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<InsertPlaceholderForDynamicGRUV2>()); | |||
| @@ -18,6 +18,7 @@ | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "backend/optimizer/ascend/ir_fission/bn_split.h" | |||
| #include "utils/utils.h" | |||
| #include "utils/ms_context.h" | |||
| #include "backend/optimizer/common/helper.h" | |||
| @@ -104,6 +105,36 @@ CNodePtr BNGradSplitForTBE(const FuncGraphPtr &func_graph, const CNodePtr &cnode | |||
| MS_EXCEPTION_IF_NULL(make_tuple); | |||
| return make_tuple; | |||
| } | |||
| CNodePtr SyncBNGradSplitForTBE(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| std::vector<AnfNodePtr> bn_update_grad_outputs; | |||
| CreateOutputsOfUpdateGrad(func_graph, cnode, &bn_update_grad_outputs); | |||
| if (bn_update_grad_outputs.size() != kBNTrainingUpdateGradOutputNum) { | |||
| MS_LOG(EXCEPTION) << "bn_update_grad_outputs has wrong size" | |||
| << " trace: " << trace::DumpSourceLines(cnode); | |||
| } | |||
| std::vector<AnfNodePtr> allreduce_mul_outputs; | |||
| for (size_t i = 0; i < bn_update_grad_outputs.size(); ++i) { | |||
| auto allreduce_mul_output = CreateAllReduceAndMul(func_graph, bn_update_grad_outputs[i], cnode); | |||
| allreduce_mul_outputs.emplace_back(allreduce_mul_output); | |||
| } | |||
| std::vector<AnfNodePtr> bn_reduce_grad_outputs; | |||
| CreateOutputsOfReduceGrad(func_graph, cnode, allreduce_mul_outputs, &bn_reduce_grad_outputs); | |||
| if (bn_reduce_grad_outputs.size() != 1) { | |||
| MS_LOG(EXCEPTION) << "bn_reduce_grad_outputs has wrong size" | |||
| << " trace: " << trace::DumpSourceLines(cnode); | |||
| } | |||
| std::vector<AnfNodePtr> make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple), bn_reduce_grad_outputs[0], | |||
| allreduce_mul_outputs[0], allreduce_mul_outputs[1]}; | |||
| auto make_tuple = func_graph->NewCNode(make_tuple_inputs); | |||
| MS_EXCEPTION_IF_NULL(make_tuple); | |||
| return make_tuple; | |||
| } | |||
| } // namespace | |||
| const BaseRef BnGradSplit::DefinePattern() const { | |||
| @@ -120,5 +151,17 @@ const AnfNodePtr BnGradSplit::Process(const FuncGraphPtr &func_graph, const AnfN | |||
| } | |||
| return BNGradSplitForTBE(func_graph, cnode); | |||
| } | |||
| const BaseRef SyncBnGradSplit::DefinePattern() const { | |||
| VarPtr Xs = std::make_shared<SeqVar>(); | |||
| return VectorRef({prim::kPrimSyncBatchNormGrad, Xs}); | |||
| } | |||
| const AnfNodePtr SyncBnGradSplit::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||
| const EquivPtr &) const { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| return SyncBNGradSplitForTBE(func_graph, cnode); | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -28,6 +28,14 @@ class BnGradSplit : public PatternProcessPass { | |||
| const BaseRef DefinePattern() const override; | |||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||
| }; | |||
| class SyncBnGradSplit : public PatternProcessPass { | |||
| public: | |||
| explicit SyncBnGradSplit(bool multigraph = true) : PatternProcessPass("sync_bn_grad_split", multigraph) {} | |||
| ~SyncBnGradSplit() override = default; | |||
| const BaseRef DefinePattern() const override; | |||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_BN_GRAD_SPLIT_H_ | |||
| @@ -17,6 +17,8 @@ | |||
| #include <vector> | |||
| #include <memory> | |||
| #include <string> | |||
| #include <limits> | |||
| #include "utils/utils.h" | |||
| #include "utils/ms_context.h" | |||
| @@ -28,6 +30,9 @@ | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace { | |||
| constexpr auto kReduceOpSum = "sum"; | |||
| constexpr auto kDeviceNum = "device_num"; | |||
| bool CreateOutputsOfBNTrainingReduce(const FuncGraphPtr &graph, const CNodePtr &bn_cnode, | |||
| std::vector<AnfNodePtr> *bn_training_reduce_outputs) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| @@ -117,8 +122,105 @@ AnfNodePtr SplitBatchNormForTBE(const FuncGraphPtr &func_graph, const AnfNodePtr | |||
| // Create BNTrainingUpdate node | |||
| return CreateOutputsOfBNTrainingUpdate(func_graph, cnode, bn_training_reduce_outputs); | |||
| } | |||
| AnfNodePtr SyncBNSplitForTBE(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| if (AnfAlgo::GetInputTensorNum(cnode) < kBnInputTensorNum) { | |||
| MS_LOG(INFO) << "op[" << cnode->DebugString() << "] has less input than " << kBnInputTensorNum << " inputs."; | |||
| return nullptr; | |||
| } | |||
| // Create BNTrainingReduce node and get outputs of BNTrainingReduce | |||
| std::vector<AnfNodePtr> bn_training_reduce_outputs; | |||
| if (!CreateOutputsOfBNTrainingReduce(func_graph, cnode, &bn_training_reduce_outputs)) { | |||
| MS_LOG(WARNING) << "Create BNTrainingReduce fail, quit split"; | |||
| return nullptr; | |||
| } | |||
| if (bn_training_reduce_outputs.size() != kBN1OutputNum) { | |||
| MS_LOG(EXCEPTION) << "make outputs of op BNTrainingReduce fail" | |||
| << " trace: " << trace::DumpSourceLines(node); | |||
| } | |||
| std::vector<AnfNodePtr> allreduce_mul_outputs; | |||
| for (size_t i = 0; i < bn_training_reduce_outputs.size(); ++i) { | |||
| auto allreduce_mul_output = CreateAllReduceAndMul(func_graph, bn_training_reduce_outputs[i], cnode); | |||
| allreduce_mul_outputs.emplace_back(allreduce_mul_output); | |||
| } | |||
| // Create BNTrainingUpdate node | |||
| return CreateOutputsOfBNTrainingUpdate(func_graph, cnode, allreduce_mul_outputs); | |||
| } | |||
| } // namespace | |||
| AnfNodePtr CreateValueNodeOfDeviceNumReciprocal(const FuncGraphPtr &graph, const CNodePtr &sync_bn_cnode) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_EXCEPTION_IF_NULL(sync_bn_cnode); | |||
| if (!AnfAlgo::HasNodeAttr(kDeviceNum, sync_bn_cnode)) { | |||
| MS_LOG(EXCEPTION) << "op[" << sync_bn_cnode->DebugString() << "] does not have attr device_num."; | |||
| } | |||
| auto device_num = AnfAlgo::GetNodeAttr<int64_t>(sync_bn_cnode, kDeviceNum); | |||
| MS_LOG(INFO) << "device_num value: " << device_num; | |||
| float device_num_reciprocal = 1.0 / device_num; | |||
| std::vector<int64_t> device_num_shape = {}; | |||
| auto device_num_reciprocal_tensor = std::make_shared<tensor::Tensor>(kNumberTypeFloat32, device_num_shape); | |||
| MS_EXCEPTION_IF_NULL(device_num_reciprocal_tensor); | |||
| auto data_ptr = device_num_reciprocal_tensor->data_c(); | |||
| MS_EXCEPTION_IF_NULL(data_ptr); | |||
| auto *val = reinterpret_cast<float *>(data_ptr); | |||
| *val = device_num_reciprocal; | |||
| auto kernel_graph = graph->cast<KernelGraphPtr>(); | |||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||
| auto abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, device_num_shape); | |||
| auto device_num_reciprocal_value = kernel_graph->NewValueNode(abstract, device_num_reciprocal_tensor); | |||
| MS_EXCEPTION_IF_NULL(device_num_reciprocal_value); | |||
| kernel_graph->AddValueNodeToGraph(device_num_reciprocal_value); | |||
| return device_num_reciprocal_value; | |||
| } | |||
| AnfNodePtr CreateAllReduceAndMul(const FuncGraphPtr &graph, const AnfNodePtr &allreduce_input, | |||
| const CNodePtr &sync_bn_cnode) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_EXCEPTION_IF_NULL(allreduce_input); | |||
| MS_EXCEPTION_IF_NULL(sync_bn_cnode); | |||
| // create AllReduce | |||
| std::vector<AnfNodePtr> allreduce_inputs = {NewValueNode(std::make_shared<Primitive>(kAllReduceOpName)), | |||
| allreduce_input}; | |||
| auto allreduce = graph->NewCNode(allreduce_inputs); | |||
| MS_EXCEPTION_IF_NULL(allreduce); | |||
| allreduce->set_abstract(allreduce_input->abstract()); | |||
| allreduce->set_scope(allreduce_input->scope()); | |||
| AnfAlgo::SetNodeAttr(kAttrOp, MakeValue(kReduceOpSum), allreduce); | |||
| AnfAlgo::CopyNodeAttr(kAttrGroup, sync_bn_cnode, allreduce); | |||
| // use SyncBatchNorm's opid as AllReduce's fusion attr | |||
| auto sync_bn_opname = sync_bn_cnode->fullname_with_scope(); | |||
| auto opid_pos = sync_bn_opname.rfind("-op"); | |||
| if (opid_pos == std::string::npos) { | |||
| MS_LOG(EXCEPTION) << "op[" << sync_bn_cnode->DebugString() << "] has no opid."; | |||
| } | |||
| int64_t opid = std::stol(sync_bn_opname.substr(opid_pos + 3)); | |||
| // user defined fusion should be greater than 1 | |||
| if (opid < 2) { | |||
| opid = opid - 2 + std::numeric_limits<int64_t>::max(); | |||
| } | |||
| AnfAlgo::SetNodeAttr(kAttrFusion, MakeValue(opid), allreduce); | |||
| // create Mul | |||
| auto device_num_reciprocal_vnode = CreateValueNodeOfDeviceNumReciprocal(graph, sync_bn_cnode); | |||
| std::vector<AnfNodePtr> mul_inputs = {NewValueNode(std::make_shared<Primitive>(kMulOpName)), allreduce, | |||
| device_num_reciprocal_vnode}; | |||
| auto mul = graph->NewCNode(mul_inputs); | |||
| MS_EXCEPTION_IF_NULL(mul); | |||
| mul->set_abstract(allreduce_input->abstract()); | |||
| mul->set_scope(allreduce_input->scope()); | |||
| return mul; | |||
| } | |||
| const BaseRef BnSplit::DefinePattern() const { | |||
| VarPtr Xs = std::make_shared<SeqVar>(); | |||
| MS_EXCEPTION_IF_NULL(Xs); | |||
| @@ -132,5 +234,14 @@ const AnfNodePtr BnSplit::Process(const FuncGraphPtr &func_graph, const AnfNodeP | |||
| } | |||
| return SplitBatchNormForTBE(func_graph, node); | |||
| } | |||
| const BaseRef SyncBnSplit::DefinePattern() const { | |||
| VarPtr Xs = std::make_shared<SeqVar>(); | |||
| return VectorRef({prim::kPrimSyncBatchNorm, Xs}); | |||
| } | |||
| const AnfNodePtr SyncBnSplit::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const { | |||
| return SyncBNSplitForTBE(func_graph, node); | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -28,6 +28,19 @@ class BnSplit : public PatternProcessPass { | |||
| const BaseRef DefinePattern() const override; | |||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||
| }; | |||
| class SyncBnSplit : public PatternProcessPass { | |||
| public: | |||
| explicit SyncBnSplit(bool multigraph = true) : PatternProcessPass("sync_bn_split", multigraph) {} | |||
| ~SyncBnSplit() override = default; | |||
| const BaseRef DefinePattern() const override; | |||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||
| }; | |||
| AnfNodePtr CreateValueNodeOfDeviceNumReciprocal(const FuncGraphPtr &graph, const CNodePtr &sync_bn_cnode); | |||
| AnfNodePtr CreateAllReduceAndMul(const FuncGraphPtr &graph, const AnfNodePtr &allreduce_input, | |||
| const CNodePtr &sync_bn_cnode); | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_BN_SPLIT_H_ | |||
| @@ -228,6 +228,8 @@ inline const PrimitivePtr kPrimFusedBatchNormGrad = std::make_shared<Primitive>( | |||
| inline const PrimitivePtr kPrimFusedBatchNormGradEx = std::make_shared<Primitive>("FusedBatchNormGradEx"); | |||
| inline const PrimitivePtr kPrimBatchNorm = std::make_shared<Primitive>("BatchNorm"); | |||
| inline const PrimitivePtr kPrimBatchNormGrad = std::make_shared<Primitive>("BatchNormGrad"); | |||
| inline const PrimitivePtr kPrimSyncBatchNorm = std::make_shared<Primitive>("SyncBatchNorm"); | |||
| inline const PrimitivePtr kPrimSyncBatchNormGrad = std::make_shared<Primitive>("SyncBatchNormGrad"); | |||
| inline const PrimitivePtr kPrimReluGrad = std::make_shared<Primitive>("ReluGrad"); | |||
| inline const PrimitivePtr kPrimReluGradV2 = std::make_shared<Primitive>("ReluGradV2"); | |||
| inline const PrimitivePtr kPrimRelu6Grad = std::make_shared<Primitive>("ReLU6Grad"); | |||
| @@ -13,12 +13,17 @@ | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """normalization""" | |||
| import itertools | |||
| from mindspore.ops import operations as P | |||
| from mindspore.ops import functional as F | |||
| from mindspore.ops.operations import _inner_ops as inner | |||
| from mindspore.common.parameter import Parameter | |||
| from mindspore.common.initializer import initializer | |||
| from mindspore.common._decorator import deprecated | |||
| from mindspore.ops.primitive import constexpr | |||
| import mindspore.context as context | |||
| from mindspore._checkparam import Rel | |||
| from mindspore._checkparam import Validator as validator | |||
| from mindspore._extends import cell_attr_register | |||
| from mindspore.communication.management import get_group_size, get_rank | |||
| @@ -26,8 +31,9 @@ from mindspore.communication import management | |||
| from mindspore.ops import _selected_ops | |||
| from ..cell import Cell | |||
| __all__ = ['BatchNorm1d', 'BatchNorm2d', 'LayerNorm', 'GroupNorm', 'GlobalBatchNorm', 'InstanceNorm2d'] | |||
| __all__ = ['BatchNorm1d', 'BatchNorm2d', 'LayerNorm', 'GroupNorm', 'GlobalBatchNorm', 'SyncBatchNorm', 'InstanceNorm2d'] | |||
| SYNC_BN_GROUP_NAME = "" | |||
| class _BatchNorm(Cell): | |||
| """Batch Normalization base class.""" | |||
| @@ -44,6 +50,7 @@ class _BatchNorm(Cell): | |||
| moving_var_init='ones', | |||
| use_batch_statistics=None, | |||
| device_num_each_group=1, | |||
| process_groups=0, | |||
| input_dims='2d', | |||
| data_format='NCHW'): | |||
| super(_BatchNorm, self).__init__() | |||
| @@ -68,19 +75,47 @@ class _BatchNorm(Cell): | |||
| gamma_init, num_features), name="gamma", requires_grad=affine) | |||
| self.beta = Parameter(initializer( | |||
| beta_init, num_features), name="beta", requires_grad=affine) | |||
| self.group = validator.check_positive_int(device_num_each_group) | |||
| self.group_device_num = validator.check_positive_int(device_num_each_group) | |||
| self.process_groups = process_groups | |||
| self.is_global = False | |||
| if self.group != 1: | |||
| self.parallel_mode = context.get_auto_parallel_context("parallel_mode") | |||
| global SYNC_BN_GROUP_NAME | |||
| # for GlobalBatchNorm | |||
| if self.group_device_num != 1 and self.parallel_mode != context.ParallelMode.STAND_ALONE: | |||
| self.rank_id = get_rank() | |||
| self.rank_size = get_group_size() | |||
| self.device_list = [i for i in range(0, self.rank_size)] | |||
| self.rank_list = self.list_group(self.device_list, self.group) | |||
| self.rank_list = self.list_group(self.device_list, self.group_device_num) | |||
| self.rank_list_idx = len(self.rank_list) | |||
| for i in range(self.rank_list_idx): | |||
| if self.rank_id in self.rank_list[i] and self.group != 1: | |||
| if self.rank_id in self.rank_list[i]: | |||
| self.is_global = True | |||
| management.create_group('group' + str(i), self.rank_list[i]) | |||
| self.all_reduce = P.AllReduce(P.ReduceOp.SUM, 'group' + str(i)).add_prim_attr('fusion', 1) | |||
| if SYNC_BN_GROUP_NAME == "": | |||
| SYNC_BN_GROUP_NAME = "sync_bn_group"+ str(i) | |||
| management.create_group(SYNC_BN_GROUP_NAME, self.rank_list[i]) | |||
| # for SyncBatchNorm | |||
| if self.process_groups != 0 and self.parallel_mode != context.ParallelMode.STAND_ALONE: | |||
| self.rank_id = get_rank() | |||
| self.rank_size = get_group_size() | |||
| if self.process_groups is not None: | |||
| validator.check_isinstance("process_groups", self.process_groups, list) | |||
| self._check_rank_ids(self.process_groups, self.rank_size) | |||
| for i in range(len(self.process_groups)): | |||
| validator.check_isinstance("process_groups[" + str(i) +"]", self.process_groups[i], list) | |||
| self.group_device_num = len(self.process_groups[i]) | |||
| if self.rank_id in self.process_groups[i] and self.group_device_num > 1: | |||
| self.is_global = True | |||
| if SYNC_BN_GROUP_NAME == "": | |||
| SYNC_BN_GROUP_NAME = "sync_bn_group" + str(i) | |||
| management.create_group(SYNC_BN_GROUP_NAME, self.process_groups[i]) | |||
| elif self.rank_size > 1: | |||
| self.is_global = True | |||
| self.group_device_num = self.rank_size | |||
| self.device_list = [i for i in range(0, self.rank_size)] | |||
| if SYNC_BN_GROUP_NAME == "": | |||
| SYNC_BN_GROUP_NAME = "sync_bn_group0" | |||
| management.create_group(SYNC_BN_GROUP_NAME, self.device_list) | |||
| self.shape = P.Shape() | |||
| self.reduce_mean = P.ReduceMean(keep_dims=True) | |||
| self.square = P.Square() | |||
| @@ -109,9 +144,12 @@ class _BatchNorm(Cell): | |||
| self.bn_train = P.FusedBatchNorm(mode=1, | |||
| epsilon=self.eps, | |||
| momentum=self.momentum) | |||
| if self.is_global: | |||
| self.bn_train = inner.SyncBatchNorm(epsilon=self.eps, | |||
| momentum=self.momentum, | |||
| group=SYNC_BN_GROUP_NAME, | |||
| device_num=self.group_device_num) | |||
| self.bn_infer = P.BatchNorm(is_training=False, epsilon=self.eps, data_format=self.format) | |||
| self.enable_global_sync = self.is_global and (self.is_ge_backend or\ | |||
| (self.is_graph_mode and self._target == "Ascend")) | |||
| data_parallel_strategy = ((1,), (1,)) | |||
| data_parallel_strategy_one = ((1,), ()) | |||
| @@ -135,26 +173,13 @@ class _BatchNorm(Cell): | |||
| group_list = [list(i) for i in world_rank_list] | |||
| return group_list | |||
| def _global_sync(self, x, axes, re_shape): | |||
| """calculate global batch normalization output""" | |||
| x_mean = self.reduce_mean(x, axes) | |||
| x_mean_square = self.reduce_mean(self.square(x), axes) | |||
| global_batch_mean = self.all_reduce(x_mean) / self.group | |||
| global_batch_mean_square = self.all_reduce(x_mean_square) / self.group | |||
| global_mean = global_batch_mean | |||
| global_var = global_batch_mean_square - self.square(global_mean) | |||
| var_sqrt = self.sqrt(global_var + self.eps) | |||
| mean_first = (x - global_mean) / var_sqrt | |||
| y = mean_first * self.reshape(self.gamma, re_shape) + self.reshape(self.beta, re_shape) | |||
| mean_sub = self.sub_mean(self.reshape(self.moving_mean, re_shape), global_mean) | |||
| tmp_mean = self.mul_mean(mean_sub, self.cast(self.momentum, self.dtype(mean_sub))) | |||
| mean_sub2 = self.sub_var(self.reshape(self.moving_mean, re_shape), global_var) | |||
| tmp_variance = self.mul_var(mean_sub2, self.cast(self.momentum, self.dtype(mean_sub2))) | |||
| y = F.depend(y, self.assign_sub_mean(self.moving_mean, self.reshape(tmp_mean, self.shape(self.moving_mean)))) | |||
| y = F.depend(y, self.assign_sub_var(self.moving_variance, | |||
| self.reshape(tmp_variance, self.shape(self.moving_variance)))) | |||
| return y | |||
| def _check_rank_ids(self, process_groups, rank_size): | |||
| seen = set() | |||
| for rid in itertools.chain(*process_groups): | |||
| validator.check_int_range(rid, 0, rank_size, Rel.INC_LEFT, "rank id in process_groups") | |||
| if rid in seen: | |||
| raise ValueError("rank id in process_groups should not be duplicated.") | |||
| seen.add(rid) | |||
| def construct(self, x): | |||
| _shape_check_bn(self.shape(x), self.input_dims) | |||
| @@ -164,10 +189,6 @@ class _BatchNorm(Cell): | |||
| flag = self.use_batch_statistics | |||
| if flag: | |||
| if self.enable_global_sync: | |||
| axes, re_shape = _shape_infer(F.shape(x), self.num_features) | |||
| return self._global_sync(x, axes, re_shape) | |||
| return self.bn_train(x, | |||
| self.gamma, | |||
| self.beta, | |||
| @@ -597,6 +618,7 @@ class GlobalBatchNorm(_BatchNorm): | |||
| [ 20.9999895 241.9988 ]]]] | |||
| """ | |||
| @deprecated("1.2", "SyncBatchNorm", True) | |||
| def __init__(self, | |||
| num_features, | |||
| eps=1e-5, | |||
| @@ -619,8 +641,8 @@ class GlobalBatchNorm(_BatchNorm): | |||
| use_batch_statistics, | |||
| device_num_each_group, | |||
| input_dims='both') | |||
| self.group = validator.check_positive_int(device_num_each_group) | |||
| if self.group <= 1: | |||
| self.group_device_num = validator.check_positive_int(device_num_each_group) | |||
| if self.group_device_num <= 1: | |||
| raise ValueError("the number of group must be greater than 1.") | |||
| def _check_data_dim(self, x): | |||
| @@ -628,6 +650,121 @@ class GlobalBatchNorm(_BatchNorm): | |||
| pass | |||
| class SyncBatchNorm(_BatchNorm): | |||
| r""" | |||
| Sync Batch normalization layer over a N-dimension input. | |||
| Sync Batch Normalization is cross device synchronized batch normalization. The implementation of Batch | |||
| Normalization only normalizes the data within each device. Sync Batch normalization will normalize the input | |||
| within the group. It has been described in the paper `Batch Normalization: Accelerating Deep Network Training by | |||
| Reducing Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`_. It rescales and recenters the | |||
| feature using a mini-batch of data and the learned parameters which can be described in the following formula. | |||
| .. math:: | |||
| y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta | |||
| Note: | |||
| Currently, SyncBatchNorm only supports 2D and 4D inputs. | |||
| Args: | |||
| num_features (int): `C` from an expected input of size (N, C, H, W). | |||
| eps (float): A value added to the denominator for numerical stability. Default: 1e-5. | |||
| momentum (float): A floating hyperparameter of the momentum for the | |||
| running_mean and running_var computation. Default: 0.9. | |||
| affine (bool): A bool value. When set to True, gamma and beta can be learned. Default: True. | |||
| gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight. | |||
| The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform', | |||
| 'he_uniform', etc. Default: 'ones'. | |||
| beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight. | |||
| The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform', | |||
| 'he_uniform', etc. Default: 'zeros'. | |||
| moving_mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving mean. | |||
| The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform', | |||
| 'he_uniform', etc. Default: 'zeros'. | |||
| moving_var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving variance. | |||
| The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform', | |||
| 'he_uniform', etc. Default: 'ones'. | |||
| use_batch_statistics (bool): If true, use the mean value and variance value of current batch data. If false, | |||
| use the mean value and variance value of specified value. If None, training process will use the mean and | |||
| variance of current batch data and track the running mean and variance, eval process will use the running | |||
| mean and variance. Default: None. | |||
| process_groups (list): A list to divide devices into different sync groups, containing N subtraction lists. | |||
| Each subtraction list contains int numbers identifying rank ids which need to be synchronized in the same | |||
| group. All int values must be in [0, rank_size) and different from each other. Default: None, indicating | |||
| synchronization across all devices. | |||
| Inputs: | |||
| - **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`. | |||
| Outputs: | |||
| Tensor, the normalized, scaled, offset tensor, of shape :math:`(N, C_{out}, H_{out}, W_{out})`. | |||
| Raises: | |||
| TypeError: If `num_features` is not an int. | |||
| TypeError: If `eps` is not a float. | |||
| TypeError: If `process_groups` is not a list. | |||
| ValueError: If `num_features` is less than 1. | |||
| ValueError: If `momentum` is not in range [0, 1]. | |||
| ValueError: If `device_num_each_group` is less than 2. | |||
| Supported Platforms: | |||
| ``Ascend`` | |||
| Examples: | |||
| >>> # This example should be run with multiple processes. | |||
| >>> # Please refer to the tutorial > Distributed Training on mindspore.cn. | |||
| >>> import numpy as np | |||
| >>> from mindspore.communication import init | |||
| >>> from mindspore import context | |||
| >>> from mindspore.context import ParallelMode | |||
| >>> from mindspore import nn, Tensor | |||
| >>> from mindspore.common import dtype as mstype | |||
| >>> | |||
| >>> context.set_context(mode=context.GRAPH_MODE) | |||
| >>> init() | |||
| >>> context.reset_auto_parallel_context() | |||
| >>> context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL) | |||
| >>> np.random.seed(0) | |||
| >>> sync_bn_op = nn.SyncBatchNorm(num_features=3, process_groups=[[0, 1], [2, 3]]) | |||
| >>> input = Tensor(np.random.randint(0, 255, [1, 3, 2, 2]), mstype.float32) | |||
| >>> output = sync_bn_op(input) | |||
| >>> print(output) | |||
| [[[[171.99915 46.999763] | |||
| [116.99941 191.99904 ]] | |||
| [[ 66.999664 250.99875 ] | |||
| [194.99902 102.99948 ]] | |||
| [[ 8.999955 210.99895 ] | |||
| [ 20.9999895 241.9988 ]]]] | |||
| """ | |||
| def __init__(self, | |||
| num_features, | |||
| eps=1e-5, | |||
| momentum=0.9, | |||
| affine=True, | |||
| gamma_init='ones', | |||
| beta_init='zeros', | |||
| moving_mean_init='zeros', | |||
| moving_var_init='ones', | |||
| use_batch_statistics=None, | |||
| process_groups=None): | |||
| super(SyncBatchNorm, self).__init__(num_features, | |||
| eps, | |||
| momentum, | |||
| affine, | |||
| gamma_init, | |||
| beta_init, | |||
| moving_mean_init, | |||
| moving_var_init, | |||
| use_batch_statistics, | |||
| process_groups=process_groups, | |||
| input_dims='both') | |||
| def _check_data_dim(self, x): | |||
| if x.dim == 0: | |||
| pass | |||
| class LayerNorm(Cell): | |||
| r""" | |||
| Applies Layer Normalization over a mini-batch of inputs. | |||
| @@ -17,6 +17,8 @@ | |||
| from .. import operations as P | |||
| from .. import composite as C | |||
| from ..operations import _grad_ops as G | |||
| from ..operations import _inner_ops as inner | |||
| from ..composite.multitype_ops.zeros_like_impl import zeros_like | |||
| from .grad_base import bprop_getters | |||
| @@ -64,5 +66,20 @@ def bprop_pqc(self): | |||
| dx = t(dx, (1, 0)) | |||
| dy = C.tensor_dot(dout[0], out[2], ((0, 1), (0, 1))) | |||
| return dx, dy | |||
| return bprop | |||
| @bprop_getters.register(inner.SyncBatchNorm) | |||
| def get_bprop_sync_batch_norm(self): | |||
| """Grad definition for `SyncBatchNorm` operation.""" | |||
| input_grad = G.SyncBatchNormGrad(self.epsilon, self.group, self.device_num) | |||
| def bprop(x, scale, b, mean, variance, out, dout): | |||
| saved_mean = out[3] | |||
| saved_variance = out[4] | |||
| out = input_grad(dout[0], x, scale, saved_mean, saved_variance) | |||
| dx = out[0] | |||
| dscale = out[1] | |||
| dbias = out[2] | |||
| return dx, dscale, dbias, zeros_like(mean), zeros_like(variance) | |||
| return bprop | |||
| @@ -204,6 +204,24 @@ class BatchNormGrad(PrimitiveWithInfer): | |||
| return (x_type, scale_type, scale_type, reserve_1_type, reserve_2_type) | |||
| class SyncBatchNormGrad(PrimitiveWithInfer): | |||
| """Performs grad of SyncBatchNorm operation.""" | |||
| @prim_attr_register | |||
| def __init__(self, epsilon=1e-5, group="group0", device_num=2): | |||
| validator.check_float_range(epsilon, 0, 1, Rel.INC_RIGHT, 'epsilon', self.name) | |||
| if not isinstance(group, str): | |||
| raise TypeError("The group attr of SyncBatchNormGrad should be str.") | |||
| validator.check_int(device_num, 2, Rel.GE, "device_num", self.name) | |||
| def infer_shape(self, y_backprop_shape, x_shape, scale_shape, save_mean_shape, save_variance_shape): | |||
| validator.check("BatchNorm y_backprop_shape", y_backprop_shape, "BatchNorm x_shape", x_shape) | |||
| return (x_shape, scale_shape, scale_shape) | |||
| def infer_dtype(self, y_backprop_type, x_type, scale_type, save_mean_shape, save_variance_shape): | |||
| return (x_type, scale_type, scale_type) | |||
| class BiasAddGrad(PrimitiveWithInfer): | |||
| """Computes gradients of BiasAdd.""" | |||
| @@ -630,6 +630,7 @@ class GpuConvertToDynamicShape(PrimitiveWithCheck): | |||
| def check_dtype(self, input_dtype): | |||
| validator.check_subclass("input_dtype", input_dtype, mstype.tensor, self.name) | |||
| class ErrorOnDynamicShapeInput(PrimitiveWithInfer): | |||
| """ | |||
| This op is used for dynamic shape testing. The only purpose of this operator is | |||
| @@ -724,3 +725,93 @@ class SequenceMask(PrimitiveWithCheck): | |||
| def check_dtype(self, lengths_dtype, maxlen_dtype): | |||
| validator.check_subclass("lengths_dtype", lengths_dtype, mstype.tensor, self.name) | |||
| validator.check_subclass("maxlen", maxlen_dtype, mstype.number, self.name) | |||
| class SyncBatchNorm(PrimitiveWithInfer): | |||
| r""" | |||
| Sync Batch Normalization for input data and updated parameters. | |||
| Sync Batch Normalization is cross device synchronized batch normalization. Batch Normalization is | |||
| widely used in convolutional neural networks. This operation applies Batch Normalization over input | |||
| to avoid internal covariate shift as described in the paper `Batch Normalization: Accelerating | |||
| Deep Network Training by Reducing Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`_. | |||
| It rescales and recenters the features using a mini-batch of data and the learned parameters which | |||
| can be described in the following formula, | |||
| .. math:: | |||
| y = \frac{x - mean}{\sqrt{variance + \epsilon}} * \gamma + \beta | |||
| where :math:`\gamma` is scale, :math:`\beta` is bias, :math:`\epsilon` is epsilon. | |||
| Args: | |||
| epsilon (float): A small value added for numerical stability. Default: 1e-5. | |||
| momentum (float): The hyper parameter to compute moving average for running_mean and running_var | |||
| (e.g. :math:`new\_running\_mean = (1 - momentum) * running\_mean + momentum * current\_mean`). | |||
| Momentum value must be [0, 1]. Default: 0.1. | |||
| group (str): The communication group to work on. Default: "sync_bn_group0". | |||
| device_num (int): The number of devices in each group. Default: 2. | |||
| Inputs: | |||
| - **input_x** (Tensor) - Tensor of shape :math:`(N, C)`, with float16 or float32 data type. | |||
| - **scale** (Tensor) - Tensor of shape :math:`(C,)`, with float16 or float32 data type. | |||
| - **bias** (Tensor) - Tensor of shape :math:`(C,)`, has the same data type with `scale`. | |||
| - **mean** (Tensor) - Tensor of shape :math:`(C,)`, with float16 or float32 data type. | |||
| - **variance** (Tensor) - Tensor of shape :math:`(C,)`, has the same data type with `mean`. | |||
| Outputs: | |||
| Tuple of 5 Tensor, the normalized inputs and the updated parameters. | |||
| - **output_x** (Tensor) - The same type and shape as the input_x. The shape is :math:`(N, C)`. | |||
| - **updated_scale** (Tensor) - Tensor of shape :math:`(C,)`. | |||
| - **updated_bias** (Tensor) - Tensor of shape :math:`(C,)`. | |||
| - **updated_moving_mean** (Tensor) - Tensor of shape :math:`(C,)`. | |||
| - **updated_moving_variance** (Tensor) - Tensor of shape :math:`(C,)`. | |||
| Supported Platforms: | |||
| ``Ascend`` | |||
| Examples: | |||
| >>> # This example should be run with multiple processes. | |||
| >>> # Please refer to nn.SyncBatchNorm for direct use. | |||
| >>> input_x = Tensor(np.ones([2, 2]), mindspore.float32) | |||
| >>> scale = Tensor(np.ones([2]), mindspore.float32) | |||
| >>> bias = Tensor(np.ones([2]), mindspore.float32) | |||
| >>> mean = Tensor(np.ones([2]), mindspore.float32) | |||
| >>> variance = Tensor(np.ones([2]), mindspore.float32) | |||
| >>> sync_batch_norm = ops._inner_ops.SyncBatchNorm() | |||
| >>> output = sync_batch_norm(input_x, scale, bias, mean, variance) | |||
| >>> print(output) | |||
| (Tensor(shape=[2, 2], dtype=Float32, value= | |||
| [[ 1.00000000e+00, 1.00000000e+00], | |||
| [ 1.00000000e+00, 1.00000000e+00]]), Tensor(shape=[2], dtype=Float32, value= | |||
| [ 1.00000000e+00, 1.00000000e+00]), Tensor(shape=[2], dtype=Float32, value= | |||
| [ 1.00000000e+00, 1.00000000e+00]), Tensor(shape=[2], dtype=Float32, value= | |||
| [ 1.00000000e+00, 1.00000000e+00]), Tensor(shape=[2], dtype=Float32, value= | |||
| [ 1.00000000e+00, 1.00000000e+00])) | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self, epsilon=1e-5, momentum=0.1, group="sync_bn_group0", device_num=2): | |||
| validator.check_float_range(epsilon, 0, 1, Rel.INC_RIGHT, 'epsilon', self.name) | |||
| validator.check_float_range(momentum, 0, 1, Rel.INC_BOTH, 'momentum', self.name) | |||
| validator.check_isinstance("group", group, str) | |||
| validator.check_int(device_num, 2, Rel.GE, "device_num", self.name) | |||
| self.init_prim_io_names(inputs=['x', 'scale', 'offset', 'mean', 'variance'], | |||
| outputs=['y', 'batch_mean', 'batch_variance', 'reserve_space_1', 'reserve_space_2']) | |||
| def infer_shape(self, input_x, scale, bias, mean, variance): | |||
| validator.check_equal_int(len(scale), 1, "scale rank", self.name) | |||
| validator.check("scale shape", scale, "bias shape", bias, Rel.EQ, self.name) | |||
| validator.check("scale shape[0]", scale[0], "input_x channel", input_x[1], Rel.EQ, self.name) | |||
| validator.check_equal_int(len(mean), 1, "mean rank", self.name) | |||
| validator.check("mean shape", mean, "variance shape", variance, Rel.EQ, self.name) | |||
| validator.check("mean shape", mean, "scale shape", scale, Rel.EQ, self.name) | |||
| return (input_x, scale, scale, scale, scale) | |||
| def infer_dtype(self, input_x, scale, bias, mean, variance): | |||
| validator.check_tensor_dtype_valid("input_x", input_x, [mstype.float16, mstype.float32], self.name) | |||
| args = {"scale": scale, "bias": bias} | |||
| validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name) | |||
| args_moving = {"mean": mean, "variance": variance} | |||
| validator.check_tensors_dtypes_same_and_valid(args_moving, [mstype.float16, mstype.float32], self.name) | |||
| return (input_x, scale, bias, input_x, input_x) | |||
| @@ -100,5 +100,67 @@ TEST_F(TestHWBnGradSplit, test_bn_grad_split_tbe) { | |||
| FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_bn_grad_split", "after2"); | |||
| EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); | |||
| } | |||
| TEST_F(TestHWBnGradSplit, test_sync_bn_grad_split_tbe) { | |||
| get_py_fun_.SetDoResolve(true); | |||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_sync_bn_grad_split", "before"); | |||
| ASSERT_TRUE(g != nullptr); | |||
| std::vector<int64_t> shp_x{1, 64, 112, 112}; | |||
| std::vector<int64_t> shp_b{64}; | |||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x); | |||
| auto b_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_b); | |||
| AbstractBasePtrList args_spec_list{x_abstract, x_abstract, b_abstract, b_abstract, b_abstract}; | |||
| auto kernel_graph = GetKernelGraph(g, args_spec_list); | |||
| EXPECT_NE(kernel_graph, nullptr); | |||
| // get SyncBNGrad | |||
| CNodePtr ret = kernel_graph->get_return(); | |||
| EXPECT_NE(ret, nullptr); | |||
| EXPECT_NE(ret->input(1), nullptr); | |||
| EXPECT_TRUE(ret->input(1)->isa<CNode>()); | |||
| auto make_tuple1 = ret->input(1)->cast<CNodePtr>(); | |||
| EXPECT_NE(make_tuple1->input(1), nullptr); | |||
| EXPECT_TRUE(make_tuple1->input(1)->isa<CNode>()); | |||
| auto make_tuple2 = make_tuple1->input(1)->cast<CNodePtr>(); | |||
| EXPECT_NE(make_tuple2->input(1), nullptr); | |||
| EXPECT_TRUE(make_tuple2->input(1)->isa<CNode>()); | |||
| auto tuple_getitem = make_tuple2->input(1)->cast<CNodePtr>(); | |||
| EXPECT_NE(tuple_getitem->input(1), nullptr); | |||
| EXPECT_TRUE(tuple_getitem->input(1)->isa<CNode>()); | |||
| auto bn_grad = tuple_getitem->input(1)->cast<CNodePtr>(); | |||
| // get param1 | |||
| EXPECT_NE(bn_grad->input(1), nullptr); | |||
| auto param1 = bn_grad->input(1); | |||
| // set kernel for param1 | |||
| kernel::KernelBuildInfo::KernelBuildInfoBuilder builder2; | |||
| builder2.SetOutputsFormat({kOpFormat_NC1HWC0}); | |||
| builder2.SetOutputsDeviceType({kNumberTypeFloat32}); | |||
| AnfAlgo::SetSelectKernelBuildInfo(builder2.Build(), param1.get()); | |||
| // set kernel for SyncBNGrad | |||
| kernel::KernelBuildInfo::KernelBuildInfoBuilder builder1; | |||
| builder1.SetInputsFormat( | |||
| {kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0}); | |||
| builder1.SetOutputsFormat( | |||
| {kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0}); | |||
| builder1.SetInputsDeviceType( | |||
| {kNumberTypeFloat32, kNumberTypeFloat32, kNumberTypeFloat32, kNumberTypeFloat32, kNumberTypeFloat32}); | |||
| builder1.SetOutputsDeviceType( | |||
| {kNumberTypeFloat32, kNumberTypeFloat32, kNumberTypeFloat32, kNumberTypeFloat32, kNumberTypeFloat32}); | |||
| builder1.SetKernelType(TBE_KERNEL); | |||
| AnfAlgo::SetSelectKernelBuildInfo(builder1.Build(), bn_grad.get()); | |||
| // do sync_bn_grad_split pass | |||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||
| auto pm = std::make_shared<opt::PassManager>(); | |||
| auto pass = std::make_shared<opt::SyncBnGradSplit>(); | |||
| pm->AddPass(pass); | |||
| optimizer->AddPassManager(pm); | |||
| auto new_graph = optimizer->Optimize(kernel_graph); | |||
| FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_sync_bn_grad_split", "after"); | |||
| EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -86,7 +86,7 @@ TEST_F(TestHWBnSplit, test_bn_split_tbe) { | |||
| builder.SetKernelType(KernelType::TBE_KERNEL); | |||
| AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), bn.get()); | |||
| // do bn_grad_split_pass | |||
| // do bn_split_pass | |||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||
| auto pm = std::make_shared<opt::PassManager>(); | |||
| auto pass = std::make_shared<opt::BnSplit>(); | |||
| @@ -97,5 +97,54 @@ TEST_F(TestHWBnSplit, test_bn_split_tbe) { | |||
| FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_bn_split_tbe", "after"); | |||
| EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); | |||
| } | |||
| TEST_F(TestHWBnSplit, test_sync_bn_split_tbe) { | |||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_sync_bn_split_tbe", "before"); | |||
| ASSERT_TRUE(g != nullptr); | |||
| std::vector<int64_t> shp_x{1, 64, 112, 112}; | |||
| std::vector<int64_t> shp_b{64}; | |||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x); | |||
| auto b_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_b); | |||
| AbstractBasePtrList args_spec_list{x_abstract, b_abstract, b_abstract, b_abstract, b_abstract}; | |||
| auto kernel_graph = GetKernelGraph(g, args_spec_list); | |||
| // get kernel | |||
| auto ret = kernel_graph->get_return(); | |||
| EXPECT_NE(ret, nullptr); | |||
| EXPECT_TRUE(ret->inputs().size() == 2); | |||
| auto make_tuple = ret->input(1)->cast<CNodePtr>(); | |||
| EXPECT_NE(make_tuple, nullptr); | |||
| EXPECT_TRUE(make_tuple->inputs().size() == 2); | |||
| auto item0 = make_tuple->input(1)->cast<CNodePtr>(); | |||
| EXPECT_NE(item0, nullptr); | |||
| EXPECT_TRUE(item0->inputs().size() == 3); | |||
| auto bn = item0->input(1); | |||
| EXPECT_NE(bn, nullptr); | |||
| EXPECT_TRUE(bn->isa<CNode>()); | |||
| // set kernel for SyncBN | |||
| kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; | |||
| builder.SetInputsFormat( | |||
| {kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0}); | |||
| builder.SetOutputsFormat( | |||
| {kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0}); | |||
| builder.SetInputsDeviceType( | |||
| {kNumberTypeFloat32, kNumberTypeFloat32, kNumberTypeFloat32, kNumberTypeFloat32, kNumberTypeFloat32}); | |||
| builder.SetOutputsDeviceType( | |||
| {kNumberTypeFloat32, kNumberTypeFloat32, kNumberTypeFloat32, kNumberTypeFloat32, kNumberTypeFloat32}); | |||
| builder.SetKernelType(KernelType::TBE_KERNEL); | |||
| AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), bn.get()); | |||
| // do sync_bn_split_pass | |||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||
| auto pm = std::make_shared<opt::PassManager>(); | |||
| auto pass = std::make_shared<opt::SyncBnSplit>(); | |||
| pm->AddPass(pass); | |||
| optimizer->AddPassManager(pm); | |||
| auto new_graph = optimizer->Optimize(kernel_graph); | |||
| FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_sync_bn_split_tbe", "after"); | |||
| EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -16,15 +16,21 @@ | |||
| from mindspore.ops import Primitive | |||
| from mindspore.ops.operations import _grad_ops as G | |||
| from mindspore.ops import _constants as Constants | |||
| from mindspore.common.tensor import Tensor | |||
| import mindspore.common.dtype as mstype | |||
| make_tuple = Primitive('make_tuple') | |||
| tuple_getitem = Primitive(Constants.kTupleGetItem) | |||
| bn_grad = G.BatchNormGrad(is_training=True) | |||
| sync_bn_grad = G.SyncBatchNormGrad() | |||
| bn_grad1 = Primitive('BNGrad1') | |||
| bn_grad2 = Primitive('BNGrad2') | |||
| bn_grad3 = Primitive('BNGrad3') | |||
| bn_training_update_grad = Primitive('BNTrainingUpdateGrad') | |||
| bn_training_reduce_grad = Primitive('BNTrainingReduceGrad') | |||
| allreduce = Primitive('AllReduce') | |||
| mul = Primitive('Mul') | |||
| mul_value = Tensor(0.5, mstype.float32) | |||
| class FnDict: | |||
| @@ -85,3 +91,36 @@ def test_bn_grad_split(tag): | |||
| return make_tuple(output) | |||
| return fns[tag] | |||
| def test_sync_bn_grad_split(tag): | |||
| """ test_sync_bn_grad_split """ | |||
| fns = FnDict() | |||
| @fns | |||
| def before(i0, i1, i2, i3, i4): | |||
| bn_grad_output = sync_bn_grad(i0, i1, i2, i3, i4) | |||
| item0 = tuple_getitem(bn_grad_output, 0) | |||
| item1 = tuple_getitem(bn_grad_output, 1) | |||
| item2 = tuple_getitem(bn_grad_output, 2) | |||
| output = make_tuple(item0, item1, item2) | |||
| return output | |||
| @fns | |||
| def after(i0, i1, i2, i3, i4): | |||
| bn_update_grad_output = bn_training_update_grad(i0, i1, i3, i4) | |||
| update_output0 = tuple_getitem(bn_update_grad_output, 0) | |||
| update_output1 = tuple_getitem(bn_update_grad_output, 1) | |||
| allreduce_output0 = allreduce(update_output0) | |||
| allreduce_output1 = allreduce(update_output1) | |||
| update_item0 = mul(allreduce_output0, mul_value) | |||
| update_item1 = mul(allreduce_output1, mul_value) | |||
| bn_reduce_grad_output = bn_training_reduce_grad(i0, i1, update_item0, update_item1, i2, i3, i4) | |||
| output = make_tuple(bn_reduce_grad_output, update_item0, update_item1) | |||
| item0 = tuple_getitem(output, 0) | |||
| item1 = tuple_getitem(output, 1) | |||
| item2 = tuple_getitem(output, 2) | |||
| output = make_tuple(item0, item1, item2) | |||
| return make_tuple(output) | |||
| return fns[tag] | |||
| @@ -15,16 +15,23 @@ | |||
| from mindspore.ops import Primitive | |||
| from mindspore.ops import operations as P | |||
| from mindspore.ops.operations import _inner_ops as inner | |||
| from mindspore.ops import _constants as Constants | |||
| from mindspore.common.tensor import Tensor | |||
| import mindspore.common.dtype as mstype | |||
| make_tuple = Primitive('make_tuple') | |||
| tuple_getitem = Primitive(Constants.kTupleGetItem) | |||
| bn = P.BatchNorm(is_training=True) | |||
| sync_bn = inner.SyncBatchNorm() | |||
| fused_bn1 = Primitive('FusedBN1') | |||
| fused_bn2 = Primitive('FusedBN2') | |||
| fused_bn3 = Primitive('FusedBN3') | |||
| bn_training_reduce = Primitive('BNTrainingReduce') | |||
| bn_training_update = Primitive('BNTrainingUpdate') | |||
| allreduce = Primitive('AllReduce') | |||
| mul = Primitive('Mul') | |||
| mul_value = Tensor(0.5, mstype.float32) | |||
| class FnDict: | |||
| @@ -89,3 +96,30 @@ def test_bn_split_tbe(tag): | |||
| return make_tuple(output) | |||
| return fns[tag] | |||
| def test_sync_bn_split_tbe(tag): | |||
| """ test_sync_split_bn_fusion """ | |||
| fns = FnDict() | |||
| @fns | |||
| def before(x, scale, b, mean, variance): | |||
| bn_output = sync_bn(x, scale, b, mean, variance) | |||
| output = tuple_getitem(bn_output, 0) | |||
| return output | |||
| @fns | |||
| def after(x, scale, b, mean, variance): | |||
| bn_training_reduce_output = bn_training_reduce(x) | |||
| bn_training_reduce_output0 = tuple_getitem(bn_training_reduce_output, 0) | |||
| bn_training_reduce_output1 = tuple_getitem(bn_training_reduce_output, 1) | |||
| allreduce_output0 = allreduce(bn_training_reduce_output0) | |||
| allreduce_output1 = allreduce(bn_training_reduce_output1) | |||
| bn_training_update_input1 = mul(allreduce_output0, mul_value) | |||
| bn_training_update_input2 = mul(allreduce_output1, mul_value) | |||
| bn_training_update_output = bn_training_update(x, bn_training_update_input1, bn_training_update_input2, | |||
| scale, b, mean, variance) | |||
| output = tuple_getitem(bn_training_update_output, 0) | |||
| return make_tuple(output) | |||
| return fns[tag] | |||
| @@ -1755,6 +1755,16 @@ test_case_nn_ops = [ | |||
| 'desc_inputs': [[128, 64, 32, 32], [128, 64, 32, 32], [64], [64], [64]], | |||
| 'desc_bprop': [[128, 64, 32, 32], [64], [64], [64], [64]], | |||
| 'skip': ['backward']}), | |||
| ('SyncBatchNorm', { | |||
| 'block': inner.SyncBatchNorm(), | |||
| 'desc_inputs': [[128, 64, 32, 32], [64], [64], [64], [64]], | |||
| 'desc_bprop': [[128, 64, 32, 32], [64], [64], [64], [64]], | |||
| 'skip': []}), | |||
| ('SyncBatchNormGrad', { | |||
| 'block': G.SyncBatchNormGrad(), | |||
| 'desc_inputs': [[128, 64, 32, 32], [128, 64, 32, 32], [64], [64], [64]], | |||
| 'desc_bprop': [[128, 64, 32, 32], [64], [64], [64], [64]], | |||
| 'skip': ['backward']}), | |||
| ('TopK', { | |||
| 'block': P.TopK(), | |||
| 'desc_const': [5], | |||