You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

eval.py 18 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. # Licensed under the Apache License, Version 2.0 (the "License");
  3. # you may not use this file except in compliance with the License.
  4. # You may obtain a copy of the License at
  5. # http://www.apache.org/licenses/LICENSE-2.0
  6. # Unless required by applicable law or agreed to in writing, software
  7. # distributed under the License is distributed on an "AS IS" BASIS,
  8. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. # See the License for the specific language governing permissions and
  10. # limitations under the License.
  11. # ============================================================================
  12. import json
  13. import os
  14. import argparse
  15. import warnings
  16. import sys
  17. import numpy as np
  18. import cv2
  19. from scipy.ndimage.filters import gaussian_filter
  20. from tqdm import tqdm
  21. from pycocotools.coco import COCO as LoadAnn
  22. from pycocotools.cocoeval import COCOeval as MapEval
  23. from mindspore import context, Tensor
  24. from mindspore.train.serialization import load_checkpoint, load_param_into_net
  25. from mindspore.communication.management import init, get_rank, get_group_size
  26. from mindspore.common import dtype as mstype
  27. from src.config import params, JointType
  28. from src.openposenet import OpenPoseNet
  29. from src.dataset import valdata
  30. warnings.filterwarnings("ignore")
  31. devid = int(os.getenv('DEVICE_ID'))
  32. context.set_context(mode=context.GRAPH_MODE,
  33. device_target="Ascend", save_graphs=False, device_id=devid)
  34. show_gt = 0
  35. parser = argparse.ArgumentParser('mindspore openpose_net test')
  36. parser.add_argument('--model_path', type=str, default='./0-33_170000.ckpt', help='path of testing model')
  37. parser.add_argument('--imgpath_val', type=str, default='./dataset/coco/val2017', help='path of testing imgs')
  38. parser.add_argument('--ann', type=str, default='./dataset/coco/annotations/person_keypoints_val2017.json',
  39. help='path of annotations')
  40. parser.add_argument('--output_path', type=str, default='./output_img', help='path of testing imgs')
  41. # distributed related
  42. parser.add_argument('--is_distributed', type=int, default=0, help='if multi device')
  43. parser.add_argument('--rank', type=int, default=0, help='local rank of distributed')
  44. parser.add_argument('--group_size', type=int, default=1, help='world size of distributed')
  45. args, _ = parser.parse_known_args()
  46. def evaluate_mAP(res_file, ann_file, ann_type='keypoints', silence=True):
  47. class NullWriter():
  48. def write(self, arg):
  49. pass
  50. if silence:
  51. nullwrite = NullWriter()
  52. oldstdout = sys.stdout
  53. sys.stdout = nullwrite # disable output
  54. Gt = LoadAnn(ann_file)
  55. Dt = Gt.loadRes(res_file)
  56. Eval = MapEval(Gt, Dt, ann_type)
  57. Eval.evaluate()
  58. Eval.accumulate()
  59. Eval.summarize()
  60. if silence:
  61. sys.stdout = oldstdout # enable output
  62. stats_names = ['AP', 'Ap .5', 'AP .75', 'AP (M)', 'AP (L)',
  63. 'AR', 'AR .5', 'AR .75', 'AR (M)', 'AR (L)']
  64. info_str = {}
  65. for ind, name in enumerate(stats_names):
  66. info_str[name] = Eval.stats[ind]
  67. return info_str
  68. def load_model(test_net, model_path):
  69. assert os.path.exists(model_path)
  70. param_dict = load_checkpoint(model_path)
  71. param_dict_new = {}
  72. for key, values in param_dict.items():
  73. if key.startswith('moment'):
  74. continue
  75. elif key.startswith('network'):
  76. param_dict_new[key[8:]] = values
  77. load_param_into_net(test_net, param_dict_new)
  78. def preprocess(img):
  79. x_data = img.astype('f')
  80. x_data /= 255
  81. x_data -= 0.5
  82. x_data = x_data.transpose(2, 0, 1)[None]
  83. return x_data
  84. def getImgsPath(img_dir_path):
  85. filepaths = []
  86. dirpaths = []
  87. pathName = img_dir_path
  88. for root, dirs, files in os.walk(pathName):
  89. for file in files:
  90. file_path = os.path.join(root, file)
  91. filepaths.append(file_path)
  92. for d in dirs:
  93. dir_path = os.path.join(root, d)
  94. dirpaths.append(dir_path)
  95. return filepaths
  96. def compute_optimal_size(orig_img, img_size, stride=8):
  97. orig_img_h, orig_img_w, _ = orig_img.shape
  98. aspect = orig_img_h / orig_img_w
  99. if orig_img_h < orig_img_w:
  100. img_h = img_size
  101. img_w = np.round(img_size / aspect).astype(int)
  102. surplus = img_w % stride
  103. if surplus != 0:
  104. img_w += stride - surplus
  105. else:
  106. img_w = img_size
  107. img_h = np.round(img_size * aspect).astype(int)
  108. surplus = img_h % stride
  109. if surplus != 0:
  110. img_h += stride - surplus
  111. return (img_w, img_h)
  112. def compute_peaks_from_heatmaps(heatmaps):
  113. heatmaps = heatmaps[:-1]
  114. all_peaks = []
  115. peak_counter = 0
  116. for i, heatmap in enumerate(heatmaps):
  117. heatmap = gaussian_filter(heatmap, sigma=params['gaussian_sigma'])
  118. map_left = np.zeros(heatmap.shape)
  119. map_right = np.zeros(heatmap.shape)
  120. map_top = np.zeros(heatmap.shape)
  121. map_bottom = np.zeros(heatmap.shape)
  122. map_left[1:, :] = heatmap[:-1, :]
  123. map_right[:-1, :] = heatmap[1:, :]
  124. map_top[:, 1:] = heatmap[:, :-1]
  125. map_bottom[:, :-1] = heatmap[:, 1:]
  126. peaks_binary = np.logical_and.reduce((
  127. heatmap > params['heatmap_peak_thresh'],
  128. heatmap > map_left,
  129. heatmap > map_right,
  130. heatmap > map_top,
  131. heatmap > map_bottom,
  132. ))
  133. peaks = zip(np.nonzero(peaks_binary)[1], np.nonzero(peaks_binary)[0])
  134. peaks_with_score = [(i,) + peak_pos + (heatmap[peak_pos[1], peak_pos[0]],) for peak_pos in peaks]
  135. peaks_id = range(peak_counter, peak_counter + len(peaks_with_score))
  136. peaks_with_score_and_id = [peaks_with_score[i] + (peaks_id[i],) for i in range(len(peaks_id))]
  137. peak_counter += len(peaks_with_score_and_id)
  138. all_peaks.append(peaks_with_score_and_id)
  139. all_peaks = np.array([peak for peaks_each_category in all_peaks for peak in peaks_each_category])
  140. return all_peaks
  141. def compute_candidate_connections(paf, cand_a, cand_b, img_len, params_):
  142. candidate_connections = []
  143. for joint_a in cand_a:
  144. for joint_b in cand_b:
  145. vector = joint_b[:2] - joint_a[:2]
  146. norm = np.linalg.norm(vector)
  147. if norm == 0:
  148. continue
  149. ys = np.linspace(joint_a[1], joint_b[1], num=params_['n_integ_points'])
  150. xs = np.linspace(joint_a[0], joint_b[0], num=params_['n_integ_points'])
  151. integ_points = np.stack([ys, xs]).T.round().astype('i')
  152. paf_in_edge = np.hstack([paf[0][np.hsplit(integ_points, 2)], paf[1][np.hsplit(integ_points, 2)]])
  153. unit_vector = vector / norm
  154. inner_products = np.dot(paf_in_edge, unit_vector)
  155. integ_value = inner_products.sum() / len(inner_products)
  156. integ_value_with_dist_prior = integ_value + min(params_['limb_length_ratio'] * img_len / norm -
  157. params_['length_penalty_value'], 0)
  158. n_valid_points = sum(inner_products > params_['inner_product_thresh'])
  159. if n_valid_points > params_['n_integ_points_thresh'] and integ_value_with_dist_prior > 0:
  160. candidate_connections.append([int(joint_a[3]), int(joint_b[3]), integ_value_with_dist_prior])
  161. candidate_connections = sorted(candidate_connections, key=lambda x: x[2], reverse=True)
  162. return candidate_connections
  163. def compute_connections(pafs, all_peaks, img_len, params_):
  164. all_connections = []
  165. for i in range(len(params_['limbs_point'])):
  166. paf_index = [i * 2, i * 2 + 1]
  167. paf = pafs[paf_index] # shape: (2, 320, 320)
  168. limb_point = params_['limbs_point'][i] # example: [<JointType.Neck: 1>, <JointType.RightWaist: 8>]
  169. cand_a = all_peaks[all_peaks[:, 0] == limb_point[0]][:, 1:]
  170. cand_b = all_peaks[all_peaks[:, 0] == limb_point[1]][:, 1:]
  171. if cand_a.shape[0] > 0 and cand_b.shape[0] > 0:
  172. candidate_connections = compute_candidate_connections(paf, cand_a, cand_b, img_len, params_)
  173. connections = np.zeros((0, 3))
  174. for index_a, index_b, score in candidate_connections:
  175. if index_a not in connections[:, 0] and index_b not in connections[:, 1]:
  176. connections = np.vstack([connections, [index_a, index_b, score]])
  177. if len(connections) >= min(len(cand_a), len(cand_b)):
  178. break
  179. all_connections.append(connections)
  180. else:
  181. all_connections.append(np.zeros((0, 3)))
  182. return all_connections
  183. def grouping_key_points(all_connections, candidate_peaks, params_):
  184. subsets = -1 * np.ones((0, 20))
  185. for l, connections in enumerate(all_connections):
  186. joint_a, joint_b = params_['limbs_point'][l]
  187. for ind_a, ind_b, score in connections[:, :3]:
  188. ind_a, ind_b = int(ind_a), int(ind_b)
  189. joint_found_cnt = 0
  190. joint_found_subset_index = [-1, -1]
  191. for subset_ind, subset in enumerate(subsets):
  192. if subset[joint_a] == ind_a or subset[joint_b] == ind_b:
  193. joint_found_subset_index[joint_found_cnt] = subset_ind
  194. joint_found_cnt += 1
  195. if joint_found_cnt == 1:
  196. found_subset = subsets[joint_found_subset_index[0]]
  197. if found_subset[joint_b] != ind_b:
  198. found_subset[joint_b] = ind_b
  199. found_subset[-1] += 1 # increment joint count
  200. found_subset[-2] += candidate_peaks[ind_b, 3] + score
  201. elif joint_found_cnt == 2:
  202. found_subset_1 = subsets[joint_found_subset_index[0]]
  203. found_subset_2 = subsets[joint_found_subset_index[1]]
  204. membership = ((found_subset_1 >= 0).astype(int) + (found_subset_2 >= 0).astype(int))[:-2]
  205. if not np.any(membership == 2): # merge two subsets when no duplication
  206. found_subset_1[:-2] += found_subset_2[:-2] + 1 # default is -1
  207. found_subset_1[-2:] += found_subset_2[-2:]
  208. found_subset_1[-2] += score
  209. subsets = np.delete(subsets, joint_found_subset_index[1], axis=0)
  210. else:
  211. if found_subset_1[joint_a] == -1:
  212. found_subset_1[joint_a] = ind_a
  213. found_subset_1[-1] += 1
  214. found_subset_1[-2] += candidate_peaks[ind_a, 3] + score
  215. elif found_subset_1[joint_b] == -1:
  216. found_subset_1[joint_b] = ind_b
  217. found_subset_1[-1] += 1
  218. found_subset_1[-2] += candidate_peaks[ind_b, 3] + score
  219. if found_subset_2[joint_a] == -1:
  220. found_subset_2[joint_a] = ind_a
  221. found_subset_2[-1] += 1
  222. found_subset_2[-2] += candidate_peaks[ind_a, 3] + score
  223. elif found_subset_2[joint_b] == -1:
  224. found_subset_2[joint_b] = ind_b
  225. found_subset_2[-1] += 1
  226. found_subset_2[-2] += candidate_peaks[ind_b, 3] + score
  227. elif joint_found_cnt == 0 and l != 9 and l != 13:
  228. row = -1 * np.ones(20)
  229. row[joint_a] = ind_a
  230. row[joint_b] = ind_b
  231. row[-1] = 2
  232. row[-2] = sum(candidate_peaks[[ind_a, ind_b], 3]) + score
  233. subsets = np.vstack([subsets, row])
  234. elif joint_found_cnt >= 3:
  235. pass
  236. # delete low score subsets
  237. keep = np.logical_and(subsets[:, -1] >= params_['n_subset_limbs_thresh'],
  238. subsets[:, -2] / subsets[:, -1] >= params_['subset_score_thresh'])
  239. subsets = subsets[keep]
  240. return subsets
  241. def subsets_to_pose_array(subsets, all_peaks):
  242. person_pose_array = []
  243. for subset in subsets:
  244. joints = []
  245. for joint_index in subset[:18].astype('i'):
  246. if joint_index >= 0:
  247. joint = all_peaks[joint_index][1:3].tolist()
  248. joint.append(2)
  249. joints.append(joint)
  250. else:
  251. joints.append([0, 0, 0])
  252. person_pose_array.append(np.array(joints))
  253. person_pose_array = np.array(person_pose_array)
  254. return person_pose_array
  255. def detect(img, network):
  256. orig_img = img.copy()
  257. orig_img_h, orig_img_w, _ = orig_img.shape
  258. input_w, input_h = compute_optimal_size(orig_img, params['inference_img_size']) # 368
  259. map_w, map_h = compute_optimal_size(orig_img, params['inference_img_size'])
  260. resized_image = cv2.resize(orig_img, (input_w, input_h))
  261. x_data = preprocess(resized_image)
  262. x_data = Tensor(x_data, mstype.float32)
  263. x_data.requires_grad = False
  264. logit_pafs, logit_heatmap = network(x_data)
  265. logit_pafs = logit_pafs[-1].asnumpy()[0]
  266. logit_heatmap = logit_heatmap[-1].asnumpy()[0]
  267. pafs = np.zeros((logit_pafs.shape[0], map_h, map_w))
  268. for i in range(logit_pafs.shape[0]):
  269. pafs[i] = cv2.resize(logit_pafs[i], (map_w, map_h))
  270. if show_gt:
  271. save_path = "./test_output/" + str(i) + "pafs.png"
  272. cv2.imwrite(save_path, pafs[i]*255)
  273. heatmaps = np.zeros((logit_heatmap.shape[0], map_h, map_w))
  274. for i in range(logit_heatmap.shape[0]):
  275. heatmaps[i] = cv2.resize(logit_heatmap[i], (map_w, map_h))
  276. if show_gt:
  277. save_path = "./test_output/" + str(i) + "heatmap.png"
  278. cv2.imwrite(save_path, heatmaps[i]*255)
  279. all_peaks = compute_peaks_from_heatmaps(heatmaps)
  280. if all_peaks.shape[0] == 0:
  281. return np.empty((0, len(JointType), 3)), np.empty(0)
  282. all_connections = compute_connections(pafs, all_peaks, map_w, params)
  283. subsets = grouping_key_points(all_connections, all_peaks, params)
  284. all_peaks[:, 1] *= orig_img_w / map_w
  285. all_peaks[:, 2] *= orig_img_h / map_h
  286. poses = subsets_to_pose_array(subsets, all_peaks)
  287. scores = subsets[:, -2]
  288. return poses, scores
  289. def draw_person_pose(orig_img, poses):
  290. orig_img = cv2.cvtColor(orig_img, cv2.COLOR_BGR2RGB)
  291. if poses.shape[0] == 0:
  292. return orig_img
  293. limb_colors = [
  294. [0, 255, 0], [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255],
  295. [0, 85, 255], [255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0.],
  296. [255, 0, 85], [170, 255, 0], [85, 255, 0], [170, 0, 255.], [0, 0, 255],
  297. [0, 0, 255], [255, 0, 255], [170, 0, 255], [255, 0, 170],
  298. ]
  299. joint_colors = [
  300. [255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0],
  301. [85, 255, 0], [0, 255, 0], [0, 255, 85], [0, 255, 170], [0, 255, 255],
  302. [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], [170, 0, 255],
  303. [255, 0, 255], [255, 0, 170], [255, 0, 85]]
  304. canvas = orig_img.copy()
  305. # limbs
  306. for pose in poses.round().astype('i'):
  307. for i, (limb, color) in enumerate(zip(params['limbs_point'], limb_colors)):
  308. if i not in (9, 13): # don't show ear-shoulder connection
  309. limb_ind = np.array(limb)
  310. if np.all(pose[limb_ind][:, 2] != 0):
  311. joint1, joint2 = pose[limb_ind][:, :2]
  312. cv2.line(canvas, tuple(joint1), tuple(joint2), color, 2)
  313. # joints
  314. for pose in poses.round().astype('i'):
  315. for i, ((x, y, v), color) in enumerate(zip(pose, joint_colors)):
  316. if v != 0:
  317. cv2.circle(canvas, (x, y), 3, color, -1)
  318. return canvas
  319. def depreprocess(img):
  320. x_data = img[0]
  321. x_data += 0.5
  322. x_data *= 255
  323. x_data = x_data.astype('uint8')
  324. x_data = x_data.transpose(1, 2, 0)
  325. return x_data
  326. def val():
  327. if args.is_distributed:
  328. init()
  329. args.rank = get_rank()
  330. args.group_size = get_group_size()
  331. if not os.path.exists(args.output_path):
  332. os.mkdir(args.output_path)
  333. network = OpenPoseNet(vgg_with_bn=params['vgg_with_bn'])
  334. network.set_train(False)
  335. load_model(network, args.model_path)
  336. print("load models right")
  337. dataset = valdata(args.ann, args.imgpath_val, args.rank, args.group_size, mode='val')
  338. dataset_size = dataset.get_dataset_size()
  339. de_dataset = dataset.create_tuple_iterator()
  340. print("eval dataset size: ", dataset_size)
  341. kpt_json = []
  342. for _, (img, img_id) in tqdm(enumerate(de_dataset), total=dataset_size):
  343. img = img.asnumpy()
  344. img_id = int((img_id.asnumpy())[0])
  345. poses, scores = detect(img, network)
  346. if poses.shape[0] > 0:
  347. for index, pose in enumerate(poses):
  348. data = dict()
  349. pose = pose[[0, 15, 14, 17, 16, 5, 2, 6, 3, 7, 4, 11, 8, 12, 9, 13, 10, 1], :].round().astype('i')
  350. keypoints = pose.reshape(-1).tolist()
  351. keypoints = keypoints[:-3]
  352. data['image_id'] = img_id
  353. data['score'] = scores[index]
  354. data['category_id'] = 1
  355. data['keypoints'] = keypoints
  356. kpt_json.append(data)
  357. else:
  358. print("Predict poses size is zero.", flush=True)
  359. img = draw_person_pose(cv2.cvtColor(img, cv2.COLOR_BGR2RGB), poses)
  360. save_path = os.path.join(args.output_path, str(img_id)+".png")
  361. cv2.imwrite(save_path, img)
  362. result_json = 'eval_result.json'
  363. with open(os.path.join(args.output_path, result_json), 'w') as fid:
  364. json.dump(kpt_json, fid)
  365. res = evaluate_mAP(os.path.join(args.output_path, result_json), ann_file=args.ann)
  366. print('result: ', res)
  367. if __name__ == "__main__":
  368. val()