| @@ -371,34 +371,34 @@ The ckpt_file parameter is required. | |||
| #### Evaluation Performance | |||
| | Parameters | Ascend | GPU | | |||
| | -------------------------- | -------------------------------------------------------------| -------------------------------------------------------------| | |||
| | Model Version | SSD V1 | SSD V1 | | |||
| | Resource | Ascend 910 ;CPU 2.60GHz,192cores;Memory,755G | NV SMX2 V100-16G | | |||
| | uploaded Date | 09/15/2020 (month/day/year) | 09/24/2020 (month/day/year) | | |||
| | MindSpore Version | 1.0.0 | 1.0.0 | | |||
| | Dataset | COCO2017 | COCO2017 | | |||
| | Training Parameters | epoch = 500, batch_size = 32 | epoch = 800, batch_size = 32 | | |||
| | Optimizer | Momentum | Momentum | | |||
| | Loss Function | Sigmoid Cross Entropy,SmoothL1Loss | Sigmoid Cross Entropy,SmoothL1Loss | | |||
| | Speed | 8pcs: 90ms/step | 8pcs: 121ms/step | | |||
| | Total time | 8pcs: 4.81hours | 8pcs: 12.31hours | | |||
| | Parameters (M) | 34 | 34 | | |||
| | Scripts | <https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/ssd> | <https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/ssd> | | |||
| | Parameters | Ascend | GPU | Ascend | | |||
| | ------------------- | ----------------------------------------------------------------------------- | ----------------------------------------------------------------------------- | ----------------------------------------------------------------------------- | | |||
| | Model Version | SSD V1 | SSD V1 | SSD-Mobilenet-V1-Fpn | | |||
| | Resource | Ascend 910 ;CPU 2.60GHz,192cores;Memory,755G | NV SMX2 V100-16G | Ascend 910 ;CPU 2.60GHz,192cores;Memory,755G | | |||
| | uploaded Date | 09/15/2020 (month/day/year) | 09/24/2020 (month/day/year) | 01/13/2021 (month/day/year) | | |||
| | MindSpore Version | 1.0.0 | 1.0.0 | 1.1.0 | | |||
| | Dataset | COCO2017 | COCO2017 | COCO2017 | | |||
| | Training Parameters | epoch = 500, batch_size = 32 | epoch = 800, batch_size = 32 | epoch = 60, batch_size = 32 | | |||
| | Optimizer | Momentum | Momentum | Momentum | | |||
| | Loss Function | Sigmoid Cross Entropy,SmoothL1Loss | Sigmoid Cross Entropy,SmoothL1Loss | Sigmoid Cross Entropy,SmoothL1Loss | | |||
| | Speed | 8pcs: 90ms/step | 8pcs: 121ms/step | 8pcs: 547ms/step | | |||
| | Total time | 8pcs: 4.81hours | 8pcs: 12.31hours | 8pcs: 4.22hours | | |||
| | Parameters (M) | 34 | 34 | 48M | | |||
| | Scripts | <https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/ssd> | <https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/ssd> | <https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/ssd> | | |||
| #### Inference Performance | |||
| | Parameters | Ascend | GPU | | |||
| | ------------------- | ----------------------------| ----------------------------| | |||
| | Model Version | SSD V1 | SSD V1 | | |||
| | Resource | Ascend 910 | GPU | | |||
| | Uploaded Date | 09/15/2020 (month/day/year) | 09/24/2020 (month/day/year) | | |||
| | MindSpore Version | 1.0.0 | 1.0.0 | | |||
| | Dataset | COCO2017 | COCO2017 | | |||
| | batch_size | 1 | 1 | | |||
| | outputs | mAP | mAP | | |||
| | Accuracy | IoU=0.50: 23.8% | IoU=0.50: 22.4% | | |||
| | Model for inference | 34M(.ckpt file) | 34M(.ckpt file) | | |||
| | Parameters | Ascend | GPU | Ascend | | |||
| | ------------------- | --------------------------- | --------------------------- | --------------------------- | | |||
| | Model Version | SSD V1 | SSD V1 | SSD-Mobilenet-V1-Fpn | | |||
| | Resource | Ascend 910 | GPU | Ascend 910 | | |||
| | Uploaded Date | 09/15/2020 (month/day/year) | 09/24/2020 (month/day/year) | 09/24/2020 (month/day/year) | | |||
| | MindSpore Version | 1.0.0 | 1.0.0 | 1.1.0 | | |||
| | Dataset | COCO2017 | COCO2017 | COCO2017 | | |||
| | batch_size | 1 | 1 | 1 | | |||
| | outputs | mAP | mAP | mAP | | |||
| | Accuracy | IoU=0.50: 23.8% | IoU=0.50: 22.4% | Iout=0.50: 30% | | |||
| | Model for inference | 34M(.ckpt file) | 34M(.ckpt file) | 48M(.ckpt file) | | |||
| ## [Description of Random Situation](#contents) | |||
| @@ -21,10 +21,11 @@ import time | |||
| import numpy as np | |||
| from mindspore import context, Tensor | |||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| from src.ssd import SSD300, ssd_mobilenet_v2, ssd_mobilenet_v1_fpn | |||
| from src.ssd import SSD300, SsdInferWithDecoder, ssd_mobilenet_v2, ssd_mobilenet_v1_fpn | |||
| from src.dataset import create_ssd_dataset, create_mindrecord | |||
| from src.config import config | |||
| from src.eval_utils import metrics | |||
| from src.box_utils import default_boxes | |||
| def ssd_eval(dataset_path, ckpt_path, anno_json): | |||
| """SSD evaluation.""" | |||
| @@ -35,6 +36,8 @@ def ssd_eval(dataset_path, ckpt_path, anno_json): | |||
| net = SSD300(ssd_mobilenet_v2(), config, is_training=False) | |||
| else: | |||
| net = ssd_mobilenet_v1_fpn(config=config) | |||
| net = SsdInferWithDecoder(net, Tensor(default_boxes), config) | |||
| print("Load Checkpoint!") | |||
| param_dict = load_checkpoint(ckpt_path) | |||
| net.init_parameters_data() | |||
| @@ -19,8 +19,9 @@ import numpy as np | |||
| import mindspore | |||
| from mindspore import context, Tensor | |||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net, export | |||
| from src.ssd import SSD300, ssd_mobilenet_v2, ssd_mobilenet_v1_fpn | |||
| from src.ssd import SSD300, SsdInferWithDecoder, ssd_mobilenet_v2, ssd_mobilenet_v1_fpn | |||
| from src.config import config | |||
| from src.box_utils import default_boxes | |||
| parser = argparse.ArgumentParser(description='SSD export') | |||
| parser.add_argument("--device_id", type=int, default=0, help="Device id") | |||
| @@ -41,6 +42,7 @@ if __name__ == '__main__': | |||
| net = SSD300(ssd_mobilenet_v2(), config, is_training=False) | |||
| else: | |||
| net = ssd_mobilenet_v1_fpn(config=config) | |||
| net = SsdInferWithDecoder(net, Tensor(default_boxes), config) | |||
| param_dict = load_checkpoint(args.ckpt_file) | |||
| net.init_parameters_data() | |||
| @@ -31,7 +31,7 @@ fi | |||
| # Before start distribute train, first create mindrecord files. | |||
| BASE_PATH=$(cd "`dirname $0`" || exit; pwd) | |||
| cd $BASE_PATH/../ || exit | |||
| python train.py --only_create_dataset=True | |||
| python train.py --only_create_dataset=True --dataset=$4 | |||
| echo "After running the scipt, the network runs in the background. The log will be generated in LOGx/log.txt" | |||
| @@ -31,7 +31,7 @@ fi | |||
| # Before start distribute train, first create mindrecord files. | |||
| BASE_PATH=$(cd "`dirname $0`" || exit; pwd) | |||
| cd $BASE_PATH/../ || exit | |||
| python train.py --only_create_dataset=True --run_platform="GPU" | |||
| python train.py --only_create_dataset=True --run_platform="GPU" --dataset=$4 | |||
| echo "After running the scipt, the network runs in the background. The log will be generated in LOG/log.txt" | |||
| @@ -207,10 +207,10 @@ def create_voc_label(is_training): | |||
| print(f'Label "{cls_name}" not in "{config.classes}"') | |||
| continue | |||
| bnd_box = obj.find('bndbox') | |||
| x_min = int(bnd_box.find('xmin').text) - 1 | |||
| y_min = int(bnd_box.find('ymin').text) - 1 | |||
| x_max = int(bnd_box.find('xmax').text) - 1 | |||
| y_max = int(bnd_box.find('ymax').text) - 1 | |||
| x_min = int(float(bnd_box.find('xmin').text)) - 1 | |||
| y_min = int(float(bnd_box.find('ymin').text)) - 1 | |||
| x_max = int(float(bnd_box.find('xmax').text)) - 1 | |||
| y_max = int(float(bnd_box.find('ymax').text)) - 1 | |||
| labels.append([y_min, x_min, y_max, x_max, cls_map[cls_name]]) | |||
| if not is_training: | |||
| @@ -17,7 +17,6 @@ | |||
| import json | |||
| import numpy as np | |||
| from .config import config | |||
| from .box_utils import ssd_bboxes_decode | |||
| def apply_nms(all_boxes, all_scores, thres, max_boxes): | |||
| @@ -81,7 +80,6 @@ def metrics(pred_data, anno_json): | |||
| img_id = sample['img_id'] | |||
| h, w = sample['image_shape'] | |||
| pred_boxes = ssd_bboxes_decode(pred_boxes) | |||
| final_boxes = [] | |||
| final_label = [] | |||
| final_score = [] | |||
| @@ -569,6 +569,42 @@ class SSDWithMobileNetV2(nn.Cell): | |||
| return self.last_channel | |||
| class SsdInferWithDecoder(nn.Cell): | |||
| """ | |||
| SSD Infer wrapper to decode the bbox locations. | |||
| Args: | |||
| network (Cell): the origin ssd infer network without bbox decoder. | |||
| default_boxes (Tensor): the default_boxes from anchor generator | |||
| config (dict): ssd config | |||
| Returns: | |||
| Tensor, the locations for bbox after decoder representing (y0,x0,y1,x1) | |||
| Tensor, the prediction labels. | |||
| """ | |||
| def __init__(self, network, default_boxes, config): | |||
| super(SsdInferWithDecoder, self).__init__() | |||
| self.network = network | |||
| self.default_boxes = default_boxes | |||
| self.prior_scaling_xy = config.prior_scaling[0] | |||
| self.prior_scaling_wh = config.prior_scaling[1] | |||
| def construct(self, x): | |||
| pred_loc, pred_label = self.network(x) | |||
| default_bbox_xy = self.default_boxes[..., :2] | |||
| default_bbox_wh = self.default_boxes[..., 2:] | |||
| pred_xy = pred_loc[..., :2] * self.prior_scaling_xy * default_bbox_wh + default_bbox_xy | |||
| pred_wh = P.Exp()(pred_loc[..., 2:] * self.prior_scaling_wh) * default_bbox_wh | |||
| pred_xy_0 = pred_xy - pred_wh / 2.0 | |||
| pred_xy_1 = pred_xy + pred_wh / 2.0 | |||
| pred_xy = P.Concat(-1)((pred_xy_0, pred_xy_1)) | |||
| pred_xy = P.Maximum()(pred_xy, 0) | |||
| pred_xy = P.Minimum()(pred_xy, 1) | |||
| return pred_xy, pred_label | |||
| def ssd_mobilenet_v1_fpn(**kwargs): | |||
| return SsdMobilenetV1Fpn(**kwargs) | |||