You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

README.md 7.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248
  1. # LeNet Quantization Aware Training
  2. ## Description
  3. Training LeNet with MNIST dataset in MindSpore with quantization aware training.
  4. This is the simple and basic tutorial for constructing a network in MindSpore with quantization aware.
  5. In this tutorial, you will:
  6. 1. Train a MindSpore fusion model for MNIST from scratch using `nn.Conv2dBnAct` and `nn.DenseBnAct`.
  7. 2. Fine tune the fusion model by applying the quantization aware training auto network converter API `convert_quant_network`, after the network convergence then export a quantization aware model checkpoint file.
  8. 3. Use the quantization aware model to create an actually quantized model for the Ascend inference backend.
  9. 4. See the persistence of accuracy in inference backend and a 4x smaller model. To see the latency benefits on mobile, try out the Ascend inference backend examples.
  10. ## Train fusion model
  11. ### Install
  12. Install MindSpore base on the ascend device and GPU device from [MindSpore](https://www.mindspore.cn/install/en).
  13. ```python
  14. pip uninstall -y mindspore-ascend
  15. pip uninstall -y mindspore-gpu
  16. pip install mindspore-ascend.whl
  17. ```
  18. Then you will get the following display
  19. ```bash
  20. >>> Found existing installation: mindspore-ascend
  21. >>> Uninstalling mindspore-ascend:
  22. >>> Successfully uninstalled mindspore-ascend.
  23. ```
  24. ### Prepare Dataset
  25. Download the MNIST dataset, the directory structure is as follows:
  26. ```
  27. └─MNIST_Data
  28. ├─test
  29. │ t10k-images.idx3-ubyte
  30. │ t10k-labels.idx1-ubyte
  31. └─train
  32. train-images.idx3-ubyte
  33. train-labels.idx1-ubyte
  34. ```
  35. ### Define fusion model
  36. Define a MindSpore fusion model using `nn.Conv2dBnAct` and `nn.DenseBnAct`.
  37. ```Python
  38. class LeNet5(nn.Cell):
  39. """
  40. Define Lenet fusion model
  41. """
  42. def __init__(self, num_class=10, channel=1):
  43. super(LeNet5, self).__init__()
  44. self.num_class = num_class
  45. # change `nn.Conv2d` to `nn.Conv2dBnAct`
  46. self.conv1 = nn.Conv2dBnAct(channel, 6, 5, activation='relu')
  47. self.conv2 = nn.Conv2dBnAct(6, 16, 5, activation='relu')
  48. # change `nn.Dense` to `nn.DenseBnAct`
  49. self.fc1 = nn.DenseBnAct(16 * 5 * 5, 120, activation='relu')
  50. self.fc2 = nn.DenseBnAct(120, 84, activation='relu')
  51. self.fc3 = nn.DenseBnAct(84, self.num_class)
  52. self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
  53. self.flatten = nn.Flatten()
  54. def construct(self, x):
  55. x = self.conv1(x)
  56. x = self.max_pool2d(x)
  57. x = self.conv2(x)
  58. x = self.max_pool2d(x)
  59. x = self.flatten(x)
  60. x = self.fc1(x)
  61. x = self.fc2(x)
  62. x = self.fc3(x)
  63. return x
  64. ```
  65. Get the MNIST from scratch dataset.
  66. ```Python
  67. ds_train = create_dataset(os.path.join(args.data_path, "train"),
  68. cfg.batch_size, cfg.epoch_size)
  69. step_size = ds_train.get_dataset_size()
  70. ```
  71. ### Train model
  72. Load the Lenet fusion network, training network using loss `nn.SoftmaxCrossEntropyWithLogits` with optimization `nn.Momentum`.
  73. ```Python
  74. # Define the network
  75. network = LeNet5Fusion(cfg.num_classes)
  76. # Define the loss
  77. net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
  78. # Define optimization
  79. net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
  80. # Define model using loss and optimization.
  81. time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
  82. config_ck = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size,
  83. keep_checkpoint_max=cfg.keep_checkpoint_max)
  84. ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ck)
  85. model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
  86. ```
  87. Now we can start training.
  88. ```Python
  89. model.train(cfg['epoch_size'], ds_train,
  90. callbacks=[time_cb, ckpoint_cb, LossMonitor()],
  91. dataset_sink_mode=args.dataset_sink_mode)
  92. ```
  93. After all the following we will get the loss value of each step as following:
  94. ```bash
  95. >>> Epoch: [ 1/ 10] step: [ 1/ 900], loss: [2.3040/2.5234], time: [1.300234]
  96. >>> ...
  97. >>> Epoch: [ 9/ 10] step: [887/ 900], loss: [0.0113/0.0223], time: [1.300234]
  98. >>> Epoch: [ 9/ 10] step: [888/ 900], loss: [0.0334/0.0223], time: [1.300234]
  99. >>> Epoch: [ 9/ 10] step: [889/ 900], loss: [0.0233/0.0223], time: [1.300234]
  100. ```
  101. Also, you can just run this command instead.
  102. ```python
  103. python train.py --data_path MNIST_Data --device_target Ascend
  104. ```
  105. ### Evaluate fusion model
  106. After training epoch stop. We can get the fusion model checkpoint file like `checkpoint_lenet.ckpt`. Meanwhile, we can evaluate this fusion model.
  107. ```python
  108. python eval.py --data_path MNIST_Data --device_target Ascend --ckpt_path checkpoint_lenet.ckpt
  109. ```
  110. The top1 accuracy would display on shell.
  111. ```bash
  112. >>> Accuracy: 98.53.
  113. ```
  114. ## Train quantization aware model
  115. ### Define quantization aware model
  116. You will apply quantization aware training to the whole model and the layers of "fake quant op" are insert into the whole model. All layers are now perpare by "fake quant op".
  117. Note that the resulting model is quantization aware but not quantized (e.g. the weights are float32 instead of int8).
  118. ```python
  119. # define funsion network
  120. network = LeNet5Fusion(cfg.num_classes)
  121. # load quantization aware network checkpoint
  122. param_dict = load_checkpoint(args.ckpt_path)
  123. load_param_into_net(network, param_dict)
  124. # convert funsion netwrok to quantization aware network
  125. network = quant.convert_quant_network(network)
  126. ```
  127. ### load checkpoint
  128. After convert to quantization aware network, we can load the checkpoint file.
  129. ```python
  130. config_ck = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size,
  131. keep_checkpoint_max=cfg.keep_checkpoint_max)
  132. ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ck)
  133. model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
  134. ```
  135. ### train quantization aware model
  136. Also, you can just run this command instead.
  137. ```python
  138. python train_quant.py --data_path MNIST_Data --device_target Ascend --ckpt_path checkpoint_lenet.ckpt
  139. ```
  140. After all the following we will get the loss value of each step as following:
  141. ```bash
  142. >>> Epoch: [ 1/ 10] step: [ 1/ 900], loss: [2.3040/2.5234], time: [1.300234]
  143. >>> ...
  144. >>> Epoch: [ 9/ 10] step: [887/ 900], loss: [0.0113/0.0223], time: [1.300234]
  145. >>> Epoch: [ 9/ 10] step: [888/ 900], loss: [0.0334/0.0223], time: [1.300234]
  146. >>> Epoch: [ 9/ 10] step: [889/ 900], loss: [0.0233/0.0223], time: [1.300234]
  147. ```
  148. ### Evaluate quantization aware model
  149. Procedure of quantization aware model evaluation is different from normal. Because the checkpoint was create by quantization aware model, so we need to load fusion model checkpoint before convert fusion model to quantization aware model.
  150. ```python
  151. # define funsion network
  152. network = LeNet5Fusion(cfg.num_classes)
  153. # load quantization aware network checkpoint
  154. param_dict = load_checkpoint(args.ckpt_path)
  155. load_param_into_net(network, param_dict)
  156. # convert funsion netwrok to quantization aware network
  157. network = quant.convert_quant_network(network)
  158. ```
  159. Also, you can just run this command insread.
  160. ```python
  161. python eval_quant.py --data_path MNIST_Data --device_target Ascend --ckpt_path checkpoint_lenet.ckpt
  162. ```
  163. The top1 accuracy would display on shell.
  164. ```bash
  165. >>> Accuracy: 98.54.
  166. ```
  167. ## Note
  168. Here are some optional parameters:
  169. ```bash
  170. --device_target {Ascend,GPU}
  171. device where the code will be implemented (default: Ascend)
  172. --data_path DATA_PATH
  173. path where the dataset is saved
  174. --dataset_sink_mode DATASET_SINK_MODE
  175. dataset_sink_mode is False or True
  176. ```
  177. You can run ```python train.py -h``` or ```python eval.py -h``` to get more information.
  178. We encourage you to try this new capability, which can be particularly important for deployment in resource-constrained environments.