From: @liu_xiao_93 Reviewed-by: @liangchenghui Signed-off-by: @liangchenghuitags/v1.2.0-rc1
| @@ -622,6 +622,14 @@ void TbeKernelJsonCreator::ParseAttrDefaultValue(const std::string &type, const | |||
| (*attr_obj)[kJValue] = attr_value; | |||
| } else if (type == kVTypeFloat) { | |||
| (*attr_obj)[kJValue] = std::stof(value); | |||
| } else if (type == kVTypeListInt) { | |||
| std::stringstream string_value(value); | |||
| std::string list_elem; | |||
| std::vector<int64_t> attr_value; | |||
| while (std::getline(string_value, list_elem, ',')) { | |||
| attr_value.push_back(std::stoi(list_elem)); | |||
| } | |||
| (*attr_obj)[kJValue] = attr_value; | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Type: " << type << "not support"; | |||
| } | |||
| @@ -55,6 +55,7 @@ | |||
| #include "backend/optimizer/ascend/ir_fission/topk_split.h" | |||
| #include "backend/optimizer/ascend/ir_fission/lin_space_fission.h" | |||
| #include "backend/optimizer/ascend/ir_fission/space_to_depth_split.h" | |||
| #include "backend/optimizer/ascend/ir_fission/max_pool3d_grad_grad_fission.h" | |||
| #include "backend/optimizer/ascend/ir_fusion/momentum_lossscale_fusion.h" | |||
| #include "backend/optimizer/ascend/ir_fusion/mul_add_fusion.h" | |||
| #include "backend/optimizer/ascend/ir_fusion/mul_addn_fusion.h" | |||
| @@ -173,6 +174,7 @@ void AddAscendIRFusionPass(PassManager *ir_fusion_pm) { | |||
| ir_fusion_pm->AddPass(std::make_shared<TransposeReshapeFusion>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<TopKSplit>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<LinSpaceFission>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<MaxPool3DGradGradFission>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<MomentumLossscaleFusion>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<MulAddFusion>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<MulAddNFusion>()); | |||
| @@ -325,6 +327,7 @@ void RunOpAscendBackendIRFusionOptimization(const std::shared_ptr<session::Kerne | |||
| ir_fusion_pm->AddPass(std::make_shared<TopKSplit>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<LinSpaceFission>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<SpaceToDepthSplit>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<MaxPool3DGradGradFission>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<AddnFission>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<InsertPadForNMSWithMask>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<TensorScatterUpdateFission>()); | |||
| @@ -0,0 +1,120 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except i n compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/optimizer/ascend/ir_fission/max_pool3d_grad_grad_fission.h" | |||
| #include <vector> | |||
| #include <memory> | |||
| #include <string> | |||
| #include "backend/session/anf_runtime_algorithm.h" | |||
| #include "frontend/optimizer/opt.h" | |||
| #include "backend/optimizer/common/helper.h" | |||
| namespace mindspore::opt { | |||
| constexpr size_t kInputNum = 3; | |||
| constexpr size_t kFloat16Len = 2; // size of float16; | |||
| namespace { | |||
| tensor::TensorPtr CreateTensor(const AnfNodePtr &node) { | |||
| // 1 get attr ksize | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| auto ksize = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(cnode, "kernel_size"); | |||
| auto data_format = AnfAlgo::GetNodeAttr<std::string>(cnode, "format"); | |||
| if (data_format != kOpFormat_NCDHW) { | |||
| MS_LOG(ERROR) << "MaxPool3DGradGrad only support NCDHW."; | |||
| } | |||
| MS_LOG(DEBUG) << "ksize of MaxPool3DGradGrad:" << ksize; | |||
| int64_t D = ksize[2]; | |||
| int64_t H = ksize[3]; | |||
| int64_t W = ksize[4]; | |||
| // 1 create tensor | |||
| std::vector<int64_t> assist_shape = {1, 1, D, H, W}; // shape:NCDHW | |||
| TensorTypePtr tensor_type = std::make_shared<TensorType>(kFloat16); | |||
| MS_EXCEPTION_IF_NULL(tensor_type); | |||
| tensor::DeviceInfo device_info{kOpFormat_NDC1HWC0, tensor_type}; | |||
| tensor::TensorPtr assist_tensor = std::make_shared<tensor::Tensor>(kFloat16->type_id(), assist_shape); | |||
| assist_tensor->set_device_info(device_info); | |||
| // 2 set value of tensor | |||
| auto data_ptr = assist_tensor->data_c(); | |||
| MS_EXCEPTION_IF_NULL(data_ptr); | |||
| std::vector<float16> half_data; | |||
| int64_t dims = 1 * 1 * D * H * W; | |||
| int64_t counter = dims; | |||
| for (int64_t i = 0; i < dims; i++) { | |||
| half_data.emplace_back(float16(static_cast<float>(counter))); | |||
| counter--; | |||
| } | |||
| auto elem_num = dims * kFloat16Len; | |||
| auto ret_code = memcpy_s(data_ptr, static_cast<size_t>(assist_tensor->data().nbytes()), half_data.data(), elem_num); | |||
| if (ret_code != 0) { | |||
| MS_LOG(ERROR) << "Failed to copy data into Tensor."; | |||
| return nullptr; | |||
| } | |||
| return assist_tensor; | |||
| } | |||
| ValueNodePtr CreateValueNode(const AnfNodePtr &node) { | |||
| tensor::TensorPtr assist_tensor = CreateTensor(node); | |||
| MS_EXCEPTION_IF_NULL(assist_tensor); | |||
| auto assist_const = std::make_shared<ValueNode>(assist_tensor); | |||
| MS_EXCEPTION_IF_NULL(assist_const); | |||
| auto assist_abstract = assist_tensor->ToAbstract(); | |||
| assist_const->set_abstract(assist_abstract); | |||
| auto assist_kernel_info = std::make_shared<device::KernelInfo>(); | |||
| MS_EXCEPTION_IF_NULL(assist_kernel_info); | |||
| assist_const->set_kernel_info(assist_kernel_info); | |||
| kernel::KernelBuildInfo::KernelBuildInfoBuilder op_builder; | |||
| op_builder.SetOutputsFormat({kOpFormat_NDC1HWC0}); | |||
| op_builder.SetOutputsDeviceType({kNumberTypeFloat16}); | |||
| AnfAlgo::SetSelectKernelBuildInfo(op_builder.Build(), assist_const.get()); | |||
| return assist_const; | |||
| } | |||
| } // namespace | |||
| const BaseRef MaxPool3DGradGradFission::DefinePattern() const { | |||
| VarPtr Xs = std::make_shared<SeqVar>(); | |||
| auto max_pool3d_grad_grad_prim = std::make_shared<Primitive>(kMaxPool3DGradGradOpName); | |||
| return VectorRef({max_pool3d_grad_grad_prim, Xs}); | |||
| } | |||
| const AnfNodePtr MaxPool3DGradGradFission::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, | |||
| const EquivPtr &) const { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto kernel_graph = graph->cast<KernelGraphPtr>(); | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| if (cnode->size() != kInputNum + 1) { | |||
| MS_LOG(INFO) << "The node " << cnode->DebugString() << " is not equal to " << kInputNum << " inputs"; | |||
| return nullptr; | |||
| } | |||
| std::vector<AnfNodePtr> new_inputs{NewValueNode(std::make_shared<Primitive>(kMaxPool3DGradGradOpName))}; | |||
| auto assist_const = CreateValueNode(cnode); | |||
| new_inputs.insert(new_inputs.end(), cnode->inputs().begin() + 1, cnode->inputs().end()); | |||
| new_inputs.push_back(assist_const); | |||
| CNodePtr new_cnode = graph->NewCNode(new_inputs); | |||
| MS_EXCEPTION_IF_NULL(new_cnode); | |||
| new_cnode->set_abstract(cnode->abstract()); | |||
| new_cnode->set_scope(cnode->scope()); | |||
| AnfAlgo::CopyNodeAttrs(cnode, new_cnode); | |||
| if (kernel_graph != nullptr) { | |||
| kernel_graph->AddValueNodeToGraph(assist_const); | |||
| MS_LOG(INFO) << "Split MaxPool3DGradGrad op success."; | |||
| } | |||
| return new_cnode; | |||
| } | |||
| } // namespace mindspore::opt | |||
| @@ -0,0 +1,34 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_MAX_POOL3D_GRAD_GRAD_FISSION_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_MAX_POOL3D_GRAD_GRAD_FISSION_H_ | |||
| #include "backend/optimizer/common/optimizer.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| class MaxPool3DGradGradFission : public PatternProcessPass { | |||
| public: | |||
| explicit MaxPool3DGradGradFission(bool multigraph = true) | |||
| : PatternProcessPass("max_pool3d_grad_grad_fission", multigraph) {} | |||
| ~MaxPool3DGradGradFission() override = default; | |||
| const BaseRef DefinePattern() const override; | |||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_MAX_POOL3D_GRAD_GRAD_FISSION_H_ | |||
| @@ -221,6 +221,7 @@ constexpr auto kmaxPoolGradOpName = "MaxPoolGrad"; | |||
| constexpr auto kMaxPoolWithArgmaxOpName = "MaxPoolWithArgmax"; | |||
| constexpr auto kMaxPoolGradWithArgmaxOpName = "MaxPoolGradWithArgmax"; | |||
| constexpr auto kTensorAddOpName = "Add"; | |||
| constexpr auto kMaxPool3DGradGradOpName = "MaxPool3DGradGrad"; | |||
| constexpr auto kCastOpName = "Cast"; | |||
| constexpr auto kGreaterEqualOpName = "GreaterEqual"; | |||
| constexpr auto kAbsOpName = "Abs"; | |||
| @@ -249,6 +249,55 @@ def get_bprop_max_pool_grad(self): | |||
| return bprop | |||
| @bprop_getters.register(P.MaxPool3D) | |||
| def get_bprop_max_pool3d_grad(self): | |||
| """Grad definition for `MaxPool3D` operation.""" | |||
| max_pool3d_grad = G.MaxPool3DGrad( | |||
| kernel_size=self.kernel_size, | |||
| strides=self.strides, | |||
| pad_mode=self.pad_mode, | |||
| data_format=self.data_format) | |||
| def bprop(x, out, dout): | |||
| dx = max_pool3d_grad(x, out, dout) | |||
| return (dx,) | |||
| return bprop | |||
| @bprop_getters.register(G.MaxPool3DGrad) | |||
| def get_bprop_max_pool3d_grad_grad(self): | |||
| """Grad definition for `MaxPool3Grad` operation.""" | |||
| max_pool3d_grad_grad = G.MaxPool3DGradGrad( | |||
| kernel_size=self.kernel_size, | |||
| strides=self.strides, | |||
| pad_mode=self.pad_mode, | |||
| data_format=self.data_format) | |||
| def bprop(x, y, grad, out, dout): | |||
| dgrad = max_pool3d_grad_grad(x, y, dout) | |||
| return zeros_like(x), zeros_like(y), dgrad | |||
| return bprop | |||
| @bprop_getters.register(G.MaxPool3DGradGrad) | |||
| def get_bprop_max_pool3d_grad_grad_grad(self): | |||
| """Grad definition for `MaxPool3GradGrad` operation.""" | |||
| max_pool3d_grad = G.MaxPool3DGrad( | |||
| kernel_size=self.kernel_size, | |||
| strides=self.strides, | |||
| pad_mode=self.pad_mode, | |||
| data_format=self.data_format) | |||
| def bprop(x, y, grad, out, dout): | |||
| dgrad = max_pool3d_grad(x, y, dout) | |||
| return zeros_like(x), zeros_like(y), dgrad | |||
| return bprop | |||
| @bprop_getters.register(P.AvgPool) | |||
| def get_bprop_avg_pool_grad(self): | |||
| """Grad definition for `AvgPool` operation.""" | |||
| @@ -65,6 +65,9 @@ from .max_pool import _max_pool_tbe | |||
| from .max_pool_grad import _max_pool_grad_tbe | |||
| from .max_pool_grad_with_argmax import _max_pool_grad_with_argmax_tbe | |||
| from .max_pool_with_argmax import _max_pool_with_argmax_tbe | |||
| from .max_pool3d import _max_pool_3d_tbe | |||
| from .max_pool3d_grad import _max_pool_3d_grad_tbe | |||
| from .max_pool3d_grad_grad import _max_pool_3d_grad_grad_tbe | |||
| from .mul import _mul_tbe | |||
| from .mul_ds import _mul_ds_tbe | |||
| from .real_div import _real_div_tbe | |||
| @@ -0,0 +1,44 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """MaxPool3D op""" | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| max_pool3d_op_info = TBERegOp("MaxPool3D") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("max_pool3d.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("max_pool3d") \ | |||
| .partial_flag(True) \ | |||
| .attr("kernel_size", "required", "listInt", "all") \ | |||
| .attr("strides", "required", "listInt", "all") \ | |||
| .attr("pad_mode", "required", "str", "all") \ | |||
| .attr("pad_list", "optional", "listInt", "all", "0,0,0") \ | |||
| .attr("dilation", "optional", "listInt", "all", "1,1,1") \ | |||
| .attr("ceil_mode", "optional", "int", "all", "0") \ | |||
| .attr("format", "optional", "str", "all") \ | |||
| .input(0, "x", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .op_pattern("dynamicFormat") \ | |||
| .dtype_format(DataType.None_None, DataType.None_None) \ | |||
| .get_op_info() | |||
| @op_info_register(max_pool3d_op_info) | |||
| def _max_pool_3d_tbe(): | |||
| """MaxPool3D TBE register""" | |||
| return | |||
| @@ -0,0 +1,42 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """MaxPool3DGrad op""" | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| max_pool3d_grad_op_info = TBERegOp("MaxPool3DGrad") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("max_pool3d_grad.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("max_pool3d_grad") \ | |||
| .partial_flag(True) \ | |||
| .attr("kernel_size", "required", "listInt", "all") \ | |||
| .attr("strides", "required", "listInt", "all") \ | |||
| .attr("pad_list", "required", "listInt", "all") \ | |||
| .attr("format", "optional", "str", "all") \ | |||
| .input(0, "orig_x", False, "required", "all") \ | |||
| .input(1, "orig_y", False, "required", "all") \ | |||
| .input(2, "grads", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_NDC1HWC0, DataType.F16_NDC1HWC0, DataType.F16_NDC1HWC0, DataType.F32_NDC1HWC0) \ | |||
| .get_op_info() | |||
| @op_info_register(max_pool3d_grad_op_info) | |||
| def _max_pool_3d_grad_tbe(): | |||
| """MaxPool3DGrad TBE register""" | |||
| return | |||
| @@ -0,0 +1,44 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """MaxPool3DGradGrad op""" | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| max_pool3d_grad_grad_op_info = TBERegOp("MaxPool3DGradGrad") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("max_pool3d_grad_grad_d.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("max_pool3d_grad_grad_d") \ | |||
| .partial_flag(True) \ | |||
| .attr("kernel_size", "required", "listInt", "all") \ | |||
| .attr("strides", "required", "listInt", "all") \ | |||
| .attr("pad_list", "required", "listInt", "all") \ | |||
| .attr("format", "optional", "str", "all") \ | |||
| .input(0, "orig_in", False, "required", "all") \ | |||
| .input(1, "orig_out", False, "required", "all") \ | |||
| .input(2, "grads", False, "required", "all") \ | |||
| .input(3, "assist", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_NDC1HWC0, DataType.F16_NDC1HWC0, DataType.F16_NDC1HWC0, DataType.F16_NDC1HWC0, | |||
| DataType.F16_NDC1HWC0) \ | |||
| .get_op_info() | |||
| @op_info_register(max_pool3d_grad_grad_op_info) | |||
| def _max_pool_3d_grad_grad_tbe(): | |||
| """MaxPool3DGradGrad TBE register""" | |||
| return | |||
| @@ -69,7 +69,7 @@ from .nn_ops import (LSTM, SGD, Adam, FusedSparseAdam, FusedSparseLazyAdam, Adam | |||
| GeLU, Gelu, FastGeLU, FastGelu, Elu, | |||
| GetNext, L2Normalize, LayerNorm, L2Loss, CTCLoss, CTCGreedyDecoder, | |||
| LogSoftmax, | |||
| LogSoftmax, MaxPool3D, | |||
| MaxPool, DataFormatDimMap, | |||
| AvgPool, Conv2DBackpropInput, ComputeAccidentalHits, | |||
| MaxPoolWithArgmax, OneHot, Pad, MirrorPad, Mish, PReLU, ReLU, ReLU6, ReLUV2, HSwish, HSigmoid, | |||
| @@ -118,6 +118,7 @@ __all__ = [ | |||
| 'TensorAdd', | |||
| 'Argmax', | |||
| 'Argmin', | |||
| 'MaxPool3D', | |||
| 'ArgMaxWithValue', | |||
| 'ArgMinWithValue', | |||
| 'AddN', | |||
| @@ -984,6 +984,89 @@ class MaxPoolGradGrad(_PoolGrad): | |||
| return x1_dtype | |||
| def _get_max_pool3d_grad_pads_by_pad_mode(input_shape, kernel_size, strides, pad_mode): | |||
| """ | |||
| helper for get max pool3d grad pads by pad_mode | |||
| """ | |||
| def get_pad(origin_shape, ksize, stride): | |||
| tail = origin_shape % stride | |||
| pad = (ksize - tail) if tail > 0 else (ksize - stride) | |||
| pad = max(pad, 0) | |||
| pad1 = int(pad / 2) | |||
| pad2 = int(pad / 2) + pad % 2 | |||
| return pad1, pad2 | |||
| _, _, d, h, w = input_shape | |||
| _, _, kd, kh, kw = kernel_size | |||
| _, _, strd, strh, strw = strides | |||
| pads = (0, 0, 0, 0, 0, 0) | |||
| if pad_mode == 'SAME': | |||
| pads_d = get_pad(d, kd, strd) | |||
| pads_h = get_pad(h, kh, strh) | |||
| pads_w = get_pad(w, kw, strw) | |||
| pads = pads_d + pads_h + pads_w | |||
| return pads | |||
| class MaxPool3DGrad(PrimitiveWithInfer): | |||
| """Gradients of the max pool3d operation.""" | |||
| @prim_attr_register | |||
| def __init__(self, kernel_size=(1, 1, 1, 1, 1), strides=(1, 1, 1, 1, 1), pad_mode='VALID', data_format="NCDHW"): | |||
| validator.check_value_type('kernel_size', kernel_size, [int, tuple], self.name) | |||
| validator.check_value_type('strides', strides, [int, tuple], self.name) | |||
| validator.check_value_type('pad_mode', pad_mode, [str], self.name) | |||
| self.format = validator.check_string(data_format, ['NCDHW'], 'format', self.name) | |||
| self.pad_mode = validator.check_string(pad_mode.upper(), ['VALID', 'SAME'], 'pad_mode', self.name) | |||
| self.kernel_size = _check_3d_int_or_tuple("kernel_size", kernel_size, self.name, allow_five=True, ret_five=True) | |||
| self.add_prim_attr("kernel_size", self.kernel_size) | |||
| self.strides = _check_3d_int_or_tuple("strides", strides, self.name, allow_five=True, ret_five=True) | |||
| self.add_prim_attr("strides", self.strides) | |||
| def infer_shape(self, x_shape, y_shape, grad_shape): | |||
| validator.check_equal_int(len(x_shape), 5, "x rank", self.name) | |||
| pad_list = _get_max_pool3d_grad_pads_by_pad_mode(x_shape, self.kernel_size, self.strides, self.pad_mode) | |||
| for pad in pad_list: | |||
| validator.check_non_negative_int(pad, 'element of pad_list', self.name) | |||
| self.add_prim_attr("pad_list", pad_list) | |||
| return x_shape | |||
| def infer_dtype(self, x_dtype, y_dtype, grad_dtype): | |||
| args = {'x_dtype': x_dtype, 'y_dtype': y_dtype, 'grad_dtype': grad_dtype} | |||
| validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name) | |||
| return mstype.tensor_type(mstype.float32) | |||
| class MaxPool3DGradGrad(PrimitiveWithInfer): | |||
| """Gradients of the max pool3d grad operation.""" | |||
| @prim_attr_register | |||
| def __init__(self, kernel_size=(1, 1, 1, 1, 1), strides=(1, 1, 1, 1, 1), pad_mode='VALID', data_format="NCDHW"): | |||
| validator.check_value_type('kernel_size', kernel_size, [int, tuple], self.name) | |||
| validator.check_value_type('strides', strides, [int, tuple], self.name) | |||
| validator.check_value_type('pad_mode', pad_mode, [str], self.name) | |||
| self.format = validator.check_string(data_format, ['NCDHW'], 'format', self.name) | |||
| self.pad_mode = validator.check_string(pad_mode.upper(), ['VALID', 'SAME'], 'pad_mode', self.name) | |||
| self.kernel_size = _check_3d_int_or_tuple("kernel_size", kernel_size, self.name, allow_five=True, ret_five=True) | |||
| self.add_prim_attr("kernel_size", self.kernel_size) | |||
| self.strides = _check_3d_int_or_tuple("strides", strides, self.name, allow_five=True, ret_five=True) | |||
| self.add_prim_attr("strides", self.strides) | |||
| def infer_shape(self, x_shape, y_shape, grad_shape): | |||
| validator.check_equal_int(len(x_shape), 5, "x rank", self.name) | |||
| pad_list = _get_max_pool3d_grad_pads_by_pad_mode(x_shape, self.kernel_size, self.strides, self.pad_mode) | |||
| for pad in pad_list: | |||
| validator.check_non_negative_int(pad, 'element of pad_list', self.name) | |||
| self.add_prim_attr("pad_list", pad_list) | |||
| return y_shape | |||
| def infer_dtype(self, x_dtype, y_dtype, grad_dtype): | |||
| args = {'x_dtype': x_dtype, 'y_dtype': y_dtype} | |||
| validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name) | |||
| validator.check_tensor_dtype_valid('grad_dtype', grad_dtype, [mstype.float16, mstype.float32], self.name) | |||
| return x_dtype | |||
| class MaximumGrad(Primitive): | |||
| """Grad for maximum.""" | |||
| @@ -1833,6 +1833,105 @@ class MaxPoolWithArgmax(_Pool): | |||
| return x_dtype, argmax_dtype | |||
| class MaxPool3D(PrimitiveWithInfer): | |||
| r""" | |||
| Max pooling operation. | |||
| Applies a 3D max pooling over an input Tensor which can be regarded as a composition of 3D planes. | |||
| Typically the input is of shape :math:`(N_{in}, C_{in}, D_{in}, H_{in}, W_{in})`, MaxPool outputs | |||
| regional maximum in the :math:`(D_{in}, H_{in}, W_{in})`-dimension. Given kernel size | |||
| :math:`ks = (d_{ker}, h_{ker}, w_{ker})` and stride :math:`s = (s_0, s_1, s_2)`, the operation is as follows. | |||
| .. math:: | |||
| \text{output}(N_i, C_j, d, h, w) = | |||
| \max_{l=0, \ldots, d_{ker}-1} \max_{m=0, \ldots, h_{ker}-1} \max_{n=0, \ldots, w_{ker}-1} | |||
| \text{input}(N_i, C_j, s_0 \times d + l, s_1 \times h + m, s_2 \times w + n) | |||
| Args: | |||
| kernel_size (Union[int, tuple[int]]): The size of kernel used to take the maximum value, | |||
| is an int number that represents height and width are both kernel_size, or a tuple | |||
| of three int numbers that represent depth, height and width respectively. Default: 1. | |||
| strides (Union[int, tuple[int]]): The distance of kernel moving, an int number that represents | |||
| the depth, height and width of movement are both strides, or a tuple of three int numbers that | |||
| represent depth, height and width of movement respectively. Default: 1. | |||
| pad_mode (str): The optional value for pad mode, is "same" or "valid", not case sensitive. | |||
| Default: "valid". | |||
| - same: Adopts the way of completion. The height and width of the output will be the same as | |||
| the input. The total number of padding will be calculated in horizontal and vertical | |||
| directions and evenly distributed to top and bottom, left and right if possible. | |||
| Otherwise, the last extra padding will be done from the bottom and the right side. | |||
| - valid: Adopts the way of discarding. The possible largest height and width of output | |||
| will be returned without padding. Extra pixels will be discarded. | |||
| data_format (str) : The optional value for data format. Currently only support 'NCDHW'. Default: 'NCDHW'. | |||
| Inputs: | |||
| - **input** (Tensor) - Tensor of shape :math:`(N, C, D_{in}, H_{in}, W_{in})`. Data type must be float16. | |||
| Outputs: | |||
| Tensor, with shape :math:`(N, C, D_{out}, H_{out}, W_{out})`. Has the data type with `input`. | |||
| Raises: | |||
| TypeError: If `kernel_size` or `strides` is neither an int not a tuple. | |||
| TypeError: If `pad_mode` or `data_format` is not a string. | |||
| ValueError: If numbers in `kernel_size` or `strides` are not positive. | |||
| ValueError: If `pad_mode` is not one of 'same', 'valid'. | |||
| ValueError: If `kernel_size` or `strides` is a tuple whose length is not equal to 3 or 5. | |||
| ValueError: If `data_format` is not 'NCDHW'. | |||
| Supported Platforms: | |||
| ``Ascend`` | |||
| Examples: | |||
| >>> input = Tensor(np.arange(1 * 3 * 3 * 4).reshape((1, 3, 3, 4)), mindspore.float32) | |||
| >>> max_pool3d = ops.MaxPool3D(kernel_size=2, strides=1, pad_mode="valid") | |||
| >>> output = max_pool3d(input) | |||
| >>> print(output) | |||
| [[[[[10. 11.]]] | |||
| [[[22. 23.]]]]] | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self, kernel_size=1, strides=1, pad_mode="VALID", data_format="NCDHW"): | |||
| self.init_prim_io_names(inputs=['x'], outputs=['output']) | |||
| validator.check_value_type('kernel_size', kernel_size, [int, tuple], self.name) | |||
| validator.check_value_type('strides', strides, [int, tuple], self.name) | |||
| validator.check_value_type('pad_mode', pad_mode, [str], self.name) | |||
| self.pad_mode = validator.check_string(pad_mode.upper(), ['VALID', 'SAME'], 'pad_mode', self.name) | |||
| self.add_prim_attr("pad_mode", self.pad_mode) | |||
| self.data_format = validator.check_string(data_format, ['NCDHW'], 'data_format', self.name) | |||
| self.kernel_size = _check_3d_int_or_tuple("kernel_size", kernel_size, self.name, allow_five=True, ret_five=True) | |||
| self.add_prim_attr("kernel_size", self.kernel_size) | |||
| self.strides = _check_3d_int_or_tuple("strides", strides, self.name, allow_five=True, ret_five=True) | |||
| self.add_prim_attr("strides", self.strides) | |||
| def infer_shape(self, x_shape): | |||
| validator.check_equal_int(len(x_shape), 5, "x rank", self.name) | |||
| batch, channel, input_d, input_h, input_w = x_shape | |||
| self.add_prim_attr("x_shape", x_shape) | |||
| _, _, kernel_d, kernel_h, kernel_w = self.kernel_size | |||
| _, _, stride_d, stride_h, stride_w = self.strides | |||
| if self.pad_mode == "VALID": | |||
| out_d = math.ceil((input_d - (kernel_d - 1)) / stride_d) | |||
| out_h = math.ceil((input_h - (kernel_h - 1)) / stride_h) | |||
| out_w = math.ceil((input_w - (kernel_w - 1)) / stride_w) | |||
| elif self.pad_mode == "SAME": | |||
| out_d = math.ceil(input_d / stride_d) | |||
| out_h = math.ceil(input_h / stride_h) | |||
| out_w = math.ceil(input_w / stride_w) | |||
| out_shape = [batch, channel, out_d, out_h, out_w] | |||
| _check_shape('output', out_shape, self.name) | |||
| return out_shape | |||
| def infer_dtype(self, x_dtype): | |||
| validator.check_tensor_dtype_valid("x", x_dtype, [mstype.float16, mstype.float32], self.name) | |||
| return x_dtype | |||
| class AvgPool(_Pool): | |||
| r""" | |||
| Average pooling operation. | |||
| @@ -2097,8 +2196,8 @@ class BiasAdd(PrimitiveWithCheck): | |||
| def check_shape(self, x_shape, b_shape): | |||
| validator.check_int(len(x_shape), 2, Rel.GE, "x rank", self.name) | |||
| if self.format == "NCDHW" and len(x_shape) != 5: | |||
| raise ValueError("NCDHW format only support 5-dims input.") | |||
| if self.format == "NCDHW" and (len(x_shape) != 5 or context.get_context("device_target") != "Ascend"): | |||
| raise ValueError("NCDHW format only support 5-dims input in Ascend target.") | |||
| validator.check_equal_int(len(b_shape), 1, "bias rank", self.name) | |||
| x_channel = x_shape[-1] if self.format == "NHWC" else x_shape[1] | |||
| if np.all(np.array(x_shape) != -1): | |||
| @@ -1696,6 +1696,14 @@ test_case_nn_ops = [ | |||
| 'desc_inputs': [[3, 4, 6, 6], [3, 4, 3, 3], [3, 4, 3, 3]], | |||
| 'desc_bprop': [[3, 4, 6, 6]], | |||
| 'skip': ['backward']}), | |||
| ('MaxPool3D', { | |||
| 'block': P.MaxPool3D(kernel_size=2, strides=2, pad_mode="VALID"), | |||
| 'desc_inputs': [[100, 3, 28, 28, 28]], | |||
| 'desc_bprop': [[100, 3, 14, 14, 14]]}), | |||
| ('MaxPool3DGrad', { | |||
| 'block': G.MaxPool3DGrad(kernel_size=2, strides=2, pad_mode="VALID"), | |||
| 'desc_inputs': [[3, 4, 6, 6, 6], [3, 4, 3, 3, 3], [3, 4, 3, 3, 3]], | |||
| 'desc_bprop': [[3, 4, 6, 6, 6]]}), | |||
| ('AvgPool', { | |||
| 'block': P.AvgPool(kernel_size=(2, 2), strides=(2, 2), pad_mode="VALID"), | |||
| 'desc_inputs': [[100, 3, 28, 28]], | |||