You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

eval.py 7.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Face Quality Assessment eval."""
  16. import os
  17. import warnings
  18. import argparse
  19. import numpy as np
  20. from tqdm import tqdm
  21. import cv2
  22. import mindspore.nn as nn
  23. from mindspore import Tensor
  24. from mindspore.train.serialization import load_checkpoint, load_param_into_net
  25. from mindspore.ops import operations as P
  26. from mindspore import context
  27. from src.face_qa import FaceQABackbone
  28. warnings.filterwarnings('ignore')
  29. devid = int(os.getenv('DEVICE_ID'))
  30. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=devid)
  31. def softmax(x):
  32. """Compute softmax values for each sets of scores in x."""
  33. return np.exp(x) / np.sum(np.exp(x), axis=1)
  34. def get_md_output(out):
  35. '''get md output'''
  36. out_eul = out[0].asnumpy().astype(np.float32)[0]
  37. heatmap = out[1].asnumpy().astype(np.float32)[0]
  38. eulers = out_eul * 90
  39. kps_score_sum = 0
  40. kp_scores = list()
  41. kp_coord_ori = list()
  42. for i, _ in enumerate(heatmap):
  43. map_1 = heatmap[i].reshape(1, 48*48)
  44. map_1 = softmax(map_1)
  45. kp_coor = map_1.argmax()
  46. max_response = map_1.max()
  47. kp_scores.append(max_response)
  48. kps_score_sum += min(max_response, 0.25)
  49. kp_coor = int((kp_coor % 48) * 2.0), int((kp_coor / 48) * 2.0)
  50. kp_coord_ori.append(kp_coor)
  51. return kp_scores, kps_score_sum, kp_coord_ori, eulers, 1
  52. def read_gt(txt_path, x_length, y_length):
  53. '''read gt'''
  54. txt_line = open(txt_path).readline()
  55. eulers_txt = txt_line.strip().split(" ")[:3]
  56. kp_list = [[-1, -1], [-1, -1], [-1, -1], [-1, -1], [-1, -1]]
  57. box_cur = txt_line.strip().split(" ")[3:]
  58. bndbox = []
  59. for index in range(len(box_cur) // 2):
  60. bndbox.append([box_cur[index * 2], box_cur[index * 2 + 1]])
  61. kp_id = -1
  62. for box in bndbox:
  63. kp_id = kp_id + 1
  64. x_coord = float(box[0])
  65. y_coord = float(box[1])
  66. if x_coord < 0 or y_coord < 0:
  67. continue
  68. kp_list[kp_id][0] = int(float(x_coord) / x_length * 96)
  69. kp_list[kp_id][1] = int(float(y_coord) / y_length * 96)
  70. return eulers_txt, kp_list
  71. def read_img(img_path):
  72. img_ori = cv2.imread(img_path)
  73. img = cv2.cvtColor(img_ori, cv2.COLOR_BGR2RGB)
  74. img = cv2.resize(img, (96, 96))
  75. img = img.transpose(2, 0, 1)
  76. img = np.array([img]).astype(np.float32)/255.
  77. img = Tensor(img)
  78. return img, img_ori
  79. blur_soft = nn.Softmax(0)
  80. kps_soft = nn.Softmax(-1)
  81. reshape = P.Reshape()
  82. argmax = P.ArgMaxWithValue()
  83. def test_trains(args):
  84. '''test trains'''
  85. print('----eval----begin----')
  86. model_path = args.pretrained
  87. result_file = model_path.replace('.ckpt', '.txt')
  88. if os.path.exists(result_file):
  89. os.remove(result_file)
  90. epoch_result = open(result_file, 'a')
  91. epoch_result.write(model_path + '\n')
  92. network = FaceQABackbone()
  93. ckpt_path = model_path
  94. if os.path.isfile(ckpt_path):
  95. param_dict = load_checkpoint(ckpt_path)
  96. param_dict_new = {}
  97. for key, values in param_dict.items():
  98. if key.startswith('moments.'):
  99. continue
  100. elif key.startswith('network.'):
  101. param_dict_new[key[8:]] = values
  102. else:
  103. param_dict_new[key] = values
  104. load_param_into_net(network, param_dict_new)
  105. else:
  106. print('wrong model path')
  107. return 1
  108. path = args.eval_dir
  109. kp_error_all = [[], [], [], [], []]
  110. eulers_error_all = [[], [], []]
  111. kp_ipn = []
  112. file_list = os.listdir(path)
  113. for file_name in tqdm(file_list):
  114. if file_name.endswith('jpg'):
  115. img_path = os.path.join(path, file_name)
  116. img, img_ori = read_img(img_path)
  117. txt_path = img_path.replace('jpg', 'txt')
  118. if os.path.exists(txt_path):
  119. euler_kps_do = True
  120. x_length = img_ori.shape[1]
  121. y_length = img_ori.shape[0]
  122. eulers_gt, kp_list = read_gt(txt_path, x_length, y_length)
  123. else:
  124. euler_kps_do = False
  125. continue
  126. out = network(img)
  127. _, _, kp_coord_ori, eulers_ori, _ = get_md_output(out)
  128. if euler_kps_do:
  129. eulgt = list(eulers_gt)
  130. for euler_id, _ in enumerate(eulers_ori):
  131. eulori = eulers_ori[euler_id]
  132. eulers_error_all[euler_id].append(abs(eulori-float(eulgt[euler_id])))
  133. eye01 = kp_list[0]
  134. eye02 = kp_list[1]
  135. eye_dis = 1
  136. cur_flag = True
  137. if eye01[0] < 0 or eye01[1] < 0 or eye02[0] < 0 or eye02[1] < 0:
  138. cur_flag = False
  139. else:
  140. eye_dis = np.sqrt(np.square(abs(eye01[0]-eye02[0]))+np.square(abs(eye01[1]-eye02[1])))
  141. cur_error_list = []
  142. for i in range(5):
  143. kp_coord_gt = kp_list[i]
  144. kp_coord_model = kp_coord_ori[i]
  145. if kp_coord_gt[0] != -1:
  146. dis = np.sqrt(np.square(
  147. kp_coord_gt[0] - kp_coord_model[0]) + np.square(kp_coord_gt[1] - kp_coord_model[1]))
  148. kp_error_all[i].append(dis)
  149. cur_error_list.append(dis)
  150. if cur_flag:
  151. kp_ipn.append(sum(cur_error_list)/len(cur_error_list)/eye_dis)
  152. kp_ave_error = []
  153. for kps, _ in enumerate(kp_error_all):
  154. kp_ave_error.append("%.3f" % (sum(kp_error_all[kps])/len(kp_error_all[kps])))
  155. euler_ave_error = []
  156. elur_mae = []
  157. for eulers, _ in enumerate(eulers_error_all):
  158. euler_ave_error.append("%.3f" % (sum(eulers_error_all[eulers])/len(eulers_error_all[eulers])))
  159. elur_mae.append((sum(eulers_error_all[eulers])/len(eulers_error_all[eulers])))
  160. print(r'5 keypoints average err:'+str(kp_ave_error))
  161. print(r'3 eulers average err:'+str(euler_ave_error))
  162. print('IPN of 5 keypoints:'+str(sum(kp_ipn)/len(kp_ipn)*100))
  163. print('MAE of elur:'+str(sum(elur_mae)/len(elur_mae)))
  164. epoch_result.write(str(sum(kp_ipn)/len(kp_ipn)*100)+'\t'+str(sum(elur_mae)/len(elur_mae))+'\t'
  165. + str(kp_ave_error)+'\t'+str(euler_ave_error)+'\n')
  166. print('----eval----end----')
  167. return 0
  168. if __name__ == "__main__":
  169. parser = argparse.ArgumentParser(description='Face Quality Assessment')
  170. parser.add_argument('--eval_dir', type=str, default='', help='eval image dir, e.g. /home/test')
  171. parser.add_argument('--pretrained', type=str, default='', help='pretrained model to load')
  172. arg = parser.parse_args()
  173. test_trains(arg)