| @@ -21,7 +21,7 @@ import time | |||||
| import numpy as np | import numpy as np | ||||
| from mindspore import context, Tensor | from mindspore import context, Tensor | ||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | from mindspore.train.serialization import load_checkpoint, load_param_into_net | ||||
| from src.ssd import SSD300, ssd_mobilenet_v2 | |||||
| from src.ssd import SSD300, ssd_mobilenet_v2, ssd_mobilenet_v1_fpn | |||||
| from src.dataset import create_ssd_dataset, create_mindrecord | from src.dataset import create_ssd_dataset, create_mindrecord | ||||
| from src.config import config | from src.config import config | ||||
| from src.eval_utils import metrics | from src.eval_utils import metrics | ||||
| @@ -31,7 +31,10 @@ def ssd_eval(dataset_path, ckpt_path, anno_json): | |||||
| batch_size = 1 | batch_size = 1 | ||||
| ds = create_ssd_dataset(dataset_path, batch_size=batch_size, repeat_num=1, | ds = create_ssd_dataset(dataset_path, batch_size=batch_size, repeat_num=1, | ||||
| is_training=False, use_multiprocessing=False) | is_training=False, use_multiprocessing=False) | ||||
| net = SSD300(ssd_mobilenet_v2(), config, is_training=False) | |||||
| if config.model == "ssd300": | |||||
| net = SSD300(ssd_mobilenet_v2(), config, is_training=False) | |||||
| else: | |||||
| net = ssd_mobilenet_v1_fpn(config=config) | |||||
| print("Load Checkpoint!") | print("Load Checkpoint!") | ||||
| param_dict = load_checkpoint(ckpt_path) | param_dict = load_checkpoint(ckpt_path) | ||||
| net.init_parameters_data() | net.init_parameters_data() | ||||
| @@ -0,0 +1,92 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Anchor Generator""" | |||||
| import numpy as np | |||||
| class GridAnchorGenerator: | |||||
| """ | |||||
| Anchor Generator | |||||
| """ | |||||
| def __init__(self, image_shape, scale, scales_per_octave, aspect_ratios): | |||||
| super(GridAnchorGenerator, self).__init__() | |||||
| self.scale = scale | |||||
| self.scales_per_octave = scales_per_octave | |||||
| self.aspect_ratios = aspect_ratios | |||||
| self.image_shape = image_shape | |||||
| def generate(self, step): | |||||
| scales = np.array([2**(float(scale) / self.scales_per_octave) | |||||
| for scale in range(self.scales_per_octave)]).astype(np.float32) | |||||
| aspects = np.array(list(self.aspect_ratios)).astype(np.float32) | |||||
| scales_grid, aspect_ratios_grid = np.meshgrid(scales, aspects) | |||||
| scales_grid = scales_grid.reshape([-1]) | |||||
| aspect_ratios_grid = aspect_ratios_grid.reshape([-1]) | |||||
| feature_size = [self.image_shape[0] / step, self.image_shape[0] / step] | |||||
| grid_height, grid_width = feature_size | |||||
| base_size = np.array([self.scale * step, self.scale * step]).astype(np.float32) | |||||
| anchor_offset = step / 2.0 | |||||
| ratio_sqrt = np.sqrt(aspect_ratios_grid) | |||||
| heights = scales_grid / ratio_sqrt * base_size[0] | |||||
| widths = scales_grid * ratio_sqrt * base_size[1] | |||||
| y_centers = np.arange(grid_height).astype(np.float32) | |||||
| y_centers = y_centers * step + anchor_offset | |||||
| x_centers = np.arange(grid_width).astype(np.float32) | |||||
| x_centers = x_centers * step + anchor_offset | |||||
| x_centers, y_centers = np.meshgrid(x_centers, y_centers) | |||||
| x_centers_shape = x_centers.shape | |||||
| y_centers_shape = y_centers.shape | |||||
| widths_grid, x_centers_grid = np.meshgrid(widths, x_centers.reshape([-1])) | |||||
| heights_grid, y_centers_grid = np.meshgrid(heights, y_centers.reshape([-1])) | |||||
| x_centers_grid = x_centers_grid.reshape(*x_centers_shape, -1) | |||||
| y_centers_grid = y_centers_grid.reshape(*y_centers_shape, -1) | |||||
| widths_grid = widths_grid.reshape(-1, *x_centers_shape) | |||||
| heights_grid = heights_grid.reshape(-1, *y_centers_shape) | |||||
| bbox_centers = np.stack([y_centers_grid, x_centers_grid], axis=3) | |||||
| bbox_sizes = np.stack([heights_grid, widths_grid], axis=3) | |||||
| bbox_centers = bbox_centers.reshape([-1, 2]) | |||||
| bbox_sizes = bbox_sizes.reshape([-1, 2]) | |||||
| bbox_corners = np.concatenate([bbox_centers - 0.5 * bbox_sizes, bbox_centers + 0.5 * bbox_sizes], axis=1) | |||||
| self.bbox_corners = bbox_corners / np.array([*self.image_shape, *self.image_shape]).astype(np.float32) | |||||
| self.bbox_centers = np.concatenate([bbox_centers, bbox_sizes], axis=1) | |||||
| self.bbox_centers = self.bbox_centers / np.array([*self.image_shape, *self.image_shape]).astype(np.float32) | |||||
| print(self.bbox_centers.shape) | |||||
| return self.bbox_centers, self.bbox_corners | |||||
| def generate_multi_levels(self, steps): | |||||
| bbox_centers_list = [] | |||||
| bbox_corners_list = [] | |||||
| for step in steps: | |||||
| bbox_centers, bbox_corners = self.generate(step) | |||||
| bbox_centers_list.append(bbox_centers) | |||||
| bbox_corners_list.append(bbox_corners) | |||||
| self.bbox_centers = np.concatenate(bbox_centers_list, axis=0) | |||||
| self.bbox_corners = np.concatenate(bbox_corners_list, axis=0) | |||||
| return self.bbox_centers, self.bbox_corners | |||||
| @@ -19,6 +19,7 @@ import math | |||||
| import itertools as it | import itertools as it | ||||
| import numpy as np | import numpy as np | ||||
| from .config import config | from .config import config | ||||
| from .anchor_generator import GridAnchorGenerator | |||||
| class GeneratDefaultBoxes(): | class GeneratDefaultBoxes(): | ||||
| @@ -36,7 +37,7 @@ class GeneratDefaultBoxes(): | |||||
| sk1 = scales[idex] | sk1 = scales[idex] | ||||
| sk2 = scales[idex + 1] | sk2 = scales[idex + 1] | ||||
| sk3 = math.sqrt(sk1 * sk2) | sk3 = math.sqrt(sk1 * sk2) | ||||
| if idex == 0: | |||||
| if idex == 0 and not config.aspect_ratios[idex]: | |||||
| w, h = sk1 * math.sqrt(2), sk1 / math.sqrt(2) | w, h = sk1 * math.sqrt(2), sk1 / math.sqrt(2) | ||||
| all_sizes = [(0.1, 0.1), (w, h), (h, w)] | all_sizes = [(0.1, 0.1), (w, h), (h, w)] | ||||
| else: | else: | ||||
| @@ -61,9 +62,12 @@ class GeneratDefaultBoxes(): | |||||
| self.default_boxes_tlbr = np.array(tuple(to_tlbr(*i) for i in self.default_boxes), dtype='float32') | self.default_boxes_tlbr = np.array(tuple(to_tlbr(*i) for i in self.default_boxes), dtype='float32') | ||||
| self.default_boxes = np.array(self.default_boxes, dtype='float32') | self.default_boxes = np.array(self.default_boxes, dtype='float32') | ||||
| default_boxes_tlbr = GeneratDefaultBoxes().default_boxes_tlbr | |||||
| default_boxes = GeneratDefaultBoxes().default_boxes | |||||
| if 'use_anchor_generator' in config and config.use_anchor_generator: | |||||
| generator = GridAnchorGenerator(config.img_shape, 4, 2, [1.0, 2.0, 0.5]) | |||||
| default_boxes, default_boxes_tlbr = generator.generate_multi_levels(config.steps) | |||||
| else: | |||||
| default_boxes_tlbr = GeneratDefaultBoxes().default_boxes_tlbr | |||||
| default_boxes = GeneratDefaultBoxes().default_boxes | |||||
| y1, x1, y2, x2 = np.split(default_boxes_tlbr[:, :4], 4, axis=-1) | y1, x1, y2, x2 = np.split(default_boxes_tlbr[:, :4], 4, axis=-1) | ||||
| vol_anchors = (x2 - x1) * (y2 - y1) | vol_anchors = (x2 - x1) * (y2 - y1) | ||||
| matching_threshold = config.match_threshold | matching_threshold = config.match_threshold | ||||
| @@ -15,68 +15,15 @@ | |||||
| """Config parameters for SSD models.""" | """Config parameters for SSD models.""" | ||||
| from easydict import EasyDict as ed | |||||
| from .config_ssd300 import config as config_ssd300 | |||||
| from .config_ssd_mobilenet_v1_fpn import config as config_ssd_mobilenet_v1_fpn | |||||
| config = ed({ | |||||
| "img_shape": [300, 300], | |||||
| "num_ssd_boxes": 1917, | |||||
| "neg_pre_positive": 3, | |||||
| "match_threshold": 0.5, | |||||
| "nms_threshold": 0.6, | |||||
| "min_score": 0.1, | |||||
| "max_boxes": 100, | |||||
| # learing rate settings | |||||
| "lr_init": 0.001, | |||||
| "lr_end_rate": 0.001, | |||||
| "warmup_epochs": 2, | |||||
| "momentum": 0.9, | |||||
| "weight_decay": 1.5e-4, | |||||
| using_model = "ssd300" | |||||
| # network | |||||
| "num_default": [3, 6, 6, 6, 6, 6], | |||||
| "extras_in_channels": [256, 576, 1280, 512, 256, 256], | |||||
| "extras_out_channels": [576, 1280, 512, 256, 256, 128], | |||||
| "extras_strides": [1, 1, 2, 2, 2, 2], | |||||
| "extras_ratio": [0.2, 0.2, 0.2, 0.25, 0.5, 0.25], | |||||
| "feature_size": [19, 10, 5, 3, 2, 1], | |||||
| "min_scale": 0.2, | |||||
| "max_scale": 0.95, | |||||
| "aspect_ratios": [(2,), (2, 3), (2, 3), (2, 3), (2, 3), (2, 3)], | |||||
| "steps": (16, 32, 64, 100, 150, 300), | |||||
| "prior_scaling": (0.1, 0.2), | |||||
| "gamma": 2.0, | |||||
| "alpha": 0.75, | |||||
| config_map = { | |||||
| "ssd300": config_ssd300, | |||||
| "ssd_mobilenet_v1_fpn": config_ssd_mobilenet_v1_fpn | |||||
| } | |||||
| # `mindrecord_dir` and `coco_root` are better to use absolute path. | |||||
| "mindrecord_dir": "/data/MindRecord_COCO", | |||||
| "coco_root": "/data/coco2017", | |||||
| "train_data_type": "train2017", | |||||
| "val_data_type": "val2017", | |||||
| "instances_set": "annotations/instances_{}.json", | |||||
| "classes": ('background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', | |||||
| 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', | |||||
| 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', | |||||
| 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', | |||||
| 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', | |||||
| 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', | |||||
| 'kite', 'baseball bat', 'baseball glove', 'skateboard', | |||||
| 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', | |||||
| 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', | |||||
| 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', | |||||
| 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', | |||||
| 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', | |||||
| 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', | |||||
| 'refrigerator', 'book', 'clock', 'vase', 'scissors', | |||||
| 'teddy bear', 'hair drier', 'toothbrush'), | |||||
| "num_classes": 81, | |||||
| # The annotation.json position of voc validation dataset. | |||||
| "voc_json": "annotations/voc_instances_val.json", | |||||
| # voc original dataset. | |||||
| "voc_root": "/data/voc_dataset", | |||||
| # if coco or voc used, `image_dir` and `anno_path` are useless. | |||||
| "image_dir": "", | |||||
| "anno_path": "", | |||||
| "export_format": "MINDIR", | |||||
| "export_file": "ssd.mindir" | |||||
| }) | |||||
| config = config_map[using_model] | |||||
| @@ -0,0 +1,84 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| #" ============================================================================ | |||||
| """Config parameters for SSD models.""" | |||||
| from easydict import EasyDict as ed | |||||
| config = ed({ | |||||
| "model": "ssd300", | |||||
| "img_shape": [300, 300], | |||||
| "num_ssd_boxes": 1917, | |||||
| "neg_pre_positive": 3, | |||||
| "match_threshold": 0.5, | |||||
| "nms_threshold": 0.6, | |||||
| "min_score": 0.1, | |||||
| "max_boxes": 100, | |||||
| # learing rate settings | |||||
| "lr_init": 0.001, | |||||
| "lr_end_rate": 0.001, | |||||
| "warmup_epochs": 2, | |||||
| "momentum": 0.9, | |||||
| "weight_decay": 1.5e-4, | |||||
| # network | |||||
| "num_default": [3, 6, 6, 6, 6, 6], | |||||
| "extras_in_channels": [256, 576, 1280, 512, 256, 256], | |||||
| "extras_out_channels": [576, 1280, 512, 256, 256, 128], | |||||
| "extras_strides": [1, 1, 2, 2, 2, 2], | |||||
| "extras_ratio": [0.2, 0.2, 0.2, 0.25, 0.5, 0.25], | |||||
| "feature_size": [19, 10, 5, 3, 2, 1], | |||||
| "min_scale": 0.2, | |||||
| "max_scale": 0.95, | |||||
| "aspect_ratios": [(), (2, 3), (2, 3), (2, 3), (2, 3), (2, 3)], | |||||
| "steps": (16, 32, 64, 100, 150, 300), | |||||
| "prior_scaling": (0.1, 0.2), | |||||
| "gamma": 2.0, | |||||
| "alpha": 0.75, | |||||
| # `mindrecord_dir` and `coco_root` are better to use absolute path. | |||||
| "feature_extractor_base_param": "", | |||||
| "mindrecord_dir": "/data/MindRecord_COCO", | |||||
| "coco_root": "/data/coco2017", | |||||
| "train_data_type": "train2017", | |||||
| "val_data_type": "val2017", | |||||
| "instances_set": "annotations/instances_{}.json", | |||||
| "classes": ('background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', | |||||
| 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', | |||||
| 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', | |||||
| 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', | |||||
| 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', | |||||
| 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', | |||||
| 'kite', 'baseball bat', 'baseball glove', 'skateboard', | |||||
| 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', | |||||
| 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', | |||||
| 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', | |||||
| 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', | |||||
| 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', | |||||
| 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', | |||||
| 'refrigerator', 'book', 'clock', 'vase', 'scissors', | |||||
| 'teddy bear', 'hair drier', 'toothbrush'), | |||||
| "num_classes": 81, | |||||
| # The annotation.json position of voc validation dataset. | |||||
| "voc_json": "annotations/voc_instances_val.json", | |||||
| # voc original dataset. | |||||
| "voc_root": "/data/voc_dataset", | |||||
| # if coco or voc used, `image_dir` and `anno_path` are useless. | |||||
| "image_dir": "", | |||||
| "anno_path": "", | |||||
| "export_format": "MINDIR", | |||||
| "export_file": "ssd.mindir" | |||||
| }) | |||||
| @@ -0,0 +1,88 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| #" ============================================================================ | |||||
| """Config parameters for SSD models.""" | |||||
| from easydict import EasyDict as ed | |||||
| config = ed({ | |||||
| "model": "ssd_mobilenet_v1_fpn", | |||||
| "img_shape": [640, 640], | |||||
| "num_ssd_boxes": 51150, | |||||
| "neg_pre_positive": 3, | |||||
| "match_threshold": 0.5, | |||||
| "nms_threshold": 0.6, | |||||
| "min_score": 0.1, | |||||
| "max_boxes": 100, | |||||
| # learning rate settings | |||||
| "global_step": 0, | |||||
| "lr_init": 0.01333, | |||||
| "lr_end_rate": 0.0, | |||||
| "warmup_epochs": 2, | |||||
| "momentum": 0.9, | |||||
| "weight_decay": 1.5e-4, | |||||
| # network | |||||
| "num_default": [6, 6, 6, 6, 6], | |||||
| "extras_in_channels": [256, 512, 1024, 256, 256], | |||||
| "extras_out_channels": [256, 256, 256, 256, 256], | |||||
| "extras_strides": [1, 1, 2, 2, 2, 2], | |||||
| "extras_ratio": [0.2, 0.2, 0.2, 0.25, 0.5, 0.25], | |||||
| "feature_size": [80, 40, 20, 10, 5], | |||||
| "min_scale": 0.2, | |||||
| "max_scale": 0.95, | |||||
| "aspect_ratios": [(2, 3), (2, 3), (2, 3), (2, 3), (2, 3), (2, 3)], | |||||
| "steps": (8, 16, 32, 64, 128), | |||||
| "prior_scaling": (0.1, 0.2), | |||||
| "gamma": 2.0, | |||||
| "alpha": 0.75, | |||||
| "num_addition_layers": 4, | |||||
| "use_anchor_generator": True, | |||||
| "use_global_norm": True, | |||||
| # `mindrecord_dir` and `coco_root` are better to use absolute path. | |||||
| "feature_extractor_base_param": "/ckpt/mobilenet_v1.ckpt", | |||||
| "mindrecord_dir": "/data/MindRecord_COCO", | |||||
| "coco_root": "/data/coco2017", | |||||
| "train_data_type": "train2017", | |||||
| "val_data_type": "val2017", | |||||
| "instances_set": "annotations/instances_{}.json", | |||||
| "classes": ('background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', | |||||
| 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', | |||||
| 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', | |||||
| 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', | |||||
| 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', | |||||
| 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', | |||||
| 'kite', 'baseball bat', 'baseball glove', 'skateboard', | |||||
| 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', | |||||
| 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', | |||||
| 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', | |||||
| 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', | |||||
| 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', | |||||
| 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', | |||||
| 'refrigerator', 'book', 'clock', 'vase', 'scissors', | |||||
| 'teddy bear', 'hair drier', 'toothbrush'), | |||||
| "num_classes": 81, | |||||
| # The annotation.json position of voc validation dataset. | |||||
| "voc_json": "annotations/voc_instances_val.json", | |||||
| # voc original dataset. | |||||
| "voc_root": "/data/voc_dataset", | |||||
| # if coco or voc used, `image_dir` and `anno_path` are useless. | |||||
| "image_dir": "", | |||||
| "anno_path": "", | |||||
| "export_format": "MINDIR", | |||||
| "export_file": "ssd.mindir" | |||||
| }) | |||||
| @@ -22,14 +22,14 @@ def init_net_param(network, initialize_mode='TruncatedNormal'): | |||||
| for p in params: | for p in params: | ||||
| if 'beta' not in p.name and 'gamma' not in p.name and 'bias' not in p.name: | if 'beta' not in p.name and 'gamma' not in p.name and 'bias' not in p.name: | ||||
| if initialize_mode == 'TruncatedNormal': | if initialize_mode == 'TruncatedNormal': | ||||
| p.set_data(initializer(TruncatedNormal(), p.data.shape, p.data.dtype)) | |||||
| p.set_data(initializer(TruncatedNormal(0.02), p.data.shape, p.data.dtype)) | |||||
| else: | else: | ||||
| p.set_data(initialize_mode, p.data.shape, p.data.dtype) | p.set_data(initialize_mode, p.data.shape, p.data.dtype) | ||||
| def load_backbone_params(network, param_dict): | def load_backbone_params(network, param_dict): | ||||
| """Init the parameters from pre-train model, default is mobilenetv2.""" | """Init the parameters from pre-train model, default is mobilenetv2.""" | ||||
| for _, param in net.parameters_and_names(): | |||||
| for _, param in network.parameters_and_names(): | |||||
| param_name = param.name.replace('network.backbone.', '') | param_name = param.name.replace('network.backbone.', '') | ||||
| name_split = param_name.split('.') | name_split = param_name.split('.') | ||||
| if 'features_1' in param_name: | if 'features_1' in param_name: | ||||
| @@ -0,0 +1,192 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import mindspore.nn as nn | |||||
| from mindspore.ops import operations as P | |||||
| from mindspore.ops import functional as F | |||||
| def conv_bn_relu(in_channel, out_channel, kernel_size, stride, depthwise, activation='relu6'): | |||||
| output = [] | |||||
| output.append(nn.Conv2d(in_channel, out_channel, kernel_size, stride, pad_mode="same", | |||||
| group=1 if not depthwise else in_channel)) | |||||
| output.append(nn.BatchNorm2d(out_channel)) | |||||
| if activation: | |||||
| output.append(nn.get_activation(activation)) | |||||
| return nn.SequentialCell(output) | |||||
| class MobileNetV1(nn.Cell): | |||||
| """ | |||||
| MobileNet V1 backbone | |||||
| """ | |||||
| def __init__(self, class_num=1001, features_only=False): | |||||
| super(MobileNetV1, self).__init__() | |||||
| self.features_only = features_only | |||||
| cnn = [ | |||||
| conv_bn_relu(3, 32, 3, 2, False), # Conv0 | |||||
| conv_bn_relu(32, 32, 3, 1, True), # Conv1_depthwise | |||||
| conv_bn_relu(32, 64, 1, 1, False), # Conv1_pointwise | |||||
| conv_bn_relu(64, 64, 3, 2, True), # Conv2_depthwise | |||||
| conv_bn_relu(64, 128, 1, 1, False), # Conv2_pointwise | |||||
| conv_bn_relu(128, 128, 3, 1, True), # Conv3_depthwise | |||||
| conv_bn_relu(128, 128, 1, 1, False), # Conv3_pointwise | |||||
| conv_bn_relu(128, 128, 3, 2, True), # Conv4_depthwise | |||||
| conv_bn_relu(128, 256, 1, 1, False), # Conv4_pointwise | |||||
| conv_bn_relu(256, 256, 3, 1, True), # Conv5_depthwise | |||||
| conv_bn_relu(256, 256, 1, 1, False), # Conv5_pointwise | |||||
| conv_bn_relu(256, 256, 3, 2, True), # Conv6_depthwise | |||||
| conv_bn_relu(256, 512, 1, 1, False), # Conv6_pointwise | |||||
| conv_bn_relu(512, 512, 3, 1, True), # Conv7_depthwise | |||||
| conv_bn_relu(512, 512, 1, 1, False), # Conv7_pointwise | |||||
| conv_bn_relu(512, 512, 3, 1, True), # Conv8_depthwise | |||||
| conv_bn_relu(512, 512, 1, 1, False), # Conv8_pointwise | |||||
| conv_bn_relu(512, 512, 3, 1, True), # Conv9_depthwise | |||||
| conv_bn_relu(512, 512, 1, 1, False), # Conv9_pointwise | |||||
| conv_bn_relu(512, 512, 3, 1, True), # Conv10_depthwise | |||||
| conv_bn_relu(512, 512, 1, 1, False), # Conv10_pointwise | |||||
| conv_bn_relu(512, 512, 3, 1, True), # Conv11_depthwise | |||||
| conv_bn_relu(512, 512, 1, 1, False), # Conv11_pointwise | |||||
| conv_bn_relu(512, 512, 3, 2, True), # Conv12_depthwise | |||||
| conv_bn_relu(512, 1024, 1, 1, False), # Conv12_pointwise | |||||
| conv_bn_relu(1024, 1024, 3, 1, True), # Conv13_depthwise | |||||
| conv_bn_relu(1024, 1024, 1, 1, False), # Conv13_pointwise | |||||
| ] | |||||
| if self.features_only: | |||||
| self.network = nn.CellList(cnn) | |||||
| else: | |||||
| self.network = nn.SequentialCell(cnn) | |||||
| self.fc = nn.Dense(1024, class_num) | |||||
| def construct(self, x): | |||||
| output = x | |||||
| if self.features_only: | |||||
| features = () | |||||
| for block in self.network: | |||||
| output = block(output) | |||||
| features = features + (output,) | |||||
| return features | |||||
| output = self.network(x) | |||||
| output = P.ReduceMean()(output, (2, 3)) | |||||
| output = self.fc(output) | |||||
| return output | |||||
| class FpnTopDown(nn.Cell): | |||||
| """ | |||||
| Fpn to extract features | |||||
| """ | |||||
| def __init__(self, in_channel_list, out_channels): | |||||
| super(FpnTopDown, self).__init__() | |||||
| self.lateral_convs_list_ = [] | |||||
| self.fpn_convs_ = [] | |||||
| for channel in in_channel_list: | |||||
| l_conv = nn.Conv2d(channel, out_channels, kernel_size=1, stride=1, | |||||
| has_bias=True, padding=0, pad_mode='same') | |||||
| fpn_conv = conv_bn_relu(out_channels, out_channels, kernel_size=3, stride=1, depthwise=False) | |||||
| self.lateral_convs_list_.append(l_conv) | |||||
| self.fpn_convs_.append(fpn_conv) | |||||
| self.lateral_convs_list = nn.layer.CellList(self.lateral_convs_list_) | |||||
| self.fpn_convs_list = nn.layer.CellList(self.fpn_convs_) | |||||
| self.num_layers = len(in_channel_list) | |||||
| def construct(self, inputs): | |||||
| image_features = () | |||||
| for i, feature in enumerate(inputs): | |||||
| image_features = image_features + (self.lateral_convs_list[i](feature),) | |||||
| features = (image_features[-1],) | |||||
| for i in range(len(inputs) - 1): | |||||
| top = len(inputs) - i - 1 | |||||
| down = top - 1 | |||||
| size = F.shape(inputs[down]) | |||||
| top_down = P.ResizeBilinear((size[2], size[3]))(features[-1]) | |||||
| top_down = top_down + image_features[down] | |||||
| features = features + (top_down,) | |||||
| extract_features = () | |||||
| num_features = len(features) | |||||
| for i in range(num_features): | |||||
| extract_features = extract_features + (self.fpn_convs_list[i](features[num_features - i - 1]),) | |||||
| return extract_features | |||||
| class BottomUp(nn.Cell): | |||||
| """ | |||||
| Bottom Up feature extractor | |||||
| """ | |||||
| def __init__(self, levels, channels, kernel_size, stride): | |||||
| super(BottomUp, self).__init__() | |||||
| self.levels = levels | |||||
| bottom_up_cells = [ | |||||
| conv_bn_relu(channels, channels, kernel_size, stride, False) for x in range(self.levels) | |||||
| ] | |||||
| self.blocks = nn.CellList(bottom_up_cells) | |||||
| def construct(self, features): | |||||
| for block in self.blocks: | |||||
| features = features + (block(features[-1]),) | |||||
| return features | |||||
| class FeatureSelector(nn.Cell): | |||||
| """ | |||||
| Select specific layers from an entire feature list | |||||
| """ | |||||
| def __init__(self, feature_idxes): | |||||
| super(FeatureSelector, self).__init__() | |||||
| self.feature_idxes = feature_idxes | |||||
| def construct(self, feature_list): | |||||
| selected = () | |||||
| for i in self.feature_idxes: | |||||
| selected = selected + (feature_list[i],) | |||||
| return selected | |||||
| class MobileNetV1Fpn(nn.Cell): | |||||
| """ | |||||
| MobileNetV1 with FPN as SSD backbone. | |||||
| """ | |||||
| def __init__(self, config): | |||||
| super(MobileNetV1Fpn, self).__init__() | |||||
| self.mobilenet_v1 = MobileNetV1(features_only=True) | |||||
| self.selector = FeatureSelector([10, 22, 26]) | |||||
| self.layer_indexs = [10, 22, 26] | |||||
| self.fpn = FpnTopDown([256, 512, 1024], 256) | |||||
| self.bottom_up = BottomUp(2, 256, 3, 2) | |||||
| def construct(self, x): | |||||
| features = self.mobilenet_v1(x) | |||||
| features = self.selector(features) | |||||
| features = self.fpn(features) | |||||
| features = self.bottom_up(features) | |||||
| return features | |||||
| def mobilenet_v1_fpn(config): | |||||
| return MobileNetV1Fpn(config) | |||||
| def mobilenet_v1(class_num=1001): | |||||
| return MobileNetV1(class_num) | |||||
| @@ -26,6 +26,8 @@ from mindspore.ops import operations as P | |||||
| from mindspore.ops import functional as F | from mindspore.ops import functional as F | ||||
| from mindspore.ops import composite as C | from mindspore.ops import composite as C | ||||
| from .mobilenet_v1_fpn import mobilenet_v1_fpn | |||||
| def _make_divisible(v, divisor, min_value=None): | def _make_divisible(v, divisor, min_value=None): | ||||
| """nsures that all layers have a channel number that is divisible by 8.""" | """nsures that all layers have a channel number that is divisible by 8.""" | ||||
| @@ -67,6 +69,7 @@ class ConvBNReLU(nn.Cell): | |||||
| kernel_size (int): Input kernel size. | kernel_size (int): Input kernel size. | ||||
| stride (int): Stride size for the first convolutional layer. Default: 1. | stride (int): Stride size for the first convolutional layer. Default: 1. | ||||
| groups (int): channel group. Convolution is 1 while Depthiwse is input channel. Default: 1. | groups (int): channel group. Convolution is 1 while Depthiwse is input channel. Default: 1. | ||||
| shared_conv(Cell): Use the weight shared conv, default: None. | |||||
| Returns: | Returns: | ||||
| Tensor, output tensor. | Tensor, output tensor. | ||||
| @@ -74,18 +77,21 @@ class ConvBNReLU(nn.Cell): | |||||
| Examples: | Examples: | ||||
| >>> ConvBNReLU(16, 256, kernel_size=1, stride=1, groups=1) | >>> ConvBNReLU(16, 256, kernel_size=1, stride=1, groups=1) | ||||
| """ | """ | ||||
| def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1): | |||||
| def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1, shared_conv=None): | |||||
| super(ConvBNReLU, self).__init__() | super(ConvBNReLU, self).__init__() | ||||
| padding = 0 | padding = 0 | ||||
| in_channels = in_planes | in_channels = in_planes | ||||
| out_channels = out_planes | out_channels = out_planes | ||||
| if groups == 1: | |||||
| conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, pad_mode='same', padding=padding) | |||||
| if shared_conv is None: | |||||
| if groups == 1: | |||||
| conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, pad_mode='same', padding=padding) | |||||
| else: | |||||
| out_channels = in_planes | |||||
| conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, pad_mode='same', | |||||
| padding=padding, group=in_channels) | |||||
| layers = [conv, _bn(out_planes), nn.ReLU6()] | |||||
| else: | else: | ||||
| out_channels = in_planes | |||||
| conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, pad_mode='same', | |||||
| padding=padding, group=in_channels) | |||||
| layers = [conv, _bn(out_planes), nn.ReLU6()] | |||||
| layers = [shared_conv, _bn(out_planes), nn.ReLU6()] | |||||
| self.features = nn.SequentialCell(layers) | self.features = nn.SequentialCell(layers) | ||||
| def construct(self, x): | def construct(self, x): | ||||
| @@ -205,6 +211,86 @@ class MultiBox(nn.Cell): | |||||
| return self.flatten_concat(loc_outputs), self.flatten_concat(cls_outputs) | return self.flatten_concat(loc_outputs), self.flatten_concat(cls_outputs) | ||||
| class WeightSharedMultiBox(nn.Cell): | |||||
| """ | |||||
| Weight shared Multi-box conv layers. Each multi-box layer contains class conf scores and localization predictions. | |||||
| All box predictors shares the same conv weight in different features. | |||||
| Args: | |||||
| config (dict): The default config of SSD. | |||||
| loc_cls_shared_addition(bool): Whether the location predictor and classifier prediction share the | |||||
| same addition layer. | |||||
| Returns: | |||||
| Tensor, localization predictions. | |||||
| Tensor, class conf scores. | |||||
| """ | |||||
| def __init__(self, config, loc_cls_shared_addition=False): | |||||
| super(WeightSharedMultiBox, self).__init__() | |||||
| num_classes = config.num_classes | |||||
| out_channels = config.extras_out_channels[0] | |||||
| num_default = config.num_default[0] | |||||
| num_features = len(config.feature_size) | |||||
| num_addition_layers = config.num_addition_layers | |||||
| self.loc_cls_shared_addition = loc_cls_shared_addition | |||||
| if not loc_cls_shared_addition: | |||||
| loc_convs = [ | |||||
| _conv2d(out_channels, out_channels, 3, 1) for x in range(num_addition_layers) | |||||
| ] | |||||
| cls_convs = [ | |||||
| _conv2d(out_channels, out_channels, 3, 1) for x in range(num_addition_layers) | |||||
| ] | |||||
| addition_loc_layer_list = [] | |||||
| addition_cls_layer_list = [] | |||||
| for _ in range(num_features): | |||||
| addition_loc_layer = [ | |||||
| ConvBNReLU(out_channels, out_channels, 3, 1, 1, loc_convs[x]) for x in range(num_addition_layers) | |||||
| ] | |||||
| addition_cls_layer = [ | |||||
| ConvBNReLU(out_channels, out_channels, 3, 1, 1, cls_convs[x]) for x in range(num_addition_layers) | |||||
| ] | |||||
| addition_loc_layer_list.append(nn.SequentialCell(addition_loc_layer)) | |||||
| addition_cls_layer_list.append(nn.SequentialCell(addition_cls_layer)) | |||||
| self.addition_layer_loc = nn.CellList(addition_loc_layer_list) | |||||
| self.addition_layer_cls = nn.CellList(addition_cls_layer_list) | |||||
| else: | |||||
| convs = [ | |||||
| _conv2d(out_channels, out_channels, 3, 1) for x in range(num_addition_layers) | |||||
| ] | |||||
| addition_layer_list = [] | |||||
| for _ in range(num_features): | |||||
| addition_layers = [ | |||||
| ConvBNReLU(out_channels, out_channels, 3, 1, 1, convs[x]) for x in range(num_addition_layers) | |||||
| ] | |||||
| addition_layer_list.append(nn.SequentialCell(addition_layers)) | |||||
| self.addition_layer = nn.SequentialCell(addition_layer_list) | |||||
| loc_layers = [_conv2d(out_channels, 4 * num_default, | |||||
| kernel_size=3, stride=1, pad_mod='same')] | |||||
| cls_layers = [_conv2d(out_channels, num_classes * num_default, | |||||
| kernel_size=3, stride=1, pad_mod='same')] | |||||
| self.loc_layers = nn.SequentialCell(loc_layers) | |||||
| self.cls_layers = nn.SequentialCell(cls_layers) | |||||
| self.flatten_concat = FlattenConcat(config) | |||||
| def construct(self, inputs): | |||||
| loc_outputs = () | |||||
| cls_outputs = () | |||||
| num_heads = len(inputs) | |||||
| for i in range(num_heads): | |||||
| if self.loc_cls_shared_addition: | |||||
| features = self.addition_layer[i](inputs[i]) | |||||
| loc_outputs += (self.loc_layers(features),) | |||||
| cls_outputs += (self.cls_layers(features),) | |||||
| else: | |||||
| features = self.addition_layer_loc[i](inputs[i]) | |||||
| loc_outputs += (self.loc_layers(features),) | |||||
| features = self.addition_layer_cls[i](inputs[i]) | |||||
| cls_outputs += (self.cls_layers(features),) | |||||
| return self.flatten_concat(loc_outputs), self.flatten_concat(cls_outputs) | |||||
| class SSD300(nn.Cell): | class SSD300(nn.Cell): | ||||
| """ | """ | ||||
| SSD300 Network. Default backbone is resnet34. | SSD300 Network. Default backbone is resnet34. | ||||
| @@ -255,6 +341,40 @@ class SSD300(nn.Cell): | |||||
| return pred_loc, pred_label | return pred_loc, pred_label | ||||
| class SsdMobilenetV1Fpn(nn.Cell): | |||||
| """ | |||||
| SSD Network using mobilenetV1 with fpn to extract features | |||||
| Args: | |||||
| config (dict): The default config of SSD. | |||||
| is_training (bool): Used for training, default is True. | |||||
| Returns: | |||||
| Tensor, localization predictions. | |||||
| Tensor, class conf scores. | |||||
| Examples:backbone | |||||
| SsdMobilenetV1Fpn(config, True). | |||||
| """ | |||||
| def __init__(self, config, is_training=True): | |||||
| super(SsdMobilenetV1Fpn, self).__init__() | |||||
| self.multi_box = WeightSharedMultiBox(config) | |||||
| self.is_training = is_training | |||||
| if not is_training: | |||||
| self.activation = P.Sigmoid() | |||||
| self.feature_extractor = mobilenet_v1_fpn(config) | |||||
| def construct(self, x): | |||||
| features = self.feature_extractor(x) | |||||
| pred_loc, pred_label = self.multi_box(features) | |||||
| if not self.is_training: | |||||
| pred_label = self.activation(pred_label) | |||||
| pred_loc = F.cast(pred_loc, mstype.float32) | |||||
| pred_label = F.cast(pred_label, mstype.float32) | |||||
| return pred_loc, pred_label | |||||
| class SigmoidFocalClassificationLoss(nn.Cell): | class SigmoidFocalClassificationLoss(nn.Cell): | ||||
| """" | """" | ||||
| Sigmoid focal-loss for classification. | Sigmoid focal-loss for classification. | ||||
| @@ -328,6 +448,12 @@ class SSDWithLossCell(nn.Cell): | |||||
| return self.reduce_sum((loss_cls + loss_loc) / num_matched_boxes) | return self.reduce_sum((loss_cls + loss_loc) / num_matched_boxes) | ||||
| grad_scale = C.MultitypeFuncGraph("grad_scale") | |||||
| @grad_scale.register("Tensor", "Tensor") | |||||
| def tensor_grad_scale(scale, grad): | |||||
| return grad * P.Reciprocal()(scale) | |||||
| class TrainingWrapper(nn.Cell): | class TrainingWrapper(nn.Cell): | ||||
| """ | """ | ||||
| Encapsulation class of SSD network training. | Encapsulation class of SSD network training. | ||||
| @@ -339,8 +465,9 @@ class TrainingWrapper(nn.Cell): | |||||
| network (Cell): The training network. Note that loss function should have been added. | network (Cell): The training network. Note that loss function should have been added. | ||||
| optimizer (Optimizer): Optimizer for updating the weights. | optimizer (Optimizer): Optimizer for updating the weights. | ||||
| sens (Number): The adjust parameter. Default: 1.0. | sens (Number): The adjust parameter. Default: 1.0. | ||||
| use_global_nrom(bool): Whether apply global norm before optimizer. Default: False | |||||
| """ | """ | ||||
| def __init__(self, network, optimizer, sens=1.0): | |||||
| def __init__(self, network, optimizer, sens=1.0, use_global_norm=False): | |||||
| super(TrainingWrapper, self).__init__(auto_prefix=False) | super(TrainingWrapper, self).__init__(auto_prefix=False) | ||||
| self.network = network | self.network = network | ||||
| self.network.set_grad() | self.network.set_grad() | ||||
| @@ -350,6 +477,7 @@ class TrainingWrapper(nn.Cell): | |||||
| self.sens = sens | self.sens = sens | ||||
| self.reducer_flag = False | self.reducer_flag = False | ||||
| self.grad_reducer = None | self.grad_reducer = None | ||||
| self.use_global_norm = use_global_norm | |||||
| self.parallel_mode = context.get_auto_parallel_context("parallel_mode") | self.parallel_mode = context.get_auto_parallel_context("parallel_mode") | ||||
| if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: | if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: | ||||
| self.reducer_flag = True | self.reducer_flag = True | ||||
| @@ -360,6 +488,7 @@ class TrainingWrapper(nn.Cell): | |||||
| else: | else: | ||||
| degree = get_group_size() | degree = get_group_size() | ||||
| self.grad_reducer = nn.DistributedGradReducer(optimizer.parameters, mean, degree) | self.grad_reducer = nn.DistributedGradReducer(optimizer.parameters, mean, degree) | ||||
| self.hyper_map = C.HyperMap() | |||||
| def construct(self, *args): | def construct(self, *args): | ||||
| weights = self.weights | weights = self.weights | ||||
| @@ -369,6 +498,9 @@ class TrainingWrapper(nn.Cell): | |||||
| if self.reducer_flag: | if self.reducer_flag: | ||||
| # apply grad reducer on grads | # apply grad reducer on grads | ||||
| grads = self.grad_reducer(grads) | grads = self.grad_reducer(grads) | ||||
| if self.use_global_norm: | |||||
| grads = self.hyper_map(F.partial(grad_scale, F.scalar_to_array(self.sens)), grads) | |||||
| grads = C.clip_by_global_norm(grads) | |||||
| return F.depend(loss, self.optimizer(grads)) | return F.depend(loss, self.optimizer(grads)) | ||||
| @@ -439,5 +571,10 @@ class SSDWithMobileNetV2(nn.Cell): | |||||
| def get_out_channels(self): | def get_out_channels(self): | ||||
| return self.last_channel | return self.last_channel | ||||
| def ssd_mobilenet_v1_fpn(**kwargs): | |||||
| return SsdMobilenetV1Fpn(**kwargs) | |||||
| def ssd_mobilenet_v2(**kwargs): | def ssd_mobilenet_v2(**kwargs): | ||||
| return SSDWithMobileNetV2(**kwargs) | return SSDWithMobileNetV2(**kwargs) | ||||
| @@ -25,7 +25,7 @@ from mindspore.train import Model | |||||
| from mindspore.context import ParallelMode | from mindspore.context import ParallelMode | ||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | from mindspore.train.serialization import load_checkpoint, load_param_into_net | ||||
| from mindspore.common import set_seed, dtype | from mindspore.common import set_seed, dtype | ||||
| from src.ssd import SSD300, SSDWithLossCell, TrainingWrapper, ssd_mobilenet_v2 | |||||
| from src.ssd import SSD300, SSDWithLossCell, TrainingWrapper, ssd_mobilenet_v2, ssd_mobilenet_v1_fpn | |||||
| from src.config import config | from src.config import config | ||||
| from src.dataset import create_ssd_dataset, create_mindrecord | from src.dataset import create_ssd_dataset, create_mindrecord | ||||
| from src.lr_schedule import get_lr | from src.lr_schedule import get_lr | ||||
| @@ -74,63 +74,85 @@ def main(): | |||||
| context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True, | context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True, | ||||
| device_num=device_num) | device_num=device_num) | ||||
| init() | init() | ||||
| context.set_auto_parallel_context(all_reduce_fusion_config=[29, 58, 89]) | |||||
| rank = get_rank() | rank = get_rank() | ||||
| mindrecord_file = create_mindrecord(args_opt.dataset, "ssd.mindrecord", True) | mindrecord_file = create_mindrecord(args_opt.dataset, "ssd.mindrecord", True) | ||||
| if not args_opt.only_create_dataset: | |||||
| loss_scale = float(args_opt.loss_scale) | |||||
| if args_opt.run_platform == "CPU": | |||||
| loss_scale = 1.0 | |||||
| # When create MindDataset, using the fitst mindrecord file, such as ssd.mindrecord0. | |||||
| use_multiprocessing = (args_opt.run_platform != "CPU") | |||||
| dataset = create_ssd_dataset(mindrecord_file, repeat_num=1, batch_size=args_opt.batch_size, | |||||
| device_num=device_num, rank=rank, use_multiprocessing=use_multiprocessing) | |||||
| if args_opt.only_create_dataset: | |||||
| return | |||||
| dataset_size = dataset.get_dataset_size() | |||||
| print("Create dataset done!") | |||||
| loss_scale = float(args_opt.loss_scale) | |||||
| if args_opt.run_platform == "CPU": | |||||
| loss_scale = 1.0 | |||||
| backbone = ssd_mobilenet_v2() | |||||
| ssd = SSD300(backbone=backbone, config=config) | |||||
| if args_opt.run_platform == "GPU": | |||||
| ssd.to_float(dtype.float16) | |||||
| net = SSDWithLossCell(ssd, config) | |||||
| init_net_param(net) | |||||
| # checkpoint | |||||
| ckpt_config = CheckpointConfig(save_checkpoint_steps=dataset_size * args_opt.save_checkpoint_epochs) | |||||
| save_ckpt_path = './ckpt_' + str(rank) + '/' | |||||
| ckpoint_cb = ModelCheckpoint(prefix="ssd", directory=save_ckpt_path, config=ckpt_config) | |||||
| if args_opt.pre_trained: | |||||
| param_dict = load_checkpoint(args_opt.pre_trained) | |||||
| if args_opt.filter_weight: | |||||
| filter_checkpoint_parameter(param_dict) | |||||
| load_param_into_net(net, param_dict) | |||||
| if args_opt.freeze_layer == "backbone": | |||||
| for param in backbone.feature_1.trainable_params(): | |||||
| param.requires_grad = False | |||||
| lr = Tensor(get_lr(global_step=args_opt.pre_trained_epoch_size * dataset_size, | |||||
| lr_init=config.lr_init, lr_end=config.lr_end_rate * args_opt.lr, lr_max=args_opt.lr, | |||||
| warmup_epochs=config.warmup_epochs, | |||||
| total_epochs=args_opt.epoch_size, | |||||
| steps_per_epoch=dataset_size)) | |||||
| # When create MindDataset, using the fitst mindrecord file, such as ssd.mindrecord0. | |||||
| use_multiprocessing = (args_opt.run_platform != "CPU") | |||||
| dataset = create_ssd_dataset(mindrecord_file, repeat_num=1, batch_size=args_opt.batch_size, | |||||
| device_num=device_num, rank=rank, use_multiprocessing=use_multiprocessing) | |||||
| dataset_size = dataset.get_dataset_size() | |||||
| print("Create dataset done!") | |||||
| backbone = ssd_mobilenet_v2() | |||||
| if config.model == "ssd300": | |||||
| ssd = SSD300(backbone=backbone, config=config) | |||||
| elif config.model == "ssd_mobilenet_v1_fpn": | |||||
| ssd = ssd_mobilenet_v1_fpn(config=config) | |||||
| else: | |||||
| raise ValueError(f'config.model: {config.model} is not supported') | |||||
| if args_opt.run_platform == "GPU": | |||||
| ssd.to_float(dtype.float16) | |||||
| net = SSDWithLossCell(ssd, config) | |||||
| init_net_param(net) | |||||
| if config.feature_extractor_base_param != "": | |||||
| param_dict = load_checkpoint(config.feature_extractor_base_param) | |||||
| for x in list(param_dict.keys()): | |||||
| param_dict["network.feature_extractor.mobilenet_v1." + x] = param_dict[x] | |||||
| del param_dict[x] | |||||
| load_param_into_net(ssd.feature_extractor.mobilenet_v1.network, param_dict) | |||||
| # checkpoint | |||||
| ckpt_config = CheckpointConfig(save_checkpoint_steps=dataset_size * args_opt.save_checkpoint_epochs) | |||||
| save_ckpt_path = './ckpt_' + str(rank) + '/' | |||||
| ckpoint_cb = ModelCheckpoint(prefix="ssd", directory=save_ckpt_path, config=ckpt_config) | |||||
| if args_opt.pre_trained: | |||||
| param_dict = load_checkpoint(args_opt.pre_trained) | |||||
| if args_opt.filter_weight: | |||||
| filter_checkpoint_parameter(param_dict) | |||||
| load_param_into_net(net, param_dict) | |||||
| if args_opt.freeze_layer == "backbone": | |||||
| for param in backbone.feature_1.trainable_params(): | |||||
| param.requires_grad = False | |||||
| lr = Tensor(get_lr(global_step=args_opt.pre_trained_epoch_size * dataset_size, | |||||
| lr_init=config.lr_init, lr_end=config.lr_end_rate * args_opt.lr, lr_max=args_opt.lr, | |||||
| warmup_epochs=config.warmup_epochs, | |||||
| total_epochs=args_opt.epoch_size, | |||||
| steps_per_epoch=dataset_size)) | |||||
| if "use_global_norm" in config and config.use_global_norm: | |||||
| opt = nn.Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, | |||||
| config.momentum, config.weight_decay, 1.0) | |||||
| net = TrainingWrapper(net, opt, loss_scale, True) | |||||
| else: | |||||
| opt = nn.Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, | opt = nn.Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, | ||||
| config.momentum, config.weight_decay, loss_scale) | config.momentum, config.weight_decay, loss_scale) | ||||
| net = TrainingWrapper(net, opt, loss_scale) | net = TrainingWrapper(net, opt, loss_scale) | ||||
| callback = [TimeMonitor(data_size=dataset_size), LossMonitor(), ckpoint_cb] | |||||
| model = Model(net) | |||||
| dataset_sink_mode = False | |||||
| if args_opt.mode == "sink" and args_opt.run_platform != "CPU": | |||||
| print("In sink mode, one epoch return a loss.") | |||||
| dataset_sink_mode = True | |||||
| print("Start train SSD, the first epoch will be slower because of the graph compilation.") | |||||
| model.train(args_opt.epoch_size, dataset, callbacks=callback, dataset_sink_mode=dataset_sink_mode) | |||||
| callback = [TimeMonitor(data_size=dataset_size), LossMonitor(), ckpoint_cb] | |||||
| model = Model(net) | |||||
| dataset_sink_mode = False | |||||
| if args_opt.mode == "sink" and args_opt.run_platform != "CPU": | |||||
| print("In sink mode, one epoch return a loss.") | |||||
| dataset_sink_mode = True | |||||
| print("Start train SSD, the first epoch will be slower because of the graph compilation.") | |||||
| model.train(args_opt.epoch_size, dataset, callbacks=callback, dataset_sink_mode=dataset_sink_mode) | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| main() | main() | ||||