| @@ -0,0 +1,88 @@ | |||
| # SSD Example | |||
| ## Description | |||
| SSD network based on MobileNetV2, with support for training and evaluation. | |||
| ## Requirements | |||
| - Install [MindSpore](https://www.mindspore.cn/install/en). | |||
| - Dataset | |||
| We use coco2017 as training dataset in this example by default, and you can also use your own datasets. | |||
| 1. If coco dataset is used. **Select dataset to coco when run script.** | |||
| Download coco2017: [train2017](http://images.cocodataset.org/zips/train2017.zip), [val2017](http://images.cocodataset.org/zips/val2017.zip), [test2017](http://images.cocodataset.org/zips/test2017.zip), [annotations](http://images.cocodataset.org/annotations/annotations_trainval2017.zip). Install pycocotool. | |||
| ``` | |||
| pip install Cython | |||
| pip install pycocotools | |||
| ``` | |||
| And change the COCO_ROOT and other settings you need in `config.py`. The directory structure is as follows: | |||
| ``` | |||
| └─coco2017 | |||
| ├── annotations # annotation jsons | |||
| ├── train2017 # train dataset | |||
| └── val2017 # infer dataset | |||
| ``` | |||
| 2. If your own dataset is used. **Select dataset to other when run script.** | |||
| Organize the dataset infomation into a TXT file, each row in the file is as follows: | |||
| ``` | |||
| train2017/0000001.jpg 0,259,401,459,7 35,28,324,201,2 0,30,59,80,2 | |||
| ``` | |||
| Each row is an image annotation which split by space, the first column is a relative path of image, the others are box and class infomations of the format [xmin,ymin,xmax,ymax,class]. We read image from an image path joined by the `IMAGE_DIR`(dataset directory) and the relative path in `ANNO_PATH`(the TXT file path), `IMAGE_DIR` and `ANNO_PATH` are setting in `config.py`. | |||
| ## Running the example | |||
| ### Training | |||
| To train the model, run `train.py`. If the `MINDRECORD_DIR` is empty, it will generate [mindrecord](https://www.mindspore.cn/tutorial/en/master/use/data_preparation/converting_datasets.html) files by `COCO_ROOT`(coco dataset) or `IMAGE_DIR` and `ANNO_PATH`(own dataset). **Note if MINDRECORD_DIR isn't empty, it will use MINDRECORD_DIR instead of raw images.** | |||
| - Stand alone mode | |||
| ``` | |||
| python train.py --dataset coco | |||
| ``` | |||
| You can run ```python train.py -h``` to get more information. | |||
| - Distribute mode | |||
| ``` | |||
| sh run_distribute_train.sh 8 150 coco /data/hccl.json | |||
| ``` | |||
| The input parameters are device numbers, epoch size, dataset mode and [hccl json configuration file](https://www.mindspore.cn/tutorial/en/master/advanced_use/distributed_training.html). **It is better to use absolute path.** | |||
| You will get the loss value of each step as following: | |||
| ``` | |||
| epoch: 1 step: 455, loss is 5.8653416 | |||
| epoch: 2 step: 455, loss is 5.4292373 | |||
| epoch: 3 step: 455, loss is 5.458992 | |||
| ... | |||
| epoch: 148 step: 455, loss is 1.8340507 | |||
| epoch: 149 step: 455, loss is 2.0876894 | |||
| epoch: 150 step: 455, loss is 2.239692 | |||
| ``` | |||
| ### Evaluation | |||
| for evaluation , run `eval.py` with `ckpt_path`. `ckpt_path` is the path of [checkpoint](https://www.mindspore.cn/tutorial/en/master/use/saving_and_loading_model_parameters.html) file. | |||
| ``` | |||
| python eval.py --ckpt_path ssd.ckpt --dataset coco | |||
| ``` | |||
| You can run ```python eval.py -h``` to get more information. | |||
| @@ -0,0 +1,64 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Config parameters for SSD models.""" | |||
| class ConfigSSD: | |||
| """ | |||
| Config parameters for SSD. | |||
| Examples: | |||
| ConfigSSD(). | |||
| """ | |||
| IMG_SHAPE = [300, 300] | |||
| NUM_SSD_BOXES = 1917 | |||
| NEG_PRE_POSITIVE = 3 | |||
| MATCH_THRESHOLD = 0.5 | |||
| NUM_DEFAULT = [3, 6, 6, 6, 6, 6] | |||
| EXTRAS_IN_CHANNELS = [256, 576, 1280, 512, 256, 256] | |||
| EXTRAS_OUT_CHANNELS = [576, 1280, 512, 256, 256, 128] | |||
| EXTRAS_STRIDES = [1, 1, 2, 2, 2, 2] | |||
| EXTRAS_RATIO = [0.2, 0.2, 0.2, 0.25, 0.5, 0.25] | |||
| FEATURE_SIZE = [19, 10, 5, 3, 2, 1] | |||
| SCALES = [21, 45, 99, 153, 207, 261, 315] | |||
| ASPECT_RATIOS = [(1,), (2, 3), (2, 3), (2, 3), (2, 3), (2, 3)] | |||
| STEPS = (16, 32, 64, 100, 150, 300) | |||
| PRIOR_SCALING = (0.1, 0.2) | |||
| # `MINDRECORD_DIR` and `COCO_ROOT` are better to use absolute path. | |||
| MINDRECORD_DIR = "MindRecord_COCO" | |||
| COCO_ROOT = "coco2017" | |||
| TRAIN_DATA_TYPE = "train2017" | |||
| VAL_DATA_TYPE = "val2017" | |||
| INSTANCES_SET = "annotations/instances_{}.json" | |||
| COCO_CLASSES = ('background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', | |||
| 'train', 'truck', 'boat', 'traffic light', 'fire', 'hydrant', | |||
| 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', | |||
| 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', | |||
| 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', | |||
| 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', | |||
| 'kite', 'baseball bat', 'baseball glove', 'skateboard', | |||
| 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', | |||
| 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', | |||
| 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', | |||
| 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', | |||
| 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', | |||
| 'keyboard', 'cell phone', 'microwave oven', 'toaster', 'sink', | |||
| 'refrigerator', 'book', 'clock', 'vase', 'scissors', | |||
| 'teddy bear', 'hair drier', 'toothbrush') | |||
| NUM_CLASSES = len(COCO_CLASSES) | |||
| @@ -0,0 +1,375 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """SSD dataset""" | |||
| from __future__ import division | |||
| import os | |||
| import math | |||
| import itertools as it | |||
| import numpy as np | |||
| import cv2 | |||
| import mindspore.dataset as de | |||
| import mindspore.dataset.transforms.vision.c_transforms as C | |||
| from mindspore.mindrecord import FileWriter | |||
| from config import ConfigSSD | |||
| config = ConfigSSD() | |||
| class GeneratDefaultBoxes(): | |||
| """ | |||
| Generate Default boxes for SSD, follows the order of (W, H, archor_sizes). | |||
| `self.default_boxes` has a shape of [archor_sizes, H, W, 4], the last dimension is [x, y, w, h]. | |||
| `self.default_boxes_ltrb` has a shape as `self.default_boxes`, the last dimension is [x1, y1, x2, y2]. | |||
| """ | |||
| def __init__(self): | |||
| fk = config.IMG_SHAPE[0] / np.array(config.STEPS) | |||
| self.default_boxes = [] | |||
| for idex, feature_size in enumerate(config.FEATURE_SIZE): | |||
| sk1 = config.SCALES[idex] / config.IMG_SHAPE[0] | |||
| sk2 = config.SCALES[idex + 1] / config.IMG_SHAPE[0] | |||
| sk3 = math.sqrt(sk1 * sk2) | |||
| if config.NUM_DEFAULT[idex] == 3: | |||
| all_sizes = [(0.5, 1.0), (1.0, 1.0), (1.0, 0.5)] | |||
| else: | |||
| all_sizes = [(sk1, sk1), (sk3, sk3)] | |||
| for aspect_ratio in config.ASPECT_RATIOS[idex]: | |||
| w, h = sk1 * math.sqrt(aspect_ratio), sk1 / math.sqrt(aspect_ratio) | |||
| all_sizes.append((w, h)) | |||
| all_sizes.append((h, w)) | |||
| assert len(all_sizes) == config.NUM_DEFAULT[idex] | |||
| for i, j in it.product(range(feature_size), repeat=2): | |||
| for w, h in all_sizes: | |||
| cx, cy = (j + 0.5) / fk[idex], (i + 0.5) / fk[idex] | |||
| box = [np.clip(k, 0, 1) for k in (cx, cy, w, h)] | |||
| self.default_boxes.append(box) | |||
| def to_ltrb(cx, cy, w, h): | |||
| return cx - w / 2, cy - h / 2, cx + w / 2, cy + h / 2 | |||
| # For IoU calculation | |||
| self.default_boxes_ltrb = np.array(tuple(to_ltrb(*i) for i in self.default_boxes), dtype='float32') | |||
| self.default_boxes = np.array(self.default_boxes, dtype='float32') | |||
| default_boxes_ltrb = GeneratDefaultBoxes().default_boxes_ltrb | |||
| default_boxes = GeneratDefaultBoxes().default_boxes | |||
| x1, y1, x2, y2 = np.split(default_boxes_ltrb[:, :4], 4, axis=-1) | |||
| vol_anchors = (x2 - x1) * (y2 - y1) | |||
| matching_threshold = config.MATCH_THRESHOLD | |||
| def ssd_bboxes_encode(boxes): | |||
| """ | |||
| Labels anchors with ground truth inputs. | |||
| Args: | |||
| boxex: ground truth with shape [N, 5], for each row, it stores [x, y, w, h, cls]. | |||
| Returns: | |||
| gt_loc: location ground truth with shape [num_anchors, 4]. | |||
| gt_label: class ground truth with shape [num_anchors, 1]. | |||
| num_matched_boxes: number of positives in an image. | |||
| """ | |||
| def jaccard_with_anchors(bbox): | |||
| """Compute jaccard score a box and the anchors.""" | |||
| # Intersection bbox and volume. | |||
| xmin = np.maximum(x1, bbox[0]) | |||
| ymin = np.maximum(y1, bbox[1]) | |||
| xmax = np.minimum(x2, bbox[2]) | |||
| ymax = np.minimum(y2, bbox[3]) | |||
| w = np.maximum(xmax - xmin, 0.) | |||
| h = np.maximum(ymax - ymin, 0.) | |||
| # Volumes. | |||
| inter_vol = h * w | |||
| union_vol = vol_anchors + (bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) - inter_vol | |||
| jaccard = inter_vol / union_vol | |||
| return np.squeeze(jaccard) | |||
| pre_scores = np.zeros((config.NUM_SSD_BOXES), dtype=np.float32) | |||
| t_boxes = np.zeros((config.NUM_SSD_BOXES, 4), dtype=np.float32) | |||
| t_label = np.zeros((config.NUM_SSD_BOXES), dtype=np.int64) | |||
| for bbox in boxes: | |||
| label = int(bbox[4]) | |||
| scores = jaccard_with_anchors(bbox) | |||
| mask = (scores > matching_threshold) | |||
| if not np.any(mask): | |||
| mask[np.argmax(scores)] = True | |||
| mask = mask & (scores > pre_scores) | |||
| pre_scores = np.maximum(pre_scores, scores) | |||
| t_label = mask * label + (1 - mask) * t_label | |||
| for i in range(4): | |||
| t_boxes[:, i] = mask * bbox[i] + (1 - mask) * t_boxes[:, i] | |||
| index = np.nonzero(t_label) | |||
| # Transform to ltrb. | |||
| bboxes = np.zeros((config.NUM_SSD_BOXES, 4), dtype=np.float32) | |||
| bboxes[:, [0, 1]] = (t_boxes[:, [0, 1]] + t_boxes[:, [2, 3]]) / 2 | |||
| bboxes[:, [2, 3]] = t_boxes[:, [2, 3]] - t_boxes[:, [0, 1]] | |||
| # Encode features. | |||
| bboxes_t = bboxes[index] | |||
| default_boxes_t = default_boxes[index] | |||
| bboxes_t[:, :2] = (bboxes_t[:, :2] - default_boxes_t[:, :2]) / (default_boxes_t[:, 2:] * config.PRIOR_SCALING[0]) | |||
| bboxes_t[:, 2:4] = np.log(bboxes_t[:, 2:4] / default_boxes_t[:, 2:4]) / config.PRIOR_SCALING[1] | |||
| bboxes[index] = bboxes_t | |||
| num_match_num = np.array([len(np.nonzero(t_label)[0])], dtype=np.int32) | |||
| return bboxes, t_label.astype(np.int32), num_match_num | |||
| def ssd_bboxes_decode(boxes, index, image_shape): | |||
| """Decode predict boxes to [x, y, w, h]""" | |||
| boxes_t = boxes[index] | |||
| default_boxes_t = default_boxes[index] | |||
| boxes_t[:, :2] = boxes_t[:, :2] * config.PRIOR_SCALING[0] * default_boxes_t[:, 2:] + default_boxes_t[:, :2] | |||
| boxes_t[:, 2:4] = np.exp(boxes_t[:, 2:4] * config.PRIOR_SCALING[1]) * default_boxes_t[:, 2:4] | |||
| bboxes = np.zeros((len(boxes_t), 4), dtype=np.float32) | |||
| bboxes[:, [0, 1]] = boxes_t[:, [0, 1]] - boxes_t[:, [2, 3]] / 2 | |||
| bboxes[:, [2, 3]] = boxes_t[:, [0, 1]] + boxes_t[:, [2, 3]] / 2 | |||
| return bboxes | |||
| def preprocess_fn(image, box, is_training): | |||
| """Preprocess function for dataset.""" | |||
| def _rand(a=0., b=1.): | |||
| """Generate random.""" | |||
| return np.random.rand() * (b - a) + a | |||
| def _infer_data(image, input_shape, box): | |||
| img_h, img_w, _ = image.shape | |||
| input_h, input_w = input_shape | |||
| scale = min(float(input_w) / float(img_w), float(input_h) / float(img_h)) | |||
| nw = int(img_w * scale) | |||
| nh = int(img_h * scale) | |||
| image = cv2.resize(image, (nw, nh)) | |||
| new_image = np.zeros((input_h, input_w, 3), np.float32) | |||
| dh = (input_h - nh) // 2 | |||
| dw = (input_w - nw) // 2 | |||
| new_image[dh: (nh + dh), dw: (nw + dw), :] = image | |||
| image = new_image | |||
| #When the channels of image is 1 | |||
| if len(image.shape) == 2: | |||
| image = np.expand_dims(image, axis=-1) | |||
| image = np.concatenate([image, image, image], axis=-1) | |||
| box = box.astype(np.float32) | |||
| box[:, [0, 2]] = (box[:, [0, 2]] * scale + dw) / input_w | |||
| box[:, [1, 3]] = (box[:, [1, 3]] * scale + dh) / input_h | |||
| return image, np.array((img_h, img_w), np.float32), box | |||
| def _data_aug(image, box, is_training, image_size=(300, 300)): | |||
| """Data augmentation function.""" | |||
| ih, iw, _ = image.shape | |||
| w, h = image_size | |||
| if not is_training: | |||
| return _infer_data(image, image_size, box) | |||
| # Random settings | |||
| scale_w = _rand(0.75, 1.25) | |||
| scale_h = _rand(0.75, 1.25) | |||
| flip = _rand() < .5 | |||
| nw = iw * scale_w | |||
| nh = ih * scale_h | |||
| scale = min(w / nw, h / nh) | |||
| nw = int(scale * nw) | |||
| nh = int(scale * nh) | |||
| # Resize image | |||
| image = cv2.resize(image, (nw, nh)) | |||
| # place image | |||
| new_image = np.zeros((h, w, 3), dtype=np.float32) | |||
| dw = (w - nw) // 2 | |||
| dh = (h - nh) // 2 | |||
| new_image[dh:dh + nh, dw:dw + nw, :] = image | |||
| image = new_image | |||
| # Flip image or not | |||
| if flip: | |||
| image = cv2.flip(image, 1, dst=None) | |||
| # Convert image to gray or not | |||
| gray = _rand() < .25 | |||
| if gray: | |||
| image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) | |||
| # When the channels of image is 1 | |||
| if len(image.shape) == 2: | |||
| image = np.expand_dims(image, axis=-1) | |||
| image = np.concatenate([image, image, image], axis=-1) | |||
| box = box.astype(np.float32) | |||
| # Transform box with shape[x1, y1, x2, y2]. | |||
| box[:, [0, 2]] = (box[:, [0, 2]] * scale * scale_w + dw) / w | |||
| box[:, [1, 3]] = (box[:, [1, 3]] * scale * scale_h + dh) / h | |||
| if flip: | |||
| box[:, [0, 2]] = 1 - box[:, [2, 0]] | |||
| box, label, num_match_num = ssd_bboxes_encode(box) | |||
| return image, box, label, num_match_num | |||
| return _data_aug(image, box, is_training, image_size=config.IMG_SHAPE) | |||
| def create_coco_label(is_training): | |||
| """Get image path and annotation from COCO.""" | |||
| from pycocotools.coco import COCO | |||
| coco_root = config.COCO_ROOT | |||
| data_type = config.VAL_DATA_TYPE | |||
| if is_training: | |||
| data_type = config.TRAIN_DATA_TYPE | |||
| #Classes need to train or test. | |||
| train_cls = config.COCO_CLASSES | |||
| train_cls_dict = {} | |||
| for i, cls in enumerate(train_cls): | |||
| train_cls_dict[cls] = i | |||
| anno_json = os.path.join(coco_root, config.INSTANCES_SET.format(data_type)) | |||
| coco = COCO(anno_json) | |||
| classs_dict = {} | |||
| cat_ids = coco.loadCats(coco.getCatIds()) | |||
| for cat in cat_ids: | |||
| classs_dict[cat["id"]] = cat["name"] | |||
| image_ids = coco.getImgIds() | |||
| image_files = [] | |||
| image_anno_dict = {} | |||
| for img_id in image_ids: | |||
| image_info = coco.loadImgs(img_id) | |||
| file_name = image_info[0]["file_name"] | |||
| anno_ids = coco.getAnnIds(imgIds=img_id, iscrowd=None) | |||
| anno = coco.loadAnns(anno_ids) | |||
| image_path = os.path.join(coco_root, data_type, file_name) | |||
| annos = [] | |||
| for label in anno: | |||
| bbox = label["bbox"] | |||
| class_name = classs_dict[label["category_id"]] | |||
| if class_name in train_cls: | |||
| x_min, x_max = bbox[0], bbox[0] + bbox[2] | |||
| y_min, y_max = bbox[1], bbox[1] + bbox[3] | |||
| annos.append(list(map(round, [x_min, y_min, x_max, y_max])) + [train_cls_dict[class_name]]) | |||
| if len(annos) >= 1: | |||
| image_files.append(image_path) | |||
| image_anno_dict[image_path] = np.array(annos) | |||
| return image_files, image_anno_dict | |||
| def anno_parser(annos_str): | |||
| """Parse annotation from string to list.""" | |||
| annos = [] | |||
| for anno_str in annos_str: | |||
| anno = list(map(int, anno_str.strip().split(','))) | |||
| annos.append(anno) | |||
| return annos | |||
| def filter_valid_data(image_dir, anno_path): | |||
| """Filter valid image file, which both in image_dir and anno_path.""" | |||
| image_files = [] | |||
| image_anno_dict = {} | |||
| if not os.path.isdir(image_dir): | |||
| raise RuntimeError("Path given is not valid.") | |||
| if not os.path.isfile(anno_path): | |||
| raise RuntimeError("Annotation file is not valid.") | |||
| with open(anno_path, "rb") as f: | |||
| lines = f.readlines() | |||
| for line in lines: | |||
| line_str = line.decode("utf-8").strip() | |||
| line_split = str(line_str).split(' ') | |||
| file_name = line_split[0] | |||
| image_path = os.path.join(image_dir, file_name) | |||
| if os.path.isfile(image_path): | |||
| image_anno_dict[image_path] = anno_parser(line_split[1:]) | |||
| image_files.append(image_path) | |||
| return image_files, image_anno_dict | |||
| def data_to_mindrecord_byte_image(dataset="coco", is_training=True, prefix="ssd.mindrecord", file_num=8): | |||
| """Create MindRecord file.""" | |||
| mindrecord_dir = config.MINDRECORD_DIR | |||
| mindrecord_path = os.path.join(mindrecord_dir, prefix) | |||
| writer = FileWriter(mindrecord_path, file_num) | |||
| if dataset == "coco": | |||
| image_files, image_anno_dict = create_coco_label(is_training) | |||
| else: | |||
| image_files, image_anno_dict = filter_valid_data(config.IMAGE_DIR, config.ANNO_PATH) | |||
| ssd_json = { | |||
| "image": {"type": "bytes"}, | |||
| "annotation": {"type": "int32", "shape": [-1, 5]}, | |||
| } | |||
| writer.add_schema(ssd_json, "ssd_json") | |||
| for image_name in image_files: | |||
| with open(image_name, 'rb') as f: | |||
| img = f.read() | |||
| annos = np.array(image_anno_dict[image_name], dtype=np.int32) | |||
| row = {"image": img, "annotation": annos} | |||
| writer.write_raw_data([row]) | |||
| writer.commit() | |||
| def create_ssd_dataset(mindrecord_file, batch_size=32, repeat_num=10, device_num=1, rank=0, | |||
| is_training=True, num_parallel_workers=4): | |||
| """Creatr SSD dataset with MindDataset.""" | |||
| ds = de.MindDataset(mindrecord_file, columns_list=["image", "annotation"], num_shards=device_num, shard_id=rank, | |||
| num_parallel_workers=num_parallel_workers, shuffle=is_training) | |||
| decode = C.Decode() | |||
| ds = ds.map(input_columns=["image"], operations=decode) | |||
| compose_map_func = (lambda image, annotation: preprocess_fn(image, annotation, is_training)) | |||
| if is_training: | |||
| hwc_to_chw = C.HWC2CHW() | |||
| ds = ds.map(input_columns=["image", "annotation"], | |||
| output_columns=["image", "box", "label", "num_match_num"], | |||
| columns_order=["image", "box", "label", "num_match_num"], | |||
| operations=compose_map_func, python_multiprocessing=True, num_parallel_workers=num_parallel_workers) | |||
| ds = ds.map(input_columns=["image"], operations=hwc_to_chw, python_multiprocessing=True, | |||
| num_parallel_workers=num_parallel_workers) | |||
| ds = ds.batch(batch_size, drop_remainder=True) | |||
| ds = ds.repeat(repeat_num) | |||
| else: | |||
| hwc_to_chw = C.HWC2CHW() | |||
| ds = ds.map(input_columns=["image", "annotation"], | |||
| output_columns=["image", "image_shape", "annotation"], | |||
| columns_order=["image", "image_shape", "annotation"], | |||
| operations=compose_map_func) | |||
| ds = ds.map(input_columns=["image"], operations=hwc_to_chw, num_parallel_workers=num_parallel_workers) | |||
| ds = ds.batch(batch_size, drop_remainder=True) | |||
| ds = ds.repeat(repeat_num) | |||
| return ds | |||
| @@ -0,0 +1,99 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # less required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Evaluation for SSD""" | |||
| import os | |||
| import argparse | |||
| import time | |||
| from mindspore import context, Tensor | |||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| from mindspore.model_zoo.ssd import SSD300, ssd_mobilenet_v2 | |||
| from dataset import create_ssd_dataset, data_to_mindrecord_byte_image | |||
| from config import ConfigSSD | |||
| from util import metrics | |||
| def ssd_eval(dataset_path, ckpt_path): | |||
| """SSD evaluation.""" | |||
| ds = create_ssd_dataset(dataset_path, batch_size=1, repeat_num=1, is_training=False) | |||
| net = SSD300(ssd_mobilenet_v2(), ConfigSSD(), is_training=False) | |||
| print("Load Checkpoint!") | |||
| param_dict = load_checkpoint(ckpt_path) | |||
| load_param_into_net(net, param_dict) | |||
| net.set_train(False) | |||
| i = 1. | |||
| total = ds.get_dataset_size() | |||
| start = time.time() | |||
| pred_data = [] | |||
| print("\n========================================\n") | |||
| print("total images num: ", total) | |||
| print("Processing, please wait a moment.") | |||
| for data in ds.create_dict_iterator(): | |||
| img_np = data['image'] | |||
| image_shape = data['image_shape'] | |||
| annotation = data['annotation'] | |||
| output = net(Tensor(img_np)) | |||
| for batch_idx in range(img_np.shape[0]): | |||
| pred_data.append({"boxes": output[0].asnumpy()[batch_idx], | |||
| "box_scores": output[1].asnumpy()[batch_idx], | |||
| "annotation": annotation, | |||
| "image_shape": image_shape}) | |||
| percent = round(i / total * 100, 2) | |||
| print(f' {str(percent)} [{i}/{total}]', end='\r') | |||
| i += 1 | |||
| cost_time = int((time.time() - start) * 1000) | |||
| print(f' 100% [{total}/{total}] cost {cost_time} ms') | |||
| mAP = metrics(pred_data) | |||
| print("\n========================================\n") | |||
| print(f"mAP: {mAP}") | |||
| if __name__ == '__main__': | |||
| parser = argparse.ArgumentParser(description='SSD evaluation') | |||
| parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") | |||
| parser.add_argument("--dataset", type=str, default="coco", help="Dataset, default is coco.") | |||
| parser.add_argument("--checkpoint_path", type=str, required=True, help="Checkpoint file path.") | |||
| args_opt = parser.parse_args() | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id) | |||
| context.set_context(enable_task_sink=True, enable_loop_sink=True, enable_mem_reuse=True) | |||
| config = ConfigSSD() | |||
| prefix = "ssd_eval.mindrecord" | |||
| mindrecord_dir = config.MINDRECORD_DIR | |||
| mindrecord_file = os.path.join(mindrecord_dir, prefix + "0") | |||
| if not os.path.exists(mindrecord_file): | |||
| if not os.path.isdir(mindrecord_dir): | |||
| os.makedirs(mindrecord_dir) | |||
| if args_opt.dataset == "coco": | |||
| if os.path.isdir(config.COCO_ROOT): | |||
| print("Create Mindrecord.") | |||
| data_to_mindrecord_byte_image("coco", False, prefix) | |||
| print("Create Mindrecord Done, at {}".format(mindrecord_dir)) | |||
| else: | |||
| print("COCO_ROOT not exits.") | |||
| else: | |||
| if os.path.isdir(config.IMAGE_DIR) and os.path.exists(config.ANNO_PATH): | |||
| print("Create Mindrecord.") | |||
| data_to_mindrecord_byte_image("other", False, prefix) | |||
| print("Create Mindrecord Done, at {}".format(mindrecord_dir)) | |||
| else: | |||
| print("IMAGE_DIR or ANNO_PATH not exits.") | |||
| print("Start Eval!") | |||
| ssd_eval(mindrecord_file, args_opt.checkpoint_path) | |||
| @@ -0,0 +1,54 @@ | |||
| #!/bin/bash | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| echo "==============================================================================================================" | |||
| echo "Please run the scipt as: " | |||
| echo "sh run_distribute_train.sh DEVICE_NUM EPOCH_SIZE MINDSPORE_HCCL_CONFIG_PATH" | |||
| echo "for example: sh run_distribute_train.sh 8 150 coco /data/hccl.json" | |||
| echo "It is better to use absolute path." | |||
| echo "The learning rate is 0.4 as default, if you want other lr, please change the value in this script." | |||
| echo "==============================================================================================================" | |||
| # Before start distribute train, first create mindrecord files. | |||
| python train.py --only_create_dataset=1 | |||
| echo "After running the scipt, the network runs in the background. The log will be generated in LOGx/log.txt" | |||
| export RANK_SIZE=$1 | |||
| EPOCH_SIZE=$2 | |||
| DATASET=$3 | |||
| export MINDSPORE_HCCL_CONFIG_PATH=$4 | |||
| for((i=0;i<RANK_SIZE;i++)) | |||
| do | |||
| export DEVICE_ID=$i | |||
| rm -rf LOG$i | |||
| mkdir ./LOG$i | |||
| cp *.py ./LOG$i | |||
| cd ./LOG$i || exit | |||
| export RANK_ID=$i | |||
| echo "start training for rank $i, device $DEVICE_ID" | |||
| env > env.log | |||
| python ../train.py \ | |||
| --distribute=1 \ | |||
| --lr=0.4 \ | |||
| --dataset=$DATASET \ | |||
| --device_num=$RANK_SIZE \ | |||
| --device_id=$DEVICE_ID \ | |||
| --epoch_size=$EPOCH_SIZE > log.txt 2>&1 & | |||
| cd ../ | |||
| done | |||
| @@ -0,0 +1,176 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # less required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """train SSD and get checkpoint files.""" | |||
| import os | |||
| import math | |||
| import argparse | |||
| import numpy as np | |||
| import mindspore.nn as nn | |||
| from mindspore import context, Tensor | |||
| from mindspore.communication.management import init | |||
| from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, LossMonitor, TimeMonitor | |||
| from mindspore.train import Model, ParallelMode | |||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| from mindspore.common.initializer import initializer | |||
| from mindspore.model_zoo.ssd import SSD300, SSDWithLossCell, TrainingWrapper, ssd_mobilenet_v2 | |||
| from config import ConfigSSD | |||
| from dataset import create_ssd_dataset, data_to_mindrecord_byte_image | |||
| def get_lr(global_step, lr_init, lr_end, lr_max, warmup_epochs, total_epochs, steps_per_epoch): | |||
| """ | |||
| generate learning rate array | |||
| Args: | |||
| global_step(int): total steps of the training | |||
| lr_init(float): init learning rate | |||
| lr_end(float): end learning rate | |||
| lr_max(float): max learning rate | |||
| warmup_epochs(int): number of warmup epochs | |||
| total_epochs(int): total epoch of training | |||
| steps_per_epoch(int): steps of one epoch | |||
| Returns: | |||
| np.array, learning rate array | |||
| """ | |||
| lr_each_step = [] | |||
| total_steps = steps_per_epoch * total_epochs | |||
| warmup_steps = steps_per_epoch * warmup_epochs | |||
| for i in range(total_steps): | |||
| if i < warmup_steps: | |||
| lr = lr_init + (lr_max - lr_init) * i / warmup_steps | |||
| else: | |||
| lr = lr_end + (lr_max - lr_end) * \ | |||
| (1. + math.cos(math.pi * (i - warmup_steps) / (total_steps - warmup_steps))) / 2. | |||
| if lr < 0.0: | |||
| lr = 0.0 | |||
| lr_each_step.append(lr) | |||
| current_step = global_step | |||
| lr_each_step = np.array(lr_each_step).astype(np.float32) | |||
| learning_rate = lr_each_step[current_step:] | |||
| return learning_rate | |||
| def init_net_param(network, initialize_mode='XavierUniform'): | |||
| """Init the parameters in net.""" | |||
| params = network.trainable_params() | |||
| for p in params: | |||
| if isinstance(p.data, Tensor) and 'beta' not in p.name and 'gamma' not in p.name and 'bias' not in p.name: | |||
| p.set_parameter_data(initializer(initialize_mode, p.data.shape(), p.data.dtype())) | |||
| def main(): | |||
| parser = argparse.ArgumentParser(description="SSD training") | |||
| parser.add_argument("--only_create_dataset", type=bool, default=False, help="If set it true, only create " | |||
| "Mindrecord, default is false.") | |||
| parser.add_argument("--distribute", type=bool, default=False, help="Run distribute, default is false.") | |||
| parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") | |||
| parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.") | |||
| parser.add_argument("--lr", type=float, default=0.25, help="Learning rate, default is 0.25.") | |||
| parser.add_argument("--mode", type=str, default="sink", help="Run sink mode or not, default is sink.") | |||
| parser.add_argument("--dataset", type=str, default="coco", help="Dataset, defalut is coco.") | |||
| parser.add_argument("--epoch_size", type=int, default=70, help="Epoch size, default is 70.") | |||
| parser.add_argument("--batch_size", type=int, default=32, help="Batch size, default is 32.") | |||
| parser.add_argument("--checkpoint_path", type=str, default="", help="Checkpoint file path.") | |||
| parser.add_argument("--save_checkpoint_epochs", type=int, default=5, help="Save checkpoint epochs, default is 5.") | |||
| parser.add_argument("--loss_scale", type=int, default=1024, help="Loss scale, default is 1024.") | |||
| args_opt = parser.parse_args() | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id) | |||
| context.set_context(enable_task_sink=True, enable_loop_sink=True, enable_mem_reuse=True) | |||
| if args_opt.distribute: | |||
| device_num = args_opt.device_num | |||
| context.reset_auto_parallel_context() | |||
| context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True, | |||
| device_num=device_num) | |||
| init() | |||
| rank = args_opt.device_id % device_num | |||
| else: | |||
| rank = 0 | |||
| device_num = 1 | |||
| print("Start create dataset!") | |||
| # It will generate mindrecord file in args_opt.mindrecord_dir, | |||
| # and the file name is ssd.mindrecord0, 1, ... file_num. | |||
| config = ConfigSSD() | |||
| prefix = "ssd.mindrecord" | |||
| mindrecord_dir = config.MINDRECORD_DIR | |||
| mindrecord_file = os.path.join(mindrecord_dir, prefix + "0") | |||
| if not os.path.exists(mindrecord_file): | |||
| if not os.path.isdir(mindrecord_dir): | |||
| os.makedirs(mindrecord_dir) | |||
| if args_opt.dataset == "coco": | |||
| if os.path.isdir(config.COCO_ROOT): | |||
| print("Create Mindrecord.") | |||
| data_to_mindrecord_byte_image("coco", True, prefix) | |||
| print("Create Mindrecord Done, at {}".format(mindrecord_dir)) | |||
| else: | |||
| print("COCO_ROOT not exits.") | |||
| else: | |||
| if os.path.isdir(config.IMAGE_DIR) and os.path.exists(config.ANNO_PATH): | |||
| print("Create Mindrecord.") | |||
| data_to_mindrecord_byte_image("other", True, prefix) | |||
| print("Create Mindrecord Done, at {}".format(mindrecord_dir)) | |||
| else: | |||
| print("IMAGE_DIR or ANNO_PATH not exits.") | |||
| if not args_opt.only_create_dataset: | |||
| loss_scale = float(args_opt.loss_scale) | |||
| # When create MindDataset, using the fitst mindrecord file, such as ssd.mindrecord0. | |||
| dataset = create_ssd_dataset(mindrecord_file, repeat_num=args_opt.epoch_size, | |||
| batch_size=args_opt.batch_size, device_num=device_num, rank=rank) | |||
| dataset_size = dataset.get_dataset_size() | |||
| print("Create dataset done!") | |||
| ssd = SSD300(backbone=ssd_mobilenet_v2(), config=config) | |||
| net = SSDWithLossCell(ssd, config) | |||
| init_net_param(net) | |||
| # checkpoint | |||
| ckpt_config = CheckpointConfig(save_checkpoint_steps=dataset_size * args_opt.save_checkpoint_epochs) | |||
| ckpoint_cb = ModelCheckpoint(prefix="ssd", directory=None, config=ckpt_config) | |||
| lr = Tensor(get_lr(global_step=0, lr_init=0, lr_end=0, lr_max=args_opt.lr, | |||
| warmup_epochs=max(args_opt.epoch_size // 20, 1), | |||
| total_epochs=args_opt.epoch_size, | |||
| steps_per_epoch=dataset_size)) | |||
| opt = nn.Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, 0.9, 0.0001, loss_scale) | |||
| net = TrainingWrapper(net, opt, loss_scale) | |||
| if args_opt.checkpoint_path != "": | |||
| param_dict = load_checkpoint(args_opt.checkpoint_path) | |||
| load_param_into_net(net, param_dict) | |||
| callback = [TimeMonitor(data_size=dataset_size), LossMonitor(), ckpoint_cb] | |||
| model = Model(net) | |||
| dataset_sink_mode = False | |||
| if args_opt.mode == "sink": | |||
| print("In sink mode, one epoch return a loss.") | |||
| dataset_sink_mode = True | |||
| print("Start train SSD, the first epoch will be slower because of the graph compilation.") | |||
| model.train(args_opt.epoch_size, dataset, callbacks=callback, dataset_sink_mode=dataset_sink_mode) | |||
| if __name__ == '__main__': | |||
| main() | |||
| @@ -0,0 +1,208 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """metrics utils""" | |||
| import numpy as np | |||
| from config import ConfigSSD | |||
| from dataset import ssd_bboxes_decode | |||
| def calc_iou(bbox_pred, bbox_ground): | |||
| """Calculate iou of predicted bbox and ground truth.""" | |||
| bbox_pred = np.expand_dims(bbox_pred, axis=0) | |||
| pred_w = bbox_pred[:, 2] - bbox_pred[:, 0] | |||
| pred_h = bbox_pred[:, 3] - bbox_pred[:, 1] | |||
| pred_area = pred_w * pred_h | |||
| gt_w = bbox_ground[:, 2] - bbox_ground[:, 0] | |||
| gt_h = bbox_ground[:, 3] - bbox_ground[:, 1] | |||
| gt_area = gt_w * gt_h | |||
| iw = np.minimum(bbox_pred[:, 2], bbox_ground[:, 2]) - np.maximum(bbox_pred[:, 0], bbox_ground[:, 0]) | |||
| ih = np.minimum(bbox_pred[:, 3], bbox_ground[:, 3]) - np.maximum(bbox_pred[:, 1], bbox_ground[:, 1]) | |||
| iw = np.maximum(iw, 0) | |||
| ih = np.maximum(ih, 0) | |||
| intersection_area = iw * ih | |||
| union_area = pred_area + gt_area - intersection_area | |||
| union_area = np.maximum(union_area, np.finfo(float).eps) | |||
| iou = intersection_area * 1. / union_area | |||
| return iou | |||
| def apply_nms(all_boxes, all_scores, thres, max_boxes): | |||
| """Apply NMS to bboxes.""" | |||
| x1 = all_boxes[:, 0] | |||
| y1 = all_boxes[:, 1] | |||
| x2 = all_boxes[:, 2] | |||
| y2 = all_boxes[:, 3] | |||
| areas = (x2 - x1 + 1) * (y2 - y1 + 1) | |||
| order = all_scores.argsort()[::-1] | |||
| keep = [] | |||
| while order.size > 0: | |||
| i = order[0] | |||
| keep.append(i) | |||
| if len(keep) >= max_boxes: | |||
| break | |||
| xx1 = np.maximum(x1[i], x1[order[1:]]) | |||
| yy1 = np.maximum(y1[i], y1[order[1:]]) | |||
| xx2 = np.minimum(x2[i], x2[order[1:]]) | |||
| yy2 = np.minimum(y2[i], y2[order[1:]]) | |||
| w = np.maximum(0.0, xx2 - xx1 + 1) | |||
| h = np.maximum(0.0, yy2 - yy1 + 1) | |||
| inter = w * h | |||
| ovr = inter / (areas[i] + areas[order[1:]] - inter) | |||
| inds = np.where(ovr <= thres)[0] | |||
| order = order[inds + 1] | |||
| return keep | |||
| def calc_ap(recall, precision): | |||
| """Calculate AP.""" | |||
| correct_recall = np.concatenate(([0.], recall, [1.])) | |||
| correct_precision = np.concatenate(([0.], precision, [0.])) | |||
| for i in range(correct_recall.size - 1, 0, -1): | |||
| correct_precision[i - 1] = np.maximum(correct_precision[i - 1], correct_precision[i]) | |||
| i = np.where(correct_recall[1:] != correct_recall[:-1])[0] | |||
| ap = np.sum((correct_recall[i + 1] - correct_recall[i]) * correct_precision[i + 1]) | |||
| return ap | |||
| def metrics(pred_data): | |||
| """Calculate mAP of predicted bboxes.""" | |||
| config = ConfigSSD() | |||
| num_classes = config.NUM_CLASSES | |||
| all_detections = [None for i in range(num_classes)] | |||
| all_pred_scores = [None for i in range(num_classes)] | |||
| all_annotations = [None for i in range(num_classes)] | |||
| average_precisions = {} | |||
| num = [0 for i in range(num_classes)] | |||
| accurate_num = [0 for i in range(num_classes)] | |||
| for sample in pred_data: | |||
| pred_boxes = sample['boxes'] | |||
| boxes_scores = sample['box_scores'] | |||
| annotation = sample['annotation'] | |||
| image_shape = sample['image_shape'] | |||
| annotation = np.squeeze(annotation, axis=0) | |||
| image_shape = np.squeeze(image_shape, axis=0) | |||
| pred_labels = np.argmax(boxes_scores, axis=-1) | |||
| index = np.nonzero(pred_labels) | |||
| pred_boxes = ssd_bboxes_decode(pred_boxes, index, image_shape) | |||
| pred_boxes = pred_boxes.clip(0, 1) | |||
| boxes_scores = np.max(boxes_scores, axis=-1) | |||
| boxes_scores = boxes_scores[index] | |||
| pred_labels = pred_labels[index] | |||
| top_k = 50 | |||
| for c in range(1, num_classes): | |||
| if len(pred_labels) >= 1: | |||
| class_box_scores = boxes_scores[pred_labels == c] | |||
| class_boxes = pred_boxes[pred_labels == c] | |||
| nms_index = apply_nms(class_boxes, class_box_scores, config.MATCH_THRESHOLD, top_k) | |||
| class_boxes = class_boxes[nms_index] | |||
| class_box_scores = class_box_scores[nms_index] | |||
| cmask = class_box_scores > 0.5 | |||
| class_boxes = class_boxes[cmask] | |||
| class_box_scores = class_box_scores[cmask] | |||
| all_detections[c] = class_boxes | |||
| all_pred_scores[c] = class_box_scores | |||
| for c in range(1, num_classes): | |||
| if len(annotation) >= 1: | |||
| all_annotations[c] = annotation[annotation[:, 4] == c, :4] | |||
| for c in range(1, num_classes): | |||
| false_positives = np.zeros((0,)) | |||
| true_positives = np.zeros((0,)) | |||
| scores = np.zeros((0,)) | |||
| num_annotations = 0.0 | |||
| annotations = all_annotations[c] | |||
| num_annotations += annotations.shape[0] | |||
| detections = all_detections[c] | |||
| pred_scores = all_pred_scores[c] | |||
| for index, detection in enumerate(detections): | |||
| scores = np.append(scores, pred_scores[index]) | |||
| if len(annotations) >= 1: | |||
| IoUs = calc_iou(detection, annotations) | |||
| assigned_anno = np.argmax(IoUs) | |||
| max_overlap = IoUs[assigned_anno] | |||
| if max_overlap >= 0.5: | |||
| false_positives = np.append(false_positives, 0) | |||
| true_positives = np.append(true_positives, 1) | |||
| else: | |||
| false_positives = np.append(false_positives, 1) | |||
| true_positives = np.append(true_positives, 0) | |||
| else: | |||
| false_positives = np.append(false_positives, 1) | |||
| true_positives = np.append(true_positives, 0) | |||
| if num_annotations == 0: | |||
| if c not in average_precisions.keys(): | |||
| average_precisions[c] = 0 | |||
| continue | |||
| accurate_num[c] = 1 | |||
| indices = np.argsort(-scores) | |||
| false_positives = false_positives[indices] | |||
| true_positives = true_positives[indices] | |||
| false_positives = np.cumsum(false_positives) | |||
| true_positives = np.cumsum(true_positives) | |||
| recall = true_positives * 1. / num_annotations | |||
| precision = true_positives * 1. / np.maximum(true_positives + false_positives, np.finfo(np.float64).eps) | |||
| average_precision = calc_ap(recall, precision) | |||
| if c not in average_precisions.keys(): | |||
| average_precisions[c] = average_precision | |||
| else: | |||
| average_precisions[c] += average_precision | |||
| num[c] += 1 | |||
| count = 0 | |||
| for key in average_precisions: | |||
| if num[key] != 0: | |||
| count += (average_precisions[key] / num[key]) | |||
| mAP = count * 1. / accurate_num.count(1) | |||
| return mAP | |||
| @@ -0,0 +1,367 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """SSD net based MobilenetV2.""" | |||
| import mindspore.common.dtype as mstype | |||
| import mindspore as ms | |||
| import mindspore.nn as nn | |||
| from mindspore import context | |||
| from mindspore.parallel._auto_parallel_context import auto_parallel_context | |||
| from mindspore.communication.management import get_group_size | |||
| from mindspore.ops import operations as P | |||
| from mindspore.ops import functional as F | |||
| from mindspore.ops import composite as C | |||
| from mindspore.common.initializer import initializer | |||
| from .mobilenet import InvertedResidual, ConvBNReLU | |||
| def _conv2d(in_channel, out_channel, kernel_size=3, stride=1, pad_mod='same'): | |||
| weight_shape = (out_channel, in_channel, kernel_size, kernel_size) | |||
| weight = initializer('XavierUniform', shape=weight_shape, dtype=mstype.float32) | |||
| return nn.Conv2d(in_channel, out_channel, kernel_size=kernel_size, stride=stride, | |||
| padding=0, pad_mode=pad_mod, weight_init=weight) | |||
| def _make_divisible(v, divisor, min_value=None): | |||
| """nsures that all layers have a channel number that is divisible by 8.""" | |||
| if min_value is None: | |||
| min_value = divisor | |||
| new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) | |||
| # Make sure that round down does not go down by more than 10%. | |||
| if new_v < 0.9 * v: | |||
| new_v += divisor | |||
| return new_v | |||
| class FlattenConcat(nn.Cell): | |||
| """ | |||
| Concatenate predictions into a single tensor. | |||
| Args: | |||
| config (Class): The default config of SSD. | |||
| Returns: | |||
| Tensor, flatten predictions. | |||
| """ | |||
| def __init__(self, config): | |||
| super(FlattenConcat, self).__init__() | |||
| self.sizes = config.FEATURE_SIZE | |||
| self.length = len(self.sizes) | |||
| self.num_default = config.NUM_DEFAULT | |||
| self.concat = P.Concat(axis=-1) | |||
| self.transpose = P.Transpose() | |||
| def construct(self, x): | |||
| output = () | |||
| for i in range(self.length): | |||
| shape = F.shape(x[i]) | |||
| mid_shape = (shape[0], -1, self.num_default[i], self.sizes[i], self.sizes[i]) | |||
| final_shape = (shape[0], -1, self.num_default[i] * self.sizes[i] * self.sizes[i]) | |||
| output += (F.reshape(F.reshape(x[i], mid_shape), final_shape),) | |||
| res = self.concat(output) | |||
| return self.transpose(res, (0, 2, 1)) | |||
| class MultiBox(nn.Cell): | |||
| """ | |||
| Multibox conv layers. Each multibox layer contains class conf scores and localization predictions. | |||
| Args: | |||
| config (Class): The default config of SSD. | |||
| Returns: | |||
| Tensor, localization predictions. | |||
| Tensor, class conf scores. | |||
| """ | |||
| def __init__(self, config): | |||
| super(MultiBox, self).__init__() | |||
| num_classes = config.NUM_CLASSES | |||
| out_channels = config.EXTRAS_OUT_CHANNELS | |||
| num_default = config.NUM_DEFAULT | |||
| loc_layers = [] | |||
| cls_layers = [] | |||
| for k, out_channel in enumerate(out_channels): | |||
| loc_layers += [_conv2d(out_channel, 4 * num_default[k], | |||
| kernel_size=3, stride=1, pad_mod='same')] | |||
| cls_layers += [_conv2d(out_channel, num_classes * num_default[k], | |||
| kernel_size=3, stride=1, pad_mod='same')] | |||
| self.multi_loc_layers = nn.layer.CellList(loc_layers) | |||
| self.multi_cls_layers = nn.layer.CellList(cls_layers) | |||
| self.flatten_concat = FlattenConcat(config) | |||
| def construct(self, inputs): | |||
| loc_outputs = () | |||
| cls_outputs = () | |||
| for i in range(len(self.multi_loc_layers)): | |||
| loc_outputs += (self.multi_loc_layers[i](inputs[i]),) | |||
| cls_outputs += (self.multi_cls_layers[i](inputs[i]),) | |||
| return self.flatten_concat(loc_outputs), self.flatten_concat(cls_outputs) | |||
| class SSD300(nn.Cell): | |||
| """ | |||
| SSD300 Network. Default backbone is resnet34. | |||
| Args: | |||
| backbone (Cell): Backbone Network. | |||
| config (Class): The default config of SSD. | |||
| Returns: | |||
| Tensor, localization predictions. | |||
| Tensor, class conf scores. | |||
| Examples:backbone | |||
| SSD300(backbone=resnet34(num_classes=None), | |||
| config=ConfigSSDResNet34()). | |||
| """ | |||
| def __init__(self, backbone, config, is_training=True): | |||
| super(SSD300, self).__init__() | |||
| self.backbone = backbone | |||
| in_channels = config.EXTRAS_IN_CHANNELS | |||
| out_channels = config.EXTRAS_OUT_CHANNELS | |||
| ratios = config.EXTRAS_RATIO | |||
| strides = config.EXTRAS_STRIDES | |||
| residual_list = [] | |||
| for i in range(2, len(in_channels)): | |||
| residual = InvertedResidual(in_channels[i], out_channels[i], stride=strides[i], expand_ratio=ratios[i]) | |||
| residual_list.append(residual) | |||
| self.multi_residual = nn.layer.CellList(residual_list) | |||
| self.multi_box = MultiBox(config) | |||
| self.is_training = is_training | |||
| if not is_training: | |||
| self.softmax = P.Softmax() | |||
| def construct(self, x): | |||
| layer_out_13, output = self.backbone(x) | |||
| multi_feature = (layer_out_13, output) | |||
| feature = output | |||
| for residual in self.multi_residual: | |||
| feature = residual(feature) | |||
| multi_feature += (feature,) | |||
| pred_loc, pred_label = self.multi_box(multi_feature) | |||
| if not self.is_training: | |||
| pred_label = self.softmax(pred_label) | |||
| return pred_loc, pred_label | |||
| class LocalizationLoss(nn.Cell): | |||
| """" | |||
| Computes the localization loss with SmoothL1Loss. | |||
| Returns: | |||
| Tensor, box regression loss. | |||
| """ | |||
| def __init__(self): | |||
| super(LocalizationLoss, self).__init__() | |||
| self.reduce_sum = P.ReduceSum() | |||
| self.reduce_mean = P.ReduceMean() | |||
| self.loss = nn.SmoothL1Loss() | |||
| self.expand_dims = P.ExpandDims() | |||
| self.less = P.Less() | |||
| def construct(self, pred_loc, gt_loc, gt_label, num_matched_boxes): | |||
| mask = F.cast(self.less(0, gt_label), mstype.float32) | |||
| mask = self.expand_dims(mask, -1) | |||
| smooth_l1 = self.loss(gt_loc, pred_loc) * mask | |||
| box_loss = self.reduce_sum(smooth_l1, 1) | |||
| return self.reduce_mean(box_loss / F.cast(num_matched_boxes, mstype.float32), (0, 1)) | |||
| class ClassificationLoss(nn.Cell): | |||
| """" | |||
| Computes the classification loss with hard example mining. | |||
| Args: | |||
| config (Class): The default config of SSD. | |||
| Returns: | |||
| Tensor, classification loss. | |||
| """ | |||
| def __init__(self, config): | |||
| super(ClassificationLoss, self).__init__() | |||
| self.num_classes = config.NUM_CLASSES | |||
| self.num_boxes = config.NUM_SSD_BOXES | |||
| self.neg_pre_positive = config.NEG_PRE_POSITIVE | |||
| self.minimum = P.Minimum() | |||
| self.less = P.Less() | |||
| self.sort = P.TopK() | |||
| self.tile = P.Tile() | |||
| self.reduce_sum = P.ReduceSum() | |||
| self.reduce_mean = P.ReduceMean() | |||
| self.expand_dims = P.ExpandDims() | |||
| self.sort_descend = P.TopK(True) | |||
| self.cross_entropy = nn.SoftmaxCrossEntropyWithLogits(sparse=True) | |||
| def construct(self, pred_label, gt_label, num_matched_boxes): | |||
| gt_label = F.cast(gt_label, mstype.int32) | |||
| mask = F.cast(self.less(0, gt_label), mstype.float32) | |||
| gt_label_shape = F.shape(gt_label) | |||
| pred_label = F.reshape(pred_label, (-1, self.num_classes)) | |||
| gt_label = F.reshape(gt_label, (-1,)) | |||
| cross_entropy = self.cross_entropy(pred_label, gt_label) | |||
| cross_entropy = F.reshape(cross_entropy, gt_label_shape) | |||
| # Hard example mining | |||
| num_matched_boxes = F.reshape(num_matched_boxes, (-1,)) | |||
| neg_masked_cross_entropy = F.cast(cross_entropy * (1- mask), mstype.float16) | |||
| _, loss_idx = self.sort_descend(neg_masked_cross_entropy, self.num_boxes) | |||
| _, relative_position = self.sort(F.cast(loss_idx, mstype.float16), self.num_boxes) | |||
| num_neg_boxes = self.minimum(num_matched_boxes * self.neg_pre_positive, self.num_boxes) | |||
| tile_num_neg_boxes = self.tile(self.expand_dims(num_neg_boxes, -1), (1, self.num_boxes)) | |||
| top_k_neg_mask = F.cast(self.less(relative_position, tile_num_neg_boxes), mstype.float32) | |||
| class_loss = self.reduce_sum(cross_entropy * (mask + top_k_neg_mask), 1) | |||
| return self.reduce_mean(class_loss / F.cast(num_matched_boxes, mstype.float32), 0) | |||
| class SSDWithLossCell(nn.Cell): | |||
| """" | |||
| Provide SSD training loss through network. | |||
| Args: | |||
| network (Cell): The training network. | |||
| config (Class): SSD config. | |||
| Returns: | |||
| Tensor, the loss of the network. | |||
| """ | |||
| def __init__(self, network, config): | |||
| super(SSDWithLossCell, self).__init__() | |||
| self.network = network | |||
| self.class_loss = ClassificationLoss(config) | |||
| self.box_loss = LocalizationLoss() | |||
| def construct(self, x, gt_loc, gt_label, num_matched_boxes): | |||
| pred_loc, pred_label = self.network(x) | |||
| loss_cls = self.class_loss(pred_label, gt_label, num_matched_boxes) | |||
| loss_loc = self.box_loss(pred_loc, gt_loc, gt_label, num_matched_boxes) | |||
| return loss_cls + loss_loc | |||
| class TrainingWrapper(nn.Cell): | |||
| """ | |||
| Encapsulation class of SSD network training. | |||
| Append an optimizer to the training network after that the construct | |||
| function can be called to create the backward graph. | |||
| Args: | |||
| network (Cell): The training network. Note that loss function should have been added. | |||
| optimizer (Optimizer): Optimizer for updating the weights. | |||
| sens (Number): The adjust parameter. Default: 1.0. | |||
| """ | |||
| def __init__(self, network, optimizer, sens=1.0): | |||
| super(TrainingWrapper, self).__init__(auto_prefix=False) | |||
| self.network = network | |||
| self.weights = ms.ParameterTuple(network.trainable_params()) | |||
| self.optimizer = optimizer | |||
| self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True) | |||
| self.sens = sens | |||
| self.reducer_flag = False | |||
| self.grad_reducer = None | |||
| self.parallel_mode = context.get_auto_parallel_context("parallel_mode") | |||
| if self.parallel_mode in [ms.ParallelMode.DATA_PARALLEL, ms.ParallelMode.HYBRID_PARALLEL]: | |||
| self.reducer_flag = True | |||
| if self.reducer_flag: | |||
| mean = context.get_auto_parallel_context("mirror_mean") | |||
| if auto_parallel_context().get_device_num_is_set(): | |||
| degree = context.get_auto_parallel_context("device_num") | |||
| else: | |||
| degree = get_group_size() | |||
| self.grad_reducer = nn.DistributedGradReducer(optimizer.parameters, mean, degree) | |||
| def construct(self, *args): | |||
| weights = self.weights | |||
| loss = self.network(*args) | |||
| sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) | |||
| grads = self.grad(self.network, weights)(*args, sens) | |||
| if self.reducer_flag: | |||
| # apply grad reducer on grads | |||
| grads = self.grad_reducer(grads) | |||
| return F.depend(loss, self.optimizer(grads)) | |||
| class SSDWithMobileNetV2(nn.Cell): | |||
| """ | |||
| MobileNetV2 architecture for SSD backbone. | |||
| Args: | |||
| width_mult (int): Channels multiplier for round to 8/16 and others. Default is 1. | |||
| inverted_residual_setting (list): Inverted residual settings. Default is None | |||
| round_nearest (list): Channel round to. Default is 8 | |||
| Returns: | |||
| Tensor, the 13th feature after ConvBNReLU in MobileNetV2. | |||
| Tensor, the last feature in MobileNetV2. | |||
| Examples: | |||
| >>> SSDWithMobileNetV2() | |||
| """ | |||
| def __init__(self, width_mult=1.0, inverted_residual_setting=None, round_nearest=8): | |||
| super(SSDWithMobileNetV2, self).__init__() | |||
| block = InvertedResidual | |||
| input_channel = 32 | |||
| last_channel = 1280 | |||
| if inverted_residual_setting is None: | |||
| inverted_residual_setting = [ | |||
| # t, c, n, s | |||
| [1, 16, 1, 1], | |||
| [6, 24, 2, 2], | |||
| [6, 32, 3, 2], | |||
| [6, 64, 4, 2], | |||
| [6, 96, 3, 1], | |||
| [6, 160, 3, 2], | |||
| [6, 320, 1, 1], | |||
| ] | |||
| if len(inverted_residual_setting[0]) != 4: | |||
| raise ValueError("inverted_residual_setting should be non-empty " | |||
| "or a 4-element list, got {}".format(inverted_residual_setting)) | |||
| #building first layer | |||
| input_channel = _make_divisible(input_channel * width_mult, round_nearest) | |||
| self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest) | |||
| features = [ConvBNReLU(3, input_channel, stride=2)] | |||
| # building inverted residual blocks | |||
| layer_index = 0 | |||
| for t, c, n, s in inverted_residual_setting: | |||
| output_channel = _make_divisible(c * width_mult, round_nearest) | |||
| for i in range(n): | |||
| if layer_index == 13: | |||
| hidden_dim = int(round(input_channel * t)) | |||
| self.expand_layer_conv_13 = ConvBNReLU(input_channel, hidden_dim, kernel_size=1) | |||
| stride = s if i == 0 else 1 | |||
| features.append(block(input_channel, output_channel, stride, expand_ratio=t)) | |||
| input_channel = output_channel | |||
| layer_index += 1 | |||
| # building last several layers | |||
| features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1)) | |||
| self.features_1 = nn.SequentialCell(features[:14]) | |||
| self.features_2 = nn.SequentialCell(features[14:]) | |||
| def construct(self, x): | |||
| out = self.features_1(x) | |||
| expand_layer_conv_13 = self.expand_layer_conv_13(out) | |||
| out = self.features_2(out) | |||
| return expand_layer_conv_13, out | |||
| def get_out_channels(self): | |||
| return self.last_channel | |||
| def ssd_mobilenet_v2(**kwargs): | |||
| return SSDWithMobileNetV2(**kwargs) | |||