diff --git a/model_zoo/official/cv/ssd/README.md b/model_zoo/official/cv/ssd/README.md index 0d9ee3c707..1105b42fc1 100644 --- a/model_zoo/official/cv/ssd/README.md +++ b/model_zoo/official/cv/ssd/README.md @@ -12,6 +12,7 @@ - [Training](#training) - [Evaluation Process](#evaluation-process) - [Evaluation](#evaluation) + - [Export MindIR](#export-mindir) - [Model Description](#model-description) - [Performance](#performance) - [Evaluation Performance](#evaluation-performance) @@ -49,21 +50,23 @@ Dataset used: [COCO2017]() - Download the dataset COCO2017. - We use COCO2017 as training dataset in this example by default, and you can also use your own datasets. + First, install Cython ,pycocotool and opencv to process data and to get evaluation result. - 1. If coco dataset is used. **Select dataset to coco when run script.** - Install Cython and pycocotool, and you can also install mmcv to process data. + ``` + pip install Cython - ``` - pip install Cython + pip install pycocotools - pip install pycocotools + pip install opencv-python - ``` - And change the COCO_ROOT and other settings you need in `config.py`. The directory structure is as follows: + ``` + 1. If coco dataset is used. **Select dataset to coco when run script.** + + Change the `coco_root` and other settings you need in `src/config.py`. The directory structure is as follows: ``` . - └─cocodataset + └─coco_dataset ├─annotations ├─instance_train2017.json └─instance_val2017.json @@ -72,7 +75,27 @@ Dataset used: [COCO2017]() ``` - 2. If your own dataset is used. **Select dataset to other when run script.** + 2. If VOC dataset is used. **Select dataset to voc when run script.** + Change `classes`, `num_classes`, `voc_json` and `voc_root` in `src/config.py`. `voc_json` is the path of json file with coco format for evalution, `voc_root` is the path of VOC dataset, the directory structure is as follows: + ``` + . + └─voc_dataset + └─train + ├─0001.jpg + └─0001.xml + ... + ├─xxxx.jpg + └─xxxx.xml + └─eval + ├─0001.jpg + └─0001.xml + ... + ├─xxxx.jpg + └─xxxx.xml + + ``` + + 3. If your own dataset is used. **Select dataset to other when run script.** Organize the dataset infomation into a TXT file, each row in the file is as follows: ``` @@ -80,7 +103,7 @@ Dataset used: [COCO2017]() ``` - Each row is an image annotation which split by space, the first column is a relative path of image, the others are box and class infomations of the format [xmin,ymin,xmax,ymax,class]. We read image from an image path joined by the `IMAGE_DIR`(dataset directory) and the relative path in `ANNO_PATH`(the TXT file path), `IMAGE_DIR` and `ANNO_PATH` are setting in `config.py`. + Each row is an image annotation which split by space, the first column is a relative path of image, the others are box and class infomations of the format [xmin,ymin,xmax,ymax,class]. We read image from an image path joined by the `image_dir`(dataset directory) and the relative path in `anno_path`(the TXT file path), `image_dir` and `anno_path` are setting in `src/config.py`. # [Quick Start](#contents) @@ -103,7 +126,19 @@ sh run_distribute_train_gpu.sh [DEVICE_NUM] [EPOCH_SIZE] [LR] [DATASET] sh run_eval_gpu.sh [DATASET] [CHECKPOINT_PATH] [DEVICE_ID] ``` -# [Script Description](#contents) +- runing on CPU(support Windows and Ubuntu) + +**CPU is usually used for fine-tuning, which needs pre_trained checkpoint.** + +``` +# training on CPU +python train.py --run_platform=CPU --lr=[LR] --dataset=[DATASET] --epoch_size=[EPOCH_SIZE] --batch_size=[BATCH_SIZE] --pre_trained=[PRETRAINED_CKPT] --filter_weight=True --save_checkpoint_epochs=1 + +# run eval on GPU +python eval.py --run_platform=CPU --dataset=[DATASET] --checkpoint_path=[PRETRAINED_CKPT] +``` + +# [Script Description](#contents) ## [Script and Sample Code](#contents) @@ -111,24 +146,25 @@ sh run_eval_gpu.sh [DATASET] [CHECKPOINT_PATH] [DEVICE_ID] . └─ cv └─ ssd - ├─ README.md ## descriptions about SSD + ├─ README.md # descriptions about SSD ├─ scripts - ├─ run_distribute_train.sh ## shell script for distributed on ascend - ├─ run_distribute_train_gpu.sh ## shell script for distributed on gpu - ├─ run_eval.sh ## shell script for eval on ascend - └─ run_eval_gpu.sh ## shell script for eval on gpu + ├─ run_distribute_train.sh # shell script for distributed on ascend + ├─ run_distribute_train_gpu.sh # shell script for distributed on gpu + ├─ run_eval.sh # shell script for eval on ascend + └─ run_eval_gpu.sh # shell script for eval on gpu ├─ src - ├─ __init__.py ## init file - ├─ box_util.py ## bbox utils - ├─ coco_eval.py ## coco metrics utils - ├─ config.py ## total config - ├─ dataset.py ## create dataset and process dataset - ├─ init_params.py ## parameters utils - ├─ lr_schedule.py ## learning ratio generator - └─ ssd.py ## ssd architecture - ├─ eval.py ## eval scripts - ├─ train.py ## train scripts - └─ mindspore_hub_conf.py ## mindspore hub interface + ├─ __init__.py # init file + ├─ box_utils.py # bbox utils + ├─ eval_utils.py # metrics utils + ├─ config.py # total config + ├─ dataset.py # create dataset and process dataset + ├─ init_params.py # parameters utils + ├─ lr_schedule.py # learning ratio generator + └─ ssd.py # ssd architecture + ├─ eval.py # eval scripts + ├─ train.py # train scripts + ├─ export.py # export mindir script + └─ mindspore_hub_conf.py # mindspore hub interface ``` ## [Script Parameters](#contents) @@ -136,30 +172,33 @@ sh run_eval_gpu.sh [DATASET] [CHECKPOINT_PATH] [DEVICE_ID] ``` Major parameters in train.py and config.py as follows: - "device_num": 1 # Use device nums - "lr": 0.05 # Learning rate init value - "dataset": coco # Dataset name - "epoch_size": 500 # Epoch size - "batch_size": 32 # Batch size of input tensor - "pre_trained": None # Pretrained checkpoint file path - "pre_trained_epoch_size": 0 # Pretrained epoch size - "save_checkpoint_epochs": 10 # The epoch interval between two checkpoints. By default, the checkpoint will be saved per 10 epochs - "loss_scale": 1024 # Loss scale - - "class_num": 81 # Dataset class number - "image_shape": [300, 300] # Image height and width used as input to the model - "mindrecord_dir": "/data/MindRecord_COCO" # MindRecord path - "coco_root": "/data/coco2017" # COCO2017 dataset path - "voc_root": "" # VOC original dataset path - "image_dir": "" # Other dataset image path, if coco or voc used, it will be useless - "anno_path": "" # Other dataset annotation path, if coco or voc used, it will be useless + "device_num": 1 # Use device nums + "lr": 0.05 # Learning rate init value + "dataset": coco # Dataset name + "epoch_size": 500 # Epoch size + "batch_size": 32 # Batch size of input tensor + "pre_trained": None # Pretrained checkpoint file path + "pre_trained_epoch_size": 0 # Pretrained epoch size + "save_checkpoint_epochs": 10 # The epoch interval between two checkpoints. By default, the checkpoint will be saved per 10 epochs + "loss_scale": 1024 # Loss scale + "filter_weight": False # Load paramters in head layer or not. If the class numbers of train dataset is different from the class numbers in pre_trained checkpoint, please set True. + "freeze_layer": "none" # Freeze the backbone paramters or not, support none and backbone. + + "class_num": 81 # Dataset class number + "image_shape": [300, 300] # Image height and width used as input to the model + "mindrecord_dir": "/data/MindRecord_COCO" # MindRecord path + "coco_root": "/data/coco2017" # COCO2017 dataset path + "voc_root": "/data/voc_dataset" # VOC original dataset path + "voc_json": "annotations/voc_instances_val.json" # is the path of json file with coco format for evalution + "image_dir": "" # Other dataset image path, if coco or voc used, it will be useless + "anno_path": "" # Other dataset annotation path, if coco or voc used, it will be useless ``` ## [Training Process](#contents) -To train the model, run `train.py`. If the `mindrecord_dir` is empty, it will generate [mindrecord](https://www.mindspore.cn/tutorial/training/zh-CN/master/advanced_use/convert_dataset.html) files by `coco_root`(coco dataset) or `iamge_dir` and `anno_path`(own dataset). **Note if mindrecord_dir isn't empty, it will use mindrecord_dir instead of raw images.** +To train the model, run `train.py`. If the `mindrecord_dir` is empty, it will generate [mindrecord](https://www.mindspore.cn/tutorial/training/zh-CN/master/advanced_use/convert_dataset.html) files by `coco_root`(coco dataset), `voc_root`(voc dataset) or `image_dir` and `anno_path`(own dataset). **Note if mindrecord_dir isn't empty, it will use mindrecord_dir instead of raw images.** ### Training on Ascend @@ -292,6 +331,14 @@ Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.686 mAP: 0.2244936111705981 ``` +## [Export MindIR](#contents) + +Change the export mode and export file in `src/config.py`, and run `export.py`. + +``` +python export.py --run_platform [PLATFORM] --checkpoint_path [CKPT_PATH] +``` + # [Model Description](#contents) ## [Performance](#contents) diff --git a/model_zoo/official/cv/ssd/eval.py b/model_zoo/official/cv/ssd/eval.py index d40d4a2ec8..f98b98926f 100644 --- a/model_zoo/official/cv/ssd/eval.py +++ b/model_zoo/official/cv/ssd/eval.py @@ -22,14 +22,15 @@ import numpy as np from mindspore import context, Tensor from mindspore.train.serialization import load_checkpoint, load_param_into_net from src.ssd import SSD300, ssd_mobilenet_v2 -from src.dataset import create_ssd_dataset, data_to_mindrecord_byte_image, voc_data_to_mindrecord +from src.dataset import create_ssd_dataset, create_mindrecord from src.config import config -from src.coco_eval import metrics +from src.eval_utils import metrics -def ssd_eval(dataset_path, ckpt_path): +def ssd_eval(dataset_path, ckpt_path, anno_json): """SSD evaluation.""" batch_size = 1 - ds = create_ssd_dataset(dataset_path, batch_size=batch_size, repeat_num=1, is_training=False) + ds = create_ssd_dataset(dataset_path, batch_size=batch_size, repeat_num=1, + is_training=False, use_multiprocessing=False) net = SSD300(ssd_mobilenet_v2(), config, is_training=False) print("Load Checkpoint!") param_dict = load_checkpoint(ckpt_path) @@ -61,51 +62,31 @@ def ssd_eval(dataset_path, ckpt_path): i += batch_size cost_time = int((time.time() - start) * 1000) print(f' 100% [{total}/{total}] cost {cost_time} ms') - mAP = metrics(pred_data) + mAP = metrics(pred_data, anno_json) print("\n========================================\n") print(f"mAP: {mAP}") - -if __name__ == '__main__': +def get_eval_args(): parser = argparse.ArgumentParser(description='SSD evaluation') parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") parser.add_argument("--dataset", type=str, default="coco", help="Dataset, default is coco.") parser.add_argument("--checkpoint_path", type=str, required=True, help="Checkpoint file path.") - parser.add_argument("--run_platform", type=str, default="Ascend", choices=("Ascend", "GPU"), - help="run platform, only support Ascend and GPU.") - args_opt = parser.parse_args() + parser.add_argument("--run_platform", type=str, default="Ascend", choices=("Ascend", "GPU", "CPU"), + help="run platform, support Ascend ,GPU and CPU.") + return parser.parse_args() + +if __name__ == '__main__': + args_opt = get_eval_args() + if args_opt.dataset == "coco": + json_path = os.path.join(config.coco_root, config.instances_set.format(config.val_data_type)) + elif args_opt.dataset == "voc": + json_path = os.path.join(config.voc_root, config.voc_json) + else: + raise ValueError('SSD eval only supprt dataset mode is coco and voc!') context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.run_platform, device_id=args_opt.device_id) - prefix = "ssd_eval.mindrecord" - mindrecord_dir = config.mindrecord_dir - mindrecord_file = os.path.join(mindrecord_dir, prefix + "0") - if args_opt.dataset == "voc": - config.coco_root = config.voc_root - if not os.path.exists(mindrecord_file): - if not os.path.isdir(mindrecord_dir): - os.makedirs(mindrecord_dir) - if args_opt.dataset == "coco": - if os.path.isdir(config.coco_root): - print("Create Mindrecord.") - data_to_mindrecord_byte_image("coco", False, prefix) - print("Create Mindrecord Done, at {}".format(mindrecord_dir)) - else: - print("coco_root not exits.") - elif args_opt.dataset == "voc": - if os.path.isdir(config.voc_dir) and os.path.isdir(config.voc_root): - print("Create Mindrecord.") - voc_data_to_mindrecord(mindrecord_dir, False, prefix) - print("Create Mindrecord Done, at {}".format(mindrecord_dir)) - else: - print("voc_root or voc_dir not exits.") - else: - if os.path.isdir(config.image_dir) and os.path.exists(config.anno_path): - print("Create Mindrecord.") - data_to_mindrecord_byte_image("other", False, prefix) - print("Create Mindrecord Done, at {}".format(mindrecord_dir)) - else: - print("IMAGE_DIR or ANNO_PATH not exits.") + mindrecord_file = create_mindrecord(args_opt.dataset, "ssd_eval.mindrecord", False) print("Start Eval!") - ssd_eval(mindrecord_file, args_opt.checkpoint_path) + ssd_eval(mindrecord_file, args_opt.checkpoint_path, json_path) diff --git a/model_zoo/official/cv/ssd/export.py b/model_zoo/official/cv/ssd/export.py new file mode 100644 index 0000000000..1d5f0087e6 --- /dev/null +++ b/model_zoo/official/cv/ssd/export.py @@ -0,0 +1,41 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +ssd export mindir. +""" +import argparse +import numpy as np +from mindspore import context, Tensor +from mindspore.train.serialization import load_checkpoint, load_param_into_net, export +from src.ssd import SSD300, ssd_mobilenet_v2 +from src.config import config + +def get_export_args(): + parser = argparse.ArgumentParser(description='SSD export') + parser.add_argument("--checkpoint_path", type=str, required=True, help="Checkpoint file path.") + parser.add_argument("--run_platform", type=str, default="Ascend", choices=("Ascend", "GPU", "CPU"), + help="run platform, support Ascend, GPU and CPU.") + return parser.parse_args() + +if __name__ == '__main__': + args_opt = get_export_args() + context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.run_platform) + net = SSD300(ssd_mobilenet_v2(), config, is_training=False) + + param_dict = load_checkpoint(args_opt.checkpoint_path) + load_param_into_net(net, param_dict) + input_shp = [1, 3] + config.img_shape + input_array = Tensor(np.random.uniform(-1.0, 1.0, size=input_shp).astype(np.float32)) + export(net, input_array, file_name=config.export_file, file_format=config.export_format) diff --git a/model_zoo/official/cv/ssd/src/box_utils.py b/model_zoo/official/cv/ssd/src/box_utils.py index 34b655d1f5..dfb2e7a03e 100644 --- a/model_zoo/official/cv/ssd/src/box_utils.py +++ b/model_zoo/official/cv/ssd/src/box_utils.py @@ -25,7 +25,7 @@ class GeneratDefaultBoxes(): """ Generate Default boxes for SSD, follows the order of (W, H, archor_sizes). `self.default_boxes` has a shape of [archor_sizes, H, W, 4], the last dimension is [y, x, h, w]. - `self.default_boxes_ltrb` has a shape as `self.default_boxes`, the last dimension is [y1, x1, y2, x2]. + `self.default_boxes_tlbr` has a shape as `self.default_boxes`, the last dimension is [y1, x1, y2, x2]. """ def __init__(self): fk = config.img_shape[0] / np.array(config.steps) @@ -54,17 +54,17 @@ class GeneratDefaultBoxes(): cx, cy = (j + 0.5) / fk[idex], (i + 0.5) / fk[idex] self.default_boxes.append([cy, cx, h, w]) - def to_ltrb(cy, cx, h, w): + def to_tlbr(cy, cx, h, w): return cy - h / 2, cx - w / 2, cy + h / 2, cx + w / 2 # For IoU calculation - self.default_boxes_ltrb = np.array(tuple(to_ltrb(*i) for i in self.default_boxes), dtype='float32') + self.default_boxes_tlbr = np.array(tuple(to_tlbr(*i) for i in self.default_boxes), dtype='float32') self.default_boxes = np.array(self.default_boxes, dtype='float32') -default_boxes_ltrb = GeneratDefaultBoxes().default_boxes_ltrb +default_boxes_tlbr = GeneratDefaultBoxes().default_boxes_tlbr default_boxes = GeneratDefaultBoxes().default_boxes -y1, x1, y2, x2 = np.split(default_boxes_ltrb[:, :4], 4, axis=-1) +y1, x1, y2, x2 = np.split(default_boxes_tlbr[:, :4], 4, axis=-1) vol_anchors = (x2 - x1) * (y2 - y1) matching_threshold = config.match_threshold @@ -115,7 +115,7 @@ def ssd_bboxes_encode(boxes): index = np.nonzero(t_label) - # Transform to ltrb. + # Transform to tlbr. bboxes = np.zeros((config.num_ssd_boxes, 4), dtype=np.float32) bboxes[:, [0, 1]] = (t_boxes[:, [0, 1]] + t_boxes[:, [2, 3]]) / 2 bboxes[:, [2, 3]] = t_boxes[:, [2, 3]] - t_boxes[:, [0, 1]] diff --git a/model_zoo/official/cv/ssd/src/config.py b/model_zoo/official/cv/ssd/src/config.py index d2d3ddcef9..a41831c0be 100644 --- a/model_zoo/official/cv/ssd/src/config.py +++ b/model_zoo/official/cv/ssd/src/config.py @@ -27,7 +27,6 @@ config = ed({ "max_boxes": 100, # learing rate settings - "global_step": 0, "lr_init": 0.001, "lr_end_rate": 0.001, "warmup_epochs": 2, @@ -55,27 +54,29 @@ config = ed({ "train_data_type": "train2017", "val_data_type": "val2017", "instances_set": "annotations/instances_{}.json", - "coco_classes": ('background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', - 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', - 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', - 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', - 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', - 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', - 'kite', 'baseball bat', 'baseball glove', 'skateboard', - 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', - 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', - 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', - 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', - 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', - 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', - 'refrigerator', 'book', 'clock', 'vase', 'scissors', - 'teddy bear', 'hair drier', 'toothbrush'), + "classes": ('background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', + 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', + 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', + 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', + 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', + 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', + 'kite', 'baseball bat', 'baseball glove', 'skateboard', + 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', + 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', + 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', + 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', + 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', + 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', + 'refrigerator', 'book', 'clock', 'vase', 'scissors', + 'teddy bear', 'hair drier', 'toothbrush'), "num_classes": 81, # The annotation.json position of voc validation dataset. - "voc_root": "", + "voc_json": "annotations/voc_instances_val.json", # voc original dataset. - "voc_dir": "", + "voc_root": "/data/voc_dataset", # if coco or voc used, `image_dir` and `anno_path` are useless. "image_dir": "", "anno_path": "", + "export_format": "MINDIR", + "export_file": "ssd.mindir" }) diff --git a/model_zoo/official/cv/ssd/src/dataset.py b/model_zoo/official/cv/ssd/src/dataset.py index ae02edce2a..74f6344a47 100644 --- a/model_zoo/official/cv/ssd/src/dataset.py +++ b/model_zoo/official/cv/ssd/src/dataset.py @@ -159,10 +159,10 @@ def preprocess_fn(img_id, image, box, is_training): def create_voc_label(is_training): """Get image path and annotation from VOC.""" - voc_dir = config.voc_dir - cls_map = {name: i for i, name in enumerate(config.coco_classes)} + voc_root = config.voc_root + cls_map = {name: i for i, name in enumerate(config.classes)} sub_dir = 'train' if is_training else 'eval' - voc_dir = os.path.join(voc_dir, sub_dir) + voc_dir = os.path.join(voc_root, sub_dir) if not os.path.isdir(voc_dir): raise ValueError(f'Cannot find {sub_dir} dataset path.') @@ -173,8 +173,7 @@ def create_voc_label(is_training): anno_dir = os.path.join(voc_dir, 'Annotations') if not is_training: - data_dir = config.voc_root - json_file = os.path.join(data_dir, config.instances_set.format(sub_dir)) + json_file = os.path.join(config.voc_root, config.voc_json) file_dir = os.path.split(json_file)[0] if not os.path.isdir(file_dir): os.makedirs(file_dir) @@ -203,7 +202,7 @@ def create_voc_label(is_training): for obj in root_node.iter('object'): cls_name = obj.find('name').text if cls_name not in cls_map: - print(f'Label "{cls_name}" not in "{config.coco_classes}"') + print(f'Label "{cls_name}" not in "{config.classes}"') continue bnd_box = obj.find('bndbox') x_min = int(bnd_box.find('xmin').text) - 1 @@ -258,7 +257,7 @@ def create_coco_label(is_training): data_type = config.train_data_type # Classes need to train or test. - train_cls = config.coco_classes + train_cls = config.classes train_cls_dict = {} for i, cls in enumerate(train_cls): train_cls_dict[cls] = i @@ -390,7 +389,7 @@ def data_to_mindrecord_byte_image(dataset="coco", is_training=True, prefix="ssd. def create_ssd_dataset(mindrecord_file, batch_size=32, repeat_num=10, device_num=1, rank=0, - is_training=True, num_parallel_workers=4): + is_training=True, num_parallel_workers=4, use_multiprocessing=True): """Creatr SSD dataset with MindDataset.""" ds = de.MindDataset(mindrecord_file, columns_list=["img_id", "image", "annotation"], num_shards=device_num, shard_id=rank, num_parallel_workers=num_parallel_workers, shuffle=is_training) @@ -409,10 +408,45 @@ def create_ssd_dataset(mindrecord_file, batch_size=32, repeat_num=10, device_num trans = [normalize_op, change_swap_op] ds = ds.map(operations=compose_map_func, input_columns=["img_id", "image", "annotation"], output_columns=output_columns, column_order=output_columns, - python_multiprocessing=is_training, + python_multiprocessing=use_multiprocessing, num_parallel_workers=num_parallel_workers) - ds = ds.map(operations=trans, input_columns=["image"], python_multiprocessing=is_training, + ds = ds.map(operations=trans, input_columns=["image"], python_multiprocessing=use_multiprocessing, num_parallel_workers=num_parallel_workers) ds = ds.batch(batch_size, drop_remainder=True) ds = ds.repeat(repeat_num) return ds + + +def create_mindrecord(dataset="coco", prefix="ssd.mindrecord", is_training=True): + print("Start create dataset!") + + # It will generate mindrecord file in config.mindrecord_dir, + # and the file name is ssd.mindrecord0, 1, ... file_num. + + mindrecord_dir = config.mindrecord_dir + mindrecord_file = os.path.join(mindrecord_dir, prefix + "0") + if not os.path.exists(mindrecord_file): + if not os.path.isdir(mindrecord_dir): + os.makedirs(mindrecord_dir) + if dataset == "coco": + if os.path.isdir(config.coco_root): + print("Create Mindrecord.") + data_to_mindrecord_byte_image("coco", is_training, prefix) + print("Create Mindrecord Done, at {}".format(mindrecord_dir)) + else: + print("coco_root not exits.") + elif dataset == "voc": + if os.path.isdir(config.voc_root): + print("Create Mindrecord.") + voc_data_to_mindrecord(mindrecord_dir, is_training, prefix) + print("Create Mindrecord Done, at {}".format(mindrecord_dir)) + else: + print("voc_root not exits.") + else: + if os.path.isdir(config.image_dir) and os.path.exists(config.anno_path): + print("Create Mindrecord.") + data_to_mindrecord_byte_image("other", is_training, prefix) + print("Create Mindrecord Done, at {}".format(mindrecord_dir)) + else: + print("image_dir or anno_path not exits.") + return mindrecord_file diff --git a/model_zoo/official/cv/ssd/src/coco_eval.py b/model_zoo/official/cv/ssd/src/eval_utils.py similarity index 94% rename from model_zoo/official/cv/ssd/src/coco_eval.py rename to model_zoo/official/cv/ssd/src/eval_utils.py index 4c190bc5ef..180069d185 100644 --- a/model_zoo/official/cv/ssd/src/coco_eval.py +++ b/model_zoo/official/cv/ssd/src/eval_utils.py @@ -14,7 +14,6 @@ # ============================================================================ """Coco metrics utils""" -import os import json import numpy as np from .config import config @@ -56,22 +55,17 @@ def apply_nms(all_boxes, all_scores, thres, max_boxes): return keep -def metrics(pred_data): +def metrics(pred_data, anno_json): """Calculate mAP of predicted bboxes.""" from pycocotools.coco import COCO from pycocotools.cocoeval import COCOeval num_classes = config.num_classes - coco_root = config.coco_root - data_type = config.val_data_type - #Classes need to train or test. - val_cls = config.coco_classes + val_cls = config.classes val_cls_dict = {} for i, cls in enumerate(val_cls): val_cls_dict[i] = cls - - anno_json = os.path.join(coco_root, config.instances_set.format(data_type)) coco_gt = COCO(anno_json) classs_dict = {} cat_ids = coco_gt.loadCats(coco_gt.getCatIds()) diff --git a/model_zoo/official/cv/ssd/train.py b/model_zoo/official/cv/ssd/train.py index c18dc72f77..2094077e25 100644 --- a/model_zoo/official/cv/ssd/train.py +++ b/model_zoo/official/cv/ssd/train.py @@ -15,7 +15,6 @@ """Train SSD and get checkpoint files.""" -import os import argparse import ast import mindspore.nn as nn @@ -28,14 +27,16 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.common import set_seed, dtype from src.ssd import SSD300, SSDWithLossCell, TrainingWrapper, ssd_mobilenet_v2 from src.config import config -from src.dataset import create_ssd_dataset, data_to_mindrecord_byte_image, voc_data_to_mindrecord +from src.dataset import create_ssd_dataset, create_mindrecord from src.lr_schedule import get_lr from src.init_params import init_net_param, filter_checkpoint_parameter set_seed(1) -def main(): +def get_args(): parser = argparse.ArgumentParser(description="SSD training") + parser.add_argument("--run_platform", type=str, default="Ascend", choices=("Ascend", "GPU", "CPU"), + help="run platform, support Ascend, GPU and CPU.") parser.add_argument("--only_create_dataset", type=ast.literal_eval, default=False, help="If set it true, only create Mindrecord, default is False.") parser.add_argument("--distribute", type=ast.literal_eval, default=False, @@ -52,77 +53,39 @@ def main(): parser.add_argument("--save_checkpoint_epochs", type=int, default=10, help="Save checkpoint epochs, default is 10.") parser.add_argument("--loss_scale", type=int, default=1024, help="Loss scale, default is 1024.") parser.add_argument("--filter_weight", type=ast.literal_eval, default=False, - help="Filter weight parameters, default is False.") - parser.add_argument("--run_platform", type=str, default="Ascend", choices=("Ascend", "GPU"), - help="run platform, only support Ascend and GPU.") + help="Filter head weight parameters, default is False.") + parser.add_argument('--freeze_layer', type=str, default="none", choices=["none", "backbone"], + help="freeze the weights of network, support freeze the backbone's weights, " + "default is not freezing.") args_opt = parser.parse_args() + return args_opt - if args_opt.run_platform == "Ascend": - context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id) +def main(): + args_opt = get_args() + rank = 0 + device_num = 1 + if args_opt.run_platform == "CPU": + context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + else: + context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.run_platform, device_id=args_opt.device_id) if args_opt.distribute: device_num = args_opt.device_num context.reset_auto_parallel_context() context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True, device_num=device_num) init() - rank = args_opt.device_id % device_num - else: - rank = 0 - device_num = 1 - elif args_opt.run_platform == "GPU": - context.set_context(mode=context.GRAPH_MODE, device_target="GPU", device_id=args_opt.device_id) - init() - if args_opt.distribute: - device_num = args_opt.device_num - context.reset_auto_parallel_context() - context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True, - device_num=device_num) rank = get_rank() - else: - rank = 0 - device_num = 1 - else: - raise ValueError("Unsupported platform.") - - print("Start create dataset!") - - # It will generate mindrecord file in args_opt.mindrecord_dir, - # and the file name is ssd.mindrecord0, 1, ... file_num. - - prefix = "ssd.mindrecord" - mindrecord_dir = config.mindrecord_dir - mindrecord_file = os.path.join(mindrecord_dir, prefix + "0") - if not os.path.exists(mindrecord_file): - if not os.path.isdir(mindrecord_dir): - os.makedirs(mindrecord_dir) - if args_opt.dataset == "coco": - if os.path.isdir(config.coco_root): - print("Create Mindrecord.") - data_to_mindrecord_byte_image("coco", True, prefix) - print("Create Mindrecord Done, at {}".format(mindrecord_dir)) - else: - print("coco_root not exits.") - elif args_opt.dataset == "voc": - if os.path.isdir(config.voc_dir): - print("Create Mindrecord.") - voc_data_to_mindrecord(mindrecord_dir, True, prefix) - print("Create Mindrecord Done, at {}".format(mindrecord_dir)) - else: - print("voc_dir not exits.") - else: - if os.path.isdir(config.image_dir) and os.path.exists(config.anno_path): - print("Create Mindrecord.") - data_to_mindrecord_byte_image("other", True, prefix) - print("Create Mindrecord Done, at {}".format(mindrecord_dir)) - else: - print("image_dir or anno_path not exits.") + mindrecord_file = create_mindrecord(args_opt.dataset, "ssd.mindrecord", True) if not args_opt.only_create_dataset: loss_scale = float(args_opt.loss_scale) + if args_opt.run_platform == "CPU": + loss_scale = 1.0 # When create MindDataset, using the fitst mindrecord file, such as ssd.mindrecord0. - dataset = create_ssd_dataset(mindrecord_file, repeat_num=1, - batch_size=args_opt.batch_size, device_num=device_num, rank=rank) + use_multiprocessing = (args_opt.run_platform != "CPU") + dataset = create_ssd_dataset(mindrecord_file, repeat_num=1, batch_size=args_opt.batch_size, + device_num=device_num, rank=rank, use_multiprocessing=use_multiprocessing) dataset_size = dataset.get_dataset_size() print("Create dataset done!") @@ -140,27 +103,30 @@ def main(): ckpoint_cb = ModelCheckpoint(prefix="ssd", directory=save_ckpt_path, config=ckpt_config) if args_opt.pre_trained: - if args_opt.pre_trained_epoch_size <= 0: - raise KeyError("pre_trained_epoch_size must be greater than 0.") param_dict = load_checkpoint(args_opt.pre_trained) if args_opt.filter_weight: filter_checkpoint_parameter(param_dict) load_param_into_net(net, param_dict) - lr = Tensor(get_lr(global_step=config.global_step, + if args_opt.freeze_layer == "backbone": + for param in backbone.feature_1.trainable_params(): + param.requires_grad = False + + lr = Tensor(get_lr(global_step=args_opt.pre_trained_epoch_size * dataset_size, lr_init=config.lr_init, lr_end=config.lr_end_rate * args_opt.lr, lr_max=args_opt.lr, warmup_epochs=config.warmup_epochs, total_epochs=args_opt.epoch_size, steps_per_epoch=dataset_size)) + opt = nn.Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum, config.weight_decay, loss_scale) + net = TrainingWrapper(net, opt, loss_scale) callback = [TimeMonitor(data_size=dataset_size), LossMonitor(), ckpoint_cb] - model = Model(net) dataset_sink_mode = False - if args_opt.mode == "sink": + if args_opt.mode == "sink" and args_opt.run_platform != "CPU": print("In sink mode, one epoch return a loss.") dataset_sink_mode = True print("Start train SSD, the first epoch will be slower because of the graph compilation.")