Browse Source

Modify ssd_ghostnet network.

tags/v1.1.0
zhanghuiyao 5 years ago
parent
commit
3891f802b6
7 changed files with 56 additions and 50 deletions
  1. +1
    -1
      model_zoo/official/cv/openpose/README.md
  2. +35
    -27
      model_zoo/research/cv/ssd_ghostnet/README.md
  3. +4
    -4
      model_zoo/research/cv/ssd_ghostnet/eval.py
  4. +6
    -6
      model_zoo/research/cv/ssd_ghostnet/src/dataset.py
  5. +4
    -4
      model_zoo/research/cv/ssd_ghostnet/src/init_params.py
  6. +5
    -7
      model_zoo/research/cv/ssd_ghostnet/src/ssd_ghostnet.py
  7. +1
    -1
      model_zoo/research/cv/ssd_ghostnet/train.py

+ 1
- 1
model_zoo/official/cv/openpose/README.md View File

@@ -195,7 +195,7 @@ For more configuration details, please refer the script `config.py`.
```python ```python
# grep "AP" eval.log # grep "AP" eval.log


{'AP': 0.39830956300341397, 'Ap .5': 0.6658941566481336, 'AP .75': 0.396047897339743, 'AP (M)': 0.3075356543635785, 'AP (L)': 0.533772768618845, 'AR': 0.4519836272040302, 'AR .5': 0.693639798488665, 'AR .75': 0.4570214105793451, 'AR (M)': 0.32155148866429945, 'AR (L)': 0.6330360460795242}
{'AP': 0.40250956300341397, 'Ap .5': 0.6658941566481336, 'AP .75': 0.396047897339743, 'AP (M)': 0.3075356543635785, 'AP (L)': 0.533772768618845, 'AR': 0.4519836272040302, 'AR .5': 0.693639798488665, 'AR .75': 0.4570214105793451, 'AR (M)': 0.32155148866429945, 'AR (L)': 0.6330360460795242}


``` ```




+ 35
- 27
model_zoo/research/cv/ssd_ghostnet/README.md View File

@@ -1,5 +1,5 @@
# [SSD Description](#contents) # [SSD Description](#contents)
SSD discretizes the output space of bounding boxes into a set of default boxes over different aspect ratios and scales per feature map location. At prediction time, the network generates scores for the presence of each object category in each default box and produces adjustments to the box to better match the object shape.Additionally, the network combines predictions from multiple feature maps with different resolutions to naturally handle objects of various sizes. SSD discretizes the output space of bounding boxes into a set of default boxes over different aspect ratios and scales per feature map location. At prediction time, the network generates scores for the presence of each object category in each default box and produces adjustments to the box to better match the object shape.Additionally, the network combines predictions from multiple feature maps with different resolutions to naturally handle objects of various sizes.


[Paper](https://arxiv.org/abs/1512.02325): Wei Liu, Dragomir Anguelov, Dumitru Erhan, Christian Szegedy, Scott Reed, Cheng-Yang Fu, Alexander C. Berg.European Conference on Computer Vision (ECCV), 2016 (In press). [Paper](https://arxiv.org/abs/1512.02325): Wei Liu, Dragomir Anguelov, Dumitru Erhan, Christian Szegedy, Scott Reed, Cheng-Yang Fu, Alexander C. Berg.European Conference on Computer Vision (ECCV), 2016 (In press).
@@ -9,24 +9,25 @@ SSD discretizes the output space of bounding boxes into a set of default boxes o
The SSD approach is based on a feed-forward convolutional network that produces a fixed-size collection of bounding boxes and scores for the presence of object class instances in those boxes, followed by a non-maximum suppression step to produce the final detections. The early network layers are based on a standard architecture used for high quality image classification, which is called the base network. Then add auxiliary structure to the network to produce detections. The SSD approach is based on a feed-forward convolutional network that produces a fixed-size collection of bounding boxes and scores for the presence of object class instances in those boxes, followed by a non-maximum suppression step to produce the final detections. The early network layers are based on a standard architecture used for high quality image classification, which is called the base network. Then add auxiliary structure to the network to produce detections.


# [Dataset](#contents) # [Dataset](#contents)
Dataset used: [COCO2017](<http://images.cocodataset.org/>)

Dataset used: [COCO2017](<http://images.cocodataset.org/>)


- Dataset size:19G - Dataset size:19G
- Train:18G,118000 images
- Val:1G,5000 images
- Annotations:241M,instances,captions,person_keypoints etc
- Train:18G,118000 images
- Val:1G,5000 images
- Annotations:241M,instances,captions,person_keypoints etc
- Data format:image and json files - Data format:image and json files
- Note:Data will be processed in dataset.py
- Note:Data will be processed in dataset.py


# [Environment Requirements](#contents) # [Environment Requirements](#contents)


- Hardware(Ascend/GPU) - Hardware(Ascend/GPU)
- Prepare hardware environment with Ascend or GPU processor. If you want to try Ascend, please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources.
- Prepare hardware environment with Ascend or GPU processor. If you want to try Ascend, please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources.
- Framework - Framework
- [MindSpore](https://www.mindspore.cn/install/en)
- [MindSpore](https://www.mindspore.cn/install/en)
- For more information, please check the resources below: - For more information, please check the resources below:
- [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
- [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)


- Install [MindSpore](https://www.mindspore.cn/install/en). - Install [MindSpore](https://www.mindspore.cn/install/en).


@@ -37,15 +38,16 @@ Dataset used: [COCO2017](<http://images.cocodataset.org/>)
1. If coco dataset is used. **Select dataset to coco when run script.** 1. If coco dataset is used. **Select dataset to coco when run script.**
Install Cython and pycocotool, and you can also install mmcv to process data. Install Cython and pycocotool, and you can also install mmcv to process data.


```
```bash
pip install Cython pip install Cython


pip install pycocotools pip install pycocotools


``` ```

And change the COCO_ROOT and other settings you need in `config.py`. The directory structure is as follows: And change the COCO_ROOT and other settings you need in `config.py`. The directory structure is as follows:
```
```python
. .
└─cocodataset └─cocodataset
├─annotations ├─annotations
@@ -59,18 +61,18 @@ Dataset used: [COCO2017](<http://images.cocodataset.org/>)
2. If your own dataset is used. **Select dataset to other when run script.** 2. If your own dataset is used. **Select dataset to other when run script.**
Organize the dataset infomation into a TXT file, each row in the file is as follows: Organize the dataset infomation into a TXT file, each row in the file is as follows:


```
```python
train2017/0000001.jpg 0,259,401,459,7 35,28,324,201,2 0,30,59,80,2 train2017/0000001.jpg 0,259,401,459,7 35,28,324,201,2 0,30,59,80,2


``` ```


Each row is an image annotation which split by space, the first column is a relative path of image, the others are box and class infomations of the format [xmin,ymin,xmax,ymax,class]. We read image from an image path joined by the `IMAGE_DIR`(dataset directory) and the relative path in `ANNO_PATH`(the TXT file path), `IMAGE_DIR` and `ANNO_PATH` are setting in `config.py`.
Each row is an image annotation which split by space, the first column is a relative path of image, the others are box and class infomations of the format [xmin,ymin,xmax,ymax,class]. We read image from an image path joined by the `IMAGE_DIR`(dataset directory) and the relative path in `ANNO_PATH`(the TXT file path), `IMAGE_DIR` and `ANNO_PATH` are setting in `config.py`.


# [Quick Start](#contents) # [Quick Start](#contents)


After installing MindSpore via the official website, you can start training and evaluation on Ascend as follows:
After installing MindSpore via the official website, you can start training and evaluation on Ascend as follows:


```
```bash
# single npu training on Ascend # single npu training on Ascend
python train.py python train.py


@@ -85,9 +87,9 @@ python eval.py --device_id 0 --dataset coco --checkpoint_path LOG4/ssd-500_458.c


## [Script and Sample Code](#contents) ## [Script and Sample Code](#contents)


```shell
```python


├── ssd_ghostnet
├── ssd_ghostnet
├── README.md ## readme file of ssd_ghostnet ├── README.md ## readme file of ssd_ghostnet
├── scripts ├── scripts
└─ run_distribute_train_ghostnet.sh ## shell script for distributed on ascend └─ run_distribute_train_ghostnet.sh ## shell script for distributed on ascend
@@ -106,7 +108,7 @@ python eval.py --device_id 0 --dataset coco --checkpoint_path LOG4/ssd-500_458.c


## [Script Parameters](#contents) ## [Script Parameters](#contents)


```
```python
Major parameters in train.py and config_ghostnet_13x.py as follows: Major parameters in train.py and config_ghostnet_13x.py as follows:


"device_num": 1 # Use device nums "device_num": 1 # Use device nums
@@ -129,39 +131,46 @@ python eval.py --device_id 0 --dataset coco --checkpoint_path LOG4/ssd-500_458.c


``` ```



## [Training Process](#contents) ## [Training Process](#contents)


### Training on Ascend ### Training on Ascend


To train the model, run `train.py`. If the `mindrecord_dir` is empty, it will generate [mindrecord](https://www.mindspore.cn/tutorial/training/zh-CN/master/advanced_use/convert_dataset.html) files by `coco_root`(coco dataset) or `iamge_dir` and `anno_path`(own dataset). **Note if mindrecord_dir isn't empty, it will use mindrecord_dir instead of raw images.** To train the model, run `train.py`. If the `mindrecord_dir` is empty, it will generate [mindrecord](https://www.mindspore.cn/tutorial/training/zh-CN/master/advanced_use/convert_dataset.html) files by `coco_root`(coco dataset) or `iamge_dir` and `anno_path`(own dataset). **Note if mindrecord_dir isn't empty, it will use mindrecord_dir instead of raw images.**



- Distribute mode - Distribute mode


```
```bash
sh run_distribute_train_ghostnet.sh [DEVICE_NUM] [EPOCH_SIZE] [LR] [DATASET] [RANK_TABLE_FILE] [PRE_TRAINED](optional) [PRE_TRAINED_EPOCH_SIZE](optional) sh run_distribute_train_ghostnet.sh [DEVICE_NUM] [EPOCH_SIZE] [LR] [DATASET] [RANK_TABLE_FILE] [PRE_TRAINED](optional) [PRE_TRAINED_EPOCH_SIZE](optional)
``` ```

We need five or seven parameters for this scripts. We need five or seven parameters for this scripts.

- `DEVICE_NUM`: the device number for distributed train. - `DEVICE_NUM`: the device number for distributed train.

- `EPOCH_NUM`: epoch num for distributed train. - `EPOCH_NUM`: epoch num for distributed train.

- `LR`: learning rate init value for distributed train. - `LR`: learning rate init value for distributed train.

- `DATASET`:the dataset mode for distributed train. - `DATASET`:the dataset mode for distributed train.

- `RANK_TABLE_FILE :` the path of [rank_table.json](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools), it is better to use absolute path. - `RANK_TABLE_FILE :` the path of [rank_table.json](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools), it is better to use absolute path.

- `PRE_TRAINED :` the path of pretrained checkpoint file, it is better to use absolute path. - `PRE_TRAINED :` the path of pretrained checkpoint file, it is better to use absolute path.

- `PRE_TRAINED_EPOCH_SIZE :` the epoch num of pretrained. - `PRE_TRAINED_EPOCH_SIZE :` the epoch num of pretrained.


Training result will be stored in the current path, whose folder name begins with "LOG". Under this, you can find checkpoint file together with result like the followings in LOG4/log.txt.
Training result will be stored in the current path, whose folder name begins with "LOG". Under this, you can find checkpoint file together with result like the followings in LOG4/log.txt.


## [Evaluation Process](#contents) ## [Evaluation Process](#contents)


### Evaluation on Ascend ### Evaluation on Ascend


```
python eval.py --device_id 0 --dataset coco --checkpoint_path LOG4/ssd-500_458.ckpt
```bash
python eval.py --device_id 0 --dataset coco --checkpoint_path LOG4/ssd-500_458.ckpt
``` ```


# [Model Description](#contents) # [Model Description](#contents)

## [Performance](#contents) ## [Performance](#contents)


### Evaluation Performance ### Evaluation Performance
@@ -177,7 +186,6 @@ python eval.py --device_id 0 --dataset coco --checkpoint_path LOG4/ssd-500_458.c
| Loss Function | Sigmoid Cross Entropy,SmoothL1Loss | | Loss Function | Sigmoid Cross Entropy,SmoothL1Loss |
| Total time | 8pcs: 12hours | | Total time | 8pcs: 12hours |



### Inference Performance ### Inference Performance


| Parameters | Ascend | | Parameters | Ascend |


+ 4
- 4
model_zoo/research/cv/ssd_ghostnet/eval.py View File

@@ -19,7 +19,7 @@ import os
import argparse import argparse
import time import time
import numpy as np import numpy as np
from mindspore import context, Tensor
from mindspore import context
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.ssd_ghostnet import SSD300, ssd_ghostnet from src.ssd_ghostnet import SSD300, ssd_ghostnet
from src.dataset import create_ssd_dataset, data_to_mindrecord_byte_image, voc_data_to_mindrecord from src.dataset import create_ssd_dataset, data_to_mindrecord_byte_image, voc_data_to_mindrecord
@@ -47,11 +47,11 @@ def ssd_eval(dataset_path, ckpt_path):
print("total images num: ", total) print("total images num: ", total)
print("Processing, please wait a moment.") print("Processing, please wait a moment.")
for data in ds.create_dict_iterator(): for data in ds.create_dict_iterator():
img_id = data['img_id']
img_id = data['img_id'].asnumpy()
img_np = data['image'] img_np = data['image']
image_shape = data['image_shape']
image_shape = data['image_shape'].asnumpy()


output = net(Tensor(img_np))
output = net(img_np)
for batch_idx in range(img_np.shape[0]): for batch_idx in range(img_np.shape[0]):
pred_data.append({"boxes": output[0].asnumpy()[batch_idx], pred_data.append({"boxes": output[0].asnumpy()[batch_idx],
"box_scores": output[1].asnumpy()[batch_idx], "box_scores": output[1].asnumpy()[batch_idx],


+ 6
- 6
model_zoo/research/cv/ssd_ghostnet/src/dataset.py View File

@@ -24,7 +24,7 @@ import numpy as np
import cv2 import cv2


import mindspore.dataset as de import mindspore.dataset as de
import mindspore.dataset.transforms.vision.c_transforms as C
import mindspore.dataset.vision.c_transforms as C2
from mindspore.mindrecord import FileWriter from mindspore.mindrecord import FileWriter
from .config_ghostnet_13x import config from .config_ghostnet_13x import config
from .box_utils import jaccard_numpy, ssd_bboxes_encode from .box_utils import jaccard_numpy, ssd_bboxes_encode
@@ -397,12 +397,12 @@ def create_ssd_dataset(mindrecord_file, batch_size=32, repeat_num=10, device_num
"""Create SSD dataset with MindDataset.""" """Create SSD dataset with MindDataset."""
ds = de.MindDataset(mindrecord_file, columns_list=["img_id", "image", "annotation"], num_shards=device_num, ds = de.MindDataset(mindrecord_file, columns_list=["img_id", "image", "annotation"], num_shards=device_num,
shard_id=rank, num_parallel_workers=num_parallel_workers, shuffle=is_training) shard_id=rank, num_parallel_workers=num_parallel_workers, shuffle=is_training)
decode = C.Decode()
decode = C2.Decode()
ds = ds.map(input_columns=["image"], operations=decode) ds = ds.map(input_columns=["image"], operations=decode)
change_swap_op = C.HWC2CHW()
normalize_op = C.Normalize(
change_swap_op = C2.HWC2CHW()
normalize_op = C2.Normalize(
mean=[0.485*255, 0.456*255, 0.406*255], std=[0.229*255, 0.224*255, 0.225*255]) mean=[0.485*255, 0.456*255, 0.406*255], std=[0.229*255, 0.224*255, 0.225*255])
color_adjust_op = C.RandomColorAdjust(
color_adjust_op = C2.RandomColorAdjust(
brightness=0.4, contrast=0.4, saturation=0.4) brightness=0.4, contrast=0.4, saturation=0.4)
compose_map_func = (lambda img_id, image, annotation: preprocess_fn( compose_map_func = (lambda img_id, image, annotation: preprocess_fn(
img_id, image, annotation, is_training)) img_id, image, annotation, is_training))
@@ -413,7 +413,7 @@ def create_ssd_dataset(mindrecord_file, batch_size=32, repeat_num=10, device_num
output_columns = ["img_id", "image", "image_shape"] output_columns = ["img_id", "image", "image_shape"]
trans = [normalize_op, change_swap_op] trans = [normalize_op, change_swap_op]
ds = ds.map(input_columns=["img_id", "image", "annotation"], ds = ds.map(input_columns=["img_id", "image", "annotation"],
output_columns=output_columns, columns_order=output_columns,
output_columns=output_columns, column_order=output_columns,
operations=compose_map_func, python_multiprocessing=is_training, operations=compose_map_func, python_multiprocessing=is_training,
num_parallel_workers=num_parallel_workers) num_parallel_workers=num_parallel_workers)
ds = ds.map(input_columns=["image"], operations=trans, python_multiprocessing=is_training, ds = ds.map(input_columns=["image"], operations=trans, python_multiprocessing=is_training,


+ 4
- 4
model_zoo/research/cv/ssd_ghostnet/src/init_params.py View File

@@ -25,16 +25,16 @@ def init_net_param(network, initialize_mode='TruncatedNormal'):
if 'beta' not in p.name and 'gamma' not in p.name and 'bias' not in p.name: if 'beta' not in p.name and 'gamma' not in p.name and 'bias' not in p.name:
np.random.seed(seed=1) np.random.seed(seed=1)
if initialize_mode == 'TruncatedNormal': if initialize_mode == 'TruncatedNormal':
p.set_parameter_data(initializer(
p.set_data(initializer(
TruncatedNormal(), p.data.shape, p.data.dtype)) TruncatedNormal(), p.data.shape, p.data.dtype))
else: else:
p.set_parameter_data(
initialize_mode, p.data.shape, p.data.dtype)
p.set_data(
initialize_mode, p.data.shape)




def load_backbone_params(network, param_dict): def load_backbone_params(network, param_dict):
"""Init the parameters from pre-train model, default is mobilenetv2.""" """Init the parameters from pre-train model, default is mobilenetv2."""
for _, param in net.parameters_and_names():
for _, param in network.parameters_and_names():
param_name = param.name.replace('network.backbone.', '') param_name = param.name.replace('network.backbone.', '')
name_split = param_name.split('.') name_split = param_name.split('.')
if 'features_1' in param_name: if 'features_1' in param_name:


+ 5
- 7
model_zoo/research/cv/ssd_ghostnet/src/ssd_ghostnet.py View File

@@ -118,13 +118,11 @@ class DepthwiseConv(nn.Cell):
stride=stride, pad_mode=pad_mode, pad=pad) stride=stride, pad_mode=pad_mode, pad=pad)
self.bias_add = P.BiasAdd() self.bias_add = P.BiasAdd()
weight_shape = [channel_multiplier, in_planes, *self.kernel_size] weight_shape = [channel_multiplier, in_planes, *self.kernel_size]
self.weight = Parameter(initializer(
'ones', weight_shape))
self.weight = Parameter(initializer('ones', weight_shape), name="weight")


if has_bias: if has_bias:
bias_shape = [channel_multiplier * in_planes] bias_shape = [channel_multiplier * in_planes]
self.bias = Parameter(initializer(
'zeros', bias_shape))
self.bias = Parameter(initializer('zeros', bias_shape), name="bias")
else: else:
self.bias = None self.bias = None


@@ -624,15 +622,15 @@ class TrainingWrapper(nn.Cell):
self.network = network self.network = network
self.weights = ms.ParameterTuple(network.trainable_params()) self.weights = ms.ParameterTuple(network.trainable_params())
self.optimizer = optimizer self.optimizer = optimizer
self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True)
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
self.sens = sens self.sens = sens
self.reducer_flag = False self.reducer_flag = False
self.grad_reducer = None self.grad_reducer = None
self.parallel_mode = context.get_auto_parallel_context("parallel_mode") self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
if self.parallel_mode in [ms.ParallelMode.DATA_PARALLEL, ms.ParallelMode.HYBRID_PARALLEL]:
if self.parallel_mode in [context.ParallelMode.DATA_PARALLEL, context.ParallelMode.HYBRID_PARALLEL]:
self.reducer_flag = True self.reducer_flag = True
if self.reducer_flag: if self.reducer_flag:
mean = context.get_auto_parallel_context("mirror_mean")
mean = context.get_auto_parallel_context("gradients_mean")
if auto_parallel_context().get_device_num_is_set(): if auto_parallel_context().get_device_num_is_set():
degree = context.get_auto_parallel_context("device_num") degree = context.get_auto_parallel_context("device_num")
else: else:


+ 1
- 1
model_zoo/research/cv/ssd_ghostnet/train.py View File

@@ -70,7 +70,7 @@ def main():
if args_opt.distribute: if args_opt.distribute:
device_num = args_opt.device_num device_num = args_opt.device_num
context.reset_auto_parallel_context() context.reset_auto_parallel_context()
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True,
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True,
device_num=device_num) device_num=device_num)
init() init()
rank = args_opt.device_id % device_num rank = args_opt.device_id % device_num


Loading…
Cancel
Save