From: @tom__chen Reviewed-by: Signed-off-by:tags/v1.2.0-rc1
| @@ -31,31 +31,7 @@ namespace kernel { | |||
| template <typename T> | |||
| class Conv2dGpuFwdKernel : public GpuKernel { | |||
| public: | |||
| Conv2dGpuFwdKernel() | |||
| : cudnn_handle_(nullptr), | |||
| input_desc_(nullptr), | |||
| output_desc_(nullptr), | |||
| filter_desc_(nullptr), | |||
| conv_desc_(nullptr), | |||
| padded_desc_(nullptr), | |||
| cudnn_data_type_(CUDNN_DATA_FLOAT), | |||
| compute_format_(CUDNN_TENSOR_NCHW), | |||
| old_height_(0), | |||
| old_width_(0), | |||
| pad_height_(0), | |||
| pad_width_(0), | |||
| pad_top_(0), | |||
| pad_left_(0), | |||
| n_(0), | |||
| c_(0), | |||
| group_(1), | |||
| is_null_input_(false), | |||
| input_size_(0), | |||
| filter_size_(0), | |||
| output_size_(0), | |||
| padded_size_(0), | |||
| workspace_size_(0), | |||
| use_pad_(true) {} | |||
| Conv2dGpuFwdKernel() { ResetResource(); } | |||
| ~Conv2dGpuFwdKernel() override { DestroyResource(); } | |||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||
| const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | |||
| @@ -194,6 +170,38 @@ class Conv2dGpuFwdKernel : public GpuKernel { | |||
| return true; | |||
| } | |||
| void ResetResource() noexcept override { | |||
| cudnn_handle_ = nullptr; | |||
| input_desc_ = nullptr; | |||
| output_desc_ = nullptr; | |||
| filter_desc_ = nullptr; | |||
| conv_desc_ = nullptr; | |||
| padded_desc_ = nullptr; | |||
| cudnn_data_type_ = CUDNN_DATA_FLOAT; | |||
| compute_format_ = CUDNN_TENSOR_NCHW; | |||
| old_height_ = 0; | |||
| old_width_ = 0; | |||
| pad_height_ = 0; | |||
| pad_width_ = 0; | |||
| pad_top_ = 0; | |||
| pad_left_ = 0; | |||
| n_ = 0; | |||
| c_ = 0; | |||
| stride_.clear(); | |||
| dilation_.clear(); | |||
| group_ = 1; | |||
| is_null_input_ = false; | |||
| input_size_ = 0; | |||
| filter_size_ = 0; | |||
| output_size_ = 0; | |||
| padded_size_ = 0; | |||
| workspace_size_ = 0; | |||
| use_pad_ = true; | |||
| input_size_list_.clear(); | |||
| output_size_list_.clear(); | |||
| workspace_size_list_.clear(); | |||
| } | |||
| void DestroyResource() noexcept override { | |||
| CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyConvolutionDescriptor(conv_desc_), | |||
| "cudnnDestroyConvolutionDescriptor failed"); | |||
| @@ -57,6 +57,8 @@ AbstractBasePtr InferImplBatchNormGrad(const AnalysisEnginePtr &, const Primitiv | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplReluGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplConv2D(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplConv2DBackpropInput(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplConv2DBackpropFilter(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| @@ -14,6 +14,7 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include <cmath> | |||
| #include "abstract/infer_functions.h" | |||
| #include "abstract/utils.h" | |||
| #include "abstract/param_validator.h" | |||
| @@ -267,6 +268,194 @@ AbstractBasePtr InferImplFusedSparseAdam(const AnalysisEnginePtr &, const Primit | |||
| return std::make_shared<AbstractTuple>(rets); | |||
| } | |||
| AbstractBasePtr InferImplConv2D(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| const std::string op_name = primitive->name(); | |||
| CheckArgsSize(op_name, args_spec_list, 2); | |||
| AbstractTensorPtr input_x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0); | |||
| MS_EXCEPTION_IF_NULL(input_x); | |||
| MS_EXCEPTION_IF_NULL(input_x->shape()); | |||
| ShapeVector x_shape = input_x->shape()->shape(); | |||
| ShapeVector x_min_shape = input_x->shape()->min_shape(); | |||
| ShapeVector x_max_shape = input_x->shape()->max_shape(); | |||
| (void)CheckMinMaxShape(x_shape, &x_min_shape, &x_max_shape); | |||
| AbstractTensorPtr input_w = CheckArg<AbstractTensor>(op_name, args_spec_list, 1); | |||
| MS_EXCEPTION_IF_NULL(input_w); | |||
| MS_EXCEPTION_IF_NULL(input_w->shape()); | |||
| ShapeVector w_shape = input_w->shape()->shape(); | |||
| ShapeVector w_min_shape = input_w->shape()->min_shape(); | |||
| ShapeVector w_max_shape = input_w->shape()->max_shape(); | |||
| (void)CheckMinMaxShape(w_shape, &w_min_shape, &w_max_shape); | |||
| std::set<std::string> available_data_format{"NCHW", "NHWC"}; | |||
| int64_t n_axis = 0; | |||
| int64_t c_axis = 1; | |||
| int64_t h_axis = 2; | |||
| int64_t w_axis = 3; | |||
| auto data_format_ptr = primitive->GetAttr("format"); | |||
| std::string data_format = "NCHW"; | |||
| if ((data_format_ptr != nullptr) && data_format_ptr->isa<StringImm>()) { | |||
| data_format = data_format_ptr->cast<StringImmPtr>()->value(); | |||
| if (available_data_format.find(data_format) == available_data_format.end()) { | |||
| MS_LOG(EXCEPTION) << "Unsupported data format: " << data_format << ". use NCHW or NHWC"; | |||
| } | |||
| if (data_format == "NHWC") { | |||
| c_axis = 3; | |||
| h_axis = 1; | |||
| w_axis = 2; | |||
| } | |||
| } | |||
| int64_t group = primitive->GetAttr("group")->cast<Int64ImmPtr>()->value(); | |||
| if (group <= 0) { | |||
| MS_LOG(EXCEPTION) << "Invalid group value: " << group << ", should be greater then 0"; | |||
| } | |||
| if ((x_shape[c_axis] != Shape::SHP_ANY) && (x_shape[c_axis] % group != 0)) { | |||
| MS_LOG(EXCEPTION) << "x_shape[" << c_axis << "] = " << x_shape[c_axis] | |||
| << " (channels) must be divisible by group = " << group; | |||
| } | |||
| int64_t out_channel = primitive->GetAttr("out_channel")->cast<Int64ImmPtr>()->value(); | |||
| if (out_channel <= 0) { | |||
| MS_LOG(EXCEPTION) << "Invalid out_channel value: " << out_channel << ", should be greater then 0"; | |||
| } | |||
| if ((w_shape[0] != Shape::SHP_ANY) && (w_shape[0] != out_channel)) { | |||
| MS_LOG(EXCEPTION) << "w_shape[0] = " << w_shape[0] << " must equal to = " << out_channel; | |||
| } | |||
| int64_t kernel_h = 0; | |||
| int64_t kernel_w = 0; | |||
| ValuePtr kernel_size_attr = primitive->GetAttr("kernel_size"); | |||
| MS_EXCEPTION_IF_NULL(kernel_size_attr); | |||
| if (kernel_size_attr->isa<ValueTuple>()) { | |||
| std::vector<ValuePtr> kernel_size_vec = kernel_size_attr->cast<ValueTuplePtr>()->value(); | |||
| kernel_h = GetValue<int64_t>(kernel_size_vec[0]); | |||
| kernel_w = GetValue<int64_t>(kernel_size_vec[1]); | |||
| } else { | |||
| int64_t kernel_size = kernel_size_attr->cast<Int64ImmPtr>()->value(); | |||
| kernel_h = kernel_size; | |||
| kernel_w = kernel_size; | |||
| } | |||
| if ((w_shape[2] != Shape::SHP_ANY) && (w_shape[2] != kernel_h)) { | |||
| MS_LOG(EXCEPTION) << "weight height, w_shape[2] = " << w_shape[2] << ", must equal to = " << kernel_h; | |||
| } | |||
| if ((w_shape[3] != Shape::SHP_ANY) && (w_shape[3] != kernel_w)) { | |||
| MS_LOG(EXCEPTION) << "weight width, w_shape[3] = " << w_shape[3] << ", must equal to = " << kernel_w; | |||
| } | |||
| int64_t stride_h = 0; | |||
| int64_t stride_w = 0; | |||
| ValuePtr stride_attr = primitive->GetAttr("stride"); | |||
| MS_EXCEPTION_IF_NULL(stride_attr); | |||
| if (stride_attr->isa<ValueTuple>()) { | |||
| std::vector<ValuePtr> stride_vec = stride_attr->cast<ValueTuplePtr>()->value(); | |||
| stride_h = GetValue<int64_t>(stride_vec[2]); | |||
| stride_w = GetValue<int64_t>(stride_vec[3]); | |||
| } else { | |||
| int64_t stride = stride_attr->cast<Int64ImmPtr>()->value(); | |||
| stride_h = stride; | |||
| stride_w = stride; | |||
| } | |||
| int64_t dilation_h = 0; | |||
| int64_t dilation_w = 0; | |||
| ValuePtr dilation_attr = primitive->GetAttr("dilation"); | |||
| MS_EXCEPTION_IF_NULL(dilation_attr); | |||
| if (dilation_attr->isa<ValueTuple>()) { | |||
| std::vector<ValuePtr> dilation_vec = dilation_attr->cast<ValueTuplePtr>()->value(); | |||
| dilation_h = GetValue<int64_t>(dilation_vec[2]); | |||
| dilation_w = GetValue<int64_t>(dilation_vec[3]); | |||
| } else { | |||
| int64_t dilation = dilation_attr->cast<Int64ImmPtr>()->value(); | |||
| dilation_h = dilation; | |||
| dilation_w = dilation; | |||
| } | |||
| std::vector<int64_t> padding; | |||
| ValuePtr padding_attr = primitive->GetAttr("pad"); | |||
| MS_EXCEPTION_IF_NULL(padding_attr); | |||
| if (padding_attr->isa<ValueTuple>()) { | |||
| std::vector<ValuePtr> padding_vec = padding_attr->cast<ValueTuplePtr>()->value(); | |||
| (void)std::transform(std::begin(padding_vec), std::end(padding_vec), std::back_inserter(padding), | |||
| [](const ValuePtr &e) -> int64_t { return GetValue<int64_t>(e); }); | |||
| } else { | |||
| int64_t padding_val = padding_attr->cast<Int64ImmPtr>()->value(); | |||
| padding = {padding_val, padding_val, padding_val, padding_val}; | |||
| } | |||
| std::set<std::string> available_pad_mode{"pad", "same", "valid"}; | |||
| ValuePtr pad_mode_attr = primitive->GetAttr("pad_mode"); | |||
| MS_EXCEPTION_IF_NULL(pad_mode_attr); | |||
| auto pad_mode = pad_mode_attr->cast<StringImmPtr>()->value(); | |||
| if (available_pad_mode.find(pad_mode) == available_pad_mode.end()) { | |||
| MS_LOG(EXCEPTION) << "Unsupported pad mode: " << pad_mode << ". use pad, same, valid"; | |||
| } | |||
| std::function<void(int64_t, int64_t, std::vector<int64_t> &, std::vector<int64_t> &)> pad_function = | |||
| [kernel_h, kernel_w, stride_h, stride_w, dilation_h, dilation_w, pad_mode, padding]( | |||
| int64_t x_h, int64_t x_w, std::vector<int64_t> &output_hw, std::vector<int64_t> &pad_list) { | |||
| if (pad_mode == "valid") { | |||
| output_hw.push_back(std::ceil(((x_h * 1.0) - dilation_h * (kernel_h - 1)) / stride_h)); | |||
| output_hw.push_back(std::ceil(((x_w * 1.0) - dilation_w * (kernel_w - 1)) / stride_w)); | |||
| pad_list = {0, 0, 0, 0}; | |||
| } else if (pad_mode == "same") { | |||
| output_hw.push_back(std::ceil((x_h * 1.0) / stride_h)); | |||
| output_hw.push_back(std::ceil((x_w * 1.0) / stride_w)); | |||
| int64_t pad_needed_h = (output_hw[0] - 1) * stride_h + dilation_h * (kernel_h - 1) + 1 - x_h; | |||
| pad_needed_h = std::max((int64_t)0, pad_needed_h); | |||
| pad_list.push_back(std::floor(pad_needed_h / 2)); | |||
| pad_list.push_back(pad_needed_h - pad_list[0]); | |||
| int64_t pad_needed_w = (output_hw[1] - 1) * stride_w + dilation_w * (kernel_w - 1) + 1 - x_w; | |||
| pad_needed_w = std::max((int64_t)0, pad_needed_w); | |||
| pad_list.push_back(std::floor(pad_needed_w / 2)); | |||
| pad_list.push_back(pad_needed_w - pad_list[2]); | |||
| } else if (pad_mode == "pad") { | |||
| pad_list = padding; | |||
| output_hw.push_back(std::floor( | |||
| 1 + ((x_h * 1.0) + pad_list[0] + pad_list[1] - kernel_h - (kernel_h - 1) * (dilation_h - 1)) / stride_h)); | |||
| output_hw.push_back(std::floor( | |||
| 1 + ((x_w * 1.0) + pad_list[2] + pad_list[3] - kernel_w - (kernel_w - 1) * (dilation_w - 1)) / stride_w)); | |||
| } | |||
| }; | |||
| std::vector<int64_t> output_hw; | |||
| std::vector<int64_t> pad_list; | |||
| std::vector<int64_t> output_hw_min; | |||
| std::vector<int64_t> pad_list_min; | |||
| std::vector<int64_t> output_hw_max; | |||
| std::vector<int64_t> pad_list_max; | |||
| pad_function(x_shape[h_axis], x_shape[w_axis], output_hw, pad_list); | |||
| if (x_shape[h_axis] == Shape::SHP_ANY) { | |||
| output_hw[0] = Shape::SHP_ANY; | |||
| } | |||
| if (x_shape[w_axis] == Shape::SHP_ANY) { | |||
| output_hw[1] = Shape::SHP_ANY; | |||
| } | |||
| pad_function(x_min_shape[h_axis], x_min_shape[w_axis], output_hw_min, pad_list_min); | |||
| pad_function(x_max_shape[h_axis], x_max_shape[w_axis], output_hw_max, pad_list_max); | |||
| std::vector<ValuePtr> pad_list_val = {MakeValue(pad_list[0]), MakeValue(pad_list[1]), MakeValue(pad_list[2]), | |||
| MakeValue(pad_list[3])}; | |||
| primitive->set_attr("pad_list", MakeValue(pad_list_val)); | |||
| ShapeVector output_shape; | |||
| ShapeVector output_shape_min; | |||
| ShapeVector output_shape_max; | |||
| if (data_format == "NHWC") { | |||
| output_shape = {x_shape[n_axis], output_hw[0], output_hw[1], out_channel}; | |||
| output_shape_min = {x_min_shape[n_axis], output_hw_min[0], output_hw_min[1], out_channel}; | |||
| output_shape_max = {x_max_shape[n_axis], output_hw_max[0], output_hw_max[1], out_channel}; | |||
| } else { | |||
| output_shape = {x_shape[n_axis], out_channel, output_hw[0], output_hw[1]}; | |||
| output_shape_min = {x_min_shape[n_axis], out_channel, output_hw_min[0], output_hw_min[1]}; | |||
| output_shape_max = {x_max_shape[n_axis], out_channel, output_hw_max[0], output_hw_max[1]}; | |||
| } | |||
| ShapePtr output_shape_ptr = std::make_shared<Shape>(output_shape, output_shape_min, output_shape_max); | |||
| return std::make_shared<AbstractTensor>(input_x->element(), output_shape_ptr); | |||
| } | |||
| AbstractBasePtr InferImplConv2DBackpropInput(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| // Inputs: three tensors(doutput, input, filters). | |||
| @@ -111,6 +111,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { | |||
| {prim::kPrimFusedBatchNormGrad, {InferImplFusedBatchNormGrad, true}}, | |||
| {prim::kPrimBatchNormGrad, {InferImplBatchNormGrad, true}}, | |||
| {prim::kPrimReluGrad, {InferImplReluGrad, true}}, | |||
| {prim::kPrimConv2D, {InferImplConv2D, true}}, | |||
| {prim::kPrimConv2DBackpropInput, {InferImplConv2DBackpropInput, true}}, | |||
| {prim::kPrimConv2DBackpropFilter, {InferImplConv2DBackpropFilter, true}}, | |||
| {prim::kPrimBiasAddGrad, {InferImplBiasAddGrad, true}}, | |||
| @@ -1207,7 +1207,7 @@ class BatchNorm(PrimitiveWithInfer): | |||
| return (input_x, scale, bias, input_x, input_x) | |||
| class Conv2D(PrimitiveWithInfer): | |||
| class Conv2D(PrimitiveWithCheck): | |||
| r""" | |||
| 2D convolution layer. | |||
| @@ -1314,65 +1314,16 @@ class Conv2D(PrimitiveWithInfer): | |||
| self.add_prim_attr('groups', self.group) | |||
| self.add_prim_attr('offset_a', 0) | |||
| def infer_shape(self, x_shape, w_shape, b_shape=None): | |||
| def check_shape(self, x_shape, w_shape, b_shape=None): | |||
| x_shape_norm = x_shape if self.format == "NCHW" else (x_shape[0], x_shape[3], x_shape[1], x_shape[2]) | |||
| w_shape_norm = w_shape if self.format == "NCHW" else (w_shape[0], w_shape[3], w_shape[1], w_shape[2]) | |||
| validator.check_equal_int(len(w_shape_norm), 4, "weight rank", self.name) | |||
| validator.check_equal_int(len(x_shape_norm), 4, "x rank", self.name) | |||
| validator.check(f"x_shape[1] / group", x_shape_norm[1] // self.group, "w_shape[1]", w_shape_norm[1], \ | |||
| Rel.EQ, self.name) | |||
| validator.check('out_channel', self.out_channel, 'w_shape[0]', w_shape_norm[0], Rel.EQ, self.name) | |||
| validator.check('kernel_size', self.kernel_size, 'w_shape[2:4]', tuple(w_shape_norm[2:4]), Rel.EQ, self.name) | |||
| kernel_size_h = w_shape_norm[2] | |||
| kernel_size_w = w_shape_norm[3] | |||
| stride_h = self.stride[2] | |||
| stride_w = self.stride[3] | |||
| dilation_h = self.dilation[2] | |||
| dilation_w = self.dilation[3] | |||
| if self.pad_mode == "valid": | |||
| h_out = math.ceil((x_shape_norm[2] - dilation_h * (kernel_size_h - 1)) / stride_h) | |||
| w_out = math.ceil((x_shape_norm[3] - dilation_w * (kernel_size_w - 1)) / stride_w) | |||
| pad_top, pad_bottom, pad_left, pad_right = 0, 0, 0, 0 | |||
| elif self.pad_mode == "same": | |||
| h_out = math.ceil(x_shape_norm[2] / stride_h) | |||
| w_out = math.ceil(x_shape_norm[3] / stride_w) | |||
| pad_needed_h = max(0, (h_out - 1) * stride_h + dilation_h * (kernel_size_h - 1) + 1 - x_shape_norm[2]) | |||
| pad_top = math.floor(pad_needed_h / 2) | |||
| pad_bottom = pad_needed_h - pad_top | |||
| pad_needed_w = max(0, (w_out - 1) * stride_w + dilation_w * (kernel_size_w - 1) + 1 - x_shape_norm[3]) | |||
| pad_left = math.floor(pad_needed_w / 2) | |||
| pad_right = pad_needed_w - pad_left | |||
| elif self.pad_mode == 'pad': | |||
| pad_top, pad_bottom, pad_left, pad_right = self.padding | |||
| h_out = 1 + (x_shape_norm[2] + pad_top + pad_bottom - kernel_size_h - (kernel_size_h - 1) \ | |||
| * (dilation_h - 1)) / stride_h | |||
| w_out = 1 + (x_shape_norm[3] + pad_left + pad_right - kernel_size_w - (kernel_size_w - 1) \ | |||
| * (dilation_w - 1)) / stride_w | |||
| h_out = math.floor(h_out) | |||
| w_out = math.floor(w_out) | |||
| self.pad_list = [pad_top, pad_bottom, pad_left, pad_right] | |||
| self.add_prim_attr('pad_list', (pad_top, pad_bottom, pad_left, pad_right)) | |||
| out_channel = self.out_channel | |||
| out_shape = [x_shape_norm[0], out_channel, h_out, w_out] if self.format == "NCHW" else \ | |||
| [x_shape_norm[0], h_out, w_out, out_channel] | |||
| _check_shape('output', out_shape, self.name) | |||
| return out_shape | |||
| def infer_dtype(self, x_dtype, w_dtype, b_dtype=None): | |||
| def check_dtype(self, x_dtype, w_dtype, b_dtype=None): | |||
| args = {'x': x_dtype, 'w': w_dtype} | |||
| valid_dtypes = [mstype.int8, mstype.int32, mstype.float16, mstype.float32] | |||
| validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name) | |||
| if x_dtype.element_type() == mstype.int8: | |||
| return mstype.tensor_type(mstype.int32) | |||
| return x_dtype | |||
| class DepthwiseConv2dNative(PrimitiveWithInfer): | |||
| @@ -20,6 +20,9 @@ import mindspore.context as context | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore.ops import operations as P | |||
| from mindspore.ops.operations import _inner_ops as inner | |||
| from mindspore.common.parameter import Parameter | |||
| from mindspore.common.initializer import initializer | |||
| class NetConv2d(nn.Cell): | |||
| @@ -61,3 +64,171 @@ def test_conv2d(): | |||
| conv2d = NetConv2d() | |||
| output = conv2d(x, w) | |||
| assert (output.asnumpy() == expect).all() | |||
| class NetConv(nn.Cell): | |||
| def __init__(self, weight, x): | |||
| super(NetConv, self).__init__() | |||
| self.conv = nn.Conv2d(in_channels=3, | |||
| out_channels=3, | |||
| kernel_size=(5, 3), | |||
| stride=2, | |||
| pad_mode='same', | |||
| padding=(0, 0, 0, 0), | |||
| dilation=(1, 1), | |||
| group=1, | |||
| has_bias=False, | |||
| weight_init=Tensor(weight) | |||
| ) | |||
| self.x = Parameter(initializer(Tensor(x), [1, 3, 4, 2]), name="x") | |||
| def construct(self): | |||
| return self.conv(self.x) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_conv(): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| weight = np.array([[[[0.38968208, 0.14398979, 0.7962463], | |||
| [-2.1836321, -0.63823014, -0.50588065], | |||
| [0.6660469, 0.64673275, -0.13160042], | |||
| [1.3683757, 1.4005762, -0.37235805], | |||
| [-0.22638111, 0.45427424, -0.10293389]], | |||
| [[1.4985064, -0.29318333, -0.92694616], | |||
| [1.539068, 0.8937254, -1.2598171], | |||
| [0.9658142, -0.63945454, -0.23185322], | |||
| [1.363089, -0.41694695, -2.2750475], | |||
| [-0.4865508, -1.6938025, 0.609849]], | |||
| [[1.1844803, 0.99874926, -1.9475793], | |||
| [0.4987858, 0.5307887, -0.04226681], | |||
| [0.4529779, -1.1960793, 0.9456575], | |||
| [3.133675, 0.2309789, -0.29201075], | |||
| [-0.59632736, -0.0789804, -0.69486314]]], | |||
| [[[-0.5606142, 0.6420862, 0.2478745], | |||
| [0.02717604, 1.5483379, -0.9373383], | |||
| [-1.1017276, -0.259478, 1.0311872], | |||
| [1.8387799, 0.16468556, 0.33392152], | |||
| [-1.8781787, 1.0158662, 1.6527579]], | |||
| [[0.45696944, -0.5652523, -1.5618048], | |||
| [-0.30304828, 0.1331878, -0.36955845], | |||
| [0.91655576, 0.66612357, 0.3068175], | |||
| [-0.45732066, 0.8923335, 1.0542952], | |||
| [-0.73519516, 1.0518405, -1.0273266]], | |||
| [[-0.79712886, -0.26814285, 0.12779616], | |||
| [1.0367643, -1.6180774, 0.42999932], | |||
| [-0.81818223, -0.81502074, 0.882194], | |||
| [0.53640485, 0.4178927, 1.6037121], | |||
| [0.9256354, -1.1006796, 0.16614541]]], | |||
| [[[-1.5216796, -1.2473261, 0.6549515], | |||
| [0.63627815, 0.7221449, 0.02977821], | |||
| [-0.61331123, -0.49451825, 0.33852202], | |||
| [1.4510741, -1.3818305, -0.791747], | |||
| [0.6989747, 0.49558765, 1.0813237]], | |||
| [[-0.03969796, 0.71586496, 0.8326594], | |||
| [-0.15443641, 1.0389746, -0.59301984], | |||
| [0.7197836, 0.03257621, 1.8398637], | |||
| [0.6111736, -0.16166899, -2.4869773], | |||
| [1.3066711, -1.8003578, 0.17412892]], | |||
| [[-0.31470737, -0.5938182, -1.1311078], | |||
| [-0.99081016, 0.4005125, 0.44154453], | |||
| [1.0876914, -2.5958562, -0.5914863], | |||
| [1.3759689, -0.7741513, 0.19928917], | |||
| [1.6792973, 2.2744863, -0.04308867]]]]).astype(np.float32) | |||
| x = np.array([[[[-1.4311737, 1.015344], | |||
| [0.04431088, -2.2886624], | |||
| [1.4832113, 1.240908], | |||
| [0.67040104, 0.15266363]], | |||
| [[0.44226435, 1.1461105], | |||
| [1.194218, 1.5547837], | |||
| [0.23152256, 1.5911953], | |||
| [0.11206784, 0.17978816]], | |||
| [[-0.57803905, 0.8039611], | |||
| [0.0823025, -0.6134477], | |||
| [-1.4171146, 1.6269946], | |||
| [0.48878875, 0.9117505]]]]).astype(np.float32) | |||
| conv2d = NetConv(weight, x) | |||
| output = conv2d() | |||
| expected = np.array([[[[2.3498724], | |||
| [-1.9199573]], | |||
| [[5.376562], | |||
| [-5.425745]], | |||
| [[5.9105043], | |||
| [7.469034]]]]).astype(np.float32) | |||
| loss = np.abs(expected - output.asnumpy()) | |||
| error = 1e-4 * np.ones(loss.shape) | |||
| assert (loss < error).all() | |||
| class NetConv2dDynamic(nn.Cell): | |||
| def __init__(self, axis=0, out_nums=1): | |||
| super(NetConv2dDynamic, self).__init__() | |||
| self.dynshape = inner.GpuConvertToDynamicShape() | |||
| out_channel = 2 | |||
| kernel_size = 1 | |||
| self.conv = P.Conv2D(out_channel, | |||
| kernel_size, | |||
| mode=1, | |||
| pad_mode="valid", | |||
| pad=0, | |||
| stride=1, | |||
| dilation=1, | |||
| group=1) | |||
| def construct(self, x, w): | |||
| x_dyn = self.dynshape(x) | |||
| w_dyn = self.dynshape(w) | |||
| x_conv = self.conv(x_dyn, w_dyn) | |||
| return x_conv | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_conv2d_dynamic(): | |||
| x1 = Tensor(np.arange(1 * 3 * 3 * 3).reshape(1, 3, 3, 3).astype(np.float32)) | |||
| w1 = Tensor(np.arange(2 * 3 * 1 * 1).reshape(2, 3, 1, 1).astype(np.float32)) | |||
| expect1 = np.array([[[[45, 48, 51], | |||
| [54, 57, 60], | |||
| [63, 66, 69]], | |||
| [[126, 138, 150], | |||
| [162, 174, 186], | |||
| [198, 210, 222]]]]).astype(np.float32) | |||
| x2 = Tensor(np.arange(5 * 1 * 2 * 2).reshape(5, 1, 2, 2).astype(np.float32)) | |||
| w2 = Tensor(np.arange(2 * 1 * 1 * 1).reshape(2, 1, 1, 1).astype(np.float32)) | |||
| expect2 = np.array([[[[0., 0.], | |||
| [0., 0.]], | |||
| [[0., 1.], | |||
| [2., 3.]]], | |||
| [[[0., 0.], | |||
| [0., 0.]], | |||
| [[4., 5.], | |||
| [6., 7.]]], | |||
| [[[0., 0.], | |||
| [0., 0.]], | |||
| [[8., 9.], | |||
| [10., 11.]]], | |||
| [[[0., 0.], | |||
| [0., 0.]], | |||
| [[12., 13.], | |||
| [14., 15.]]], | |||
| [[[0., 0.], | |||
| [0., 0.]], | |||
| [[16., 17.], | |||
| [18., 19.]]]]).astype(np.float32) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| conv2d = NetConv2dDynamic() | |||
| output1 = conv2d(x1, w1) | |||
| assert (output1.asnumpy() == expect1).all() | |||
| output2 = conv2d(x2, w2) | |||
| assert (output2.asnumpy() == expect2).all() | |||