Merge pull request !7948 from zhaoting/ssdtags/v1.1.0
| @@ -12,6 +12,7 @@ | |||||
| - [Training](#training) | - [Training](#training) | ||||
| - [Evaluation Process](#evaluation-process) | - [Evaluation Process](#evaluation-process) | ||||
| - [Evaluation](#evaluation) | - [Evaluation](#evaluation) | ||||
| - [Export MindIR](#export-mindir) | |||||
| - [Model Description](#model-description) | - [Model Description](#model-description) | ||||
| - [Performance](#performance) | - [Performance](#performance) | ||||
| - [Evaluation Performance](#evaluation-performance) | - [Evaluation Performance](#evaluation-performance) | ||||
| @@ -49,21 +50,23 @@ Dataset used: [COCO2017](<http://images.cocodataset.org/>) | |||||
| - Download the dataset COCO2017. | - Download the dataset COCO2017. | ||||
| - We use COCO2017 as training dataset in this example by default, and you can also use your own datasets. | - We use COCO2017 as training dataset in this example by default, and you can also use your own datasets. | ||||
| First, install Cython ,pycocotool and opencv to process data and to get evaluation result. | |||||
| 1. If coco dataset is used. **Select dataset to coco when run script.** | |||||
| Install Cython and pycocotool, and you can also install mmcv to process data. | |||||
| ``` | |||||
| pip install Cython | |||||
| ``` | |||||
| pip install Cython | |||||
| pip install pycocotools | |||||
| pip install pycocotools | |||||
| pip install opencv-python | |||||
| ``` | |||||
| And change the COCO_ROOT and other settings you need in `config.py`. The directory structure is as follows: | |||||
| ``` | |||||
| 1. If coco dataset is used. **Select dataset to coco when run script.** | |||||
| Change the `coco_root` and other settings you need in `src/config.py`. The directory structure is as follows: | |||||
| ``` | ``` | ||||
| . | . | ||||
| └─cocodataset | |||||
| └─coco_dataset | |||||
| ├─annotations | ├─annotations | ||||
| ├─instance_train2017.json | ├─instance_train2017.json | ||||
| └─instance_val2017.json | └─instance_val2017.json | ||||
| @@ -72,7 +75,27 @@ Dataset used: [COCO2017](<http://images.cocodataset.org/>) | |||||
| ``` | ``` | ||||
| 2. If your own dataset is used. **Select dataset to other when run script.** | |||||
| 2. If VOC dataset is used. **Select dataset to voc when run script.** | |||||
| Change `classes`, `num_classes`, `voc_json` and `voc_root` in `src/config.py`. `voc_json` is the path of json file with coco format for evalution, `voc_root` is the path of VOC dataset, the directory structure is as follows: | |||||
| ``` | |||||
| . | |||||
| └─voc_dataset | |||||
| └─train | |||||
| ├─0001.jpg | |||||
| └─0001.xml | |||||
| ... | |||||
| ├─xxxx.jpg | |||||
| └─xxxx.xml | |||||
| └─eval | |||||
| ├─0001.jpg | |||||
| └─0001.xml | |||||
| ... | |||||
| ├─xxxx.jpg | |||||
| └─xxxx.xml | |||||
| ``` | |||||
| 3. If your own dataset is used. **Select dataset to other when run script.** | |||||
| Organize the dataset infomation into a TXT file, each row in the file is as follows: | Organize the dataset infomation into a TXT file, each row in the file is as follows: | ||||
| ``` | ``` | ||||
| @@ -80,7 +103,7 @@ Dataset used: [COCO2017](<http://images.cocodataset.org/>) | |||||
| ``` | ``` | ||||
| Each row is an image annotation which split by space, the first column is a relative path of image, the others are box and class infomations of the format [xmin,ymin,xmax,ymax,class]. We read image from an image path joined by the `IMAGE_DIR`(dataset directory) and the relative path in `ANNO_PATH`(the TXT file path), `IMAGE_DIR` and `ANNO_PATH` are setting in `config.py`. | |||||
| Each row is an image annotation which split by space, the first column is a relative path of image, the others are box and class infomations of the format [xmin,ymin,xmax,ymax,class]. We read image from an image path joined by the `image_dir`(dataset directory) and the relative path in `anno_path`(the TXT file path), `image_dir` and `anno_path` are setting in `src/config.py`. | |||||
| # [Quick Start](#contents) | # [Quick Start](#contents) | ||||
| @@ -103,7 +126,19 @@ sh run_distribute_train_gpu.sh [DEVICE_NUM] [EPOCH_SIZE] [LR] [DATASET] | |||||
| sh run_eval_gpu.sh [DATASET] [CHECKPOINT_PATH] [DEVICE_ID] | sh run_eval_gpu.sh [DATASET] [CHECKPOINT_PATH] [DEVICE_ID] | ||||
| ``` | ``` | ||||
| # [Script Description](#contents) | |||||
| - runing on CPU(support Windows and Ubuntu) | |||||
| **CPU is usually used for fine-tuning, which needs pre_trained checkpoint.** | |||||
| ``` | |||||
| # training on CPU | |||||
| python train.py --run_platform=CPU --lr=[LR] --dataset=[DATASET] --epoch_size=[EPOCH_SIZE] --batch_size=[BATCH_SIZE] --pre_trained=[PRETRAINED_CKPT] --filter_weight=True --save_checkpoint_epochs=1 | |||||
| # run eval on GPU | |||||
| python eval.py --run_platform=CPU --dataset=[DATASET] --checkpoint_path=[PRETRAINED_CKPT] | |||||
| ``` | |||||
| # [Script Description](#contents) | |||||
| ## [Script and Sample Code](#contents) | ## [Script and Sample Code](#contents) | ||||
| @@ -111,24 +146,25 @@ sh run_eval_gpu.sh [DATASET] [CHECKPOINT_PATH] [DEVICE_ID] | |||||
| . | . | ||||
| └─ cv | └─ cv | ||||
| └─ ssd | └─ ssd | ||||
| ├─ README.md ## descriptions about SSD | |||||
| ├─ README.md # descriptions about SSD | |||||
| ├─ scripts | ├─ scripts | ||||
| ├─ run_distribute_train.sh ## shell script for distributed on ascend | |||||
| ├─ run_distribute_train_gpu.sh ## shell script for distributed on gpu | |||||
| ├─ run_eval.sh ## shell script for eval on ascend | |||||
| └─ run_eval_gpu.sh ## shell script for eval on gpu | |||||
| ├─ run_distribute_train.sh # shell script for distributed on ascend | |||||
| ├─ run_distribute_train_gpu.sh # shell script for distributed on gpu | |||||
| ├─ run_eval.sh # shell script for eval on ascend | |||||
| └─ run_eval_gpu.sh # shell script for eval on gpu | |||||
| ├─ src | ├─ src | ||||
| ├─ __init__.py ## init file | |||||
| ├─ box_util.py ## bbox utils | |||||
| ├─ coco_eval.py ## coco metrics utils | |||||
| ├─ config.py ## total config | |||||
| ├─ dataset.py ## create dataset and process dataset | |||||
| ├─ init_params.py ## parameters utils | |||||
| ├─ lr_schedule.py ## learning ratio generator | |||||
| └─ ssd.py ## ssd architecture | |||||
| ├─ eval.py ## eval scripts | |||||
| ├─ train.py ## train scripts | |||||
| └─ mindspore_hub_conf.py ## mindspore hub interface | |||||
| ├─ __init__.py # init file | |||||
| ├─ box_utils.py # bbox utils | |||||
| ├─ eval_utils.py # metrics utils | |||||
| ├─ config.py # total config | |||||
| ├─ dataset.py # create dataset and process dataset | |||||
| ├─ init_params.py # parameters utils | |||||
| ├─ lr_schedule.py # learning ratio generator | |||||
| └─ ssd.py # ssd architecture | |||||
| ├─ eval.py # eval scripts | |||||
| ├─ train.py # train scripts | |||||
| ├─ export.py # export mindir script | |||||
| └─ mindspore_hub_conf.py # mindspore hub interface | |||||
| ``` | ``` | ||||
| ## [Script Parameters](#contents) | ## [Script Parameters](#contents) | ||||
| @@ -136,30 +172,33 @@ sh run_eval_gpu.sh [DATASET] [CHECKPOINT_PATH] [DEVICE_ID] | |||||
| ``` | ``` | ||||
| Major parameters in train.py and config.py as follows: | Major parameters in train.py and config.py as follows: | ||||
| "device_num": 1 # Use device nums | |||||
| "lr": 0.05 # Learning rate init value | |||||
| "dataset": coco # Dataset name | |||||
| "epoch_size": 500 # Epoch size | |||||
| "batch_size": 32 # Batch size of input tensor | |||||
| "pre_trained": None # Pretrained checkpoint file path | |||||
| "pre_trained_epoch_size": 0 # Pretrained epoch size | |||||
| "save_checkpoint_epochs": 10 # The epoch interval between two checkpoints. By default, the checkpoint will be saved per 10 epochs | |||||
| "loss_scale": 1024 # Loss scale | |||||
| "class_num": 81 # Dataset class number | |||||
| "image_shape": [300, 300] # Image height and width used as input to the model | |||||
| "mindrecord_dir": "/data/MindRecord_COCO" # MindRecord path | |||||
| "coco_root": "/data/coco2017" # COCO2017 dataset path | |||||
| "voc_root": "" # VOC original dataset path | |||||
| "image_dir": "" # Other dataset image path, if coco or voc used, it will be useless | |||||
| "anno_path": "" # Other dataset annotation path, if coco or voc used, it will be useless | |||||
| "device_num": 1 # Use device nums | |||||
| "lr": 0.05 # Learning rate init value | |||||
| "dataset": coco # Dataset name | |||||
| "epoch_size": 500 # Epoch size | |||||
| "batch_size": 32 # Batch size of input tensor | |||||
| "pre_trained": None # Pretrained checkpoint file path | |||||
| "pre_trained_epoch_size": 0 # Pretrained epoch size | |||||
| "save_checkpoint_epochs": 10 # The epoch interval between two checkpoints. By default, the checkpoint will be saved per 10 epochs | |||||
| "loss_scale": 1024 # Loss scale | |||||
| "filter_weight": False # Load paramters in head layer or not. If the class numbers of train dataset is different from the class numbers in pre_trained checkpoint, please set True. | |||||
| "freeze_layer": "none" # Freeze the backbone paramters or not, support none and backbone. | |||||
| "class_num": 81 # Dataset class number | |||||
| "image_shape": [300, 300] # Image height and width used as input to the model | |||||
| "mindrecord_dir": "/data/MindRecord_COCO" # MindRecord path | |||||
| "coco_root": "/data/coco2017" # COCO2017 dataset path | |||||
| "voc_root": "/data/voc_dataset" # VOC original dataset path | |||||
| "voc_json": "annotations/voc_instances_val.json" # is the path of json file with coco format for evalution | |||||
| "image_dir": "" # Other dataset image path, if coco or voc used, it will be useless | |||||
| "anno_path": "" # Other dataset annotation path, if coco or voc used, it will be useless | |||||
| ``` | ``` | ||||
| ## [Training Process](#contents) | ## [Training Process](#contents) | ||||
| To train the model, run `train.py`. If the `mindrecord_dir` is empty, it will generate [mindrecord](https://www.mindspore.cn/tutorial/training/zh-CN/master/advanced_use/convert_dataset.html) files by `coco_root`(coco dataset) or `iamge_dir` and `anno_path`(own dataset). **Note if mindrecord_dir isn't empty, it will use mindrecord_dir instead of raw images.** | |||||
| To train the model, run `train.py`. If the `mindrecord_dir` is empty, it will generate [mindrecord](https://www.mindspore.cn/tutorial/training/zh-CN/master/advanced_use/convert_dataset.html) files by `coco_root`(coco dataset), `voc_root`(voc dataset) or `image_dir` and `anno_path`(own dataset). **Note if mindrecord_dir isn't empty, it will use mindrecord_dir instead of raw images.** | |||||
| ### Training on Ascend | ### Training on Ascend | ||||
| @@ -292,6 +331,14 @@ Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.686 | |||||
| mAP: 0.2244936111705981 | mAP: 0.2244936111705981 | ||||
| ``` | ``` | ||||
| ## [Export MindIR](#contents) | |||||
| Change the export mode and export file in `src/config.py`, and run `export.py`. | |||||
| ``` | |||||
| python export.py --run_platform [PLATFORM] --checkpoint_path [CKPT_PATH] | |||||
| ``` | |||||
| # [Model Description](#contents) | # [Model Description](#contents) | ||||
| ## [Performance](#contents) | ## [Performance](#contents) | ||||
| @@ -22,14 +22,15 @@ import numpy as np | |||||
| from mindspore import context, Tensor | from mindspore import context, Tensor | ||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | from mindspore.train.serialization import load_checkpoint, load_param_into_net | ||||
| from src.ssd import SSD300, ssd_mobilenet_v2 | from src.ssd import SSD300, ssd_mobilenet_v2 | ||||
| from src.dataset import create_ssd_dataset, data_to_mindrecord_byte_image, voc_data_to_mindrecord | |||||
| from src.dataset import create_ssd_dataset, create_mindrecord | |||||
| from src.config import config | from src.config import config | ||||
| from src.coco_eval import metrics | |||||
| from src.eval_utils import metrics | |||||
| def ssd_eval(dataset_path, ckpt_path): | |||||
| def ssd_eval(dataset_path, ckpt_path, anno_json): | |||||
| """SSD evaluation.""" | """SSD evaluation.""" | ||||
| batch_size = 1 | batch_size = 1 | ||||
| ds = create_ssd_dataset(dataset_path, batch_size=batch_size, repeat_num=1, is_training=False) | |||||
| ds = create_ssd_dataset(dataset_path, batch_size=batch_size, repeat_num=1, | |||||
| is_training=False, use_multiprocessing=False) | |||||
| net = SSD300(ssd_mobilenet_v2(), config, is_training=False) | net = SSD300(ssd_mobilenet_v2(), config, is_training=False) | ||||
| print("Load Checkpoint!") | print("Load Checkpoint!") | ||||
| param_dict = load_checkpoint(ckpt_path) | param_dict = load_checkpoint(ckpt_path) | ||||
| @@ -61,51 +62,31 @@ def ssd_eval(dataset_path, ckpt_path): | |||||
| i += batch_size | i += batch_size | ||||
| cost_time = int((time.time() - start) * 1000) | cost_time = int((time.time() - start) * 1000) | ||||
| print(f' 100% [{total}/{total}] cost {cost_time} ms') | print(f' 100% [{total}/{total}] cost {cost_time} ms') | ||||
| mAP = metrics(pred_data) | |||||
| mAP = metrics(pred_data, anno_json) | |||||
| print("\n========================================\n") | print("\n========================================\n") | ||||
| print(f"mAP: {mAP}") | print(f"mAP: {mAP}") | ||||
| if __name__ == '__main__': | |||||
| def get_eval_args(): | |||||
| parser = argparse.ArgumentParser(description='SSD evaluation') | parser = argparse.ArgumentParser(description='SSD evaluation') | ||||
| parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") | parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") | ||||
| parser.add_argument("--dataset", type=str, default="coco", help="Dataset, default is coco.") | parser.add_argument("--dataset", type=str, default="coco", help="Dataset, default is coco.") | ||||
| parser.add_argument("--checkpoint_path", type=str, required=True, help="Checkpoint file path.") | parser.add_argument("--checkpoint_path", type=str, required=True, help="Checkpoint file path.") | ||||
| parser.add_argument("--run_platform", type=str, default="Ascend", choices=("Ascend", "GPU"), | |||||
| help="run platform, only support Ascend and GPU.") | |||||
| args_opt = parser.parse_args() | |||||
| parser.add_argument("--run_platform", type=str, default="Ascend", choices=("Ascend", "GPU", "CPU"), | |||||
| help="run platform, support Ascend ,GPU and CPU.") | |||||
| return parser.parse_args() | |||||
| if __name__ == '__main__': | |||||
| args_opt = get_eval_args() | |||||
| if args_opt.dataset == "coco": | |||||
| json_path = os.path.join(config.coco_root, config.instances_set.format(config.val_data_type)) | |||||
| elif args_opt.dataset == "voc": | |||||
| json_path = os.path.join(config.voc_root, config.voc_json) | |||||
| else: | |||||
| raise ValueError('SSD eval only supprt dataset mode is coco and voc!') | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.run_platform, device_id=args_opt.device_id) | context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.run_platform, device_id=args_opt.device_id) | ||||
| prefix = "ssd_eval.mindrecord" | |||||
| mindrecord_dir = config.mindrecord_dir | |||||
| mindrecord_file = os.path.join(mindrecord_dir, prefix + "0") | |||||
| if args_opt.dataset == "voc": | |||||
| config.coco_root = config.voc_root | |||||
| if not os.path.exists(mindrecord_file): | |||||
| if not os.path.isdir(mindrecord_dir): | |||||
| os.makedirs(mindrecord_dir) | |||||
| if args_opt.dataset == "coco": | |||||
| if os.path.isdir(config.coco_root): | |||||
| print("Create Mindrecord.") | |||||
| data_to_mindrecord_byte_image("coco", False, prefix) | |||||
| print("Create Mindrecord Done, at {}".format(mindrecord_dir)) | |||||
| else: | |||||
| print("coco_root not exits.") | |||||
| elif args_opt.dataset == "voc": | |||||
| if os.path.isdir(config.voc_dir) and os.path.isdir(config.voc_root): | |||||
| print("Create Mindrecord.") | |||||
| voc_data_to_mindrecord(mindrecord_dir, False, prefix) | |||||
| print("Create Mindrecord Done, at {}".format(mindrecord_dir)) | |||||
| else: | |||||
| print("voc_root or voc_dir not exits.") | |||||
| else: | |||||
| if os.path.isdir(config.image_dir) and os.path.exists(config.anno_path): | |||||
| print("Create Mindrecord.") | |||||
| data_to_mindrecord_byte_image("other", False, prefix) | |||||
| print("Create Mindrecord Done, at {}".format(mindrecord_dir)) | |||||
| else: | |||||
| print("IMAGE_DIR or ANNO_PATH not exits.") | |||||
| mindrecord_file = create_mindrecord(args_opt.dataset, "ssd_eval.mindrecord", False) | |||||
| print("Start Eval!") | print("Start Eval!") | ||||
| ssd_eval(mindrecord_file, args_opt.checkpoint_path) | |||||
| ssd_eval(mindrecord_file, args_opt.checkpoint_path, json_path) | |||||
| @@ -0,0 +1,41 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ | |||||
| ssd export mindir. | |||||
| """ | |||||
| import argparse | |||||
| import numpy as np | |||||
| from mindspore import context, Tensor | |||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net, export | |||||
| from src.ssd import SSD300, ssd_mobilenet_v2 | |||||
| from src.config import config | |||||
| def get_export_args(): | |||||
| parser = argparse.ArgumentParser(description='SSD export') | |||||
| parser.add_argument("--checkpoint_path", type=str, required=True, help="Checkpoint file path.") | |||||
| parser.add_argument("--run_platform", type=str, default="Ascend", choices=("Ascend", "GPU", "CPU"), | |||||
| help="run platform, support Ascend, GPU and CPU.") | |||||
| return parser.parse_args() | |||||
| if __name__ == '__main__': | |||||
| args_opt = get_export_args() | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.run_platform) | |||||
| net = SSD300(ssd_mobilenet_v2(), config, is_training=False) | |||||
| param_dict = load_checkpoint(args_opt.checkpoint_path) | |||||
| load_param_into_net(net, param_dict) | |||||
| input_shp = [1, 3] + config.img_shape | |||||
| input_array = Tensor(np.random.uniform(-1.0, 1.0, size=input_shp).astype(np.float32)) | |||||
| export(net, input_array, file_name=config.export_file, file_format=config.export_format) | |||||
| @@ -25,7 +25,7 @@ class GeneratDefaultBoxes(): | |||||
| """ | """ | ||||
| Generate Default boxes for SSD, follows the order of (W, H, archor_sizes). | Generate Default boxes for SSD, follows the order of (W, H, archor_sizes). | ||||
| `self.default_boxes` has a shape of [archor_sizes, H, W, 4], the last dimension is [y, x, h, w]. | `self.default_boxes` has a shape of [archor_sizes, H, W, 4], the last dimension is [y, x, h, w]. | ||||
| `self.default_boxes_ltrb` has a shape as `self.default_boxes`, the last dimension is [y1, x1, y2, x2]. | |||||
| `self.default_boxes_tlbr` has a shape as `self.default_boxes`, the last dimension is [y1, x1, y2, x2]. | |||||
| """ | """ | ||||
| def __init__(self): | def __init__(self): | ||||
| fk = config.img_shape[0] / np.array(config.steps) | fk = config.img_shape[0] / np.array(config.steps) | ||||
| @@ -54,17 +54,17 @@ class GeneratDefaultBoxes(): | |||||
| cx, cy = (j + 0.5) / fk[idex], (i + 0.5) / fk[idex] | cx, cy = (j + 0.5) / fk[idex], (i + 0.5) / fk[idex] | ||||
| self.default_boxes.append([cy, cx, h, w]) | self.default_boxes.append([cy, cx, h, w]) | ||||
| def to_ltrb(cy, cx, h, w): | |||||
| def to_tlbr(cy, cx, h, w): | |||||
| return cy - h / 2, cx - w / 2, cy + h / 2, cx + w / 2 | return cy - h / 2, cx - w / 2, cy + h / 2, cx + w / 2 | ||||
| # For IoU calculation | # For IoU calculation | ||||
| self.default_boxes_ltrb = np.array(tuple(to_ltrb(*i) for i in self.default_boxes), dtype='float32') | |||||
| self.default_boxes_tlbr = np.array(tuple(to_tlbr(*i) for i in self.default_boxes), dtype='float32') | |||||
| self.default_boxes = np.array(self.default_boxes, dtype='float32') | self.default_boxes = np.array(self.default_boxes, dtype='float32') | ||||
| default_boxes_ltrb = GeneratDefaultBoxes().default_boxes_ltrb | |||||
| default_boxes_tlbr = GeneratDefaultBoxes().default_boxes_tlbr | |||||
| default_boxes = GeneratDefaultBoxes().default_boxes | default_boxes = GeneratDefaultBoxes().default_boxes | ||||
| y1, x1, y2, x2 = np.split(default_boxes_ltrb[:, :4], 4, axis=-1) | |||||
| y1, x1, y2, x2 = np.split(default_boxes_tlbr[:, :4], 4, axis=-1) | |||||
| vol_anchors = (x2 - x1) * (y2 - y1) | vol_anchors = (x2 - x1) * (y2 - y1) | ||||
| matching_threshold = config.match_threshold | matching_threshold = config.match_threshold | ||||
| @@ -115,7 +115,7 @@ def ssd_bboxes_encode(boxes): | |||||
| index = np.nonzero(t_label) | index = np.nonzero(t_label) | ||||
| # Transform to ltrb. | |||||
| # Transform to tlbr. | |||||
| bboxes = np.zeros((config.num_ssd_boxes, 4), dtype=np.float32) | bboxes = np.zeros((config.num_ssd_boxes, 4), dtype=np.float32) | ||||
| bboxes[:, [0, 1]] = (t_boxes[:, [0, 1]] + t_boxes[:, [2, 3]]) / 2 | bboxes[:, [0, 1]] = (t_boxes[:, [0, 1]] + t_boxes[:, [2, 3]]) / 2 | ||||
| bboxes[:, [2, 3]] = t_boxes[:, [2, 3]] - t_boxes[:, [0, 1]] | bboxes[:, [2, 3]] = t_boxes[:, [2, 3]] - t_boxes[:, [0, 1]] | ||||
| @@ -27,7 +27,6 @@ config = ed({ | |||||
| "max_boxes": 100, | "max_boxes": 100, | ||||
| # learing rate settings | # learing rate settings | ||||
| "global_step": 0, | |||||
| "lr_init": 0.001, | "lr_init": 0.001, | ||||
| "lr_end_rate": 0.001, | "lr_end_rate": 0.001, | ||||
| "warmup_epochs": 2, | "warmup_epochs": 2, | ||||
| @@ -55,27 +54,29 @@ config = ed({ | |||||
| "train_data_type": "train2017", | "train_data_type": "train2017", | ||||
| "val_data_type": "val2017", | "val_data_type": "val2017", | ||||
| "instances_set": "annotations/instances_{}.json", | "instances_set": "annotations/instances_{}.json", | ||||
| "coco_classes": ('background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', | |||||
| 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', | |||||
| 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', | |||||
| 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', | |||||
| 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', | |||||
| 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', | |||||
| 'kite', 'baseball bat', 'baseball glove', 'skateboard', | |||||
| 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', | |||||
| 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', | |||||
| 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', | |||||
| 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', | |||||
| 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', | |||||
| 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', | |||||
| 'refrigerator', 'book', 'clock', 'vase', 'scissors', | |||||
| 'teddy bear', 'hair drier', 'toothbrush'), | |||||
| "classes": ('background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', | |||||
| 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', | |||||
| 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', | |||||
| 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', | |||||
| 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', | |||||
| 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', | |||||
| 'kite', 'baseball bat', 'baseball glove', 'skateboard', | |||||
| 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', | |||||
| 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', | |||||
| 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', | |||||
| 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', | |||||
| 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', | |||||
| 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', | |||||
| 'refrigerator', 'book', 'clock', 'vase', 'scissors', | |||||
| 'teddy bear', 'hair drier', 'toothbrush'), | |||||
| "num_classes": 81, | "num_classes": 81, | ||||
| # The annotation.json position of voc validation dataset. | # The annotation.json position of voc validation dataset. | ||||
| "voc_root": "", | |||||
| "voc_json": "annotations/voc_instances_val.json", | |||||
| # voc original dataset. | # voc original dataset. | ||||
| "voc_dir": "", | |||||
| "voc_root": "/data/voc_dataset", | |||||
| # if coco or voc used, `image_dir` and `anno_path` are useless. | # if coco or voc used, `image_dir` and `anno_path` are useless. | ||||
| "image_dir": "", | "image_dir": "", | ||||
| "anno_path": "", | "anno_path": "", | ||||
| "export_format": "MINDIR", | |||||
| "export_file": "ssd.mindir" | |||||
| }) | }) | ||||
| @@ -159,10 +159,10 @@ def preprocess_fn(img_id, image, box, is_training): | |||||
| def create_voc_label(is_training): | def create_voc_label(is_training): | ||||
| """Get image path and annotation from VOC.""" | """Get image path and annotation from VOC.""" | ||||
| voc_dir = config.voc_dir | |||||
| cls_map = {name: i for i, name in enumerate(config.coco_classes)} | |||||
| voc_root = config.voc_root | |||||
| cls_map = {name: i for i, name in enumerate(config.classes)} | |||||
| sub_dir = 'train' if is_training else 'eval' | sub_dir = 'train' if is_training else 'eval' | ||||
| voc_dir = os.path.join(voc_dir, sub_dir) | |||||
| voc_dir = os.path.join(voc_root, sub_dir) | |||||
| if not os.path.isdir(voc_dir): | if not os.path.isdir(voc_dir): | ||||
| raise ValueError(f'Cannot find {sub_dir} dataset path.') | raise ValueError(f'Cannot find {sub_dir} dataset path.') | ||||
| @@ -173,8 +173,7 @@ def create_voc_label(is_training): | |||||
| anno_dir = os.path.join(voc_dir, 'Annotations') | anno_dir = os.path.join(voc_dir, 'Annotations') | ||||
| if not is_training: | if not is_training: | ||||
| data_dir = config.voc_root | |||||
| json_file = os.path.join(data_dir, config.instances_set.format(sub_dir)) | |||||
| json_file = os.path.join(config.voc_root, config.voc_json) | |||||
| file_dir = os.path.split(json_file)[0] | file_dir = os.path.split(json_file)[0] | ||||
| if not os.path.isdir(file_dir): | if not os.path.isdir(file_dir): | ||||
| os.makedirs(file_dir) | os.makedirs(file_dir) | ||||
| @@ -203,7 +202,7 @@ def create_voc_label(is_training): | |||||
| for obj in root_node.iter('object'): | for obj in root_node.iter('object'): | ||||
| cls_name = obj.find('name').text | cls_name = obj.find('name').text | ||||
| if cls_name not in cls_map: | if cls_name not in cls_map: | ||||
| print(f'Label "{cls_name}" not in "{config.coco_classes}"') | |||||
| print(f'Label "{cls_name}" not in "{config.classes}"') | |||||
| continue | continue | ||||
| bnd_box = obj.find('bndbox') | bnd_box = obj.find('bndbox') | ||||
| x_min = int(bnd_box.find('xmin').text) - 1 | x_min = int(bnd_box.find('xmin').text) - 1 | ||||
| @@ -258,7 +257,7 @@ def create_coco_label(is_training): | |||||
| data_type = config.train_data_type | data_type = config.train_data_type | ||||
| # Classes need to train or test. | # Classes need to train or test. | ||||
| train_cls = config.coco_classes | |||||
| train_cls = config.classes | |||||
| train_cls_dict = {} | train_cls_dict = {} | ||||
| for i, cls in enumerate(train_cls): | for i, cls in enumerate(train_cls): | ||||
| train_cls_dict[cls] = i | train_cls_dict[cls] = i | ||||
| @@ -390,7 +389,7 @@ def data_to_mindrecord_byte_image(dataset="coco", is_training=True, prefix="ssd. | |||||
| def create_ssd_dataset(mindrecord_file, batch_size=32, repeat_num=10, device_num=1, rank=0, | def create_ssd_dataset(mindrecord_file, batch_size=32, repeat_num=10, device_num=1, rank=0, | ||||
| is_training=True, num_parallel_workers=4): | |||||
| is_training=True, num_parallel_workers=4, use_multiprocessing=True): | |||||
| """Creatr SSD dataset with MindDataset.""" | """Creatr SSD dataset with MindDataset.""" | ||||
| ds = de.MindDataset(mindrecord_file, columns_list=["img_id", "image", "annotation"], num_shards=device_num, | ds = de.MindDataset(mindrecord_file, columns_list=["img_id", "image", "annotation"], num_shards=device_num, | ||||
| shard_id=rank, num_parallel_workers=num_parallel_workers, shuffle=is_training) | shard_id=rank, num_parallel_workers=num_parallel_workers, shuffle=is_training) | ||||
| @@ -409,10 +408,45 @@ def create_ssd_dataset(mindrecord_file, batch_size=32, repeat_num=10, device_num | |||||
| trans = [normalize_op, change_swap_op] | trans = [normalize_op, change_swap_op] | ||||
| ds = ds.map(operations=compose_map_func, input_columns=["img_id", "image", "annotation"], | ds = ds.map(operations=compose_map_func, input_columns=["img_id", "image", "annotation"], | ||||
| output_columns=output_columns, column_order=output_columns, | output_columns=output_columns, column_order=output_columns, | ||||
| python_multiprocessing=is_training, | |||||
| python_multiprocessing=use_multiprocessing, | |||||
| num_parallel_workers=num_parallel_workers) | num_parallel_workers=num_parallel_workers) | ||||
| ds = ds.map(operations=trans, input_columns=["image"], python_multiprocessing=is_training, | |||||
| ds = ds.map(operations=trans, input_columns=["image"], python_multiprocessing=use_multiprocessing, | |||||
| num_parallel_workers=num_parallel_workers) | num_parallel_workers=num_parallel_workers) | ||||
| ds = ds.batch(batch_size, drop_remainder=True) | ds = ds.batch(batch_size, drop_remainder=True) | ||||
| ds = ds.repeat(repeat_num) | ds = ds.repeat(repeat_num) | ||||
| return ds | return ds | ||||
| def create_mindrecord(dataset="coco", prefix="ssd.mindrecord", is_training=True): | |||||
| print("Start create dataset!") | |||||
| # It will generate mindrecord file in config.mindrecord_dir, | |||||
| # and the file name is ssd.mindrecord0, 1, ... file_num. | |||||
| mindrecord_dir = config.mindrecord_dir | |||||
| mindrecord_file = os.path.join(mindrecord_dir, prefix + "0") | |||||
| if not os.path.exists(mindrecord_file): | |||||
| if not os.path.isdir(mindrecord_dir): | |||||
| os.makedirs(mindrecord_dir) | |||||
| if dataset == "coco": | |||||
| if os.path.isdir(config.coco_root): | |||||
| print("Create Mindrecord.") | |||||
| data_to_mindrecord_byte_image("coco", is_training, prefix) | |||||
| print("Create Mindrecord Done, at {}".format(mindrecord_dir)) | |||||
| else: | |||||
| print("coco_root not exits.") | |||||
| elif dataset == "voc": | |||||
| if os.path.isdir(config.voc_root): | |||||
| print("Create Mindrecord.") | |||||
| voc_data_to_mindrecord(mindrecord_dir, is_training, prefix) | |||||
| print("Create Mindrecord Done, at {}".format(mindrecord_dir)) | |||||
| else: | |||||
| print("voc_root not exits.") | |||||
| else: | |||||
| if os.path.isdir(config.image_dir) and os.path.exists(config.anno_path): | |||||
| print("Create Mindrecord.") | |||||
| data_to_mindrecord_byte_image("other", is_training, prefix) | |||||
| print("Create Mindrecord Done, at {}".format(mindrecord_dir)) | |||||
| else: | |||||
| print("image_dir or anno_path not exits.") | |||||
| return mindrecord_file | |||||
| @@ -14,7 +14,6 @@ | |||||
| # ============================================================================ | # ============================================================================ | ||||
| """Coco metrics utils""" | """Coco metrics utils""" | ||||
| import os | |||||
| import json | import json | ||||
| import numpy as np | import numpy as np | ||||
| from .config import config | from .config import config | ||||
| @@ -56,22 +55,17 @@ def apply_nms(all_boxes, all_scores, thres, max_boxes): | |||||
| return keep | return keep | ||||
| def metrics(pred_data): | |||||
| def metrics(pred_data, anno_json): | |||||
| """Calculate mAP of predicted bboxes.""" | """Calculate mAP of predicted bboxes.""" | ||||
| from pycocotools.coco import COCO | from pycocotools.coco import COCO | ||||
| from pycocotools.cocoeval import COCOeval | from pycocotools.cocoeval import COCOeval | ||||
| num_classes = config.num_classes | num_classes = config.num_classes | ||||
| coco_root = config.coco_root | |||||
| data_type = config.val_data_type | |||||
| #Classes need to train or test. | #Classes need to train or test. | ||||
| val_cls = config.coco_classes | |||||
| val_cls = config.classes | |||||
| val_cls_dict = {} | val_cls_dict = {} | ||||
| for i, cls in enumerate(val_cls): | for i, cls in enumerate(val_cls): | ||||
| val_cls_dict[i] = cls | val_cls_dict[i] = cls | ||||
| anno_json = os.path.join(coco_root, config.instances_set.format(data_type)) | |||||
| coco_gt = COCO(anno_json) | coco_gt = COCO(anno_json) | ||||
| classs_dict = {} | classs_dict = {} | ||||
| cat_ids = coco_gt.loadCats(coco_gt.getCatIds()) | cat_ids = coco_gt.loadCats(coco_gt.getCatIds()) | ||||
| @@ -15,7 +15,6 @@ | |||||
| """Train SSD and get checkpoint files.""" | """Train SSD and get checkpoint files.""" | ||||
| import os | |||||
| import argparse | import argparse | ||||
| import ast | import ast | ||||
| import mindspore.nn as nn | import mindspore.nn as nn | ||||
| @@ -28,14 +27,16 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||||
| from mindspore.common import set_seed, dtype | from mindspore.common import set_seed, dtype | ||||
| from src.ssd import SSD300, SSDWithLossCell, TrainingWrapper, ssd_mobilenet_v2 | from src.ssd import SSD300, SSDWithLossCell, TrainingWrapper, ssd_mobilenet_v2 | ||||
| from src.config import config | from src.config import config | ||||
| from src.dataset import create_ssd_dataset, data_to_mindrecord_byte_image, voc_data_to_mindrecord | |||||
| from src.dataset import create_ssd_dataset, create_mindrecord | |||||
| from src.lr_schedule import get_lr | from src.lr_schedule import get_lr | ||||
| from src.init_params import init_net_param, filter_checkpoint_parameter | from src.init_params import init_net_param, filter_checkpoint_parameter | ||||
| set_seed(1) | set_seed(1) | ||||
| def main(): | |||||
| def get_args(): | |||||
| parser = argparse.ArgumentParser(description="SSD training") | parser = argparse.ArgumentParser(description="SSD training") | ||||
| parser.add_argument("--run_platform", type=str, default="Ascend", choices=("Ascend", "GPU", "CPU"), | |||||
| help="run platform, support Ascend, GPU and CPU.") | |||||
| parser.add_argument("--only_create_dataset", type=ast.literal_eval, default=False, | parser.add_argument("--only_create_dataset", type=ast.literal_eval, default=False, | ||||
| help="If set it true, only create Mindrecord, default is False.") | help="If set it true, only create Mindrecord, default is False.") | ||||
| parser.add_argument("--distribute", type=ast.literal_eval, default=False, | parser.add_argument("--distribute", type=ast.literal_eval, default=False, | ||||
| @@ -52,77 +53,39 @@ def main(): | |||||
| parser.add_argument("--save_checkpoint_epochs", type=int, default=10, help="Save checkpoint epochs, default is 10.") | parser.add_argument("--save_checkpoint_epochs", type=int, default=10, help="Save checkpoint epochs, default is 10.") | ||||
| parser.add_argument("--loss_scale", type=int, default=1024, help="Loss scale, default is 1024.") | parser.add_argument("--loss_scale", type=int, default=1024, help="Loss scale, default is 1024.") | ||||
| parser.add_argument("--filter_weight", type=ast.literal_eval, default=False, | parser.add_argument("--filter_weight", type=ast.literal_eval, default=False, | ||||
| help="Filter weight parameters, default is False.") | |||||
| parser.add_argument("--run_platform", type=str, default="Ascend", choices=("Ascend", "GPU"), | |||||
| help="run platform, only support Ascend and GPU.") | |||||
| help="Filter head weight parameters, default is False.") | |||||
| parser.add_argument('--freeze_layer', type=str, default="none", choices=["none", "backbone"], | |||||
| help="freeze the weights of network, support freeze the backbone's weights, " | |||||
| "default is not freezing.") | |||||
| args_opt = parser.parse_args() | args_opt = parser.parse_args() | ||||
| return args_opt | |||||
| if args_opt.run_platform == "Ascend": | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id) | |||||
| def main(): | |||||
| args_opt = get_args() | |||||
| rank = 0 | |||||
| device_num = 1 | |||||
| if args_opt.run_platform == "CPU": | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="CPU") | |||||
| else: | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.run_platform, device_id=args_opt.device_id) | |||||
| if args_opt.distribute: | if args_opt.distribute: | ||||
| device_num = args_opt.device_num | device_num = args_opt.device_num | ||||
| context.reset_auto_parallel_context() | context.reset_auto_parallel_context() | ||||
| context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True, | context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True, | ||||
| device_num=device_num) | device_num=device_num) | ||||
| init() | init() | ||||
| rank = args_opt.device_id % device_num | |||||
| else: | |||||
| rank = 0 | |||||
| device_num = 1 | |||||
| elif args_opt.run_platform == "GPU": | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU", device_id=args_opt.device_id) | |||||
| init() | |||||
| if args_opt.distribute: | |||||
| device_num = args_opt.device_num | |||||
| context.reset_auto_parallel_context() | |||||
| context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True, | |||||
| device_num=device_num) | |||||
| rank = get_rank() | rank = get_rank() | ||||
| else: | |||||
| rank = 0 | |||||
| device_num = 1 | |||||
| else: | |||||
| raise ValueError("Unsupported platform.") | |||||
| print("Start create dataset!") | |||||
| # It will generate mindrecord file in args_opt.mindrecord_dir, | |||||
| # and the file name is ssd.mindrecord0, 1, ... file_num. | |||||
| prefix = "ssd.mindrecord" | |||||
| mindrecord_dir = config.mindrecord_dir | |||||
| mindrecord_file = os.path.join(mindrecord_dir, prefix + "0") | |||||
| if not os.path.exists(mindrecord_file): | |||||
| if not os.path.isdir(mindrecord_dir): | |||||
| os.makedirs(mindrecord_dir) | |||||
| if args_opt.dataset == "coco": | |||||
| if os.path.isdir(config.coco_root): | |||||
| print("Create Mindrecord.") | |||||
| data_to_mindrecord_byte_image("coco", True, prefix) | |||||
| print("Create Mindrecord Done, at {}".format(mindrecord_dir)) | |||||
| else: | |||||
| print("coco_root not exits.") | |||||
| elif args_opt.dataset == "voc": | |||||
| if os.path.isdir(config.voc_dir): | |||||
| print("Create Mindrecord.") | |||||
| voc_data_to_mindrecord(mindrecord_dir, True, prefix) | |||||
| print("Create Mindrecord Done, at {}".format(mindrecord_dir)) | |||||
| else: | |||||
| print("voc_dir not exits.") | |||||
| else: | |||||
| if os.path.isdir(config.image_dir) and os.path.exists(config.anno_path): | |||||
| print("Create Mindrecord.") | |||||
| data_to_mindrecord_byte_image("other", True, prefix) | |||||
| print("Create Mindrecord Done, at {}".format(mindrecord_dir)) | |||||
| else: | |||||
| print("image_dir or anno_path not exits.") | |||||
| mindrecord_file = create_mindrecord(args_opt.dataset, "ssd.mindrecord", True) | |||||
| if not args_opt.only_create_dataset: | if not args_opt.only_create_dataset: | ||||
| loss_scale = float(args_opt.loss_scale) | loss_scale = float(args_opt.loss_scale) | ||||
| if args_opt.run_platform == "CPU": | |||||
| loss_scale = 1.0 | |||||
| # When create MindDataset, using the fitst mindrecord file, such as ssd.mindrecord0. | # When create MindDataset, using the fitst mindrecord file, such as ssd.mindrecord0. | ||||
| dataset = create_ssd_dataset(mindrecord_file, repeat_num=1, | |||||
| batch_size=args_opt.batch_size, device_num=device_num, rank=rank) | |||||
| use_multiprocessing = (args_opt.run_platform != "CPU") | |||||
| dataset = create_ssd_dataset(mindrecord_file, repeat_num=1, batch_size=args_opt.batch_size, | |||||
| device_num=device_num, rank=rank, use_multiprocessing=use_multiprocessing) | |||||
| dataset_size = dataset.get_dataset_size() | dataset_size = dataset.get_dataset_size() | ||||
| print("Create dataset done!") | print("Create dataset done!") | ||||
| @@ -140,27 +103,30 @@ def main(): | |||||
| ckpoint_cb = ModelCheckpoint(prefix="ssd", directory=save_ckpt_path, config=ckpt_config) | ckpoint_cb = ModelCheckpoint(prefix="ssd", directory=save_ckpt_path, config=ckpt_config) | ||||
| if args_opt.pre_trained: | if args_opt.pre_trained: | ||||
| if args_opt.pre_trained_epoch_size <= 0: | |||||
| raise KeyError("pre_trained_epoch_size must be greater than 0.") | |||||
| param_dict = load_checkpoint(args_opt.pre_trained) | param_dict = load_checkpoint(args_opt.pre_trained) | ||||
| if args_opt.filter_weight: | if args_opt.filter_weight: | ||||
| filter_checkpoint_parameter(param_dict) | filter_checkpoint_parameter(param_dict) | ||||
| load_param_into_net(net, param_dict) | load_param_into_net(net, param_dict) | ||||
| lr = Tensor(get_lr(global_step=config.global_step, | |||||
| if args_opt.freeze_layer == "backbone": | |||||
| for param in backbone.feature_1.trainable_params(): | |||||
| param.requires_grad = False | |||||
| lr = Tensor(get_lr(global_step=args_opt.pre_trained_epoch_size * dataset_size, | |||||
| lr_init=config.lr_init, lr_end=config.lr_end_rate * args_opt.lr, lr_max=args_opt.lr, | lr_init=config.lr_init, lr_end=config.lr_end_rate * args_opt.lr, lr_max=args_opt.lr, | ||||
| warmup_epochs=config.warmup_epochs, | warmup_epochs=config.warmup_epochs, | ||||
| total_epochs=args_opt.epoch_size, | total_epochs=args_opt.epoch_size, | ||||
| steps_per_epoch=dataset_size)) | steps_per_epoch=dataset_size)) | ||||
| opt = nn.Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, | opt = nn.Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, | ||||
| config.momentum, config.weight_decay, loss_scale) | config.momentum, config.weight_decay, loss_scale) | ||||
| net = TrainingWrapper(net, opt, loss_scale) | net = TrainingWrapper(net, opt, loss_scale) | ||||
| callback = [TimeMonitor(data_size=dataset_size), LossMonitor(), ckpoint_cb] | callback = [TimeMonitor(data_size=dataset_size), LossMonitor(), ckpoint_cb] | ||||
| model = Model(net) | model = Model(net) | ||||
| dataset_sink_mode = False | dataset_sink_mode = False | ||||
| if args_opt.mode == "sink": | |||||
| if args_opt.mode == "sink" and args_opt.run_platform != "CPU": | |||||
| print("In sink mode, one epoch return a loss.") | print("In sink mode, one epoch return a loss.") | ||||
| dataset_sink_mode = True | dataset_sink_mode = True | ||||
| print("Start train SSD, the first epoch will be slower because of the graph compilation.") | print("Start train SSD, the first epoch will be slower because of the graph compilation.") | ||||