From: @liu_xiao_93 Reviewed-by: @liangchenghui Signed-off-by: @liangchenghuipull/13818/MERGE
| @@ -64,8 +64,6 @@ | |||
| #include "backend/optimizer/ascend/ir_fusion/derelu_fusion.h" | |||
| #include "backend/optimizer/ascend/ir_fusion/batchnorm_to_bninfer.h" | |||
| #include "backend/optimizer/ascend/ir_fusion/batchnormgrad_to_bninfergrad.h" | |||
| #include "backend/optimizer/ascend/ir_fusion/batchnorm_to_batchnorm3d.h" | |||
| #include "backend/optimizer/ascend/ir_fusion/batchnorm_grad_to_batchnorm3d_grad.h" | |||
| #include "backend/optimizer/ascend/ir_fusion/confusion_mul_grad_fusion.h" | |||
| #include "backend/optimizer/ascend/ir_fusion/softmax_grad_ext_fusion.h" | |||
| #include "backend/optimizer/ascend/format_type/insert_trans_op.h" | |||
| @@ -278,8 +276,6 @@ void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGrap | |||
| } | |||
| auto optimizer = std::make_shared<GraphOptimizer>(); | |||
| auto ir_fusion_pm = std::make_shared<PassManager>("ir_fusion_pm"); | |||
| ir_fusion_pm->AddPass(std::make_shared<BatchNorm2BatchNorm3D>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<BatchNormGrad2BatchNorm3DGRAD>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<BnSplit>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<BnGradSplit>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<SyncBnSplit>()); | |||
| @@ -325,8 +321,6 @@ void RunOpAscendBackendIRFusionOptimization(const std::shared_ptr<session::Kerne | |||
| auto ir_fusion_pm = std::make_shared<PassManager>("ir_fusion_pm"); | |||
| ir_fusion_pm->AddPass(std::make_shared<SplitFission>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<SplitVFission>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<BatchNorm2BatchNorm3D>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<BatchNormGrad2BatchNorm3DGRAD>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<BnSplit>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<BnGradSplit>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<LayerNormGradSplit>()); | |||
| @@ -1,85 +0,0 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/optimizer/ascend/ir_fusion/batchnorm_grad_to_batchnorm3d_grad.h" | |||
| #include <memory> | |||
| #include <string> | |||
| #include <vector> | |||
| #include "backend/session/anf_runtime_algorithm.h" | |||
| #include "ir/primitive.h" | |||
| #include "utils/utils.h" | |||
| #include "base/core_ops.h" | |||
| #include "abstract/abstract_value.h" | |||
| #include "backend/optimizer/common/helper.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace { | |||
| constexpr size_t kBN3DGradInputXIndex = 2; | |||
| CNodePtr CreateBatchNorm3DGrad(const FuncGraphPtr &graph, const CNodePtr &batchnorm_grad) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_EXCEPTION_IF_NULL(batchnorm_grad); | |||
| auto prim = std::make_shared<Primitive>(kBatchNorm3DGradOpName); | |||
| std::vector<AnfNodePtr> inputs = {NewValueNode(prim)}; | |||
| for (size_t i = 1; i < batchnorm_grad->size() - 1; ++i) { | |||
| inputs.push_back(batchnorm_grad->input(i)); | |||
| } | |||
| auto new_node = graph->NewCNode(inputs); | |||
| MS_EXCEPTION_IF_NULL(new_node); | |||
| new_node->set_scope(batchnorm_grad->scope()); | |||
| new_node->set_abstract(batchnorm_grad->abstract()); | |||
| AnfAlgo::CopyNodeAttrs(batchnorm_grad, new_node); | |||
| return new_node; | |||
| } | |||
| bool NeedFusion(const FuncGraphPtr &graph, const CNodePtr &batchnorm_grad) { | |||
| MS_EXCEPTION_IF_NULL(batchnorm_grad); | |||
| if (AnfAlgo::GetInputTensorNum(batchnorm_grad) < kBNGradInputTensorNum) { | |||
| MS_LOG(INFO) << "BatchNormGrad's input less than " << kBNGradInputTensorNum; | |||
| return false; | |||
| } | |||
| auto format = AnfAlgo::GetNodeAttr<std::string>(batchnorm_grad, kAttrFormat); | |||
| const auto &ori_inputs = batchnorm_grad->inputs(); | |||
| auto x_shape = AnfAlgo::GetOutputInferShape(ori_inputs[kBN3DGradInputXIndex], 0); | |||
| if (format != kOpFormat_NCDHW || x_shape.size() != 5) { | |||
| MS_LOG(INFO) << "Only format is NCDHW and the input dim of BatchNormGrad is 5, then do fusion. But format is: " | |||
| << format << ", size of x_shape is: " << x_shape.size(); | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| } // namespace | |||
| const BaseRef BatchNormGrad2BatchNorm3DGRAD::DefinePattern() const { | |||
| VarPtr Xs = std::make_shared<SeqVar>(); | |||
| MS_EXCEPTION_IF_NULL(Xs); | |||
| VectorRef pattern({prim::kPrimBatchNormGrad, Xs}); | |||
| return pattern; | |||
| } | |||
| const AnfNodePtr BatchNormGrad2BatchNorm3DGRAD::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, | |||
| const EquivPtr &) const { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto cnode_bn_grad = node->cast<CNodePtr>(); | |||
| if (!NeedFusion(graph, cnode_bn_grad)) { | |||
| return nullptr; | |||
| } | |||
| auto bn_3d_grad = CreateBatchNorm3DGrad(graph, cnode_bn_grad); | |||
| TransferDepend(cnode_bn_grad, graph, bn_3d_grad); | |||
| return bn_3d_grad; | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -1,34 +0,0 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_BATCHNORM_GRAD_TO_BATCHNORM_3D_GRAD_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_BATCHNORM_GRAD_TO_BATCHNORM_3D_GRAD_H_ | |||
| #include <memory> | |||
| #include "backend/optimizer/common/optimizer.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| class BatchNormGrad2BatchNorm3DGRAD : public PatternProcessPass { | |||
| public: | |||
| explicit BatchNormGrad2BatchNorm3DGRAD(bool multigraph = true) | |||
| : PatternProcessPass("batchnorm_grad_to_batchnorm3d_grad", multigraph) {} | |||
| ~BatchNormGrad2BatchNorm3DGRAD() override = default; | |||
| const BaseRef DefinePattern() const override; | |||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_BATCHNORM_GRAD_TO_BATCHNORM_3D_GRAD_H_ | |||
| @@ -1,104 +0,0 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/optimizer/ascend/ir_fusion/batchnorm_to_batchnorm3d.h" | |||
| #include <memory> | |||
| #include <string> | |||
| #include <vector> | |||
| #include "backend/session/anf_runtime_algorithm.h" | |||
| #include "ir/primitive.h" | |||
| #include "utils/utils.h" | |||
| #include "base/core_ops.h" | |||
| #include "abstract/abstract_value.h" | |||
| #include "backend/optimizer/common/helper.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace { | |||
| constexpr size_t kBN3InputXIndex = 1; | |||
| constexpr size_t kBn3DTrainInputTensorNum = 3; | |||
| CNodePtr CreateBatchNorm3D(const FuncGraphPtr &graph, const CNodePtr &batchnorm) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_EXCEPTION_IF_NULL(batchnorm); | |||
| auto prim = std::make_shared<Primitive>(kBatchNorm3DOpName); | |||
| std::vector<AnfNodePtr> inputs = {NewValueNode(prim)}; | |||
| auto is_training = AnfAlgo::GetNodeAttr<bool>(batchnorm, kAttrIsTraining); | |||
| for (size_t i = 1; i < batchnorm->size(); ++i) { | |||
| if (is_training && i > kBn3DTrainInputTensorNum) { | |||
| continue; | |||
| } else { | |||
| inputs.push_back(batchnorm->input(i)); | |||
| } | |||
| } | |||
| auto new_node = graph->NewCNode(inputs); | |||
| MS_EXCEPTION_IF_NULL(new_node); | |||
| new_node->set_scope(batchnorm->scope()); | |||
| new_node->set_abstract(batchnorm->abstract()); | |||
| AnfAlgo::CopyNodeAttrs(batchnorm, new_node); | |||
| return new_node; | |||
| } | |||
| bool NeedFusion(const FuncGraphPtr &graph, const CNodePtr &batchnorm) { | |||
| MS_EXCEPTION_IF_NULL(batchnorm); | |||
| if (!AnfAlgo::HasNodeAttr(kAttrIsTraining, batchnorm)) { | |||
| MS_LOG(INFO) << "BatchNorm has no is_training attr."; | |||
| return false; | |||
| } | |||
| auto is_training = AnfAlgo::GetNodeAttr<bool>(batchnorm, kAttrIsTraining); | |||
| auto format = AnfAlgo::GetNodeAttr<std::string>(batchnorm, kAttrFormat); | |||
| if (is_training && format == kOpFormat_NCDHW) { | |||
| if (AnfAlgo::GetInputTensorNum(batchnorm) < kBn3DTrainInputTensorNum) { | |||
| MS_LOG(INFO) << "When data format is NCDHW and is_training is true, BatchNorm's input less than " | |||
| << kBn3DTrainInputTensorNum; | |||
| return false; | |||
| } | |||
| } else { | |||
| if (AnfAlgo::GetInputTensorNum(batchnorm) < kBnInputTensorNum) { | |||
| MS_LOG(INFO) << "BatchNorm's input less than " << kBnInputTensorNum; | |||
| return false; | |||
| } | |||
| } | |||
| const auto &ori_inputs = batchnorm->inputs(); | |||
| auto x_shape = AnfAlgo::GetOutputInferShape(ori_inputs[kBN3InputXIndex], 0); | |||
| if (format != kOpFormat_NCDHW || x_shape.size() != 5) { | |||
| MS_LOG(INFO) << "Only format is NCDHW and the input dim of BatchNorm is 5, then do fusion. But format is: " | |||
| << format << ", size of x_shape is: " << x_shape.size(); | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| } // namespace | |||
| const BaseRef BatchNorm2BatchNorm3D::DefinePattern() const { | |||
| VarPtr Xs = std::make_shared<SeqVar>(); | |||
| MS_EXCEPTION_IF_NULL(Xs); | |||
| VectorRef pattern({prim::kPrimBatchNorm, Xs}); | |||
| return pattern; | |||
| } | |||
| const AnfNodePtr BatchNorm2BatchNorm3D::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, | |||
| const EquivPtr &) const { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto cnode_bn = node->cast<CNodePtr>(); | |||
| if (!NeedFusion(graph, cnode_bn)) { | |||
| return nullptr; | |||
| } | |||
| auto bn_3d = CreateBatchNorm3D(graph, cnode_bn); | |||
| TransferDepend(cnode_bn, graph, bn_3d); | |||
| return bn_3d; | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -1,33 +0,0 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_BATCHNORM_TO_BATCHNORM_3D_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_BATCHNORM_TO_BATCHNORM_3D_H_ | |||
| #include <memory> | |||
| #include "backend/optimizer/common/optimizer.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| class BatchNorm2BatchNorm3D : public PatternProcessPass { | |||
| public: | |||
| explicit BatchNorm2BatchNorm3D(bool multigraph = true) : PatternProcessPass("batchnorm_to_batchnorm3d", multigraph) {} | |||
| ~BatchNorm2BatchNorm3D() override = default; | |||
| const BaseRef DefinePattern() const override; | |||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_BATCHNORM_TO_BATCHNORM_3D_H_ | |||
| @@ -141,8 +141,6 @@ constexpr auto kAdamApplyOneWithDecayOpName = "AdamApplyOneWithDecay"; | |||
| constexpr auto kAdamApplyOneWithDecayAssignOpName = "AdamApplyOneWithDecayAssign"; | |||
| constexpr auto kBatchNormGradOpName = "BatchNormGrad"; | |||
| constexpr auto kBNInferOpName = "BNInfer"; | |||
| constexpr auto kBatchNorm3DOpName = "BatchNorm3D"; | |||
| constexpr auto kBatchNorm3DGradOpName = "BatchNorm3DGrad"; | |||
| constexpr auto kAdamApplyOneOpName = "AdamApplyOne"; | |||
| constexpr auto kAdamApplyOneAssignOpName = "AdamApplyOneAssign"; | |||
| constexpr auto kResizeNearestNeighborGradOpName = "ResizeNearestNeighborGrad"; | |||
| @@ -65,12 +65,7 @@ class _BatchNorm(Cell): | |||
| if momentum < 0 or momentum > 1: | |||
| raise ValueError("momentum should be a number in range [0, 1], but got {}".format(momentum)) | |||
| self.input_dims = input_dims | |||
| if self.input_dims == "3d": | |||
| self.format = validator.check_string(data_format, ['NCDHW'], 'format', self.cls_name) | |||
| else: | |||
| self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.cls_name) | |||
| if context.get_context("device_target") != "Ascend" and self.format == "NCDHW": | |||
| raise ValueError("NCDHW format only support in Ascend target.") | |||
| self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.cls_name) | |||
| if context.get_context("device_target") != "GPU" and self.format == "NHWC": | |||
| raise ValueError("NHWC format only support in GPU target.") | |||
| self.use_batch_statistics = use_batch_statistics | |||
| @@ -441,7 +436,7 @@ def _check_3d_shape(input_shape): | |||
| raise ValueError("For BatchNorm3d, input data must be 5-dimensional.") | |||
| class BatchNorm3d(_BatchNorm): | |||
| class BatchNorm3d(Cell): | |||
| r""" | |||
| Batch normalization layer over a 5D input. | |||
| @@ -493,7 +488,7 @@ class BatchNorm3d(_BatchNorm): | |||
| ValueError: If `data_format` is not 'NCDHW'. | |||
| Supported Platforms: | |||
| ``Ascend`` | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| Examples: | |||
| >>> net = nn.BatchNorm3d(num_features=3) | |||
| @@ -515,21 +510,27 @@ class BatchNorm3d(_BatchNorm): | |||
| moving_var_init='ones', | |||
| use_batch_statistics=None, | |||
| data_format='NCDHW'): | |||
| super(BatchNorm3d, self).__init__(num_features, | |||
| eps, | |||
| momentum, | |||
| affine, | |||
| gamma_init, | |||
| beta_init, | |||
| moving_mean_init, | |||
| moving_var_init, | |||
| use_batch_statistics, | |||
| input_dims='3d', | |||
| data_format=data_format) | |||
| super(BatchNorm3d, self).__init__() | |||
| self.format = validator.check_string(data_format, ['NCDHW'], 'format', self.cls_name) | |||
| self.reshape = P.Reshape() | |||
| self.bn2d = BatchNorm2d(num_features=num_features, | |||
| eps=eps, | |||
| momentum=momentum, | |||
| affine=affine, | |||
| gamma_init=gamma_init, | |||
| beta_init=beta_init, | |||
| moving_mean_init=moving_mean_init, | |||
| moving_var_init=moving_var_init, | |||
| use_batch_statistics=use_batch_statistics, | |||
| data_format="NCHW") | |||
| def _check_data_dim(self, x): | |||
| if x.ndim != 5: | |||
| pass | |||
| def construct(self, input_x): | |||
| x_shape = F.shape(input_x) | |||
| _check_3d_shape(x_shape) | |||
| input_x = self.reshape(input_x, (x_shape[0], x_shape[1], x_shape[2]*x_shape[3], x_shape[4])) | |||
| bn2d_out = self.bn2d(input_x) | |||
| bn3d_out = self.reshape(bn2d_out, x_shape) | |||
| return bn3d_out | |||
| class GlobalBatchNorm(_BatchNorm): | |||
| @@ -48,8 +48,6 @@ from .assign_sub import _assign_sub_tbe | |||
| from .batch_matmul import _batch_matmul_tbe | |||
| from .batchnorm import _batch_norm_tbe | |||
| from .batchnorm_grad import _batch_norm_grad_tbe | |||
| from .batchnorm3d import _batch_norm3d_tbe | |||
| from .batchnorm3d_grad import _batch_norm3d_grad_tbe | |||
| from .bias_add import _bias_add_tbe | |||
| from .bias_add_grad import _bias_add_grad_tbe | |||
| from .cast import _cast_tbe | |||
| @@ -1,51 +0,0 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """BatchNorm3D op""" | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| batch_norm3d_op_info = TBERegOp("BatchNorm3D") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("batch_norm3d.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("batch_norm3d") \ | |||
| .partial_flag(True) \ | |||
| .attr("epsilon", "optional", "float", "all") \ | |||
| .attr("format", "optional", "str", "all") \ | |||
| .attr("is_training", "optional", "bool", "all") \ | |||
| .input(0, "x", False, "required", "all") \ | |||
| .input(1, "scale", False, "required", "all", reshape_type="C") \ | |||
| .input(2, "offset", False, "required", "all", reshape_type="C") \ | |||
| .input(3, "mean", False, "optional", "all", reshape_type="C") \ | |||
| .input(4, "variance", False, "optional", "all", reshape_type="C") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .output(1, "batch_mean", False, "required", "all") \ | |||
| .output(2, "batch_variance", False, "required", "all") \ | |||
| .output(3, "reserve_space_1", False, "optional", "all") \ | |||
| .output(4, "reserve_space_2", False, "optional", "all") \ | |||
| .dtype_format(DataType.F16_NDC1HWC0, DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0, | |||
| DataType.F32_NDC1HWC0, DataType.F16_NDC1HWC0, DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0, | |||
| DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0) \ | |||
| .dtype_format(DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0, | |||
| DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0, | |||
| DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0) \ | |||
| .get_op_info() | |||
| @op_info_register(batch_norm3d_op_info) | |||
| def _batch_norm3d_tbe(): | |||
| """BatchNorm3D TBE register""" | |||
| return | |||
| @@ -1,51 +0,0 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """BatchNorm3DGrad op""" | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| batch_norm3d_grad_op_info = TBERegOp("BatchNorm3DGrad") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("batch_norm3d_grad.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("batch_norm3d_grad") \ | |||
| .partial_flag(True) \ | |||
| .attr("epsilon", "optional", "float", "all") \ | |||
| .attr("format", "optional", "str", "all") \ | |||
| .attr("is_training", "optional", "bool", "all") \ | |||
| .input(0, "y_backprop", False, "required", "all") \ | |||
| .input(1, "x", False, "required", "all") \ | |||
| .input(2, "scale", False, "required", "all", reshape_type="C") \ | |||
| .input(3, "reserve_space_1", False, "optional", "all") \ | |||
| .input(4, "reserve_space_2", False, "optional", "all") \ | |||
| .output(0, "x_backprop", False, "required", "all") \ | |||
| .output(1, "scale_backprop", False, "required", "all") \ | |||
| .output(2, "offset_backprop", False, "required", "all") \ | |||
| .output(3, "reserve_space_4", False, "optional", "all") \ | |||
| .output(4, "reserve_space_5", False, "optional", "all") \ | |||
| .dtype_format(DataType.F16_NDC1HWC0, DataType.F16_NDC1HWC0, DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0, | |||
| DataType.F32_NDC1HWC0, DataType.F16_NDC1HWC0, DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0, | |||
| DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0) \ | |||
| .dtype_format(DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0, | |||
| DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0, | |||
| DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0) \ | |||
| .get_op_info() | |||
| @op_info_register(batch_norm3d_grad_op_info) | |||
| def _batch_norm3d_grad_tbe(): | |||
| """BatchNorm3DGrad TBE register""" | |||
| return | |||
| @@ -195,7 +195,7 @@ class BatchNormGrad(PrimitiveWithInfer): | |||
| def __init__(self, is_training=False, epsilon=1e-5, data_format='NCHW'): | |||
| self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name) | |||
| self.epsilon = validator.check_float_range(epsilon, 0, 1, Rel.INC_RIGHT, 'epsilon', self.name) | |||
| self.data_format = validator.check_string(data_format, ['NCHW', 'NHWC', "NCDHW"], 'format', self.name) | |||
| self.data_format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.name) | |||
| def infer_shape(self, y_backprop_shape, x_shape, scale_shape, save_mean_shape, save_variance_shape, reserve): | |||
| validator.check("BatchNorm y_backprop_shape", y_backprop_shape, "BatchNorm x_shape", x_shape) | |||
| @@ -1201,11 +1201,9 @@ class BatchNorm(PrimitiveWithInfer): | |||
| validator.check_value_type('is_training', is_training, (bool,), self.name) | |||
| validator.check_float_range(epsilon, 0, 1, Rel.INC_RIGHT, 'epsilon', self.name) | |||
| validator.check_float_range(momentum, 0, 1, Rel.INC_BOTH, 'momentum', self.name) | |||
| self.format = validator.check_string(data_format, ['NCHW', 'NHWC', "NCDHW"], 'format', self.name) | |||
| self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.name) | |||
| if context.get_context("device_target") != "GPU" and self.format == "NHWC": | |||
| raise ValueError("NHWC format only support in GPU target.") | |||
| if context.get_context("device_target") != "Ascend" and self.format == "NCDHW": | |||
| raise ValueError("NCDHW format only support in Ascend target.") | |||
| self.add_prim_attr('data_format', self.format) | |||
| self.init_prim_io_names(inputs=['x', 'scale', 'offset', 'mean', 'variance'], | |||
| outputs=['y', 'batch_mean', 'batch_variance', 'reserve_space_1', 'reserve_space_2']) | |||