|
- # Copyright 2021 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # less required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
-
- import mindspore.nn as nn
- from mindspore import Parameter
- from mindspore import dtype as mstype
- from mindspore.ops import operations as P
- from mindspore.ops.operations import nn_ops as nps
- from mindspore.common.initializer import initializer
-
- def weight_variable(shape):
- init_value = initializer('Normal', shape, mstype.float32)
- return Parameter(init_value)
-
- class Conv3D(nn.Cell):
- def __init__(self,
- in_channel,
- out_channel,
- kernel_size,
- mode=1,
- pad_mode="valid",
- pad=0,
- stride=1,
- dilation=1,
- group=1,
- data_format="NCDHW",
- bias_init="zeros",
- has_bias=True):
- super().__init__()
- self.weight_shape = (out_channel, in_channel, kernel_size[0], kernel_size[1], kernel_size[2])
- self.weight = weight_variable(self.weight_shape)
- self.conv = nps.Conv3D(out_channel=out_channel, kernel_size=kernel_size, mode=mode, \
- pad_mode=pad_mode, pad=pad, stride=stride, dilation=dilation, \
- group=group, data_format=data_format)
- self.bias_init = bias_init
- self.has_bias = has_bias
- self.bias_add = P.BiasAdd(data_format=data_format)
- if self.has_bias:
- self.bias = Parameter(initializer(self.bias_init, [out_channel]), name='bias')
-
- def construct(self, x):
- output = self.conv(x, self.weight)
- if self.has_bias:
- output = self.bias_add(output, self.bias)
- return output
-
- class Conv3DTranspose(nn.Cell):
- def __init__(self,
- in_channel,
- out_channel,
- kernel_size,
- mode=1,
- pad=0,
- stride=1,
- dilation=1,
- group=1,
- output_padding=0,
- data_format="NCDHW",
- bias_init="zeros",
- has_bias=True):
- super().__init__()
- self.weight_shape = (in_channel, out_channel, kernel_size[0], kernel_size[1], kernel_size[2])
- self.weight = weight_variable(self.weight_shape)
- self.conv_transpose = nps.Conv3DTranspose(in_channel=in_channel, out_channel=out_channel,\
- kernel_size=kernel_size, mode=mode, pad=pad, stride=stride, \
- dilation=dilation, group=group, output_padding=output_padding, \
- data_format=data_format)
- self.bias_init = bias_init
- self.has_bias = has_bias
- self.bias_add = P.BiasAdd(data_format=data_format)
- if self.has_bias:
- self.bias = Parameter(initializer(self.bias_init, [out_channel]), name='bias')
-
- def construct(self, x):
- output = self.conv_transpose(x, self.weight)
- if self.has_bias:
- output = self.bias_add(output, self.bias)
- return output
|