You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

eval.py 18 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. # Licensed under the Apache License, Version 2.0 (the "License");
  3. # you may not use this file except in compliance with the License.
  4. # You may obtain a copy of the License at
  5. # http://www.apache.org/licenses/LICENSE-2.0
  6. # Unless required by applicable law or agreed to in writing, software
  7. # distributed under the License is distributed on an "AS IS" BASIS,
  8. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. # See the License for the specific language governing permissions and
  10. # limitations under the License.
  11. # ============================================================================
  12. import json
  13. import os
  14. import argparse
  15. import warnings
  16. import sys
  17. import cv2
  18. from tqdm import tqdm
  19. import numpy as np
  20. from scipy.ndimage.filters import gaussian_filter
  21. from pycocotools.coco import COCO as LoadAnn
  22. from pycocotools.cocoeval import COCOeval as MapEval
  23. from mindspore import context, Tensor
  24. from mindspore.train.serialization import load_checkpoint, load_param_into_net
  25. from mindspore.communication.management import init, get_rank, get_group_size
  26. from mindspore.common import dtype as mstype
  27. from src.dataset import valdata
  28. from src.openposenet import OpenPoseNet
  29. from src.config import params, JointType
  30. warnings.filterwarnings("ignore")
  31. devid = int(os.getenv('DEVICE_ID'))
  32. context.set_context(mode=context.GRAPH_MODE,
  33. device_target="Ascend", save_graphs=False, device_id=devid)
  34. show_gt = 0
  35. def evaluate_mAP(res_file, ann_file, ann_type='keypoints', silence=True):
  36. class NullWriter():
  37. def write(self, arg):
  38. pass
  39. if silence:
  40. nullwrite = NullWriter()
  41. oldstdout = sys.stdout
  42. sys.stdout = nullwrite # disable output
  43. Gt = LoadAnn(ann_file)
  44. Dt = Gt.loadRes(res_file)
  45. Eval = MapEval(Gt, Dt, ann_type)
  46. Eval.evaluate()
  47. Eval.accumulate()
  48. Eval.summarize()
  49. if silence:
  50. sys.stdout = oldstdout # enable output
  51. stats_names = ['AP', 'Ap .5', 'AP .75', 'AP (M)', 'AP (L)',
  52. 'AR', 'AR .5', 'AR .75', 'AR (M)', 'AR (L)']
  53. info_str = {}
  54. for ind, name in enumerate(stats_names):
  55. info_str[name] = Eval.stats[ind]
  56. return info_str
  57. def parse_args():
  58. """Parse arguments."""
  59. parser = argparse.ArgumentParser('mindspore openpose_net test')
  60. parser.add_argument('--model_path', type=str, default='./scripts/train_parallel0/checkpoints/ckpt_0/0-60_663.ckpt',
  61. help='path of testing model')
  62. parser.add_argument('--imgpath_val', type=str, default='/data0/zhy/dataset/coco/val2017',
  63. help='path of testing imgs')
  64. parser.add_argument('--ann', type=str, default='/data0/zhy/dataset/coco/annotations/person_keypoints_val2017.json',
  65. help='path of annotations')
  66. parser.add_argument('--output_path', type=str, default='./output_img', help='path of testing imgs')
  67. # distributed related
  68. parser.add_argument('--is_distributed', type=int, default=0, help='if multi device')
  69. parser.add_argument('--rank', type=int, default=0, help='local rank of distributed')
  70. parser.add_argument('--group_size', type=int, default=1, help='world size of distributed')
  71. args, _ = parser.parse_known_args()
  72. return args
  73. def load_model(test_net, model_path):
  74. assert os.path.exists(model_path)
  75. param_dict = load_checkpoint(model_path)
  76. param_dict_new = {}
  77. for key, values in param_dict.items():
  78. if key.startswith('moment'):
  79. continue
  80. elif key.startswith('network'):
  81. param_dict_new[key[8:]] = values
  82. # else:
  83. # param_dict_new[key] = values
  84. load_param_into_net(test_net, param_dict_new)
  85. def preprocess(img):
  86. x_data = img.astype('f')
  87. x_data /= 255
  88. x_data -= 0.5
  89. x_data = x_data.transpose(2, 0, 1)[None]
  90. return x_data
  91. def getImgsPath(img_dir_path):
  92. filepaths = []
  93. dirpaths = []
  94. pathName = img_dir_path
  95. for root, dirs, files in os.walk(pathName):
  96. for file in files:
  97. file_path = os.path.join(root, file)
  98. filepaths.append(file_path)
  99. for d in dirs:
  100. dir_path = os.path.join(root, d)
  101. dirpaths.append(dir_path)
  102. return filepaths
  103. def compute_optimal_size(orig_img, img_size, stride=8):
  104. orig_img_h, orig_img_w, _ = orig_img.shape
  105. aspect = orig_img_h / orig_img_w
  106. if orig_img_h < orig_img_w:
  107. img_h = img_size
  108. img_w = np.round(img_size / aspect).astype(int)
  109. surplus = img_w % stride
  110. if surplus != 0:
  111. img_w += stride - surplus
  112. else:
  113. img_w = img_size
  114. img_h = np.round(img_size * aspect).astype(int)
  115. surplus = img_h % stride
  116. if surplus != 0:
  117. img_h += stride - surplus
  118. return (img_w, img_h)
  119. def compute_peaks_from_heatmaps(heatmaps):
  120. heatmaps = heatmaps[:-1]
  121. all_peaks = []
  122. peak_counter = 0
  123. for i, heatmap in enumerate(heatmaps):
  124. heatmap = gaussian_filter(heatmap, sigma=params['gaussian_sigma'])
  125. map_left = np.zeros(heatmap.shape)
  126. map_right = np.zeros(heatmap.shape)
  127. map_top = np.zeros(heatmap.shape)
  128. map_bottom = np.zeros(heatmap.shape)
  129. map_left[1:, :] = heatmap[:-1, :]
  130. map_right[:-1, :] = heatmap[1:, :]
  131. map_top[:, 1:] = heatmap[:, :-1]
  132. map_bottom[:, :-1] = heatmap[:, 1:]
  133. peaks_binary = np.logical_and.reduce((
  134. heatmap > params['heatmap_peak_thresh'],
  135. heatmap > map_left,
  136. heatmap > map_right,
  137. heatmap > map_top,
  138. heatmap > map_bottom,
  139. ))
  140. peaks = zip(np.nonzero(peaks_binary)[1], np.nonzero(peaks_binary)[0])
  141. peaks_with_score = [(i,) + peak_pos + (heatmap[peak_pos[1], peak_pos[0]],) for peak_pos in peaks]
  142. peaks_id = range(peak_counter, peak_counter + len(peaks_with_score))
  143. peaks_with_score_and_id = [peaks_with_score[i] + (peaks_id[i],) for i in range(len(peaks_id))]
  144. peak_counter += len(peaks_with_score_and_id)
  145. all_peaks.append(peaks_with_score_and_id)
  146. all_peaks = np.array([peak for peaks_each_category in all_peaks for peak in peaks_each_category])
  147. return all_peaks
  148. def compute_candidate_connections(paf, cand_a, cand_b, img_len, cfg):
  149. candidate_connections = []
  150. for joint_a in cand_a:
  151. for joint_b in cand_b:
  152. vector = joint_b[:2] - joint_a[:2]
  153. norm = np.linalg.norm(vector)
  154. if norm == 0:
  155. continue
  156. ys = np.linspace(joint_a[1], joint_b[1], num=cfg['n_integ_points'])
  157. xs = np.linspace(joint_a[0], joint_b[0], num=cfg['n_integ_points'])
  158. integ_points = np.stack([ys, xs]).T.round().astype('i')
  159. paf_in_edge = np.hstack([paf[0][np.hsplit(integ_points, 2)], paf[1][np.hsplit(integ_points, 2)]])
  160. unit_vector = vector / norm
  161. inner_products = np.dot(paf_in_edge, unit_vector)
  162. integ_value = inner_products.sum() / len(inner_products)
  163. integ_value_with_dist_prior = integ_value + min(cfg['limb_length_ratio'] * img_len / norm -
  164. cfg['length_penalty_value'], 0)
  165. n_valid_points = sum(inner_products > cfg['inner_product_thresh'])
  166. if n_valid_points > cfg['n_integ_points_thresh'] and integ_value_with_dist_prior > 0:
  167. candidate_connections.append([int(joint_a[3]), int(joint_b[3]), integ_value_with_dist_prior])
  168. candidate_connections = sorted(candidate_connections, key=lambda x: x[2], reverse=True)
  169. return candidate_connections
  170. def compute_connections(pafs, all_peaks, img_len, cfg):
  171. all_connections = []
  172. for i in range(len(cfg['limbs_point'])):
  173. paf_index = [i * 2, i * 2 + 1]
  174. paf = pafs[paf_index] # shape: (2, 320, 320)
  175. limb_point = cfg['limbs_point'][i] # example: [<JointType.Neck: 1>, <JointType.RightWaist: 8>]
  176. cand_a = all_peaks[all_peaks[:, 0] == limb_point[0]][:, 1:]
  177. cand_b = all_peaks[all_peaks[:, 0] == limb_point[1]][:, 1:]
  178. if cand_a and cand_b:
  179. candidate_connections = compute_candidate_connections(paf, cand_a, cand_b, img_len, cfg)
  180. connections = np.zeros((0, 3))
  181. for index_a, index_b, score in candidate_connections:
  182. if index_a not in connections[:, 0] and index_b not in connections[:, 1]:
  183. connections = np.vstack([connections, [index_a, index_b, score]])
  184. if len(connections) >= min(len(cand_a), len(cand_b)):
  185. break
  186. all_connections.append(connections)
  187. else:
  188. all_connections.append(np.zeros((0, 3)))
  189. return all_connections
  190. def grouping_key_points(all_connections, candidate_peaks, cfg):
  191. subsets = -1 * np.ones((0, 20))
  192. for l, connections in enumerate(all_connections):
  193. joint_a, joint_b = cfg['limbs_point'][l]
  194. for ind_a, ind_b, score in connections[:, :3]:
  195. ind_a, ind_b = int(ind_a), int(ind_b)
  196. joint_found_cnt = 0
  197. joint_found_subset_index = [-1, -1]
  198. for subset_ind, subset in enumerate(subsets):
  199. if subset[joint_a] == ind_a or subset[joint_b] == ind_b:
  200. joint_found_subset_index[joint_found_cnt] = subset_ind
  201. joint_found_cnt += 1
  202. if joint_found_cnt == 1:
  203. found_subset = subsets[joint_found_subset_index[0]]
  204. if found_subset[joint_b] != ind_b:
  205. found_subset[joint_b] = ind_b
  206. found_subset[-1] += 1 # increment joint count
  207. found_subset[-2] += candidate_peaks[ind_b, 3] + score
  208. elif joint_found_cnt == 2:
  209. found_subset_1 = subsets[joint_found_subset_index[0]]
  210. found_subset_2 = subsets[joint_found_subset_index[1]]
  211. membership = ((found_subset_1 >= 0).astype(int) + (found_subset_2 >= 0).astype(int))[:-2]
  212. if not np.any(membership == 2): # merge two subsets when no duplication
  213. found_subset_1[:-2] += found_subset_2[:-2] + 1 # default is -1
  214. found_subset_1[-2:] += found_subset_2[-2:]
  215. found_subset_1[-2] += score
  216. subsets = np.delete(subsets, joint_found_subset_index[1], axis=0)
  217. else:
  218. if found_subset_1[joint_a] == -1:
  219. found_subset_1[joint_a] = ind_a
  220. found_subset_1[-1] += 1
  221. found_subset_1[-2] += candidate_peaks[ind_a, 3] + score
  222. elif found_subset_1[joint_b] == -1:
  223. found_subset_1[joint_b] = ind_b
  224. found_subset_1[-1] += 1
  225. found_subset_1[-2] += candidate_peaks[ind_b, 3] + score
  226. if found_subset_2[joint_a] == -1:
  227. found_subset_2[joint_a] = ind_a
  228. found_subset_2[-1] += 1
  229. found_subset_2[-2] += candidate_peaks[ind_a, 3] + score
  230. elif found_subset_2[joint_b] == -1:
  231. found_subset_2[joint_b] = ind_b
  232. found_subset_2[-1] += 1
  233. found_subset_2[-2] += candidate_peaks[ind_b, 3] + score
  234. elif joint_found_cnt == 0 and l != 9 and l != 13:
  235. row = -1 * np.ones(20)
  236. row[joint_a] = ind_a
  237. row[joint_b] = ind_b
  238. row[-1] = 2
  239. row[-2] = sum(candidate_peaks[[ind_a, ind_b], 3]) + score
  240. subsets = np.vstack([subsets, row])
  241. elif joint_found_cnt >= 3:
  242. pass
  243. # delete low score subsets
  244. keep = np.logical_and(subsets[:, -1] >= cfg['n_subset_limbs_thresh'],
  245. subsets[:, -2] / subsets[:, -1] >= cfg['subset_score_thresh'])
  246. # cfg['n_subset_limbs_thresh'] = 3
  247. # cfg['subset_score_thresh'] = 0.2
  248. subsets = subsets[keep]
  249. return subsets
  250. def subsets_to_pose_array(subsets, all_peaks):
  251. person_pose_array = []
  252. for subset in subsets:
  253. joints = []
  254. for joint_index in subset[:18].astype('i'):
  255. if joint_index >= 0:
  256. joint = all_peaks[joint_index][1:3].tolist()
  257. joint.append(2)
  258. joints.append(joint)
  259. else:
  260. joints.append([0, 0, 0])
  261. person_pose_array.append(np.array(joints))
  262. person_pose_array = np.array(person_pose_array)
  263. return person_pose_array
  264. def detect(img, network):
  265. orig_img = img.copy()
  266. orig_img_h, orig_img_w, _ = orig_img.shape
  267. input_w, input_h = compute_optimal_size(orig_img, params['inference_img_size']) # 368
  268. # map_w, map_h = compute_optimal_size(orig_img, params['heatmap_size']) # 320
  269. map_w, map_h = compute_optimal_size(orig_img, params['inference_img_size'])
  270. print("image size is: ", input_w, input_h)
  271. resized_image = cv2.resize(orig_img, (input_w, input_h))
  272. x_data = preprocess(resized_image)
  273. x_data = Tensor(x_data, mstype.float32)
  274. x_data.requires_grad = False
  275. logit_pafs, logit_heatmap = network(x_data)
  276. logit_pafs = logit_pafs[-1].asnumpy()[0]
  277. logit_heatmap = logit_heatmap[-1].asnumpy()[0]
  278. pafs = np.zeros((logit_pafs.shape[0], map_h, map_w))
  279. for i in range(logit_pafs.shape[0]):
  280. pafs[i] = cv2.resize(logit_pafs[i], (map_w, map_h))
  281. if show_gt:
  282. save_path = "./test_output/" + str(i) + "pafs.png"
  283. cv2.imwrite(save_path, pafs[i]*255)
  284. heatmaps = np.zeros((logit_heatmap.shape[0], map_h, map_w))
  285. for i in range(logit_heatmap.shape[0]):
  286. heatmaps[i] = cv2.resize(logit_heatmap[i], (map_w, map_h))
  287. if show_gt:
  288. save_path = "./test_output/" + str(i) + "heatmap.png"
  289. cv2.imwrite(save_path, heatmaps[i]*255)
  290. all_peaks = compute_peaks_from_heatmaps(heatmaps)
  291. if not all_peaks:
  292. return np.empty((0, len(JointType), 3)), np.empty(0)
  293. all_connections = compute_connections(pafs, all_peaks, map_w, params)
  294. subsets = grouping_key_points(all_connections, all_peaks, params)
  295. all_peaks[:, 1] *= orig_img_w / map_w
  296. all_peaks[:, 2] *= orig_img_h / map_h
  297. poses = subsets_to_pose_array(subsets, all_peaks)
  298. scores = subsets[:, -2]
  299. return poses, scores
  300. def draw_person_pose(orig_img, poses):
  301. orig_img = cv2.cvtColor(orig_img, cv2.COLOR_BGR2RGB)
  302. if not poses:
  303. return orig_img
  304. limb_colors = [
  305. [0, 255, 0], [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255],
  306. [0, 85, 255], [255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0.],
  307. [255, 0, 85], [170, 255, 0], [85, 255, 0], [170, 0, 255.], [0, 0, 255],
  308. [0, 0, 255], [255, 0, 255], [170, 0, 255], [255, 0, 170],
  309. ]
  310. joint_colors = [
  311. [255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0],
  312. [85, 255, 0], [0, 255, 0], [0, 255, 85], [0, 255, 170], [0, 255, 255],
  313. [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], [170, 0, 255],
  314. [255, 0, 255], [255, 0, 170], [255, 0, 85]]
  315. canvas = orig_img.copy()
  316. # limbs
  317. for pose in poses.round().astype('i'):
  318. for i, (limb, color) in enumerate(zip(params['limbs_point'], limb_colors)):
  319. if i not in (9, 13): # don't show ear-shoulder connection
  320. limb_ind = np.array(limb)
  321. if np.all(pose[limb_ind][:, 2] != 0):
  322. joint1, joint2 = pose[limb_ind][:, :2]
  323. cv2.line(canvas, tuple(joint1), tuple(joint2), color, 2)
  324. # joints
  325. for pose in poses.round().astype('i'):
  326. for i, ((x, y, v), color) in enumerate(zip(pose, joint_colors)):
  327. if v != 0:
  328. cv2.circle(canvas, (x, y), 3, color, -1)
  329. return canvas
  330. def depreprocess(img):
  331. # x_data = img.astype('f')
  332. x_data = img[0]
  333. x_data += 0.5
  334. x_data *= 255
  335. x_data = x_data.astype('uint8')
  336. x_data = x_data.transpose(1, 2, 0)
  337. return x_data
  338. def _eval():
  339. args = parse_args()
  340. if args.is_distributed:
  341. init()
  342. args.rank = get_rank()
  343. args.group_size = get_group_size()
  344. if not os.path.exists(args.output_path):
  345. os.mkdir(args.output_path)
  346. network = OpenPoseNet()
  347. network.set_train(False)
  348. load_model(network, args.model_path)
  349. print("load models right")
  350. dataset = valdata(args.ann, args.imgpath_val, args.rank, args.group_size, mode='val')
  351. dataset_size = dataset.get_dataset_size()
  352. de_dataset = dataset.create_tuple_iterator()
  353. print("eval dataset size: ", dataset_size)
  354. kpt_json = []
  355. for _, (img, img_id) in tqdm(enumerate(de_dataset), total=dataset_size):
  356. img = img.asnumpy()
  357. img_id = int((img_id.asnumpy())[0])
  358. poses, scores = detect(img, network)
  359. if poses:
  360. #print("got poses")
  361. for index, pose in enumerate(poses):
  362. data = dict()
  363. pose = pose[[0, 15, 14, 17, 16, 5, 2, 6, 3, 7, 4, 11, 8, 12, 9, 13, 10, 1], :].round().astype('i')
  364. keypoints = pose.reshape(-1).tolist()
  365. keypoints = keypoints[:-3]
  366. data['image_id'] = img_id
  367. data['score'] = scores[index]
  368. data['category_id'] = 1
  369. data['keypoints'] = keypoints
  370. kpt_json.append(data)
  371. else:
  372. print("Predict poses size is zero.", flush=True)
  373. img = draw_person_pose(cv2.cvtColor(img, cv2.COLOR_BGR2RGB), poses)
  374. #print('Saving result into',str(img_id)+'.png...')
  375. save_path = os.path.join(args.output_path, str(img_id)+".png")
  376. cv2.imwrite(save_path, img)
  377. result_json = 'eval_result.json'
  378. with open(os.path.join(args.output_path, result_json), 'w') as fid:
  379. json.dump(kpt_json, fid)
  380. res = evaluate_mAP(os.path.join(args.output_path, result_json), ann_file=args.ann)
  381. print('result: ', res)
  382. if __name__ == "__main__":
  383. _eval()