You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

cus_conv2d.py 7.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. import numpy as np
  16. import math
  17. from functools import reduce
  18. from mindspore.ops import prim_attr_register, PrimitiveWithInfer
  19. from mindspore import Tensor
  20. from mindspore._checkparam import ParamValidator as validator
  21. from mindspore._checkparam import Rel, check_bool, check_int_positive, twice
  22. from mindspore.common import dtype as mstype
  23. class Cus_Conv2D(PrimitiveWithInfer):
  24. r"""
  25. Applies 2D convolution for the input.
  26. Input is typically of shape :math:`(N, C, H, W)`, where :math:`N` is batch size and :math:`C` is channel number.
  27. For each batch of shape :math:`(C, H, W)` the formula (given mode 1) is defined as:
  28. .. math::
  29. out_j = \sum_{i=0}^{C_{in} - 1} ccor(W_{ij}, X_i) + b_j,
  30. where :math:`ccor` is cross correlation operator, :math:`C_{in}` is the input channel number, :math:`j` ranges
  31. from :math:`0` to :math:`C_{out} - 1`, :math:`W_{ij}` corresponds to i-th channel of the j-th filter and
  32. :math:`out_{j}` corresponds to the j-th channel of the output.
  33. The first introduction can be found in paper `Gradient Based Learning Applied to Document Recognition
  34. <http://vision.stanford.edu/cs598_spring07/papers/Lecun98.pdf>`_.
  35. More detailed introduction can be found here: http://cs231n.github.io/convolutional-networks/.
  36. Args:
  37. out_channel (int): The dimensionality of the output space.
  38. kernel_size (Union[int, tuple[int]]): The kernel size of the 2D convolution.
  39. mode (int): 0 Math convolutiuon, 1 cross-correlation convolution ,
  40. 2 deconvolution, 3 depthwise convolution. Default: 1.
  41. pad_mode (str): "valid", "same", "pad" the mode to fill padding. Default: "valid".
  42. pad (int): The pad value to fill. Default: 0.
  43. stride (int): The stride to apply conv filter. Default: 1.
  44. dilation (int): Specifying the dilation rate to use for dilated convolution. Default: 1.
  45. group (int): Split input into groups. Default: 1.
  46. Returns:
  47. Tensor, the value that applied 2D convolution.
  48. Inputs:
  49. - **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
  50. Outputs:
  51. Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`.
  52. """
  53. @prim_attr_register
  54. def __init__(self,
  55. out_channel,
  56. kernel_size,
  57. mode=1,
  58. pad_mode="valid",
  59. pad=0,
  60. stride=1,
  61. dilation=1,
  62. group=1):
  63. """init Conv2D"""
  64. self.init_prim_io_names(inputs=['x', 'w'], outputs=['output'])
  65. self.kernel_size = kernel_size
  66. self.kernel_size = validator.check_type('kernel_size', kernel_size, (int, tuple))
  67. if isinstance(self.kernel_size, int):
  68. self.kernel_size = (self.kernel_size, self.kernel_size)
  69. validator.check_integer('length of kernel_size', len(self.kernel_size), 2, Rel.GE)
  70. validator.equal('type of pad', type(pad), 'not bool', not isinstance(pad, bool))
  71. validator.equal('type of pad', type(pad), 'int', isinstance(pad, int))
  72. self.pad_mode = validator.check_string('pad_mode', pad_mode, ['valid', 'same', 'pad'])
  73. self.pad = validator.check_pad_value_by_mode(self.__class__.__name__, pad_mode, pad)
  74. if self.pad_mode == 'pad':
  75. validator.check_integer('pad', self.pad, 0, Rel.GE)
  76. self.mode = validator.check_integer('mode', mode, 1, Rel.EQ)
  77. self.add_prim_attr('data_format', "NCHW")
  78. self.out_channel = validator.check_integer('out_channel', out_channel, 0, Rel.GT)
  79. self.group = validator.check_integer('group', group, 0, Rel.GT)
  80. self.dilation = validator.check_integer('dilation', dilation, 1, Rel.GE)
  81. validator.check_type('kernel_size', kernel_size, [int, tuple])
  82. if isinstance(kernel_size, int) and kernel_size < 1:
  83. raise ValueError('Attr \'kernel_size\' of \'Conv2D\' Op passed '
  84. + str(self.kernel_size)+', should be a int or tuple and equal to or greater than 1.')
  85. if isinstance(kernel_size, tuple) and (len(kernel_size) != 2 or
  86. (not isinstance(kernel_size[0], int)) or
  87. (not isinstance(kernel_size[1], int)) or
  88. kernel_size[0] < 1 or kernel_size[1] < 1):
  89. raise ValueError('Attr \'kernel_size\' of \'Conv2D\' Op passed '
  90. + str(self.kernel_size)+', should be a int or tuple and equal to or greater than 1.')
  91. self.stride = validator.check_integer('stride', stride, 1, Rel.GE)
  92. from .cus_conv2d_impl import Cus_Conv2D
  93. def infer_shape(self, x_shape, w_shape):
  94. validator.check_integer("weight_shape", len(w_shape), 4, Rel.EQ)
  95. validator.check_integer("x_shape", len(x_shape), 4, Rel.EQ)
  96. validator.check_param_equal("x_shape[1]", x_shape[1] // self.group, "w_shape[1]", w_shape[1])
  97. validator.check_param_equal('out_channel', self.out_channel, 'w_shape[0]', w_shape[0])
  98. validator.check_param_equal('kernel_size', self.kernel_size, 'w_shape[2:4]', tuple(w_shape[2:4]))
  99. kernel_size_h = w_shape[2]
  100. kernel_size_w = w_shape[3]
  101. if self.pad_mode == "valid":
  102. h_out = math.ceil((x_shape[2] - kernel_size_h + 1) / self.stride)
  103. w_out = math.ceil((x_shape[3] - kernel_size_w + 1) / self.stride)
  104. pad_top, pad_bottom, pad_left, pad_right = 0, 0, 0, 0
  105. elif self.pad_mode == "same":
  106. h_out = math.ceil(x_shape[2] / self.stride)
  107. w_out = math.ceil(x_shape[3] / self.stride)
  108. pad_needed_h = max(0, (h_out - 1) * self.stride + kernel_size_h - x_shape[2])
  109. pad_top = math.floor(pad_needed_h / 2)
  110. pad_bottom = pad_needed_h - pad_top
  111. pad_needed_w = max(0, (w_out - 1) * self.stride + kernel_size_w - x_shape[3])
  112. pad_left = math.floor(pad_needed_w / 2)
  113. pad_right = pad_needed_w - pad_left
  114. elif self.pad_mode == 'pad':
  115. pad_top, pad_bottom, pad_left, pad_right = self.pad, self.pad, self.pad, self.pad
  116. h_out = 1 + (x_shape[2] + 2 * self.pad - kernel_size_h - (kernel_size_h - 1) * (self.dilation - 1)) \
  117. / self.stride
  118. w_out = 1 + (x_shape[3] + 2 * self.pad - kernel_size_w - (kernel_size_w - 1) * (self.dilation - 1)) \
  119. / self.stride
  120. h_out = math.floor(h_out)
  121. w_out = math.floor(w_out)
  122. self.pad_list = [pad_top, pad_bottom, pad_left, pad_right]
  123. self.add_prim_attr('pad_list', (pad_top, pad_bottom, pad_left, pad_right))
  124. out_channel = self.out_channel
  125. out_shape = [x_shape[0], out_channel, h_out, w_out]
  126. return out_shape
  127. def infer_dtype(self, x_dtype, w_dtype):
  128. args = {'x_dtype': x_dtype, 'w_dtype': w_dtype}
  129. validator.check_type_same(args, [mstype.int8, mstype.int32, mstype.float16, mstype.float32])
  130. return x_dtype