You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

eval.py 18 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. # Licensed under the Apache License, Version 2.0 (the "License");
  3. # you may not use this file except in compliance with the License.
  4. # You may obtain a copy of the License at
  5. # http://www.apache.org/licenses/LICENSE-2.0
  6. # Unless required by applicable law or agreed to in writing, software
  7. # distributed under the License is distributed on an "AS IS" BASIS,
  8. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. # See the License for the specific language governing permissions and
  10. # limitations under the License.
  11. # ============================================================================
  12. import json
  13. import os
  14. import argparse
  15. import warnings
  16. import sys
  17. import numpy as np
  18. from tqdm import tqdm
  19. import cv2
  20. from scipy.ndimage.filters import gaussian_filter
  21. from pycocotools.coco import COCO as LoadAnn
  22. from pycocotools.cocoeval import COCOeval as MapEval
  23. from mindspore import context, Tensor
  24. from mindspore.train.serialization import load_checkpoint, load_param_into_net
  25. from mindspore.communication.management import init, get_rank, get_group_size
  26. from mindspore.common import dtype as mstype
  27. from src.config import params, JointType
  28. from src.openposenet import OpenPoseNet
  29. from src.dataset import valdata
  30. warnings.filterwarnings("ignore")
  31. devid = int(os.getenv('DEVICE_ID'))
  32. context.set_context(mode=context.GRAPH_MODE,
  33. device_target="Ascend", save_graphs=False, device_id=devid)
  34. show_gt = 0
  35. parser = argparse.ArgumentParser('mindspore openpose_net test')
  36. parser.add_argument('--model_path', type=str, default='./0-33_170000.ckpt', help='path of testing model')
  37. parser.add_argument('--imgpath_val', type=str, default='./dataset/coco/val2017', help='path of testing imgs')
  38. parser.add_argument('--ann', type=str, default='./dataset/coco/annotations/person_keypoints_val2017.json',
  39. help='path of annotations')
  40. parser.add_argument('--output_path', type=str, default='./output_img', help='path of testing imgs')
  41. # distributed related
  42. parser.add_argument('--is_distributed', type=int, default=0, help='if multi device')
  43. parser.add_argument('--rank', type=int, default=0, help='local rank of distributed')
  44. parser.add_argument('--group_size', type=int, default=1, help='world size of distributed')
  45. args, _ = parser.parse_known_args()
  46. def evaluate_mAP(res_file, ann_file, ann_type='keypoints', silence=True):
  47. class NullWriter():
  48. def write(self, arg):
  49. pass
  50. if silence:
  51. nullwrite = NullWriter()
  52. oldstdout = sys.stdout
  53. sys.stdout = nullwrite # disable output
  54. Gt = LoadAnn(ann_file)
  55. Dt = Gt.loadRes(res_file)
  56. Eval = MapEval(Gt, Dt, ann_type)
  57. Eval.evaluate()
  58. Eval.accumulate()
  59. Eval.summarize()
  60. if silence:
  61. sys.stdout = oldstdout # enable output
  62. stats_names = ['AP', 'Ap .5', 'AP .75', 'AP (M)', 'AP (L)',
  63. 'AR', 'AR .5', 'AR .75', 'AR (M)', 'AR (L)']
  64. info_str = {}
  65. for ind, name in enumerate(stats_names):
  66. info_str[name] = Eval.stats[ind]
  67. return info_str
  68. def load_model(test_net, model_path):
  69. assert os.path.exists(model_path)
  70. param_dict = load_checkpoint(model_path)
  71. param_dict_new = {}
  72. for key, values in param_dict.items():
  73. if key.startswith('moment'):
  74. continue
  75. elif key.startswith('network'):
  76. param_dict_new[key[8:]] = values
  77. # else:
  78. # param_dict_new[key] = values
  79. load_param_into_net(test_net, param_dict_new)
  80. def preprocess(img):
  81. x_data = img.astype('f')
  82. x_data /= 255
  83. x_data -= 0.5
  84. x_data = x_data.transpose(2, 0, 1)[None]
  85. return x_data
  86. def getImgsPath(img_dir_path):
  87. filepaths = []
  88. dirpaths = []
  89. pathName = img_dir_path
  90. for root, dirs, files in os.walk(pathName):
  91. for file in files:
  92. file_path = os.path.join(root, file)
  93. filepaths.append(file_path)
  94. for d in dirs:
  95. dir_path = os.path.join(root, d)
  96. dirpaths.append(dir_path)
  97. return filepaths
  98. def compute_optimal_size(orig_img, img_size, stride=8):
  99. orig_img_h, orig_img_w, _ = orig_img.shape
  100. aspect = orig_img_h / orig_img_w
  101. if orig_img_h < orig_img_w:
  102. img_h = img_size
  103. img_w = np.round(img_size / aspect).astype(int)
  104. surplus = img_w % stride
  105. if surplus != 0:
  106. img_w += stride - surplus
  107. else:
  108. img_w = img_size
  109. img_h = np.round(img_size * aspect).astype(int)
  110. surplus = img_h % stride
  111. if surplus != 0:
  112. img_h += stride - surplus
  113. return (img_w, img_h)
  114. def compute_peaks_from_heatmaps(heatmaps):
  115. heatmaps = heatmaps[:-1]
  116. all_peaks = []
  117. peak_counter = 0
  118. for i, heatmap in enumerate(heatmaps):
  119. heatmap = gaussian_filter(heatmap, sigma=params['gaussian_sigma'])
  120. map_left = np.zeros(heatmap.shape)
  121. map_right = np.zeros(heatmap.shape)
  122. map_top = np.zeros(heatmap.shape)
  123. map_bottom = np.zeros(heatmap.shape)
  124. map_left[1:, :] = heatmap[:-1, :]
  125. map_right[:-1, :] = heatmap[1:, :]
  126. map_top[:, 1:] = heatmap[:, :-1]
  127. map_bottom[:, :-1] = heatmap[:, 1:]
  128. peaks_binary = np.logical_and.reduce((
  129. heatmap > params['heatmap_peak_thresh'],
  130. heatmap > map_left,
  131. heatmap > map_right,
  132. heatmap > map_top,
  133. heatmap > map_bottom,
  134. ))
  135. peaks = zip(np.nonzero(peaks_binary)[1], np.nonzero(peaks_binary)[0])
  136. peaks_with_score = [(i,) + peak_pos + (heatmap[peak_pos[1], peak_pos[0]],) for peak_pos in peaks]
  137. peaks_id = range(peak_counter, peak_counter + len(peaks_with_score))
  138. peaks_with_score_and_id = [peaks_with_score[i] + (peaks_id[i],) for i in range(len(peaks_id))]
  139. peak_counter += len(peaks_with_score_and_id)
  140. all_peaks.append(peaks_with_score_and_id)
  141. all_peaks = np.array([peak for peaks_each_category in all_peaks for peak in peaks_each_category])
  142. return all_peaks
  143. def compute_candidate_connections(paf, cand_a, cand_b, img_len, params_):
  144. candidate_connections = []
  145. for joint_a in cand_a:
  146. for joint_b in cand_b:
  147. vector = joint_b[:2] - joint_a[:2]
  148. norm = np.linalg.norm(vector)
  149. if norm == 0:
  150. continue
  151. ys = np.linspace(joint_a[1], joint_b[1], num=params_['n_integ_points'])
  152. xs = np.linspace(joint_a[0], joint_b[0], num=params_['n_integ_points'])
  153. integ_points = np.stack([ys, xs]).T.round().astype('i')
  154. paf_in_edge = np.hstack([paf[0][np.hsplit(integ_points, 2)], paf[1][np.hsplit(integ_points, 2)]])
  155. unit_vector = vector / norm
  156. inner_products = np.dot(paf_in_edge, unit_vector)
  157. integ_value = inner_products.sum() / len(inner_products)
  158. integ_value_with_dist_prior = integ_value + min(params_['limb_length_ratio'] * img_len / norm -
  159. params_['length_penalty_value'], 0)
  160. n_valid_points = sum(inner_products > params_['inner_product_thresh'])
  161. if n_valid_points > params_['n_integ_points_thresh'] and integ_value_with_dist_prior > 0:
  162. candidate_connections.append([int(joint_a[3]), int(joint_b[3]), integ_value_with_dist_prior])
  163. candidate_connections = sorted(candidate_connections, key=lambda x: x[2], reverse=True)
  164. return candidate_connections
  165. def compute_connections(pafs, all_peaks, img_len, params_):
  166. all_connections = []
  167. for i in range(len(params_['limbs_point'])):
  168. paf_index = [i * 2, i * 2 + 1]
  169. paf = pafs[paf_index] # shape: (2, 320, 320)
  170. limb_point = params_['limbs_point'][i] # example: [<JointType.Neck: 1>, <JointType.RightWaist: 8>]
  171. cand_a = all_peaks[all_peaks[:, 0] == limb_point[0]][:, 1:]
  172. cand_b = all_peaks[all_peaks[:, 0] == limb_point[1]][:, 1:]
  173. if cand_a.shape[0] > 0 and cand_b.shape[0] > 0:
  174. candidate_connections = compute_candidate_connections(paf, cand_a, cand_b, img_len, params_)
  175. connections = np.zeros((0, 3))
  176. for index_a, index_b, score in candidate_connections:
  177. if index_a not in connections[:, 0] and index_b not in connections[:, 1]:
  178. connections = np.vstack([connections, [index_a, index_b, score]])
  179. if len(connections) >= min(len(cand_a), len(cand_b)):
  180. break
  181. all_connections.append(connections)
  182. else:
  183. all_connections.append(np.zeros((0, 3)))
  184. return all_connections
  185. def grouping_key_points(all_connections, candidate_peaks, params_):
  186. subsets = -1 * np.ones((0, 20))
  187. for l, connections in enumerate(all_connections):
  188. joint_a, joint_b = params_['limbs_point'][l]
  189. for ind_a, ind_b, score in connections[:, :3]:
  190. ind_a, ind_b = int(ind_a), int(ind_b)
  191. joint_found_cnt = 0
  192. joint_found_subset_index = [-1, -1]
  193. for subset_ind, subset in enumerate(subsets):
  194. if subset[joint_a] == ind_a or subset[joint_b] == ind_b:
  195. joint_found_subset_index[joint_found_cnt] = subset_ind
  196. joint_found_cnt += 1
  197. if joint_found_cnt == 1:
  198. found_subset = subsets[joint_found_subset_index[0]]
  199. if found_subset[joint_b] != ind_b:
  200. found_subset[joint_b] = ind_b
  201. found_subset[-1] += 1 # increment joint count
  202. found_subset[-2] += candidate_peaks[ind_b, 3] + score
  203. elif joint_found_cnt == 2:
  204. found_subset_1 = subsets[joint_found_subset_index[0]]
  205. found_subset_2 = subsets[joint_found_subset_index[1]]
  206. membership = ((found_subset_1 >= 0).astype(int) + (found_subset_2 >= 0).astype(int))[:-2]
  207. if not np.any(membership == 2): # merge two subsets when no duplication
  208. found_subset_1[:-2] += found_subset_2[:-2] + 1 # default is -1
  209. found_subset_1[-2:] += found_subset_2[-2:]
  210. found_subset_1[-2] += score
  211. subsets = np.delete(subsets, joint_found_subset_index[1], axis=0)
  212. else:
  213. if found_subset_1[joint_a] == -1:
  214. found_subset_1[joint_a] = ind_a
  215. found_subset_1[-1] += 1
  216. found_subset_1[-2] += candidate_peaks[ind_a, 3] + score
  217. elif found_subset_1[joint_b] == -1:
  218. found_subset_1[joint_b] = ind_b
  219. found_subset_1[-1] += 1
  220. found_subset_1[-2] += candidate_peaks[ind_b, 3] + score
  221. if found_subset_2[joint_a] == -1:
  222. found_subset_2[joint_a] = ind_a
  223. found_subset_2[-1] += 1
  224. found_subset_2[-2] += candidate_peaks[ind_a, 3] + score
  225. elif found_subset_2[joint_b] == -1:
  226. found_subset_2[joint_b] = ind_b
  227. found_subset_2[-1] += 1
  228. found_subset_2[-2] += candidate_peaks[ind_b, 3] + score
  229. elif joint_found_cnt == 0 and l != 9 and l != 13:
  230. row = -1 * np.ones(20)
  231. row[joint_a] = ind_a
  232. row[joint_b] = ind_b
  233. row[-1] = 2
  234. row[-2] = sum(candidate_peaks[[ind_a, ind_b], 3]) + score
  235. subsets = np.vstack([subsets, row])
  236. elif joint_found_cnt >= 3:
  237. pass
  238. # delete low score subsets
  239. keep = np.logical_and(subsets[:, -1] >= params_['n_subset_limbs_thresh'],
  240. subsets[:, -2] / subsets[:, -1] >= params_['subset_score_thresh'])
  241. subsets = subsets[keep]
  242. return subsets
  243. def subsets_to_pose_array(subsets, all_peaks):
  244. person_pose_array = []
  245. for subset in subsets:
  246. joints = []
  247. for joint_index in subset[:18].astype('i'):
  248. if joint_index >= 0:
  249. joint = all_peaks[joint_index][1:3].tolist()
  250. joint.append(2)
  251. joints.append(joint)
  252. else:
  253. joints.append([0, 0, 0])
  254. person_pose_array.append(np.array(joints))
  255. person_pose_array = np.array(person_pose_array)
  256. return person_pose_array
  257. def detect(img, network):
  258. orig_img = img.copy()
  259. orig_img_h, orig_img_w, _ = orig_img.shape
  260. input_w, input_h = compute_optimal_size(orig_img, params['inference_img_size']) # 368
  261. # map_w, map_h = compute_optimal_size(orig_img, params['heatmap_size']) # 320
  262. map_w, map_h = compute_optimal_size(orig_img, params['inference_img_size'])
  263. # print("image size is: ", input_w, input_h)
  264. resized_image = cv2.resize(orig_img, (input_w, input_h))
  265. x_data = preprocess(resized_image)
  266. x_data = Tensor(x_data, mstype.float32)
  267. x_data.requires_grad = False
  268. logit_pafs, logit_heatmap = network(x_data)
  269. logit_pafs = logit_pafs[-1].asnumpy()[0]
  270. logit_heatmap = logit_heatmap[-1].asnumpy()[0]
  271. pafs = np.zeros((logit_pafs.shape[0], map_h, map_w))
  272. for i in range(logit_pafs.shape[0]):
  273. pafs[i] = cv2.resize(logit_pafs[i], (map_w, map_h))
  274. if show_gt:
  275. save_path = "./test_output/" + str(i) + "pafs.png"
  276. cv2.imwrite(save_path, pafs[i]*255)
  277. heatmaps = np.zeros((logit_heatmap.shape[0], map_h, map_w))
  278. for i in range(logit_heatmap.shape[0]):
  279. heatmaps[i] = cv2.resize(logit_heatmap[i], (map_w, map_h))
  280. if show_gt:
  281. save_path = "./test_output/" + str(i) + "heatmap.png"
  282. cv2.imwrite(save_path, heatmaps[i]*255)
  283. all_peaks = compute_peaks_from_heatmaps(heatmaps)
  284. if all_peaks.shape[0] == 0:
  285. return np.empty((0, len(JointType), 3)), np.empty(0)
  286. all_connections = compute_connections(pafs, all_peaks, map_w, params)
  287. subsets = grouping_key_points(all_connections, all_peaks, params)
  288. all_peaks[:, 1] *= orig_img_w / map_w
  289. all_peaks[:, 2] *= orig_img_h / map_h
  290. poses = subsets_to_pose_array(subsets, all_peaks)
  291. scores = subsets[:, -2]
  292. return poses, scores
  293. def draw_person_pose(orig_img, poses):
  294. orig_img = cv2.cvtColor(orig_img, cv2.COLOR_BGR2RGB)
  295. if poses.shape[0] == 0:
  296. return orig_img
  297. limb_colors = [
  298. [0, 255, 0], [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255],
  299. [0, 85, 255], [255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0.],
  300. [255, 0, 85], [170, 255, 0], [85, 255, 0], [170, 0, 255.], [0, 0, 255],
  301. [0, 0, 255], [255, 0, 255], [170, 0, 255], [255, 0, 170],
  302. ]
  303. joint_colors = [
  304. [255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0],
  305. [85, 255, 0], [0, 255, 0], [0, 255, 85], [0, 255, 170], [0, 255, 255],
  306. [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], [170, 0, 255],
  307. [255, 0, 255], [255, 0, 170], [255, 0, 85]]
  308. canvas = orig_img.copy()
  309. # limbs
  310. for pose in poses.round().astype('i'):
  311. for i, (limb, color) in enumerate(zip(params['limbs_point'], limb_colors)):
  312. if i not in (9, 13): # don't show ear-shoulder connection
  313. limb_ind = np.array(limb)
  314. if np.all(pose[limb_ind][:, 2] != 0):
  315. joint1, joint2 = pose[limb_ind][:, :2]
  316. cv2.line(canvas, tuple(joint1), tuple(joint2), color, 2)
  317. # joints
  318. for pose in poses.round().astype('i'):
  319. for i, ((x, y, v), color) in enumerate(zip(pose, joint_colors)):
  320. if v != 0:
  321. cv2.circle(canvas, (x, y), 3, color, -1)
  322. return canvas
  323. def depreprocess(img):
  324. #x_data = img.astype('f')
  325. x_data = img[0]
  326. x_data += 0.5
  327. x_data *= 255
  328. x_data = x_data.astype('uint8')
  329. x_data = x_data.transpose(1, 2, 0)
  330. return x_data
  331. def val():
  332. if args.is_distributed:
  333. init()
  334. args.rank = get_rank()
  335. args.group_size = get_group_size()
  336. if not os.path.exists(args.output_path):
  337. os.mkdir(args.output_path)
  338. network = OpenPoseNet(vgg_with_bn=params['vgg_with_bn'])
  339. network.set_train(False)
  340. load_model(network, args.model_path)
  341. print("load models right")
  342. dataset = valdata(args.ann, args.imgpath_val, args.rank, args.group_size, mode='val')
  343. dataset_size = dataset.get_dataset_size()
  344. de_dataset = dataset.create_tuple_iterator()
  345. print("eval dataset size: ", dataset_size)
  346. kpt_json = []
  347. for _, (img, img_id) in tqdm(enumerate(de_dataset), total=dataset_size):
  348. img = img.asnumpy()
  349. img_id = int((img_id.asnumpy())[0])
  350. poses, scores = detect(img, network)
  351. if poses.shape[0] > 0:
  352. #print("got poses")
  353. for index, pose in enumerate(poses):
  354. data = dict()
  355. pose = pose[[0, 15, 14, 17, 16, 5, 2, 6, 3, 7, 4, 11, 8, 12, 9, 13, 10, 1], :].round().astype('i')
  356. keypoints = pose.reshape(-1).tolist()
  357. keypoints = keypoints[:-3]
  358. data['image_id'] = img_id
  359. data['score'] = scores[index]
  360. data['category_id'] = 1
  361. data['keypoints'] = keypoints
  362. kpt_json.append(data)
  363. else:
  364. print("Predict poses size is zero.", flush=True)
  365. img = draw_person_pose(cv2.cvtColor(img, cv2.COLOR_BGR2RGB), poses)
  366. #print('Saving result into',str(img_id)+'.png...')
  367. save_path = os.path.join(args.output_path, str(img_id)+".png")
  368. cv2.imwrite(save_path, img)
  369. result_json = 'eval_result.json'
  370. with open(os.path.join(args.output_path, result_json), 'w') as fid:
  371. json.dump(kpt_json, fid)
  372. res = evaluate_mAP(os.path.join(args.output_path, result_json), ann_file=args.ann)
  373. print('result: ', res)
  374. if __name__ == "__main__":
  375. val()