From c45eee9b94d9e908248fbb5b9d077b04ec802436 Mon Sep 17 00:00:00 2001 From: jiangzhenguang Date: Sat, 28 Nov 2020 14:02:11 +0800 Subject: [PATCH] add conv3d object --- mindspore/_checkparam.py | 44 ++- .../pass/const_input_to_attr_registry.cc | 2 + .../ccsrc/backend/session/kernel_graph.cc | 4 +- mindspore/ccsrc/common/trans.cc | 5 +- mindspore/core/base/core_ops.h | 2 + mindspore/ops/_grad/grad_nn_ops.py | 22 ++ mindspore/ops/_op_impl/tbe/__init__.py | 4 + mindspore/ops/_op_impl/tbe/conv3d.py | 45 +++ .../_op_impl/tbe/conv3d_backprop_filter.py | 42 +++ .../ops/_op_impl/tbe/conv3d_backprop_input.py | 42 +++ .../ops/_op_impl/tbe/conv3d_transpose.py | 46 +++ mindspore/ops/_op_impl/tbe/trans_data.py | 18 ++ mindspore/ops/composite/clip_ops.py | 2 +- mindspore/ops/op_info_register.py | 8 + mindspore/ops/operations/_grad_ops.py | 137 +++++++- mindspore/ops/operations/nn_ops.py | 294 ++++++++++++++++++ tests/ut/python/ops/test_ops.py | 64 ++++ 17 files changed, 771 insertions(+), 10 deletions(-) create mode 100644 mindspore/ops/_op_impl/tbe/conv3d.py create mode 100644 mindspore/ops/_op_impl/tbe/conv3d_backprop_filter.py create mode 100644 mindspore/ops/_op_impl/tbe/conv3d_backprop_input.py create mode 100644 mindspore/ops/_op_impl/tbe/conv3d_transpose.py diff --git a/mindspore/_checkparam.py b/mindspore/_checkparam.py index ed71c859ff..6f6a50c360 100644 --- a/mindspore/_checkparam.py +++ b/mindspore/_checkparam.py @@ -37,9 +37,9 @@ class Rel(Enum): GE = 6 # >= # scalar range check INC_NEITHER = 7 # (), include neither - INC_LEFT = 8 # [), include left - INC_RIGHT = 9 # (], include right - INC_BOTH = 10 # [], include both + INC_LEFT = 8 # [), include left + INC_RIGHT = 9 # (], include right + INC_BOTH = 10 # [], include both # collection in, not in IN = 11 NOT_IN = 12 @@ -92,6 +92,41 @@ rel_strs = { } +def _check_3d_int_or_tuple(arg_name, arg_value, prim_name, allow_five=False, + ret_five=False, greater_zero=True): + """ + Checks whether an argument is a positive int or tuple with 3 or 5(when allow_five is True) positive int elements. + """ + + def _raise_message(): + raise ValueError(f"For '{prim_name}' attr '{arg_name}' should be an positive int number or a tuple of three " + f"{'or five ' if allow_five else ''}positive int numbers, but got {arg_value}") + + def _get_return_value(): + if isinstance(arg_value, int): + ret = (1, 1, arg_value, arg_value, arg_value) if ret_five else (arg_value, arg_value, arg_value) + elif len(arg_value) == 3: + ret = (1, 1, arg_value[0], arg_value[1], arg_value[2]) if ret_five else arg_value + elif len(arg_value) == 5: + if not allow_five: + _raise_message() + ret = arg_value if ret_five else (arg_value[1], arg_value[2], arg_value[3]) + else: + _raise_message() + return ret + + Validator.check_value_type(arg_name, arg_value, (int, tuple), prim_name) + ret_value = _get_return_value() + for item in ret_value: + if isinstance(item, int) and not isinstance(item, bool): + if greater_zero and item > 0: + continue + if not greater_zero and item >= 0: + continue + _raise_message() + return tuple(ret_value) + + def check_number(arg_value, value, rel, arg_type=int, arg_name=None, prim_name=None): """ Check argument integer. @@ -428,6 +463,7 @@ class Validator: @staticmethod def check_types_same_and_valid(args, valid_values, prim_name): """Checks whether the types of inputs are the same and valid.""" + def _check_type_valid(arg): arg_key, arg_val = arg elem_type = arg_val @@ -494,6 +530,7 @@ class Validator: raise TypeError(f'For \'{prim_name}\' type of `{arg2_name}` should be same as `{arg1_name}`,' f' but `{arg1_name}` is {arg1_type} and `{arg2_name}` is {arg2_type}.') return arg1 + reduce(_check_types_same, map(_check_argument_type, args.items())) @staticmethod @@ -625,6 +662,7 @@ def args_type_check(*type_args, **type_kwargs): if value is not None and not isinstance(value, bound_types[name]): raise TypeError('Argument {} must be {}'.format(name, bound_types[name])) return func(*args, **kwargs) + return wrapper return type_check diff --git a/mindspore/ccsrc/backend/optimizer/pass/const_input_to_attr_registry.cc b/mindspore/ccsrc/backend/optimizer/pass/const_input_to_attr_registry.cc index c0fb214752..52fc6e507b 100644 --- a/mindspore/ccsrc/backend/optimizer/pass/const_input_to_attr_registry.cc +++ b/mindspore/ccsrc/backend/optimizer/pass/const_input_to_attr_registry.cc @@ -29,6 +29,8 @@ ConstInputToAttrInfoRegistry::ConstInputToAttrInfoRegistry() { Register(prim::kPrimAvgPoolGradVm->name(), {0}); Register(prim::kPrimConv2DBackpropInput->name(), {2}); Register(prim::kPrimConv2DBackpropFilter->name(), {2}); + Register(prim::kPrimConv3DBackpropInput->name(), {2}); + Register(prim::kPrimConv3DBackpropFilter->name(), {2}); Register(prim::kPrimDepthwiseConv2dNativeBackpropFilter->name(), {1}); Register(prim::kPrimDepthwiseConv2dNativeBackpropInput->name(), {0}); Register(prim::kPrimReshape->name(), {1}); diff --git a/mindspore/ccsrc/backend/session/kernel_graph.cc b/mindspore/ccsrc/backend/session/kernel_graph.cc index 03bdfaafcc..2d08a7049c 100644 --- a/mindspore/ccsrc/backend/session/kernel_graph.cc +++ b/mindspore/ccsrc/backend/session/kernel_graph.cc @@ -31,6 +31,7 @@ namespace session { namespace { constexpr auto kIsFeatureMapOutput = "IsFeatureMapOutput"; constexpr auto kIsFeatureMapInputList = "IsFeatureMapInputList"; +constexpr size_t k5dDims = 5; const std::set kOpAssignKernelNameList = {prim::kPrimAssign->name(), prim::kPrimAssignAdd->name(), prim::kPrimAssignSub->name()}; void PushNoVisitedNode(const AnfNodePtr &node, std::queue *que, @@ -383,7 +384,8 @@ void KernelGraph::ResetInFormat(const AnfNodePtr &node, const std::string &forma for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(node); i++) { auto in_node = AnfAlgo::GetInputNode(node->cast(), i); MS_EXCEPTION_IF_NULL(in_node); - if (in_node->isa() || in_node->isa()) { + if ((in_node->isa() || in_node->isa()) && + AnfAlgo::GetOutputInferShape(in_node, 0).size() == k5dDims) { ReSetParameterValueNodeFormatAndType(in_node, format); } } diff --git a/mindspore/ccsrc/common/trans.cc b/mindspore/ccsrc/common/trans.cc index f274009537..61f6de0ea3 100644 --- a/mindspore/ccsrc/common/trans.cc +++ b/mindspore/ccsrc/common/trans.cc @@ -291,10 +291,7 @@ std::vector Fracz3DDeviceShape(const std::vector &shape) { std::vector device_shape; const size_t C1 = (shape[1] + kCubeSize - 1) / kCubeSize; const size_t N1 = (shape[0] + kCubeSize - 1) / kCubeSize; - device_shape.push_back(shape[2]); - device_shape.push_back(C1); - device_shape.push_back(shape[3]); - device_shape.push_back(shape[4]); + device_shape.push_back(shape[2] * C1 * shape[3] * shape[4]); device_shape.push_back(N1); device_shape.push_back(kCubeSize); device_shape.push_back(kCubeSize); diff --git a/mindspore/core/base/core_ops.h b/mindspore/core/base/core_ops.h index ba24ff97b9..e47e7f9dff 100644 --- a/mindspore/core/base/core_ops.h +++ b/mindspore/core/base/core_ops.h @@ -148,6 +148,8 @@ inline const PrimitivePtr kPrimReluGrad = std::make_shared("ReluGrad" inline const PrimitivePtr kPrimRelu6Grad = std::make_shared("ReLU6Grad"); inline const PrimitivePtr kPrimConv2DBackpropInput = std::make_shared("Conv2DBackpropInput"); inline const PrimitivePtr kPrimConv2DBackpropFilter = std::make_shared("Conv2DBackpropFilter"); +inline const PrimitivePtr kPrimConv3DBackpropInput = std::make_shared("Conv3DBackpropInput"); +inline const PrimitivePtr kPrimConv3DBackpropFilter = std::make_shared("Conv3DBackpropFilter"); inline const PrimitivePtr kPrimDepthwiseConv2dNative = std::make_shared("DepthwiseConv2dNative"); inline const PrimitivePtr kPrimDepthwiseConv2dNativeBackpropFilter = std::make_shared("DepthwiseConv2dNativeBackpropFilter"); diff --git a/mindspore/ops/_grad/grad_nn_ops.py b/mindspore/ops/_grad/grad_nn_ops.py index 5aec089dfc..239f51ff8c 100755 --- a/mindspore/ops/_grad/grad_nn_ops.py +++ b/mindspore/ops/_grad/grad_nn_ops.py @@ -18,6 +18,7 @@ import numpy as np from mindspore.ops import _selected_grad_ops as SG from mindspore.ops.primitive import constexpr from mindspore.common.tensor import Tensor +from mindspore.ops.operations import nn_ops as nps from .grad_base import bprop_getters from .. import functional as F from .. import operations as P @@ -60,6 +61,27 @@ def get_bprop_conv2d(self): return bprop +@bprop_getters.register(nps.Conv3D) +def get_bprop_conv3d(self): + """Grad definition for `Conv3D` operation.""" + input_grad = nps.Conv3DBackpropInput( + self.out_channel, self.kernel_size, self.pad_mode, self.pad, mode=self.mode, + dilation=self.dilation, stride=self.stride, group=self.group, data_format=self.format + ) + filter_grad = G.Conv3DBackpropFilter( + self.out_channel, self.kernel_size, self.pad_mode, self.pad, mode=self.mode, + dilation=self.dilation, stride=self.stride, group=self.group, data_format=self.format + ) + get_shape = P.Shape() + + def bprop(x, w, out, dout): + dx = input_grad(w, dout, get_shape(x)) + dw = filter_grad(x, dout, get_shape(w)) + return dx, dw + + return bprop + + @bprop_getters.register(inner.ExtractImagePatches) def get_bprop_extract_image_patches(self): """Grad definition for `ExtractImagePatches` operation.""" diff --git a/mindspore/ops/_op_impl/tbe/__init__.py b/mindspore/ops/_op_impl/tbe/__init__.py index aa9ecc6269..d704687320 100644 --- a/mindspore/ops/_op_impl/tbe/__init__.py +++ b/mindspore/ops/_op_impl/tbe/__init__.py @@ -347,3 +347,7 @@ from .fake_quant_with_min_max_vars import _fake_quant_with_min_max_vars_tbe from .fake_quant_with_min_max_vars_gradient import _fake_quant_with_min_max_vars_gradient_tbe from .fake_quant_with_min_max_vars_per_channel import _fake_quant_with_min_max_vars_per_channel_tbe from .fake_quant_with_min_max_vars_per_channel_gradient import _fake_quant_with_min_max_vars_per_channel_gradient_tbe +from .conv3d import _conv3d_tbe +from .conv3d_backprop_input import _conv3d_backprop_input_tbe +from .conv3d_backprop_filter import _conv3d_backprop_filter_tbe +from .conv3d_transpose import _conv3d_transpose_tbe diff --git a/mindspore/ops/_op_impl/tbe/conv3d.py b/mindspore/ops/_op_impl/tbe/conv3d.py new file mode 100644 index 0000000000..b7f2a0ec87 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/conv3d.py @@ -0,0 +1,45 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Conv3D op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +conv3d_op_info = TBERegOp("Conv3D") \ + .fusion_type("CONVLUTION") \ + .async_flag(False) \ + .binfile_name("conv3d.so") \ + .compute_cost(10) \ + .kernel_name("conv3d") \ + .partial_flag(True) \ + .attr("strides", "required", "listInt", "all") \ + .attr("pads", "required", "listInt", "all") \ + .attr("dilations", "required", "listInt", "all") \ + .attr("groups", "optional", "int", "all") \ + .attr("data_format", "optional", "str", "all") \ + .attr("offset_x", "optional", "int", "all") \ + .input(0, "x", False, "required", "all") \ + .input(1, "filter", False, "required", "all") \ + .input(2, "bias", False, "optional", "all") \ + .input(3, "offset_w", False, "optional", "all") \ + .output(0, "y", True, "required", "all") \ + .dtype_format(DataType.F16_NDC1HWC0, DataType.F16_FRACTAL_Z_3D, DataType.F16_Default, + DataType.I8_Default, DataType.F16_NDC1HWC0) \ + .get_op_info() + + +@op_info_register(conv3d_op_info) +def _conv3d_tbe(): + """Conv3D TBE register""" + return diff --git a/mindspore/ops/_op_impl/tbe/conv3d_backprop_filter.py b/mindspore/ops/_op_impl/tbe/conv3d_backprop_filter.py new file mode 100644 index 0000000000..fd0b267740 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/conv3d_backprop_filter.py @@ -0,0 +1,42 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Conv3DBackpropFilter op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +conv3d_backprop_filter_op_info = TBERegOp("Conv3DBackpropFilter") \ + .fusion_type("CONVLUTION") \ + .async_flag(False) \ + .binfile_name("conv3d_backprop_filter_d.so") \ + .compute_cost(10) \ + .kernel_name("conv3d_backprop_filter_d") \ + .partial_flag(True) \ + .attr("filter_size", "required", "listInt", "all") \ + .attr("strides", "required", "listInt", "all") \ + .attr("pads", "required", "listInt", "all") \ + .attr("dilations", "required", "listInt", "all") \ + .attr("groups", "optional", "int", "all") \ + .attr("data_format", "optional", "str", "all") \ + .input(0, "x", False, "required", "all") \ + .input(1, "out_backprop", False, "required", "all") \ + .output(0, "y", True, "required", "all") \ + .dtype_format(DataType.F16_NDC1HWC0, DataType.F16_NDC1HWC0, DataType.F32_FRACTAL_Z_3D) \ + .get_op_info() + + +@op_info_register(conv3d_backprop_filter_op_info) +def _conv3d_backprop_filter_tbe(): + """Conv3DBackpropFilter TBE register""" + return diff --git a/mindspore/ops/_op_impl/tbe/conv3d_backprop_input.py b/mindspore/ops/_op_impl/tbe/conv3d_backprop_input.py new file mode 100644 index 0000000000..b5957c54f8 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/conv3d_backprop_input.py @@ -0,0 +1,42 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Conv3DBackpropInput op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +conv3d_backprop_input_op_info = TBERegOp("Conv3DBackpropInput") \ + .fusion_type("CONVLUTION") \ + .async_flag(False) \ + .binfile_name("conv3d_backprop_input_d.so") \ + .compute_cost(10) \ + .kernel_name("conv3d_backprop_input_d") \ + .partial_flag(True) \ + .attr("input_size", "required", "listInt", "all") \ + .attr("strides", "required", "listInt", "all") \ + .attr("pads", "required", "listInt", "all") \ + .attr("dilations", "required", "listInt", "all") \ + .attr("groups", "optional", "int", "all") \ + .attr("data_format", "optional", "str", "all") \ + .input(0, "filter", False, "required", "all") \ + .input(1, "out_backprop", False, "required", "all") \ + .output(0, "y", True, "required", "all") \ + .dtype_format(DataType.F16_FRACTAL_Z_3D, DataType.F16_NDC1HWC0, DataType.F16_NDC1HWC0) \ + .get_op_info() + + +@op_info_register(conv3d_backprop_input_op_info) +def _conv3d_backprop_input_tbe(): + """Conv3DBackpropInput TBE register""" + return diff --git a/mindspore/ops/_op_impl/tbe/conv3d_transpose.py b/mindspore/ops/_op_impl/tbe/conv3d_transpose.py new file mode 100644 index 0000000000..fb1763eae2 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/conv3d_transpose.py @@ -0,0 +1,46 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Conv3DTranspose op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +conv3d_transpose_op_info = TBERegOp("Conv3DTranspose") \ + .fusion_type("CONVLUTION") \ + .async_flag(False) \ + .binfile_name("conv3d_transpose_d.so") \ + .compute_cost(10) \ + .kernel_name("conv3d_transpose_d") \ + .partial_flag(True) \ + .attr("input_size", "required", "listInt", "all") \ + .attr("strides", "required", "listInt", "all") \ + .attr("pads", "required", "listInt", "all") \ + .attr("dilations", "optional", "listInt", "all") \ + .attr("groups", "optional", "int", "all") \ + .attr("data_format", "optional", "str", "all") \ + .attr("output_padding", "optional", "listInt", "all") \ + .input(0, "x", False, "required", "all") \ + .input(0, "filter", False, "required", "all") \ + .input(0, "bias", False, "optional", "all") \ + .input(1, "offset_w", False, "optional", "all") \ + .output(0, "y", True, "required", "all") \ + .dtype_format(DataType.F16_NDC1HWC0, DataType.F16_FRACTAL_Z_3D, DataType.F16_Default, DataType.I8_Default, + DataType.F16_NDC1HWC0) \ + .get_op_info() + + +@op_info_register(conv3d_transpose_op_info) +def _conv3d_transpose_tbe(): + """Conv3DTranspose TBE register""" + return diff --git a/mindspore/ops/_op_impl/tbe/trans_data.py b/mindspore/ops/_op_impl/tbe/trans_data.py index 666902172c..44f2af1c42 100644 --- a/mindspore/ops/_op_impl/tbe/trans_data.py +++ b/mindspore/ops/_op_impl/tbe/trans_data.py @@ -133,6 +133,24 @@ trans_data_op_info = TBERegOp("TransData") \ .dtype_format(DataType.F32_HWCN, DataType.F32_FracZNLSTM) \ .dtype_format(DataType.F16_FracZNLSTM, DataType.F16_HWCN) \ .dtype_format(DataType.F32_FracZNLSTM, DataType.F32_HWCN) \ + .dtype_format(DataType.F16_NDHWC, DataType.F16_NDC1HWC0) \ + .dtype_format(DataType.F16_NDC1HWC0, DataType.F16_NDHWC) \ + .dtype_format(DataType.F16_DHWCN, DataType.F16_FRACTAL_Z_3D) \ + .dtype_format(DataType.F16_FRACTAL_Z_3D, DataType.F16_DHWCN) \ + .dtype_format(DataType.F16_NCDHW, DataType.F16_NDC1HWC0) \ + .dtype_format(DataType.F16_NDC1HWC0, DataType.F16_NCDHW) \ + .dtype_format(DataType.F16_NCDHW, DataType.F16_FRACTAL_Z_3D) \ + .dtype_format(DataType.F32_NCDHW, DataType.F32_FRACTAL_Z_3D) \ + .dtype_format(DataType.F16_FRACTAL_Z_3D, DataType.F16_NCDHW) \ + .dtype_format(DataType.F32_FRACTAL_Z_3D, DataType.F32_NCDHW) \ + .dtype_format(DataType.F16_NDHWC, DataType.F16_FRACTAL_Z_3D) \ + .dtype_format(DataType.F32_NDHWC, DataType.F32_FRACTAL_Z_3D) \ + .dtype_format(DataType.F16_FRACTAL_Z_3D, DataType.F16_NDHWC) \ + .dtype_format(DataType.F32_FRACTAL_Z_3D, DataType.F32_NDHWC) \ + .dtype_format(DataType.F32_DHWCN, DataType.F32_FRACTAL_Z_3D) \ + .dtype_format(DataType.F32_FRACTAL_Z_3D, DataType.F32_DHWCN) \ + .dtype_format(DataType.F32_NDC1HWC0, DataType.F32_NDHWC) \ + .dtype_format(DataType.F32_NDHWC, DataType.F32_NDC1HWC0) \ .get_op_info() diff --git a/mindspore/ops/composite/clip_ops.py b/mindspore/ops/composite/clip_ops.py index 92eeeeafb3..768c15bbe4 100644 --- a/mindspore/ops/composite/clip_ops.py +++ b/mindspore/ops/composite/clip_ops.py @@ -110,7 +110,7 @@ def clip_by_global_norm(x, clip_norm=1.0, use_norm=None): Clips tensor values by the ratio of the sum of their norms. Note: - input 'x' should be a tuple or list of tensors. Otherwise, it will raise an error. + Input 'x' should be a tuple or list of tensors. Otherwise, it will raise an error. Args: x (Union(tuple[Tensor], list[Tensor])): Input data to clip. diff --git a/mindspore/ops/op_info_register.py b/mindspore/ops/op_info_register.py index 003d883cf2..907a74523d 100644 --- a/mindspore/ops/op_info_register.py +++ b/mindspore/ops/op_info_register.py @@ -631,6 +631,10 @@ class DataType: F16_NHWC = ("float16", "NHWC") F16_HWCN = ("float16", "HWCN") F16_NDHWC = ("float16", "NDHWC") + F16_NCDHW = ("float16", "NCDHW") + F16_DHWCN = ("float16", "DHWCN") + F16_NDC1HWC0 = ("float16", "NDC1HWC0") + F16_FRACTAL_Z_3D = ("float16", "FRACTAL_Z_3D") F16_FracZNLSTM = ("float16", "FRACTAL_ZN_LSTM") F32_None = ("float32", "") @@ -643,6 +647,10 @@ class DataType: F32_NHWC = ("float32", "NHWC") F32_HWCN = ("float32", "HWCN") F32_NDHWC = ("float32", "NDHWC") + F32_NCDHW = ("float32", "NCDHW") + F32_DHWCN = ("float32", "DHWCN") + F32_NDC1HWC0 = ("float32", "NDC1HWC0") + F32_FRACTAL_Z_3D = ("float32", "FRACTAL_Z_3D") F32_FracZNLSTM = ("float32", "FRACTAL_ZN_LSTM") F64_None = ("float64", "") diff --git a/mindspore/ops/operations/_grad_ops.py b/mindspore/ops/operations/_grad_ops.py index e9147f9288..423d3f12d3 100644 --- a/mindspore/ops/operations/_grad_ops.py +++ b/mindspore/ops/operations/_grad_ops.py @@ -14,8 +14,9 @@ # ============================================================================ """Operators for gradients.""" +import math from functools import partial - +from mindspore._checkparam import _check_3d_int_or_tuple from .. import signature as sig from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register from ..._checkparam import Validator as validator, Rel @@ -288,6 +289,140 @@ class ConcatOffset(PrimitiveWithInfer): return out +class Conv3DBackpropFilter(PrimitiveWithInfer): + """ + Computes the gradients of convolution 3D with respect to the filter. + + Args: + out_channel (int): The dimension of the output. + kernel_size (Union[int, tuple[int]]): The kernel size of the 3D convolution. + mode (int): Modes for different convolutions. Not currently used. + pad_mode (str): Modes to fill padding. It could be "valid", "same", or "pad". Default: "valid". + pad (Union(int, tuple[int])): The pad value to be filled. Default: 0. If `pad` is an integer, the paddings of + head, tail, top, bottom, left and right are the same, equal to pad. If `pad` is a tuple of four + integers, the padding of head, tail, top, bottom, left and right equal to pad[0], pad[1], pad[2], + pad[3], pad[4] and pad[5] correspondingly. + stride (Union(int, tuple[int])): The stride to be applied to the convolution filter. Default: 1. + dilation (Union(int, tuple[int])): Specifies the space to use between kernel elements. Default: 1. + group (int): Splits input into groups. Default: 1. + data_format (str): The optional value for data format. Currently only support 'NCDHW'. + + Inputs: + - **x** (Tensor) - The input of the convolution, then the shape is :math:`(C_{out}, C_{in}, D_{in}, K_1, K_2)`. + - **dout** (Tensor) - The gradients w.r.t the output of the convolution. The shape conforms to the default + data_format :math:`(N, C_{out}, D_{out}, H_{out}, W_{out})`. + - **w_size** (Tensor) - A tuple describes the shape of the weight which conforms to the format + :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})`. + + Outputs: + Tensor, the gradients w.r.t the weight of convolution 3D. It has the same shape as the weight. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> x = Tensor(np.ones([16, 32, 13, 37, 33]), mindspore.float16) + >>> dout = Tensor(np.ones([16, 32, 10, 32, 32]), mindspore.float16) + >>> w = Tensor(np.ones([32, 32, 4, 6, 2]), mindspore.float16) + >>> conv3d_backprop_input = P.Conv3DBackpropInput(out_channel=4, kernel_size=(4, 6, 2)) + >>> output = conv3d_backprop_input(x, dout, F.shape(w)) + >>> print(output.shape) + (32, 32, 4, 6, 2) + """ + + + @prim_attr_register + def __init__(self, + out_channel, + kernel_size, + pad_mode="valid", + pad=0, + mode=1, + stride=(1, 1, 1, 1, 1), + dilation=(1, 1, 1, 1, 1), + group=1, + data_format="NCDHW"): + """Initialize Convolution""" + self.init_prim_io_names(inputs=['x', 'out_backprop', 'filter_size'], outputs=['y']) + self.out_channel = validator.check_positive_int(out_channel, 'out_channel', self.name) + self.kernel_size = _check_3d_int_or_tuple('kernel_size', kernel_size, self.name) + self.stride = _check_3d_int_or_tuple('stride', stride, self.name, allow_five=True, ret_five=True) + self.add_prim_attr('strides', self.stride) + self.dilation = _check_3d_int_or_tuple('dilation', dilation, self.name, allow_five=True, ret_five=True) + self.add_prim_attr('dilations', self.dilation) + validator.check_value_type('pad', pad, (int, tuple), self.name) + if isinstance(pad, int): + pad = (pad,) * 6 + validator.check_equal_int(len(pad), 6, 'pad size', self.name) + self.pad_list = pad + self.add_prim_attr('pads', self.pad_list) + + self.pad_mode = validator.check_string(pad_mode.lower(), ['valid', 'same', 'pad'], 'pad_mode', self.name) + if self.pad_mode != 'pad' and self.pad_list != (0, 0, 0, 0, 0, 0): + raise ValueError(f"For '{self.name}', when pad is not 0, pad_mode should be set as 'pad'.") + if self.pad_mode == 'pad': + for item in pad: + validator.check_non_negative_int(item, 'pad item', self.name) + self.add_prim_attr('pad_mode', self.pad_mode) + + self.mode = validator.check_equal_int(mode, 1, 'mode', self.name) + self.group = validator.check_positive_int(group, 'group', self.name) + self.add_prim_attr('groups', self.group) + self.format = validator.check_string(data_format, ['NCDHW'], 'format', self.name) + self.add_prim_attr('data_format', self.format) + self.add_prim_attr('io_format', "NCDHW") + + def __infer__(self, x, doutput, w_size): + w_size_v = w_size['value'] + validator.check_value_type('w_size', w_size_v, [tuple], self.name) + for i, dim_len in enumerate(w_size_v): + validator.check_value_type("w_size[%d]" % i, dim_len, [int], self.name) + args = {"x": x['dtype'], "doutput": doutput['dtype']} + valid_dtypes = [mstype.float16, mstype.float32] + validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name) + + validator.check("filter's batch", w_size_v[0], "dout's channel", doutput['shape'][1], Rel.EQ, self.name) + validator.check("filter's channel", w_size_v[1], "input_size's channel", x['shape'][1], Rel.EQ, self.name) + validator.check("input_size's batch", x['shape'][0], "dout's batch", doutput['shape'][0], Rel.EQ, self.name) + + # infer shape + x_shape = x['shape'] + dout_shape = doutput['shape'] + kernel_d = self.kernel_size[0] + kernel_h = self.kernel_size[1] + kernel_w = self.kernel_size[2] + stride_d = self.stride[2] + stride_h = self.stride[3] + stride_w = self.stride[4] + dilation_d = self.dilation[2] + dilation_h = self.dilation[3] + dilation_w = self.dilation[4] + # The pad_mode is valid by default. If pad_mode is not valid or same, then pad. + if self.pad_mode == "valid": + self.pad_list = (0, 0, 0, 0, 0, 0) + if self.pad_mode == "same": + pad_needed_d = max(0, (dout_shape[2] - 1) * stride_d + dilation_d * (kernel_d - 1) + 1 - x_shape[2]) + pad_head = math.floor(pad_needed_d / 2) + pad_tail = pad_needed_d - pad_head + + pad_needed_h = max(0, (dout_shape[3] - 1) * stride_h + dilation_h * (kernel_h - 1) + 1 - x_shape[3]) + pad_top = math.floor(pad_needed_h / 2) + pad_bottom = pad_needed_h - pad_top + + pad_needed_w = max(0, (dout_shape[4] - 1) * stride_w + dilation_w * (kernel_w - 1) + 1 - x_shape[4]) + pad_left = math.floor(pad_needed_w / 2) + pad_right = pad_needed_w - pad_left + self.pad_list = (pad_head, pad_tail, pad_top, pad_bottom, pad_left, pad_right) + + self.add_prim_attr('pads', self.pad_list) + out = { + 'value': None, + 'shape': w_size_v, + 'dtype': mstype.float32, + } + return out + + class Conv2DBackpropFilter(PrimitiveWithInfer): """ Computes the gradients of convolution with respect to the filter. diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index f59cec67a2..c13035f1cd 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -18,6 +18,7 @@ import math import operator from functools import reduce, partial +from mindspore._checkparam import _check_3d_int_or_tuple import numpy as np from ... import context from .. import signature as sig @@ -3683,6 +3684,7 @@ class AdamNoUpdateParam(PrimitiveWithInfer): [-0.00013441 -0.00013441 -0.00013441]] """ + @prim_attr_register def __init__(self, use_locking=False, use_nesterov=False): validator.check_value_type("use_locking", use_locking, [bool], self.name) @@ -6526,3 +6528,295 @@ class LRN(PrimitiveWithInfer): def infer_shape(self, x_shape): validator.check_int(len(x_shape), 4, Rel.EQ, "x_shape", self.name) return x_shape + + +class Conv3D(PrimitiveWithInfer): + r""" + 3D convolution layer. + + Applies a 3D convolution over an input tensor which is typically of shape + :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})`, where :math:`N` is batch size and :math:`C_{in}` is channel number. + For each batch of shape :math:`(C_{in}, D_{in}, H_{in}, W_{in})`. + + If the 'pad_mode' is set to be "valid", the output height and width will be + :math:`\left \lfloor{1 + \frac{D_{in} + 2 \times \text{padding} - \text{ks_d} - + (\text{ks_d} - 1) \times (\text{dilation} - 1) }{\text{stride}}} \right \rfloor` and + :math:`\left \lfloor{1 + \frac{H_{in} + 2 \times \text{padding} - \text{ks_h} - + (\text{ks_h} - 1) \times (\text{dilation} - 1) }{\text{stride}}} \right \rfloor` and + :math:`\left \lfloor{1 + \frac{W_{in} + 2 \times \text{padding} - \text{ks_w} - + (\text{ks_w} - 1) \times (\text{dilation} - 1) }{\text{stride}}} \right \rfloor` respectively. + + Args: + out_channel (int): The dimension of the output. + kernel_size (Union[int, tuple[int]]): The kernel size of the 3D convolution. + mode (int): Modes for different convolutions. Not currently used. + pad_mode (str): Modes to fill padding. It could be "valid", "same", or "pad". Default: "valid". + pad (Union(int, tuple[int])): The pad value to be filled. Default: 0. If `pad` is an integer, the paddings of + head, tail, top, bottom, left and right are the same, equal to pad. If `pad` is a tuple of four + integers, the padding of head, tail, top, bottom, left and right equal to pad[0], pad[1], pad[2], + pad[3], pad[4] and pad[5] correspondingly. + stride (Union(int, tuple[int])): The stride to be applied to the convolution filter. Default: 1. + dilation (Union(int, tuple[int])): Specifies the space to use between kernel elements. Default: 1. + group (int): Splits input into groups. Default: 1. + data_format (str): The optional value for data format. Currently only support "NCDHW". + + Inputs: + - **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})`. + - **weight** (Tensor) - Set size of kernel is :math:`(D_1, K_2, K_3)`, then the shape is + :math:`(C_{out}, C_{in}, D_{in}, K_1, K_2)`. + + Outputs: + Tensor, the value that applied 3D convolution. The shape is :math:`(N, C_{out}, D_{out}, H_{out}, W_{out})`. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> input = Tensor(np.ones([16, 3, 10, 32, 32]), mindspore.float32) + >>> weight = Tensor(np.ones([32, 3, 4, 3, 3]), mindspore.float32) + >>> conv3d = P.Conv3D(out_channel=32, kernel_size=(4, 3, 3)) + >>> output = conv3d(input, weight) + >>> print(output.shape) + (16, 32, 7, 30, 30) + """ + + @prim_attr_register + def __init__(self, + out_channel, + kernel_size, + mode=1, + pad_mode="valid", + pad=0, + stride=1, + dilation=1, + group=1, + data_format="NCDHW"): + """Initialize Conv3D""" + self.init_prim_io_names(inputs=['x', 'w'], outputs=['output']) + self.kernel_size = _check_3d_int_or_tuple('kernel_size', kernel_size, self.name) + self.stride = _check_3d_int_or_tuple('stride', stride, self.name, allow_five=True, ret_five=True) + self.add_prim_attr('strides', self.stride) + self.dilation = _check_3d_int_or_tuple('dilation', dilation, self.name, allow_five=True, ret_five=True) + self.add_prim_attr('dilations', self.dilation) + validator.check_value_type('pad', pad, (int, tuple), self.name) + if isinstance(pad, int): + pad = (pad,) * 6 + validator.check_equal_int(len(pad), 6, 'pad size', self.name) + self.padding = pad + self.pad_mode = validator.check_string(pad_mode.lower(), ['valid', 'same', 'pad'], 'pad_mode', self.name) + self.add_prim_attr('pad_mode', self.pad_mode) + + if self.pad_mode != 'pad' and pad != (0, 0, 0, 0, 0, 0): + raise ValueError(f"For '{self.name}', when pad is not 0, pad_mode should be set as 'pad'.") + if self.pad_mode == 'pad': + for item in pad: + validator.check_non_negative_int(item, 'pad item', self.name) + + self.mode = validator.check_equal_int(mode, 1, 'mode', self.name) + self.format = validator.check_string(data_format, ['NCDHW'], 'format', self.name) + self.add_prim_attr('data_format', self.format) + self.add_prim_attr('io_format', "NCDHW") + self.out_channel = validator.check_positive_int(out_channel, 'out_channel', self.name) + self.group = validator.check_positive_int(group, 'group', self.name) + self.add_prim_attr('groups', self.group) + self.add_prim_attr('offset_x', 0) + + def infer_shape(self, x_shape, w_shape, b_shape=None): + validator.check_equal_int(len(w_shape), 5, "weight rank", self.name) + validator.check_equal_int(len(x_shape), 5, "x rank", self.name) + if b_shape is not None: + raise ValueError("Bias currently only support None.") + validator.check(f"x_shape[1] / group", x_shape[1] // self.group, "w_shape[1]", w_shape[1], Rel.EQ, self.name) + validator.check('out_channel', self.out_channel, 'w_shape[0]', w_shape[0], Rel.EQ, self.name) + validator.check('kernel_size', self.kernel_size, 'w_shape[1:4]', tuple(w_shape[2:]), Rel.EQ, self.name) + + kernel_size_d = w_shape[2] + kernel_size_h = w_shape[3] + kernel_size_w = w_shape[4] + + stride_d = self.stride[2] + stride_h = self.stride[3] + stride_w = self.stride[4] + + dilation_d = self.dilation[2] + dilation_h = self.dilation[3] + dilation_w = self.dilation[4] + + if self.pad_mode == "valid": + d_out = math.ceil((x_shape[2] - dilation_d * (kernel_size_d - 1)) / stride_d) + h_out = math.ceil((x_shape[3] - dilation_h * (kernel_size_h - 1)) / stride_h) + w_out = math.ceil((x_shape[4] - dilation_w * (kernel_size_w - 1)) / stride_w) + pad_head, pad_tail, pad_top, pad_bottom, pad_left, pad_right = 0, 0, 0, 0, 0, 0 + + elif self.pad_mode == "same": + d_out = math.ceil(x_shape[2] / stride_d) + h_out = math.ceil(x_shape[3] / stride_h) + w_out = math.ceil(x_shape[4] / stride_w) + + pad_needed_d = max(0, (d_out - 1) * stride_d + dilation_d * (kernel_size_d - 1) + 1 - x_shape[2]) + pad_head = math.floor(pad_needed_d / 2) + pad_tail = pad_needed_d - pad_head + + pad_needed_h = max(0, (h_out - 1) * stride_h + dilation_h * (kernel_size_h - 1) + 1 - x_shape[3]) + pad_top = math.floor(pad_needed_h / 2) + pad_bottom = pad_needed_h - pad_top + + pad_needed_w = max(0, (w_out - 1) * stride_w + dilation_w * (kernel_size_w - 1) + 1 - x_shape[4]) + pad_left = math.floor(pad_needed_w / 2) + pad_right = pad_needed_w - pad_left + + elif self.pad_mode == 'pad': + pad_head, pad_tail, pad_top, pad_bottom, pad_left, pad_right = self.padding + d_out = 1 + (x_shape[2] + pad_head + pad_tail - kernel_size_d - (kernel_size_d - 1) + * (dilation_d - 1)) / stride_d + h_out = 1 + (x_shape[3] + pad_top + pad_bottom - kernel_size_h - (kernel_size_h - 1) + * (dilation_h - 1)) / stride_h + w_out = 1 + (x_shape[4] + pad_left + pad_right - kernel_size_w - (kernel_size_w - 1) + * (dilation_w - 1)) / stride_w + d_out = math.floor(d_out) + h_out = math.floor(h_out) + w_out = math.floor(w_out) + + self.pad_list = [pad_head, pad_tail, pad_top, pad_bottom, pad_left, pad_right] + self.add_prim_attr('pads', (pad_head, pad_tail, pad_top, pad_bottom, pad_left, pad_right)) + out_channel = self.out_channel + out_shape = [x_shape[0], out_channel, d_out, h_out, w_out] + _check_shape('output', out_shape, self.name) + return out_shape + + def infer_dtype(self, x_dtype, w_dtype, b_dtype=None): + args = {'x': x_dtype, 'w': w_dtype} + valid_dtypes = [mstype.float16, mstype.float32] + validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name) + return x_dtype + + +class Conv3DBackpropInput(PrimitiveWithInfer): + """ + Computes the gradients of convolution 3D with respect to the input. + + Args: + out_channel (int): The dimension of the output. + kernel_size (Union[int, tuple[int]]): The kernel size of the 3D convolution. + mode (int): Modes for different convolutions. Not currently used. + pad_mode (str): Modes to fill padding. It could be "valid", "same", or "pad". Default: "valid". + pad (Union(int, tuple[int])): The pad value to be filled. Default: 0. If `pad` is an integer, the paddings of + head, tail, top, bottom, left and right are the same, equal to pad. If `pad` is a tuple of four + integers, the padding of head, tail, top, bottom, left and right equal to pad[0], pad[1], pad[2], + pad[3], pad[4] and pad[5] correspondingly. + stride (Union(int, tuple[int])): The stride to be applied to the convolution filter. Default: 1. + dilation (Union(int, tuple[int])): Specifies the space to use between kernel elements. Default: 1. + group (int): Splits input into groups. Default: 1. + data_format (str): The optional value for data format. Currently only support 'NCDHW'. + + Inputs: + - **weight** (Tensor) - Set size of kernel is :math:`(K_1, K_2, K_3)`, then the shape is + :math:`(C_{out}, C_{in}, D_{in}, K_1, K_2)`. + - **dout** (Tensor) - the gradients w.r.t the output of the convolution. The shape conforms to the default + data_format :math:`(N, C_{out}, D_{out}, H_{out}, W_{out})`. + - **input_size** (Tensor) - A tuple describes the shape of the input which conforms to the format + :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})`. + + Outputs: + Tensor, the gradients w.r.t the input of convolution 3D. It has the same shape as the input. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> dout = Tensor(np.ones([16, 32, 10, 32, 32]), mindspore.float32) + >>> weight = Tensor(np.ones([32, 32, 4, 6, 2]), mindspore.float32) + >>> x = Tensor(np.ones([16, 32, 13, 37, 33])) + >>> conv3d_backprop_input = P.Conv3DBackpropInput(out_channel=4, kernel_size=(4, 6, 2)) + >>> output = conv3d_backprop_input(dout, weight, F.shape(x)) + >>> print(output.shape) + (16, 32, 13, 37, 33) + """ + + @prim_attr_register + def __init__(self, + out_channel, + kernel_size, + pad_mode="valid", + pad=0, + mode=1, + stride=1, + dilation=1, + group=1, + data_format="NCDHW"): + """Initialize Conv3DBackpropInput""" + self.init_prim_io_names(inputs=['filter', 'out_backprop', 'input_size'], outputs=['y']) + self.out_channel = validator.check_positive_int(out_channel, 'out_channel', self.name) + self.kernel_size = _check_3d_int_or_tuple('kernel_size', kernel_size, self.name) + self.stride = _check_3d_int_or_tuple('stride', stride, self.name, allow_five=True, ret_five=True) + self.add_prim_attr('strides', self.stride) + self.dilation = _check_3d_int_or_tuple('dilation', dilation, self.name, allow_five=True, ret_five=True) + self.add_prim_attr('dilations', self.dilation) + validator.check_value_type('pad', pad, (int, tuple), self.name) + if isinstance(pad, int): + pad = (pad,) * 6 + validator.check_equal_int(len(pad), 6, 'pad size', self.name) + self.pad_list = pad + + self.pad_mode = validator.check_string(pad_mode.lower(), ['valid', 'same', 'pad'], 'pad_mode', self.name) + if self.pad_mode != 'pad' and self.pad_list != (0, 0, 0, 0, 0, 0): + raise ValueError(f"For '{self.name}', when pad is not 0, pad_mode should be set as 'pad'.") + if self.pad_mode == 'pad': + for item in pad: + validator.check_non_negative_int(item, 'pad item', self.name) + self.add_prim_attr('pad_mode', self.pad_mode) + + self.mode = validator.check_equal_int(mode, 1, 'mode', self.name) + self.group = validator.check_positive_int(group, 'group', self.name) + self.add_prim_attr('groups', self.group) + self.format = validator.check_string(data_format, ['NCDHW'], 'format', self.name) + self.add_prim_attr('data_format', self.format) + self.add_prim_attr('io_format', "NCDHW") + + def __infer__(self, w, doutput, x_size): + x_size_v = x_size['value'] + validator.check_value_type('x_size', x_size_v, [tuple], self.name) + for i, dim_len in enumerate(x_size_v): + validator.check_value_type("x_size[%d]" % i, dim_len, [int], self.name) + args = {'doutput': doutput['dtype'], 'w': w['dtype']} + valid_dtypes = [mstype.float16, mstype.float32] + validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name) + validator.check("filter's batch", w['shape'][0], "dout's channel", doutput['shape'][1], Rel.EQ, self.name) + validator.check("filter's channel", w['shape'][1], "input_size's channel", x_size_v[1], Rel.EQ, self.name) + validator.check("input_size's batch", x_size_v[0], "dout's batch", doutput['shape'][0], Rel.EQ, self.name) + + # infer shape + dout_shape = doutput['shape'] + kernel_d = self.kernel_size[0] + kernel_h = self.kernel_size[1] + kernel_w = self.kernel_size[2] + stride_d = self.stride[2] + stride_h = self.stride[3] + stride_w = self.stride[4] + dilation_d = self.dilation[2] + dilation_h = self.dilation[3] + dilation_w = self.dilation[4] + # The pad_mode is valid by default. If pad_mode is not valid or same, then pad. + if self.pad_mode == "valid": + self.pad_list = (0, 0, 0, 0, 0, 0) + if self.pad_mode == "same": + pad_needed_d = max(0, (dout_shape[2] - 1) * stride_d + dilation_d * (kernel_d - 1) + 1 - x_size_v[2]) + pad_head = math.floor(pad_needed_d / 2) + pad_tail = pad_needed_d - pad_head + + pad_needed_h = max(0, (dout_shape[3] - 1) * stride_h + dilation_h * (kernel_h - 1) + 1 - x_size_v[3]) + pad_top = math.floor(pad_needed_h / 2) + pad_bottom = pad_needed_h - pad_top + + pad_needed_w = max(0, (dout_shape[4] - 1) * stride_w + dilation_w * (kernel_w - 1) + 1 - x_size_v[4]) + pad_left = math.floor(pad_needed_w / 2) + pad_right = pad_needed_w - pad_left + self.pad_list = (pad_head, pad_tail, pad_top, pad_bottom, pad_left, pad_right) + + self.add_prim_attr('pads', self.pad_list) + out = { + 'value': None, + 'shape': x_size_v, + 'dtype': doutput['dtype'], + } + return out diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index 96e321efee..51c29f7eec 100755 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -27,6 +27,7 @@ from mindspore.ops import operations as P from mindspore.ops.operations import _grad_ops as G from mindspore.ops.operations import _inner_ops as inner from mindspore.ops.operations import _quant_ops as Q +from mindspore.ops.operations import nn_ops as nps from ..ut_filter import non_graph_engine from ....mindspore_test_framework.mindspore_test import mindspore_test from ....mindspore_test_framework.pipeline.forward.compile_forward \ @@ -288,6 +289,7 @@ class CountNonZero(nn.Cell): self.axis = axis self.keep_dims = keep_dims self.dtype = dtype + def construct(self, input_x): nonzero_num = C.count_nonzero(input_x, self.axis, self.keep_dims, self.dtype) return nonzero_num @@ -423,6 +425,50 @@ class ScatterDiv(nn.Cell): return out +class Conv3D(nn.Cell): + """Conv3D net definition""" + + def __init__(self, out_channel, kernel_size, mode, pad_mode, pad, stride, dilation, group, data_format): + super(Conv3D, self).__init__() + self.conv = nps.Conv3D(out_channel=out_channel, kernel_size=kernel_size, mode=mode, pad_mode=pad_mode, + pad=pad, stride=stride, dilation=dilation, group=group, data_format=data_format) + + def construct(self, x, w): + out = self.conv(x, w) + return out + + +class Conv3DBackpropInput(nn.Cell): + """Conv3DBackpropInput net definition""" + + def __init__(self, input_shape, out_channel, kernel_size, mode, pad_mode, pad, stride, dilation, group, + data_format): + super(Conv3DBackpropInput, self).__init__() + self.conv = nps.Conv3DBackpropInput(out_channel, kernel_size, pad_mode=pad_mode, + pad=pad, mode=mode, stride=stride, dilation=dilation, + group=group, data_format=data_format) + self.x_size = input_shape + + def construct(self, w, doutput): + ms_out = self.conv(w, doutput, self.x_size) + return ms_out + + +class Conv3DBackpropFilter(nn.Cell): + """Conv3DBackpropFilter net definition""" + + def __init__(self, w_shape, out_channel, kernel_size, mode, pad_mode, pad, stride, dilation, group, data_format): + super(Conv3DBackpropFilter, self).__init__() + self.conv = G.Conv3DBackpropFilter(out_channel, kernel_size, pad_mode=pad_mode, + pad=pad, mode=mode, stride=stride, dilation=dilation, + group=group, data_format=data_format) + self.w_size = w_shape + + def construct(self, x, doutput): + ms_out = self.conv(x, doutput, self.w_size) + return ms_out + + class ApplyFtrlNet(nn.Cell): def __init__(self): super(ApplyFtrlNet, self).__init__() @@ -1180,6 +1226,24 @@ test_case_math_ops = [ 'block': Moments(axis=(), keep_dims=False), 'desc_inputs': [Tensor(np.random.rand(3, 16, 5, 4).astype(np.float32))], 'skip': ['backward']}), + ('Conv3D', { + 'block': Conv3D(out_channel=32, kernel_size=(4, 3, 3), mode=1, pad_mode='valid', pad=0, + stride=1, dilation=1, group=1, data_format="NCDHW"), + 'desc_inputs': [Tensor(np.random.random((16, 3, 10, 32, 32)).astype(np.float16)), + Tensor(np.random.random((32, 3, 4, 3, 3)).astype(np.float16))], + 'skip': ['backward']}), + ('Conv3DBackpropInput', { + 'block': Conv3DBackpropInput(input_shape=(16, 32, 13, 37, 33), out_channel=32, kernel_size=(4, 6, 2), mode=1, + pad_mode='valid', pad=0, stride=1, dilation=1, group=1, data_format="NCDHW"), + 'desc_inputs': [Tensor(np.random.random((32, 32, 4, 6, 2)).astype(np.float16)), + Tensor(np.random.random((16, 32, 10, 32, 32)).astype(np.float16))], + 'skip': ['backward']}), + ('Conv3DBackpropFilter', { + 'block': Conv3DBackpropFilter(w_shape=(32, 32, 4, 6, 2), out_channel=32, kernel_size=(4, 6, 2), mode=1, + pad_mode='valid', pad=0, stride=1, dilation=1, group=1, data_format="NCDHW"), + 'desc_inputs': [Tensor(np.random.random((16, 32, 13, 37, 33)).astype(np.float16)), + Tensor(np.random.random((16, 32, 10, 32, 32)).astype(np.float16))], + 'skip': ['backward']}), ('CountNonZero', { 'block': CountNonZero(axis=(), keep_dims=False, dtype=mstype.int32), 'desc_inputs': [Tensor(np.random.rand(3, 16, 5, 4).astype(np.float32))],