From f4b3bafb70790a57e86c337238ebaeef253330d0 Mon Sep 17 00:00:00 2001 From: CaoJian Date: Thu, 25 Feb 2021 23:17:49 +0800 Subject: [PATCH] add ssd vgg backbone support --- model_zoo/official/cv/ssd/README.md | 3 +- model_zoo/official/cv/ssd/eval.py | 4 +- model_zoo/official/cv/ssd/export.py | 4 +- model_zoo/official/cv/ssd/src/config.py | 4 +- .../official/cv/ssd/src/config_ssd_vgg16.py | 84 ++++++++++++++++ model_zoo/official/cv/ssd/src/ssd.py | 76 ++++++++++++++ model_zoo/official/cv/ssd/src/vgg16.py | 99 +++++++++++++++++++ model_zoo/official/cv/ssd/train.py | 15 ++- 8 files changed, 284 insertions(+), 5 deletions(-) create mode 100644 model_zoo/official/cv/ssd/src/config_ssd_vgg16.py create mode 100644 model_zoo/official/cv/ssd/src/vgg16.py diff --git a/model_zoo/official/cv/ssd/README.md b/model_zoo/official/cv/ssd/README.md index aa0a7f2ed0..6b9f548000 100644 --- a/model_zoo/official/cv/ssd/README.md +++ b/model_zoo/official/cv/ssd/README.md @@ -35,11 +35,12 @@ SSD discretizes the output space of bounding boxes into a set of default boxes o The SSD approach is based on a feed-forward convolutional network that produces a fixed-size collection of bounding boxes and scores for the presence of object class instances in those boxes, followed by a non-maximum suppression step to produce the final detections. The early network layers are based on a standard architecture used for high quality image classification, which is called the base network. Then add auxiliary structure to the network to produce detections. -We present three different base architecture. +We present four different base architecture. - **ssd300**, reference from the paper. Using mobilenetv2 as backbone and the same bbox predictor as the paper present. - ***ssd-mobilenet-v1-fpn**, using mobilenet-v1 and FPN as feature extractor with weight-shared box predcitors. - ***ssd-resnet50-fpn**, using resnet50 and FPN as feature extractor with weight-shared box predcitors. +- **ssd-vgg16**, reference from the paper. Using vgg16 as backbone and the same bbox predictor as the paper present. ## [Dataset](#contents) diff --git a/model_zoo/official/cv/ssd/eval.py b/model_zoo/official/cv/ssd/eval.py index 459162f173..ca4bea9cee 100644 --- a/model_zoo/official/cv/ssd/eval.py +++ b/model_zoo/official/cv/ssd/eval.py @@ -21,7 +21,7 @@ import time import numpy as np from mindspore import context, Tensor from mindspore.train.serialization import load_checkpoint, load_param_into_net -from src.ssd import SSD300, SsdInferWithDecoder, ssd_mobilenet_v2, ssd_mobilenet_v1_fpn, ssd_resnet50_fpn +from src.ssd import SSD300, SsdInferWithDecoder, ssd_mobilenet_v2, ssd_mobilenet_v1_fpn, ssd_resnet50_fpn, ssd_vgg16 from src.dataset import create_ssd_dataset, create_mindrecord from src.config import config from src.eval_utils import metrics @@ -34,6 +34,8 @@ def ssd_eval(dataset_path, ckpt_path, anno_json): is_training=False, use_multiprocessing=False) if config.model == "ssd300": net = SSD300(ssd_mobilenet_v2(), config, is_training=False) + elif config.model == "ssd_vgg16": + net = ssd_vgg16(config=config) elif config.model == "ssd_mobilenet_v1_fpn": net = ssd_mobilenet_v1_fpn(config=config) elif config.model == "ssd_resnet50_fpn": diff --git a/model_zoo/official/cv/ssd/export.py b/model_zoo/official/cv/ssd/export.py index 83a23f9157..4aff4e553d 100644 --- a/model_zoo/official/cv/ssd/export.py +++ b/model_zoo/official/cv/ssd/export.py @@ -19,7 +19,7 @@ import numpy as np import mindspore from mindspore import context, Tensor from mindspore.train.serialization import load_checkpoint, load_param_into_net, export -from src.ssd import SSD300, SsdInferWithDecoder, ssd_mobilenet_v2, ssd_mobilenet_v1_fpn, ssd_resnet50_fpn +from src.ssd import SSD300, SsdInferWithDecoder, ssd_mobilenet_v2, ssd_mobilenet_v1_fpn, ssd_resnet50_fpn, ssd_vgg16 from src.config import config from src.box_utils import default_boxes @@ -40,6 +40,8 @@ if args.device_target == "Ascend": if __name__ == '__main__': if config.model == "ssd300": net = SSD300(ssd_mobilenet_v2(), config, is_training=False) + elif config.model == "ssd_vgg16": + net = ssd_vgg16(config=config) elif config.model == "ssd_mobilenet_v1_fpn": net = ssd_mobilenet_v1_fpn(config=config) elif config.model == "ssd_resnet50_fpn": diff --git a/model_zoo/official/cv/ssd/src/config.py b/model_zoo/official/cv/ssd/src/config.py index 3c0ae01d66..014edd0ffc 100644 --- a/model_zoo/official/cv/ssd/src/config.py +++ b/model_zoo/official/cv/ssd/src/config.py @@ -11,18 +11,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -#" ============================================================================ +# ============================================================================ """Config parameters for SSD models.""" from .config_ssd300 import config as config_ssd300 from .config_ssd_mobilenet_v1_fpn import config as config_ssd_mobilenet_v1_fpn from .config_ssd_resnet50_fpn import config as config_ssd_resnet50_fpn +from .config_ssd_vgg16 import config as config_ssd_vgg16 using_model = "ssd300" config_map = { "ssd300": config_ssd300, + "ssd_vgg16": config_ssd_vgg16, "ssd_mobilenet_v1_fpn": config_ssd_mobilenet_v1_fpn, "ssd_resnet50_fpn": config_ssd_resnet50_fpn } diff --git a/model_zoo/official/cv/ssd/src/config_ssd_vgg16.py b/model_zoo/official/cv/ssd/src/config_ssd_vgg16.py new file mode 100644 index 0000000000..e0d2a005f5 --- /dev/null +++ b/model_zoo/official/cv/ssd/src/config_ssd_vgg16.py @@ -0,0 +1,84 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Config parameters for SSD models.""" + +from easydict import EasyDict as ed + +config = ed({ + "model": "ssd_vgg16", + "img_shape": [300, 300], + "num_ssd_boxes": 7308, + "match_threshold": 0.5, + "nms_threshold": 0.6, + "min_score": 0.1, + "max_boxes": 100, + "ssd_vgg_bn": False, + + # learing rate settings + "lr_init": 0.001, + "lr_end_rate": 0.001, + "warmup_epochs": 2, + "momentum": 0.9, + "weight_decay": 1.5e-4, + + # network + "num_default": [3, 6, 6, 6, 6, 6], + "extras_in_channels": [256, 512, 1024, 512, 256, 256], + "extras_out_channels": [512, 1024, 512, 256, 256, 256], + "extras_strides": [1, 1, 2, 2, 2, 2], + "extras_ratio": [0.2, 0.2, 0.2, 0.25, 0.5, 0.25], + "feature_size": [38, 19, 10, 5, 3, 1], + "min_scale": 0.2, + "max_scale": 0.95, + "aspect_ratios": [(), (2, 3), (2, 3), (2, 3), (2, 3), (2, 3)], + "steps": (8, 16, 32, 64, 100, 300), + "prior_scaling": (0.1, 0.2), + "gamma": 2.0, + "alpha": 0.75, + + # `mindrecord_dir` and `coco_root` are better to use absolute path. + "feature_extractor_base_param": "", + "pretrain_vgg_bn": False, + "checkpoint_filter_list": ['multi_loc_layers', 'multi_cls_layers'], + "mindrecord_dir": "/data/MindRecord_COCO", + "coco_root": "/data/coco2017", + "train_data_type": "train2017", + "val_data_type": "val2017", + "instances_set": "annotations/instances_{}.json", + "classes": ('background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', + 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', + 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', + 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', + 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', + 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', + 'kite', 'baseball bat', 'baseball glove', 'skateboard', + 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', + 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', + 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', + 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', + 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', + 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', + 'refrigerator', 'book', 'clock', 'vase', 'scissors', + 'teddy bear', 'hair drier', 'toothbrush'), + "num_classes": 81, + # The annotation.json position of voc validation dataset. + "voc_json": "annotations/voc_instances_val.json", + # voc original dataset. + "voc_root": "/data/voc_dataset", + # if coco or voc used, `image_dir` and `anno_path` are useless. + "image_dir": "", + "anno_path": "" +}) diff --git a/model_zoo/official/cv/ssd/src/ssd.py b/model_zoo/official/cv/ssd/src/ssd.py index e4a889e939..f030f75dcb 100644 --- a/model_zoo/official/cv/ssd/src/ssd.py +++ b/model_zoo/official/cv/ssd/src/ssd.py @@ -27,6 +27,7 @@ from mindspore.ops import functional as F from mindspore.ops import composite as C from .fpn import mobilenet_v1_fpn, resnet50_fpn +from .vgg16 import vgg16 def _make_divisible(v, divisor, min_value=None): @@ -641,3 +642,78 @@ def ssd_resnet50_fpn(**kwargs): def ssd_mobilenet_v2(**kwargs): return SSDWithMobileNetV2(**kwargs) + + +class SSD300VGG16(nn.Cell): + def __init__(self, config): + super(SSD300VGG16, self).__init__() + + # VGG16 backbone: block1~5 + self.backbone = vgg16() + + # SSD blocks: block6~7 + self.b6_1 = nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=3, padding=6, dilation=6, pad_mode='pad') + self.b6_2 = nn.Dropout(0.5) + + self.b7_1 = nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=1) + self.b7_2 = nn.Dropout(0.5) + + # Extra Feature Layers: block8~11 + self.b8_1 = nn.Conv2d(in_channels=1024, out_channels=256, kernel_size=1, padding=1, pad_mode='pad') + self.b8_2 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=2, pad_mode='valid') + + self.b9_1 = nn.Conv2d(in_channels=512, out_channels=128, kernel_size=1, padding=1, pad_mode='pad') + self.b9_2 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=2, pad_mode='valid') + + self.b10_1 = nn.Conv2d(in_channels=256, out_channels=128, kernel_size=1) + self.b10_2 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, pad_mode='valid') + + self.b11_1 = nn.Conv2d(in_channels=256, out_channels=128, kernel_size=1) + self.b11_2 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, pad_mode='valid') + + # boxes + self.multi_box = MultiBox(config) + if not self.training: + self.activation = P.Sigmoid() + + def construct(self, x): + # VGG16 backbone: block1~5 + block4, x = self.backbone(x) + + # SSD blocks: block6~7 + x = self.b6_1(x) # 1024 + x = self.b6_2(x) + + x = self.b7_1(x) # 1024 + x = self.b7_2(x) + block7 = x + + # Extra Feature Layers: block8~11 + x = self.b8_1(x) # 256 + x = self.b8_2(x) # 512 + block8 = x + + x = self.b9_1(x) # 128 + x = self.b9_2(x) # 256 + block9 = x + + x = self.b10_1(x) # 128 + x = self.b10_2(x) # 256 + block10 = x + + x = self.b11_1(x) # 128 + x = self.b11_2(x) # 256 + block11 = x + + # boxes + multi_feature = (block4, block7, block8, block9, block10, block11) + pred_loc, pred_label = self.multi_box(multi_feature) + if not self.training: + pred_label = self.activation(pred_label) + pred_loc = F.cast(pred_loc, mstype.float32) + pred_label = F.cast(pred_label, mstype.float32) + return pred_loc, pred_label + + +def ssd_vgg16(**kwargs): + return SSD300VGG16(**kwargs) diff --git a/model_zoo/official/cv/ssd/src/vgg16.py b/model_zoo/official/cv/ssd/src/vgg16.py new file mode 100644 index 0000000000..a5e25f1a82 --- /dev/null +++ b/model_zoo/official/cv/ssd/src/vgg16.py @@ -0,0 +1,99 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""VGG16 backbone for SSD""" + +from mindspore import nn +from .config_ssd_vgg16 import config + +pretrain_vgg_bn = config.pretrain_vgg_bn +ssd_vgg_bn = config.ssd_vgg_bn + + +def _get_key_mapper(): + vgg_key_num = [1, 1, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5] + size = len(vgg_key_num) + + pretrain_vgg_bn_false = [0, 2, 5, 7, 10, 12, 14, 17, 19, 21, 24, 26, 28] + pretrain_vgg_bn_true = [0, 3, 7, 10, 14, 17, 20, 24, 27, 30, 34, 37, 40] + ssd_vgg_bn_false = [0, 2, 0, 2, 0, 2, 4, 0, 2, 4, 0, 2, 4] + ssd_vgg_bn_true = [0, 3, 0, 3, 0, 3, 6, 0, 3, 6, 0, 3, 6] + + pretrain_vgg_keys = pretrain_vgg_bn_true if pretrain_vgg_bn else pretrain_vgg_bn_false + ssd_vgg_keys = ssd_vgg_bn_true if ssd_vgg_bn else ssd_vgg_bn_false + + pretrain_vgg_keys = ['layers.' + str(pretrain_vgg_keys[i]) for i in range(size)] + ssd_vgg_keys = ['b' + str(vgg_key_num[i]) + '.' + str(ssd_vgg_keys[i]) for i in range(size)] + + return {pretrain_vgg_keys[i]: ssd_vgg_keys[i] for i in range(size)} + + +ssd_vgg_key_mapper = _get_key_mapper() + + +def _make_layer(channels): + in_channels = channels[0] + layers = [] + for out_channels in channels[1:]: + layers.append(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3)) + if ssd_vgg_bn: + layers.append(nn.BatchNorm2d(out_channels)) + layers.append(nn.ReLU()) + in_channels = out_channels + return nn.SequentialCell(layers) + + +class VGG16(nn.Cell): + def __init__(self): + super(VGG16, self).__init__() + self.b1 = _make_layer([3, 64, 64]) + self.b2 = _make_layer([64, 128, 128]) + self.b3 = _make_layer([128, 256, 256, 256]) + self.b4 = _make_layer([256, 512, 512, 512]) + self.b5 = _make_layer([512, 512, 512, 512]) + + self.m1 = nn.MaxPool2d(kernel_size=2, stride=2, pad_mode='SAME') + self.m2 = nn.MaxPool2d(kernel_size=2, stride=2, pad_mode='SAME') + self.m3 = nn.MaxPool2d(kernel_size=2, stride=2, pad_mode='SAME') + self.m4 = nn.MaxPool2d(kernel_size=2, stride=2, pad_mode='SAME') + self.m5 = nn.MaxPool2d(kernel_size=3, stride=1, pad_mode='SAME') + + def construct(self, x): + # block1 + x = self.b1(x) + x = self.m1(x) + + # block2 + x = self.b2(x) + x = self.m2(x) + + # block3 + x = self.b3(x) + x = self.m3(x) + + # block4 + x = self.b4(x) + block4 = x + x = self.m4(x) + + # block5 + x = self.b5(x) + x = self.m5(x) + + return block4, x + + +def vgg16(): + return VGG16() diff --git a/model_zoo/official/cv/ssd/train.py b/model_zoo/official/cv/ssd/train.py index fbeee76244..cf182e3792 100644 --- a/model_zoo/official/cv/ssd/train.py +++ b/model_zoo/official/cv/ssd/train.py @@ -25,7 +25,7 @@ from mindspore.train import Model from mindspore.context import ParallelMode from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.common import set_seed, dtype -from src.ssd import SSD300, SSDWithLossCell, TrainingWrapper, ssd_mobilenet_v2, ssd_mobilenet_v1_fpn, ssd_resnet50_fpn +from src.ssd import SSD300, SSDWithLossCell, TrainingWrapper, ssd_mobilenet_v2, ssd_mobilenet_v1_fpn, ssd_resnet50_fpn, ssd_vgg16 from src.config import config from src.dataset import create_ssd_dataset, create_mindrecord from src.lr_schedule import get_lr @@ -86,6 +86,17 @@ def ssd_model_build(args_opt): param_dict["network.feature_extractor.resnet." + x] = param_dict[x] del param_dict[x] load_param_into_net(ssd.feature_extractor.resnet, param_dict) + elif config.model == "ssd_vgg16": + ssd = ssd_vgg16(config=config) + init_net_param(ssd) + if config.feature_extractor_base_param != "": + param_dict = load_checkpoint(config.feature_extractor_base_param) + from src.vgg16 import ssd_vgg_key_mapper + for k in ssd_vgg_key_mapper: + v = ssd_vgg_key_mapper[k] + param_dict["network.backbone." + v + ".weight"] = param_dict[k + ".weight"] + del param_dict[k + ".weight"] + load_param_into_net(ssd.backbone, param_dict) else: raise ValueError(f'config.model: {config.model} is not supported') return ssd @@ -106,6 +117,8 @@ def main(): init() if config.model == "ssd_resnet50_fpn": context.set_auto_parallel_context(all_reduce_fusion_config=[90, 183, 279]) + if config.model == "ssd_vgg16": + context.set_auto_parallel_context(all_reduce_fusion_config=[20, 41, 62]) else: context.set_auto_parallel_context(all_reduce_fusion_config=[29, 58, 89]) rank = get_rank()