Merge pull request !7948 from zhaoting/ssdtags/v1.1.0
| @@ -12,6 +12,7 @@ | |||
| - [Training](#training) | |||
| - [Evaluation Process](#evaluation-process) | |||
| - [Evaluation](#evaluation) | |||
| - [Export MindIR](#export-mindir) | |||
| - [Model Description](#model-description) | |||
| - [Performance](#performance) | |||
| - [Evaluation Performance](#evaluation-performance) | |||
| @@ -49,21 +50,23 @@ Dataset used: [COCO2017](<http://images.cocodataset.org/>) | |||
| - Download the dataset COCO2017. | |||
| - We use COCO2017 as training dataset in this example by default, and you can also use your own datasets. | |||
| First, install Cython ,pycocotool and opencv to process data and to get evaluation result. | |||
| 1. If coco dataset is used. **Select dataset to coco when run script.** | |||
| Install Cython and pycocotool, and you can also install mmcv to process data. | |||
| ``` | |||
| pip install Cython | |||
| ``` | |||
| pip install Cython | |||
| pip install pycocotools | |||
| pip install pycocotools | |||
| pip install opencv-python | |||
| ``` | |||
| And change the COCO_ROOT and other settings you need in `config.py`. The directory structure is as follows: | |||
| ``` | |||
| 1. If coco dataset is used. **Select dataset to coco when run script.** | |||
| Change the `coco_root` and other settings you need in `src/config.py`. The directory structure is as follows: | |||
| ``` | |||
| . | |||
| └─cocodataset | |||
| └─coco_dataset | |||
| ├─annotations | |||
| ├─instance_train2017.json | |||
| └─instance_val2017.json | |||
| @@ -72,7 +75,27 @@ Dataset used: [COCO2017](<http://images.cocodataset.org/>) | |||
| ``` | |||
| 2. If your own dataset is used. **Select dataset to other when run script.** | |||
| 2. If VOC dataset is used. **Select dataset to voc when run script.** | |||
| Change `classes`, `num_classes`, `voc_json` and `voc_root` in `src/config.py`. `voc_json` is the path of json file with coco format for evalution, `voc_root` is the path of VOC dataset, the directory structure is as follows: | |||
| ``` | |||
| . | |||
| └─voc_dataset | |||
| └─train | |||
| ├─0001.jpg | |||
| └─0001.xml | |||
| ... | |||
| ├─xxxx.jpg | |||
| └─xxxx.xml | |||
| └─eval | |||
| ├─0001.jpg | |||
| └─0001.xml | |||
| ... | |||
| ├─xxxx.jpg | |||
| └─xxxx.xml | |||
| ``` | |||
| 3. If your own dataset is used. **Select dataset to other when run script.** | |||
| Organize the dataset infomation into a TXT file, each row in the file is as follows: | |||
| ``` | |||
| @@ -80,7 +103,7 @@ Dataset used: [COCO2017](<http://images.cocodataset.org/>) | |||
| ``` | |||
| Each row is an image annotation which split by space, the first column is a relative path of image, the others are box and class infomations of the format [xmin,ymin,xmax,ymax,class]. We read image from an image path joined by the `IMAGE_DIR`(dataset directory) and the relative path in `ANNO_PATH`(the TXT file path), `IMAGE_DIR` and `ANNO_PATH` are setting in `config.py`. | |||
| Each row is an image annotation which split by space, the first column is a relative path of image, the others are box and class infomations of the format [xmin,ymin,xmax,ymax,class]. We read image from an image path joined by the `image_dir`(dataset directory) and the relative path in `anno_path`(the TXT file path), `image_dir` and `anno_path` are setting in `src/config.py`. | |||
| # [Quick Start](#contents) | |||
| @@ -103,7 +126,19 @@ sh run_distribute_train_gpu.sh [DEVICE_NUM] [EPOCH_SIZE] [LR] [DATASET] | |||
| sh run_eval_gpu.sh [DATASET] [CHECKPOINT_PATH] [DEVICE_ID] | |||
| ``` | |||
| # [Script Description](#contents) | |||
| - runing on CPU(support Windows and Ubuntu) | |||
| **CPU is usually used for fine-tuning, which needs pre_trained checkpoint.** | |||
| ``` | |||
| # training on CPU | |||
| python train.py --run_platform=CPU --lr=[LR] --dataset=[DATASET] --epoch_size=[EPOCH_SIZE] --batch_size=[BATCH_SIZE] --pre_trained=[PRETRAINED_CKPT] --filter_weight=True --save_checkpoint_epochs=1 | |||
| # run eval on GPU | |||
| python eval.py --run_platform=CPU --dataset=[DATASET] --checkpoint_path=[PRETRAINED_CKPT] | |||
| ``` | |||
| # [Script Description](#contents) | |||
| ## [Script and Sample Code](#contents) | |||
| @@ -111,24 +146,25 @@ sh run_eval_gpu.sh [DATASET] [CHECKPOINT_PATH] [DEVICE_ID] | |||
| . | |||
| └─ cv | |||
| └─ ssd | |||
| ├─ README.md ## descriptions about SSD | |||
| ├─ README.md # descriptions about SSD | |||
| ├─ scripts | |||
| ├─ run_distribute_train.sh ## shell script for distributed on ascend | |||
| ├─ run_distribute_train_gpu.sh ## shell script for distributed on gpu | |||
| ├─ run_eval.sh ## shell script for eval on ascend | |||
| └─ run_eval_gpu.sh ## shell script for eval on gpu | |||
| ├─ run_distribute_train.sh # shell script for distributed on ascend | |||
| ├─ run_distribute_train_gpu.sh # shell script for distributed on gpu | |||
| ├─ run_eval.sh # shell script for eval on ascend | |||
| └─ run_eval_gpu.sh # shell script for eval on gpu | |||
| ├─ src | |||
| ├─ __init__.py ## init file | |||
| ├─ box_util.py ## bbox utils | |||
| ├─ coco_eval.py ## coco metrics utils | |||
| ├─ config.py ## total config | |||
| ├─ dataset.py ## create dataset and process dataset | |||
| ├─ init_params.py ## parameters utils | |||
| ├─ lr_schedule.py ## learning ratio generator | |||
| └─ ssd.py ## ssd architecture | |||
| ├─ eval.py ## eval scripts | |||
| ├─ train.py ## train scripts | |||
| └─ mindspore_hub_conf.py ## mindspore hub interface | |||
| ├─ __init__.py # init file | |||
| ├─ box_utils.py # bbox utils | |||
| ├─ eval_utils.py # metrics utils | |||
| ├─ config.py # total config | |||
| ├─ dataset.py # create dataset and process dataset | |||
| ├─ init_params.py # parameters utils | |||
| ├─ lr_schedule.py # learning ratio generator | |||
| └─ ssd.py # ssd architecture | |||
| ├─ eval.py # eval scripts | |||
| ├─ train.py # train scripts | |||
| ├─ export.py # export mindir script | |||
| └─ mindspore_hub_conf.py # mindspore hub interface | |||
| ``` | |||
| ## [Script Parameters](#contents) | |||
| @@ -136,30 +172,33 @@ sh run_eval_gpu.sh [DATASET] [CHECKPOINT_PATH] [DEVICE_ID] | |||
| ``` | |||
| Major parameters in train.py and config.py as follows: | |||
| "device_num": 1 # Use device nums | |||
| "lr": 0.05 # Learning rate init value | |||
| "dataset": coco # Dataset name | |||
| "epoch_size": 500 # Epoch size | |||
| "batch_size": 32 # Batch size of input tensor | |||
| "pre_trained": None # Pretrained checkpoint file path | |||
| "pre_trained_epoch_size": 0 # Pretrained epoch size | |||
| "save_checkpoint_epochs": 10 # The epoch interval between two checkpoints. By default, the checkpoint will be saved per 10 epochs | |||
| "loss_scale": 1024 # Loss scale | |||
| "class_num": 81 # Dataset class number | |||
| "image_shape": [300, 300] # Image height and width used as input to the model | |||
| "mindrecord_dir": "/data/MindRecord_COCO" # MindRecord path | |||
| "coco_root": "/data/coco2017" # COCO2017 dataset path | |||
| "voc_root": "" # VOC original dataset path | |||
| "image_dir": "" # Other dataset image path, if coco or voc used, it will be useless | |||
| "anno_path": "" # Other dataset annotation path, if coco or voc used, it will be useless | |||
| "device_num": 1 # Use device nums | |||
| "lr": 0.05 # Learning rate init value | |||
| "dataset": coco # Dataset name | |||
| "epoch_size": 500 # Epoch size | |||
| "batch_size": 32 # Batch size of input tensor | |||
| "pre_trained": None # Pretrained checkpoint file path | |||
| "pre_trained_epoch_size": 0 # Pretrained epoch size | |||
| "save_checkpoint_epochs": 10 # The epoch interval between two checkpoints. By default, the checkpoint will be saved per 10 epochs | |||
| "loss_scale": 1024 # Loss scale | |||
| "filter_weight": False # Load paramters in head layer or not. If the class numbers of train dataset is different from the class numbers in pre_trained checkpoint, please set True. | |||
| "freeze_layer": "none" # Freeze the backbone paramters or not, support none and backbone. | |||
| "class_num": 81 # Dataset class number | |||
| "image_shape": [300, 300] # Image height and width used as input to the model | |||
| "mindrecord_dir": "/data/MindRecord_COCO" # MindRecord path | |||
| "coco_root": "/data/coco2017" # COCO2017 dataset path | |||
| "voc_root": "/data/voc_dataset" # VOC original dataset path | |||
| "voc_json": "annotations/voc_instances_val.json" # is the path of json file with coco format for evalution | |||
| "image_dir": "" # Other dataset image path, if coco or voc used, it will be useless | |||
| "anno_path": "" # Other dataset annotation path, if coco or voc used, it will be useless | |||
| ``` | |||
| ## [Training Process](#contents) | |||
| To train the model, run `train.py`. If the `mindrecord_dir` is empty, it will generate [mindrecord](https://www.mindspore.cn/tutorial/training/zh-CN/master/advanced_use/convert_dataset.html) files by `coco_root`(coco dataset) or `iamge_dir` and `anno_path`(own dataset). **Note if mindrecord_dir isn't empty, it will use mindrecord_dir instead of raw images.** | |||
| To train the model, run `train.py`. If the `mindrecord_dir` is empty, it will generate [mindrecord](https://www.mindspore.cn/tutorial/training/zh-CN/master/advanced_use/convert_dataset.html) files by `coco_root`(coco dataset), `voc_root`(voc dataset) or `image_dir` and `anno_path`(own dataset). **Note if mindrecord_dir isn't empty, it will use mindrecord_dir instead of raw images.** | |||
| ### Training on Ascend | |||
| @@ -292,6 +331,14 @@ Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.686 | |||
| mAP: 0.2244936111705981 | |||
| ``` | |||
| ## [Export MindIR](#contents) | |||
| Change the export mode and export file in `src/config.py`, and run `export.py`. | |||
| ``` | |||
| python export.py --run_platform [PLATFORM] --checkpoint_path [CKPT_PATH] | |||
| ``` | |||
| # [Model Description](#contents) | |||
| ## [Performance](#contents) | |||
| @@ -22,14 +22,15 @@ import numpy as np | |||
| from mindspore import context, Tensor | |||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| from src.ssd import SSD300, ssd_mobilenet_v2 | |||
| from src.dataset import create_ssd_dataset, data_to_mindrecord_byte_image, voc_data_to_mindrecord | |||
| from src.dataset import create_ssd_dataset, create_mindrecord | |||
| from src.config import config | |||
| from src.coco_eval import metrics | |||
| from src.eval_utils import metrics | |||
| def ssd_eval(dataset_path, ckpt_path): | |||
| def ssd_eval(dataset_path, ckpt_path, anno_json): | |||
| """SSD evaluation.""" | |||
| batch_size = 1 | |||
| ds = create_ssd_dataset(dataset_path, batch_size=batch_size, repeat_num=1, is_training=False) | |||
| ds = create_ssd_dataset(dataset_path, batch_size=batch_size, repeat_num=1, | |||
| is_training=False, use_multiprocessing=False) | |||
| net = SSD300(ssd_mobilenet_v2(), config, is_training=False) | |||
| print("Load Checkpoint!") | |||
| param_dict = load_checkpoint(ckpt_path) | |||
| @@ -61,51 +62,31 @@ def ssd_eval(dataset_path, ckpt_path): | |||
| i += batch_size | |||
| cost_time = int((time.time() - start) * 1000) | |||
| print(f' 100% [{total}/{total}] cost {cost_time} ms') | |||
| mAP = metrics(pred_data) | |||
| mAP = metrics(pred_data, anno_json) | |||
| print("\n========================================\n") | |||
| print(f"mAP: {mAP}") | |||
| if __name__ == '__main__': | |||
| def get_eval_args(): | |||
| parser = argparse.ArgumentParser(description='SSD evaluation') | |||
| parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") | |||
| parser.add_argument("--dataset", type=str, default="coco", help="Dataset, default is coco.") | |||
| parser.add_argument("--checkpoint_path", type=str, required=True, help="Checkpoint file path.") | |||
| parser.add_argument("--run_platform", type=str, default="Ascend", choices=("Ascend", "GPU"), | |||
| help="run platform, only support Ascend and GPU.") | |||
| args_opt = parser.parse_args() | |||
| parser.add_argument("--run_platform", type=str, default="Ascend", choices=("Ascend", "GPU", "CPU"), | |||
| help="run platform, support Ascend ,GPU and CPU.") | |||
| return parser.parse_args() | |||
| if __name__ == '__main__': | |||
| args_opt = get_eval_args() | |||
| if args_opt.dataset == "coco": | |||
| json_path = os.path.join(config.coco_root, config.instances_set.format(config.val_data_type)) | |||
| elif args_opt.dataset == "voc": | |||
| json_path = os.path.join(config.voc_root, config.voc_json) | |||
| else: | |||
| raise ValueError('SSD eval only supprt dataset mode is coco and voc!') | |||
| context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.run_platform, device_id=args_opt.device_id) | |||
| prefix = "ssd_eval.mindrecord" | |||
| mindrecord_dir = config.mindrecord_dir | |||
| mindrecord_file = os.path.join(mindrecord_dir, prefix + "0") | |||
| if args_opt.dataset == "voc": | |||
| config.coco_root = config.voc_root | |||
| if not os.path.exists(mindrecord_file): | |||
| if not os.path.isdir(mindrecord_dir): | |||
| os.makedirs(mindrecord_dir) | |||
| if args_opt.dataset == "coco": | |||
| if os.path.isdir(config.coco_root): | |||
| print("Create Mindrecord.") | |||
| data_to_mindrecord_byte_image("coco", False, prefix) | |||
| print("Create Mindrecord Done, at {}".format(mindrecord_dir)) | |||
| else: | |||
| print("coco_root not exits.") | |||
| elif args_opt.dataset == "voc": | |||
| if os.path.isdir(config.voc_dir) and os.path.isdir(config.voc_root): | |||
| print("Create Mindrecord.") | |||
| voc_data_to_mindrecord(mindrecord_dir, False, prefix) | |||
| print("Create Mindrecord Done, at {}".format(mindrecord_dir)) | |||
| else: | |||
| print("voc_root or voc_dir not exits.") | |||
| else: | |||
| if os.path.isdir(config.image_dir) and os.path.exists(config.anno_path): | |||
| print("Create Mindrecord.") | |||
| data_to_mindrecord_byte_image("other", False, prefix) | |||
| print("Create Mindrecord Done, at {}".format(mindrecord_dir)) | |||
| else: | |||
| print("IMAGE_DIR or ANNO_PATH not exits.") | |||
| mindrecord_file = create_mindrecord(args_opt.dataset, "ssd_eval.mindrecord", False) | |||
| print("Start Eval!") | |||
| ssd_eval(mindrecord_file, args_opt.checkpoint_path) | |||
| ssd_eval(mindrecord_file, args_opt.checkpoint_path, json_path) | |||
| @@ -0,0 +1,41 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| ssd export mindir. | |||
| """ | |||
| import argparse | |||
| import numpy as np | |||
| from mindspore import context, Tensor | |||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net, export | |||
| from src.ssd import SSD300, ssd_mobilenet_v2 | |||
| from src.config import config | |||
| def get_export_args(): | |||
| parser = argparse.ArgumentParser(description='SSD export') | |||
| parser.add_argument("--checkpoint_path", type=str, required=True, help="Checkpoint file path.") | |||
| parser.add_argument("--run_platform", type=str, default="Ascend", choices=("Ascend", "GPU", "CPU"), | |||
| help="run platform, support Ascend, GPU and CPU.") | |||
| return parser.parse_args() | |||
| if __name__ == '__main__': | |||
| args_opt = get_export_args() | |||
| context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.run_platform) | |||
| net = SSD300(ssd_mobilenet_v2(), config, is_training=False) | |||
| param_dict = load_checkpoint(args_opt.checkpoint_path) | |||
| load_param_into_net(net, param_dict) | |||
| input_shp = [1, 3] + config.img_shape | |||
| input_array = Tensor(np.random.uniform(-1.0, 1.0, size=input_shp).astype(np.float32)) | |||
| export(net, input_array, file_name=config.export_file, file_format=config.export_format) | |||
| @@ -25,7 +25,7 @@ class GeneratDefaultBoxes(): | |||
| """ | |||
| Generate Default boxes for SSD, follows the order of (W, H, archor_sizes). | |||
| `self.default_boxes` has a shape of [archor_sizes, H, W, 4], the last dimension is [y, x, h, w]. | |||
| `self.default_boxes_ltrb` has a shape as `self.default_boxes`, the last dimension is [y1, x1, y2, x2]. | |||
| `self.default_boxes_tlbr` has a shape as `self.default_boxes`, the last dimension is [y1, x1, y2, x2]. | |||
| """ | |||
| def __init__(self): | |||
| fk = config.img_shape[0] / np.array(config.steps) | |||
| @@ -54,17 +54,17 @@ class GeneratDefaultBoxes(): | |||
| cx, cy = (j + 0.5) / fk[idex], (i + 0.5) / fk[idex] | |||
| self.default_boxes.append([cy, cx, h, w]) | |||
| def to_ltrb(cy, cx, h, w): | |||
| def to_tlbr(cy, cx, h, w): | |||
| return cy - h / 2, cx - w / 2, cy + h / 2, cx + w / 2 | |||
| # For IoU calculation | |||
| self.default_boxes_ltrb = np.array(tuple(to_ltrb(*i) for i in self.default_boxes), dtype='float32') | |||
| self.default_boxes_tlbr = np.array(tuple(to_tlbr(*i) for i in self.default_boxes), dtype='float32') | |||
| self.default_boxes = np.array(self.default_boxes, dtype='float32') | |||
| default_boxes_ltrb = GeneratDefaultBoxes().default_boxes_ltrb | |||
| default_boxes_tlbr = GeneratDefaultBoxes().default_boxes_tlbr | |||
| default_boxes = GeneratDefaultBoxes().default_boxes | |||
| y1, x1, y2, x2 = np.split(default_boxes_ltrb[:, :4], 4, axis=-1) | |||
| y1, x1, y2, x2 = np.split(default_boxes_tlbr[:, :4], 4, axis=-1) | |||
| vol_anchors = (x2 - x1) * (y2 - y1) | |||
| matching_threshold = config.match_threshold | |||
| @@ -115,7 +115,7 @@ def ssd_bboxes_encode(boxes): | |||
| index = np.nonzero(t_label) | |||
| # Transform to ltrb. | |||
| # Transform to tlbr. | |||
| bboxes = np.zeros((config.num_ssd_boxes, 4), dtype=np.float32) | |||
| bboxes[:, [0, 1]] = (t_boxes[:, [0, 1]] + t_boxes[:, [2, 3]]) / 2 | |||
| bboxes[:, [2, 3]] = t_boxes[:, [2, 3]] - t_boxes[:, [0, 1]] | |||
| @@ -27,7 +27,6 @@ config = ed({ | |||
| "max_boxes": 100, | |||
| # learing rate settings | |||
| "global_step": 0, | |||
| "lr_init": 0.001, | |||
| "lr_end_rate": 0.001, | |||
| "warmup_epochs": 2, | |||
| @@ -55,27 +54,29 @@ config = ed({ | |||
| "train_data_type": "train2017", | |||
| "val_data_type": "val2017", | |||
| "instances_set": "annotations/instances_{}.json", | |||
| "coco_classes": ('background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', | |||
| 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', | |||
| 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', | |||
| 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', | |||
| 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', | |||
| 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', | |||
| 'kite', 'baseball bat', 'baseball glove', 'skateboard', | |||
| 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', | |||
| 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', | |||
| 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', | |||
| 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', | |||
| 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', | |||
| 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', | |||
| 'refrigerator', 'book', 'clock', 'vase', 'scissors', | |||
| 'teddy bear', 'hair drier', 'toothbrush'), | |||
| "classes": ('background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', | |||
| 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', | |||
| 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', | |||
| 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', | |||
| 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', | |||
| 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', | |||
| 'kite', 'baseball bat', 'baseball glove', 'skateboard', | |||
| 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', | |||
| 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', | |||
| 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', | |||
| 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', | |||
| 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', | |||
| 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', | |||
| 'refrigerator', 'book', 'clock', 'vase', 'scissors', | |||
| 'teddy bear', 'hair drier', 'toothbrush'), | |||
| "num_classes": 81, | |||
| # The annotation.json position of voc validation dataset. | |||
| "voc_root": "", | |||
| "voc_json": "annotations/voc_instances_val.json", | |||
| # voc original dataset. | |||
| "voc_dir": "", | |||
| "voc_root": "/data/voc_dataset", | |||
| # if coco or voc used, `image_dir` and `anno_path` are useless. | |||
| "image_dir": "", | |||
| "anno_path": "", | |||
| "export_format": "MINDIR", | |||
| "export_file": "ssd.mindir" | |||
| }) | |||
| @@ -159,10 +159,10 @@ def preprocess_fn(img_id, image, box, is_training): | |||
| def create_voc_label(is_training): | |||
| """Get image path and annotation from VOC.""" | |||
| voc_dir = config.voc_dir | |||
| cls_map = {name: i for i, name in enumerate(config.coco_classes)} | |||
| voc_root = config.voc_root | |||
| cls_map = {name: i for i, name in enumerate(config.classes)} | |||
| sub_dir = 'train' if is_training else 'eval' | |||
| voc_dir = os.path.join(voc_dir, sub_dir) | |||
| voc_dir = os.path.join(voc_root, sub_dir) | |||
| if not os.path.isdir(voc_dir): | |||
| raise ValueError(f'Cannot find {sub_dir} dataset path.') | |||
| @@ -173,8 +173,7 @@ def create_voc_label(is_training): | |||
| anno_dir = os.path.join(voc_dir, 'Annotations') | |||
| if not is_training: | |||
| data_dir = config.voc_root | |||
| json_file = os.path.join(data_dir, config.instances_set.format(sub_dir)) | |||
| json_file = os.path.join(config.voc_root, config.voc_json) | |||
| file_dir = os.path.split(json_file)[0] | |||
| if not os.path.isdir(file_dir): | |||
| os.makedirs(file_dir) | |||
| @@ -203,7 +202,7 @@ def create_voc_label(is_training): | |||
| for obj in root_node.iter('object'): | |||
| cls_name = obj.find('name').text | |||
| if cls_name not in cls_map: | |||
| print(f'Label "{cls_name}" not in "{config.coco_classes}"') | |||
| print(f'Label "{cls_name}" not in "{config.classes}"') | |||
| continue | |||
| bnd_box = obj.find('bndbox') | |||
| x_min = int(bnd_box.find('xmin').text) - 1 | |||
| @@ -258,7 +257,7 @@ def create_coco_label(is_training): | |||
| data_type = config.train_data_type | |||
| # Classes need to train or test. | |||
| train_cls = config.coco_classes | |||
| train_cls = config.classes | |||
| train_cls_dict = {} | |||
| for i, cls in enumerate(train_cls): | |||
| train_cls_dict[cls] = i | |||
| @@ -390,7 +389,7 @@ def data_to_mindrecord_byte_image(dataset="coco", is_training=True, prefix="ssd. | |||
| def create_ssd_dataset(mindrecord_file, batch_size=32, repeat_num=10, device_num=1, rank=0, | |||
| is_training=True, num_parallel_workers=4): | |||
| is_training=True, num_parallel_workers=4, use_multiprocessing=True): | |||
| """Creatr SSD dataset with MindDataset.""" | |||
| ds = de.MindDataset(mindrecord_file, columns_list=["img_id", "image", "annotation"], num_shards=device_num, | |||
| shard_id=rank, num_parallel_workers=num_parallel_workers, shuffle=is_training) | |||
| @@ -409,10 +408,45 @@ def create_ssd_dataset(mindrecord_file, batch_size=32, repeat_num=10, device_num | |||
| trans = [normalize_op, change_swap_op] | |||
| ds = ds.map(operations=compose_map_func, input_columns=["img_id", "image", "annotation"], | |||
| output_columns=output_columns, column_order=output_columns, | |||
| python_multiprocessing=is_training, | |||
| python_multiprocessing=use_multiprocessing, | |||
| num_parallel_workers=num_parallel_workers) | |||
| ds = ds.map(operations=trans, input_columns=["image"], python_multiprocessing=is_training, | |||
| ds = ds.map(operations=trans, input_columns=["image"], python_multiprocessing=use_multiprocessing, | |||
| num_parallel_workers=num_parallel_workers) | |||
| ds = ds.batch(batch_size, drop_remainder=True) | |||
| ds = ds.repeat(repeat_num) | |||
| return ds | |||
| def create_mindrecord(dataset="coco", prefix="ssd.mindrecord", is_training=True): | |||
| print("Start create dataset!") | |||
| # It will generate mindrecord file in config.mindrecord_dir, | |||
| # and the file name is ssd.mindrecord0, 1, ... file_num. | |||
| mindrecord_dir = config.mindrecord_dir | |||
| mindrecord_file = os.path.join(mindrecord_dir, prefix + "0") | |||
| if not os.path.exists(mindrecord_file): | |||
| if not os.path.isdir(mindrecord_dir): | |||
| os.makedirs(mindrecord_dir) | |||
| if dataset == "coco": | |||
| if os.path.isdir(config.coco_root): | |||
| print("Create Mindrecord.") | |||
| data_to_mindrecord_byte_image("coco", is_training, prefix) | |||
| print("Create Mindrecord Done, at {}".format(mindrecord_dir)) | |||
| else: | |||
| print("coco_root not exits.") | |||
| elif dataset == "voc": | |||
| if os.path.isdir(config.voc_root): | |||
| print("Create Mindrecord.") | |||
| voc_data_to_mindrecord(mindrecord_dir, is_training, prefix) | |||
| print("Create Mindrecord Done, at {}".format(mindrecord_dir)) | |||
| else: | |||
| print("voc_root not exits.") | |||
| else: | |||
| if os.path.isdir(config.image_dir) and os.path.exists(config.anno_path): | |||
| print("Create Mindrecord.") | |||
| data_to_mindrecord_byte_image("other", is_training, prefix) | |||
| print("Create Mindrecord Done, at {}".format(mindrecord_dir)) | |||
| else: | |||
| print("image_dir or anno_path not exits.") | |||
| return mindrecord_file | |||
| @@ -14,7 +14,6 @@ | |||
| # ============================================================================ | |||
| """Coco metrics utils""" | |||
| import os | |||
| import json | |||
| import numpy as np | |||
| from .config import config | |||
| @@ -56,22 +55,17 @@ def apply_nms(all_boxes, all_scores, thres, max_boxes): | |||
| return keep | |||
| def metrics(pred_data): | |||
| def metrics(pred_data, anno_json): | |||
| """Calculate mAP of predicted bboxes.""" | |||
| from pycocotools.coco import COCO | |||
| from pycocotools.cocoeval import COCOeval | |||
| num_classes = config.num_classes | |||
| coco_root = config.coco_root | |||
| data_type = config.val_data_type | |||
| #Classes need to train or test. | |||
| val_cls = config.coco_classes | |||
| val_cls = config.classes | |||
| val_cls_dict = {} | |||
| for i, cls in enumerate(val_cls): | |||
| val_cls_dict[i] = cls | |||
| anno_json = os.path.join(coco_root, config.instances_set.format(data_type)) | |||
| coco_gt = COCO(anno_json) | |||
| classs_dict = {} | |||
| cat_ids = coco_gt.loadCats(coco_gt.getCatIds()) | |||
| @@ -15,7 +15,6 @@ | |||
| """Train SSD and get checkpoint files.""" | |||
| import os | |||
| import argparse | |||
| import ast | |||
| import mindspore.nn as nn | |||
| @@ -28,14 +27,16 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| from mindspore.common import set_seed, dtype | |||
| from src.ssd import SSD300, SSDWithLossCell, TrainingWrapper, ssd_mobilenet_v2 | |||
| from src.config import config | |||
| from src.dataset import create_ssd_dataset, data_to_mindrecord_byte_image, voc_data_to_mindrecord | |||
| from src.dataset import create_ssd_dataset, create_mindrecord | |||
| from src.lr_schedule import get_lr | |||
| from src.init_params import init_net_param, filter_checkpoint_parameter | |||
| set_seed(1) | |||
| def main(): | |||
| def get_args(): | |||
| parser = argparse.ArgumentParser(description="SSD training") | |||
| parser.add_argument("--run_platform", type=str, default="Ascend", choices=("Ascend", "GPU", "CPU"), | |||
| help="run platform, support Ascend, GPU and CPU.") | |||
| parser.add_argument("--only_create_dataset", type=ast.literal_eval, default=False, | |||
| help="If set it true, only create Mindrecord, default is False.") | |||
| parser.add_argument("--distribute", type=ast.literal_eval, default=False, | |||
| @@ -52,77 +53,39 @@ def main(): | |||
| parser.add_argument("--save_checkpoint_epochs", type=int, default=10, help="Save checkpoint epochs, default is 10.") | |||
| parser.add_argument("--loss_scale", type=int, default=1024, help="Loss scale, default is 1024.") | |||
| parser.add_argument("--filter_weight", type=ast.literal_eval, default=False, | |||
| help="Filter weight parameters, default is False.") | |||
| parser.add_argument("--run_platform", type=str, default="Ascend", choices=("Ascend", "GPU"), | |||
| help="run platform, only support Ascend and GPU.") | |||
| help="Filter head weight parameters, default is False.") | |||
| parser.add_argument('--freeze_layer', type=str, default="none", choices=["none", "backbone"], | |||
| help="freeze the weights of network, support freeze the backbone's weights, " | |||
| "default is not freezing.") | |||
| args_opt = parser.parse_args() | |||
| return args_opt | |||
| if args_opt.run_platform == "Ascend": | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id) | |||
| def main(): | |||
| args_opt = get_args() | |||
| rank = 0 | |||
| device_num = 1 | |||
| if args_opt.run_platform == "CPU": | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="CPU") | |||
| else: | |||
| context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.run_platform, device_id=args_opt.device_id) | |||
| if args_opt.distribute: | |||
| device_num = args_opt.device_num | |||
| context.reset_auto_parallel_context() | |||
| context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True, | |||
| device_num=device_num) | |||
| init() | |||
| rank = args_opt.device_id % device_num | |||
| else: | |||
| rank = 0 | |||
| device_num = 1 | |||
| elif args_opt.run_platform == "GPU": | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU", device_id=args_opt.device_id) | |||
| init() | |||
| if args_opt.distribute: | |||
| device_num = args_opt.device_num | |||
| context.reset_auto_parallel_context() | |||
| context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True, | |||
| device_num=device_num) | |||
| rank = get_rank() | |||
| else: | |||
| rank = 0 | |||
| device_num = 1 | |||
| else: | |||
| raise ValueError("Unsupported platform.") | |||
| print("Start create dataset!") | |||
| # It will generate mindrecord file in args_opt.mindrecord_dir, | |||
| # and the file name is ssd.mindrecord0, 1, ... file_num. | |||
| prefix = "ssd.mindrecord" | |||
| mindrecord_dir = config.mindrecord_dir | |||
| mindrecord_file = os.path.join(mindrecord_dir, prefix + "0") | |||
| if not os.path.exists(mindrecord_file): | |||
| if not os.path.isdir(mindrecord_dir): | |||
| os.makedirs(mindrecord_dir) | |||
| if args_opt.dataset == "coco": | |||
| if os.path.isdir(config.coco_root): | |||
| print("Create Mindrecord.") | |||
| data_to_mindrecord_byte_image("coco", True, prefix) | |||
| print("Create Mindrecord Done, at {}".format(mindrecord_dir)) | |||
| else: | |||
| print("coco_root not exits.") | |||
| elif args_opt.dataset == "voc": | |||
| if os.path.isdir(config.voc_dir): | |||
| print("Create Mindrecord.") | |||
| voc_data_to_mindrecord(mindrecord_dir, True, prefix) | |||
| print("Create Mindrecord Done, at {}".format(mindrecord_dir)) | |||
| else: | |||
| print("voc_dir not exits.") | |||
| else: | |||
| if os.path.isdir(config.image_dir) and os.path.exists(config.anno_path): | |||
| print("Create Mindrecord.") | |||
| data_to_mindrecord_byte_image("other", True, prefix) | |||
| print("Create Mindrecord Done, at {}".format(mindrecord_dir)) | |||
| else: | |||
| print("image_dir or anno_path not exits.") | |||
| mindrecord_file = create_mindrecord(args_opt.dataset, "ssd.mindrecord", True) | |||
| if not args_opt.only_create_dataset: | |||
| loss_scale = float(args_opt.loss_scale) | |||
| if args_opt.run_platform == "CPU": | |||
| loss_scale = 1.0 | |||
| # When create MindDataset, using the fitst mindrecord file, such as ssd.mindrecord0. | |||
| dataset = create_ssd_dataset(mindrecord_file, repeat_num=1, | |||
| batch_size=args_opt.batch_size, device_num=device_num, rank=rank) | |||
| use_multiprocessing = (args_opt.run_platform != "CPU") | |||
| dataset = create_ssd_dataset(mindrecord_file, repeat_num=1, batch_size=args_opt.batch_size, | |||
| device_num=device_num, rank=rank, use_multiprocessing=use_multiprocessing) | |||
| dataset_size = dataset.get_dataset_size() | |||
| print("Create dataset done!") | |||
| @@ -140,27 +103,30 @@ def main(): | |||
| ckpoint_cb = ModelCheckpoint(prefix="ssd", directory=save_ckpt_path, config=ckpt_config) | |||
| if args_opt.pre_trained: | |||
| if args_opt.pre_trained_epoch_size <= 0: | |||
| raise KeyError("pre_trained_epoch_size must be greater than 0.") | |||
| param_dict = load_checkpoint(args_opt.pre_trained) | |||
| if args_opt.filter_weight: | |||
| filter_checkpoint_parameter(param_dict) | |||
| load_param_into_net(net, param_dict) | |||
| lr = Tensor(get_lr(global_step=config.global_step, | |||
| if args_opt.freeze_layer == "backbone": | |||
| for param in backbone.feature_1.trainable_params(): | |||
| param.requires_grad = False | |||
| lr = Tensor(get_lr(global_step=args_opt.pre_trained_epoch_size * dataset_size, | |||
| lr_init=config.lr_init, lr_end=config.lr_end_rate * args_opt.lr, lr_max=args_opt.lr, | |||
| warmup_epochs=config.warmup_epochs, | |||
| total_epochs=args_opt.epoch_size, | |||
| steps_per_epoch=dataset_size)) | |||
| opt = nn.Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, | |||
| config.momentum, config.weight_decay, loss_scale) | |||
| net = TrainingWrapper(net, opt, loss_scale) | |||
| callback = [TimeMonitor(data_size=dataset_size), LossMonitor(), ckpoint_cb] | |||
| model = Model(net) | |||
| dataset_sink_mode = False | |||
| if args_opt.mode == "sink": | |||
| if args_opt.mode == "sink" and args_opt.run_platform != "CPU": | |||
| print("In sink mode, one epoch return a loss.") | |||
| dataset_sink_mode = True | |||
| print("Start train SSD, the first epoch will be slower because of the graph compilation.") | |||