| @@ -37,9 +37,9 @@ class Rel(Enum): | |||
| GE = 6 # >= | |||
| # scalar range check | |||
| INC_NEITHER = 7 # (), include neither | |||
| INC_LEFT = 8 # [), include left | |||
| INC_RIGHT = 9 # (], include right | |||
| INC_BOTH = 10 # [], include both | |||
| INC_LEFT = 8 # [), include left | |||
| INC_RIGHT = 9 # (], include right | |||
| INC_BOTH = 10 # [], include both | |||
| # collection in, not in | |||
| IN = 11 | |||
| NOT_IN = 12 | |||
| @@ -92,6 +92,41 @@ rel_strs = { | |||
| } | |||
| def _check_3d_int_or_tuple(arg_name, arg_value, prim_name, allow_five=False, | |||
| ret_five=False, greater_zero=True): | |||
| """ | |||
| Checks whether an argument is a positive int or tuple with 3 or 5(when allow_five is True) positive int elements. | |||
| """ | |||
| def _raise_message(): | |||
| raise ValueError(f"For '{prim_name}' attr '{arg_name}' should be an positive int number or a tuple of three " | |||
| f"{'or five ' if allow_five else ''}positive int numbers, but got {arg_value}") | |||
| def _get_return_value(): | |||
| if isinstance(arg_value, int): | |||
| ret = (1, 1, arg_value, arg_value, arg_value) if ret_five else (arg_value, arg_value, arg_value) | |||
| elif len(arg_value) == 3: | |||
| ret = (1, 1, arg_value[0], arg_value[1], arg_value[2]) if ret_five else arg_value | |||
| elif len(arg_value) == 5: | |||
| if not allow_five: | |||
| _raise_message() | |||
| ret = arg_value if ret_five else (arg_value[1], arg_value[2], arg_value[3]) | |||
| else: | |||
| _raise_message() | |||
| return ret | |||
| Validator.check_value_type(arg_name, arg_value, (int, tuple), prim_name) | |||
| ret_value = _get_return_value() | |||
| for item in ret_value: | |||
| if isinstance(item, int) and not isinstance(item, bool): | |||
| if greater_zero and item > 0: | |||
| continue | |||
| if not greater_zero and item >= 0: | |||
| continue | |||
| _raise_message() | |||
| return tuple(ret_value) | |||
| def check_number(arg_value, value, rel, arg_type=int, arg_name=None, prim_name=None): | |||
| """ | |||
| Check argument integer. | |||
| @@ -428,6 +463,7 @@ class Validator: | |||
| @staticmethod | |||
| def check_types_same_and_valid(args, valid_values, prim_name): | |||
| """Checks whether the types of inputs are the same and valid.""" | |||
| def _check_type_valid(arg): | |||
| arg_key, arg_val = arg | |||
| elem_type = arg_val | |||
| @@ -494,6 +530,7 @@ class Validator: | |||
| raise TypeError(f'For \'{prim_name}\' type of `{arg2_name}` should be same as `{arg1_name}`,' | |||
| f' but `{arg1_name}` is {arg1_type} and `{arg2_name}` is {arg2_type}.') | |||
| return arg1 | |||
| reduce(_check_types_same, map(_check_argument_type, args.items())) | |||
| @staticmethod | |||
| @@ -625,6 +662,7 @@ def args_type_check(*type_args, **type_kwargs): | |||
| if value is not None and not isinstance(value, bound_types[name]): | |||
| raise TypeError('Argument {} must be {}'.format(name, bound_types[name])) | |||
| return func(*args, **kwargs) | |||
| return wrapper | |||
| return type_check | |||
| @@ -29,6 +29,8 @@ ConstInputToAttrInfoRegistry::ConstInputToAttrInfoRegistry() { | |||
| Register(prim::kPrimAvgPoolGradVm->name(), {0}); | |||
| Register(prim::kPrimConv2DBackpropInput->name(), {2}); | |||
| Register(prim::kPrimConv2DBackpropFilter->name(), {2}); | |||
| Register(prim::kPrimConv3DBackpropInput->name(), {2}); | |||
| Register(prim::kPrimConv3DBackpropFilter->name(), {2}); | |||
| Register(prim::kPrimDepthwiseConv2dNativeBackpropFilter->name(), {1}); | |||
| Register(prim::kPrimDepthwiseConv2dNativeBackpropInput->name(), {0}); | |||
| Register(prim::kPrimReshape->name(), {1}); | |||
| @@ -31,6 +31,7 @@ namespace session { | |||
| namespace { | |||
| constexpr auto kIsFeatureMapOutput = "IsFeatureMapOutput"; | |||
| constexpr auto kIsFeatureMapInputList = "IsFeatureMapInputList"; | |||
| constexpr size_t k5dDims = 5; | |||
| const std::set<std::string> kOpAssignKernelNameList = {prim::kPrimAssign->name(), prim::kPrimAssignAdd->name(), | |||
| prim::kPrimAssignSub->name()}; | |||
| void PushNoVisitedNode(const AnfNodePtr &node, std::queue<AnfNodePtr> *que, | |||
| @@ -383,7 +384,8 @@ void KernelGraph::ResetInFormat(const AnfNodePtr &node, const std::string &forma | |||
| for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(node); i++) { | |||
| auto in_node = AnfAlgo::GetInputNode(node->cast<CNodePtr>(), i); | |||
| MS_EXCEPTION_IF_NULL(in_node); | |||
| if (in_node->isa<Parameter>() || in_node->isa<ValueNode>()) { | |||
| if ((in_node->isa<Parameter>() || in_node->isa<ValueNode>()) && | |||
| AnfAlgo::GetOutputInferShape(in_node, 0).size() == k5dDims) { | |||
| ReSetParameterValueNodeFormatAndType(in_node, format); | |||
| } | |||
| } | |||
| @@ -291,10 +291,7 @@ std::vector<size_t> Fracz3DDeviceShape(const std::vector<size_t> &shape) { | |||
| std::vector<size_t> device_shape; | |||
| const size_t C1 = (shape[1] + kCubeSize - 1) / kCubeSize; | |||
| const size_t N1 = (shape[0] + kCubeSize - 1) / kCubeSize; | |||
| device_shape.push_back(shape[2]); | |||
| device_shape.push_back(C1); | |||
| device_shape.push_back(shape[3]); | |||
| device_shape.push_back(shape[4]); | |||
| device_shape.push_back(shape[2] * C1 * shape[3] * shape[4]); | |||
| device_shape.push_back(N1); | |||
| device_shape.push_back(kCubeSize); | |||
| device_shape.push_back(kCubeSize); | |||
| @@ -149,6 +149,8 @@ inline const PrimitivePtr kPrimReluGrad = std::make_shared<Primitive>("ReluGrad" | |||
| inline const PrimitivePtr kPrimRelu6Grad = std::make_shared<Primitive>("ReLU6Grad"); | |||
| inline const PrimitivePtr kPrimConv2DBackpropInput = std::make_shared<Primitive>("Conv2DBackpropInput"); | |||
| inline const PrimitivePtr kPrimConv2DBackpropFilter = std::make_shared<Primitive>("Conv2DBackpropFilter"); | |||
| inline const PrimitivePtr kPrimConv3DBackpropInput = std::make_shared<Primitive>("Conv3DBackpropInput"); | |||
| inline const PrimitivePtr kPrimConv3DBackpropFilter = std::make_shared<Primitive>("Conv3DBackpropFilter"); | |||
| inline const PrimitivePtr kPrimDepthwiseConv2dNative = std::make_shared<Primitive>("DepthwiseConv2dNative"); | |||
| inline const PrimitivePtr kPrimDepthwiseConv2dNativeBackpropFilter = | |||
| std::make_shared<Primitive>("DepthwiseConv2dNativeBackpropFilter"); | |||
| @@ -18,6 +18,7 @@ import numpy as np | |||
| from mindspore.ops import _selected_grad_ops as SG | |||
| from mindspore.ops.primitive import constexpr | |||
| from mindspore.common.tensor import Tensor | |||
| from mindspore.ops.operations import nn_ops as nps | |||
| from .grad_base import bprop_getters | |||
| from .. import functional as F | |||
| from .. import operations as P | |||
| @@ -60,6 +61,27 @@ def get_bprop_conv2d(self): | |||
| return bprop | |||
| @bprop_getters.register(nps.Conv3D) | |||
| def get_bprop_conv3d(self): | |||
| """Grad definition for `Conv3D` operation.""" | |||
| input_grad = nps.Conv3DBackpropInput( | |||
| self.out_channel, self.kernel_size, self.pad_mode, self.pad, mode=self.mode, | |||
| dilation=self.dilation, stride=self.stride, group=self.group, data_format=self.format | |||
| ) | |||
| filter_grad = G.Conv3DBackpropFilter( | |||
| self.out_channel, self.kernel_size, self.pad_mode, self.pad, mode=self.mode, | |||
| dilation=self.dilation, stride=self.stride, group=self.group, data_format=self.format | |||
| ) | |||
| get_shape = P.Shape() | |||
| def bprop(x, w, out, dout): | |||
| dx = input_grad(w, dout, get_shape(x)) | |||
| dw = filter_grad(x, dout, get_shape(w)) | |||
| return dx, dw | |||
| return bprop | |||
| @bprop_getters.register(inner.ExtractImagePatches) | |||
| def get_bprop_extract_image_patches(self): | |||
| """Grad definition for `ExtractImagePatches` operation.""" | |||
| @@ -347,3 +347,7 @@ from .fake_quant_with_min_max_vars import _fake_quant_with_min_max_vars_tbe | |||
| from .fake_quant_with_min_max_vars_gradient import _fake_quant_with_min_max_vars_gradient_tbe | |||
| from .fake_quant_with_min_max_vars_per_channel import _fake_quant_with_min_max_vars_per_channel_tbe | |||
| from .fake_quant_with_min_max_vars_per_channel_gradient import _fake_quant_with_min_max_vars_per_channel_gradient_tbe | |||
| from .conv3d import _conv3d_tbe | |||
| from .conv3d_backprop_input import _conv3d_backprop_input_tbe | |||
| from .conv3d_backprop_filter import _conv3d_backprop_filter_tbe | |||
| from .conv3d_transpose import _conv3d_transpose_tbe | |||
| @@ -0,0 +1,45 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Conv3D op""" | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| conv3d_op_info = TBERegOp("Conv3D") \ | |||
| .fusion_type("CONVLUTION") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("conv3d.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("conv3d") \ | |||
| .partial_flag(True) \ | |||
| .attr("strides", "required", "listInt", "all") \ | |||
| .attr("pads", "required", "listInt", "all") \ | |||
| .attr("dilations", "required", "listInt", "all") \ | |||
| .attr("groups", "optional", "int", "all") \ | |||
| .attr("data_format", "optional", "str", "all") \ | |||
| .attr("offset_x", "optional", "int", "all") \ | |||
| .input(0, "x", False, "required", "all") \ | |||
| .input(1, "filter", False, "required", "all") \ | |||
| .input(2, "bias", False, "optional", "all") \ | |||
| .input(3, "offset_w", False, "optional", "all") \ | |||
| .output(0, "y", True, "required", "all") \ | |||
| .dtype_format(DataType.F16_NDC1HWC0, DataType.F16_FRACTAL_Z_3D, DataType.F16_Default, | |||
| DataType.I8_Default, DataType.F16_NDC1HWC0) \ | |||
| .get_op_info() | |||
| @op_info_register(conv3d_op_info) | |||
| def _conv3d_tbe(): | |||
| """Conv3D TBE register""" | |||
| return | |||
| @@ -0,0 +1,42 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Conv3DBackpropFilter op""" | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| conv3d_backprop_filter_op_info = TBERegOp("Conv3DBackpropFilter") \ | |||
| .fusion_type("CONVLUTION") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("conv3d_backprop_filter_d.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("conv3d_backprop_filter_d") \ | |||
| .partial_flag(True) \ | |||
| .attr("filter_size", "required", "listInt", "all") \ | |||
| .attr("strides", "required", "listInt", "all") \ | |||
| .attr("pads", "required", "listInt", "all") \ | |||
| .attr("dilations", "required", "listInt", "all") \ | |||
| .attr("groups", "optional", "int", "all") \ | |||
| .attr("data_format", "optional", "str", "all") \ | |||
| .input(0, "x", False, "required", "all") \ | |||
| .input(1, "out_backprop", False, "required", "all") \ | |||
| .output(0, "y", True, "required", "all") \ | |||
| .dtype_format(DataType.F16_NDC1HWC0, DataType.F16_NDC1HWC0, DataType.F32_FRACTAL_Z_3D) \ | |||
| .get_op_info() | |||
| @op_info_register(conv3d_backprop_filter_op_info) | |||
| def _conv3d_backprop_filter_tbe(): | |||
| """Conv3DBackpropFilter TBE register""" | |||
| return | |||
| @@ -0,0 +1,42 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Conv3DBackpropInput op""" | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| conv3d_backprop_input_op_info = TBERegOp("Conv3DBackpropInput") \ | |||
| .fusion_type("CONVLUTION") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("conv3d_backprop_input_d.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("conv3d_backprop_input_d") \ | |||
| .partial_flag(True) \ | |||
| .attr("input_size", "required", "listInt", "all") \ | |||
| .attr("strides", "required", "listInt", "all") \ | |||
| .attr("pads", "required", "listInt", "all") \ | |||
| .attr("dilations", "required", "listInt", "all") \ | |||
| .attr("groups", "optional", "int", "all") \ | |||
| .attr("data_format", "optional", "str", "all") \ | |||
| .input(0, "filter", False, "required", "all") \ | |||
| .input(1, "out_backprop", False, "required", "all") \ | |||
| .output(0, "y", True, "required", "all") \ | |||
| .dtype_format(DataType.F16_FRACTAL_Z_3D, DataType.F16_NDC1HWC0, DataType.F16_NDC1HWC0) \ | |||
| .get_op_info() | |||
| @op_info_register(conv3d_backprop_input_op_info) | |||
| def _conv3d_backprop_input_tbe(): | |||
| """Conv3DBackpropInput TBE register""" | |||
| return | |||
| @@ -0,0 +1,46 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Conv3DTranspose op""" | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| conv3d_transpose_op_info = TBERegOp("Conv3DTranspose") \ | |||
| .fusion_type("CONVLUTION") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("conv3d_transpose_d.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("conv3d_transpose_d") \ | |||
| .partial_flag(True) \ | |||
| .attr("input_size", "required", "listInt", "all") \ | |||
| .attr("strides", "required", "listInt", "all") \ | |||
| .attr("pads", "required", "listInt", "all") \ | |||
| .attr("dilations", "optional", "listInt", "all") \ | |||
| .attr("groups", "optional", "int", "all") \ | |||
| .attr("data_format", "optional", "str", "all") \ | |||
| .attr("output_padding", "optional", "listInt", "all") \ | |||
| .input(0, "x", False, "required", "all") \ | |||
| .input(0, "filter", False, "required", "all") \ | |||
| .input(0, "bias", False, "optional", "all") \ | |||
| .input(1, "offset_w", False, "optional", "all") \ | |||
| .output(0, "y", True, "required", "all") \ | |||
| .dtype_format(DataType.F16_NDC1HWC0, DataType.F16_FRACTAL_Z_3D, DataType.F16_Default, DataType.I8_Default, | |||
| DataType.F16_NDC1HWC0) \ | |||
| .get_op_info() | |||
| @op_info_register(conv3d_transpose_op_info) | |||
| def _conv3d_transpose_tbe(): | |||
| """Conv3DTranspose TBE register""" | |||
| return | |||
| @@ -133,6 +133,24 @@ trans_data_op_info = TBERegOp("TransData") \ | |||
| .dtype_format(DataType.F32_HWCN, DataType.F32_FracZNLSTM) \ | |||
| .dtype_format(DataType.F16_FracZNLSTM, DataType.F16_HWCN) \ | |||
| .dtype_format(DataType.F32_FracZNLSTM, DataType.F32_HWCN) \ | |||
| .dtype_format(DataType.F16_NDHWC, DataType.F16_NDC1HWC0) \ | |||
| .dtype_format(DataType.F16_NDC1HWC0, DataType.F16_NDHWC) \ | |||
| .dtype_format(DataType.F16_DHWCN, DataType.F16_FRACTAL_Z_3D) \ | |||
| .dtype_format(DataType.F16_FRACTAL_Z_3D, DataType.F16_DHWCN) \ | |||
| .dtype_format(DataType.F16_NCDHW, DataType.F16_NDC1HWC0) \ | |||
| .dtype_format(DataType.F16_NDC1HWC0, DataType.F16_NCDHW) \ | |||
| .dtype_format(DataType.F16_NCDHW, DataType.F16_FRACTAL_Z_3D) \ | |||
| .dtype_format(DataType.F32_NCDHW, DataType.F32_FRACTAL_Z_3D) \ | |||
| .dtype_format(DataType.F16_FRACTAL_Z_3D, DataType.F16_NCDHW) \ | |||
| .dtype_format(DataType.F32_FRACTAL_Z_3D, DataType.F32_NCDHW) \ | |||
| .dtype_format(DataType.F16_NDHWC, DataType.F16_FRACTAL_Z_3D) \ | |||
| .dtype_format(DataType.F32_NDHWC, DataType.F32_FRACTAL_Z_3D) \ | |||
| .dtype_format(DataType.F16_FRACTAL_Z_3D, DataType.F16_NDHWC) \ | |||
| .dtype_format(DataType.F32_FRACTAL_Z_3D, DataType.F32_NDHWC) \ | |||
| .dtype_format(DataType.F32_DHWCN, DataType.F32_FRACTAL_Z_3D) \ | |||
| .dtype_format(DataType.F32_FRACTAL_Z_3D, DataType.F32_DHWCN) \ | |||
| .dtype_format(DataType.F32_NDC1HWC0, DataType.F32_NDHWC) \ | |||
| .dtype_format(DataType.F32_NDHWC, DataType.F32_NDC1HWC0) \ | |||
| .get_op_info() | |||
| @@ -110,7 +110,7 @@ def clip_by_global_norm(x, clip_norm=1.0, use_norm=None): | |||
| Clips tensor values by the ratio of the sum of their norms. | |||
| Note: | |||
| input 'x' should be a tuple or list of tensors. Otherwise, it will raise an error. | |||
| Input 'x' should be a tuple or list of tensors. Otherwise, it will raise an error. | |||
| Args: | |||
| x (Union(tuple[Tensor], list[Tensor])): Input data to clip. | |||
| @@ -631,6 +631,10 @@ class DataType: | |||
| F16_NHWC = ("float16", "NHWC") | |||
| F16_HWCN = ("float16", "HWCN") | |||
| F16_NDHWC = ("float16", "NDHWC") | |||
| F16_NCDHW = ("float16", "NCDHW") | |||
| F16_DHWCN = ("float16", "DHWCN") | |||
| F16_NDC1HWC0 = ("float16", "NDC1HWC0") | |||
| F16_FRACTAL_Z_3D = ("float16", "FRACTAL_Z_3D") | |||
| F16_FracZNLSTM = ("float16", "FRACTAL_ZN_LSTM") | |||
| F32_None = ("float32", "") | |||
| @@ -643,6 +647,10 @@ class DataType: | |||
| F32_NHWC = ("float32", "NHWC") | |||
| F32_HWCN = ("float32", "HWCN") | |||
| F32_NDHWC = ("float32", "NDHWC") | |||
| F32_NCDHW = ("float32", "NCDHW") | |||
| F32_DHWCN = ("float32", "DHWCN") | |||
| F32_NDC1HWC0 = ("float32", "NDC1HWC0") | |||
| F32_FRACTAL_Z_3D = ("float32", "FRACTAL_Z_3D") | |||
| F32_FracZNLSTM = ("float32", "FRACTAL_ZN_LSTM") | |||
| F64_None = ("float64", "") | |||
| @@ -14,8 +14,9 @@ | |||
| # ============================================================================ | |||
| """Operators for gradients.""" | |||
| import math | |||
| from functools import partial | |||
| from mindspore._checkparam import _check_3d_int_or_tuple | |||
| from .. import signature as sig | |||
| from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register | |||
| from ..._checkparam import Validator as validator, Rel | |||
| @@ -288,6 +289,140 @@ class ConcatOffset(PrimitiveWithInfer): | |||
| return out | |||
| class Conv3DBackpropFilter(PrimitiveWithInfer): | |||
| """ | |||
| Computes the gradients of convolution 3D with respect to the filter. | |||
| Args: | |||
| out_channel (int): The dimension of the output. | |||
| kernel_size (Union[int, tuple[int]]): The kernel size of the 3D convolution. | |||
| mode (int): Modes for different convolutions. Not currently used. | |||
| pad_mode (str): Modes to fill padding. It could be "valid", "same", or "pad". Default: "valid". | |||
| pad (Union(int, tuple[int])): The pad value to be filled. Default: 0. If `pad` is an integer, the paddings of | |||
| head, tail, top, bottom, left and right are the same, equal to pad. If `pad` is a tuple of four | |||
| integers, the padding of head, tail, top, bottom, left and right equal to pad[0], pad[1], pad[2], | |||
| pad[3], pad[4] and pad[5] correspondingly. | |||
| stride (Union(int, tuple[int])): The stride to be applied to the convolution filter. Default: 1. | |||
| dilation (Union(int, tuple[int])): Specifies the space to use between kernel elements. Default: 1. | |||
| group (int): Splits input into groups. Default: 1. | |||
| data_format (str): The optional value for data format. Currently only support 'NCDHW'. | |||
| Inputs: | |||
| - **x** (Tensor) - The input of the convolution, then the shape is :math:`(C_{out}, C_{in}, D_{in}, K_1, K_2)`. | |||
| - **dout** (Tensor) - The gradients w.r.t the output of the convolution. The shape conforms to the default | |||
| data_format :math:`(N, C_{out}, D_{out}, H_{out}, W_{out})`. | |||
| - **w_size** (Tensor) - A tuple describes the shape of the weight which conforms to the format | |||
| :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})`. | |||
| Outputs: | |||
| Tensor, the gradients w.r.t the weight of convolution 3D. It has the same shape as the weight. | |||
| Supported Platforms: | |||
| ``Ascend`` | |||
| Examples: | |||
| >>> x = Tensor(np.ones([16, 32, 13, 37, 33]), mindspore.float16) | |||
| >>> dout = Tensor(np.ones([16, 32, 10, 32, 32]), mindspore.float16) | |||
| >>> w = Tensor(np.ones([32, 32, 4, 6, 2]), mindspore.float16) | |||
| >>> conv3d_backprop_input = P.Conv3DBackpropInput(out_channel=4, kernel_size=(4, 6, 2)) | |||
| >>> output = conv3d_backprop_input(x, dout, F.shape(w)) | |||
| >>> print(output.shape) | |||
| (32, 32, 4, 6, 2) | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self, | |||
| out_channel, | |||
| kernel_size, | |||
| pad_mode="valid", | |||
| pad=0, | |||
| mode=1, | |||
| stride=(1, 1, 1, 1, 1), | |||
| dilation=(1, 1, 1, 1, 1), | |||
| group=1, | |||
| data_format="NCDHW"): | |||
| """Initialize Convolution""" | |||
| self.init_prim_io_names(inputs=['x', 'out_backprop', 'filter_size'], outputs=['y']) | |||
| self.out_channel = validator.check_positive_int(out_channel, 'out_channel', self.name) | |||
| self.kernel_size = _check_3d_int_or_tuple('kernel_size', kernel_size, self.name) | |||
| self.stride = _check_3d_int_or_tuple('stride', stride, self.name, allow_five=True, ret_five=True) | |||
| self.add_prim_attr('strides', self.stride) | |||
| self.dilation = _check_3d_int_or_tuple('dilation', dilation, self.name, allow_five=True, ret_five=True) | |||
| self.add_prim_attr('dilations', self.dilation) | |||
| validator.check_value_type('pad', pad, (int, tuple), self.name) | |||
| if isinstance(pad, int): | |||
| pad = (pad,) * 6 | |||
| validator.check_equal_int(len(pad), 6, 'pad size', self.name) | |||
| self.pad_list = pad | |||
| self.add_prim_attr('pads', self.pad_list) | |||
| self.pad_mode = validator.check_string(pad_mode.lower(), ['valid', 'same', 'pad'], 'pad_mode', self.name) | |||
| if self.pad_mode != 'pad' and self.pad_list != (0, 0, 0, 0, 0, 0): | |||
| raise ValueError(f"For '{self.name}', when pad is not 0, pad_mode should be set as 'pad'.") | |||
| if self.pad_mode == 'pad': | |||
| for item in pad: | |||
| validator.check_non_negative_int(item, 'pad item', self.name) | |||
| self.add_prim_attr('pad_mode', self.pad_mode) | |||
| self.mode = validator.check_equal_int(mode, 1, 'mode', self.name) | |||
| self.group = validator.check_positive_int(group, 'group', self.name) | |||
| self.add_prim_attr('groups', self.group) | |||
| self.format = validator.check_string(data_format, ['NCDHW'], 'format', self.name) | |||
| self.add_prim_attr('data_format', self.format) | |||
| self.add_prim_attr('io_format', "NCDHW") | |||
| def __infer__(self, x, doutput, w_size): | |||
| w_size_v = w_size['value'] | |||
| validator.check_value_type('w_size', w_size_v, [tuple], self.name) | |||
| for i, dim_len in enumerate(w_size_v): | |||
| validator.check_value_type("w_size[%d]" % i, dim_len, [int], self.name) | |||
| args = {"x": x['dtype'], "doutput": doutput['dtype']} | |||
| valid_dtypes = [mstype.float16, mstype.float32] | |||
| validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name) | |||
| validator.check("filter's batch", w_size_v[0], "dout's channel", doutput['shape'][1], Rel.EQ, self.name) | |||
| validator.check("filter's channel", w_size_v[1], "input_size's channel", x['shape'][1], Rel.EQ, self.name) | |||
| validator.check("input_size's batch", x['shape'][0], "dout's batch", doutput['shape'][0], Rel.EQ, self.name) | |||
| # infer shape | |||
| x_shape = x['shape'] | |||
| dout_shape = doutput['shape'] | |||
| kernel_d = self.kernel_size[0] | |||
| kernel_h = self.kernel_size[1] | |||
| kernel_w = self.kernel_size[2] | |||
| stride_d = self.stride[2] | |||
| stride_h = self.stride[3] | |||
| stride_w = self.stride[4] | |||
| dilation_d = self.dilation[2] | |||
| dilation_h = self.dilation[3] | |||
| dilation_w = self.dilation[4] | |||
| # The pad_mode is valid by default. If pad_mode is not valid or same, then pad. | |||
| if self.pad_mode == "valid": | |||
| self.pad_list = (0, 0, 0, 0, 0, 0) | |||
| if self.pad_mode == "same": | |||
| pad_needed_d = max(0, (dout_shape[2] - 1) * stride_d + dilation_d * (kernel_d - 1) + 1 - x_shape[2]) | |||
| pad_head = math.floor(pad_needed_d / 2) | |||
| pad_tail = pad_needed_d - pad_head | |||
| pad_needed_h = max(0, (dout_shape[3] - 1) * stride_h + dilation_h * (kernel_h - 1) + 1 - x_shape[3]) | |||
| pad_top = math.floor(pad_needed_h / 2) | |||
| pad_bottom = pad_needed_h - pad_top | |||
| pad_needed_w = max(0, (dout_shape[4] - 1) * stride_w + dilation_w * (kernel_w - 1) + 1 - x_shape[4]) | |||
| pad_left = math.floor(pad_needed_w / 2) | |||
| pad_right = pad_needed_w - pad_left | |||
| self.pad_list = (pad_head, pad_tail, pad_top, pad_bottom, pad_left, pad_right) | |||
| self.add_prim_attr('pads', self.pad_list) | |||
| out = { | |||
| 'value': None, | |||
| 'shape': w_size_v, | |||
| 'dtype': mstype.float32, | |||
| } | |||
| return out | |||
| class Conv2DBackpropFilter(PrimitiveWithInfer): | |||
| """ | |||
| Computes the gradients of convolution with respect to the filter. | |||
| @@ -18,6 +18,7 @@ | |||
| import math | |||
| import operator | |||
| from functools import reduce, partial | |||
| from mindspore._checkparam import _check_3d_int_or_tuple | |||
| import numpy as np | |||
| from ... import context | |||
| from .. import signature as sig | |||
| @@ -3748,6 +3749,7 @@ class AdamNoUpdateParam(PrimitiveWithInfer): | |||
| [-0.00013441 -0.00013441 -0.00013441]] | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self, use_locking=False, use_nesterov=False): | |||
| validator.check_value_type("use_locking", use_locking, [bool], self.name) | |||
| @@ -6732,3 +6734,295 @@ class LRN(PrimitiveWithInfer): | |||
| def infer_shape(self, x_shape): | |||
| validator.check_int(len(x_shape), 4, Rel.EQ, "x_shape", self.name) | |||
| return x_shape | |||
| class Conv3D(PrimitiveWithInfer): | |||
| r""" | |||
| 3D convolution layer. | |||
| Applies a 3D convolution over an input tensor which is typically of shape | |||
| :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})`, where :math:`N` is batch size and :math:`C_{in}` is channel number. | |||
| For each batch of shape :math:`(C_{in}, D_{in}, H_{in}, W_{in})`. | |||
| If the 'pad_mode' is set to be "valid", the output height and width will be | |||
| :math:`\left \lfloor{1 + \frac{D_{in} + 2 \times \text{padding} - \text{ks_d} - | |||
| (\text{ks_d} - 1) \times (\text{dilation} - 1) }{\text{stride}}} \right \rfloor` and | |||
| :math:`\left \lfloor{1 + \frac{H_{in} + 2 \times \text{padding} - \text{ks_h} - | |||
| (\text{ks_h} - 1) \times (\text{dilation} - 1) }{\text{stride}}} \right \rfloor` and | |||
| :math:`\left \lfloor{1 + \frac{W_{in} + 2 \times \text{padding} - \text{ks_w} - | |||
| (\text{ks_w} - 1) \times (\text{dilation} - 1) }{\text{stride}}} \right \rfloor` respectively. | |||
| Args: | |||
| out_channel (int): The dimension of the output. | |||
| kernel_size (Union[int, tuple[int]]): The kernel size of the 3D convolution. | |||
| mode (int): Modes for different convolutions. Not currently used. | |||
| pad_mode (str): Modes to fill padding. It could be "valid", "same", or "pad". Default: "valid". | |||
| pad (Union(int, tuple[int])): The pad value to be filled. Default: 0. If `pad` is an integer, the paddings of | |||
| head, tail, top, bottom, left and right are the same, equal to pad. If `pad` is a tuple of four | |||
| integers, the padding of head, tail, top, bottom, left and right equal to pad[0], pad[1], pad[2], | |||
| pad[3], pad[4] and pad[5] correspondingly. | |||
| stride (Union(int, tuple[int])): The stride to be applied to the convolution filter. Default: 1. | |||
| dilation (Union(int, tuple[int])): Specifies the space to use between kernel elements. Default: 1. | |||
| group (int): Splits input into groups. Default: 1. | |||
| data_format (str): The optional value for data format. Currently only support "NCDHW". | |||
| Inputs: | |||
| - **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})`. | |||
| - **weight** (Tensor) - Set size of kernel is :math:`(D_1, K_2, K_3)`, then the shape is | |||
| :math:`(C_{out}, C_{in}, D_{in}, K_1, K_2)`. | |||
| Outputs: | |||
| Tensor, the value that applied 3D convolution. The shape is :math:`(N, C_{out}, D_{out}, H_{out}, W_{out})`. | |||
| Supported Platforms: | |||
| ``Ascend`` | |||
| Examples: | |||
| >>> input = Tensor(np.ones([16, 3, 10, 32, 32]), mindspore.float32) | |||
| >>> weight = Tensor(np.ones([32, 3, 4, 3, 3]), mindspore.float32) | |||
| >>> conv3d = P.Conv3D(out_channel=32, kernel_size=(4, 3, 3)) | |||
| >>> output = conv3d(input, weight) | |||
| >>> print(output.shape) | |||
| (16, 32, 7, 30, 30) | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self, | |||
| out_channel, | |||
| kernel_size, | |||
| mode=1, | |||
| pad_mode="valid", | |||
| pad=0, | |||
| stride=1, | |||
| dilation=1, | |||
| group=1, | |||
| data_format="NCDHW"): | |||
| """Initialize Conv3D""" | |||
| self.init_prim_io_names(inputs=['x', 'w'], outputs=['output']) | |||
| self.kernel_size = _check_3d_int_or_tuple('kernel_size', kernel_size, self.name) | |||
| self.stride = _check_3d_int_or_tuple('stride', stride, self.name, allow_five=True, ret_five=True) | |||
| self.add_prim_attr('strides', self.stride) | |||
| self.dilation = _check_3d_int_or_tuple('dilation', dilation, self.name, allow_five=True, ret_five=True) | |||
| self.add_prim_attr('dilations', self.dilation) | |||
| validator.check_value_type('pad', pad, (int, tuple), self.name) | |||
| if isinstance(pad, int): | |||
| pad = (pad,) * 6 | |||
| validator.check_equal_int(len(pad), 6, 'pad size', self.name) | |||
| self.padding = pad | |||
| self.pad_mode = validator.check_string(pad_mode.lower(), ['valid', 'same', 'pad'], 'pad_mode', self.name) | |||
| self.add_prim_attr('pad_mode', self.pad_mode) | |||
| if self.pad_mode != 'pad' and pad != (0, 0, 0, 0, 0, 0): | |||
| raise ValueError(f"For '{self.name}', when pad is not 0, pad_mode should be set as 'pad'.") | |||
| if self.pad_mode == 'pad': | |||
| for item in pad: | |||
| validator.check_non_negative_int(item, 'pad item', self.name) | |||
| self.mode = validator.check_equal_int(mode, 1, 'mode', self.name) | |||
| self.format = validator.check_string(data_format, ['NCDHW'], 'format', self.name) | |||
| self.add_prim_attr('data_format', self.format) | |||
| self.add_prim_attr('io_format', "NCDHW") | |||
| self.out_channel = validator.check_positive_int(out_channel, 'out_channel', self.name) | |||
| self.group = validator.check_positive_int(group, 'group', self.name) | |||
| self.add_prim_attr('groups', self.group) | |||
| self.add_prim_attr('offset_x', 0) | |||
| def infer_shape(self, x_shape, w_shape, b_shape=None): | |||
| validator.check_equal_int(len(w_shape), 5, "weight rank", self.name) | |||
| validator.check_equal_int(len(x_shape), 5, "x rank", self.name) | |||
| if b_shape is not None: | |||
| raise ValueError("Bias currently only support None.") | |||
| validator.check(f"x_shape[1] / group", x_shape[1] // self.group, "w_shape[1]", w_shape[1], Rel.EQ, self.name) | |||
| validator.check('out_channel', self.out_channel, 'w_shape[0]', w_shape[0], Rel.EQ, self.name) | |||
| validator.check('kernel_size', self.kernel_size, 'w_shape[1:4]', tuple(w_shape[2:]), Rel.EQ, self.name) | |||
| kernel_size_d = w_shape[2] | |||
| kernel_size_h = w_shape[3] | |||
| kernel_size_w = w_shape[4] | |||
| stride_d = self.stride[2] | |||
| stride_h = self.stride[3] | |||
| stride_w = self.stride[4] | |||
| dilation_d = self.dilation[2] | |||
| dilation_h = self.dilation[3] | |||
| dilation_w = self.dilation[4] | |||
| if self.pad_mode == "valid": | |||
| d_out = math.ceil((x_shape[2] - dilation_d * (kernel_size_d - 1)) / stride_d) | |||
| h_out = math.ceil((x_shape[3] - dilation_h * (kernel_size_h - 1)) / stride_h) | |||
| w_out = math.ceil((x_shape[4] - dilation_w * (kernel_size_w - 1)) / stride_w) | |||
| pad_head, pad_tail, pad_top, pad_bottom, pad_left, pad_right = 0, 0, 0, 0, 0, 0 | |||
| elif self.pad_mode == "same": | |||
| d_out = math.ceil(x_shape[2] / stride_d) | |||
| h_out = math.ceil(x_shape[3] / stride_h) | |||
| w_out = math.ceil(x_shape[4] / stride_w) | |||
| pad_needed_d = max(0, (d_out - 1) * stride_d + dilation_d * (kernel_size_d - 1) + 1 - x_shape[2]) | |||
| pad_head = math.floor(pad_needed_d / 2) | |||
| pad_tail = pad_needed_d - pad_head | |||
| pad_needed_h = max(0, (h_out - 1) * stride_h + dilation_h * (kernel_size_h - 1) + 1 - x_shape[3]) | |||
| pad_top = math.floor(pad_needed_h / 2) | |||
| pad_bottom = pad_needed_h - pad_top | |||
| pad_needed_w = max(0, (w_out - 1) * stride_w + dilation_w * (kernel_size_w - 1) + 1 - x_shape[4]) | |||
| pad_left = math.floor(pad_needed_w / 2) | |||
| pad_right = pad_needed_w - pad_left | |||
| elif self.pad_mode == 'pad': | |||
| pad_head, pad_tail, pad_top, pad_bottom, pad_left, pad_right = self.padding | |||
| d_out = 1 + (x_shape[2] + pad_head + pad_tail - kernel_size_d - (kernel_size_d - 1) | |||
| * (dilation_d - 1)) / stride_d | |||
| h_out = 1 + (x_shape[3] + pad_top + pad_bottom - kernel_size_h - (kernel_size_h - 1) | |||
| * (dilation_h - 1)) / stride_h | |||
| w_out = 1 + (x_shape[4] + pad_left + pad_right - kernel_size_w - (kernel_size_w - 1) | |||
| * (dilation_w - 1)) / stride_w | |||
| d_out = math.floor(d_out) | |||
| h_out = math.floor(h_out) | |||
| w_out = math.floor(w_out) | |||
| self.pad_list = [pad_head, pad_tail, pad_top, pad_bottom, pad_left, pad_right] | |||
| self.add_prim_attr('pads', (pad_head, pad_tail, pad_top, pad_bottom, pad_left, pad_right)) | |||
| out_channel = self.out_channel | |||
| out_shape = [x_shape[0], out_channel, d_out, h_out, w_out] | |||
| _check_shape('output', out_shape, self.name) | |||
| return out_shape | |||
| def infer_dtype(self, x_dtype, w_dtype, b_dtype=None): | |||
| args = {'x': x_dtype, 'w': w_dtype} | |||
| valid_dtypes = [mstype.float16, mstype.float32] | |||
| validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name) | |||
| return x_dtype | |||
| class Conv3DBackpropInput(PrimitiveWithInfer): | |||
| """ | |||
| Computes the gradients of convolution 3D with respect to the input. | |||
| Args: | |||
| out_channel (int): The dimension of the output. | |||
| kernel_size (Union[int, tuple[int]]): The kernel size of the 3D convolution. | |||
| mode (int): Modes for different convolutions. Not currently used. | |||
| pad_mode (str): Modes to fill padding. It could be "valid", "same", or "pad". Default: "valid". | |||
| pad (Union(int, tuple[int])): The pad value to be filled. Default: 0. If `pad` is an integer, the paddings of | |||
| head, tail, top, bottom, left and right are the same, equal to pad. If `pad` is a tuple of four | |||
| integers, the padding of head, tail, top, bottom, left and right equal to pad[0], pad[1], pad[2], | |||
| pad[3], pad[4] and pad[5] correspondingly. | |||
| stride (Union(int, tuple[int])): The stride to be applied to the convolution filter. Default: 1. | |||
| dilation (Union(int, tuple[int])): Specifies the space to use between kernel elements. Default: 1. | |||
| group (int): Splits input into groups. Default: 1. | |||
| data_format (str): The optional value for data format. Currently only support 'NCDHW'. | |||
| Inputs: | |||
| - **weight** (Tensor) - Set size of kernel is :math:`(K_1, K_2, K_3)`, then the shape is | |||
| :math:`(C_{out}, C_{in}, D_{in}, K_1, K_2)`. | |||
| - **dout** (Tensor) - the gradients w.r.t the output of the convolution. The shape conforms to the default | |||
| data_format :math:`(N, C_{out}, D_{out}, H_{out}, W_{out})`. | |||
| - **input_size** (Tensor) - A tuple describes the shape of the input which conforms to the format | |||
| :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})`. | |||
| Outputs: | |||
| Tensor, the gradients w.r.t the input of convolution 3D. It has the same shape as the input. | |||
| Supported Platforms: | |||
| ``Ascend`` | |||
| Examples: | |||
| >>> dout = Tensor(np.ones([16, 32, 10, 32, 32]), mindspore.float32) | |||
| >>> weight = Tensor(np.ones([32, 32, 4, 6, 2]), mindspore.float32) | |||
| >>> x = Tensor(np.ones([16, 32, 13, 37, 33])) | |||
| >>> conv3d_backprop_input = P.Conv3DBackpropInput(out_channel=4, kernel_size=(4, 6, 2)) | |||
| >>> output = conv3d_backprop_input(dout, weight, F.shape(x)) | |||
| >>> print(output.shape) | |||
| (16, 32, 13, 37, 33) | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self, | |||
| out_channel, | |||
| kernel_size, | |||
| pad_mode="valid", | |||
| pad=0, | |||
| mode=1, | |||
| stride=1, | |||
| dilation=1, | |||
| group=1, | |||
| data_format="NCDHW"): | |||
| """Initialize Conv3DBackpropInput""" | |||
| self.init_prim_io_names(inputs=['filter', 'out_backprop', 'input_size'], outputs=['y']) | |||
| self.out_channel = validator.check_positive_int(out_channel, 'out_channel', self.name) | |||
| self.kernel_size = _check_3d_int_or_tuple('kernel_size', kernel_size, self.name) | |||
| self.stride = _check_3d_int_or_tuple('stride', stride, self.name, allow_five=True, ret_five=True) | |||
| self.add_prim_attr('strides', self.stride) | |||
| self.dilation = _check_3d_int_or_tuple('dilation', dilation, self.name, allow_five=True, ret_five=True) | |||
| self.add_prim_attr('dilations', self.dilation) | |||
| validator.check_value_type('pad', pad, (int, tuple), self.name) | |||
| if isinstance(pad, int): | |||
| pad = (pad,) * 6 | |||
| validator.check_equal_int(len(pad), 6, 'pad size', self.name) | |||
| self.pad_list = pad | |||
| self.pad_mode = validator.check_string(pad_mode.lower(), ['valid', 'same', 'pad'], 'pad_mode', self.name) | |||
| if self.pad_mode != 'pad' and self.pad_list != (0, 0, 0, 0, 0, 0): | |||
| raise ValueError(f"For '{self.name}', when pad is not 0, pad_mode should be set as 'pad'.") | |||
| if self.pad_mode == 'pad': | |||
| for item in pad: | |||
| validator.check_non_negative_int(item, 'pad item', self.name) | |||
| self.add_prim_attr('pad_mode', self.pad_mode) | |||
| self.mode = validator.check_equal_int(mode, 1, 'mode', self.name) | |||
| self.group = validator.check_positive_int(group, 'group', self.name) | |||
| self.add_prim_attr('groups', self.group) | |||
| self.format = validator.check_string(data_format, ['NCDHW'], 'format', self.name) | |||
| self.add_prim_attr('data_format', self.format) | |||
| self.add_prim_attr('io_format', "NCDHW") | |||
| def __infer__(self, w, doutput, x_size): | |||
| x_size_v = x_size['value'] | |||
| validator.check_value_type('x_size', x_size_v, [tuple], self.name) | |||
| for i, dim_len in enumerate(x_size_v): | |||
| validator.check_value_type("x_size[%d]" % i, dim_len, [int], self.name) | |||
| args = {'doutput': doutput['dtype'], 'w': w['dtype']} | |||
| valid_dtypes = [mstype.float16, mstype.float32] | |||
| validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name) | |||
| validator.check("filter's batch", w['shape'][0], "dout's channel", doutput['shape'][1], Rel.EQ, self.name) | |||
| validator.check("filter's channel", w['shape'][1], "input_size's channel", x_size_v[1], Rel.EQ, self.name) | |||
| validator.check("input_size's batch", x_size_v[0], "dout's batch", doutput['shape'][0], Rel.EQ, self.name) | |||
| # infer shape | |||
| dout_shape = doutput['shape'] | |||
| kernel_d = self.kernel_size[0] | |||
| kernel_h = self.kernel_size[1] | |||
| kernel_w = self.kernel_size[2] | |||
| stride_d = self.stride[2] | |||
| stride_h = self.stride[3] | |||
| stride_w = self.stride[4] | |||
| dilation_d = self.dilation[2] | |||
| dilation_h = self.dilation[3] | |||
| dilation_w = self.dilation[4] | |||
| # The pad_mode is valid by default. If pad_mode is not valid or same, then pad. | |||
| if self.pad_mode == "valid": | |||
| self.pad_list = (0, 0, 0, 0, 0, 0) | |||
| if self.pad_mode == "same": | |||
| pad_needed_d = max(0, (dout_shape[2] - 1) * stride_d + dilation_d * (kernel_d - 1) + 1 - x_size_v[2]) | |||
| pad_head = math.floor(pad_needed_d / 2) | |||
| pad_tail = pad_needed_d - pad_head | |||
| pad_needed_h = max(0, (dout_shape[3] - 1) * stride_h + dilation_h * (kernel_h - 1) + 1 - x_size_v[3]) | |||
| pad_top = math.floor(pad_needed_h / 2) | |||
| pad_bottom = pad_needed_h - pad_top | |||
| pad_needed_w = max(0, (dout_shape[4] - 1) * stride_w + dilation_w * (kernel_w - 1) + 1 - x_size_v[4]) | |||
| pad_left = math.floor(pad_needed_w / 2) | |||
| pad_right = pad_needed_w - pad_left | |||
| self.pad_list = (pad_head, pad_tail, pad_top, pad_bottom, pad_left, pad_right) | |||
| self.add_prim_attr('pads', self.pad_list) | |||
| out = { | |||
| 'value': None, | |||
| 'shape': x_size_v, | |||
| 'dtype': doutput['dtype'], | |||
| } | |||
| return out | |||
| @@ -27,6 +27,7 @@ from mindspore.ops import operations as P | |||
| from mindspore.ops.operations import _grad_ops as G | |||
| from mindspore.ops.operations import _inner_ops as inner | |||
| from mindspore.ops.operations import _quant_ops as Q | |||
| from mindspore.ops.operations import nn_ops as nps | |||
| from ..ut_filter import non_graph_engine | |||
| from ....mindspore_test_framework.mindspore_test import mindspore_test | |||
| from ....mindspore_test_framework.pipeline.forward.compile_forward \ | |||
| @@ -288,6 +289,7 @@ class CountNonZero(nn.Cell): | |||
| self.axis = axis | |||
| self.keep_dims = keep_dims | |||
| self.dtype = dtype | |||
| def construct(self, input_x): | |||
| nonzero_num = C.count_nonzero(input_x, self.axis, self.keep_dims, self.dtype) | |||
| return nonzero_num | |||
| @@ -423,6 +425,50 @@ class ScatterDiv(nn.Cell): | |||
| return out | |||
| class Conv3D(nn.Cell): | |||
| """Conv3D net definition""" | |||
| def __init__(self, out_channel, kernel_size, mode, pad_mode, pad, stride, dilation, group, data_format): | |||
| super(Conv3D, self).__init__() | |||
| self.conv = nps.Conv3D(out_channel=out_channel, kernel_size=kernel_size, mode=mode, pad_mode=pad_mode, | |||
| pad=pad, stride=stride, dilation=dilation, group=group, data_format=data_format) | |||
| def construct(self, x, w): | |||
| out = self.conv(x, w) | |||
| return out | |||
| class Conv3DBackpropInput(nn.Cell): | |||
| """Conv3DBackpropInput net definition""" | |||
| def __init__(self, input_shape, out_channel, kernel_size, mode, pad_mode, pad, stride, dilation, group, | |||
| data_format): | |||
| super(Conv3DBackpropInput, self).__init__() | |||
| self.conv = nps.Conv3DBackpropInput(out_channel, kernel_size, pad_mode=pad_mode, | |||
| pad=pad, mode=mode, stride=stride, dilation=dilation, | |||
| group=group, data_format=data_format) | |||
| self.x_size = input_shape | |||
| def construct(self, w, doutput): | |||
| ms_out = self.conv(w, doutput, self.x_size) | |||
| return ms_out | |||
| class Conv3DBackpropFilter(nn.Cell): | |||
| """Conv3DBackpropFilter net definition""" | |||
| def __init__(self, w_shape, out_channel, kernel_size, mode, pad_mode, pad, stride, dilation, group, data_format): | |||
| super(Conv3DBackpropFilter, self).__init__() | |||
| self.conv = G.Conv3DBackpropFilter(out_channel, kernel_size, pad_mode=pad_mode, | |||
| pad=pad, mode=mode, stride=stride, dilation=dilation, | |||
| group=group, data_format=data_format) | |||
| self.w_size = w_shape | |||
| def construct(self, x, doutput): | |||
| ms_out = self.conv(x, doutput, self.w_size) | |||
| return ms_out | |||
| class ApplyFtrlNet(nn.Cell): | |||
| def __init__(self): | |||
| super(ApplyFtrlNet, self).__init__() | |||
| @@ -1180,6 +1226,24 @@ test_case_math_ops = [ | |||
| 'block': Moments(axis=(), keep_dims=False), | |||
| 'desc_inputs': [Tensor(np.random.rand(3, 16, 5, 4).astype(np.float32))], | |||
| 'skip': ['backward']}), | |||
| ('Conv3D', { | |||
| 'block': Conv3D(out_channel=32, kernel_size=(4, 3, 3), mode=1, pad_mode='valid', pad=0, | |||
| stride=1, dilation=1, group=1, data_format="NCDHW"), | |||
| 'desc_inputs': [Tensor(np.random.random((16, 3, 10, 32, 32)).astype(np.float16)), | |||
| Tensor(np.random.random((32, 3, 4, 3, 3)).astype(np.float16))], | |||
| 'skip': ['backward']}), | |||
| ('Conv3DBackpropInput', { | |||
| 'block': Conv3DBackpropInput(input_shape=(16, 32, 13, 37, 33), out_channel=32, kernel_size=(4, 6, 2), mode=1, | |||
| pad_mode='valid', pad=0, stride=1, dilation=1, group=1, data_format="NCDHW"), | |||
| 'desc_inputs': [Tensor(np.random.random((32, 32, 4, 6, 2)).astype(np.float16)), | |||
| Tensor(np.random.random((16, 32, 10, 32, 32)).astype(np.float16))], | |||
| 'skip': ['backward']}), | |||
| ('Conv3DBackpropFilter', { | |||
| 'block': Conv3DBackpropFilter(w_shape=(32, 32, 4, 6, 2), out_channel=32, kernel_size=(4, 6, 2), mode=1, | |||
| pad_mode='valid', pad=0, stride=1, dilation=1, group=1, data_format="NCDHW"), | |||
| 'desc_inputs': [Tensor(np.random.random((16, 32, 13, 37, 33)).astype(np.float16)), | |||
| Tensor(np.random.random((16, 32, 10, 32, 32)).astype(np.float16))], | |||
| 'skip': ['backward']}), | |||
| ('CountNonZero', { | |||
| 'block': CountNonZero(axis=(), keep_dims=False, dtype=mstype.int32), | |||
| 'desc_inputs': [Tensor(np.random.rand(3, 16, 5, 4).astype(np.float32))], | |||