You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

README.md 13 kB

5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235
  1. # Contents
  2. - [CycleGAN Description](#cyclegan-description)
  3. - [Model Architecture](#model-architecture)
  4. - [Dataset](#dataset)
  5. - [Environment Requirements](#environment-requirements)
  6. - [Script Description](#script-description)
  7. - [Script and Sample Code](#script-and-sample-code)
  8. - [Script Parameters](#script-parameters)
  9. - [Training Process](#training-process)
  10. - [Knowledge Distillation Process](#knowledge-distillation-process)
  11. - [Prediction Process](#prediction-process)
  12. - [Evaluation with cityscape dataset](#evaluation-with-cityscape-dataset)
  13. - [Export MindIR](#export-mindir)
  14. - [Model Description](#model-description)
  15. - [Performance](#performance)
  16. - [Evaluation Performance](#evaluation-performance)
  17. - [Inference Performance](#evaluation-performance)
  18. - [Description of Random Situation](#description-of-random-situation)
  19. - [ModelZoo Homepage](#modelzoo-homepage)
  20. # [CycleGAN Description](#contents)
  21. Generative Adversarial Network (referred to as GAN) is an unsupervised learning method that learns by letting two neural networks play against each other. CycleGAN is a kind of GAN, which consists of two generation networks and two discriminant networks. It converts a certain type of pictures into another type of pictures through unpaired pictures, which can be used for style transfer.
  22. [Paper](https://arxiv.org/abs/1703.10593): Zhu J Y , Park T , Isola P , et al. Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks[J]. 2017.
  23. # [Model Architecture](#contents)
  24. The CycleGAN contains two generation networks and two discriminant networks. We support two architectures for generation networks: resnet and unet. Resnet architecture contains three convolutions, several residual blocks, two fractionally-strided convlutions with stride 1/2, and one convolution that maps features to RGB. Unet architecture contains three unet block to downsample and upsample, several unet blocks unet block and one convolution that maps features to RGB. For the discriminator networks we use 70 × 70 PatchGANs, which aim to classify whether 70 × 70 overlapping image patches are real or fake.
  25. # [Dataset](#contents)
  26. Note that you can run the scripts based on the dataset mentioned in original paper or widely used in relevant domain/network architecture. In the following sections, we will introduce how to run the scripts using the related dataset below.
  27. Dataset used: [CityScape](<https://cityscapes-dataset.com>)
  28. Please download the datasets [gtFine_trainvaltest.zip] and [leftImg8bit_trainvaltest.zip] and unzip them. We provide `src/utils/prepare_cityscapes_dataset.py` to process images. gtFine contains the semantics segmentations. Use --gtFine_dir to specify the path to the unzipped gtFine_trainvaltest directory. leftImg8bit contains the dashcam photographs. Use --leftImg8bit_dir to specify the path to the unzipped leftImg8bit_trainvaltest directory.
  29. The processed images will be placed at --output_dir.
  30. Example usage:
  31. ```bash
  32. python src/utils/prepare_cityscapes_dataset.py --gitFine_dir ./cityscapes/gtFine/ --leftImg8bit_dir ./cityscapes/leftImg8bit --output_dir ./cityscapes/
  33. ```
  34. The directory structure is as follows:
  35. ```path
  36. .
  37. └─cityscapes
  38. ├─trainA
  39. ├─trainB
  40. ├─testA
  41. └─testB
  42. ```
  43. # [Environment Requirements](#contents)
  44. - Hardware GPU
  45. - Prepare hardware environment with GPU processor.
  46. - Framework
  47. - [MindSpore](https://www.mindspore.cn/install/en)
  48. - For more information, please check the resources below:
  49. - [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
  50. - [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
  51. # [Script Description](#contents)
  52. ## [Script and Sample Code](#contents)
  53. ```path
  54. .
  55. └─ cv
  56. └─ cyclegan
  57. ├─ src
  58. ├─ __init__.py # init file
  59. ├─ dataset
  60. ├─ __init__.py # init file
  61. ├─ cyclegan_dataset.py # create cyclegan dataset
  62. ├─ datasets.py # UnalignedDataset and ImageFolderDataset class and some image utils
  63. └─ distributed_sampler.py # iterator of dataset
  64. ├─ models
  65. ├─ __init__.py # init file
  66. ├─ cycle_gan.py # cyclegan model define
  67. ├─ losses.py # cyclegan losses function define
  68. ├─ networks.py # cyclegan sub networks define
  69. ├─ resnet.py # resnet generate network
  70. └─ unet.py # unet generate network
  71. └─ utils
  72. ├─ __init__.py # init file
  73. ├─ args.py # parse args
  74. ├─ prepare_cityscapes_dataset.py # prepare cityscapes dataset to cyclegan format
  75. ├─ cityscapes_utils.py # cityscapes dataset evaluation utils
  76. ├─ reporter.py # Reporter class
  77. └─ tools.py # utils for cyclegan
  78. ├─ cityscape_eval.py # cityscape dataset eval script
  79. ├─ predict.py # generate images from A->B and B->A
  80. ├─ train.py # train script
  81. ├─ export.py # export mindir script
  82. ├─ README.md # descriptions about CycleGAN
  83. └─ mindspore_hub_conf.py # mindspore hub interface
  84. ```
  85. ## [Script Parameters](#contents)
  86. ```python
  87. Major parameters in train.py and config.py as follows:
  88. "model": "resnet" # generator model, should be in [resnet, unet].
  89. "platform": "GPU" # run platform, support GPU, CPU and Ascend.
  90. "device_id": 0 # device id, default is 0.
  91. "lr": 0.0002 # init learning rate, default is 0.0002.
  92. "pool_size": 50 # the size of image buffer that stores previously generated images, default is 50.
  93. "lr_policy": "linear" # learning rate policy, default is linear.
  94. "image_size": 256 # input image_size, default is 256.
  95. "batch_size": 1 # batch_size, default is 1.
  96. "max_epoch": 200 # epoch size for training, default is 200.
  97. "n_epochs": 100 # number of epochs with the initial learning rate, default is 100
  98. "beta1": 0.5 # Adam beta1, default is 0.5.
  99. "init_type": normal # network initialization, default is normal.
  100. "init_gain": 0.02 # scaling factor for normal, xavier and orthogonal, default is 0.02.
  101. "in_planes": 3 # input channels, default is 3.
  102. "ngf": 64 # generator model filter numbers, default is 64.
  103. "gl_num": 9 # generator model residual block numbers, default is 9.
  104. "ndf": 64 # discriminator model filter numbers, default is 64.
  105. "dl_num": 3 # discriminator model residual block numbers, default is 3.
  106. "slope": 0.2 # leakyrelu slope, default is 0.2.
  107. "norm_mode":"instance" # norm mode, should be [batch, instance], default is instance.
  108. "lambda_A": 10 # weight for cycle loss (A -> B -> A), default is 10.
  109. "lambda_B": 10 # weight for cycle loss (B -> A -> B), default is 10.
  110. "lambda_idt": 0.5 # if lambda_idt > 0 use identity mapping.
  111. "gan_mode": lsgan # the type of GAN loss, should be [lsgan, vanilla], default is lsgan.
  112. "pad_mode": REFLECT # the type of Pad, should be [CONSTANT, REFLECT, SYMMETRIC], default is REFLECT.
  113. "need_dropout": True # whether need dropout, default is True.
  114. "kd": False # knowledge distillation learning or not, default is False.
  115. "t_ngf": 64 # teacher network generator model filter numbers when `kd` is True, default is 64.
  116. "t_gl_num":9 # teacher network generator model residual block numbers when `kd` is True, default is 9.
  117. "t_slope": 0.2 # teacher network leakyrelu slope when `kd` is True, default is 0.2.
  118. "t_norm_mode": "instance" #teacher network norm mode when `kd` is True, defaultis instance.
  119. "print_iter": 100 # log print iter, default is 100.
  120. "outputs_dir": "outputs" # models are saved here, default is ./outputs.
  121. "dataroot": None # path of images (should have subfolders trainA, trainB, testA, testB, etc).
  122. "save_imgs": True # whether save imgs when epoch end, if True result images will generate in `outputs_dir/imgs`, default is True.
  123. "GT_A_ckpt": None # teacher network pretrained checkpoint file path of G_A when `kd` is True.
  124. "GT_B_ckpt": None # teacher network pretrained checkpoint file path of G_B when `kd` is True.
  125. "G_A_ckpt": None # pretrained checkpoint file path of G_A.
  126. "G_B_ckpt": None # pretrained checkpoint file path of G_B.
  127. "D_A_ckpt": None # pretrained checkpoint file path of D_A.
  128. "D_B_ckpt": None # pretrained checkpoint file path of D_B.
  129. ```
  130. ## [Training Process](#contents)
  131. ```bash
  132. python train.py --platform [PLATFORM] --dataroot [DATA_PATH]
  133. ```
  134. **Note: pad_mode should be CONSTANT when use Ascend and CPU. When using unet as generate network, the gl_num should less than 7.**
  135. ## [Knowledge Distillation Process](#contents)
  136. ```bash
  137. python train.py --platform [PLATFORM] --dataroot [DATA_PATH] --ngf [NGF] --kd True --GT_A_ckpt [G_A_CKPT] --GT_B_ckpt [G_B_CKPT]
  138. ```
  139. **Note: the student network ngf should be 1/2 or 1/4 of teacher network ngf, if you change default args when training teacher generate networks, please change t_xx in knowledge distillation process.**
  140. ## [Prediction Process](#contents)
  141. ```bash
  142. python predict.py --platform [PLATFORM] --dataroot [DATA_PATH] --G_A_ckpt [G_A_CKPT] --G_B_ckpt [G_B_CKPT]
  143. ```
  144. **Note: the result will saved at `outputs_dir/predict`.**
  145. ## [Evaluation with cityscape dataset](#contents)
  146. ```bash
  147. python cityscape_eval.py --cityscapes_dir [LABEL_PATH] --result_dir [FAKEB_PATH]
  148. ```
  149. **Note: Please run cityscape_eval.py after prediction process.**
  150. ## [Export MindIR](#contents)
  151. ```bash
  152. python export.py --platform [PLATFORM] --G_A_ckpt [G_A_CKPT] --G_B_ckpt [G_B_CKPT] --file_name [FILE_NAME] --file_format [FILE_FORMAT]
  153. ```
  154. **Note: The file_name parameter is the prefix, the final file will as [FILE_NAME]_AtoB.[FILE_FORMAT] and [FILE_NAME]_BtoA.[FILE_FORMAT].**
  155. # [Model Description](#contents)
  156. ## [Performance](#contents)
  157. ### Evaluation Performance
  158. | Parameters | GPU |
  159. | -------------------------- | ----------------------------------------------------------- |
  160. | Model Version | CycleGAN |
  161. | Resource | NV SMX2 V100-32G |
  162. | uploaded Date | 12/10/2020 (month/day/year) |
  163. | MindSpore Version | 1.1.0 |
  164. | Dataset | Cityscapes |
  165. | Training Parameters | epoch=200, steps=2975, batch_size=1, lr=0.002 |
  166. | Optimizer | Adam |
  167. | Loss Function | Mean Sqare Loss & L1 Loss |
  168. | outputs | probability |
  169. | Speed | 1pc: 264 ms/step; |
  170. | Total time | 1pc: 43.6h; |
  171. | Parameters (M) | 11.378 M |
  172. | Checkpoint for Fine tuning | 44M (.ckpt file) |
  173. | Scripts | [CycleGAN script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/cycle_gan) |
  174. ### Inference Performance
  175. | Parameters | GPU |
  176. | ------------------- | --------------------------- |
  177. | Model Version | CycleGAN |
  178. | Resource | GPU |
  179. | Uploaded Date | 12/10/2020 (month/day/year) |
  180. | MindSpore Version | 1.1.0 |
  181. | Dataset | Cityscapes |
  182. | batch_size | 1 |
  183. | outputs | probability |
  184. | Accuracy | mean_pixel_acc: 54.8, mean_class_acc: 21.3, mean_class_iou: 16.1 |
  185. # [Description of Random Situation](#contents)
  186. If you set --use_random=False, there are no random when training.
  187. # [ModelZoo Homepage](#contents)
  188. Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).