You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

README.md 26 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
5 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
5 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
5 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
5 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540
  1. # Contents
  2. [查看中文](./README_CN.md)
  3. - [GoogleNet Description](#googlenet-description)
  4. - [Model Architecture](#model-architecture)
  5. - [Dataset](#dataset)
  6. - [Features](#features)
  7. - [Mixed Precision](#mixed-precision)
  8. - [Environment Requirements](#environment-requirements)
  9. - [Quick Start](#quick-start)
  10. - [Script Description](#script-description)
  11. - [Script and Sample Code](#script-and-sample-code)
  12. - [Script Parameters](#script-parameters)
  13. - [Training Process](#training-process)
  14. - [Training](#training)
  15. - [Distributed Training](#distributed-training)
  16. - [Evaluation Process](#evaluation-process)
  17. - [Evaluation](#evaluation)
  18. - [Model Description](#model-description)
  19. - [Performance](#performance)
  20. - [Evaluation Performance](#evaluation-performance)
  21. - [Inference Performance](#evaluation-performance)
  22. - [How to use](#how-to-use)
  23. - [Inference](#inference)
  24. - [Continue Training on the Pretrained Model](#continue-training-on-the-pretrained-model)
  25. - [Transfer Learning](#transfer-learning)
  26. - [Description of Random Situation](#description-of-random-situation)
  27. - [ModelZoo Homepage](#modelzoo-homepage)
  28. # [GoogleNet Description](#contents)
  29. GoogleNet, a 22 layers deep network, was proposed in 2014 and won the first place in the ImageNet Large-Scale Visual Recognition Challenge 2014 (ILSVRC14). GoogleNet, also called Inception v1, has significant improvement over ZFNet (The winner in 2013) and AlexNet (The winner in 2012), and has relatively lower error rate compared to VGGNet. Typically deeper deep learning network means larger number of parameters, which makes it more prone to overfitting. Furthermore, the increased network size leads to increased use of computational resources. To tackle these issues, GoogleNet adopts 1*1 convolution middle of the network to reduce dimension, and thus further reduce the computation. Global average pooling is used at the end of the network, instead of using fully connected layers. Another technique, called inception module, is to have different sizes of convolutions for the same input and stacking all the outputs.
  30. [Paper](https://arxiv.org/abs/1409.4842): Christian Szegedy, Wei Liu, Yangqing Jia, Pierre Sermanet, Scott Reed, Dragomir Anguelov, Dumitru Erhan, Vincent Vanhoucke, Andrew Rabinovich. "Going deeper with convolutions." *Proceedings of the IEEE conference on computer vision and pattern recognition*. 2015.
  31. # [Model Architecture](#contents)
  32. Specifically, the GoogleNet contains numerous inception modules, which are connected together to go deeper. In general, an inception module with dimensionality reduction consists of **1×1 conv**, **3×3 conv**, **5×5 conv**, and **3×3 max pooling**, which are done altogether for the previous input, and stack together again at output. In our model architecture, the kernel size used in inception module is 3×3 instead of 5×5.
  33. # [Dataset](#contents)
  34. Note that you can run the scripts based on the dataset mentioned in original paper or widely used in relevant domain/network architecture. In the following sections, we will introduce how to run the scripts using the related dataset below.
  35. Dataset used: [CIFAR-10](http://www.cs.toronto.edu/~kriz/cifar.html)
  36. - Dataset size:175M,60,000 32*32 colorful images in 10 classes
  37. - Train:146M,50,000 images
  38. - Test:29M,10,000 images
  39. - Data format:binary files
  40. - Note:Data will be processed in src/dataset.py
  41. Dataset used can refer to paper.
  42. - Dataset size: 125G, 1250k colorful images in 1000 classes
  43. - Train: 120G, 1200k images
  44. - Test: 5G, 50k images
  45. - Data format: RGB images.
  46. - Note: Data will be processed in src/dataset.py
  47. # [Features](#contents)
  48. ## Mixed Precision
  49. The [mixed precision](https://www.mindspore.cn/tutorial/training/en/master/advanced_use/enable_mixed_precision.html) training method accelerates the deep learning neural network training process by using both the single-precision and half-precision data formats, and maintains the network precision achieved by the single-precision training at the same time. Mixed precision training can accelerate the computation process, reduce memory usage, and enable a larger model or batch size to be trained on specific hardware.
  50. For FP16 operators, if the input data type is FP32, the backend of MindSpore will automatically handle it with reduced precision. Users could check the reduced-precision operators by enabling INFO log and then searching ‘reduce precision’.
  51. # [Environment Requirements](#contents)
  52. - Hardware(Ascend/GPU)
  53. - Prepare hardware environment with Ascend or GPU processor.
  54. - Framework
  55. - [MindSpore](https://www.mindspore.cn/install/en)
  56. - For more information, please check the resources below:
  57. - [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
  58. - [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
  59. # [Quick Start](#contents)
  60. After installing MindSpore via the official website, you can start training and evaluation as follows:
  61. - running on Ascend
  62. ```python
  63. # run training example
  64. python train.py > train.log 2>&1 &
  65. # run distributed training example
  66. sh scripts/run_train.sh rank_table.json
  67. # run evaluation example
  68. python eval.py > eval.log 2>&1 &
  69. OR
  70. sh run_eval.sh
  71. ```
  72. For distributed training, a hccl configuration file with JSON format needs to be created in advance.
  73. Please follow the instructions in the link below:
  74. <https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools>.
  75. - running on GPU
  76. For running on GPU, please change `device_target` from `Ascend` to `GPU` in configuration file src/config.py
  77. ```python
  78. # run training example
  79. export CUDA_VISIBLE_DEVICES=0
  80. python train.py > train.log 2>&1 &
  81. # run distributed training example
  82. sh scripts/run_train_gpu.sh 8 0,1,2,3,4,5,6,7
  83. # run evaluation example
  84. python eval.py --checkpoint_path=[CHECKPOINT_PATH] > eval.log 2>&1 &
  85. OR
  86. sh run_eval_gpu.sh [CHECKPOINT_PATH]
  87. ```
  88. We use CIFAR-10 dataset by default. Your can also pass `$dataset_type` to the scripts so that select different datasets. For more details, please refer the specify script.
  89. # [Script Description](#contents)
  90. ## [Script and Sample Code](#contents)
  91. ```text
  92. ├── model_zoo
  93. ├── README.md // descriptions about all the models
  94. ├── googlenet
  95. ├── README.md // descriptions about googlenet
  96. ├── scripts
  97. │ ├──run_train.sh // shell script for distributed on Ascend
  98. │ ├──run_train_gpu.sh // shell script for distributed on GPU
  99. │ ├──run_eval.sh // shell script for evaluation on Ascend
  100. │ ├──run_eval_gpu.sh // shell script for evaluation on GPU
  101. ├── src
  102. │ ├──dataset.py // creating dataset
  103. │ ├──googlenet.py // googlenet architecture
  104. │ ├──config.py // parameter configuration
  105. ├── train.py // training script
  106. ├── eval.py // evaluation script
  107. ├── export.py // export checkpoint files into air/onnx
  108. ```
  109. ## [Script Parameters](#contents)
  110. Parameters for both training and evaluation can be set in config.py
  111. - config for GoogleNet, CIFAR-10 dataset
  112. ```python
  113. 'pre_trained': 'False' # whether training based on the pre-trained model
  114. 'num_classes': 10 # the number of classes in the dataset
  115. 'lr_init': 0.1 # initial learning rate
  116. 'batch_size': 128 # training batch size
  117. 'epoch_size': 125 # total training epochs
  118. 'momentum': 0.9 # momentum
  119. 'weight_decay': 5e-4 # weight decay value
  120. 'image_height': 224 # image height used as input to the model
  121. 'image_width': 224 # image width used as input to the model
  122. 'data_path': './cifar10' # absolute full path to the train and evaluation datasets
  123. 'device_target': 'Ascend' # device running the program
  124. 'device_id': 0 # device ID used to train or evaluate the dataset. Ignore it when you use run_train.sh for distributed training
  125. 'keep_checkpoint_max': 10 # only keep the last keep_checkpoint_max checkpoint
  126. 'checkpoint_path': './train_googlenet_cifar10-125_390.ckpt' # the absolute full path to save the checkpoint file
  127. 'onnx_filename': 'googlenet.onnx' # file name of the onnx model used in export.py
  128. 'air_filename': 'googlenet.air' # file name of the air model used in export.py
  129. ```
  130. - config for GoogleNet, ImageNet dataset
  131. ```python
  132. 'pre_trained': 'False' # whether training based on the pre-trained model
  133. 'num_classes': 1000 # the number of classes in the dataset
  134. 'lr_init': 0.1 # initial learning rate
  135. 'batch_size': 256 # training batch size
  136. 'epoch_size': 300 # total training epochs
  137. 'momentum': 0.9 # momentum
  138. 'weight_decay': 1e-4 # weight decay value
  139. 'image_height': 224 # image height used as input to the model
  140. 'image_width': 224 # image width used as input to the model
  141. 'data_path': './ImageNet_Original/train/' # absolute full path to the train datasets
  142. 'val_data_path': './ImageNet_Original/val/' # absolute full path to the evaluation datasets
  143. 'device_target': 'Ascend' # device running the program
  144. 'device_id': 0 # device ID used to train or evaluate the dataset. Ignore it when you use run_train.sh for distributed training
  145. 'keep_checkpoint_max': 10 # only keep the last keep_checkpoint_max checkpoint
  146. 'checkpoint_path': './train_googlenet_cifar10-125_390.ckpt' # the absolute full path to save the checkpoint file
  147. 'onnx_filename': 'googlenet.onnx' # file name of the onnx model used in export.py
  148. 'air_filename': 'googlenet.air' # file name of the air model used in export.py
  149. 'lr_scheduler': 'exponential' # learning rate scheduler
  150. 'lr_epochs': [70, 140, 210, 280] # epoch of lr changing
  151. 'lr_gamma': 0.3 # decrease lr by a factor of exponential lr_scheduler
  152. 'eta_min': 0.0 # eta_min in cosine_annealing scheduler
  153. 'T_max': 150 # T-max in cosine_annealing scheduler
  154. 'warmup_epochs': 0 # warmup epoch
  155. 'is_dynamic_loss_scale': 0 # dynamic loss scale
  156. 'loss_scale': 1024 # loss scale
  157. 'label_smooth_factor': 0.1 # label_smooth_factor
  158. 'use_label_smooth': True # label smooth
  159. ```
  160. For more configuration details, please refer the script `config.py`.
  161. ## [Training Process](#contents)
  162. ### Training
  163. - running on Ascend
  164. ```python
  165. python train.py > train.log 2>&1 &
  166. ```
  167. The python command above will run in the background, you can view the results through the file `train.log`.
  168. After training, you'll get some checkpoint files under the script folder by default. The loss value will be achieved as follows:
  169. ```bash
  170. # grep "loss is " train.log
  171. epoch: 1 step: 390, loss is 1.4842823
  172. epcoh: 2 step: 390, loss is 1.0897788
  173. ...
  174. ```
  175. The model checkpoint will be saved in the current directory.
  176. - running on GPU
  177. ```python
  178. export CUDA_VISIBLE_DEVICES=0
  179. python train.py > train.log 2>&1 &
  180. ```
  181. The python command above will run in the background, you can view the results through the file `train.log`.
  182. After training, you'll get some checkpoint files under the folder `./ckpt_0/` by default.
  183. ### Distributed Training
  184. - running on Ascend
  185. ```bash
  186. sh scripts/run_train.sh rank_table.json
  187. ```
  188. The above shell script will run distribute training in the background. You can view the results through the file `train_parallel[X]/log`. The loss value will be achieved as follows:
  189. ```bash
  190. # grep "result: " train_parallel*/log
  191. train_parallel0/log:epoch: 1 step: 48, loss is 1.4302931
  192. train_parallel0/log:epcoh: 2 step: 48, loss is 1.4023874
  193. ...
  194. train_parallel1/log:epoch: 1 step: 48, loss is 1.3458025
  195. train_parallel1/log:epcoh: 2 step: 48, loss is 1.3729336
  196. ...
  197. ...
  198. ```
  199. - running on GPU
  200. ```bash
  201. sh scripts/run_train_gpu.sh 8 0,1,2,3,4,5,6,7
  202. ```
  203. The above shell script will run distribute training in the background. You can view the results through the file `train/train.log`.
  204. ## [Evaluation Process](#contents)
  205. ### Evaluation
  206. - evaluation on CIFAR-10 dataset when running on Ascend
  207. Before running the command below, please check the checkpoint path used for evaluation. Please set the checkpoint path to be the absolute full path, e.g., "username/googlenet/train_googlenet_cifar10-125_390.ckpt".
  208. ```python
  209. python eval.py > eval.log 2>&1 &
  210. OR
  211. sh scripts/run_eval.sh
  212. ```
  213. The above python command will run in the background. You can view the results through the file "eval.log". The accuracy of the test dataset will be as follows:
  214. ```bash
  215. # grep "accuracy: " eval.log
  216. accuracy: {'acc': 0.934}
  217. ```
  218. Note that for evaluation after distributed training, please set the checkpoint_path to be the last saved checkpoint file such as "username/googlenet/train_parallel0/train_googlenet_cifar10-125_48.ckpt". The accuracy of the test dataset will be as follows:
  219. ```bash
  220. # grep "accuracy: " eval.log
  221. accuracy: {'acc': 0.9217}
  222. ```
  223. - evaluation on CIFAR-10 dataset when running on GPU
  224. Before running the command below, please check the checkpoint path used for evaluation. Please set the checkpoint path to be the absolute full path, e.g., "username/googlenet/train/ckpt_0/train_googlenet_cifar10-125_390.ckpt".
  225. ```python
  226. python eval.py --checkpoint_path=[CHECKPOINT_PATH] > eval.log 2>&1 &
  227. ```
  228. The above python command will run in the background. You can view the results through the file "eval.log". The accuracy of the test dataset will be as follows:
  229. ```bash
  230. # grep "accuracy: " eval.log
  231. accuracy: {'acc': 0.930}
  232. ```
  233. OR,
  234. ```bash
  235. sh scripts/run_eval_gpu.sh [CHECKPOINT_PATH]
  236. ```
  237. The above python command will run in the background. You can view the results through the file "eval/eval.log". The accuracy of the test dataset will be as follows:
  238. ```bash
  239. # grep "accuracy: " eval/eval.log
  240. accuracy: {'acc': 0.930}
  241. ```
  242. # [Model Description](#contents)
  243. ## [Performance](#contents)
  244. ### Evaluation Performance
  245. #### GoogleNet on CIFAR-10
  246. | Parameters | Ascend | GPU |
  247. | -------------------------- | ----------------------------------------------------------- | ---------------------- |
  248. | Model Version | Inception V1 | Inception V1 |
  249. | Resource | Ascend 910; CPU 2.60GHz, 192cores; Memory 755G; OS Euler2.8 | NV SMX2 V100-32G |
  250. | uploaded Date | 10/28/2020 (month/day/year) | 10/28/2020 (month/day/year) |
  251. | MindSpore Version | 1.0.0 | 1.0.0 |
  252. | Dataset | CIFAR-10 | CIFAR-10 |
  253. | Training Parameters | epoch=125, steps=390, batch_size = 128, lr=0.1 | epoch=125, steps=390, batch_size=128, lr=0.1 |
  254. | Optimizer | Momentum | Momentum |
  255. | Loss Function | Softmax Cross Entropy | Softmax Cross Entropy |
  256. | outputs | probability | probobility |
  257. | Loss | 0.0016 | 0.0016 |
  258. | Speed | 1pc: 79 ms/step; 8pcs: 82 ms/step | 1pc: 150 ms/step; 8pcs: 164 ms/step |
  259. | Total time | 1pc: 63.85 mins; 8pcs: 11.28 mins | 1pc: 126.87 mins; 8pcs: 21.65 mins |
  260. | Parameters (M) | 13.0 | 13.0 |
  261. | Checkpoint for Fine tuning | 43.07M (.ckpt file) | 43.07M (.ckpt file) |
  262. | Model for inference | 21.50M (.onnx file), 21.60M(.air file) | |
  263. | Scripts | [googlenet script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/googlenet) | [googlenet script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/googlenet) |
  264. #### GoogleNet on 1200k images
  265. | Parameters | Ascend |
  266. | -------------------------- | ----------------------------------------------------------- |
  267. | Model Version | Inception V1 |
  268. | Resource | Ascend 910; CPU 2.60GHz, 56cores; Memory 314G; OS Euler2.8 |
  269. | uploaded Date | 10/28/2020 (month/day/year) |
  270. | MindSpore Version | 1.0.0 |
  271. | Dataset | 1200k images |
  272. | Training Parameters | epoch=300, steps=5000, batch_size=256, lr=0.1 |
  273. | Optimizer | Momentum |
  274. | Loss Function | Softmax Cross Entropy |
  275. | outputs | probability |
  276. | Loss | 2.0 |
  277. | Speed | 1pc: 152 ms/step; 8pcs: 171 ms/step |
  278. | Total time | 8pcs: 8.8 hours |
  279. | Parameters (M) | 13.0 |
  280. | Checkpoint for Fine tuning | 52M (.ckpt file) |
  281. | Scripts | [googlenet script](https://gitee.com/mindspore/mindspore/tree/r0.7/model_zoo/official/cv/googlenet) |
  282. ### Inference Performance
  283. #### GoogleNet on CIFAR-10
  284. | Parameters | Ascend | GPU |
  285. | ------------------- | --------------------------- | --------------------------- |
  286. | Model Version | Inception V1 | Inception V1 |
  287. | Resource | Ascend 910; OS Euler2.8 | GPU |
  288. | Uploaded Date | 10/28/2020 (month/day/year) | 10/28/2020 (month/day/year) |
  289. | MindSpore Version | 1.0.0 | 1.0.0 |
  290. | Dataset | CIFAR-10, 10,000 images | CIFAR-10, 10,000 images |
  291. | batch_size | 128 | 128 |
  292. | outputs | probability | probability |
  293. | Accuracy | 1pc: 93.4%; 8pcs: 92.17% | 1pc: 93%, 8pcs: 92.89% |
  294. | Model for inference | 21.50M (.onnx file) | |
  295. #### GoogleNet on 1200k images
  296. | Parameters | Ascend |
  297. | ------------------- | --------------------------- |
  298. | Model Version | Inception V1 |
  299. | Resource | Ascend 910; OS Euler2.8 |
  300. | Uploaded Date | 10/28/2020 (month/day/year) |
  301. | MindSpore Version | 1.0.0 |
  302. | Dataset | 1200k images |
  303. | batch_size | 256 |
  304. | outputs | probability |
  305. | Accuracy | 8pcs: 71.81% |
  306. ## [How to use](#contents)
  307. ### Inference
  308. If you need to use the trained model to perform inference on multiple hardware platforms, such as GPU, Ascend 910 or Ascend 310, you can refer to this [Link](https://www.mindspore.cn/tutorial/training/en/master/advanced_use/migrate_3rd_scripts.html). Following the steps below, this is a simple example:
  309. - Running on Ascend
  310. ```python
  311. # Set context
  312. context.set_context(mode=context.GRAPH_HOME, device_target=cfg.device_target)
  313. context.set_context(device_id=cfg.device_id)
  314. # Load unseen dataset for inference
  315. dataset = dataset.create_dataset(cfg.data_path, 1, False)
  316. # Define model
  317. net = GoogleNet(num_classes=cfg.num_classes)
  318. opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01,
  319. cfg.momentum, weight_decay=cfg.weight_decay)
  320. loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
  321. model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'})
  322. # Load pre-trained model
  323. param_dict = load_checkpoint(cfg.checkpoint_path)
  324. load_param_into_net(net, param_dict)
  325. net.set_train(False)
  326. # Make predictions on the unseen dataset
  327. acc = model.eval(dataset)
  328. print("accuracy: ", acc)
  329. ```
  330. - Running on GPU:
  331. ```python
  332. # Set context
  333. context.set_context(mode=context.GRAPH_HOME, device_target="GPU")
  334. # Load unseen dataset for inference
  335. dataset = dataset.create_dataset(cfg.data_path, 1, False)
  336. # Define model
  337. net = GoogleNet(num_classes=cfg.num_classes)
  338. opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01,
  339. cfg.momentum, weight_decay=cfg.weight_decay)
  340. loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
  341. model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'})
  342. # Load pre-trained model
  343. param_dict = load_checkpoint(args_opt.checkpoint_path)
  344. load_param_into_net(net, param_dict)
  345. net.set_train(False)
  346. # Make predictions on the unseen dataset
  347. acc = model.eval(dataset)
  348. print("accuracy: ", acc)
  349. ```
  350. ### Continue Training on the Pretrained Model
  351. - running on Ascend
  352. ```python
  353. # Load dataset
  354. dataset = create_dataset(cfg.data_path, 1)
  355. batch_num = dataset.get_dataset_size()
  356. # Define model
  357. net = GoogleNet(num_classes=cfg.num_classes)
  358. # Continue training if set pre_trained to be True
  359. if cfg.pre_trained:
  360. param_dict = load_checkpoint(cfg.checkpoint_path)
  361. load_param_into_net(net, param_dict)
  362. lr = lr_steps(0, lr_max=cfg.lr_init, total_epochs=cfg.epoch_size,
  363. steps_per_epoch=batch_num)
  364. opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()),
  365. Tensor(lr), cfg.momentum, weight_decay=cfg.weight_decay)
  366. loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
  367. model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'},
  368. amp_level="O2", keep_batchnorm_fp32=False, loss_scale_manager=None)
  369. # Set callbacks
  370. config_ck = CheckpointConfig(save_checkpoint_steps=batch_num * 5,
  371. keep_checkpoint_max=cfg.keep_checkpoint_max)
  372. time_cb = TimeMonitor(data_size=batch_num)
  373. ckpoint_cb = ModelCheckpoint(prefix="train_googlenet_cifar10", directory="./",
  374. config=config_ck)
  375. loss_cb = LossMonitor()
  376. # Start training
  377. model.train(cfg.epoch_size, dataset, callbacks=[time_cb, ckpoint_cb, loss_cb])
  378. print("train success")
  379. ```
  380. - running on GPU
  381. ```python
  382. # Load dataset
  383. dataset = create_dataset(cfg.data_path, 1)
  384. batch_num = dataset.get_dataset_size()
  385. # Define model
  386. net = GoogleNet(num_classes=cfg.num_classes)
  387. # Continue training if set pre_trained to be True
  388. if cfg.pre_trained:
  389. param_dict = load_checkpoint(cfg.checkpoint_path)
  390. load_param_into_net(net, param_dict)
  391. lr = lr_steps(0, lr_max=cfg.lr_init, total_epochs=cfg.epoch_size,
  392. steps_per_epoch=batch_num)
  393. opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()),
  394. Tensor(lr), cfg.momentum, weight_decay=cfg.weight_decay)
  395. loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
  396. model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'},
  397. amp_level="O2", keep_batchnorm_fp32=False, loss_scale_manager=None)
  398. # Set callbacks
  399. config_ck = CheckpointConfig(save_checkpoint_steps=batch_num * 5,
  400. keep_checkpoint_max=cfg.keep_checkpoint_max)
  401. time_cb = TimeMonitor(data_size=batch_num)
  402. ckpoint_cb = ModelCheckpoint(prefix="train_googlenet_cifar10", directory="./ckpt_" + str(get_rank()) + "/",
  403. config=config_ck)
  404. loss_cb = LossMonitor()
  405. # Start training
  406. model.train(cfg.epoch_size, dataset, callbacks=[time_cb, ckpoint_cb, loss_cb])
  407. print("train success")
  408. ```
  409. ### Transfer Learning
  410. To be added.
  411. # [Description of Random Situation](#contents)
  412. In dataset.py, we set the seed inside “create_dataset" function. We also use random seed in train.py.
  413. # [ModelZoo Homepage](#contents)
  414. Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).