Browse Source

maskrcnn parameter dtype

tags/v1.2.0-rc1
ttudu 4 years ago
parent
commit
dad3172abb
8 changed files with 78 additions and 62 deletions
  1. +6
    -4
      model_zoo/official/cv/maskrcnn/src/maskrcnn/fpn_neck.py
  2. +10
    -10
      model_zoo/official/cv/maskrcnn/src/maskrcnn/rcnn_cls.py
  3. +16
    -10
      model_zoo/official/cv/maskrcnn/src/maskrcnn/rcnn_mask.py
  4. +7
    -7
      model_zoo/official/cv/maskrcnn/src/maskrcnn/rpn.py
  5. +6
    -4
      model_zoo/official/cv/maskrcnn_mobilenetv1/src/maskrcnn_mobilenetv1/fpn_neck.py
  6. +10
    -10
      model_zoo/official/cv/maskrcnn_mobilenetv1/src/maskrcnn_mobilenetv1/rcnn_cls.py
  7. +16
    -10
      model_zoo/official/cv/maskrcnn_mobilenetv1/src/maskrcnn_mobilenetv1/rcnn_mask.py
  8. +7
    -7
      model_zoo/official/cv/maskrcnn_mobilenetv1/src/maskrcnn_mobilenetv1/rpn.py

+ 6
- 4
model_zoo/official/cv/maskrcnn/src/maskrcnn/fpn_neck.py View File

@@ -24,12 +24,12 @@ from mindspore.common.initializer import initializer

def bias_init_zeros(shape):
"""Bias init method."""
return Tensor(np.array(np.zeros(shape).astype(np.float32)).astype(np.float16))
return Tensor(np.array(np.zeros(shape).astype(np.float32)), dtype=mstype.float32)

def _conv(in_channels, out_channels, kernel_size=3, stride=1, padding=0, pad_mode='pad'):
"""Conv2D wrapper."""
shape = (out_channels, in_channels, kernel_size, kernel_size)
weights = initializer("XavierUniform", shape=shape, dtype=mstype.float16)
weights = initializer("XavierUniform", shape=shape, dtype=mstype.float32)
shape_bias = (out_channels,)
biass = bias_init_zeros(shape_bias)
return nn.Conv2d(in_channels, out_channels,
@@ -76,8 +76,10 @@ class FeatPyramidNeck(nn.Cell):
self.fpn_convs_ = []

for _, channel in enumerate(in_channels):
l_conv = _conv(channel, out_channels, kernel_size=1, stride=1, padding=0, pad_mode='valid')
fpn_conv = _conv(out_channels, out_channels, kernel_size=3, stride=1, padding=0, pad_mode='same')
l_conv = _conv(channel, out_channels, kernel_size=1, stride=1,
padding=0, pad_mode='valid').to_float(mstype.float16)
fpn_conv = _conv(out_channels, out_channels, kernel_size=3, stride=1,
padding=0, pad_mode='same').to_float(mstype.float16)
self.lateral_convs_list_.append(l_conv)
self.fpn_convs_.append(fpn_conv)
self.lateral_convs_list = nn.layer.CellList(self.lateral_convs_list_)


+ 10
- 10
model_zoo/official/cv/maskrcnn/src/maskrcnn/rcnn_cls.py View File

@@ -26,8 +26,8 @@ class DenseNoTranpose(nn.Cell):
"""Dense method"""
def __init__(self, input_channels, output_channels, weight_init):
super(DenseNoTranpose, self).__init__()
self.weight = Parameter(initializer(weight_init, [input_channels, output_channels], mstype.float16))
self.bias = Parameter(initializer("zeros", [output_channels], mstype.float16))
self.weight = Parameter(initializer(weight_init, [input_channels, output_channels], mstype.float32))
self.bias = Parameter(initializer("zeros", [output_channels], mstype.float32))
self.matmul = P.MatMul(transpose_b=False)
self.bias_add = P.BiasAdd()

@@ -41,18 +41,18 @@ class FpnCls(nn.Cell):
super(FpnCls, self).__init__()
representation_size = input_channels * pool_size * pool_size
shape_0 = (output_channels, representation_size)
weights_0 = initializer("XavierUniform", shape=shape_0[::-1], dtype=mstype.float16)
weights_0 = initializer("XavierUniform", shape=shape_0[::-1], dtype=mstype.float32)
shape_1 = (output_channels, output_channels)
weights_1 = initializer("XavierUniform", shape=shape_1[::-1], dtype=mstype.float16)
self.shared_fc_0 = DenseNoTranpose(representation_size, output_channels, weights_0)
self.shared_fc_1 = DenseNoTranpose(output_channels, output_channels, weights_1)
weights_1 = initializer("XavierUniform", shape=shape_1[::-1], dtype=mstype.float32)
self.shared_fc_0 = DenseNoTranpose(representation_size, output_channels, weights_0).to_float(mstype.float16)
self.shared_fc_1 = DenseNoTranpose(output_channels, output_channels, weights_1).to_float(mstype.float16)

cls_weight = initializer('Normal', shape=[num_classes, output_channels][::-1],
dtype=mstype.float16)
dtype=mstype.float32)
reg_weight = initializer('Normal', shape=[num_classes * 4, output_channels][::-1],
dtype=mstype.float16)
self.cls_scores = DenseNoTranpose(output_channels, num_classes, cls_weight)
self.reg_scores = DenseNoTranpose(output_channels, num_classes * 4, reg_weight)
dtype=mstype.float32)
self.cls_scores = DenseNoTranpose(output_channels, num_classes, cls_weight).to_float(mstype.float16)
self.reg_scores = DenseNoTranpose(output_channels, num_classes * 4, reg_weight).to_float(mstype.float16)

self.relu = P.ReLU()
self.flatten = P.Flatten()


+ 16
- 10
model_zoo/official/cv/maskrcnn/src/maskrcnn/rcnn_mask.py View File

@@ -24,9 +24,9 @@ from mindspore.common.initializer import initializer
def _conv(in_channels, out_channels, kernel_size=1, stride=1, padding=0, pad_mode='pad'):
"""Conv2D wrapper."""
shape = (out_channels, in_channels, kernel_size, kernel_size)
weights = initializer("XavierUniform", shape=shape, dtype=mstype.float16)
weights = initializer("XavierUniform", shape=shape, dtype=mstype.float32)
shape_bias = (out_channels,)
bias = Tensor(np.array(np.zeros(shape_bias)).astype(np.float16))
bias = Tensor(np.array(np.zeros(shape_bias)).astype(np.float32))
return nn.Conv2d(in_channels, out_channels,
kernel_size=kernel_size, stride=stride, padding=padding,
pad_mode=pad_mode, weight_init=weights, has_bias=True, bias_init=bias)
@@ -34,9 +34,9 @@ def _conv(in_channels, out_channels, kernel_size=1, stride=1, padding=0, pad_mod
def _convTanspose(in_channels, out_channels, kernel_size=1, stride=1, padding=0, pad_mode='pad'):
"""ConvTranspose wrapper."""
shape = (out_channels, in_channels, kernel_size, kernel_size)
weights = initializer("XavierUniform", shape=shape, dtype=mstype.float16)
weights = initializer("XavierUniform", shape=shape, dtype=mstype.float32)
shape_bias = (out_channels,)
bias = Tensor(np.array(np.zeros(shape_bias)).astype(np.float16))
bias = Tensor(np.array(np.zeros(shape_bias)).astype(np.float32))
return nn.Conv2dTranspose(in_channels, out_channels,
kernel_size=kernel_size, stride=stride, padding=padding,
pad_mode=pad_mode, weight_init=weights, has_bias=True, bias_init=bias)
@@ -45,21 +45,27 @@ class FpnMask(nn.Cell):
"""conv layers of mask head"""
def __init__(self, input_channels, output_channels, num_classes):
super(FpnMask, self).__init__()
self.mask_conv1 = _conv(input_channels, output_channels, kernel_size=3, pad_mode="same")
self.mask_conv1 = _conv(input_channels, output_channels, kernel_size=3,
pad_mode="same").to_float(mstype.float16)
self.mask_relu1 = P.ReLU()

self.mask_conv2 = _conv(output_channels, output_channels, kernel_size=3, pad_mode="same")
self.mask_conv2 = _conv(output_channels, output_channels, kernel_size=3,
pad_mode="same").to_float(mstype.float16)
self.mask_relu2 = P.ReLU()

self.mask_conv3 = _conv(output_channels, output_channels, kernel_size=3, pad_mode="same")
self.mask_conv3 = _conv(output_channels, output_channels, kernel_size=3,
pad_mode="same").to_float(mstype.float16)
self.mask_relu3 = P.ReLU()

self.mask_conv4 = _conv(output_channels, output_channels, kernel_size=3, pad_mode="same")
self.mask_conv4 = _conv(output_channels, output_channels, kernel_size=3,
pad_mode="same").to_float(mstype.float16)
self.mask_relu4 = P.ReLU()

self.mask_deconv5 = _convTanspose(output_channels, output_channels, kernel_size=2, stride=2, pad_mode="valid")
self.mask_deconv5 = _convTanspose(output_channels, output_channels, kernel_size=2,
stride=2, pad_mode="valid").to_float(mstype.float16)
self.mask_relu5 = P.ReLU()
self.mask_conv6 = _conv(output_channels, num_classes, kernel_size=1, stride=1, pad_mode="valid")
self.mask_conv6 = _conv(output_channels, num_classes, kernel_size=1, stride=1,
pad_mode="valid").to_float(mstype.float16)

def construct(self, x):
x = self.mask_conv1(x)


+ 7
- 7
model_zoo/official/cv/maskrcnn/src/maskrcnn/rpn.py View File

@@ -164,23 +164,23 @@ class RPN(nn.Cell):

shp_weight_conv = (feat_channels, in_channels, 3, 3)
shp_bias_conv = (feat_channels,)
weight_conv = initializer('Normal', shape=shp_weight_conv, dtype=mstype.float16)
bias_conv = initializer(0, shape=shp_bias_conv, dtype=mstype.float16)
weight_conv = initializer('Normal', shape=shp_weight_conv, dtype=mstype.float32)
bias_conv = initializer(0, shape=shp_bias_conv, dtype=mstype.float32)

shp_weight_cls = (num_anchors * cls_out_channels, feat_channels, 1, 1)
shp_bias_cls = (num_anchors * cls_out_channels,)
weight_cls = initializer('Normal', shape=shp_weight_cls, dtype=mstype.float16)
bias_cls = initializer(0, shape=shp_bias_cls, dtype=mstype.float16)
weight_cls = initializer('Normal', shape=shp_weight_cls, dtype=mstype.float32)
bias_cls = initializer(0, shape=shp_bias_cls, dtype=mstype.float32)

shp_weight_reg = (num_anchors * 4, feat_channels, 1, 1)
shp_bias_reg = (num_anchors * 4,)
weight_reg = initializer('Normal', shape=shp_weight_reg, dtype=mstype.float16)
bias_reg = initializer(0, shape=shp_bias_reg, dtype=mstype.float16)
weight_reg = initializer('Normal', shape=shp_weight_reg, dtype=mstype.float32)
bias_reg = initializer(0, shape=shp_bias_reg, dtype=mstype.float32)

for i in range(num_layers):
rpn_layer.append(RpnRegClsBlock(in_channels, feat_channels, num_anchors, cls_out_channels, \
weight_conv, bias_conv, weight_cls, \
bias_cls, weight_reg, bias_reg))
bias_cls, weight_reg, bias_reg).to_float(mstype.float16))

for i in range(1, num_layers):
rpn_layer[i].rpn_conv.weight = rpn_layer[0].rpn_conv.weight


+ 6
- 4
model_zoo/official/cv/maskrcnn_mobilenetv1/src/maskrcnn_mobilenetv1/fpn_neck.py View File

@@ -24,12 +24,12 @@ from mindspore.common.initializer import initializer

def bias_init_zeros(shape):
"""Bias init method."""
return Tensor(np.array(np.zeros(shape).astype(np.float32)).astype(np.float16))
return Tensor(np.array(np.zeros(shape).astype(np.float32)), dtype=mstype.float32)

def _conv(in_channels, out_channels, kernel_size=3, stride=1, padding=0, pad_mode='pad'):
"""Conv2D wrapper."""
shape = (out_channels, in_channels, kernel_size, kernel_size)
weights = initializer("XavierUniform", shape=shape, dtype=mstype.float16).to_tensor()
weights = initializer("XavierUniform", shape=shape, dtype=mstype.float32).to_tensor()
shape_bias = (out_channels,)
biass = bias_init_zeros(shape_bias)
return nn.Conv2d(in_channels, out_channels,
@@ -76,8 +76,10 @@ class FeatPyramidNeck(nn.Cell):
self.fpn_convs_ = []

for _, channel in enumerate(in_channels):
l_conv = _conv(channel, out_channels, kernel_size=1, stride=1, padding=0, pad_mode='valid')
fpn_conv = _conv(out_channels, out_channels, kernel_size=3, stride=1, padding=0, pad_mode='same')
l_conv = _conv(channel, out_channels, kernel_size=1, stride=1, padding=0,
pad_mode='valid').to_float(mstype.float16)
fpn_conv = _conv(out_channels, out_channels, kernel_size=3, stride=1, padding=0,
pad_mode='same').to_float(mstype.float16)
self.lateral_convs_list_.append(l_conv)
self.fpn_convs_.append(fpn_conv)
self.lateral_convs_list = nn.layer.CellList(self.lateral_convs_list_)


+ 10
- 10
model_zoo/official/cv/maskrcnn_mobilenetv1/src/maskrcnn_mobilenetv1/rcnn_cls.py View File

@@ -26,9 +26,9 @@ class DenseNoTranpose(nn.Cell):
"""Dense method"""
def __init__(self, input_channels, output_channels, weight_init):
super(DenseNoTranpose, self).__init__()
self.weight = Parameter(initializer(weight_init, [input_channels, output_channels], mstype.float16),
self.weight = Parameter(initializer(weight_init, [input_channels, output_channels], mstype.float32),
name="weight")
self.bias = Parameter(initializer("zeros", [output_channels], mstype.float16), name="bias")
self.bias = Parameter(initializer("zeros", [output_channels], mstype.float32), name="bias")
self.matmul = P.MatMul(transpose_b=False)
self.bias_add = P.BiasAdd()

@@ -42,18 +42,18 @@ class FpnCls(nn.Cell):
super(FpnCls, self).__init__()
representation_size = input_channels * pool_size * pool_size
shape_0 = (output_channels, representation_size)
weights_0 = initializer("XavierUniform", shape=shape_0[::-1], dtype=mstype.float16)
weights_0 = initializer("XavierUniform", shape=shape_0[::-1], dtype=mstype.float32)
shape_1 = (output_channels, output_channels)
weights_1 = initializer("XavierUniform", shape=shape_1[::-1], dtype=mstype.float16)
self.shared_fc_0 = DenseNoTranpose(representation_size, output_channels, weights_0)
self.shared_fc_1 = DenseNoTranpose(output_channels, output_channels, weights_1)
weights_1 = initializer("XavierUniform", shape=shape_1[::-1], dtype=mstype.float32)
self.shared_fc_0 = DenseNoTranpose(representation_size, output_channels, weights_0).to_float(mstype.float16)
self.shared_fc_1 = DenseNoTranpose(output_channels, output_channels, weights_1).to_float(mstype.float16)

cls_weight = initializer('Normal', shape=[num_classes, output_channels][::-1],
dtype=mstype.float16)
dtype=mstype.float32)
reg_weight = initializer('Normal', shape=[num_classes * 4, output_channels][::-1],
dtype=mstype.float16)
self.cls_scores = DenseNoTranpose(output_channels, num_classes, cls_weight)
self.reg_scores = DenseNoTranpose(output_channels, num_classes * 4, reg_weight)
dtype=mstype.float32)
self.cls_scores = DenseNoTranpose(output_channels, num_classes, cls_weight).to_float(mstype.float16)
self.reg_scores = DenseNoTranpose(output_channels, num_classes * 4, reg_weight).to_float(mstype.float16)

self.relu = P.ReLU()
self.flatten = P.Flatten()


+ 16
- 10
model_zoo/official/cv/maskrcnn_mobilenetv1/src/maskrcnn_mobilenetv1/rcnn_mask.py View File

@@ -24,9 +24,9 @@ from mindspore.common.initializer import initializer
def _conv(in_channels, out_channels, kernel_size=1, stride=1, padding=0, pad_mode='pad'):
"""Conv2D wrapper."""
shape = (out_channels, in_channels, kernel_size, kernel_size)
weights = initializer("XavierUniform", shape=shape, dtype=mstype.float16)
weights = initializer("XavierUniform", shape=shape, dtype=mstype.float32)
shape_bias = (out_channels,)
bias = Tensor(np.array(np.zeros(shape_bias)).astype(np.float16))
bias = Tensor(np.array(np.zeros(shape_bias)).astype(np.float32))
return nn.Conv2d(in_channels, out_channels,
kernel_size=kernel_size, stride=stride, padding=padding,
pad_mode=pad_mode, weight_init=weights, has_bias=True, bias_init=bias)
@@ -34,9 +34,9 @@ def _conv(in_channels, out_channels, kernel_size=1, stride=1, padding=0, pad_mod
def _convTanspose(in_channels, out_channels, kernel_size=1, stride=1, padding=0, pad_mode='pad'):
"""ConvTranspose wrapper."""
shape = (out_channels, in_channels, kernel_size, kernel_size)
weights = initializer("XavierUniform", shape=shape, dtype=mstype.float16)
weights = initializer("XavierUniform", shape=shape, dtype=mstype.float32)
shape_bias = (out_channels,)
bias = Tensor(np.array(np.zeros(shape_bias)).astype(np.float16))
bias = Tensor(np.array(np.zeros(shape_bias)).astype(np.float32))
return nn.Conv2dTranspose(in_channels, out_channels,
kernel_size=kernel_size, stride=stride, padding=padding,
pad_mode=pad_mode, weight_init=weights, has_bias=True, bias_init=bias)
@@ -45,21 +45,27 @@ class FpnMask(nn.Cell):
"""conv layers of mask head"""
def __init__(self, input_channels, output_channels, num_classes):
super(FpnMask, self).__init__()
self.mask_conv1 = _conv(input_channels, output_channels, kernel_size=3, pad_mode="same")
self.mask_conv1 = _conv(input_channels, output_channels, kernel_size=3,
pad_mode="same").to_float(mstype.float16)
self.mask_relu1 = P.ReLU()

self.mask_conv2 = _conv(output_channels, output_channels, kernel_size=3, pad_mode="same")
self.mask_conv2 = _conv(output_channels, output_channels, kernel_size=3,
pad_mode="same").to_float(mstype.float16)
self.mask_relu2 = P.ReLU()

self.mask_conv3 = _conv(output_channels, output_channels, kernel_size=3, pad_mode="same")
self.mask_conv3 = _conv(output_channels, output_channels, kernel_size=3,
pad_mode="same").to_float(mstype.float16)
self.mask_relu3 = P.ReLU()

self.mask_conv4 = _conv(output_channels, output_channels, kernel_size=3, pad_mode="same")
self.mask_conv4 = _conv(output_channels, output_channels, kernel_size=3,
pad_mode="same").to_float(mstype.float16)
self.mask_relu4 = P.ReLU()

self.mask_deconv5 = _convTanspose(output_channels, output_channels, kernel_size=2, stride=2, pad_mode="valid")
self.mask_deconv5 = _convTanspose(output_channels, output_channels, kernel_size=2,
stride=2, pad_mode="valid").to_float(mstype.float16)
self.mask_relu5 = P.ReLU()
self.mask_conv6 = _conv(output_channels, num_classes, kernel_size=1, stride=1, pad_mode="valid")
self.mask_conv6 = _conv(output_channels, num_classes, kernel_size=1, stride=1,
pad_mode="valid").to_float(mstype.float16)

def construct(self, x):
x = self.mask_conv1(x)


+ 7
- 7
model_zoo/official/cv/maskrcnn_mobilenetv1/src/maskrcnn_mobilenetv1/rpn.py View File

@@ -164,23 +164,23 @@ class RPN(nn.Cell):

shp_weight_conv = (feat_channels, in_channels, 3, 3)
shp_bias_conv = (feat_channels,)
weight_conv = initializer('Normal', shape=shp_weight_conv, dtype=mstype.float16)
bias_conv = initializer(0, shape=shp_bias_conv, dtype=mstype.float16)
weight_conv = initializer('Normal', shape=shp_weight_conv, dtype=mstype.float32)
bias_conv = initializer(0, shape=shp_bias_conv, dtype=mstype.float32)

shp_weight_cls = (num_anchors * cls_out_channels, feat_channels, 1, 1)
shp_bias_cls = (num_anchors * cls_out_channels,)
weight_cls = initializer('Normal', shape=shp_weight_cls, dtype=mstype.float16)
bias_cls = initializer(0, shape=shp_bias_cls, dtype=mstype.float16)
weight_cls = initializer('Normal', shape=shp_weight_cls, dtype=mstype.float32)
bias_cls = initializer(0, shape=shp_bias_cls, dtype=mstype.float32)

shp_weight_reg = (num_anchors * 4, feat_channels, 1, 1)
shp_bias_reg = (num_anchors * 4,)
weight_reg = initializer('Normal', shape=shp_weight_reg, dtype=mstype.float16)
bias_reg = initializer(0, shape=shp_bias_reg, dtype=mstype.float16)
weight_reg = initializer('Normal', shape=shp_weight_reg, dtype=mstype.float32)
bias_reg = initializer(0, shape=shp_bias_reg, dtype=mstype.float32)

for i in range(num_layers):
rpn_layer.append(RpnRegClsBlock(in_channels, feat_channels, num_anchors, cls_out_channels, \
weight_conv, bias_conv, weight_cls, \
bias_cls, weight_reg, bias_reg))
bias_cls, weight_reg, bias_reg).to_float(mstype.float16))

for i in range(1, num_layers):
rpn_layer[i].rpn_conv.weight = rpn_layer[0].rpn_conv.weight


Loading…
Cancel
Save