| @@ -35,11 +35,12 @@ SSD discretizes the output space of bounding boxes into a set of default boxes o | |||||
| The SSD approach is based on a feed-forward convolutional network that produces a fixed-size collection of bounding boxes and scores for the presence of object class instances in those boxes, followed by a non-maximum suppression step to produce the final detections. The early network layers are based on a standard architecture used for high quality image classification, which is called the base network. Then add auxiliary structure to the network to produce detections. | The SSD approach is based on a feed-forward convolutional network that produces a fixed-size collection of bounding boxes and scores for the presence of object class instances in those boxes, followed by a non-maximum suppression step to produce the final detections. The early network layers are based on a standard architecture used for high quality image classification, which is called the base network. Then add auxiliary structure to the network to produce detections. | ||||
| We present three different base architecture. | |||||
| We present four different base architecture. | |||||
| - **ssd300**, reference from the paper. Using mobilenetv2 as backbone and the same bbox predictor as the paper present. | - **ssd300**, reference from the paper. Using mobilenetv2 as backbone and the same bbox predictor as the paper present. | ||||
| - ***ssd-mobilenet-v1-fpn**, using mobilenet-v1 and FPN as feature extractor with weight-shared box predcitors. | - ***ssd-mobilenet-v1-fpn**, using mobilenet-v1 and FPN as feature extractor with weight-shared box predcitors. | ||||
| - ***ssd-resnet50-fpn**, using resnet50 and FPN as feature extractor with weight-shared box predcitors. | - ***ssd-resnet50-fpn**, using resnet50 and FPN as feature extractor with weight-shared box predcitors. | ||||
| - **ssd-vgg16**, reference from the paper. Using vgg16 as backbone and the same bbox predictor as the paper present. | |||||
| ## [Dataset](#contents) | ## [Dataset](#contents) | ||||
| @@ -21,7 +21,7 @@ import time | |||||
| import numpy as np | import numpy as np | ||||
| from mindspore import context, Tensor | from mindspore import context, Tensor | ||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | from mindspore.train.serialization import load_checkpoint, load_param_into_net | ||||
| from src.ssd import SSD300, SsdInferWithDecoder, ssd_mobilenet_v2, ssd_mobilenet_v1_fpn, ssd_resnet50_fpn | |||||
| from src.ssd import SSD300, SsdInferWithDecoder, ssd_mobilenet_v2, ssd_mobilenet_v1_fpn, ssd_resnet50_fpn, ssd_vgg16 | |||||
| from src.dataset import create_ssd_dataset, create_mindrecord | from src.dataset import create_ssd_dataset, create_mindrecord | ||||
| from src.config import config | from src.config import config | ||||
| from src.eval_utils import metrics | from src.eval_utils import metrics | ||||
| @@ -34,6 +34,8 @@ def ssd_eval(dataset_path, ckpt_path, anno_json): | |||||
| is_training=False, use_multiprocessing=False) | is_training=False, use_multiprocessing=False) | ||||
| if config.model == "ssd300": | if config.model == "ssd300": | ||||
| net = SSD300(ssd_mobilenet_v2(), config, is_training=False) | net = SSD300(ssd_mobilenet_v2(), config, is_training=False) | ||||
| elif config.model == "ssd_vgg16": | |||||
| net = ssd_vgg16(config=config) | |||||
| elif config.model == "ssd_mobilenet_v1_fpn": | elif config.model == "ssd_mobilenet_v1_fpn": | ||||
| net = ssd_mobilenet_v1_fpn(config=config) | net = ssd_mobilenet_v1_fpn(config=config) | ||||
| elif config.model == "ssd_resnet50_fpn": | elif config.model == "ssd_resnet50_fpn": | ||||
| @@ -19,7 +19,7 @@ import numpy as np | |||||
| import mindspore | import mindspore | ||||
| from mindspore import context, Tensor | from mindspore import context, Tensor | ||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net, export | from mindspore.train.serialization import load_checkpoint, load_param_into_net, export | ||||
| from src.ssd import SSD300, SsdInferWithDecoder, ssd_mobilenet_v2, ssd_mobilenet_v1_fpn, ssd_resnet50_fpn | |||||
| from src.ssd import SSD300, SsdInferWithDecoder, ssd_mobilenet_v2, ssd_mobilenet_v1_fpn, ssd_resnet50_fpn, ssd_vgg16 | |||||
| from src.config import config | from src.config import config | ||||
| from src.box_utils import default_boxes | from src.box_utils import default_boxes | ||||
| @@ -40,6 +40,8 @@ if args.device_target == "Ascend": | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| if config.model == "ssd300": | if config.model == "ssd300": | ||||
| net = SSD300(ssd_mobilenet_v2(), config, is_training=False) | net = SSD300(ssd_mobilenet_v2(), config, is_training=False) | ||||
| elif config.model == "ssd_vgg16": | |||||
| net = ssd_vgg16(config=config) | |||||
| elif config.model == "ssd_mobilenet_v1_fpn": | elif config.model == "ssd_mobilenet_v1_fpn": | ||||
| net = ssd_mobilenet_v1_fpn(config=config) | net = ssd_mobilenet_v1_fpn(config=config) | ||||
| elif config.model == "ssd_resnet50_fpn": | elif config.model == "ssd_resnet50_fpn": | ||||
| @@ -11,18 +11,20 @@ | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | # limitations under the License. | ||||
| #" ============================================================================ | |||||
| # ============================================================================ | |||||
| """Config parameters for SSD models.""" | """Config parameters for SSD models.""" | ||||
| from .config_ssd300 import config as config_ssd300 | from .config_ssd300 import config as config_ssd300 | ||||
| from .config_ssd_mobilenet_v1_fpn import config as config_ssd_mobilenet_v1_fpn | from .config_ssd_mobilenet_v1_fpn import config as config_ssd_mobilenet_v1_fpn | ||||
| from .config_ssd_resnet50_fpn import config as config_ssd_resnet50_fpn | from .config_ssd_resnet50_fpn import config as config_ssd_resnet50_fpn | ||||
| from .config_ssd_vgg16 import config as config_ssd_vgg16 | |||||
| using_model = "ssd300" | using_model = "ssd300" | ||||
| config_map = { | config_map = { | ||||
| "ssd300": config_ssd300, | "ssd300": config_ssd300, | ||||
| "ssd_vgg16": config_ssd_vgg16, | |||||
| "ssd_mobilenet_v1_fpn": config_ssd_mobilenet_v1_fpn, | "ssd_mobilenet_v1_fpn": config_ssd_mobilenet_v1_fpn, | ||||
| "ssd_resnet50_fpn": config_ssd_resnet50_fpn | "ssd_resnet50_fpn": config_ssd_resnet50_fpn | ||||
| } | } | ||||
| @@ -0,0 +1,84 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Config parameters for SSD models.""" | |||||
| from easydict import EasyDict as ed | |||||
| config = ed({ | |||||
| "model": "ssd_vgg16", | |||||
| "img_shape": [300, 300], | |||||
| "num_ssd_boxes": 7308, | |||||
| "match_threshold": 0.5, | |||||
| "nms_threshold": 0.6, | |||||
| "min_score": 0.1, | |||||
| "max_boxes": 100, | |||||
| "ssd_vgg_bn": False, | |||||
| # learing rate settings | |||||
| "lr_init": 0.001, | |||||
| "lr_end_rate": 0.001, | |||||
| "warmup_epochs": 2, | |||||
| "momentum": 0.9, | |||||
| "weight_decay": 1.5e-4, | |||||
| # network | |||||
| "num_default": [3, 6, 6, 6, 6, 6], | |||||
| "extras_in_channels": [256, 512, 1024, 512, 256, 256], | |||||
| "extras_out_channels": [512, 1024, 512, 256, 256, 256], | |||||
| "extras_strides": [1, 1, 2, 2, 2, 2], | |||||
| "extras_ratio": [0.2, 0.2, 0.2, 0.25, 0.5, 0.25], | |||||
| "feature_size": [38, 19, 10, 5, 3, 1], | |||||
| "min_scale": 0.2, | |||||
| "max_scale": 0.95, | |||||
| "aspect_ratios": [(), (2, 3), (2, 3), (2, 3), (2, 3), (2, 3)], | |||||
| "steps": (8, 16, 32, 64, 100, 300), | |||||
| "prior_scaling": (0.1, 0.2), | |||||
| "gamma": 2.0, | |||||
| "alpha": 0.75, | |||||
| # `mindrecord_dir` and `coco_root` are better to use absolute path. | |||||
| "feature_extractor_base_param": "", | |||||
| "pretrain_vgg_bn": False, | |||||
| "checkpoint_filter_list": ['multi_loc_layers', 'multi_cls_layers'], | |||||
| "mindrecord_dir": "/data/MindRecord_COCO", | |||||
| "coco_root": "/data/coco2017", | |||||
| "train_data_type": "train2017", | |||||
| "val_data_type": "val2017", | |||||
| "instances_set": "annotations/instances_{}.json", | |||||
| "classes": ('background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', | |||||
| 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', | |||||
| 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', | |||||
| 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', | |||||
| 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', | |||||
| 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', | |||||
| 'kite', 'baseball bat', 'baseball glove', 'skateboard', | |||||
| 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', | |||||
| 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', | |||||
| 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', | |||||
| 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', | |||||
| 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', | |||||
| 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', | |||||
| 'refrigerator', 'book', 'clock', 'vase', 'scissors', | |||||
| 'teddy bear', 'hair drier', 'toothbrush'), | |||||
| "num_classes": 81, | |||||
| # The annotation.json position of voc validation dataset. | |||||
| "voc_json": "annotations/voc_instances_val.json", | |||||
| # voc original dataset. | |||||
| "voc_root": "/data/voc_dataset", | |||||
| # if coco or voc used, `image_dir` and `anno_path` are useless. | |||||
| "image_dir": "", | |||||
| "anno_path": "" | |||||
| }) | |||||
| @@ -27,6 +27,7 @@ from mindspore.ops import functional as F | |||||
| from mindspore.ops import composite as C | from mindspore.ops import composite as C | ||||
| from .fpn import mobilenet_v1_fpn, resnet50_fpn | from .fpn import mobilenet_v1_fpn, resnet50_fpn | ||||
| from .vgg16 import vgg16 | |||||
| def _make_divisible(v, divisor, min_value=None): | def _make_divisible(v, divisor, min_value=None): | ||||
| @@ -641,3 +642,78 @@ def ssd_resnet50_fpn(**kwargs): | |||||
| def ssd_mobilenet_v2(**kwargs): | def ssd_mobilenet_v2(**kwargs): | ||||
| return SSDWithMobileNetV2(**kwargs) | return SSDWithMobileNetV2(**kwargs) | ||||
| class SSD300VGG16(nn.Cell): | |||||
| def __init__(self, config): | |||||
| super(SSD300VGG16, self).__init__() | |||||
| # VGG16 backbone: block1~5 | |||||
| self.backbone = vgg16() | |||||
| # SSD blocks: block6~7 | |||||
| self.b6_1 = nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=3, padding=6, dilation=6, pad_mode='pad') | |||||
| self.b6_2 = nn.Dropout(0.5) | |||||
| self.b7_1 = nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=1) | |||||
| self.b7_2 = nn.Dropout(0.5) | |||||
| # Extra Feature Layers: block8~11 | |||||
| self.b8_1 = nn.Conv2d(in_channels=1024, out_channels=256, kernel_size=1, padding=1, pad_mode='pad') | |||||
| self.b8_2 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=2, pad_mode='valid') | |||||
| self.b9_1 = nn.Conv2d(in_channels=512, out_channels=128, kernel_size=1, padding=1, pad_mode='pad') | |||||
| self.b9_2 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=2, pad_mode='valid') | |||||
| self.b10_1 = nn.Conv2d(in_channels=256, out_channels=128, kernel_size=1) | |||||
| self.b10_2 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, pad_mode='valid') | |||||
| self.b11_1 = nn.Conv2d(in_channels=256, out_channels=128, kernel_size=1) | |||||
| self.b11_2 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, pad_mode='valid') | |||||
| # boxes | |||||
| self.multi_box = MultiBox(config) | |||||
| if not self.training: | |||||
| self.activation = P.Sigmoid() | |||||
| def construct(self, x): | |||||
| # VGG16 backbone: block1~5 | |||||
| block4, x = self.backbone(x) | |||||
| # SSD blocks: block6~7 | |||||
| x = self.b6_1(x) # 1024 | |||||
| x = self.b6_2(x) | |||||
| x = self.b7_1(x) # 1024 | |||||
| x = self.b7_2(x) | |||||
| block7 = x | |||||
| # Extra Feature Layers: block8~11 | |||||
| x = self.b8_1(x) # 256 | |||||
| x = self.b8_2(x) # 512 | |||||
| block8 = x | |||||
| x = self.b9_1(x) # 128 | |||||
| x = self.b9_2(x) # 256 | |||||
| block9 = x | |||||
| x = self.b10_1(x) # 128 | |||||
| x = self.b10_2(x) # 256 | |||||
| block10 = x | |||||
| x = self.b11_1(x) # 128 | |||||
| x = self.b11_2(x) # 256 | |||||
| block11 = x | |||||
| # boxes | |||||
| multi_feature = (block4, block7, block8, block9, block10, block11) | |||||
| pred_loc, pred_label = self.multi_box(multi_feature) | |||||
| if not self.training: | |||||
| pred_label = self.activation(pred_label) | |||||
| pred_loc = F.cast(pred_loc, mstype.float32) | |||||
| pred_label = F.cast(pred_label, mstype.float32) | |||||
| return pred_loc, pred_label | |||||
| def ssd_vgg16(**kwargs): | |||||
| return SSD300VGG16(**kwargs) | |||||
| @@ -0,0 +1,99 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """VGG16 backbone for SSD""" | |||||
| from mindspore import nn | |||||
| from .config_ssd_vgg16 import config | |||||
| pretrain_vgg_bn = config.pretrain_vgg_bn | |||||
| ssd_vgg_bn = config.ssd_vgg_bn | |||||
| def _get_key_mapper(): | |||||
| vgg_key_num = [1, 1, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5] | |||||
| size = len(vgg_key_num) | |||||
| pretrain_vgg_bn_false = [0, 2, 5, 7, 10, 12, 14, 17, 19, 21, 24, 26, 28] | |||||
| pretrain_vgg_bn_true = [0, 3, 7, 10, 14, 17, 20, 24, 27, 30, 34, 37, 40] | |||||
| ssd_vgg_bn_false = [0, 2, 0, 2, 0, 2, 4, 0, 2, 4, 0, 2, 4] | |||||
| ssd_vgg_bn_true = [0, 3, 0, 3, 0, 3, 6, 0, 3, 6, 0, 3, 6] | |||||
| pretrain_vgg_keys = pretrain_vgg_bn_true if pretrain_vgg_bn else pretrain_vgg_bn_false | |||||
| ssd_vgg_keys = ssd_vgg_bn_true if ssd_vgg_bn else ssd_vgg_bn_false | |||||
| pretrain_vgg_keys = ['layers.' + str(pretrain_vgg_keys[i]) for i in range(size)] | |||||
| ssd_vgg_keys = ['b' + str(vgg_key_num[i]) + '.' + str(ssd_vgg_keys[i]) for i in range(size)] | |||||
| return {pretrain_vgg_keys[i]: ssd_vgg_keys[i] for i in range(size)} | |||||
| ssd_vgg_key_mapper = _get_key_mapper() | |||||
| def _make_layer(channels): | |||||
| in_channels = channels[0] | |||||
| layers = [] | |||||
| for out_channels in channels[1:]: | |||||
| layers.append(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3)) | |||||
| if ssd_vgg_bn: | |||||
| layers.append(nn.BatchNorm2d(out_channels)) | |||||
| layers.append(nn.ReLU()) | |||||
| in_channels = out_channels | |||||
| return nn.SequentialCell(layers) | |||||
| class VGG16(nn.Cell): | |||||
| def __init__(self): | |||||
| super(VGG16, self).__init__() | |||||
| self.b1 = _make_layer([3, 64, 64]) | |||||
| self.b2 = _make_layer([64, 128, 128]) | |||||
| self.b3 = _make_layer([128, 256, 256, 256]) | |||||
| self.b4 = _make_layer([256, 512, 512, 512]) | |||||
| self.b5 = _make_layer([512, 512, 512, 512]) | |||||
| self.m1 = nn.MaxPool2d(kernel_size=2, stride=2, pad_mode='SAME') | |||||
| self.m2 = nn.MaxPool2d(kernel_size=2, stride=2, pad_mode='SAME') | |||||
| self.m3 = nn.MaxPool2d(kernel_size=2, stride=2, pad_mode='SAME') | |||||
| self.m4 = nn.MaxPool2d(kernel_size=2, stride=2, pad_mode='SAME') | |||||
| self.m5 = nn.MaxPool2d(kernel_size=3, stride=1, pad_mode='SAME') | |||||
| def construct(self, x): | |||||
| # block1 | |||||
| x = self.b1(x) | |||||
| x = self.m1(x) | |||||
| # block2 | |||||
| x = self.b2(x) | |||||
| x = self.m2(x) | |||||
| # block3 | |||||
| x = self.b3(x) | |||||
| x = self.m3(x) | |||||
| # block4 | |||||
| x = self.b4(x) | |||||
| block4 = x | |||||
| x = self.m4(x) | |||||
| # block5 | |||||
| x = self.b5(x) | |||||
| x = self.m5(x) | |||||
| return block4, x | |||||
| def vgg16(): | |||||
| return VGG16() | |||||
| @@ -25,7 +25,7 @@ from mindspore.train import Model | |||||
| from mindspore.context import ParallelMode | from mindspore.context import ParallelMode | ||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | from mindspore.train.serialization import load_checkpoint, load_param_into_net | ||||
| from mindspore.common import set_seed, dtype | from mindspore.common import set_seed, dtype | ||||
| from src.ssd import SSD300, SSDWithLossCell, TrainingWrapper, ssd_mobilenet_v2, ssd_mobilenet_v1_fpn, ssd_resnet50_fpn | |||||
| from src.ssd import SSD300, SSDWithLossCell, TrainingWrapper, ssd_mobilenet_v2, ssd_mobilenet_v1_fpn, ssd_resnet50_fpn, ssd_vgg16 | |||||
| from src.config import config | from src.config import config | ||||
| from src.dataset import create_ssd_dataset, create_mindrecord | from src.dataset import create_ssd_dataset, create_mindrecord | ||||
| from src.lr_schedule import get_lr | from src.lr_schedule import get_lr | ||||
| @@ -86,6 +86,17 @@ def ssd_model_build(args_opt): | |||||
| param_dict["network.feature_extractor.resnet." + x] = param_dict[x] | param_dict["network.feature_extractor.resnet." + x] = param_dict[x] | ||||
| del param_dict[x] | del param_dict[x] | ||||
| load_param_into_net(ssd.feature_extractor.resnet, param_dict) | load_param_into_net(ssd.feature_extractor.resnet, param_dict) | ||||
| elif config.model == "ssd_vgg16": | |||||
| ssd = ssd_vgg16(config=config) | |||||
| init_net_param(ssd) | |||||
| if config.feature_extractor_base_param != "": | |||||
| param_dict = load_checkpoint(config.feature_extractor_base_param) | |||||
| from src.vgg16 import ssd_vgg_key_mapper | |||||
| for k in ssd_vgg_key_mapper: | |||||
| v = ssd_vgg_key_mapper[k] | |||||
| param_dict["network.backbone." + v + ".weight"] = param_dict[k + ".weight"] | |||||
| del param_dict[k + ".weight"] | |||||
| load_param_into_net(ssd.backbone, param_dict) | |||||
| else: | else: | ||||
| raise ValueError(f'config.model: {config.model} is not supported') | raise ValueError(f'config.model: {config.model} is not supported') | ||||
| return ssd | return ssd | ||||
| @@ -106,6 +117,8 @@ def main(): | |||||
| init() | init() | ||||
| if config.model == "ssd_resnet50_fpn": | if config.model == "ssd_resnet50_fpn": | ||||
| context.set_auto_parallel_context(all_reduce_fusion_config=[90, 183, 279]) | context.set_auto_parallel_context(all_reduce_fusion_config=[90, 183, 279]) | ||||
| if config.model == "ssd_vgg16": | |||||
| context.set_auto_parallel_context(all_reduce_fusion_config=[20, 41, 62]) | |||||
| else: | else: | ||||
| context.set_auto_parallel_context(all_reduce_fusion_config=[29, 58, 89]) | context.set_auto_parallel_context(all_reduce_fusion_config=[29, 58, 89]) | ||||
| rank = get_rank() | rank = get_rank() | ||||