Browse Source

!12018 MindSpore社区网络模型征集活动——DenseNet-121

From: @fireinthehole1024
Reviewed-by: 
Signed-off-by:
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
a36485fdb4
6 changed files with 402 additions and 155 deletions
  1. +125
    -88
      model_zoo/official/cv/densenet121/README.md
  2. +77
    -22
      model_zoo/official/cv/densenet121/README_CN.md
  3. +49
    -39
      model_zoo/official/cv/densenet121/eval.py
  4. +63
    -0
      model_zoo/official/cv/densenet121/scripts/run_distribute_eval_gpu.sh
  5. +70
    -0
      model_zoo/official/cv/densenet121/scripts/run_distribute_train_gpu.sh
  6. +18
    -6
      model_zoo/official/cv/densenet121/train.py

+ 125
- 88
model_zoo/official/cv/densenet121/README.md View File

@@ -6,7 +6,7 @@
- [Features](#features) - [Features](#features)
- [Mixed Precision](#mixed-precision) - [Mixed Precision](#mixed-precision)
- [Environment Requirements](#environment-requirements) - [Environment Requirements](#environment-requirements)
- [Quick Start](#quick-start)
- [Quick Start](#quick-start)
- [Script Description](#script-description) - [Script Description](#script-description)
- [Script and Sample Code](#script-and-sample-code) - [Script and Sample Code](#script-and-sample-code)
- [Script Parameters](#script-parameters) - [Script Parameters](#script-parameters)
@@ -22,7 +22,6 @@
- [Description of Random Situation](#description-of-random-situation) - [Description of Random Situation](#description-of-random-situation)
- [ModelZoo Homepage](#modelzoo-homepage) - [ModelZoo Homepage](#modelzoo-homepage)



# [DenseNet121 Description](#contents) # [DenseNet121 Description](#contents)


DenseNet121 is a convolution based neural network for the task of image classification. The paper describing the model can be found [here](https://arxiv.org/abs/1608.06993). HuaWei’s DenseNet121 is a implementation on [MindSpore](https://www.mindspore.cn/). DenseNet121 is a convolution based neural network for the task of image classification. The paper describing the model can be found [here](https://arxiv.org/abs/1608.06993). HuaWei’s DenseNet121 is a implementation on [MindSpore](https://www.mindspore.cn/).
@@ -33,60 +32,55 @@ The repository also contains scripts to launch training and inference routines.


DenseNet121 builds on 4 densely connected block. In every dense block, each layer obtains additional inputs from all preceding layers and passes on its own feature-maps to all subsequent layers. Concatenation is used. Each layer is receiving a “collective knowledge” from all preceding layers. DenseNet121 builds on 4 densely connected block. In every dense block, each layer obtains additional inputs from all preceding layers and passes on its own feature-maps to all subsequent layers. Concatenation is used. Each layer is receiving a “collective knowledge” from all preceding layers.




# [Dataset](#contents) # [Dataset](#contents)


Dataset used: ImageNet
Dataset used: ImageNet
The default configuration of the Dataset are as follows: The default configuration of the Dataset are as follows:
- Training Dataset preprocess:
- Input size of images is 224\*224
- Range (min, max) of respective size of the original size to be cropped is (0.08, 1.0)
- Range (min, max) of aspect ratio to be cropped is (0.75, 1.333)
- Probability of the image being flipped set to 0.5
- Randomly adjust the brightness, contrast, saturation (0.4, 0.4, 0.4)
- Normalize the input image with respect to mean and standard deviation

- Test Dataset preprocess:
- Input size of images is 224\*224 (Resize to 256\*256 then crops images at the center)
- Normalize the input image with respect to mean and standard deviation


- Training Dataset preprocess:
- Input size of images is 224\*224
- Range (min, max) of respective size of the original size to be cropped is (0.08, 1.0)
- Range (min, max) of aspect ratio to be cropped is (0.75, 1.333)
- Probability of the image being flipped set to 0.5
- Randomly adjust the brightness, contrast, saturation (0.4, 0.4, 0.4)
- Normalize the input image with respect to mean and standard deviation


- Test Dataset preprocess:
- Input size of images is 224\*224 (Resize to 256\*256 then crops images at the center)
- Normalize the input image with respect to mean and standard deviation


# [Features](#contents) # [Features](#contents)


## Mixed Precision ## Mixed Precision


The [mixed precision](https://www.mindspore.cn/tutorial/training/en/master/advanced_use/enable_mixed_precision.html) training method accelerates the deep learning neural network training process by using both the single-precision and half-precision data formats, and maintains the network precision achieved by the single-precision training at the same time. Mixed precision training can accelerate the computation process, reduce memory usage, and enable a larger model or batch size to be trained on specific hardware.
The [mixed precision](https://www.mindspore.cn/tutorial/training/en/master/advanced_use/enable_mixed_precision.html) training method accelerates the deep learning neural network training process by using both the single-precision and half-precision data formats, and maintains the network precision achieved by the single-precision training at the same time. Mixed precision training can accelerate the computation process, reduce memory usage, and enable a larger model or batch size to be trained on specific hardware.
For FP16 operators, if the input data type is FP32, the backend of MindSpore will automatically handle it with reduced precision. Users could check the reduced-precision operators by enabling INFO log and then searching ‘reduce precision’. For FP16 operators, if the input data type is FP32, the backend of MindSpore will automatically handle it with reduced precision. Users could check the reduced-precision operators by enabling INFO log and then searching ‘reduce precision’.




# [Environment Requirements](#contents) # [Environment Requirements](#contents)


- Hardware(Ascend)
- Prepare hardware environment with Ascend AI processor. If you want to try Ascend , please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources.
- Hardware(Ascend/GPU
- Prepare hardware environment with Ascend or GPU processor. If you want to try Ascend , please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources.
- Framework - Framework
- [MindSpore](https://www.mindspore.cn/install/en)
- [MindSpore](https://www.mindspore.cn/install/en)
- For more information, please check the resources below: - For more information, please check the resources below:
- [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)


- [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)


# [Quick Start](#contents) # [Quick Start](#contents)


After installing MindSpore via the official website, you can start training and evaluation as follows:
After installing MindSpore via the official website, you can start training and evaluation as follows:

- running on Ascend


```python ```python
# run training example # run training example
python train.py --data_dir /PATH/TO/DATASET --pretrained /PATH/TO/PRETRAINED_CKPT --is_distributed 0 > train.log 2>&1 &
python train.py --data_dir /PATH/TO/DATASET --pretrained /PATH/TO/PRETRAINED_CKPT --is_distributed 0 > train.log 2>&1 &
# run distributed training example # run distributed training example
sh scripts/run_distribute_train.sh 8 rank_table.json /PATH/TO/DATASET /PATH/TO/PRETRAINED_CKPT sh scripts/run_distribute_train.sh 8 rank_table.json /PATH/TO/DATASET /PATH/TO/PRETRAINED_CKPT
# run evaluation example # run evaluation example
python eval.py --data_dir /PATH/TO/DATASET --pretrained /PATH/TO/CHECKPOINT > eval.log 2>&1 &
python eval.py --data_dir /PATH/TO/DATASET --pretrained /PATH/TO/CHECKPOINT > eval.log 2>&1 &
OR OR
sh scripts/run_distribute_eval.sh 8 rank_table.json /PATH/TO/DATASET /PATH/TO/CHECKPOINT sh scripts/run_distribute_eval.sh 8 rank_table.json /PATH/TO/DATASET /PATH/TO/CHECKPOINT
``` ```
@@ -95,44 +89,60 @@ After installing MindSpore via the official website, you can start training and


Please follow the instructions in the link below: Please follow the instructions in the link below:


https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools.
<https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools>.

- running on GPU

For running on GPU, please change `platform` from `Ascend` to `GPU`


# run training example
export CUDA_VISIBLE_DEVICES=0
python train.py --data_dir=[DATASET_PATH] --is_distributed=0 --device_target='GPU' > train.log 2>&1 &


# run distributed training example
sh run_distribute_train_gpu.sh 8 0,1,2,3,4,5,6,7 [DATASET_PATH]

# run evaluation example
python eval.py --data_dir=[DATASET_PATH] --device_target='GPU' --pretrained=[CHECKPOINT_PATH] > eval.log 2>&1 &
OR
sh run_distribute_eval_gpu.sh 1 0 [DATASET_PATH] [CHECKPOINT_PATH]


# [Script Description](#contents) # [Script Description](#contents)


## [Script and Sample Code](#contents) ## [Script and Sample Code](#contents)


```
```text
├── model_zoo ├── model_zoo
├── README.md // descriptions about all the models ├── README.md // descriptions about all the models
├── densenet121
├── densenet121
├── README.md // descriptions about densenet121 ├── README.md // descriptions about densenet121
├── scripts
├── scripts
│ ├── run_distribute_train.sh // shell script for distributed on Ascend │ ├── run_distribute_train.sh // shell script for distributed on Ascend
│ ├── run_distribute_train_gpu.sh // shell script for distributed on GPU
│ ├── run_distribute_eval.sh // shell script for evaluation on Ascend │ ├── run_distribute_eval.sh // shell script for evaluation on Ascend
├── src
│ ├── run_distribute_eval_gpu.sh // shell script for evaluation on GPU
├── src
│ ├── datasets // dataset processing function │ ├── datasets // dataset processing function
│ ├── losses
│ ├── losses
│ ├──crossentropy.py // densenet loss function │ ├──crossentropy.py // densenet loss function
│ ├── lr_scheduler
│ ├── lr_scheduler
│ ├──lr_scheduler.py // densenet learning rate schedule function │ ├──lr_scheduler.py // densenet learning rate schedule function
│ ├── network
│ ├── network
│ ├──densenet.py // densenet architecture │ ├──densenet.py // densenet architecture
│ ├──optimizers // densenet optimize function │ ├──optimizers // densenet optimize function
│ ├──utils
│ ├──utils
│ ├──logging.py // logging function │ ├──logging.py // logging function
│ ├──var_init.py // densenet variable init function │ ├──var_init.py // densenet variable init function
│ ├── config.py // network config │ ├── config.py // network config
├── train.py // training script
├── eval.py // evaluation script
├── train.py // training script
├── eval.py // evaluation script
``` ```


## [Script Parameters](#contents) ## [Script Parameters](#contents)


You can modify the training behaviour through the various flags in the `train.py` script. Flags in the `train.py` script are as follows: You can modify the training behaviour through the various flags in the `train.py` script. Flags in the `train.py` script are as follows:


```
```python
--data_dir train data dir --data_dir train data dir
--num_classes num of classes in dataset(default:1000) --num_classes num of classes in dataset(default:1000)
--image_size image size of the dataset --image_size image size of the dataset
@@ -150,7 +160,7 @@ You can modify the training behaviour through the various flags in the `train.py
--momentum momentum(default: 0.9) --momentum momentum(default: 0.9)
--label_smooth whether to use label smooth in CE --label_smooth whether to use label smooth in CE
--label_smooth_factor smooth strength of original one-hot --label_smooth_factor smooth strength of original one-hot
--log_interval logging interval(dafault:100)
--log_interval logging interval(default:100)
--ckpt_path path to save checkpoint --ckpt_path path to save checkpoint
--ckpt_interval the interval to save checkpoint --ckpt_interval the interval to save checkpoint
--is_save_on_master save checkpoint on master or all rank --is_save_on_master save checkpoint on master or all rank
@@ -159,21 +169,19 @@ You can modify the training behaviour through the various flags in the `train.py
--group_size world size of distributed(default: 1) --group_size world size of distributed(default: 1)
``` ```




## [Training Process](#contents) ## [Training Process](#contents)


### Training
### Training


- running on Ascend - running on Ascend


```python
python train.py --data_dir /PATH/TO/DATASET --pretrained /PATH/TO/PRETRAINED_CKPT --is_distributed 0 > train.log 2>&1 &
``` ```
python train.py --data_dir /PATH/TO/DATASET --pretrained /PATH/TO/PRETRAINED_CKPT --is_distributed 0 > train.log 2>&1 &
```

The python command above will run in the background, The log and model checkpoint will be generated in `output/202x-xx-xx_time_xx_xx_xx/`. The loss value will be achieved as follows: The python command above will run in the background, The log and model checkpoint will be generated in `output/202x-xx-xx_time_xx_xx_xx/`. The loss value will be achieved as follows:
```

```shell
2020-08-22 16:58:56,617:INFO:epoch[0], iter[5003], loss:4.367, mean_fps:0.00 imgs/sec 2020-08-22 16:58:56,617:INFO:epoch[0], iter[5003], loss:4.367, mean_fps:0.00 imgs/sec
2020-08-22 16:58:56,619:INFO:local passed 2020-08-22 16:58:56,619:INFO:local passed
2020-08-22 17:02:19,920:INFO:epoch[1], iter[10007], loss:3.193, mean_fps:6301.11 imgs/sec 2020-08-22 17:02:19,920:INFO:epoch[1], iter[10007], loss:3.193, mean_fps:6301.11 imgs/sec
@@ -183,19 +191,28 @@ You can modify the training behaviour through the various flags in the `train.py
... ...
``` ```


- running on GPU


```python
export CUDA_VISIBLE_DEVICES=0
python train.py --data_dir=[DATASET_PATH] --is_distributed=0 --device_target='GPU' > train.log 2>&1 &
```

The python command above will run in the background, you can view the results through the file `train.log`.

After training, you'll get some checkpoint files under the folder `./ckpt_0/` by default.


### Distributed Training ### Distributed Training


- running on Ascend - running on Ascend


```
```bash
sh scripts/run_distribute_train.sh 8 rank_table.json /PATH/TO/DATASET /PATH/TO/PRETRAINED_CKPT sh scripts/run_distribute_train.sh 8 rank_table.json /PATH/TO/DATASET /PATH/TO/PRETRAINED_CKPT
``` ```
The above shell script will run distribute training in the background. You can view the results log and model checkpoint through the file `train[X]/output/202x-xx-xx_time_xx_xx_xx/`. The loss value will be achieved as follows: The above shell script will run distribute training in the background. You can view the results log and model checkpoint through the file `train[X]/output/202x-xx-xx_time_xx_xx_xx/`. The loss value will be achieved as follows:
```
```log
2020-08-22 16:58:54,556:INFO:epoch[0], iter[5003], loss:3.857, mean_fps:0.00 imgs/sec 2020-08-22 16:58:54,556:INFO:epoch[0], iter[5003], loss:3.857, mean_fps:0.00 imgs/sec
2020-08-22 17:02:19,188:INFO:epoch[1], iter[10007], loss:3.18, mean_fps:6260.18 imgs/sec 2020-08-22 17:02:19,188:INFO:epoch[1], iter[10007], loss:3.18, mean_fps:6260.18 imgs/sec
2020-08-22 17:05:42,490:INFO:epoch[2], iter[15011], loss:2.621, mean_fps:6301.11 imgs/sec 2020-08-22 17:05:42,490:INFO:epoch[2], iter[15011], loss:2.621, mean_fps:6301.11 imgs/sec
@@ -206,7 +223,14 @@ You can modify the training behaviour through the various flags in the `train.py
... ...
``` ```


- running on GPU

```bash
cd scripts
sh run_distribute_train_gpu.sh 8 0,1,2,3,4,5,6,7 [DATASET_PATH]
```


The above shell script will run distribute training in the background. You can view the results through the file `train/train.log`.


## [Evaluation Process](#contents) ## [Evaluation Process](#contents)


@@ -214,59 +238,72 @@ You can modify the training behaviour through the various flags in the `train.py


- evaluation on Ascend - evaluation on Ascend


running the command below for evaluation.
```
python eval.py --data_dir /PATH/TO/DATASET --pretrained /PATH/TO/CHECKPOINT > eval.log 2>&1 &
running the command below for evaluation.
```python
python eval.py --data_dir /PATH/TO/DATASET --pretrained /PATH/TO/CHECKPOINT > eval.log 2>&1 &
OR OR
sh scripts/run_distribute_eval.sh 8 rank_table.json /PATH/TO/DATASET /PATH/TO/CHECKPOINT sh scripts/run_distribute_eval.sh 8 rank_table.json /PATH/TO/DATASET /PATH/TO/CHECKPOINT
``` ```
The above python command will run in the background. You can view the results through the file "output/202x-xx-xx_time_xx_xx_xx/202x_xxxx.log". The accuracy of the test dataset will be as follows: The above python command will run in the background. You can view the results through the file "output/202x-xx-xx_time_xx_xx_xx/202x_xxxx.log". The accuracy of the test dataset will be as follows:
```
```shell
2020-08-24 09:21:50,551:INFO:after allreduce eval: top1_correct=37657, tot=49920, acc=75.43% 2020-08-24 09:21:50,551:INFO:after allreduce eval: top1_correct=37657, tot=49920, acc=75.43%
2020-08-24 09:21:50,551:INFO:after allreduce eval: top5_correct=46224, tot=49920, acc=92.60% 2020-08-24 09:21:50,551:INFO:after allreduce eval: top5_correct=46224, tot=49920, acc=92.60%
``` ```


- evaluation on GPU


running the command below for evaluation.


```python
python eval.py --data_dir=[DATASET_PATH] --device_target='GPU' --pretrained=[CHECKPOINT_PATH] > eval.log 2>&1 &
OR
sh run_distribute_eval_gpu.sh 1 0 [DATASET_PATH] [CHECKPOINT_PATH]
```

The above python command will run in the background. You can view the results through the file "eval/eval.log". The accuracy of the test dataset will be as follows:

```shell
2021-02-04 14:20:50,551:INFO:after allreduce eval: top1_correct=37637, tot=49984, acc=75.30%
2021-02-04 14:20:50,551:INFO:after allreduce eval: top5_correct=46370, tot=49984, acc=92.77%
```


# [Model Description](#contents) # [Model Description](#contents)

## [Performance](#contents) ## [Performance](#contents)


### Training accuracy results ### Training accuracy results


| Parameters | Densenet |
| ------------------- | --------------------------- |
| Model Version | Inception V1 |
| Resource | Ascend 910 |
| Uploaded Date | 09/15/2020 (month/day/year) |
| MindSpore Version | 1.0.0 |
| Dataset | ImageNet |
| epochs | 120 |
| outputs | probability |
| accuracy | Top1:75.13%; Top5:92.57% |
| Parameters | Ascend | GPU |
| ------------------- | --------------------------- | --------------------------- |
| Model Version | Inception V1 | Inception V1 |
| Resource | Ascend 910 | Tesla V100-PCIE |
| Uploaded Date | 09/15/2020 (month/day/year) | 01/27/2021 (month/day/year) |
| MindSpore Version | 1.0.0 | 1.1.0 |
| Dataset | ImageNet | ImageNet |
| epochs | 120 | 120 |
| outputs | probability | probability |
| accuracy | Top1:75.13%; Top5:92.57% | Top1:75.30%; Top5:92.77% |


### Training performance results ### Training performance results


| Parameters | Densenet |
| ------------------- | --------------------------- |
| Model Version | Inception V1 |
| Resource | Ascend 910 |
| Uploaded Date | 09/15/2020 (month/day/year) |
| MindSpore Version | 1.0.0 |
| Dataset | ImageNet |
| batch_size | 32 |
| outputs | probability |
| speed | 1pc:760 img/s;8pc:6000 img/s|


| Parameters | Ascend | GPU |
| ------------------- | --------------------------- | ---------------------------- |
| Model Version | Inception V1 | Inception V1 |
| Resource | Ascend 910 | Tesla V100-PCIE |
| Uploaded Date | 09/15/2020 (month/day/year) | 02/04/2021 (month/day/year) |
| MindSpore Version | 1.0.0 | 1.1.1 |
| Dataset | ImageNet | ImageNet |
| batch_size | 32 | 32 |
| outputs | probability | probability |
| speed | 1pc:760 img/s;8pc:6000 img/s| 1pc:161 img/s;8pc:1288 img/s |


# [Description of Random Situation](#contents) # [Description of Random Situation](#contents)


In dataset.py, we set the seed inside “create_dataset" function. We also use random seed in train.py.

In dataset.py, we set the seed inside “create_dataset" function. We also use random seed in train.py.


# [ModelZoo Homepage](#contents) # [ModelZoo Homepage](#contents)
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).

Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).

+ 77
- 22
model_zoo/official/cv/densenet121/README_CN.md View File

@@ -63,8 +63,8 @@ DenseNet-121构建在4个密集连接块上。各个密集块中,每个层都


# 环境要求 # 环境要求


- 硬件(Ascend)
- 准备Ascend AI处理器搭建硬件环境。如需试用昇腾处理器,请发送[申请表](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx)至ascend@huawei.com,审核通过即可获得资源。
- 硬件(Ascend/GPU
- 准备Ascend或GPU处理器搭建硬件环境。如需试用昇腾处理器,请发送[申请表](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx)至ascend@huawei.com,审核通过即可获得资源。
- 框架 - 框架
- [MindSpore](https://www.mindspore.cn/install) - [MindSpore](https://www.mindspore.cn/install)
- 如需查看详情,请参见如下资源: - 如需查看详情,请参见如下资源:
@@ -75,6 +75,8 @@ DenseNet-121构建在4个密集连接块上。各个密集块中,每个层都


通过官方网站安装MindSpore后,您可以按照如下步骤进行训练和评估: 通过官方网站安装MindSpore后,您可以按照如下步骤进行训练和评估:


- Ascend处理器环境运行

```python ```python
# 训练示例 # 训练示例
python train.py --data_dir /PATH/TO/DATASET --pretrained /PATH/TO/PRETRAINED_CKPT --is_distributed 0 > train.log 2>&1 & python train.py --data_dir /PATH/TO/DATASET --pretrained /PATH/TO/PRETRAINED_CKPT --is_distributed 0 > train.log 2>&1 &
@@ -94,6 +96,22 @@ DenseNet-121构建在4个密集连接块上。各个密集块中,每个层都


[链接](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools) [链接](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools)


- GPU处理器环境运行

```python
# 训练示例
export CUDA_VISIBLE_DEVICES=0
python train.py --data_dir=[DATASET_PATH] --is_distributed=0 --device_target='GPU' > train.log 2>&1 &

# 分布式训练示例
sh run_distribute_train_gpu.sh 8 0,1,2,3,4,5,6,7 [DATASET_PATH]

# 评估示例
python eval.py --data_dir=[DATASET_PATH] --device_target='GPU' --pretrained=[CHECKPOINT_PATH] > eval.log 2>&1 &
OR
sh run_distribute_eval_gpu.sh 1 0 [DATASET_PATH] [CHECKPOINT_PATH]
```

# 脚本说明 # 脚本说明


## 脚本及样例代码 ## 脚本及样例代码
@@ -105,7 +123,9 @@ DenseNet-121构建在4个密集连接块上。各个密集块中,每个层都
├── README.md // DenseNet-121相关说明 ├── README.md // DenseNet-121相关说明
├── scripts ├── scripts
│ ├── run_distribute_train.sh // Ascend分布式shell脚本 │ ├── run_distribute_train.sh // Ascend分布式shell脚本
│ ├── run_distribute_train_gpu.sh // GPU分布式shell脚本
│ ├── run_distribute_eval.sh // Ascend评估shell脚本 │ ├── run_distribute_eval.sh // Ascend评估shell脚本
│ ├── run_distribute_eval_gpu.sh // GPU评估shell脚本
├── src ├── src
│ ├── datasets // 数据集处理函数 │ ├── datasets // 数据集处理函数
│ ├── losses │ ├── losses
@@ -176,6 +196,15 @@ DenseNet-121构建在4个密集连接块上。各个密集块中,每个层都
... ...
``` ```


- GPU处理器环境运行

```python
export CUDA_VISIBLE_DEVICES=0
python train.py --data_dir=[DATASET_PATH] --is_distributed=0 --device_target='GPU' > train.log 2>&1 &
```

以上python命令在后台运行,在`output/202x-xx-xx_time_xx_xx/`目录下生成日志和模型检查点。

### 分布式训练 ### 分布式训练


- Ascend处理器环境运行 - Ascend处理器环境运行
@@ -197,6 +226,15 @@ DenseNet-121构建在4个密集连接块上。各个密集块中,每个层都
... ...
``` ```


- GPU处理器环境运行

```bash
cd scripts
sh run_distribute_train_gpu.sh 8 0,1,2,3,4,5,6,7 [DATASET_PATH]
```

上述shell脚本将在后台进行分布式训练。可以通过文件`train[X]/output/202x-xx-xx_time_xx_xx_xx/`查看结果日志和模型检查点。

## 评估过程 ## 评估过程


### 评估 ### 评估
@@ -218,35 +256,52 @@ DenseNet-121构建在4个密集连接块上。各个密集块中,每个层都
2020-08-24 09:21:50,551:INFO:after allreduce eval: top5_correct=46224, tot=49920, acc=92.60% 2020-08-24 09:21:50,551:INFO:after allreduce eval: top5_correct=46224, tot=49920, acc=92.60%
``` ```


- GPU处理器环境

运行以下命令进行评估。

```eval
python eval.py --data_dir=[DATASET_PATH] --device_target='GPU' --pretrained=[CHECKPOINT_PATH] > eval.log 2>&1 &
OR
sh run_distribute_eval_gpu.sh 1 0 [DATASET_PATH] [CHECKPOINT_PATH]
```

上述python命令在后台运行。可以通过“eval/eval.log”文件查看结果。测试数据集的准确率如下:

```log
2021-02-04 14:20:50,551:INFO:after allreduce eval: top1_correct=37637, tot=49984, acc=75.30%
2021-02-04 14:20:50,551:INFO:after allreduce eval: top5_correct=46370, tot=49984, acc=92.77%
```

# 模型描述 # 模型描述


## 性能 ## 性能


### 训练准确率结果 ### 训练准确率结果


| 参数 | DenseNet |
| ------------------- | --------------------------- |
| 模型版本 | Inception V1 |
| 资源 | Ascend 910 |
| 上传日期 | 2020/9/15 |
| MindSpore版本 | 1.0.0 |
| 数据集 | ImageNet |
| 轮次 | 120 |
| 输出 | 概率 |
| 训练性能 | Top1:75.13%; Top5:92.57% |
| 参数 | Ascend | GPU |
| ------------------- | -------------------------- | -------------------------- |
| 模型版本 | Inception V1 | Inception V1 |
| 资源 | Ascend 910 | Tesla V100-PCIE |
| 上传日期 | 2020/9/15 | 2021/2/4 |
| MindSpore版本 | 1.0.0 | 1.1.1 |
| 数据集 | ImageNet | ImageNet |
| 轮次 | 120 | 120 |
| 输出 | 概率 | 概率 |
| 训练性能 | Top1:75.13%;Top5:92.57% | Top1:75.30%; Top5:92.77% |


### 训练性能结果 ### 训练性能结果


| 参数 | DenseNet |
| ------------------- | --------------------------- |
| 模型版本 | Inception V1 |
| 资源 | Ascend 910 |
| 上传日期 | 2020/9/15 |
| MindSpore版本 | 1.0.0 |
| 数据集 | ImageNet |
| batch_size | 32 |
| 输出 | 概率 |
| 速度 | 单卡:760 img/s;8卡:6000 img/s |
| 参数 | Ascend | GPU |
| ------------------- | -------------------------------- | -------------------------------- |
| 模型版本 | Inception V1 | Inception V1 |
| 资源 | Ascend 910 | Tesla V100-PCIE |
| 上传日期 | 2020/9/15 | 2021/2/4 |
| MindSpore版本 | 1.0.0 | 1.1.1 |
| 数据集 | ImageNet | ImageNet |
| batch_size | 32 | 32 |
| 输出 | 概率 | 概率 |
| 速度 | 单卡:760 img/s;8卡:6000 img/s | 单卡:161 img/s;8卡:1288 img/s |


# 随机情况说明 # 随机情况说明




+ 49
- 39
model_zoo/official/cv/densenet121/eval.py View File

@@ -38,10 +38,6 @@ from src.datasets import classification_dataset
from src.network import DenseNet121 from src.network import DenseNet121
from src.config import config from src.config import config


devid = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Davinci",
save_graphs=True, device_id=devid)



class ParameterReduce(nn.Cell): class ParameterReduce(nn.Cell):
""" """
@@ -83,6 +79,9 @@ def parse_args(cloud_args=None):
# roma obs # roma obs
parser.add_argument('--train_url', type=str, default="", help='train url') parser.add_argument('--train_url', type=str, default="", help='train url')


# platform
parser.add_argument('--device_target', type=str, default='Ascend', choices=('Ascend', 'GPU'), help='device target')

args, _ = parser.parse_known_args() args, _ = parser.parse_known_args()
args = merge_args(args, cloud_args) args = merge_args(args, cloud_args)


@@ -114,6 +113,42 @@ def merge_args(args, cloud_args):
args_dict[key] = val args_dict[key] = val
return args return args


def generate_results(model, rank, group_size, top1_correct, top5_correct, img_tot):
model_md5 = model.replace('/', '')
tmp_dir = '../cache'
if not os.path.exists(tmp_dir):
os.mkdir(tmp_dir)

top1_correct_npy = '{}/top1_rank_{}_{}.npy'.format(tmp_dir, rank, model_md5)
top5_correct_npy = '{}/top5_rank_{}_{}.npy'.format(tmp_dir, rank, model_md5)
img_tot_npy = '{}/img_tot_rank_{}_{}.npy'.format(tmp_dir, rank, model_md5)
np.save(top1_correct_npy, top1_correct)
np.save(top5_correct_npy, top5_correct)
np.save(img_tot_npy, img_tot)
while True:
rank_ok = True
for other_rank in range(group_size):
top1_correct_npy = '{}/top1_rank_{}_{}.npy'.format(tmp_dir, other_rank, model_md5)
top5_correct_npy = '{}/top5_rank_{}_{}.npy'.format(tmp_dir, other_rank, model_md5)
img_tot_npy = '{}/img_tot_rank_{}_{}.npy'.format(tmp_dir, other_rank, model_md5)
if not os.path.exists(top1_correct_npy) or not os.path.exists(top5_correct_npy) \
or not os.path.exists(img_tot_npy):
rank_ok = False
if rank_ok:
break

top1_correct_all = 0
top5_correct_all = 0
img_tot_all = 0
for other_rank in range(group_size):
top1_correct_npy = '{}/top1_rank_{}_{}.npy'.format(tmp_dir, other_rank, model_md5)
top5_correct_npy = '{}/top5_rank_{}_{}.npy'.format(tmp_dir, other_rank, model_md5)
img_tot_npy = '{}/img_tot_rank_{}_{}.npy'.format(tmp_dir, other_rank, model_md5)
top1_correct_all += np.load(top1_correct_npy)
top5_correct_all += np.load(top5_correct_npy)
img_tot_all += np.load(img_tot_npy)
return [[top1_correct_all], [top5_correct_all], [img_tot_all]]

def test(cloud_args=None): def test(cloud_args=None):
""" """
network eval function. Get top1 and top5 ACC from classification. network eval function. Get top1 and top5 ACC from classification.
@@ -121,6 +156,12 @@ def test(cloud_args=None):
""" """
args = parse_args(cloud_args) args = parse_args(cloud_args)


context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target,
save_graphs=True)
if args.device_target == 'Ascend':
devid = int(os.getenv('DEVICE_ID'))
context.set_context(device_id=devid)

# init distributed # init distributed
if args.is_distributed: if args.is_distributed:
init() init()
@@ -164,7 +205,8 @@ def test(cloud_args=None):
load_param_into_net(network, param_dict_new) load_param_into_net(network, param_dict_new)
args.logger.info('load model {} success'.format(model)) args.logger.info('load model {} success'.format(model))


network.add_flags_recursive(fp16=True)
if args.device_target == 'Ascend':
network.add_flags_recursive(fp16=True)


img_tot = 0 img_tot = 0
top1_correct = 0 top1_correct = 0
@@ -186,41 +228,9 @@ def test(cloud_args=None):
results = [[top1_correct], [top5_correct], [img_tot]] results = [[top1_correct], [top5_correct], [img_tot]]
args.logger.info('before results={}'.format(results)) args.logger.info('before results={}'.format(results))
if args.is_distributed: if args.is_distributed:
model_md5 = model.replace('/', '')
tmp_dir = '../cache'
if not os.path.exists(tmp_dir):
os.mkdir(tmp_dir)
top1_correct_npy = '{}/top1_rank_{}_{}.npy'.format(tmp_dir, args.rank, model_md5)
top5_correct_npy = '{}/top5_rank_{}_{}.npy'.format(tmp_dir, args.rank, model_md5)
img_tot_npy = '{}/img_tot_rank_{}_{}.npy'.format(tmp_dir, args.rank, model_md5)
np.save(top1_correct_npy, top1_correct)
np.save(top5_correct_npy, top5_correct)
np.save(img_tot_npy, img_tot)
while True:
rank_ok = True
for other_rank in range(args.group_size):
top1_correct_npy = '{}/top1_rank_{}_{}.npy'.format(tmp_dir, other_rank, model_md5)
top5_correct_npy = '{}/top5_rank_{}_{}.npy'.format(tmp_dir, other_rank, model_md5)
img_tot_npy = '{}/img_tot_rank_{}_{}.npy'.format(tmp_dir, other_rank, model_md5)
if not os.path.exists(top1_correct_npy) or not os.path.exists(top5_correct_npy) \
or not os.path.exists(img_tot_npy):
rank_ok = False
if rank_ok:
break

top1_correct_all = 0
top5_correct_all = 0
img_tot_all = 0
for other_rank in range(args.group_size):
top1_correct_npy = '{}/top1_rank_{}_{}.npy'.format(tmp_dir, other_rank, model_md5)
top5_correct_npy = '{}/top5_rank_{}_{}.npy'.format(tmp_dir, other_rank, model_md5)
img_tot_npy = '{}/img_tot_rank_{}_{}.npy'.format(tmp_dir, other_rank, model_md5)
top1_correct_all += np.load(top1_correct_npy)
top5_correct_all += np.load(top5_correct_npy)
img_tot_all += np.load(img_tot_npy)
results = [[top1_correct_all], [top5_correct_all], [img_tot_all]]
results = generate_results(model, args.rank, args.group_size, top1_correct,
top5_correct, img_tot)
results = np.array(results) results = np.array(results)

else: else:
results = np.array(results) results = np.array(results)




+ 63
- 0
model_zoo/official/cv/densenet121/scripts/run_distribute_eval_gpu.sh View File

@@ -0,0 +1,63 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

if [ $# -lt 4 ]
then
echo "Usage: sh run_distribute_eval_gpu.sh [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] [CHECKPOINT_PATH]"
exit 1
fi

if [ $1 -lt 1 ] && [ $1 -gt 8 ]
then
echo "error: DEVICE_NUM=$1 is not in (1-8)"
exit 1
fi

export DEVICE_NUM=$1
export RANK_SIZE=$1

# check checkpoint file
if [ ! -f $4 ]
then
echo "error: CHECKPOINT_PATH=$4 is not a file"
exit 1
fi

BASEPATH=$(cd "`dirname $0`" || exit; pwd)
export PYTHONPATH=${BASEPATH}:$PYTHONPATH

if [ -d "../eval" ];
then
rm -rf ../eval
fi
mkdir ../eval
cd ../eval || exit

export CUDA_VISIBLE_DEVICES="$2"

if [ $1 -gt 1 ]
then
mpirun -n $1 --allow-run-as-root python3 ${BASEPATH}/../eval.py \
--data_dir=$3 \
--device_target='GPU' \
--pretrained=$4 > eval.log 2>&1 &
else

python3 ${BASEPATH}/../eval.py \
--data_dir=$3 \
--device_target='GPU' \
--pretrained=$4 > eval.log 2>&1 &
fi

+ 70
- 0
model_zoo/official/cv/densenet121/scripts/run_distribute_train_gpu.sh View File

@@ -0,0 +1,70 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

if [ $# -lt 3 ]
then
echo "Usage: sh run_distribute_train_gpu.sh [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] [PRE_TRAINED](optional)"
exit 1
fi

if [ $1 -lt 1 ] && [ $1 -gt 8 ]
then
echo "error: DEVICE_NUM=$1 is not in (1-8)"
exit 1
fi

export DEVICE_NUM=$1
export RANK_SIZE=$1

BASEPATH=$(cd "`dirname $0`" || exit; pwd)
export PYTHONPATH=${BASEPATH}:$PYTHONPATH
if [ -d "../train" ];
then
rm -rf ../train
fi
mkdir ../train
cd ../train || exit

export CUDA_VISIBLE_DEVICES="$2"

if [ -f $4 ] # pretrained ckpt
then
if [ $1 -gt 1 ]
then
mpirun -n $1 --allow-run-as-root python3 ${BASEPATH}/../train.py \
--data_dir=$3 \
--device_target='GPU' \
--pretrained=$4 > train.log 2>&1 &
else
python3 ${BASEPATH}/../train.py \
--data_dir=$3 \
--is_distributed=0 \
--device_target='GPU' \
--pretrained=$4 > train.log 2>&1 &
fi
else
if [ $1 -gt 1 ]
then
mpirun -n $1 --allow-run-as-root python3 ${BASEPATH}/../train.py \
--data_dir=$3 \
--device_target='GPU' > train.log 2>&1 &
else
python3 ${BASEPATH}/../train.py \
--data_dir=$3 \
--is_distributed=0 \
--device_target='GPU' > train.log 2>&1 &
fi
fi

+ 18
- 6
model_zoo/official/cv/densenet121/train.py View File

@@ -39,10 +39,6 @@ from src.lr_scheduler import MultiStepLR, CosineAnnealingLR
from src.utils.logging import get_logger from src.utils.logging import get_logger
from src.config import config from src.config import config


devid = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True,
device_target="Davinci", save_graphs=False, device_id=devid)

set_seed(1) set_seed(1)


class BuildTrainNetwork(nn.Cell): class BuildTrainNetwork(nn.Cell):
@@ -124,6 +120,9 @@ def parse_args(cloud_args=None):
# roma obs # roma obs
parser.add_argument('--train_url', type=str, default="", help='train url') parser.add_argument('--train_url', type=str, default="", help='train url')


# platform
parser.add_argument('--device_target', type=str, default='Ascend', choices=('Ascend', 'GPU'), help='device target')

args, _ = parser.parse_known_args() args, _ = parser.parse_known_args()
args = merge_args(args, cloud_args) args = merge_args(args, cloud_args)
args.image_size = config.image_size args.image_size = config.image_size
@@ -172,6 +171,13 @@ def train(cloud_args=None):
"""training process""" """training process"""
args = parse_args(cloud_args) args = parse_args(cloud_args)


context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True,
device_target=args.device_target, save_graphs=False)

if args.device_target == 'Ascend':
devid = int(os.getenv('DEVICE_ID'))
context.set_context(device_id=devid)

# init distributed # init distributed
if args.is_distributed: if args.is_distributed:
init() init()
@@ -181,7 +187,7 @@ def train(cloud_args=None):
if args.is_dynamic_loss_scale == 1: if args.is_dynamic_loss_scale == 1:
args.loss_scale = 1 # for dynamic loss scale can not set loss scale in momentum opt args.loss_scale = 1 # for dynamic loss scale can not set loss scale in momentum opt


# select for master rank save ckpt or all rank save, compatiable for model parallel
# select for master rank save ckpt or all rank save, compatible for model parallel
args.rank_save_ckpt_flag = 0 args.rank_save_ckpt_flag = 0
if args.is_save_on_master: if args.is_save_on_master:
if args.rank == 0: if args.rank == 0:
@@ -269,7 +275,13 @@ def train(cloud_args=None):


context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=args.group_size, context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=args.group_size,
gradients_mean=True) gradients_mean=True)
model = Model(train_net, optimizer=opt, metrics=None, loss_scale_manager=loss_scale_manager, amp_level="O3")

if args.device_target == 'Ascend':
model = Model(train_net, optimizer=opt, metrics=None, loss_scale_manager=loss_scale_manager, amp_level="O3")
elif args.device_target == 'GPU':
model = Model(train_net, optimizer=opt, metrics=None, loss_scale_manager=loss_scale_manager, amp_level="O0")
else:
raise ValueError("Unsupported device target.")


# checkpoint save # checkpoint save
progress_cb = ProgressMonitor(args) progress_cb = ProgressMonitor(args)


Loading…
Cancel
Save