You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

conv.py 3.8 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # less required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. import mindspore.nn as nn
  16. from mindspore import Parameter
  17. from mindspore import dtype as mstype
  18. from mindspore.ops import operations as P
  19. from mindspore.ops.operations import nn_ops as nps
  20. from mindspore.common.initializer import initializer
  21. def weight_variable(shape):
  22. init_value = initializer('Normal', shape, mstype.float32)
  23. return Parameter(init_value)
  24. class Conv3D(nn.Cell):
  25. def __init__(self,
  26. in_channel,
  27. out_channel,
  28. kernel_size,
  29. mode=1,
  30. pad_mode="valid",
  31. pad=0,
  32. stride=1,
  33. dilation=1,
  34. group=1,
  35. data_format="NCDHW",
  36. bias_init="zeros",
  37. has_bias=True):
  38. super().__init__()
  39. self.weight_shape = (out_channel, in_channel, kernel_size[0], kernel_size[1], kernel_size[2])
  40. self.weight = weight_variable(self.weight_shape)
  41. self.conv = nps.Conv3D(out_channel=out_channel, kernel_size=kernel_size, mode=mode, \
  42. pad_mode=pad_mode, pad=pad, stride=stride, dilation=dilation, \
  43. group=group, data_format=data_format)
  44. self.bias_init = bias_init
  45. self.has_bias = has_bias
  46. self.bias_add = P.BiasAdd(data_format=data_format)
  47. if self.has_bias:
  48. self.bias = Parameter(initializer(self.bias_init, [out_channel]), name='bias')
  49. def construct(self, x):
  50. output = self.conv(x, self.weight)
  51. if self.has_bias:
  52. output = self.bias_add(output, self.bias)
  53. return output
  54. class Conv3DTranspose(nn.Cell):
  55. def __init__(self,
  56. in_channel,
  57. out_channel,
  58. kernel_size,
  59. mode=1,
  60. pad=0,
  61. stride=1,
  62. dilation=1,
  63. group=1,
  64. output_padding=0,
  65. data_format="NCDHW",
  66. bias_init="zeros",
  67. has_bias=True):
  68. super().__init__()
  69. self.weight_shape = (in_channel, out_channel, kernel_size[0], kernel_size[1], kernel_size[2])
  70. self.weight = weight_variable(self.weight_shape)
  71. self.conv_transpose = nps.Conv3DTranspose(in_channel=in_channel, out_channel=out_channel,\
  72. kernel_size=kernel_size, mode=mode, pad=pad, stride=stride, \
  73. dilation=dilation, group=group, output_padding=output_padding, \
  74. data_format=data_format)
  75. self.bias_init = bias_init
  76. self.has_bias = has_bias
  77. self.bias_add = P.BiasAdd(data_format=data_format)
  78. if self.has_bias:
  79. self.bias = Parameter(initializer(self.bias_init, [out_channel]), name='bias')
  80. def construct(self, x):
  81. output = self.conv_transpose(x, self.weight)
  82. if self.has_bias:
  83. output = self.bias_add(output, self.bias)
  84. return output