Browse Source

Infer SSD with decoder

tags/v1.2.0-rc1
chenhaozhe 4 years ago
parent
commit
8476fe2c5e
8 changed files with 74 additions and 35 deletions
  1. +25
    -25
      model_zoo/official/cv/ssd/README.md
  2. +4
    -1
      model_zoo/official/cv/ssd/eval.py
  3. +3
    -1
      model_zoo/official/cv/ssd/export.py
  4. +1
    -1
      model_zoo/official/cv/ssd/scripts/run_distribute_train.sh
  5. +1
    -1
      model_zoo/official/cv/ssd/scripts/run_distribute_train_gpu.sh
  6. +4
    -4
      model_zoo/official/cv/ssd/src/dataset.py
  7. +0
    -2
      model_zoo/official/cv/ssd/src/eval_utils.py
  8. +36
    -0
      model_zoo/official/cv/ssd/src/ssd.py

+ 25
- 25
model_zoo/official/cv/ssd/README.md View File

@@ -371,34 +371,34 @@ The ckpt_file parameter is required.


#### Evaluation Performance #### Evaluation Performance


| Parameters | Ascend | GPU |
| -------------------------- | -------------------------------------------------------------| -------------------------------------------------------------|
| Model Version | SSD V1 | SSD V1 |
| Resource | Ascend 910 ;CPU 2.60GHz,192cores;Memory,755G | NV SMX2 V100-16G |
| uploaded Date | 09/15/2020 (month/day/year) | 09/24/2020 (month/day/year) |
| MindSpore Version | 1.0.0 | 1.0.0 |
| Dataset | COCO2017 | COCO2017 |
| Training Parameters | epoch = 500, batch_size = 32 | epoch = 800, batch_size = 32 |
| Optimizer | Momentum | Momentum |
| Loss Function | Sigmoid Cross Entropy,SmoothL1Loss | Sigmoid Cross Entropy,SmoothL1Loss |
| Speed | 8pcs: 90ms/step | 8pcs: 121ms/step |
| Total time | 8pcs: 4.81hours | 8pcs: 12.31hours |
| Parameters (M) | 34 | 34 |
| Scripts | <https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/ssd> | <https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/ssd> |
| Parameters | Ascend | GPU | Ascend |
| ------------------- | ----------------------------------------------------------------------------- | ----------------------------------------------------------------------------- | ----------------------------------------------------------------------------- |
| Model Version | SSD V1 | SSD V1 | SSD-Mobilenet-V1-Fpn |
| Resource | Ascend 910 ;CPU 2.60GHz,192cores;Memory,755G | NV SMX2 V100-16G | Ascend 910 ;CPU 2.60GHz,192cores;Memory,755G |
| uploaded Date | 09/15/2020 (month/day/year) | 09/24/2020 (month/day/year) | 01/13/2021 (month/day/year) |
| MindSpore Version | 1.0.0 | 1.0.0 | 1.1.0 |
| Dataset | COCO2017 | COCO2017 | COCO2017 |
| Training Parameters | epoch = 500, batch_size = 32 | epoch = 800, batch_size = 32 | epoch = 60, batch_size = 32 |
| Optimizer | Momentum | Momentum | Momentum |
| Loss Function | Sigmoid Cross Entropy,SmoothL1Loss | Sigmoid Cross Entropy,SmoothL1Loss | Sigmoid Cross Entropy,SmoothL1Loss |
| Speed | 8pcs: 90ms/step | 8pcs: 121ms/step | 8pcs: 547ms/step |
| Total time | 8pcs: 4.81hours | 8pcs: 12.31hours | 8pcs: 4.22hours |
| Parameters (M) | 34 | 34 | 48M |
| Scripts | <https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/ssd> | <https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/ssd> | <https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/ssd> |


#### Inference Performance #### Inference Performance


| Parameters | Ascend | GPU |
| ------------------- | ----------------------------| ----------------------------|
| Model Version | SSD V1 | SSD V1 |
| Resource | Ascend 910 | GPU |
| Uploaded Date | 09/15/2020 (month/day/year) | 09/24/2020 (month/day/year) |
| MindSpore Version | 1.0.0 | 1.0.0 |
| Dataset | COCO2017 | COCO2017 |
| batch_size | 1 | 1 |
| outputs | mAP | mAP |
| Accuracy | IoU=0.50: 23.8% | IoU=0.50: 22.4% |
| Model for inference | 34M(.ckpt file) | 34M(.ckpt file) |
| Parameters | Ascend | GPU | Ascend |
| ------------------- | --------------------------- | --------------------------- | --------------------------- |
| Model Version | SSD V1 | SSD V1 | SSD-Mobilenet-V1-Fpn |
| Resource | Ascend 910 | GPU | Ascend 910 |
| Uploaded Date | 09/15/2020 (month/day/year) | 09/24/2020 (month/day/year) | 09/24/2020 (month/day/year) |
| MindSpore Version | 1.0.0 | 1.0.0 | 1.1.0 |
| Dataset | COCO2017 | COCO2017 | COCO2017 |
| batch_size | 1 | 1 | 1 |
| outputs | mAP | mAP | mAP |
| Accuracy | IoU=0.50: 23.8% | IoU=0.50: 22.4% | Iout=0.50: 30% |
| Model for inference | 34M(.ckpt file) | 34M(.ckpt file) | 48M(.ckpt file) |


## [Description of Random Situation](#contents) ## [Description of Random Situation](#contents)




+ 4
- 1
model_zoo/official/cv/ssd/eval.py View File

@@ -21,10 +21,11 @@ import time
import numpy as np import numpy as np
from mindspore import context, Tensor from mindspore import context, Tensor
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.ssd import SSD300, ssd_mobilenet_v2, ssd_mobilenet_v1_fpn
from src.ssd import SSD300, SsdInferWithDecoder, ssd_mobilenet_v2, ssd_mobilenet_v1_fpn
from src.dataset import create_ssd_dataset, create_mindrecord from src.dataset import create_ssd_dataset, create_mindrecord
from src.config import config from src.config import config
from src.eval_utils import metrics from src.eval_utils import metrics
from src.box_utils import default_boxes


def ssd_eval(dataset_path, ckpt_path, anno_json): def ssd_eval(dataset_path, ckpt_path, anno_json):
"""SSD evaluation.""" """SSD evaluation."""
@@ -35,6 +36,8 @@ def ssd_eval(dataset_path, ckpt_path, anno_json):
net = SSD300(ssd_mobilenet_v2(), config, is_training=False) net = SSD300(ssd_mobilenet_v2(), config, is_training=False)
else: else:
net = ssd_mobilenet_v1_fpn(config=config) net = ssd_mobilenet_v1_fpn(config=config)
net = SsdInferWithDecoder(net, Tensor(default_boxes), config)

print("Load Checkpoint!") print("Load Checkpoint!")
param_dict = load_checkpoint(ckpt_path) param_dict = load_checkpoint(ckpt_path)
net.init_parameters_data() net.init_parameters_data()


+ 3
- 1
model_zoo/official/cv/ssd/export.py View File

@@ -19,8 +19,9 @@ import numpy as np
import mindspore import mindspore
from mindspore import context, Tensor from mindspore import context, Tensor
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
from src.ssd import SSD300, ssd_mobilenet_v2, ssd_mobilenet_v1_fpn
from src.ssd import SSD300, SsdInferWithDecoder, ssd_mobilenet_v2, ssd_mobilenet_v1_fpn
from src.config import config from src.config import config
from src.box_utils import default_boxes


parser = argparse.ArgumentParser(description='SSD export') parser = argparse.ArgumentParser(description='SSD export')
parser.add_argument("--device_id", type=int, default=0, help="Device id") parser.add_argument("--device_id", type=int, default=0, help="Device id")
@@ -41,6 +42,7 @@ if __name__ == '__main__':
net = SSD300(ssd_mobilenet_v2(), config, is_training=False) net = SSD300(ssd_mobilenet_v2(), config, is_training=False)
else: else:
net = ssd_mobilenet_v1_fpn(config=config) net = ssd_mobilenet_v1_fpn(config=config)
net = SsdInferWithDecoder(net, Tensor(default_boxes), config)


param_dict = load_checkpoint(args.ckpt_file) param_dict = load_checkpoint(args.ckpt_file)
net.init_parameters_data() net.init_parameters_data()


+ 1
- 1
model_zoo/official/cv/ssd/scripts/run_distribute_train.sh View File

@@ -31,7 +31,7 @@ fi
# Before start distribute train, first create mindrecord files. # Before start distribute train, first create mindrecord files.
BASE_PATH=$(cd "`dirname $0`" || exit; pwd) BASE_PATH=$(cd "`dirname $0`" || exit; pwd)
cd $BASE_PATH/../ || exit cd $BASE_PATH/../ || exit
python train.py --only_create_dataset=True
python train.py --only_create_dataset=True --dataset=$4


echo "After running the scipt, the network runs in the background. The log will be generated in LOGx/log.txt" echo "After running the scipt, the network runs in the background. The log will be generated in LOGx/log.txt"




+ 1
- 1
model_zoo/official/cv/ssd/scripts/run_distribute_train_gpu.sh View File

@@ -31,7 +31,7 @@ fi
# Before start distribute train, first create mindrecord files. # Before start distribute train, first create mindrecord files.
BASE_PATH=$(cd "`dirname $0`" || exit; pwd) BASE_PATH=$(cd "`dirname $0`" || exit; pwd)
cd $BASE_PATH/../ || exit cd $BASE_PATH/../ || exit
python train.py --only_create_dataset=True --run_platform="GPU"
python train.py --only_create_dataset=True --run_platform="GPU" --dataset=$4


echo "After running the scipt, the network runs in the background. The log will be generated in LOG/log.txt" echo "After running the scipt, the network runs in the background. The log will be generated in LOG/log.txt"




+ 4
- 4
model_zoo/official/cv/ssd/src/dataset.py View File

@@ -207,10 +207,10 @@ def create_voc_label(is_training):
print(f'Label "{cls_name}" not in "{config.classes}"') print(f'Label "{cls_name}" not in "{config.classes}"')
continue continue
bnd_box = obj.find('bndbox') bnd_box = obj.find('bndbox')
x_min = int(bnd_box.find('xmin').text) - 1
y_min = int(bnd_box.find('ymin').text) - 1
x_max = int(bnd_box.find('xmax').text) - 1
y_max = int(bnd_box.find('ymax').text) - 1
x_min = int(float(bnd_box.find('xmin').text)) - 1
y_min = int(float(bnd_box.find('ymin').text)) - 1
x_max = int(float(bnd_box.find('xmax').text)) - 1
y_max = int(float(bnd_box.find('ymax').text)) - 1
labels.append([y_min, x_min, y_max, x_max, cls_map[cls_name]]) labels.append([y_min, x_min, y_max, x_max, cls_map[cls_name]])


if not is_training: if not is_training:


+ 0
- 2
model_zoo/official/cv/ssd/src/eval_utils.py View File

@@ -17,7 +17,6 @@
import json import json
import numpy as np import numpy as np
from .config import config from .config import config
from .box_utils import ssd_bboxes_decode




def apply_nms(all_boxes, all_scores, thres, max_boxes): def apply_nms(all_boxes, all_scores, thres, max_boxes):
@@ -81,7 +80,6 @@ def metrics(pred_data, anno_json):
img_id = sample['img_id'] img_id = sample['img_id']
h, w = sample['image_shape'] h, w = sample['image_shape']


pred_boxes = ssd_bboxes_decode(pred_boxes)
final_boxes = [] final_boxes = []
final_label = [] final_label = []
final_score = [] final_score = []


+ 36
- 0
model_zoo/official/cv/ssd/src/ssd.py View File

@@ -569,6 +569,42 @@ class SSDWithMobileNetV2(nn.Cell):
return self.last_channel return self.last_channel




class SsdInferWithDecoder(nn.Cell):
"""
SSD Infer wrapper to decode the bbox locations.

Args:
network (Cell): the origin ssd infer network without bbox decoder.
default_boxes (Tensor): the default_boxes from anchor generator
config (dict): ssd config
Returns:
Tensor, the locations for bbox after decoder representing (y0,x0,y1,x1)
Tensor, the prediction labels.

"""
def __init__(self, network, default_boxes, config):
super(SsdInferWithDecoder, self).__init__()
self.network = network
self.default_boxes = default_boxes
self.prior_scaling_xy = config.prior_scaling[0]
self.prior_scaling_wh = config.prior_scaling[1]

def construct(self, x):
pred_loc, pred_label = self.network(x)

default_bbox_xy = self.default_boxes[..., :2]
default_bbox_wh = self.default_boxes[..., 2:]
pred_xy = pred_loc[..., :2] * self.prior_scaling_xy * default_bbox_wh + default_bbox_xy
pred_wh = P.Exp()(pred_loc[..., 2:] * self.prior_scaling_wh) * default_bbox_wh

pred_xy_0 = pred_xy - pred_wh / 2.0
pred_xy_1 = pred_xy + pred_wh / 2.0
pred_xy = P.Concat(-1)((pred_xy_0, pred_xy_1))
pred_xy = P.Maximum()(pred_xy, 0)
pred_xy = P.Minimum()(pred_xy, 1)
return pred_xy, pred_label


def ssd_mobilenet_v1_fpn(**kwargs): def ssd_mobilenet_v1_fpn(**kwargs):
return SsdMobilenetV1Fpn(**kwargs) return SsdMobilenetV1Fpn(**kwargs)




Loading…
Cancel
Save