diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc b/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc index aae9a3aa50..7ddcb3dcf5 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc @@ -64,6 +64,8 @@ #include "backend/optimizer/ascend/ir_fusion/derelu_fusion.h" #include "backend/optimizer/ascend/ir_fusion/batchnorm_to_bninfer.h" #include "backend/optimizer/ascend/ir_fusion/batchnormgrad_to_bninfergrad.h" +#include "backend/optimizer/ascend/ir_fusion/batchnorm_to_batchnorm3d.h" +#include "backend/optimizer/ascend/ir_fusion/batchnorm_grad_to_batchnorm3d_grad.h" #include "backend/optimizer/ascend/ir_fusion/confusion_mul_grad_fusion.h" #include "backend/optimizer/ascend/ir_fusion/softmax_grad_ext_fusion.h" #include "backend/optimizer/ascend/format_type/insert_trans_op.h" @@ -276,6 +278,8 @@ void AscendBackendIRFusionOptimization(const std::shared_ptr(); auto ir_fusion_pm = std::make_shared("ir_fusion_pm"); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); @@ -321,6 +325,8 @@ void RunOpAscendBackendIRFusionOptimization(const std::shared_ptr("ir_fusion_pm"); ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/batchnorm_grad_to_batchnorm3d_grad.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/batchnorm_grad_to_batchnorm3d_grad.cc new file mode 100644 index 0000000000..a03cee87bb --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/batchnorm_grad_to_batchnorm3d_grad.cc @@ -0,0 +1,85 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/ir_fusion/batchnorm_grad_to_batchnorm3d_grad.h" +#include +#include +#include +#include "backend/session/anf_runtime_algorithm.h" +#include "ir/primitive.h" +#include "utils/utils.h" +#include "base/core_ops.h" +#include "abstract/abstract_value.h" +#include "backend/optimizer/common/helper.h" + +namespace mindspore { +namespace opt { +namespace { +constexpr size_t kBN3DGradInputXIndex = 2; +CNodePtr CreateBatchNorm3DGrad(const FuncGraphPtr &graph, const CNodePtr &batchnorm_grad) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(batchnorm_grad); + auto prim = std::make_shared(kBatchNorm3DGradOpName); + std::vector inputs = {NewValueNode(prim)}; + for (size_t i = 1; i < batchnorm_grad->size(); ++i) { + inputs.push_back(batchnorm_grad->input(i)); + } + auto new_node = graph->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(new_node); + new_node->set_scope(batchnorm_grad->scope()); + new_node->set_abstract(batchnorm_grad->abstract()); + AnfAlgo::CopyNodeAttrs(batchnorm_grad, new_node); + return new_node; +} + +bool NeedFusion(const FuncGraphPtr &graph, const CNodePtr &batchnorm_grad) { + MS_EXCEPTION_IF_NULL(batchnorm_grad); + if (AnfAlgo::GetInputTensorNum(batchnorm_grad) < kBNGradInputTensorNum) { + MS_LOG(INFO) << "BatchNormGrad's input less than " << kBNGradInputTensorNum; + return false; + } + auto format = AnfAlgo::GetNodeAttr(batchnorm_grad, kAttrFormat); + const auto &ori_inputs = batchnorm_grad->inputs(); + auto x_shape = AnfAlgo::GetOutputInferShape(ori_inputs[kBN3DGradInputXIndex], 0); + if (format != kOpFormat_NCDHW || x_shape.size() != 5) { + MS_LOG(INFO) << "Only format is NCDHW and the input dim of BatchNormGrad is 5, then do fusion. But format is: " + << format << ", size of x_shape is: " << x_shape.size(); + return false; + } + return true; +} +} // namespace + +const BaseRef BatchNormGrad2BatchNorm3DGRAD::DefinePattern() const { + VarPtr Xs = std::make_shared(); + MS_EXCEPTION_IF_NULL(Xs); + VectorRef pattern({prim::kPrimBatchNormGrad, Xs}); + return pattern; +} + +const AnfNodePtr BatchNormGrad2BatchNorm3DGRAD::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, + const EquivPtr &) const { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(node); + auto cnode_bn_grad = node->cast(); + if (!NeedFusion(graph, cnode_bn_grad)) { + return nullptr; + } + auto bn_3d_grad = CreateBatchNorm3DGrad(graph, cnode_bn_grad); + TransferDepend(cnode_bn_grad, graph, bn_3d_grad); + return bn_3d_grad; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/batchnorm_grad_to_batchnorm3d_grad.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/batchnorm_grad_to_batchnorm3d_grad.h new file mode 100644 index 0000000000..e2f7530fba --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/batchnorm_grad_to_batchnorm3d_grad.h @@ -0,0 +1,34 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_BATCHNORM_GRAD_TO_BATCHNORM_3D_GRAD_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_BATCHNORM_GRAD_TO_BATCHNORM_3D_GRAD_H_ + +#include +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class BatchNormGrad2BatchNorm3DGRAD : public PatternProcessPass { + public: + explicit BatchNormGrad2BatchNorm3DGRAD(bool multigraph = true) + : PatternProcessPass("batchnorm_grad_to_batchnorm3d_grad", multigraph) {} + ~BatchNormGrad2BatchNorm3DGRAD() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_BATCHNORM_GRAD_TO_BATCHNORM_3D_GRAD_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/batchnorm_to_batchnorm3d.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/batchnorm_to_batchnorm3d.cc new file mode 100644 index 0000000000..a01f752424 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/batchnorm_to_batchnorm3d.cc @@ -0,0 +1,104 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/ir_fusion/batchnorm_to_batchnorm3d.h" +#include +#include +#include +#include "backend/session/anf_runtime_algorithm.h" +#include "ir/primitive.h" +#include "utils/utils.h" +#include "base/core_ops.h" +#include "abstract/abstract_value.h" +#include "backend/optimizer/common/helper.h" + +namespace mindspore { +namespace opt { +namespace { +constexpr size_t kBN3InputXIndex = 1; +constexpr size_t kBn3DTrainInputTensorNum = 3; +CNodePtr CreateBatchNorm3D(const FuncGraphPtr &graph, const CNodePtr &batchnorm) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(batchnorm); + auto prim = std::make_shared(kBatchNorm3DOpName); + std::vector inputs = {NewValueNode(prim)}; + auto is_training = AnfAlgo::GetNodeAttr(batchnorm, kAttrIsTraining); + for (size_t i = 1; i < batchnorm->size(); ++i) { + if (is_training && i > kBn3DTrainInputTensorNum) { + continue; + } else { + inputs.push_back(batchnorm->input(i)); + } + } + auto new_node = graph->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(new_node); + new_node->set_scope(batchnorm->scope()); + new_node->set_abstract(batchnorm->abstract()); + AnfAlgo::CopyNodeAttrs(batchnorm, new_node); + return new_node; +} + +bool NeedFusion(const FuncGraphPtr &graph, const CNodePtr &batchnorm) { + MS_EXCEPTION_IF_NULL(batchnorm); + if (!AnfAlgo::HasNodeAttr(kAttrIsTraining, batchnorm)) { + MS_LOG(INFO) << "BatchNorm has no is_training attr."; + return false; + } + auto is_training = AnfAlgo::GetNodeAttr(batchnorm, kAttrIsTraining); + auto format = AnfAlgo::GetNodeAttr(batchnorm, kAttrFormat); + if (is_training && format == kOpFormat_NCDHW) { + if (AnfAlgo::GetInputTensorNum(batchnorm) < kBn3DTrainInputTensorNum) { + MS_LOG(INFO) << "When data format is NCDHW and is_training is true, BatchNorm's input less than " + << kBn3DTrainInputTensorNum; + return false; + } + } else { + if (AnfAlgo::GetInputTensorNum(batchnorm) < kBnInputTensorNum) { + MS_LOG(INFO) << "BatchNorm's input less than " << kBnInputTensorNum; + return false; + } + } + const auto &ori_inputs = batchnorm->inputs(); + auto x_shape = AnfAlgo::GetOutputInferShape(ori_inputs[kBN3InputXIndex], 0); + if (format != kOpFormat_NCDHW || x_shape.size() != 5) { + MS_LOG(INFO) << "Only format is NCDHW and the input dim of BatchNorm is 5, then do fusion. But format is: " + << format << ", size of x_shape is: " << x_shape.size(); + return false; + } + return true; +} +} // namespace + +const BaseRef BatchNorm2BatchNorm3D::DefinePattern() const { + VarPtr Xs = std::make_shared(); + MS_EXCEPTION_IF_NULL(Xs); + VectorRef pattern({prim::kPrimBatchNorm, Xs}); + return pattern; +} + +const AnfNodePtr BatchNorm2BatchNorm3D::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, + const EquivPtr &) const { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(node); + auto cnode_bn = node->cast(); + if (!NeedFusion(graph, cnode_bn)) { + return nullptr; + } + auto bn_3d = CreateBatchNorm3D(graph, cnode_bn); + TransferDepend(cnode_bn, graph, bn_3d); + return bn_3d; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/batchnorm_to_batchnorm3d.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/batchnorm_to_batchnorm3d.h new file mode 100644 index 0000000000..19203ac484 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/batchnorm_to_batchnorm3d.h @@ -0,0 +1,33 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_BATCHNORM_TO_BATCHNORM_3D_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_BATCHNORM_TO_BATCHNORM_3D_H_ + +#include +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class BatchNorm2BatchNorm3D : public PatternProcessPass { + public: + explicit BatchNorm2BatchNorm3D(bool multigraph = true) : PatternProcessPass("batchnorm_to_batchnorm3d", multigraph) {} + ~BatchNorm2BatchNorm3D() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_BATCHNORM_TO_BATCHNORM_3D_H_ diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index 9532c3ff3c..5aea2b1c2b 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -142,6 +142,8 @@ constexpr auto kAdamApplyOneWithDecayOpName = "AdamApplyOneWithDecay"; constexpr auto kAdamApplyOneWithDecayAssignOpName = "AdamApplyOneWithDecayAssign"; constexpr auto kBatchNormGradOpName = "BatchNormGrad"; constexpr auto kBNInferOpName = "BNInfer"; +constexpr auto kBatchNorm3DOpName = "BatchNorm3D"; +constexpr auto kBatchNorm3DGradOpName = "BatchNorm3DGrad"; constexpr auto kAdamApplyOneOpName = "AdamApplyOne"; constexpr auto kAdamApplyOneAssignOpName = "AdamApplyOneAssign"; constexpr auto kResizeNearestNeighborGradOpName = "ResizeNearestNeighborGrad"; diff --git a/mindspore/nn/layer/normalization.py b/mindspore/nn/layer/normalization.py index 954d8d29fa..340521e792 100644 --- a/mindspore/nn/layer/normalization.py +++ b/mindspore/nn/layer/normalization.py @@ -31,7 +31,8 @@ from mindspore.communication import management from mindspore.ops import _selected_ops from ..cell import Cell -__all__ = ['BatchNorm1d', 'BatchNorm2d', 'LayerNorm', 'GroupNorm', 'GlobalBatchNorm', 'SyncBatchNorm', 'InstanceNorm2d'] +__all__ = ['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d', 'LayerNorm', 'GroupNorm', + 'GlobalBatchNorm', 'SyncBatchNorm', 'InstanceNorm2d'] SYNC_BN_GROUP_NAME = "" @@ -60,13 +61,16 @@ class _BatchNorm(Cell): if momentum < 0 or momentum > 1: raise ValueError("momentum should be a number in range [0, 1], but got {}".format(momentum)) - self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.cls_name) + self.input_dims = input_dims + if self.input_dims == "3d": + self.format = validator.check_string(data_format, ['NCDHW'], 'format', self.cls_name) + else: + self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.cls_name) if context.get_context("device_target") != "GPU" and self.format == "NHWC": raise ValueError("NHWC format only support in GPU target.") self.use_batch_statistics = use_batch_statistics self.num_features = num_features self.eps = eps - self.input_dims = input_dims self.moving_mean = Parameter(initializer( moving_mean_init, num_features), name="mean", requires_grad=False) self.moving_variance = Parameter(initializer( @@ -134,7 +138,8 @@ class _BatchNorm(Cell): if self._target == "Ascend": self.bn_train = P.BatchNorm(is_training=True, epsilon=self.eps, - momentum=self.momentum) + momentum=self.momentum, + data_format=self.format) if self._target == "GPU": self.bn_train = P.FusedBatchNormEx(mode=1, epsilon=self.eps, @@ -220,11 +225,14 @@ def _shape_check(in_shape): @constexpr def _shape_check_bn(in_shape, in_dims): + """check input dims of batch norm.""" dim = len(in_shape) if in_dims == '1d' and dim != 2: raise ValueError("The input must has 2 dims.") if in_dims == '2d' and dim != 4: raise ValueError("The input must has 4 dims.") + if in_dims == '3d' and dim != 5: + raise ValueError("The input must has 5 dims.") if in_dims == 'both' and dim != 2 and dim != 4: raise ValueError("The input must has 2 dims or 4 dims.") @@ -445,7 +453,7 @@ def _check_3d_shape(input_shape): raise ValueError("For BatchNorm3d, input data must be 5-dimensional.") -class BatchNorm3d(Cell): +class BatchNorm3d(_BatchNorm): r""" Batch normalization layer over a 5D input. @@ -489,8 +497,15 @@ class BatchNorm3d(Cell): Outputs: Tensor, the normalized, scaled, offset tensor, of shape :math:`(N, C_{out}, D_{out},H_{out}, W_{out})`. + Raises: + TypeError: If `num_features` is not an int. + TypeError: If `eps` is not a float. + ValueError: If `num_features` is less than 1. + ValueError: If `momentum` is not in range [0, 1]. + ValueError: If `data_format` is not 'NCDHW'. + Supported Platforms: - ``Ascend`` ``GPU`` ``CPU`` + ``Ascend`` Examples: >>> net = nn.BatchNorm3d(num_features=3) @@ -512,27 +527,21 @@ class BatchNorm3d(Cell): moving_var_init='ones', use_batch_statistics=None, data_format='NCDHW'): - super(BatchNorm3d, self).__init__() - self.format = validator.check_string(data_format, ['NCDHW'], 'format', self.cls_name) - self.reshape = P.Reshape() - self.bn2d = BatchNorm2d(num_features=num_features, - eps=eps, - momentum=momentum, - affine=affine, - gamma_init=gamma_init, - beta_init=beta_init, - moving_mean_init=moving_mean_init, - moving_var_init=moving_var_init, - use_batch_statistics=use_batch_statistics, - data_format="NCHW") + super(BatchNorm3d, self).__init__(num_features, + eps, + momentum, + affine, + gamma_init, + beta_init, + moving_mean_init, + moving_var_init, + use_batch_statistics, + input_dims='3d', + data_format=data_format) - def construct(self, input_x): - x_shape = F.shape(input_x) - _check_3d_shape(x_shape) - input_x = self.reshape(input_x, (x_shape[0], x_shape[1], x_shape[2]*x_shape[3], x_shape[4])) - bn2d_out = self.bn2d(input_x) - bn3d_out = self.reshape(bn2d_out, x_shape) - return bn3d_out + def _check_data_dim(self, x): + if x.ndim != 5: + pass class GlobalBatchNorm(_BatchNorm): diff --git a/mindspore/ops/_grad/grad_nn_ops.py b/mindspore/ops/_grad/grad_nn_ops.py index b1a64ff114..001c27b4d7 100755 --- a/mindspore/ops/_grad/grad_nn_ops.py +++ b/mindspore/ops/_grad/grad_nn_ops.py @@ -712,7 +712,7 @@ def get_bprop_instance_norm(self): def get_bprop_batch_norm(self): """Grad definition for `BatchNorm` operation.""" is_training = self.is_training - input_grad = G.BatchNormGrad(is_training, self.epsilon) + input_grad = G.BatchNormGrad(is_training, self.epsilon, self.data_format) def bprop(x, scale, b, mean, variance, out, dout): if is_training: diff --git a/mindspore/ops/_op_impl/tbe/__init__.py b/mindspore/ops/_op_impl/tbe/__init__.py index 84458a4903..700903ce0a 100644 --- a/mindspore/ops/_op_impl/tbe/__init__.py +++ b/mindspore/ops/_op_impl/tbe/__init__.py @@ -48,6 +48,8 @@ from .assign_sub import _assign_sub_tbe from .batch_matmul import _batch_matmul_tbe from .batchnorm import _batch_norm_tbe from .batchnorm_grad import _batch_norm_grad_tbe +from .batchnorm3d import _batch_norm3d_tbe +from .batchnorm3d_grad import _batch_norm3d_grad_tbe from .bias_add import _bias_add_tbe from .bias_add_grad import _bias_add_grad_tbe from .cast import _cast_tbe diff --git a/mindspore/ops/_op_impl/tbe/batchnorm3d.py b/mindspore/ops/_op_impl/tbe/batchnorm3d.py new file mode 100644 index 0000000000..bd66a35352 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/batchnorm3d.py @@ -0,0 +1,51 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""BatchNorm3D op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +batch_norm3d_op_info = TBERegOp("BatchNorm3D") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("batch_norm3d.so") \ + .compute_cost(10) \ + .kernel_name("batch_norm3d") \ + .partial_flag(True) \ + .attr("epsilon", "optional", "float", "all") \ + .attr("format", "optional", "str", "all") \ + .attr("is_training", "optional", "bool", "all") \ + .input(0, "x", False, "required", "all") \ + .input(1, "scale", False, "required", "all", reshape_type="C") \ + .input(2, "offset", False, "required", "all", reshape_type="C") \ + .input(3, "mean", False, "optional", "all", reshape_type="C") \ + .input(4, "variance", False, "optional", "all", reshape_type="C") \ + .output(0, "y", False, "required", "all") \ + .output(1, "batch_mean", False, "required", "all") \ + .output(2, "batch_variance", False, "required", "all") \ + .output(3, "reserve_space_1", False, "optional", "all") \ + .output(4, "reserve_space_2", False, "optional", "all") \ + .dtype_format(DataType.F16_NDC1HWC0, DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0, + DataType.F32_NDC1HWC0, DataType.F16_NDC1HWC0, DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0, + DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0) \ + .dtype_format(DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0, + DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0, + DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0) \ + .get_op_info() + + +@op_info_register(batch_norm3d_op_info) +def _batch_norm3d_tbe(): + """BatchNorm3D TBE register""" + return diff --git a/mindspore/ops/_op_impl/tbe/batchnorm3d_grad.py b/mindspore/ops/_op_impl/tbe/batchnorm3d_grad.py new file mode 100644 index 0000000000..57019856e2 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/batchnorm3d_grad.py @@ -0,0 +1,51 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""BatchNorm3DGrad op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +batch_norm3d_grad_op_info = TBERegOp("BatchNorm3DGrad") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("batch_norm3d_grad.so") \ + .compute_cost(10) \ + .kernel_name("batch_norm3d_grad") \ + .partial_flag(True) \ + .attr("epsilon", "optional", "float", "all") \ + .attr("format", "optional", "str", "all") \ + .attr("is_training", "optional", "bool", "all") \ + .input(0, "y_backprop", False, "required", "all") \ + .input(1, "x", False, "required", "all") \ + .input(2, "scale", False, "required", "all", reshape_type="C") \ + .input(3, "reserve_space_1", False, "optional", "all") \ + .input(4, "reserve_space_2", False, "optional", "all") \ + .output(0, "x_backprop", False, "required", "all") \ + .output(1, "scale_backprop", False, "required", "all") \ + .output(2, "offset_backprop", False, "required", "all") \ + .output(3, "reserve_space_4", False, "optional", "all") \ + .output(4, "reserve_space_5", False, "optional", "all") \ + .dtype_format(DataType.F16_NDC1HWC0, DataType.F16_NDC1HWC0, DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0, + DataType.F32_NDC1HWC0, DataType.F16_NDC1HWC0, DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0, + DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0) \ + .dtype_format(DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0, + DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0, + DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0) \ + .get_op_info() + + +@op_info_register(batch_norm3d_grad_op_info) +def _batch_norm3d_grad_tbe(): + """BatchNorm3DGrad TBE register""" + return diff --git a/mindspore/ops/operations/_grad_ops.py b/mindspore/ops/operations/_grad_ops.py index 5fa34591ef..e0253b82f0 100644 --- a/mindspore/ops/operations/_grad_ops.py +++ b/mindspore/ops/operations/_grad_ops.py @@ -192,10 +192,10 @@ class BatchNormGrad(PrimitiveWithInfer): """Performs grad of BatchNorm operation.""" @prim_attr_register - def __init__(self, is_training=False, epsilon=1e-5): + def __init__(self, is_training=False, epsilon=1e-5, data_format='NCHW'): self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name) self.epsilon = validator.check_float_range(epsilon, 0, 1, Rel.INC_RIGHT, 'epsilon', self.name) - self.add_prim_attr('data_format', "NCHW") + self.data_format = validator.check_string(data_format, ['NCHW', 'NHWC', "NCDHW"], 'format', self.name) def infer_shape(self, y_backprop_shape, x_shape, scale_shape, reserve_1_shape, reserve_2_shape): validator.check("BatchNorm y_backprop_shape", y_backprop_shape, "BatchNorm x_shape", x_shape) diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index f3ec6f3c1c..f182bcc559 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -1326,18 +1326,20 @@ class BatchNorm(PrimitiveWithInfer): validator.check_value_type('is_training', is_training, (bool,), self.name) validator.check_float_range(epsilon, 0, 1, Rel.INC_RIGHT, 'epsilon', self.name) validator.check_float_range(momentum, 0, 1, Rel.INC_BOTH, 'momentum', self.name) - self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.name) + self.format = validator.check_string(data_format, ['NCHW', 'NHWC', "NCDHW"], 'format', self.name) if context.get_context("device_target") != "GPU" and self.format == "NHWC": raise ValueError("NHWC format only support in GPU target.") + if context.get_context("device_target") != "Ascend" and self.format == "NCDHW": + raise ValueError("NCDHW format only support in Ascend target.") self.add_prim_attr('data_format', self.format) self.init_prim_io_names(inputs=['x', 'scale', 'offset', 'mean', 'variance'], outputs=['y', 'batch_mean', 'batch_variance', 'reserve_space_1', 'reserve_space_2']) def infer_shape(self, input_x, scale, bias, mean, variance): - input_shape_norm = input_x if self.format == "NCHW" else (input_x[0], input_x[3], input_x[1], input_x[2]) + input_x_channel = input_x[-1] if self.format == "NHWC" else input_x[1] validator.check_equal_int(len(scale), 1, "scale rank", self.name) validator.check("scale shape", scale, "bias shape", bias, Rel.EQ, self.name) - validator.check("scale shape[0]", scale[0], "input_x channel", input_shape_norm[1], Rel.EQ, self.name) + validator.check("scale shape[0]", scale[0], "input_x channel", input_x_channel, Rel.EQ, self.name) if not self.is_training: validator.check_equal_int(len(mean), 1, "mean rank", self.name) validator.check("mean shape", mean, "variance shape", variance, Rel.EQ, self.name)