diff --git a/model_zoo/official/cv/ssd/README.md b/model_zoo/official/cv/ssd/README.md
index 0d9ee3c707..1105b42fc1 100644
--- a/model_zoo/official/cv/ssd/README.md
+++ b/model_zoo/official/cv/ssd/README.md
@@ -12,6 +12,7 @@
- [Training](#training)
- [Evaluation Process](#evaluation-process)
- [Evaluation](#evaluation)
+ - [Export MindIR](#export-mindir)
- [Model Description](#model-description)
- [Performance](#performance)
- [Evaluation Performance](#evaluation-performance)
@@ -49,21 +50,23 @@ Dataset used: [COCO2017]()
- Download the dataset COCO2017.
- We use COCO2017 as training dataset in this example by default, and you can also use your own datasets.
+ First, install Cython ,pycocotool and opencv to process data and to get evaluation result.
- 1. If coco dataset is used. **Select dataset to coco when run script.**
- Install Cython and pycocotool, and you can also install mmcv to process data.
+ ```
+ pip install Cython
- ```
- pip install Cython
+ pip install pycocotools
- pip install pycocotools
+ pip install opencv-python
- ```
- And change the COCO_ROOT and other settings you need in `config.py`. The directory structure is as follows:
+ ```
+ 1. If coco dataset is used. **Select dataset to coco when run script.**
+
+ Change the `coco_root` and other settings you need in `src/config.py`. The directory structure is as follows:
```
.
- └─cocodataset
+ └─coco_dataset
├─annotations
├─instance_train2017.json
└─instance_val2017.json
@@ -72,7 +75,27 @@ Dataset used: [COCO2017]()
```
- 2. If your own dataset is used. **Select dataset to other when run script.**
+ 2. If VOC dataset is used. **Select dataset to voc when run script.**
+ Change `classes`, `num_classes`, `voc_json` and `voc_root` in `src/config.py`. `voc_json` is the path of json file with coco format for evalution, `voc_root` is the path of VOC dataset, the directory structure is as follows:
+ ```
+ .
+ └─voc_dataset
+ └─train
+ ├─0001.jpg
+ └─0001.xml
+ ...
+ ├─xxxx.jpg
+ └─xxxx.xml
+ └─eval
+ ├─0001.jpg
+ └─0001.xml
+ ...
+ ├─xxxx.jpg
+ └─xxxx.xml
+
+ ```
+
+ 3. If your own dataset is used. **Select dataset to other when run script.**
Organize the dataset infomation into a TXT file, each row in the file is as follows:
```
@@ -80,7 +103,7 @@ Dataset used: [COCO2017]()
```
- Each row is an image annotation which split by space, the first column is a relative path of image, the others are box and class infomations of the format [xmin,ymin,xmax,ymax,class]. We read image from an image path joined by the `IMAGE_DIR`(dataset directory) and the relative path in `ANNO_PATH`(the TXT file path), `IMAGE_DIR` and `ANNO_PATH` are setting in `config.py`.
+ Each row is an image annotation which split by space, the first column is a relative path of image, the others are box and class infomations of the format [xmin,ymin,xmax,ymax,class]. We read image from an image path joined by the `image_dir`(dataset directory) and the relative path in `anno_path`(the TXT file path), `image_dir` and `anno_path` are setting in `src/config.py`.
# [Quick Start](#contents)
@@ -103,7 +126,19 @@ sh run_distribute_train_gpu.sh [DEVICE_NUM] [EPOCH_SIZE] [LR] [DATASET]
sh run_eval_gpu.sh [DATASET] [CHECKPOINT_PATH] [DEVICE_ID]
```
-# [Script Description](#contents)
+- runing on CPU(support Windows and Ubuntu)
+
+**CPU is usually used for fine-tuning, which needs pre_trained checkpoint.**
+
+```
+# training on CPU
+python train.py --run_platform=CPU --lr=[LR] --dataset=[DATASET] --epoch_size=[EPOCH_SIZE] --batch_size=[BATCH_SIZE] --pre_trained=[PRETRAINED_CKPT] --filter_weight=True --save_checkpoint_epochs=1
+
+# run eval on GPU
+python eval.py --run_platform=CPU --dataset=[DATASET] --checkpoint_path=[PRETRAINED_CKPT]
+```
+
+# [Script Description](#contents)
## [Script and Sample Code](#contents)
@@ -111,24 +146,25 @@ sh run_eval_gpu.sh [DATASET] [CHECKPOINT_PATH] [DEVICE_ID]
.
└─ cv
└─ ssd
- ├─ README.md ## descriptions about SSD
+ ├─ README.md # descriptions about SSD
├─ scripts
- ├─ run_distribute_train.sh ## shell script for distributed on ascend
- ├─ run_distribute_train_gpu.sh ## shell script for distributed on gpu
- ├─ run_eval.sh ## shell script for eval on ascend
- └─ run_eval_gpu.sh ## shell script for eval on gpu
+ ├─ run_distribute_train.sh # shell script for distributed on ascend
+ ├─ run_distribute_train_gpu.sh # shell script for distributed on gpu
+ ├─ run_eval.sh # shell script for eval on ascend
+ └─ run_eval_gpu.sh # shell script for eval on gpu
├─ src
- ├─ __init__.py ## init file
- ├─ box_util.py ## bbox utils
- ├─ coco_eval.py ## coco metrics utils
- ├─ config.py ## total config
- ├─ dataset.py ## create dataset and process dataset
- ├─ init_params.py ## parameters utils
- ├─ lr_schedule.py ## learning ratio generator
- └─ ssd.py ## ssd architecture
- ├─ eval.py ## eval scripts
- ├─ train.py ## train scripts
- └─ mindspore_hub_conf.py ## mindspore hub interface
+ ├─ __init__.py # init file
+ ├─ box_utils.py # bbox utils
+ ├─ eval_utils.py # metrics utils
+ ├─ config.py # total config
+ ├─ dataset.py # create dataset and process dataset
+ ├─ init_params.py # parameters utils
+ ├─ lr_schedule.py # learning ratio generator
+ └─ ssd.py # ssd architecture
+ ├─ eval.py # eval scripts
+ ├─ train.py # train scripts
+ ├─ export.py # export mindir script
+ └─ mindspore_hub_conf.py # mindspore hub interface
```
## [Script Parameters](#contents)
@@ -136,30 +172,33 @@ sh run_eval_gpu.sh [DATASET] [CHECKPOINT_PATH] [DEVICE_ID]
```
Major parameters in train.py and config.py as follows:
- "device_num": 1 # Use device nums
- "lr": 0.05 # Learning rate init value
- "dataset": coco # Dataset name
- "epoch_size": 500 # Epoch size
- "batch_size": 32 # Batch size of input tensor
- "pre_trained": None # Pretrained checkpoint file path
- "pre_trained_epoch_size": 0 # Pretrained epoch size
- "save_checkpoint_epochs": 10 # The epoch interval between two checkpoints. By default, the checkpoint will be saved per 10 epochs
- "loss_scale": 1024 # Loss scale
-
- "class_num": 81 # Dataset class number
- "image_shape": [300, 300] # Image height and width used as input to the model
- "mindrecord_dir": "/data/MindRecord_COCO" # MindRecord path
- "coco_root": "/data/coco2017" # COCO2017 dataset path
- "voc_root": "" # VOC original dataset path
- "image_dir": "" # Other dataset image path, if coco or voc used, it will be useless
- "anno_path": "" # Other dataset annotation path, if coco or voc used, it will be useless
+ "device_num": 1 # Use device nums
+ "lr": 0.05 # Learning rate init value
+ "dataset": coco # Dataset name
+ "epoch_size": 500 # Epoch size
+ "batch_size": 32 # Batch size of input tensor
+ "pre_trained": None # Pretrained checkpoint file path
+ "pre_trained_epoch_size": 0 # Pretrained epoch size
+ "save_checkpoint_epochs": 10 # The epoch interval between two checkpoints. By default, the checkpoint will be saved per 10 epochs
+ "loss_scale": 1024 # Loss scale
+ "filter_weight": False # Load paramters in head layer or not. If the class numbers of train dataset is different from the class numbers in pre_trained checkpoint, please set True.
+ "freeze_layer": "none" # Freeze the backbone paramters or not, support none and backbone.
+
+ "class_num": 81 # Dataset class number
+ "image_shape": [300, 300] # Image height and width used as input to the model
+ "mindrecord_dir": "/data/MindRecord_COCO" # MindRecord path
+ "coco_root": "/data/coco2017" # COCO2017 dataset path
+ "voc_root": "/data/voc_dataset" # VOC original dataset path
+ "voc_json": "annotations/voc_instances_val.json" # is the path of json file with coco format for evalution
+ "image_dir": "" # Other dataset image path, if coco or voc used, it will be useless
+ "anno_path": "" # Other dataset annotation path, if coco or voc used, it will be useless
```
## [Training Process](#contents)
-To train the model, run `train.py`. If the `mindrecord_dir` is empty, it will generate [mindrecord](https://www.mindspore.cn/tutorial/training/zh-CN/master/advanced_use/convert_dataset.html) files by `coco_root`(coco dataset) or `iamge_dir` and `anno_path`(own dataset). **Note if mindrecord_dir isn't empty, it will use mindrecord_dir instead of raw images.**
+To train the model, run `train.py`. If the `mindrecord_dir` is empty, it will generate [mindrecord](https://www.mindspore.cn/tutorial/training/zh-CN/master/advanced_use/convert_dataset.html) files by `coco_root`(coco dataset), `voc_root`(voc dataset) or `image_dir` and `anno_path`(own dataset). **Note if mindrecord_dir isn't empty, it will use mindrecord_dir instead of raw images.**
### Training on Ascend
@@ -292,6 +331,14 @@ Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.686
mAP: 0.2244936111705981
```
+## [Export MindIR](#contents)
+
+Change the export mode and export file in `src/config.py`, and run `export.py`.
+
+```
+python export.py --run_platform [PLATFORM] --checkpoint_path [CKPT_PATH]
+```
+
# [Model Description](#contents)
## [Performance](#contents)
diff --git a/model_zoo/official/cv/ssd/eval.py b/model_zoo/official/cv/ssd/eval.py
index d40d4a2ec8..f98b98926f 100644
--- a/model_zoo/official/cv/ssd/eval.py
+++ b/model_zoo/official/cv/ssd/eval.py
@@ -22,14 +22,15 @@ import numpy as np
from mindspore import context, Tensor
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.ssd import SSD300, ssd_mobilenet_v2
-from src.dataset import create_ssd_dataset, data_to_mindrecord_byte_image, voc_data_to_mindrecord
+from src.dataset import create_ssd_dataset, create_mindrecord
from src.config import config
-from src.coco_eval import metrics
+from src.eval_utils import metrics
-def ssd_eval(dataset_path, ckpt_path):
+def ssd_eval(dataset_path, ckpt_path, anno_json):
"""SSD evaluation."""
batch_size = 1
- ds = create_ssd_dataset(dataset_path, batch_size=batch_size, repeat_num=1, is_training=False)
+ ds = create_ssd_dataset(dataset_path, batch_size=batch_size, repeat_num=1,
+ is_training=False, use_multiprocessing=False)
net = SSD300(ssd_mobilenet_v2(), config, is_training=False)
print("Load Checkpoint!")
param_dict = load_checkpoint(ckpt_path)
@@ -61,51 +62,31 @@ def ssd_eval(dataset_path, ckpt_path):
i += batch_size
cost_time = int((time.time() - start) * 1000)
print(f' 100% [{total}/{total}] cost {cost_time} ms')
- mAP = metrics(pred_data)
+ mAP = metrics(pred_data, anno_json)
print("\n========================================\n")
print(f"mAP: {mAP}")
-
-if __name__ == '__main__':
+def get_eval_args():
parser = argparse.ArgumentParser(description='SSD evaluation')
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
parser.add_argument("--dataset", type=str, default="coco", help="Dataset, default is coco.")
parser.add_argument("--checkpoint_path", type=str, required=True, help="Checkpoint file path.")
- parser.add_argument("--run_platform", type=str, default="Ascend", choices=("Ascend", "GPU"),
- help="run platform, only support Ascend and GPU.")
- args_opt = parser.parse_args()
+ parser.add_argument("--run_platform", type=str, default="Ascend", choices=("Ascend", "GPU", "CPU"),
+ help="run platform, support Ascend ,GPU and CPU.")
+ return parser.parse_args()
+
+if __name__ == '__main__':
+ args_opt = get_eval_args()
+ if args_opt.dataset == "coco":
+ json_path = os.path.join(config.coco_root, config.instances_set.format(config.val_data_type))
+ elif args_opt.dataset == "voc":
+ json_path = os.path.join(config.voc_root, config.voc_json)
+ else:
+ raise ValueError('SSD eval only supprt dataset mode is coco and voc!')
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.run_platform, device_id=args_opt.device_id)
- prefix = "ssd_eval.mindrecord"
- mindrecord_dir = config.mindrecord_dir
- mindrecord_file = os.path.join(mindrecord_dir, prefix + "0")
- if args_opt.dataset == "voc":
- config.coco_root = config.voc_root
- if not os.path.exists(mindrecord_file):
- if not os.path.isdir(mindrecord_dir):
- os.makedirs(mindrecord_dir)
- if args_opt.dataset == "coco":
- if os.path.isdir(config.coco_root):
- print("Create Mindrecord.")
- data_to_mindrecord_byte_image("coco", False, prefix)
- print("Create Mindrecord Done, at {}".format(mindrecord_dir))
- else:
- print("coco_root not exits.")
- elif args_opt.dataset == "voc":
- if os.path.isdir(config.voc_dir) and os.path.isdir(config.voc_root):
- print("Create Mindrecord.")
- voc_data_to_mindrecord(mindrecord_dir, False, prefix)
- print("Create Mindrecord Done, at {}".format(mindrecord_dir))
- else:
- print("voc_root or voc_dir not exits.")
- else:
- if os.path.isdir(config.image_dir) and os.path.exists(config.anno_path):
- print("Create Mindrecord.")
- data_to_mindrecord_byte_image("other", False, prefix)
- print("Create Mindrecord Done, at {}".format(mindrecord_dir))
- else:
- print("IMAGE_DIR or ANNO_PATH not exits.")
+ mindrecord_file = create_mindrecord(args_opt.dataset, "ssd_eval.mindrecord", False)
print("Start Eval!")
- ssd_eval(mindrecord_file, args_opt.checkpoint_path)
+ ssd_eval(mindrecord_file, args_opt.checkpoint_path, json_path)
diff --git a/model_zoo/official/cv/ssd/export.py b/model_zoo/official/cv/ssd/export.py
new file mode 100644
index 0000000000..1d5f0087e6
--- /dev/null
+++ b/model_zoo/official/cv/ssd/export.py
@@ -0,0 +1,41 @@
+# Copyright 2020 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""
+ssd export mindir.
+"""
+import argparse
+import numpy as np
+from mindspore import context, Tensor
+from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
+from src.ssd import SSD300, ssd_mobilenet_v2
+from src.config import config
+
+def get_export_args():
+ parser = argparse.ArgumentParser(description='SSD export')
+ parser.add_argument("--checkpoint_path", type=str, required=True, help="Checkpoint file path.")
+ parser.add_argument("--run_platform", type=str, default="Ascend", choices=("Ascend", "GPU", "CPU"),
+ help="run platform, support Ascend, GPU and CPU.")
+ return parser.parse_args()
+
+if __name__ == '__main__':
+ args_opt = get_export_args()
+ context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.run_platform)
+ net = SSD300(ssd_mobilenet_v2(), config, is_training=False)
+
+ param_dict = load_checkpoint(args_opt.checkpoint_path)
+ load_param_into_net(net, param_dict)
+ input_shp = [1, 3] + config.img_shape
+ input_array = Tensor(np.random.uniform(-1.0, 1.0, size=input_shp).astype(np.float32))
+ export(net, input_array, file_name=config.export_file, file_format=config.export_format)
diff --git a/model_zoo/official/cv/ssd/src/box_utils.py b/model_zoo/official/cv/ssd/src/box_utils.py
index 34b655d1f5..dfb2e7a03e 100644
--- a/model_zoo/official/cv/ssd/src/box_utils.py
+++ b/model_zoo/official/cv/ssd/src/box_utils.py
@@ -25,7 +25,7 @@ class GeneratDefaultBoxes():
"""
Generate Default boxes for SSD, follows the order of (W, H, archor_sizes).
`self.default_boxes` has a shape of [archor_sizes, H, W, 4], the last dimension is [y, x, h, w].
- `self.default_boxes_ltrb` has a shape as `self.default_boxes`, the last dimension is [y1, x1, y2, x2].
+ `self.default_boxes_tlbr` has a shape as `self.default_boxes`, the last dimension is [y1, x1, y2, x2].
"""
def __init__(self):
fk = config.img_shape[0] / np.array(config.steps)
@@ -54,17 +54,17 @@ class GeneratDefaultBoxes():
cx, cy = (j + 0.5) / fk[idex], (i + 0.5) / fk[idex]
self.default_boxes.append([cy, cx, h, w])
- def to_ltrb(cy, cx, h, w):
+ def to_tlbr(cy, cx, h, w):
return cy - h / 2, cx - w / 2, cy + h / 2, cx + w / 2
# For IoU calculation
- self.default_boxes_ltrb = np.array(tuple(to_ltrb(*i) for i in self.default_boxes), dtype='float32')
+ self.default_boxes_tlbr = np.array(tuple(to_tlbr(*i) for i in self.default_boxes), dtype='float32')
self.default_boxes = np.array(self.default_boxes, dtype='float32')
-default_boxes_ltrb = GeneratDefaultBoxes().default_boxes_ltrb
+default_boxes_tlbr = GeneratDefaultBoxes().default_boxes_tlbr
default_boxes = GeneratDefaultBoxes().default_boxes
-y1, x1, y2, x2 = np.split(default_boxes_ltrb[:, :4], 4, axis=-1)
+y1, x1, y2, x2 = np.split(default_boxes_tlbr[:, :4], 4, axis=-1)
vol_anchors = (x2 - x1) * (y2 - y1)
matching_threshold = config.match_threshold
@@ -115,7 +115,7 @@ def ssd_bboxes_encode(boxes):
index = np.nonzero(t_label)
- # Transform to ltrb.
+ # Transform to tlbr.
bboxes = np.zeros((config.num_ssd_boxes, 4), dtype=np.float32)
bboxes[:, [0, 1]] = (t_boxes[:, [0, 1]] + t_boxes[:, [2, 3]]) / 2
bboxes[:, [2, 3]] = t_boxes[:, [2, 3]] - t_boxes[:, [0, 1]]
diff --git a/model_zoo/official/cv/ssd/src/config.py b/model_zoo/official/cv/ssd/src/config.py
index d2d3ddcef9..a41831c0be 100644
--- a/model_zoo/official/cv/ssd/src/config.py
+++ b/model_zoo/official/cv/ssd/src/config.py
@@ -27,7 +27,6 @@ config = ed({
"max_boxes": 100,
# learing rate settings
- "global_step": 0,
"lr_init": 0.001,
"lr_end_rate": 0.001,
"warmup_epochs": 2,
@@ -55,27 +54,29 @@ config = ed({
"train_data_type": "train2017",
"val_data_type": "val2017",
"instances_set": "annotations/instances_{}.json",
- "coco_classes": ('background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
- 'train', 'truck', 'boat', 'traffic light', 'fire hydrant',
- 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog',
- 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra',
- 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie',
- 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
- 'kite', 'baseball bat', 'baseball glove', 'skateboard',
- 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup',
- 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
- 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
- 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed',
- 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
- 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink',
- 'refrigerator', 'book', 'clock', 'vase', 'scissors',
- 'teddy bear', 'hair drier', 'toothbrush'),
+ "classes": ('background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
+ 'train', 'truck', 'boat', 'traffic light', 'fire hydrant',
+ 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog',
+ 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra',
+ 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie',
+ 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
+ 'kite', 'baseball bat', 'baseball glove', 'skateboard',
+ 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup',
+ 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
+ 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
+ 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed',
+ 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
+ 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink',
+ 'refrigerator', 'book', 'clock', 'vase', 'scissors',
+ 'teddy bear', 'hair drier', 'toothbrush'),
"num_classes": 81,
# The annotation.json position of voc validation dataset.
- "voc_root": "",
+ "voc_json": "annotations/voc_instances_val.json",
# voc original dataset.
- "voc_dir": "",
+ "voc_root": "/data/voc_dataset",
# if coco or voc used, `image_dir` and `anno_path` are useless.
"image_dir": "",
"anno_path": "",
+ "export_format": "MINDIR",
+ "export_file": "ssd.mindir"
})
diff --git a/model_zoo/official/cv/ssd/src/dataset.py b/model_zoo/official/cv/ssd/src/dataset.py
index ae02edce2a..74f6344a47 100644
--- a/model_zoo/official/cv/ssd/src/dataset.py
+++ b/model_zoo/official/cv/ssd/src/dataset.py
@@ -159,10 +159,10 @@ def preprocess_fn(img_id, image, box, is_training):
def create_voc_label(is_training):
"""Get image path and annotation from VOC."""
- voc_dir = config.voc_dir
- cls_map = {name: i for i, name in enumerate(config.coco_classes)}
+ voc_root = config.voc_root
+ cls_map = {name: i for i, name in enumerate(config.classes)}
sub_dir = 'train' if is_training else 'eval'
- voc_dir = os.path.join(voc_dir, sub_dir)
+ voc_dir = os.path.join(voc_root, sub_dir)
if not os.path.isdir(voc_dir):
raise ValueError(f'Cannot find {sub_dir} dataset path.')
@@ -173,8 +173,7 @@ def create_voc_label(is_training):
anno_dir = os.path.join(voc_dir, 'Annotations')
if not is_training:
- data_dir = config.voc_root
- json_file = os.path.join(data_dir, config.instances_set.format(sub_dir))
+ json_file = os.path.join(config.voc_root, config.voc_json)
file_dir = os.path.split(json_file)[0]
if not os.path.isdir(file_dir):
os.makedirs(file_dir)
@@ -203,7 +202,7 @@ def create_voc_label(is_training):
for obj in root_node.iter('object'):
cls_name = obj.find('name').text
if cls_name not in cls_map:
- print(f'Label "{cls_name}" not in "{config.coco_classes}"')
+ print(f'Label "{cls_name}" not in "{config.classes}"')
continue
bnd_box = obj.find('bndbox')
x_min = int(bnd_box.find('xmin').text) - 1
@@ -258,7 +257,7 @@ def create_coco_label(is_training):
data_type = config.train_data_type
# Classes need to train or test.
- train_cls = config.coco_classes
+ train_cls = config.classes
train_cls_dict = {}
for i, cls in enumerate(train_cls):
train_cls_dict[cls] = i
@@ -390,7 +389,7 @@ def data_to_mindrecord_byte_image(dataset="coco", is_training=True, prefix="ssd.
def create_ssd_dataset(mindrecord_file, batch_size=32, repeat_num=10, device_num=1, rank=0,
- is_training=True, num_parallel_workers=4):
+ is_training=True, num_parallel_workers=4, use_multiprocessing=True):
"""Creatr SSD dataset with MindDataset."""
ds = de.MindDataset(mindrecord_file, columns_list=["img_id", "image", "annotation"], num_shards=device_num,
shard_id=rank, num_parallel_workers=num_parallel_workers, shuffle=is_training)
@@ -409,10 +408,45 @@ def create_ssd_dataset(mindrecord_file, batch_size=32, repeat_num=10, device_num
trans = [normalize_op, change_swap_op]
ds = ds.map(operations=compose_map_func, input_columns=["img_id", "image", "annotation"],
output_columns=output_columns, column_order=output_columns,
- python_multiprocessing=is_training,
+ python_multiprocessing=use_multiprocessing,
num_parallel_workers=num_parallel_workers)
- ds = ds.map(operations=trans, input_columns=["image"], python_multiprocessing=is_training,
+ ds = ds.map(operations=trans, input_columns=["image"], python_multiprocessing=use_multiprocessing,
num_parallel_workers=num_parallel_workers)
ds = ds.batch(batch_size, drop_remainder=True)
ds = ds.repeat(repeat_num)
return ds
+
+
+def create_mindrecord(dataset="coco", prefix="ssd.mindrecord", is_training=True):
+ print("Start create dataset!")
+
+ # It will generate mindrecord file in config.mindrecord_dir,
+ # and the file name is ssd.mindrecord0, 1, ... file_num.
+
+ mindrecord_dir = config.mindrecord_dir
+ mindrecord_file = os.path.join(mindrecord_dir, prefix + "0")
+ if not os.path.exists(mindrecord_file):
+ if not os.path.isdir(mindrecord_dir):
+ os.makedirs(mindrecord_dir)
+ if dataset == "coco":
+ if os.path.isdir(config.coco_root):
+ print("Create Mindrecord.")
+ data_to_mindrecord_byte_image("coco", is_training, prefix)
+ print("Create Mindrecord Done, at {}".format(mindrecord_dir))
+ else:
+ print("coco_root not exits.")
+ elif dataset == "voc":
+ if os.path.isdir(config.voc_root):
+ print("Create Mindrecord.")
+ voc_data_to_mindrecord(mindrecord_dir, is_training, prefix)
+ print("Create Mindrecord Done, at {}".format(mindrecord_dir))
+ else:
+ print("voc_root not exits.")
+ else:
+ if os.path.isdir(config.image_dir) and os.path.exists(config.anno_path):
+ print("Create Mindrecord.")
+ data_to_mindrecord_byte_image("other", is_training, prefix)
+ print("Create Mindrecord Done, at {}".format(mindrecord_dir))
+ else:
+ print("image_dir or anno_path not exits.")
+ return mindrecord_file
diff --git a/model_zoo/official/cv/ssd/src/coco_eval.py b/model_zoo/official/cv/ssd/src/eval_utils.py
similarity index 94%
rename from model_zoo/official/cv/ssd/src/coco_eval.py
rename to model_zoo/official/cv/ssd/src/eval_utils.py
index 4c190bc5ef..180069d185 100644
--- a/model_zoo/official/cv/ssd/src/coco_eval.py
+++ b/model_zoo/official/cv/ssd/src/eval_utils.py
@@ -14,7 +14,6 @@
# ============================================================================
"""Coco metrics utils"""
-import os
import json
import numpy as np
from .config import config
@@ -56,22 +55,17 @@ def apply_nms(all_boxes, all_scores, thres, max_boxes):
return keep
-def metrics(pred_data):
+def metrics(pred_data, anno_json):
"""Calculate mAP of predicted bboxes."""
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
num_classes = config.num_classes
- coco_root = config.coco_root
- data_type = config.val_data_type
-
#Classes need to train or test.
- val_cls = config.coco_classes
+ val_cls = config.classes
val_cls_dict = {}
for i, cls in enumerate(val_cls):
val_cls_dict[i] = cls
-
- anno_json = os.path.join(coco_root, config.instances_set.format(data_type))
coco_gt = COCO(anno_json)
classs_dict = {}
cat_ids = coco_gt.loadCats(coco_gt.getCatIds())
diff --git a/model_zoo/official/cv/ssd/train.py b/model_zoo/official/cv/ssd/train.py
index c18dc72f77..2094077e25 100644
--- a/model_zoo/official/cv/ssd/train.py
+++ b/model_zoo/official/cv/ssd/train.py
@@ -15,7 +15,6 @@
"""Train SSD and get checkpoint files."""
-import os
import argparse
import ast
import mindspore.nn as nn
@@ -28,14 +27,16 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.common import set_seed, dtype
from src.ssd import SSD300, SSDWithLossCell, TrainingWrapper, ssd_mobilenet_v2
from src.config import config
-from src.dataset import create_ssd_dataset, data_to_mindrecord_byte_image, voc_data_to_mindrecord
+from src.dataset import create_ssd_dataset, create_mindrecord
from src.lr_schedule import get_lr
from src.init_params import init_net_param, filter_checkpoint_parameter
set_seed(1)
-def main():
+def get_args():
parser = argparse.ArgumentParser(description="SSD training")
+ parser.add_argument("--run_platform", type=str, default="Ascend", choices=("Ascend", "GPU", "CPU"),
+ help="run platform, support Ascend, GPU and CPU.")
parser.add_argument("--only_create_dataset", type=ast.literal_eval, default=False,
help="If set it true, only create Mindrecord, default is False.")
parser.add_argument("--distribute", type=ast.literal_eval, default=False,
@@ -52,77 +53,39 @@ def main():
parser.add_argument("--save_checkpoint_epochs", type=int, default=10, help="Save checkpoint epochs, default is 10.")
parser.add_argument("--loss_scale", type=int, default=1024, help="Loss scale, default is 1024.")
parser.add_argument("--filter_weight", type=ast.literal_eval, default=False,
- help="Filter weight parameters, default is False.")
- parser.add_argument("--run_platform", type=str, default="Ascend", choices=("Ascend", "GPU"),
- help="run platform, only support Ascend and GPU.")
+ help="Filter head weight parameters, default is False.")
+ parser.add_argument('--freeze_layer', type=str, default="none", choices=["none", "backbone"],
+ help="freeze the weights of network, support freeze the backbone's weights, "
+ "default is not freezing.")
args_opt = parser.parse_args()
+ return args_opt
- if args_opt.run_platform == "Ascend":
- context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id)
+def main():
+ args_opt = get_args()
+ rank = 0
+ device_num = 1
+ if args_opt.run_platform == "CPU":
+ context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
+ else:
+ context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.run_platform, device_id=args_opt.device_id)
if args_opt.distribute:
device_num = args_opt.device_num
context.reset_auto_parallel_context()
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True,
device_num=device_num)
init()
- rank = args_opt.device_id % device_num
- else:
- rank = 0
- device_num = 1
- elif args_opt.run_platform == "GPU":
- context.set_context(mode=context.GRAPH_MODE, device_target="GPU", device_id=args_opt.device_id)
- init()
- if args_opt.distribute:
- device_num = args_opt.device_num
- context.reset_auto_parallel_context()
- context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True,
- device_num=device_num)
rank = get_rank()
- else:
- rank = 0
- device_num = 1
- else:
- raise ValueError("Unsupported platform.")
-
- print("Start create dataset!")
-
- # It will generate mindrecord file in args_opt.mindrecord_dir,
- # and the file name is ssd.mindrecord0, 1, ... file_num.
-
- prefix = "ssd.mindrecord"
- mindrecord_dir = config.mindrecord_dir
- mindrecord_file = os.path.join(mindrecord_dir, prefix + "0")
- if not os.path.exists(mindrecord_file):
- if not os.path.isdir(mindrecord_dir):
- os.makedirs(mindrecord_dir)
- if args_opt.dataset == "coco":
- if os.path.isdir(config.coco_root):
- print("Create Mindrecord.")
- data_to_mindrecord_byte_image("coco", True, prefix)
- print("Create Mindrecord Done, at {}".format(mindrecord_dir))
- else:
- print("coco_root not exits.")
- elif args_opt.dataset == "voc":
- if os.path.isdir(config.voc_dir):
- print("Create Mindrecord.")
- voc_data_to_mindrecord(mindrecord_dir, True, prefix)
- print("Create Mindrecord Done, at {}".format(mindrecord_dir))
- else:
- print("voc_dir not exits.")
- else:
- if os.path.isdir(config.image_dir) and os.path.exists(config.anno_path):
- print("Create Mindrecord.")
- data_to_mindrecord_byte_image("other", True, prefix)
- print("Create Mindrecord Done, at {}".format(mindrecord_dir))
- else:
- print("image_dir or anno_path not exits.")
+ mindrecord_file = create_mindrecord(args_opt.dataset, "ssd.mindrecord", True)
if not args_opt.only_create_dataset:
loss_scale = float(args_opt.loss_scale)
+ if args_opt.run_platform == "CPU":
+ loss_scale = 1.0
# When create MindDataset, using the fitst mindrecord file, such as ssd.mindrecord0.
- dataset = create_ssd_dataset(mindrecord_file, repeat_num=1,
- batch_size=args_opt.batch_size, device_num=device_num, rank=rank)
+ use_multiprocessing = (args_opt.run_platform != "CPU")
+ dataset = create_ssd_dataset(mindrecord_file, repeat_num=1, batch_size=args_opt.batch_size,
+ device_num=device_num, rank=rank, use_multiprocessing=use_multiprocessing)
dataset_size = dataset.get_dataset_size()
print("Create dataset done!")
@@ -140,27 +103,30 @@ def main():
ckpoint_cb = ModelCheckpoint(prefix="ssd", directory=save_ckpt_path, config=ckpt_config)
if args_opt.pre_trained:
- if args_opt.pre_trained_epoch_size <= 0:
- raise KeyError("pre_trained_epoch_size must be greater than 0.")
param_dict = load_checkpoint(args_opt.pre_trained)
if args_opt.filter_weight:
filter_checkpoint_parameter(param_dict)
load_param_into_net(net, param_dict)
- lr = Tensor(get_lr(global_step=config.global_step,
+ if args_opt.freeze_layer == "backbone":
+ for param in backbone.feature_1.trainable_params():
+ param.requires_grad = False
+
+ lr = Tensor(get_lr(global_step=args_opt.pre_trained_epoch_size * dataset_size,
lr_init=config.lr_init, lr_end=config.lr_end_rate * args_opt.lr, lr_max=args_opt.lr,
warmup_epochs=config.warmup_epochs,
total_epochs=args_opt.epoch_size,
steps_per_epoch=dataset_size))
+
opt = nn.Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr,
config.momentum, config.weight_decay, loss_scale)
+
net = TrainingWrapper(net, opt, loss_scale)
callback = [TimeMonitor(data_size=dataset_size), LossMonitor(), ckpoint_cb]
-
model = Model(net)
dataset_sink_mode = False
- if args_opt.mode == "sink":
+ if args_opt.mode == "sink" and args_opt.run_platform != "CPU":
print("In sink mode, one epoch return a loss.")
dataset_sink_mode = True
print("Start train SSD, the first epoch will be slower because of the graph compilation.")