| @@ -21,7 +21,7 @@ import time | |||
| import numpy as np | |||
| from mindspore import context, Tensor | |||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| from src.ssd import SSD300, ssd_mobilenet_v2 | |||
| from src.ssd import SSD300, ssd_mobilenet_v2, ssd_mobilenet_v1_fpn | |||
| from src.dataset import create_ssd_dataset, create_mindrecord | |||
| from src.config import config | |||
| from src.eval_utils import metrics | |||
| @@ -31,7 +31,10 @@ def ssd_eval(dataset_path, ckpt_path, anno_json): | |||
| batch_size = 1 | |||
| ds = create_ssd_dataset(dataset_path, batch_size=batch_size, repeat_num=1, | |||
| is_training=False, use_multiprocessing=False) | |||
| net = SSD300(ssd_mobilenet_v2(), config, is_training=False) | |||
| if config.model == "ssd300": | |||
| net = SSD300(ssd_mobilenet_v2(), config, is_training=False) | |||
| else: | |||
| net = ssd_mobilenet_v1_fpn(config=config) | |||
| print("Load Checkpoint!") | |||
| param_dict = load_checkpoint(ckpt_path) | |||
| net.init_parameters_data() | |||
| @@ -0,0 +1,92 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Anchor Generator""" | |||
| import numpy as np | |||
| class GridAnchorGenerator: | |||
| """ | |||
| Anchor Generator | |||
| """ | |||
| def __init__(self, image_shape, scale, scales_per_octave, aspect_ratios): | |||
| super(GridAnchorGenerator, self).__init__() | |||
| self.scale = scale | |||
| self.scales_per_octave = scales_per_octave | |||
| self.aspect_ratios = aspect_ratios | |||
| self.image_shape = image_shape | |||
| def generate(self, step): | |||
| scales = np.array([2**(float(scale) / self.scales_per_octave) | |||
| for scale in range(self.scales_per_octave)]).astype(np.float32) | |||
| aspects = np.array(list(self.aspect_ratios)).astype(np.float32) | |||
| scales_grid, aspect_ratios_grid = np.meshgrid(scales, aspects) | |||
| scales_grid = scales_grid.reshape([-1]) | |||
| aspect_ratios_grid = aspect_ratios_grid.reshape([-1]) | |||
| feature_size = [self.image_shape[0] / step, self.image_shape[0] / step] | |||
| grid_height, grid_width = feature_size | |||
| base_size = np.array([self.scale * step, self.scale * step]).astype(np.float32) | |||
| anchor_offset = step / 2.0 | |||
| ratio_sqrt = np.sqrt(aspect_ratios_grid) | |||
| heights = scales_grid / ratio_sqrt * base_size[0] | |||
| widths = scales_grid * ratio_sqrt * base_size[1] | |||
| y_centers = np.arange(grid_height).astype(np.float32) | |||
| y_centers = y_centers * step + anchor_offset | |||
| x_centers = np.arange(grid_width).astype(np.float32) | |||
| x_centers = x_centers * step + anchor_offset | |||
| x_centers, y_centers = np.meshgrid(x_centers, y_centers) | |||
| x_centers_shape = x_centers.shape | |||
| y_centers_shape = y_centers.shape | |||
| widths_grid, x_centers_grid = np.meshgrid(widths, x_centers.reshape([-1])) | |||
| heights_grid, y_centers_grid = np.meshgrid(heights, y_centers.reshape([-1])) | |||
| x_centers_grid = x_centers_grid.reshape(*x_centers_shape, -1) | |||
| y_centers_grid = y_centers_grid.reshape(*y_centers_shape, -1) | |||
| widths_grid = widths_grid.reshape(-1, *x_centers_shape) | |||
| heights_grid = heights_grid.reshape(-1, *y_centers_shape) | |||
| bbox_centers = np.stack([y_centers_grid, x_centers_grid], axis=3) | |||
| bbox_sizes = np.stack([heights_grid, widths_grid], axis=3) | |||
| bbox_centers = bbox_centers.reshape([-1, 2]) | |||
| bbox_sizes = bbox_sizes.reshape([-1, 2]) | |||
| bbox_corners = np.concatenate([bbox_centers - 0.5 * bbox_sizes, bbox_centers + 0.5 * bbox_sizes], axis=1) | |||
| self.bbox_corners = bbox_corners / np.array([*self.image_shape, *self.image_shape]).astype(np.float32) | |||
| self.bbox_centers = np.concatenate([bbox_centers, bbox_sizes], axis=1) | |||
| self.bbox_centers = self.bbox_centers / np.array([*self.image_shape, *self.image_shape]).astype(np.float32) | |||
| print(self.bbox_centers.shape) | |||
| return self.bbox_centers, self.bbox_corners | |||
| def generate_multi_levels(self, steps): | |||
| bbox_centers_list = [] | |||
| bbox_corners_list = [] | |||
| for step in steps: | |||
| bbox_centers, bbox_corners = self.generate(step) | |||
| bbox_centers_list.append(bbox_centers) | |||
| bbox_corners_list.append(bbox_corners) | |||
| self.bbox_centers = np.concatenate(bbox_centers_list, axis=0) | |||
| self.bbox_corners = np.concatenate(bbox_corners_list, axis=0) | |||
| return self.bbox_centers, self.bbox_corners | |||
| @@ -19,6 +19,7 @@ import math | |||
| import itertools as it | |||
| import numpy as np | |||
| from .config import config | |||
| from .anchor_generator import GridAnchorGenerator | |||
| class GeneratDefaultBoxes(): | |||
| @@ -36,7 +37,7 @@ class GeneratDefaultBoxes(): | |||
| sk1 = scales[idex] | |||
| sk2 = scales[idex + 1] | |||
| sk3 = math.sqrt(sk1 * sk2) | |||
| if idex == 0: | |||
| if idex == 0 and not config.aspect_ratios[idex]: | |||
| w, h = sk1 * math.sqrt(2), sk1 / math.sqrt(2) | |||
| all_sizes = [(0.1, 0.1), (w, h), (h, w)] | |||
| else: | |||
| @@ -61,9 +62,12 @@ class GeneratDefaultBoxes(): | |||
| self.default_boxes_tlbr = np.array(tuple(to_tlbr(*i) for i in self.default_boxes), dtype='float32') | |||
| self.default_boxes = np.array(self.default_boxes, dtype='float32') | |||
| default_boxes_tlbr = GeneratDefaultBoxes().default_boxes_tlbr | |||
| default_boxes = GeneratDefaultBoxes().default_boxes | |||
| if 'use_anchor_generator' in config and config.use_anchor_generator: | |||
| generator = GridAnchorGenerator(config.img_shape, 4, 2, [1.0, 2.0, 0.5]) | |||
| default_boxes, default_boxes_tlbr = generator.generate_multi_levels(config.steps) | |||
| else: | |||
| default_boxes_tlbr = GeneratDefaultBoxes().default_boxes_tlbr | |||
| default_boxes = GeneratDefaultBoxes().default_boxes | |||
| y1, x1, y2, x2 = np.split(default_boxes_tlbr[:, :4], 4, axis=-1) | |||
| vol_anchors = (x2 - x1) * (y2 - y1) | |||
| matching_threshold = config.match_threshold | |||
| @@ -15,68 +15,15 @@ | |||
| """Config parameters for SSD models.""" | |||
| from easydict import EasyDict as ed | |||
| from .config_ssd300 import config as config_ssd300 | |||
| from .config_ssd_mobilenet_v1_fpn import config as config_ssd_mobilenet_v1_fpn | |||
| config = ed({ | |||
| "img_shape": [300, 300], | |||
| "num_ssd_boxes": 1917, | |||
| "neg_pre_positive": 3, | |||
| "match_threshold": 0.5, | |||
| "nms_threshold": 0.6, | |||
| "min_score": 0.1, | |||
| "max_boxes": 100, | |||
| # learing rate settings | |||
| "lr_init": 0.001, | |||
| "lr_end_rate": 0.001, | |||
| "warmup_epochs": 2, | |||
| "momentum": 0.9, | |||
| "weight_decay": 1.5e-4, | |||
| using_model = "ssd300" | |||
| # network | |||
| "num_default": [3, 6, 6, 6, 6, 6], | |||
| "extras_in_channels": [256, 576, 1280, 512, 256, 256], | |||
| "extras_out_channels": [576, 1280, 512, 256, 256, 128], | |||
| "extras_strides": [1, 1, 2, 2, 2, 2], | |||
| "extras_ratio": [0.2, 0.2, 0.2, 0.25, 0.5, 0.25], | |||
| "feature_size": [19, 10, 5, 3, 2, 1], | |||
| "min_scale": 0.2, | |||
| "max_scale": 0.95, | |||
| "aspect_ratios": [(2,), (2, 3), (2, 3), (2, 3), (2, 3), (2, 3)], | |||
| "steps": (16, 32, 64, 100, 150, 300), | |||
| "prior_scaling": (0.1, 0.2), | |||
| "gamma": 2.0, | |||
| "alpha": 0.75, | |||
| config_map = { | |||
| "ssd300": config_ssd300, | |||
| "ssd_mobilenet_v1_fpn": config_ssd_mobilenet_v1_fpn | |||
| } | |||
| # `mindrecord_dir` and `coco_root` are better to use absolute path. | |||
| "mindrecord_dir": "/data/MindRecord_COCO", | |||
| "coco_root": "/data/coco2017", | |||
| "train_data_type": "train2017", | |||
| "val_data_type": "val2017", | |||
| "instances_set": "annotations/instances_{}.json", | |||
| "classes": ('background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', | |||
| 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', | |||
| 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', | |||
| 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', | |||
| 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', | |||
| 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', | |||
| 'kite', 'baseball bat', 'baseball glove', 'skateboard', | |||
| 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', | |||
| 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', | |||
| 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', | |||
| 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', | |||
| 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', | |||
| 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', | |||
| 'refrigerator', 'book', 'clock', 'vase', 'scissors', | |||
| 'teddy bear', 'hair drier', 'toothbrush'), | |||
| "num_classes": 81, | |||
| # The annotation.json position of voc validation dataset. | |||
| "voc_json": "annotations/voc_instances_val.json", | |||
| # voc original dataset. | |||
| "voc_root": "/data/voc_dataset", | |||
| # if coco or voc used, `image_dir` and `anno_path` are useless. | |||
| "image_dir": "", | |||
| "anno_path": "", | |||
| "export_format": "MINDIR", | |||
| "export_file": "ssd.mindir" | |||
| }) | |||
| config = config_map[using_model] | |||
| @@ -0,0 +1,84 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| #" ============================================================================ | |||
| """Config parameters for SSD models.""" | |||
| from easydict import EasyDict as ed | |||
| config = ed({ | |||
| "model": "ssd300", | |||
| "img_shape": [300, 300], | |||
| "num_ssd_boxes": 1917, | |||
| "neg_pre_positive": 3, | |||
| "match_threshold": 0.5, | |||
| "nms_threshold": 0.6, | |||
| "min_score": 0.1, | |||
| "max_boxes": 100, | |||
| # learing rate settings | |||
| "lr_init": 0.001, | |||
| "lr_end_rate": 0.001, | |||
| "warmup_epochs": 2, | |||
| "momentum": 0.9, | |||
| "weight_decay": 1.5e-4, | |||
| # network | |||
| "num_default": [3, 6, 6, 6, 6, 6], | |||
| "extras_in_channels": [256, 576, 1280, 512, 256, 256], | |||
| "extras_out_channels": [576, 1280, 512, 256, 256, 128], | |||
| "extras_strides": [1, 1, 2, 2, 2, 2], | |||
| "extras_ratio": [0.2, 0.2, 0.2, 0.25, 0.5, 0.25], | |||
| "feature_size": [19, 10, 5, 3, 2, 1], | |||
| "min_scale": 0.2, | |||
| "max_scale": 0.95, | |||
| "aspect_ratios": [(), (2, 3), (2, 3), (2, 3), (2, 3), (2, 3)], | |||
| "steps": (16, 32, 64, 100, 150, 300), | |||
| "prior_scaling": (0.1, 0.2), | |||
| "gamma": 2.0, | |||
| "alpha": 0.75, | |||
| # `mindrecord_dir` and `coco_root` are better to use absolute path. | |||
| "feature_extractor_base_param": "", | |||
| "mindrecord_dir": "/data/MindRecord_COCO", | |||
| "coco_root": "/data/coco2017", | |||
| "train_data_type": "train2017", | |||
| "val_data_type": "val2017", | |||
| "instances_set": "annotations/instances_{}.json", | |||
| "classes": ('background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', | |||
| 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', | |||
| 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', | |||
| 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', | |||
| 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', | |||
| 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', | |||
| 'kite', 'baseball bat', 'baseball glove', 'skateboard', | |||
| 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', | |||
| 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', | |||
| 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', | |||
| 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', | |||
| 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', | |||
| 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', | |||
| 'refrigerator', 'book', 'clock', 'vase', 'scissors', | |||
| 'teddy bear', 'hair drier', 'toothbrush'), | |||
| "num_classes": 81, | |||
| # The annotation.json position of voc validation dataset. | |||
| "voc_json": "annotations/voc_instances_val.json", | |||
| # voc original dataset. | |||
| "voc_root": "/data/voc_dataset", | |||
| # if coco or voc used, `image_dir` and `anno_path` are useless. | |||
| "image_dir": "", | |||
| "anno_path": "", | |||
| "export_format": "MINDIR", | |||
| "export_file": "ssd.mindir" | |||
| }) | |||
| @@ -0,0 +1,88 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| #" ============================================================================ | |||
| """Config parameters for SSD models.""" | |||
| from easydict import EasyDict as ed | |||
| config = ed({ | |||
| "model": "ssd_mobilenet_v1_fpn", | |||
| "img_shape": [640, 640], | |||
| "num_ssd_boxes": 51150, | |||
| "neg_pre_positive": 3, | |||
| "match_threshold": 0.5, | |||
| "nms_threshold": 0.6, | |||
| "min_score": 0.1, | |||
| "max_boxes": 100, | |||
| # learning rate settings | |||
| "global_step": 0, | |||
| "lr_init": 0.01333, | |||
| "lr_end_rate": 0.0, | |||
| "warmup_epochs": 2, | |||
| "momentum": 0.9, | |||
| "weight_decay": 1.5e-4, | |||
| # network | |||
| "num_default": [6, 6, 6, 6, 6], | |||
| "extras_in_channels": [256, 512, 1024, 256, 256], | |||
| "extras_out_channels": [256, 256, 256, 256, 256], | |||
| "extras_strides": [1, 1, 2, 2, 2, 2], | |||
| "extras_ratio": [0.2, 0.2, 0.2, 0.25, 0.5, 0.25], | |||
| "feature_size": [80, 40, 20, 10, 5], | |||
| "min_scale": 0.2, | |||
| "max_scale": 0.95, | |||
| "aspect_ratios": [(2, 3), (2, 3), (2, 3), (2, 3), (2, 3), (2, 3)], | |||
| "steps": (8, 16, 32, 64, 128), | |||
| "prior_scaling": (0.1, 0.2), | |||
| "gamma": 2.0, | |||
| "alpha": 0.75, | |||
| "num_addition_layers": 4, | |||
| "use_anchor_generator": True, | |||
| "use_global_norm": True, | |||
| # `mindrecord_dir` and `coco_root` are better to use absolute path. | |||
| "feature_extractor_base_param": "/ckpt/mobilenet_v1.ckpt", | |||
| "mindrecord_dir": "/data/MindRecord_COCO", | |||
| "coco_root": "/data/coco2017", | |||
| "train_data_type": "train2017", | |||
| "val_data_type": "val2017", | |||
| "instances_set": "annotations/instances_{}.json", | |||
| "classes": ('background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', | |||
| 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', | |||
| 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', | |||
| 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', | |||
| 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', | |||
| 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', | |||
| 'kite', 'baseball bat', 'baseball glove', 'skateboard', | |||
| 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', | |||
| 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', | |||
| 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', | |||
| 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', | |||
| 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', | |||
| 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', | |||
| 'refrigerator', 'book', 'clock', 'vase', 'scissors', | |||
| 'teddy bear', 'hair drier', 'toothbrush'), | |||
| "num_classes": 81, | |||
| # The annotation.json position of voc validation dataset. | |||
| "voc_json": "annotations/voc_instances_val.json", | |||
| # voc original dataset. | |||
| "voc_root": "/data/voc_dataset", | |||
| # if coco or voc used, `image_dir` and `anno_path` are useless. | |||
| "image_dir": "", | |||
| "anno_path": "", | |||
| "export_format": "MINDIR", | |||
| "export_file": "ssd.mindir" | |||
| }) | |||
| @@ -22,14 +22,14 @@ def init_net_param(network, initialize_mode='TruncatedNormal'): | |||
| for p in params: | |||
| if 'beta' not in p.name and 'gamma' not in p.name and 'bias' not in p.name: | |||
| if initialize_mode == 'TruncatedNormal': | |||
| p.set_data(initializer(TruncatedNormal(), p.data.shape, p.data.dtype)) | |||
| p.set_data(initializer(TruncatedNormal(0.02), p.data.shape, p.data.dtype)) | |||
| else: | |||
| p.set_data(initialize_mode, p.data.shape, p.data.dtype) | |||
| def load_backbone_params(network, param_dict): | |||
| """Init the parameters from pre-train model, default is mobilenetv2.""" | |||
| for _, param in net.parameters_and_names(): | |||
| for _, param in network.parameters_and_names(): | |||
| param_name = param.name.replace('network.backbone.', '') | |||
| name_split = param_name.split('.') | |||
| if 'features_1' in param_name: | |||
| @@ -0,0 +1,192 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import mindspore.nn as nn | |||
| from mindspore.ops import operations as P | |||
| from mindspore.ops import functional as F | |||
| def conv_bn_relu(in_channel, out_channel, kernel_size, stride, depthwise, activation='relu6'): | |||
| output = [] | |||
| output.append(nn.Conv2d(in_channel, out_channel, kernel_size, stride, pad_mode="same", | |||
| group=1 if not depthwise else in_channel)) | |||
| output.append(nn.BatchNorm2d(out_channel)) | |||
| if activation: | |||
| output.append(nn.get_activation(activation)) | |||
| return nn.SequentialCell(output) | |||
| class MobileNetV1(nn.Cell): | |||
| """ | |||
| MobileNet V1 backbone | |||
| """ | |||
| def __init__(self, class_num=1001, features_only=False): | |||
| super(MobileNetV1, self).__init__() | |||
| self.features_only = features_only | |||
| cnn = [ | |||
| conv_bn_relu(3, 32, 3, 2, False), # Conv0 | |||
| conv_bn_relu(32, 32, 3, 1, True), # Conv1_depthwise | |||
| conv_bn_relu(32, 64, 1, 1, False), # Conv1_pointwise | |||
| conv_bn_relu(64, 64, 3, 2, True), # Conv2_depthwise | |||
| conv_bn_relu(64, 128, 1, 1, False), # Conv2_pointwise | |||
| conv_bn_relu(128, 128, 3, 1, True), # Conv3_depthwise | |||
| conv_bn_relu(128, 128, 1, 1, False), # Conv3_pointwise | |||
| conv_bn_relu(128, 128, 3, 2, True), # Conv4_depthwise | |||
| conv_bn_relu(128, 256, 1, 1, False), # Conv4_pointwise | |||
| conv_bn_relu(256, 256, 3, 1, True), # Conv5_depthwise | |||
| conv_bn_relu(256, 256, 1, 1, False), # Conv5_pointwise | |||
| conv_bn_relu(256, 256, 3, 2, True), # Conv6_depthwise | |||
| conv_bn_relu(256, 512, 1, 1, False), # Conv6_pointwise | |||
| conv_bn_relu(512, 512, 3, 1, True), # Conv7_depthwise | |||
| conv_bn_relu(512, 512, 1, 1, False), # Conv7_pointwise | |||
| conv_bn_relu(512, 512, 3, 1, True), # Conv8_depthwise | |||
| conv_bn_relu(512, 512, 1, 1, False), # Conv8_pointwise | |||
| conv_bn_relu(512, 512, 3, 1, True), # Conv9_depthwise | |||
| conv_bn_relu(512, 512, 1, 1, False), # Conv9_pointwise | |||
| conv_bn_relu(512, 512, 3, 1, True), # Conv10_depthwise | |||
| conv_bn_relu(512, 512, 1, 1, False), # Conv10_pointwise | |||
| conv_bn_relu(512, 512, 3, 1, True), # Conv11_depthwise | |||
| conv_bn_relu(512, 512, 1, 1, False), # Conv11_pointwise | |||
| conv_bn_relu(512, 512, 3, 2, True), # Conv12_depthwise | |||
| conv_bn_relu(512, 1024, 1, 1, False), # Conv12_pointwise | |||
| conv_bn_relu(1024, 1024, 3, 1, True), # Conv13_depthwise | |||
| conv_bn_relu(1024, 1024, 1, 1, False), # Conv13_pointwise | |||
| ] | |||
| if self.features_only: | |||
| self.network = nn.CellList(cnn) | |||
| else: | |||
| self.network = nn.SequentialCell(cnn) | |||
| self.fc = nn.Dense(1024, class_num) | |||
| def construct(self, x): | |||
| output = x | |||
| if self.features_only: | |||
| features = () | |||
| for block in self.network: | |||
| output = block(output) | |||
| features = features + (output,) | |||
| return features | |||
| output = self.network(x) | |||
| output = P.ReduceMean()(output, (2, 3)) | |||
| output = self.fc(output) | |||
| return output | |||
| class FpnTopDown(nn.Cell): | |||
| """ | |||
| Fpn to extract features | |||
| """ | |||
| def __init__(self, in_channel_list, out_channels): | |||
| super(FpnTopDown, self).__init__() | |||
| self.lateral_convs_list_ = [] | |||
| self.fpn_convs_ = [] | |||
| for channel in in_channel_list: | |||
| l_conv = nn.Conv2d(channel, out_channels, kernel_size=1, stride=1, | |||
| has_bias=True, padding=0, pad_mode='same') | |||
| fpn_conv = conv_bn_relu(out_channels, out_channels, kernel_size=3, stride=1, depthwise=False) | |||
| self.lateral_convs_list_.append(l_conv) | |||
| self.fpn_convs_.append(fpn_conv) | |||
| self.lateral_convs_list = nn.layer.CellList(self.lateral_convs_list_) | |||
| self.fpn_convs_list = nn.layer.CellList(self.fpn_convs_) | |||
| self.num_layers = len(in_channel_list) | |||
| def construct(self, inputs): | |||
| image_features = () | |||
| for i, feature in enumerate(inputs): | |||
| image_features = image_features + (self.lateral_convs_list[i](feature),) | |||
| features = (image_features[-1],) | |||
| for i in range(len(inputs) - 1): | |||
| top = len(inputs) - i - 1 | |||
| down = top - 1 | |||
| size = F.shape(inputs[down]) | |||
| top_down = P.ResizeBilinear((size[2], size[3]))(features[-1]) | |||
| top_down = top_down + image_features[down] | |||
| features = features + (top_down,) | |||
| extract_features = () | |||
| num_features = len(features) | |||
| for i in range(num_features): | |||
| extract_features = extract_features + (self.fpn_convs_list[i](features[num_features - i - 1]),) | |||
| return extract_features | |||
| class BottomUp(nn.Cell): | |||
| """ | |||
| Bottom Up feature extractor | |||
| """ | |||
| def __init__(self, levels, channels, kernel_size, stride): | |||
| super(BottomUp, self).__init__() | |||
| self.levels = levels | |||
| bottom_up_cells = [ | |||
| conv_bn_relu(channels, channels, kernel_size, stride, False) for x in range(self.levels) | |||
| ] | |||
| self.blocks = nn.CellList(bottom_up_cells) | |||
| def construct(self, features): | |||
| for block in self.blocks: | |||
| features = features + (block(features[-1]),) | |||
| return features | |||
| class FeatureSelector(nn.Cell): | |||
| """ | |||
| Select specific layers from an entire feature list | |||
| """ | |||
| def __init__(self, feature_idxes): | |||
| super(FeatureSelector, self).__init__() | |||
| self.feature_idxes = feature_idxes | |||
| def construct(self, feature_list): | |||
| selected = () | |||
| for i in self.feature_idxes: | |||
| selected = selected + (feature_list[i],) | |||
| return selected | |||
| class MobileNetV1Fpn(nn.Cell): | |||
| """ | |||
| MobileNetV1 with FPN as SSD backbone. | |||
| """ | |||
| def __init__(self, config): | |||
| super(MobileNetV1Fpn, self).__init__() | |||
| self.mobilenet_v1 = MobileNetV1(features_only=True) | |||
| self.selector = FeatureSelector([10, 22, 26]) | |||
| self.layer_indexs = [10, 22, 26] | |||
| self.fpn = FpnTopDown([256, 512, 1024], 256) | |||
| self.bottom_up = BottomUp(2, 256, 3, 2) | |||
| def construct(self, x): | |||
| features = self.mobilenet_v1(x) | |||
| features = self.selector(features) | |||
| features = self.fpn(features) | |||
| features = self.bottom_up(features) | |||
| return features | |||
| def mobilenet_v1_fpn(config): | |||
| return MobileNetV1Fpn(config) | |||
| def mobilenet_v1(class_num=1001): | |||
| return MobileNetV1(class_num) | |||
| @@ -26,6 +26,8 @@ from mindspore.ops import operations as P | |||
| from mindspore.ops import functional as F | |||
| from mindspore.ops import composite as C | |||
| from .mobilenet_v1_fpn import mobilenet_v1_fpn | |||
| def _make_divisible(v, divisor, min_value=None): | |||
| """nsures that all layers have a channel number that is divisible by 8.""" | |||
| @@ -67,6 +69,7 @@ class ConvBNReLU(nn.Cell): | |||
| kernel_size (int): Input kernel size. | |||
| stride (int): Stride size for the first convolutional layer. Default: 1. | |||
| groups (int): channel group. Convolution is 1 while Depthiwse is input channel. Default: 1. | |||
| shared_conv(Cell): Use the weight shared conv, default: None. | |||
| Returns: | |||
| Tensor, output tensor. | |||
| @@ -74,18 +77,21 @@ class ConvBNReLU(nn.Cell): | |||
| Examples: | |||
| >>> ConvBNReLU(16, 256, kernel_size=1, stride=1, groups=1) | |||
| """ | |||
| def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1): | |||
| def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1, shared_conv=None): | |||
| super(ConvBNReLU, self).__init__() | |||
| padding = 0 | |||
| in_channels = in_planes | |||
| out_channels = out_planes | |||
| if groups == 1: | |||
| conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, pad_mode='same', padding=padding) | |||
| if shared_conv is None: | |||
| if groups == 1: | |||
| conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, pad_mode='same', padding=padding) | |||
| else: | |||
| out_channels = in_planes | |||
| conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, pad_mode='same', | |||
| padding=padding, group=in_channels) | |||
| layers = [conv, _bn(out_planes), nn.ReLU6()] | |||
| else: | |||
| out_channels = in_planes | |||
| conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, pad_mode='same', | |||
| padding=padding, group=in_channels) | |||
| layers = [conv, _bn(out_planes), nn.ReLU6()] | |||
| layers = [shared_conv, _bn(out_planes), nn.ReLU6()] | |||
| self.features = nn.SequentialCell(layers) | |||
| def construct(self, x): | |||
| @@ -205,6 +211,86 @@ class MultiBox(nn.Cell): | |||
| return self.flatten_concat(loc_outputs), self.flatten_concat(cls_outputs) | |||
| class WeightSharedMultiBox(nn.Cell): | |||
| """ | |||
| Weight shared Multi-box conv layers. Each multi-box layer contains class conf scores and localization predictions. | |||
| All box predictors shares the same conv weight in different features. | |||
| Args: | |||
| config (dict): The default config of SSD. | |||
| loc_cls_shared_addition(bool): Whether the location predictor and classifier prediction share the | |||
| same addition layer. | |||
| Returns: | |||
| Tensor, localization predictions. | |||
| Tensor, class conf scores. | |||
| """ | |||
| def __init__(self, config, loc_cls_shared_addition=False): | |||
| super(WeightSharedMultiBox, self).__init__() | |||
| num_classes = config.num_classes | |||
| out_channels = config.extras_out_channels[0] | |||
| num_default = config.num_default[0] | |||
| num_features = len(config.feature_size) | |||
| num_addition_layers = config.num_addition_layers | |||
| self.loc_cls_shared_addition = loc_cls_shared_addition | |||
| if not loc_cls_shared_addition: | |||
| loc_convs = [ | |||
| _conv2d(out_channels, out_channels, 3, 1) for x in range(num_addition_layers) | |||
| ] | |||
| cls_convs = [ | |||
| _conv2d(out_channels, out_channels, 3, 1) for x in range(num_addition_layers) | |||
| ] | |||
| addition_loc_layer_list = [] | |||
| addition_cls_layer_list = [] | |||
| for _ in range(num_features): | |||
| addition_loc_layer = [ | |||
| ConvBNReLU(out_channels, out_channels, 3, 1, 1, loc_convs[x]) for x in range(num_addition_layers) | |||
| ] | |||
| addition_cls_layer = [ | |||
| ConvBNReLU(out_channels, out_channels, 3, 1, 1, cls_convs[x]) for x in range(num_addition_layers) | |||
| ] | |||
| addition_loc_layer_list.append(nn.SequentialCell(addition_loc_layer)) | |||
| addition_cls_layer_list.append(nn.SequentialCell(addition_cls_layer)) | |||
| self.addition_layer_loc = nn.CellList(addition_loc_layer_list) | |||
| self.addition_layer_cls = nn.CellList(addition_cls_layer_list) | |||
| else: | |||
| convs = [ | |||
| _conv2d(out_channels, out_channels, 3, 1) for x in range(num_addition_layers) | |||
| ] | |||
| addition_layer_list = [] | |||
| for _ in range(num_features): | |||
| addition_layers = [ | |||
| ConvBNReLU(out_channels, out_channels, 3, 1, 1, convs[x]) for x in range(num_addition_layers) | |||
| ] | |||
| addition_layer_list.append(nn.SequentialCell(addition_layers)) | |||
| self.addition_layer = nn.SequentialCell(addition_layer_list) | |||
| loc_layers = [_conv2d(out_channels, 4 * num_default, | |||
| kernel_size=3, stride=1, pad_mod='same')] | |||
| cls_layers = [_conv2d(out_channels, num_classes * num_default, | |||
| kernel_size=3, stride=1, pad_mod='same')] | |||
| self.loc_layers = nn.SequentialCell(loc_layers) | |||
| self.cls_layers = nn.SequentialCell(cls_layers) | |||
| self.flatten_concat = FlattenConcat(config) | |||
| def construct(self, inputs): | |||
| loc_outputs = () | |||
| cls_outputs = () | |||
| num_heads = len(inputs) | |||
| for i in range(num_heads): | |||
| if self.loc_cls_shared_addition: | |||
| features = self.addition_layer[i](inputs[i]) | |||
| loc_outputs += (self.loc_layers(features),) | |||
| cls_outputs += (self.cls_layers(features),) | |||
| else: | |||
| features = self.addition_layer_loc[i](inputs[i]) | |||
| loc_outputs += (self.loc_layers(features),) | |||
| features = self.addition_layer_cls[i](inputs[i]) | |||
| cls_outputs += (self.cls_layers(features),) | |||
| return self.flatten_concat(loc_outputs), self.flatten_concat(cls_outputs) | |||
| class SSD300(nn.Cell): | |||
| """ | |||
| SSD300 Network. Default backbone is resnet34. | |||
| @@ -255,6 +341,40 @@ class SSD300(nn.Cell): | |||
| return pred_loc, pred_label | |||
| class SsdMobilenetV1Fpn(nn.Cell): | |||
| """ | |||
| SSD Network using mobilenetV1 with fpn to extract features | |||
| Args: | |||
| config (dict): The default config of SSD. | |||
| is_training (bool): Used for training, default is True. | |||
| Returns: | |||
| Tensor, localization predictions. | |||
| Tensor, class conf scores. | |||
| Examples:backbone | |||
| SsdMobilenetV1Fpn(config, True). | |||
| """ | |||
| def __init__(self, config, is_training=True): | |||
| super(SsdMobilenetV1Fpn, self).__init__() | |||
| self.multi_box = WeightSharedMultiBox(config) | |||
| self.is_training = is_training | |||
| if not is_training: | |||
| self.activation = P.Sigmoid() | |||
| self.feature_extractor = mobilenet_v1_fpn(config) | |||
| def construct(self, x): | |||
| features = self.feature_extractor(x) | |||
| pred_loc, pred_label = self.multi_box(features) | |||
| if not self.is_training: | |||
| pred_label = self.activation(pred_label) | |||
| pred_loc = F.cast(pred_loc, mstype.float32) | |||
| pred_label = F.cast(pred_label, mstype.float32) | |||
| return pred_loc, pred_label | |||
| class SigmoidFocalClassificationLoss(nn.Cell): | |||
| """" | |||
| Sigmoid focal-loss for classification. | |||
| @@ -328,6 +448,12 @@ class SSDWithLossCell(nn.Cell): | |||
| return self.reduce_sum((loss_cls + loss_loc) / num_matched_boxes) | |||
| grad_scale = C.MultitypeFuncGraph("grad_scale") | |||
| @grad_scale.register("Tensor", "Tensor") | |||
| def tensor_grad_scale(scale, grad): | |||
| return grad * P.Reciprocal()(scale) | |||
| class TrainingWrapper(nn.Cell): | |||
| """ | |||
| Encapsulation class of SSD network training. | |||
| @@ -339,8 +465,9 @@ class TrainingWrapper(nn.Cell): | |||
| network (Cell): The training network. Note that loss function should have been added. | |||
| optimizer (Optimizer): Optimizer for updating the weights. | |||
| sens (Number): The adjust parameter. Default: 1.0. | |||
| use_global_nrom(bool): Whether apply global norm before optimizer. Default: False | |||
| """ | |||
| def __init__(self, network, optimizer, sens=1.0): | |||
| def __init__(self, network, optimizer, sens=1.0, use_global_norm=False): | |||
| super(TrainingWrapper, self).__init__(auto_prefix=False) | |||
| self.network = network | |||
| self.network.set_grad() | |||
| @@ -350,6 +477,7 @@ class TrainingWrapper(nn.Cell): | |||
| self.sens = sens | |||
| self.reducer_flag = False | |||
| self.grad_reducer = None | |||
| self.use_global_norm = use_global_norm | |||
| self.parallel_mode = context.get_auto_parallel_context("parallel_mode") | |||
| if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: | |||
| self.reducer_flag = True | |||
| @@ -360,6 +488,7 @@ class TrainingWrapper(nn.Cell): | |||
| else: | |||
| degree = get_group_size() | |||
| self.grad_reducer = nn.DistributedGradReducer(optimizer.parameters, mean, degree) | |||
| self.hyper_map = C.HyperMap() | |||
| def construct(self, *args): | |||
| weights = self.weights | |||
| @@ -369,6 +498,9 @@ class TrainingWrapper(nn.Cell): | |||
| if self.reducer_flag: | |||
| # apply grad reducer on grads | |||
| grads = self.grad_reducer(grads) | |||
| if self.use_global_norm: | |||
| grads = self.hyper_map(F.partial(grad_scale, F.scalar_to_array(self.sens)), grads) | |||
| grads = C.clip_by_global_norm(grads) | |||
| return F.depend(loss, self.optimizer(grads)) | |||
| @@ -439,5 +571,10 @@ class SSDWithMobileNetV2(nn.Cell): | |||
| def get_out_channels(self): | |||
| return self.last_channel | |||
| def ssd_mobilenet_v1_fpn(**kwargs): | |||
| return SsdMobilenetV1Fpn(**kwargs) | |||
| def ssd_mobilenet_v2(**kwargs): | |||
| return SSDWithMobileNetV2(**kwargs) | |||
| @@ -25,7 +25,7 @@ from mindspore.train import Model | |||
| from mindspore.context import ParallelMode | |||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| from mindspore.common import set_seed, dtype | |||
| from src.ssd import SSD300, SSDWithLossCell, TrainingWrapper, ssd_mobilenet_v2 | |||
| from src.ssd import SSD300, SSDWithLossCell, TrainingWrapper, ssd_mobilenet_v2, ssd_mobilenet_v1_fpn | |||
| from src.config import config | |||
| from src.dataset import create_ssd_dataset, create_mindrecord | |||
| from src.lr_schedule import get_lr | |||
| @@ -74,63 +74,85 @@ def main(): | |||
| context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True, | |||
| device_num=device_num) | |||
| init() | |||
| context.set_auto_parallel_context(all_reduce_fusion_config=[29, 58, 89]) | |||
| rank = get_rank() | |||
| mindrecord_file = create_mindrecord(args_opt.dataset, "ssd.mindrecord", True) | |||
| if not args_opt.only_create_dataset: | |||
| loss_scale = float(args_opt.loss_scale) | |||
| if args_opt.run_platform == "CPU": | |||
| loss_scale = 1.0 | |||
| # When create MindDataset, using the fitst mindrecord file, such as ssd.mindrecord0. | |||
| use_multiprocessing = (args_opt.run_platform != "CPU") | |||
| dataset = create_ssd_dataset(mindrecord_file, repeat_num=1, batch_size=args_opt.batch_size, | |||
| device_num=device_num, rank=rank, use_multiprocessing=use_multiprocessing) | |||
| if args_opt.only_create_dataset: | |||
| return | |||
| dataset_size = dataset.get_dataset_size() | |||
| print("Create dataset done!") | |||
| loss_scale = float(args_opt.loss_scale) | |||
| if args_opt.run_platform == "CPU": | |||
| loss_scale = 1.0 | |||
| backbone = ssd_mobilenet_v2() | |||
| ssd = SSD300(backbone=backbone, config=config) | |||
| if args_opt.run_platform == "GPU": | |||
| ssd.to_float(dtype.float16) | |||
| net = SSDWithLossCell(ssd, config) | |||
| init_net_param(net) | |||
| # checkpoint | |||
| ckpt_config = CheckpointConfig(save_checkpoint_steps=dataset_size * args_opt.save_checkpoint_epochs) | |||
| save_ckpt_path = './ckpt_' + str(rank) + '/' | |||
| ckpoint_cb = ModelCheckpoint(prefix="ssd", directory=save_ckpt_path, config=ckpt_config) | |||
| if args_opt.pre_trained: | |||
| param_dict = load_checkpoint(args_opt.pre_trained) | |||
| if args_opt.filter_weight: | |||
| filter_checkpoint_parameter(param_dict) | |||
| load_param_into_net(net, param_dict) | |||
| if args_opt.freeze_layer == "backbone": | |||
| for param in backbone.feature_1.trainable_params(): | |||
| param.requires_grad = False | |||
| lr = Tensor(get_lr(global_step=args_opt.pre_trained_epoch_size * dataset_size, | |||
| lr_init=config.lr_init, lr_end=config.lr_end_rate * args_opt.lr, lr_max=args_opt.lr, | |||
| warmup_epochs=config.warmup_epochs, | |||
| total_epochs=args_opt.epoch_size, | |||
| steps_per_epoch=dataset_size)) | |||
| # When create MindDataset, using the fitst mindrecord file, such as ssd.mindrecord0. | |||
| use_multiprocessing = (args_opt.run_platform != "CPU") | |||
| dataset = create_ssd_dataset(mindrecord_file, repeat_num=1, batch_size=args_opt.batch_size, | |||
| device_num=device_num, rank=rank, use_multiprocessing=use_multiprocessing) | |||
| dataset_size = dataset.get_dataset_size() | |||
| print("Create dataset done!") | |||
| backbone = ssd_mobilenet_v2() | |||
| if config.model == "ssd300": | |||
| ssd = SSD300(backbone=backbone, config=config) | |||
| elif config.model == "ssd_mobilenet_v1_fpn": | |||
| ssd = ssd_mobilenet_v1_fpn(config=config) | |||
| else: | |||
| raise ValueError(f'config.model: {config.model} is not supported') | |||
| if args_opt.run_platform == "GPU": | |||
| ssd.to_float(dtype.float16) | |||
| net = SSDWithLossCell(ssd, config) | |||
| init_net_param(net) | |||
| if config.feature_extractor_base_param != "": | |||
| param_dict = load_checkpoint(config.feature_extractor_base_param) | |||
| for x in list(param_dict.keys()): | |||
| param_dict["network.feature_extractor.mobilenet_v1." + x] = param_dict[x] | |||
| del param_dict[x] | |||
| load_param_into_net(ssd.feature_extractor.mobilenet_v1.network, param_dict) | |||
| # checkpoint | |||
| ckpt_config = CheckpointConfig(save_checkpoint_steps=dataset_size * args_opt.save_checkpoint_epochs) | |||
| save_ckpt_path = './ckpt_' + str(rank) + '/' | |||
| ckpoint_cb = ModelCheckpoint(prefix="ssd", directory=save_ckpt_path, config=ckpt_config) | |||
| if args_opt.pre_trained: | |||
| param_dict = load_checkpoint(args_opt.pre_trained) | |||
| if args_opt.filter_weight: | |||
| filter_checkpoint_parameter(param_dict) | |||
| load_param_into_net(net, param_dict) | |||
| if args_opt.freeze_layer == "backbone": | |||
| for param in backbone.feature_1.trainable_params(): | |||
| param.requires_grad = False | |||
| lr = Tensor(get_lr(global_step=args_opt.pre_trained_epoch_size * dataset_size, | |||
| lr_init=config.lr_init, lr_end=config.lr_end_rate * args_opt.lr, lr_max=args_opt.lr, | |||
| warmup_epochs=config.warmup_epochs, | |||
| total_epochs=args_opt.epoch_size, | |||
| steps_per_epoch=dataset_size)) | |||
| if "use_global_norm" in config and config.use_global_norm: | |||
| opt = nn.Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, | |||
| config.momentum, config.weight_decay, 1.0) | |||
| net = TrainingWrapper(net, opt, loss_scale, True) | |||
| else: | |||
| opt = nn.Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, | |||
| config.momentum, config.weight_decay, loss_scale) | |||
| net = TrainingWrapper(net, opt, loss_scale) | |||
| callback = [TimeMonitor(data_size=dataset_size), LossMonitor(), ckpoint_cb] | |||
| model = Model(net) | |||
| dataset_sink_mode = False | |||
| if args_opt.mode == "sink" and args_opt.run_platform != "CPU": | |||
| print("In sink mode, one epoch return a loss.") | |||
| dataset_sink_mode = True | |||
| print("Start train SSD, the first epoch will be slower because of the graph compilation.") | |||
| model.train(args_opt.epoch_size, dataset, callbacks=callback, dataset_sink_mode=dataset_sink_mode) | |||
| callback = [TimeMonitor(data_size=dataset_size), LossMonitor(), ckpoint_cb] | |||
| model = Model(net) | |||
| dataset_sink_mode = False | |||
| if args_opt.mode == "sink" and args_opt.run_platform != "CPU": | |||
| print("In sink mode, one epoch return a loss.") | |||
| dataset_sink_mode = True | |||
| print("Start train SSD, the first epoch will be slower because of the graph compilation.") | |||
| model.train(args_opt.epoch_size, dataset, callbacks=callback, dataset_sink_mode=dataset_sink_mode) | |||
| if __name__ == '__main__': | |||
| main() | |||