From: @liu_xiao_93 Reviewed-by: @liangchenghui,@wuxuejian Signed-off-by: @liangchenghuitags/v1.2.0-rc1
| @@ -64,6 +64,8 @@ | |||
| #include "backend/optimizer/ascend/ir_fusion/derelu_fusion.h" | |||
| #include "backend/optimizer/ascend/ir_fusion/batchnorm_to_bninfer.h" | |||
| #include "backend/optimizer/ascend/ir_fusion/batchnormgrad_to_bninfergrad.h" | |||
| #include "backend/optimizer/ascend/ir_fusion/batchnorm_to_batchnorm3d.h" | |||
| #include "backend/optimizer/ascend/ir_fusion/batchnorm_grad_to_batchnorm3d_grad.h" | |||
| #include "backend/optimizer/ascend/ir_fusion/confusion_mul_grad_fusion.h" | |||
| #include "backend/optimizer/ascend/ir_fusion/softmax_grad_ext_fusion.h" | |||
| #include "backend/optimizer/ascend/format_type/insert_trans_op.h" | |||
| @@ -276,6 +278,8 @@ void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGrap | |||
| } | |||
| auto optimizer = std::make_shared<GraphOptimizer>(); | |||
| auto ir_fusion_pm = std::make_shared<PassManager>("ir_fusion_pm"); | |||
| ir_fusion_pm->AddPass(std::make_shared<BatchNorm2BatchNorm3D>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<BatchNormGrad2BatchNorm3DGRAD>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<BnSplit>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<BnGradSplit>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<SyncBnSplit>()); | |||
| @@ -321,6 +325,8 @@ void RunOpAscendBackendIRFusionOptimization(const std::shared_ptr<session::Kerne | |||
| auto ir_fusion_pm = std::make_shared<PassManager>("ir_fusion_pm"); | |||
| ir_fusion_pm->AddPass(std::make_shared<SplitFission>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<SplitVFission>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<BatchNorm2BatchNorm3D>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<BatchNormGrad2BatchNorm3DGRAD>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<BnSplit>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<BnGradSplit>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<LayerNormGradSplit>()); | |||
| @@ -0,0 +1,85 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/optimizer/ascend/ir_fusion/batchnorm_grad_to_batchnorm3d_grad.h" | |||
| #include <memory> | |||
| #include <string> | |||
| #include <vector> | |||
| #include "backend/session/anf_runtime_algorithm.h" | |||
| #include "ir/primitive.h" | |||
| #include "utils/utils.h" | |||
| #include "base/core_ops.h" | |||
| #include "abstract/abstract_value.h" | |||
| #include "backend/optimizer/common/helper.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace { | |||
| constexpr size_t kBN3DGradInputXIndex = 2; | |||
| CNodePtr CreateBatchNorm3DGrad(const FuncGraphPtr &graph, const CNodePtr &batchnorm_grad) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_EXCEPTION_IF_NULL(batchnorm_grad); | |||
| auto prim = std::make_shared<Primitive>(kBatchNorm3DGradOpName); | |||
| std::vector<AnfNodePtr> inputs = {NewValueNode(prim)}; | |||
| for (size_t i = 1; i < batchnorm_grad->size(); ++i) { | |||
| inputs.push_back(batchnorm_grad->input(i)); | |||
| } | |||
| auto new_node = graph->NewCNode(inputs); | |||
| MS_EXCEPTION_IF_NULL(new_node); | |||
| new_node->set_scope(batchnorm_grad->scope()); | |||
| new_node->set_abstract(batchnorm_grad->abstract()); | |||
| AnfAlgo::CopyNodeAttrs(batchnorm_grad, new_node); | |||
| return new_node; | |||
| } | |||
| bool NeedFusion(const FuncGraphPtr &graph, const CNodePtr &batchnorm_grad) { | |||
| MS_EXCEPTION_IF_NULL(batchnorm_grad); | |||
| if (AnfAlgo::GetInputTensorNum(batchnorm_grad) < kBNGradInputTensorNum) { | |||
| MS_LOG(INFO) << "BatchNormGrad's input less than " << kBNGradInputTensorNum; | |||
| return false; | |||
| } | |||
| auto format = AnfAlgo::GetNodeAttr<std::string>(batchnorm_grad, kAttrFormat); | |||
| const auto &ori_inputs = batchnorm_grad->inputs(); | |||
| auto x_shape = AnfAlgo::GetOutputInferShape(ori_inputs[kBN3DGradInputXIndex], 0); | |||
| if (format != kOpFormat_NCDHW || x_shape.size() != 5) { | |||
| MS_LOG(INFO) << "Only format is NCDHW and the input dim of BatchNormGrad is 5, then do fusion. But format is: " | |||
| << format << ", size of x_shape is: " << x_shape.size(); | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| } // namespace | |||
| const BaseRef BatchNormGrad2BatchNorm3DGRAD::DefinePattern() const { | |||
| VarPtr Xs = std::make_shared<SeqVar>(); | |||
| MS_EXCEPTION_IF_NULL(Xs); | |||
| VectorRef pattern({prim::kPrimBatchNormGrad, Xs}); | |||
| return pattern; | |||
| } | |||
| const AnfNodePtr BatchNormGrad2BatchNorm3DGRAD::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, | |||
| const EquivPtr &) const { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto cnode_bn_grad = node->cast<CNodePtr>(); | |||
| if (!NeedFusion(graph, cnode_bn_grad)) { | |||
| return nullptr; | |||
| } | |||
| auto bn_3d_grad = CreateBatchNorm3DGrad(graph, cnode_bn_grad); | |||
| TransferDepend(cnode_bn_grad, graph, bn_3d_grad); | |||
| return bn_3d_grad; | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,34 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_BATCHNORM_GRAD_TO_BATCHNORM_3D_GRAD_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_BATCHNORM_GRAD_TO_BATCHNORM_3D_GRAD_H_ | |||
| #include <memory> | |||
| #include "backend/optimizer/common/optimizer.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| class BatchNormGrad2BatchNorm3DGRAD : public PatternProcessPass { | |||
| public: | |||
| explicit BatchNormGrad2BatchNorm3DGRAD(bool multigraph = true) | |||
| : PatternProcessPass("batchnorm_grad_to_batchnorm3d_grad", multigraph) {} | |||
| ~BatchNormGrad2BatchNorm3DGRAD() override = default; | |||
| const BaseRef DefinePattern() const override; | |||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_BATCHNORM_GRAD_TO_BATCHNORM_3D_GRAD_H_ | |||
| @@ -0,0 +1,104 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/optimizer/ascend/ir_fusion/batchnorm_to_batchnorm3d.h" | |||
| #include <memory> | |||
| #include <string> | |||
| #include <vector> | |||
| #include "backend/session/anf_runtime_algorithm.h" | |||
| #include "ir/primitive.h" | |||
| #include "utils/utils.h" | |||
| #include "base/core_ops.h" | |||
| #include "abstract/abstract_value.h" | |||
| #include "backend/optimizer/common/helper.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace { | |||
| constexpr size_t kBN3InputXIndex = 1; | |||
| constexpr size_t kBn3DTrainInputTensorNum = 3; | |||
| CNodePtr CreateBatchNorm3D(const FuncGraphPtr &graph, const CNodePtr &batchnorm) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_EXCEPTION_IF_NULL(batchnorm); | |||
| auto prim = std::make_shared<Primitive>(kBatchNorm3DOpName); | |||
| std::vector<AnfNodePtr> inputs = {NewValueNode(prim)}; | |||
| auto is_training = AnfAlgo::GetNodeAttr<bool>(batchnorm, kAttrIsTraining); | |||
| for (size_t i = 1; i < batchnorm->size(); ++i) { | |||
| if (is_training && i > kBn3DTrainInputTensorNum) { | |||
| continue; | |||
| } else { | |||
| inputs.push_back(batchnorm->input(i)); | |||
| } | |||
| } | |||
| auto new_node = graph->NewCNode(inputs); | |||
| MS_EXCEPTION_IF_NULL(new_node); | |||
| new_node->set_scope(batchnorm->scope()); | |||
| new_node->set_abstract(batchnorm->abstract()); | |||
| AnfAlgo::CopyNodeAttrs(batchnorm, new_node); | |||
| return new_node; | |||
| } | |||
| bool NeedFusion(const FuncGraphPtr &graph, const CNodePtr &batchnorm) { | |||
| MS_EXCEPTION_IF_NULL(batchnorm); | |||
| if (!AnfAlgo::HasNodeAttr(kAttrIsTraining, batchnorm)) { | |||
| MS_LOG(INFO) << "BatchNorm has no is_training attr."; | |||
| return false; | |||
| } | |||
| auto is_training = AnfAlgo::GetNodeAttr<bool>(batchnorm, kAttrIsTraining); | |||
| auto format = AnfAlgo::GetNodeAttr<std::string>(batchnorm, kAttrFormat); | |||
| if (is_training && format == kOpFormat_NCDHW) { | |||
| if (AnfAlgo::GetInputTensorNum(batchnorm) < kBn3DTrainInputTensorNum) { | |||
| MS_LOG(INFO) << "When data format is NCDHW and is_training is true, BatchNorm's input less than " | |||
| << kBn3DTrainInputTensorNum; | |||
| return false; | |||
| } | |||
| } else { | |||
| if (AnfAlgo::GetInputTensorNum(batchnorm) < kBnInputTensorNum) { | |||
| MS_LOG(INFO) << "BatchNorm's input less than " << kBnInputTensorNum; | |||
| return false; | |||
| } | |||
| } | |||
| const auto &ori_inputs = batchnorm->inputs(); | |||
| auto x_shape = AnfAlgo::GetOutputInferShape(ori_inputs[kBN3InputXIndex], 0); | |||
| if (format != kOpFormat_NCDHW || x_shape.size() != 5) { | |||
| MS_LOG(INFO) << "Only format is NCDHW and the input dim of BatchNorm is 5, then do fusion. But format is: " | |||
| << format << ", size of x_shape is: " << x_shape.size(); | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| } // namespace | |||
| const BaseRef BatchNorm2BatchNorm3D::DefinePattern() const { | |||
| VarPtr Xs = std::make_shared<SeqVar>(); | |||
| MS_EXCEPTION_IF_NULL(Xs); | |||
| VectorRef pattern({prim::kPrimBatchNorm, Xs}); | |||
| return pattern; | |||
| } | |||
| const AnfNodePtr BatchNorm2BatchNorm3D::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, | |||
| const EquivPtr &) const { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto cnode_bn = node->cast<CNodePtr>(); | |||
| if (!NeedFusion(graph, cnode_bn)) { | |||
| return nullptr; | |||
| } | |||
| auto bn_3d = CreateBatchNorm3D(graph, cnode_bn); | |||
| TransferDepend(cnode_bn, graph, bn_3d); | |||
| return bn_3d; | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,33 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_BATCHNORM_TO_BATCHNORM_3D_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_BATCHNORM_TO_BATCHNORM_3D_H_ | |||
| #include <memory> | |||
| #include "backend/optimizer/common/optimizer.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| class BatchNorm2BatchNorm3D : public PatternProcessPass { | |||
| public: | |||
| explicit BatchNorm2BatchNorm3D(bool multigraph = true) : PatternProcessPass("batchnorm_to_batchnorm3d", multigraph) {} | |||
| ~BatchNorm2BatchNorm3D() override = default; | |||
| const BaseRef DefinePattern() const override; | |||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_BATCHNORM_TO_BATCHNORM_3D_H_ | |||
| @@ -142,6 +142,8 @@ constexpr auto kAdamApplyOneWithDecayOpName = "AdamApplyOneWithDecay"; | |||
| constexpr auto kAdamApplyOneWithDecayAssignOpName = "AdamApplyOneWithDecayAssign"; | |||
| constexpr auto kBatchNormGradOpName = "BatchNormGrad"; | |||
| constexpr auto kBNInferOpName = "BNInfer"; | |||
| constexpr auto kBatchNorm3DOpName = "BatchNorm3D"; | |||
| constexpr auto kBatchNorm3DGradOpName = "BatchNorm3DGrad"; | |||
| constexpr auto kAdamApplyOneOpName = "AdamApplyOne"; | |||
| constexpr auto kAdamApplyOneAssignOpName = "AdamApplyOneAssign"; | |||
| constexpr auto kResizeNearestNeighborGradOpName = "ResizeNearestNeighborGrad"; | |||
| @@ -31,7 +31,8 @@ from mindspore.communication import management | |||
| from mindspore.ops import _selected_ops | |||
| from ..cell import Cell | |||
| __all__ = ['BatchNorm1d', 'BatchNorm2d', 'LayerNorm', 'GroupNorm', 'GlobalBatchNorm', 'SyncBatchNorm', 'InstanceNorm2d'] | |||
| __all__ = ['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d', 'LayerNorm', 'GroupNorm', | |||
| 'GlobalBatchNorm', 'SyncBatchNorm', 'InstanceNorm2d'] | |||
| SYNC_BN_GROUP_NAME = "" | |||
| @@ -60,13 +61,16 @@ class _BatchNorm(Cell): | |||
| if momentum < 0 or momentum > 1: | |||
| raise ValueError("momentum should be a number in range [0, 1], but got {}".format(momentum)) | |||
| self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.cls_name) | |||
| self.input_dims = input_dims | |||
| if self.input_dims == "3d": | |||
| self.format = validator.check_string(data_format, ['NCDHW'], 'format', self.cls_name) | |||
| else: | |||
| self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.cls_name) | |||
| if context.get_context("device_target") != "GPU" and self.format == "NHWC": | |||
| raise ValueError("NHWC format only support in GPU target.") | |||
| self.use_batch_statistics = use_batch_statistics | |||
| self.num_features = num_features | |||
| self.eps = eps | |||
| self.input_dims = input_dims | |||
| self.moving_mean = Parameter(initializer( | |||
| moving_mean_init, num_features), name="mean", requires_grad=False) | |||
| self.moving_variance = Parameter(initializer( | |||
| @@ -134,7 +138,8 @@ class _BatchNorm(Cell): | |||
| if self._target == "Ascend": | |||
| self.bn_train = P.BatchNorm(is_training=True, | |||
| epsilon=self.eps, | |||
| momentum=self.momentum) | |||
| momentum=self.momentum, | |||
| data_format=self.format) | |||
| if self._target == "GPU": | |||
| self.bn_train = P.FusedBatchNormEx(mode=1, | |||
| epsilon=self.eps, | |||
| @@ -220,11 +225,14 @@ def _shape_check(in_shape): | |||
| @constexpr | |||
| def _shape_check_bn(in_shape, in_dims): | |||
| """check input dims of batch norm.""" | |||
| dim = len(in_shape) | |||
| if in_dims == '1d' and dim != 2: | |||
| raise ValueError("The input must has 2 dims.") | |||
| if in_dims == '2d' and dim != 4: | |||
| raise ValueError("The input must has 4 dims.") | |||
| if in_dims == '3d' and dim != 5: | |||
| raise ValueError("The input must has 5 dims.") | |||
| if in_dims == 'both' and dim != 2 and dim != 4: | |||
| raise ValueError("The input must has 2 dims or 4 dims.") | |||
| @@ -445,7 +453,7 @@ def _check_3d_shape(input_shape): | |||
| raise ValueError("For BatchNorm3d, input data must be 5-dimensional.") | |||
| class BatchNorm3d(Cell): | |||
| class BatchNorm3d(_BatchNorm): | |||
| r""" | |||
| Batch normalization layer over a 5D input. | |||
| @@ -489,8 +497,15 @@ class BatchNorm3d(Cell): | |||
| Outputs: | |||
| Tensor, the normalized, scaled, offset tensor, of shape :math:`(N, C_{out}, D_{out},H_{out}, W_{out})`. | |||
| Raises: | |||
| TypeError: If `num_features` is not an int. | |||
| TypeError: If `eps` is not a float. | |||
| ValueError: If `num_features` is less than 1. | |||
| ValueError: If `momentum` is not in range [0, 1]. | |||
| ValueError: If `data_format` is not 'NCDHW'. | |||
| Supported Platforms: | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| ``Ascend`` | |||
| Examples: | |||
| >>> net = nn.BatchNorm3d(num_features=3) | |||
| @@ -512,27 +527,21 @@ class BatchNorm3d(Cell): | |||
| moving_var_init='ones', | |||
| use_batch_statistics=None, | |||
| data_format='NCDHW'): | |||
| super(BatchNorm3d, self).__init__() | |||
| self.format = validator.check_string(data_format, ['NCDHW'], 'format', self.cls_name) | |||
| self.reshape = P.Reshape() | |||
| self.bn2d = BatchNorm2d(num_features=num_features, | |||
| eps=eps, | |||
| momentum=momentum, | |||
| affine=affine, | |||
| gamma_init=gamma_init, | |||
| beta_init=beta_init, | |||
| moving_mean_init=moving_mean_init, | |||
| moving_var_init=moving_var_init, | |||
| use_batch_statistics=use_batch_statistics, | |||
| data_format="NCHW") | |||
| super(BatchNorm3d, self).__init__(num_features, | |||
| eps, | |||
| momentum, | |||
| affine, | |||
| gamma_init, | |||
| beta_init, | |||
| moving_mean_init, | |||
| moving_var_init, | |||
| use_batch_statistics, | |||
| input_dims='3d', | |||
| data_format=data_format) | |||
| def construct(self, input_x): | |||
| x_shape = F.shape(input_x) | |||
| _check_3d_shape(x_shape) | |||
| input_x = self.reshape(input_x, (x_shape[0], x_shape[1], x_shape[2]*x_shape[3], x_shape[4])) | |||
| bn2d_out = self.bn2d(input_x) | |||
| bn3d_out = self.reshape(bn2d_out, x_shape) | |||
| return bn3d_out | |||
| def _check_data_dim(self, x): | |||
| if x.ndim != 5: | |||
| pass | |||
| class GlobalBatchNorm(_BatchNorm): | |||
| @@ -712,7 +712,7 @@ def get_bprop_instance_norm(self): | |||
| def get_bprop_batch_norm(self): | |||
| """Grad definition for `BatchNorm` operation.""" | |||
| is_training = self.is_training | |||
| input_grad = G.BatchNormGrad(is_training, self.epsilon) | |||
| input_grad = G.BatchNormGrad(is_training, self.epsilon, self.data_format) | |||
| def bprop(x, scale, b, mean, variance, out, dout): | |||
| if is_training: | |||
| @@ -48,6 +48,8 @@ from .assign_sub import _assign_sub_tbe | |||
| from .batch_matmul import _batch_matmul_tbe | |||
| from .batchnorm import _batch_norm_tbe | |||
| from .batchnorm_grad import _batch_norm_grad_tbe | |||
| from .batchnorm3d import _batch_norm3d_tbe | |||
| from .batchnorm3d_grad import _batch_norm3d_grad_tbe | |||
| from .bias_add import _bias_add_tbe | |||
| from .bias_add_grad import _bias_add_grad_tbe | |||
| from .cast import _cast_tbe | |||
| @@ -0,0 +1,51 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """BatchNorm3D op""" | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| batch_norm3d_op_info = TBERegOp("BatchNorm3D") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("batch_norm3d.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("batch_norm3d") \ | |||
| .partial_flag(True) \ | |||
| .attr("epsilon", "optional", "float", "all") \ | |||
| .attr("format", "optional", "str", "all") \ | |||
| .attr("is_training", "optional", "bool", "all") \ | |||
| .input(0, "x", False, "required", "all") \ | |||
| .input(1, "scale", False, "required", "all", reshape_type="C") \ | |||
| .input(2, "offset", False, "required", "all", reshape_type="C") \ | |||
| .input(3, "mean", False, "optional", "all", reshape_type="C") \ | |||
| .input(4, "variance", False, "optional", "all", reshape_type="C") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .output(1, "batch_mean", False, "required", "all") \ | |||
| .output(2, "batch_variance", False, "required", "all") \ | |||
| .output(3, "reserve_space_1", False, "optional", "all") \ | |||
| .output(4, "reserve_space_2", False, "optional", "all") \ | |||
| .dtype_format(DataType.F16_NDC1HWC0, DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0, | |||
| DataType.F32_NDC1HWC0, DataType.F16_NDC1HWC0, DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0, | |||
| DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0) \ | |||
| .dtype_format(DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0, | |||
| DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0, | |||
| DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0) \ | |||
| .get_op_info() | |||
| @op_info_register(batch_norm3d_op_info) | |||
| def _batch_norm3d_tbe(): | |||
| """BatchNorm3D TBE register""" | |||
| return | |||
| @@ -0,0 +1,51 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """BatchNorm3DGrad op""" | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| batch_norm3d_grad_op_info = TBERegOp("BatchNorm3DGrad") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("batch_norm3d_grad.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("batch_norm3d_grad") \ | |||
| .partial_flag(True) \ | |||
| .attr("epsilon", "optional", "float", "all") \ | |||
| .attr("format", "optional", "str", "all") \ | |||
| .attr("is_training", "optional", "bool", "all") \ | |||
| .input(0, "y_backprop", False, "required", "all") \ | |||
| .input(1, "x", False, "required", "all") \ | |||
| .input(2, "scale", False, "required", "all", reshape_type="C") \ | |||
| .input(3, "reserve_space_1", False, "optional", "all") \ | |||
| .input(4, "reserve_space_2", False, "optional", "all") \ | |||
| .output(0, "x_backprop", False, "required", "all") \ | |||
| .output(1, "scale_backprop", False, "required", "all") \ | |||
| .output(2, "offset_backprop", False, "required", "all") \ | |||
| .output(3, "reserve_space_4", False, "optional", "all") \ | |||
| .output(4, "reserve_space_5", False, "optional", "all") \ | |||
| .dtype_format(DataType.F16_NDC1HWC0, DataType.F16_NDC1HWC0, DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0, | |||
| DataType.F32_NDC1HWC0, DataType.F16_NDC1HWC0, DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0, | |||
| DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0) \ | |||
| .dtype_format(DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0, | |||
| DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0, | |||
| DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0) \ | |||
| .get_op_info() | |||
| @op_info_register(batch_norm3d_grad_op_info) | |||
| def _batch_norm3d_grad_tbe(): | |||
| """BatchNorm3DGrad TBE register""" | |||
| return | |||
| @@ -192,10 +192,10 @@ class BatchNormGrad(PrimitiveWithInfer): | |||
| """Performs grad of BatchNorm operation.""" | |||
| @prim_attr_register | |||
| def __init__(self, is_training=False, epsilon=1e-5): | |||
| def __init__(self, is_training=False, epsilon=1e-5, data_format='NCHW'): | |||
| self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name) | |||
| self.epsilon = validator.check_float_range(epsilon, 0, 1, Rel.INC_RIGHT, 'epsilon', self.name) | |||
| self.add_prim_attr('data_format', "NCHW") | |||
| self.data_format = validator.check_string(data_format, ['NCHW', 'NHWC', "NCDHW"], 'format', self.name) | |||
| def infer_shape(self, y_backprop_shape, x_shape, scale_shape, reserve_1_shape, reserve_2_shape): | |||
| validator.check("BatchNorm y_backprop_shape", y_backprop_shape, "BatchNorm x_shape", x_shape) | |||
| @@ -1326,18 +1326,20 @@ class BatchNorm(PrimitiveWithInfer): | |||
| validator.check_value_type('is_training', is_training, (bool,), self.name) | |||
| validator.check_float_range(epsilon, 0, 1, Rel.INC_RIGHT, 'epsilon', self.name) | |||
| validator.check_float_range(momentum, 0, 1, Rel.INC_BOTH, 'momentum', self.name) | |||
| self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.name) | |||
| self.format = validator.check_string(data_format, ['NCHW', 'NHWC', "NCDHW"], 'format', self.name) | |||
| if context.get_context("device_target") != "GPU" and self.format == "NHWC": | |||
| raise ValueError("NHWC format only support in GPU target.") | |||
| if context.get_context("device_target") != "Ascend" and self.format == "NCDHW": | |||
| raise ValueError("NCDHW format only support in Ascend target.") | |||
| self.add_prim_attr('data_format', self.format) | |||
| self.init_prim_io_names(inputs=['x', 'scale', 'offset', 'mean', 'variance'], | |||
| outputs=['y', 'batch_mean', 'batch_variance', 'reserve_space_1', 'reserve_space_2']) | |||
| def infer_shape(self, input_x, scale, bias, mean, variance): | |||
| input_shape_norm = input_x if self.format == "NCHW" else (input_x[0], input_x[3], input_x[1], input_x[2]) | |||
| input_x_channel = input_x[-1] if self.format == "NHWC" else input_x[1] | |||
| validator.check_equal_int(len(scale), 1, "scale rank", self.name) | |||
| validator.check("scale shape", scale, "bias shape", bias, Rel.EQ, self.name) | |||
| validator.check("scale shape[0]", scale[0], "input_x channel", input_shape_norm[1], Rel.EQ, self.name) | |||
| validator.check("scale shape[0]", scale[0], "input_x channel", input_x_channel, Rel.EQ, self.name) | |||
| if not self.is_training: | |||
| validator.check_equal_int(len(mean), 1, "mean rank", self.name) | |||
| validator.check("mean shape", mean, "variance shape", variance, Rel.EQ, self.name) | |||