You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train.py 11 kB

8 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281
  1. from core.image_reader import TrainImageReader
  2. import datetime
  3. import os
  4. from core.models import PNet,RNet,ONet,LossFn
  5. import torch
  6. from torch.autograd import Variable
  7. import core.image_tools as image_tools
  8. def compute_accuracy(prob_cls, gt_cls):
  9. prob_cls = torch.squeeze(prob_cls)
  10. gt_cls = torch.squeeze(gt_cls)
  11. #we only need the detection which >= 0
  12. mask = torch.ge(gt_cls,0)
  13. #get valid element
  14. valid_gt_cls = torch.masked_select(gt_cls,mask)
  15. valid_prob_cls = torch.masked_select(prob_cls,mask)
  16. size = min(valid_gt_cls.size()[0], valid_prob_cls.size()[0])
  17. prob_ones = torch.ge(valid_prob_cls,0.6).float()
  18. right_ones = torch.eq(prob_ones,valid_gt_cls).float()
  19. return torch.div(torch.mul(torch.sum(right_ones),float(1.0)),float(size))
  20. def train_pnet(model_store_path, end_epoch,imdb,
  21. batch_size,frequent=50,base_lr=0.01,use_cuda=True):
  22. if not os.path.exists(model_store_path):
  23. os.makedirs(model_store_path)
  24. lossfn = LossFn()
  25. net = PNet(is_train=True, use_cuda=use_cuda)
  26. net.train()
  27. if use_cuda:
  28. net.cuda()
  29. optimizer = torch.optim.Adam(net.parameters(), lr=base_lr)
  30. train_data=TrainImageReader(imdb,12,batch_size,shuffle=True)
  31. for cur_epoch in range(1,end_epoch+1):
  32. train_data.reset()
  33. accuracy_list=[]
  34. cls_loss_list=[]
  35. bbox_loss_list=[]
  36. # landmark_loss_list=[]
  37. for batch_idx,(image,(gt_label,gt_bbox,gt_landmark))in enumerate(train_data):
  38. im_tensor = [ image_tools.convert_image_to_tensor(image[i,:,:,:]) for i in range(image.shape[0]) ]
  39. im_tensor = torch.stack(im_tensor)
  40. im_tensor = Variable(im_tensor)
  41. gt_label = Variable(torch.from_numpy(gt_label).float())
  42. gt_bbox = Variable(torch.from_numpy(gt_bbox).float())
  43. # gt_landmark = Variable(torch.from_numpy(gt_landmark).float())
  44. if use_cuda:
  45. im_tensor = im_tensor.cuda()
  46. gt_label = gt_label.cuda()
  47. gt_bbox = gt_bbox.cuda()
  48. # gt_landmark = gt_landmark.cuda()
  49. cls_pred, box_offset_pred = net(im_tensor)
  50. # all_loss, cls_loss, offset_loss = lossfn.loss(gt_label=label_y,gt_offset=bbox_y, pred_label=cls_pred, pred_offset=box_offset_pred)
  51. cls_loss = lossfn.cls_loss(gt_label,cls_pred)
  52. box_offset_loss = lossfn.box_loss(gt_label,gt_bbox,box_offset_pred)
  53. # landmark_loss = lossfn.landmark_loss(gt_label,gt_landmark,landmark_offset_pred)
  54. all_loss = cls_loss*1.0+box_offset_loss*0.5
  55. if batch_idx%frequent==0:
  56. accuracy=compute_accuracy(cls_pred,gt_label)
  57. show1 = accuracy.data.tolist()[0]
  58. show2 = cls_loss.data.tolist()[0]
  59. show3 = box_offset_loss.data.tolist()[0]
  60. show5 = all_loss.data.tolist()[0]
  61. print "%s : Epoch: %d, Step: %d, accuracy: %s, det loss: %s, bbox loss: %s, all_loss: %s, lr:%s "%(datetime.datetime.now(),cur_epoch,batch_idx, show1,show2,show3,show5,base_lr)
  62. accuracy_list.append(accuracy)
  63. cls_loss_list.append(cls_loss)
  64. bbox_loss_list.append(box_offset_loss)
  65. optimizer.zero_grad()
  66. all_loss.backward()
  67. optimizer.step()
  68. accuracy_avg = torch.mean(torch.cat(accuracy_list))
  69. cls_loss_avg = torch.mean(torch.cat(cls_loss_list))
  70. bbox_loss_avg = torch.mean(torch.cat(bbox_loss_list))
  71. # landmark_loss_avg = torch.mean(torch.cat(landmark_loss_list))
  72. show6 = accuracy_avg.data.tolist()[0]
  73. show7 = cls_loss_avg.data.tolist()[0]
  74. show8 = bbox_loss_avg.data.tolist()[0]
  75. print "Epoch: %d, accuracy: %s, cls loss: %s, bbox loss: %s" % (cur_epoch, show6, show7, show8)
  76. torch.save(net.state_dict(), os.path.join(model_store_path,"pnet_epoch_%d.pt" % cur_epoch))
  77. torch.save(net, os.path.join(model_store_path,"pnet_epoch_model_%d.pkl" % cur_epoch))
  78. def train_rnet(model_store_path, end_epoch,imdb,
  79. batch_size,frequent=50,base_lr=0.01,use_cuda=True):
  80. if not os.path.exists(model_store_path):
  81. os.makedirs(model_store_path)
  82. lossfn = LossFn()
  83. net = RNet(is_train=True, use_cuda=use_cuda)
  84. net.train()
  85. if use_cuda:
  86. net.cuda()
  87. optimizer = torch.optim.Adam(net.parameters(), lr=base_lr)
  88. train_data=TrainImageReader(imdb,24,batch_size,shuffle=True)
  89. for cur_epoch in range(1,end_epoch+1):
  90. train_data.reset()
  91. accuracy_list=[]
  92. cls_loss_list=[]
  93. bbox_loss_list=[]
  94. landmark_loss_list=[]
  95. for batch_idx,(image,(gt_label,gt_bbox,gt_landmark))in enumerate(train_data):
  96. im_tensor = [ image_tools.convert_image_to_tensor(image[i,:,:,:]) for i in range(image.shape[0]) ]
  97. im_tensor = torch.stack(im_tensor)
  98. im_tensor = Variable(im_tensor)
  99. gt_label = Variable(torch.from_numpy(gt_label).float())
  100. gt_bbox = Variable(torch.from_numpy(gt_bbox).float())
  101. gt_landmark = Variable(torch.from_numpy(gt_landmark).float())
  102. if use_cuda:
  103. im_tensor = im_tensor.cuda()
  104. gt_label = gt_label.cuda()
  105. gt_bbox = gt_bbox.cuda()
  106. gt_landmark = gt_landmark.cuda()
  107. cls_pred, box_offset_pred = net(im_tensor)
  108. # all_loss, cls_loss, offset_loss = lossfn.loss(gt_label=label_y,gt_offset=bbox_y, pred_label=cls_pred, pred_offset=box_offset_pred)
  109. cls_loss = lossfn.cls_loss(gt_label,cls_pred)
  110. box_offset_loss = lossfn.box_loss(gt_label,gt_bbox,box_offset_pred)
  111. # landmark_loss = lossfn.landmark_loss(gt_label,gt_landmark,landmark_offset_pred)
  112. all_loss = cls_loss*1.0+box_offset_loss*0.5
  113. if batch_idx%frequent==0:
  114. accuracy=compute_accuracy(cls_pred,gt_label)
  115. show1 = accuracy.data.tolist()[0]
  116. show2 = cls_loss.data.tolist()[0]
  117. show3 = box_offset_loss.data.tolist()[0]
  118. # show4 = landmark_loss.data.tolist()[0]
  119. show5 = all_loss.data.tolist()[0]
  120. print "%s : Epoch: %d, Step: %d, accuracy: %s, det loss: %s, bbox loss: %s, all_loss: %s, lr:%s "%(datetime.datetime.now(), cur_epoch, batch_idx, show1, show2, show3, show5, base_lr)
  121. accuracy_list.append(accuracy)
  122. cls_loss_list.append(cls_loss)
  123. bbox_loss_list.append(box_offset_loss)
  124. # landmark_loss_list.append(landmark_loss)
  125. optimizer.zero_grad()
  126. all_loss.backward()
  127. optimizer.step()
  128. accuracy_avg = torch.mean(torch.cat(accuracy_list))
  129. cls_loss_avg = torch.mean(torch.cat(cls_loss_list))
  130. bbox_loss_avg = torch.mean(torch.cat(bbox_loss_list))
  131. # landmark_loss_avg = torch.mean(torch.cat(landmark_loss_list))
  132. show6 = accuracy_avg.data.tolist()[0]
  133. show7 = cls_loss_avg.data.tolist()[0]
  134. show8 = bbox_loss_avg.data.tolist()[0]
  135. # show9 = landmark_loss_avg.data.tolist()[0]
  136. print "Epoch: %d, accuracy: %s, cls loss: %s, bbox loss: %s" % (cur_epoch, show6, show7, show8)
  137. torch.save(net.state_dict(), os.path.join(model_store_path,"rnet_epoch_%d.pt" % cur_epoch))
  138. torch.save(net, os.path.join(model_store_path,"rnet_epoch_model_%d.pkl" % cur_epoch))
  139. def train_onet(model_store_path, end_epoch,imdb,
  140. batch_size,frequent=50,base_lr=0.01,use_cuda=True):
  141. if not os.path.exists(model_store_path):
  142. os.makedirs(model_store_path)
  143. lossfn = LossFn()
  144. net = ONet(is_train=True)
  145. net.train()
  146. if use_cuda:
  147. net.cuda()
  148. optimizer = torch.optim.Adam(net.parameters(), lr=base_lr)
  149. train_data=TrainImageReader(imdb,48,batch_size,shuffle=True)
  150. for cur_epoch in range(1,end_epoch+1):
  151. train_data.reset()
  152. accuracy_list=[]
  153. cls_loss_list=[]
  154. bbox_loss_list=[]
  155. landmark_loss_list=[]
  156. for batch_idx,(image,(gt_label,gt_bbox,gt_landmark))in enumerate(train_data):
  157. im_tensor = [ image_tools.convert_image_to_tensor(image[i,:,:,:]) for i in range(image.shape[0]) ]
  158. im_tensor = torch.stack(im_tensor)
  159. im_tensor = Variable(im_tensor)
  160. gt_label = Variable(torch.from_numpy(gt_label).float())
  161. gt_bbox = Variable(torch.from_numpy(gt_bbox).float())
  162. gt_landmark = Variable(torch.from_numpy(gt_landmark).float())
  163. if use_cuda:
  164. im_tensor = im_tensor.cuda()
  165. gt_label = gt_label.cuda()
  166. gt_bbox = gt_bbox.cuda()
  167. gt_landmark = gt_landmark.cuda()
  168. cls_pred, box_offset_pred, landmark_offset_pred = net(im_tensor)
  169. # all_loss, cls_loss, offset_loss = lossfn.loss(gt_label=label_y,gt_offset=bbox_y, pred_label=cls_pred, pred_offset=box_offset_pred)
  170. cls_loss = lossfn.cls_loss(gt_label,cls_pred)
  171. box_offset_loss = lossfn.box_loss(gt_label,gt_bbox,box_offset_pred)
  172. landmark_loss = lossfn.landmark_loss(gt_label,gt_landmark,landmark_offset_pred)
  173. all_loss = cls_loss*0.8+box_offset_loss*0.6+landmark_loss*1.5
  174. if batch_idx%frequent==0:
  175. accuracy=compute_accuracy(cls_pred,gt_label)
  176. show1 = accuracy.data.tolist()[0]
  177. show2 = cls_loss.data.tolist()[0]
  178. show3 = box_offset_loss.data.tolist()[0]
  179. show4 = landmark_loss.data.tolist()[0]
  180. show5 = all_loss.data.tolist()[0]
  181. print "%s : Epoch: %d, Step: %d, accuracy: %s, det loss: %s, bbox loss: %s, landmark loss: %s, all_loss: %s, lr:%s "%(datetime.datetime.now(),cur_epoch,batch_idx, show1,show2,show3,show4,show5,base_lr)
  182. accuracy_list.append(accuracy)
  183. cls_loss_list.append(cls_loss)
  184. bbox_loss_list.append(box_offset_loss)
  185. landmark_loss_list.append(landmark_loss)
  186. optimizer.zero_grad()
  187. all_loss.backward()
  188. optimizer.step()
  189. accuracy_avg = torch.mean(torch.cat(accuracy_list))
  190. cls_loss_avg = torch.mean(torch.cat(cls_loss_list))
  191. bbox_loss_avg = torch.mean(torch.cat(bbox_loss_list))
  192. landmark_loss_avg = torch.mean(torch.cat(landmark_loss_list))
  193. show6 = accuracy_avg.data.tolist()[0]
  194. show7 = cls_loss_avg.data.tolist()[0]
  195. show8 = bbox_loss_avg.data.tolist()[0]
  196. show9 = landmark_loss_avg.data.tolist()[0]
  197. print "Epoch: %d, accuracy: %s, cls loss: %s, bbox loss: %s, landmark loss: %s " % (cur_epoch, show6, show7, show8, show9)
  198. torch.save(net.state_dict(), os.path.join(model_store_path,"onet_epoch_%d.pt" % cur_epoch))
  199. torch.save(net, os.path.join(model_store_path,"onet_epoch_model_%d.pkl" % cur_epoch))

开源的深度学习人脸检测和人脸识别系统。所有功能都采用 pytorch 框架开发。pytorch是一个由facebook开发的深度学习框架,它包含了一些比较有趣的高级特性,例如自动求导,动态构图等。DFace天然的继承了这些优点,使得它的训练过程可以更加简单方便,并且实现的代码可以更加清晰易懂

Contributors (1)