You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

eval.py 17 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Face Recognition eval."""
  16. import os
  17. import time
  18. import math
  19. from pprint import pformat
  20. import numpy as np
  21. import cv2
  22. import mindspore.dataset.transforms.py_transforms as transforms
  23. import mindspore.dataset.vision.py_transforms as vision
  24. import mindspore.dataset as de
  25. from mindspore import Tensor, context
  26. from mindspore.train.serialization import load_checkpoint, load_param_into_net
  27. from src.backbone.resnet import get_backbone
  28. from src.my_logging import get_logger
  29. from utils.config import config
  30. from utils.moxing_adapter import moxing_wrapper
  31. from utils.device_adapter import get_device_id, get_device_num, get_rank_id
  32. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=get_device_id())
  33. class TxtDataset():
  34. '''TxtDataset'''
  35. def __init__(self, root_all, filenames):
  36. super(TxtDataset, self).__init__()
  37. self.imgs = []
  38. self.labels = []
  39. for root, filename in zip(root_all, filenames):
  40. fin = open(filename, "r")
  41. for line in fin:
  42. self.imgs.append(os.path.join(root, line.strip().split(" ")[0]))
  43. self.labels.append(line.strip())
  44. fin.close()
  45. def __getitem__(self, index):
  46. try:
  47. img = cv2.cvtColor(cv2.imread(self.imgs[index]), cv2.COLOR_BGR2RGB)
  48. except:
  49. print(self.imgs[index])
  50. raise
  51. return img, index
  52. def __len__(self):
  53. return len(self.imgs)
  54. def get_all_labels(self):
  55. return self.labels
  56. class DistributedSampler():
  57. '''DistributedSampler'''
  58. def __init__(self, dataset):
  59. self.dataset = dataset
  60. self.num_replicas = 1
  61. self.rank = 0
  62. self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
  63. def __iter__(self):
  64. indices = list(range(len(self.dataset)))
  65. indices = indices[self.rank::self.num_replicas]
  66. return iter(indices)
  67. def __len__(self):
  68. return self.num_samples
  69. def get_dataloader(img_predix_all, img_list_all, batch_size, img_transforms):
  70. dataset = TxtDataset(img_predix_all, img_list_all)
  71. sampler = DistributedSampler(dataset)
  72. dataset_column_names = ["image", "index"]
  73. ds = de.GeneratorDataset(dataset, column_names=dataset_column_names, sampler=sampler)
  74. ds = ds.map(input_columns=["image"], operations=img_transforms)
  75. ds = ds.batch(batch_size, num_parallel_workers=8, drop_remainder=False)
  76. ds = ds.repeat(1)
  77. return ds, len(dataset), dataset.get_all_labels()
  78. def generate_test_pair(jk_list, zj_list):
  79. '''generate_test_pair'''
  80. file_paths = [jk_list, zj_list]
  81. jk_dict = {}
  82. zj_dict = {}
  83. jk_zj_dict_list = [jk_dict, zj_dict]
  84. for path, x_dict in zip(file_paths, jk_zj_dict_list):
  85. with open(path, 'r') as fr:
  86. for line in fr:
  87. label = line.strip().split(' ')[1]
  88. tmp = x_dict.get(label, [])
  89. tmp.append(line.strip())
  90. x_dict[label] = tmp
  91. zj2jk_pairs = []
  92. for key in jk_dict:
  93. jk_file_list = jk_dict[key]
  94. zj_file_list = zj_dict[key]
  95. for zj_file in zj_file_list:
  96. zj2jk_pairs.append([zj_file, jk_file_list])
  97. return zj2jk_pairs
  98. def check_minmax(args, data, min_value=0.99, max_value=1.01):
  99. min_data = data.min()
  100. max_data = data.max()
  101. if np.isnan(min_data) or np.isnan(max_data):
  102. args.logger.info('ERROR, nan happened, please check if used fp16 or other error')
  103. raise Exception
  104. if min_data < min_value or max_data > max_value:
  105. args.logger.info('ERROR, min or max is out if range, range=[{}, {}], minmax=[{}, {}]'.format(
  106. min_value, max_value, min_data, max_data))
  107. raise Exception
  108. def get_model(args):
  109. '''get_model'''
  110. net = get_backbone(args)
  111. if args.fp16:
  112. net.add_flags_recursive(fp16=True)
  113. if args.weight.endswith('.ckpt'):
  114. param_dict = load_checkpoint(args.weight)
  115. param_dict_new = {}
  116. for key, value in param_dict.items():
  117. if key.startswith('moments.'):
  118. continue
  119. elif key.startswith('network.'):
  120. param_dict_new[key[8:]] = value
  121. else:
  122. param_dict_new[key] = value
  123. load_param_into_net(net, param_dict_new)
  124. args.logger.info('INFO, ------------- load model success--------------')
  125. else:
  126. args.logger.info('ERROR, not support file:{}, please check weight in config.py'.format(args.weight))
  127. return 0
  128. net.set_train(False)
  129. return net
  130. def topk(matrix, k, axis=1):
  131. '''topk'''
  132. if axis == 0:
  133. row_index = np.arange(matrix.shape[1 - axis])
  134. topk_index = np.argpartition(-matrix, k, axis=axis)[0:k, :]
  135. topk_data = matrix[topk_index, row_index]
  136. topk_index_sort = np.argsort(-topk_data, axis=axis)
  137. topk_data_sort = topk_data[topk_index_sort, row_index]
  138. topk_index_sort = topk_index[0:k, :][topk_index_sort, row_index]
  139. else:
  140. column_index = np.arange(matrix.shape[1 - axis])[:, None]
  141. topk_index = np.argpartition(-matrix, k, axis=axis)[:, 0:k]
  142. topk_data = matrix[column_index, topk_index]
  143. topk_index_sort = np.argsort(-topk_data, axis=axis)
  144. topk_data_sort = topk_data[column_index, topk_index_sort]
  145. topk_index_sort = topk_index[:, 0:k][column_index, topk_index_sort]
  146. return topk_data_sort, topk_index_sort
  147. def cal_topk(args, idx, zj2jk_pairs, test_embedding_tot, dis_embedding_tot):
  148. '''cal_topk'''
  149. args.logger.info('start idx:{} subprocess...'.format(idx))
  150. correct = np.array([0] * 2)
  151. tot = np.array([0])
  152. zj, jk_all = zj2jk_pairs[idx]
  153. zj_embedding = test_embedding_tot[zj]
  154. jk_all_embedding = np.concatenate([np.expand_dims(test_embedding_tot[jk], axis=0) for jk in jk_all], axis=0)
  155. args.logger.info('INFO, calculate top1 acc index:{}, zj_embedding shape:{}'.format(idx, zj_embedding.shape))
  156. args.logger.info('INFO, calculate top1 acc index:{}, jk_all_embedding shape:{}'.format(idx, jk_all_embedding.shape))
  157. test_time = time.time()
  158. mm = np.matmul(np.expand_dims(zj_embedding, axis=0), dis_embedding_tot)
  159. top100_jk2zj = np.squeeze(topk(mm, 100)[0], axis=0)
  160. top100_zj2jk = topk(np.matmul(jk_all_embedding, dis_embedding_tot), 100)[0]
  161. test_time_used = time.time() - test_time
  162. args.logger.info('INFO, calculate top1 acc index:{}, np.matmul().top(100) time used:{:.2f}s'.format(
  163. idx, test_time_used))
  164. tot[0] = len(jk_all)
  165. for i, jk in enumerate(jk_all):
  166. jk_embedding = test_embedding_tot[jk]
  167. similarity = np.dot(jk_embedding, zj_embedding)
  168. if similarity > top100_jk2zj[0]:
  169. correct[0] += 1
  170. if similarity > top100_zj2jk[i, 0]:
  171. correct[1] += 1
  172. return correct, tot
  173. def l2normalize(features):
  174. epsilon = 1e-12
  175. l2norm = np.sum(np.abs(features) ** 2, axis=1, keepdims=True) ** (1./2)
  176. l2norm[np.logical_and(l2norm < 0, l2norm > -epsilon)] = -epsilon
  177. l2norm[np.logical_and(l2norm >= 0, l2norm < epsilon)] = epsilon
  178. return features/l2norm
  179. def modelarts_pre_process():
  180. '''modelarts pre process function.'''
  181. def unzip(zip_file, save_dir):
  182. import zipfile
  183. s_time = time.time()
  184. if not os.path.exists(os.path.join(save_dir, "face_recognition_dataset")):
  185. zip_isexist = zipfile.is_zipfile(zip_file)
  186. if zip_isexist:
  187. fz = zipfile.ZipFile(zip_file, 'r')
  188. data_num = len(fz.namelist())
  189. print("Extract Start...")
  190. print("unzip file num: {}".format(data_num))
  191. i = 0
  192. for file in fz.namelist():
  193. if i % int(data_num / 100) == 0:
  194. print("unzip percent: {}%".format(i / int(data_num / 100)), flush=True)
  195. i += 1
  196. fz.extract(file, save_dir)
  197. print("cost time: {}min:{}s.".format(int((time.time() - s_time) / 60),
  198. int(int(time.time() - s_time) % 60)))
  199. print("Extract Done.")
  200. else:
  201. print("This is not zip.")
  202. else:
  203. print("Zip has been extracted.")
  204. if config.need_modelarts_dataset_unzip:
  205. zip_file_1 = os.path.join(config.data_path, "face_recognition_dataset.zip")
  206. save_dir_1 = os.path.join(config.data_path)
  207. sync_lock = "/tmp/unzip_sync.lock"
  208. # Each server contains 8 devices as most.
  209. if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
  210. print("Zip file path: ", zip_file_1)
  211. print("Unzip file save dir: ", save_dir_1)
  212. unzip(zip_file_1, save_dir_1)
  213. print("===Finish extract data synchronization===")
  214. try:
  215. os.mknod(sync_lock)
  216. except IOError:
  217. pass
  218. while True:
  219. if os.path.exists(sync_lock):
  220. break
  221. time.sleep(1)
  222. print("Device: {}, Finish sync unzip data from {} to {}.".format(get_device_id(), zip_file_1, save_dir_1))
  223. config.ckpt_path = os.path.join(config.output_path, str(get_rank_id()), config.ckpt_path)
  224. @moxing_wrapper(pre_process=modelarts_pre_process)
  225. def run_eval(args):
  226. '''run eval function.'''
  227. if not os.path.exists(args.test_dir):
  228. args.logger.info('ERROR, test_dir is not exists, please set test_dir in config.py.')
  229. return 0
  230. all_start_time = time.time()
  231. net = get_model(args)
  232. compile_time_used = time.time() - all_start_time
  233. args.logger.info('INFO, graph compile finished, time used:{:.2f}s, start calculate img embedding'.
  234. format(compile_time_used))
  235. img_transforms = transforms.Compose([vision.ToTensor(), vision.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
  236. #for test images
  237. args.logger.info('INFO, start step1, calculate test img embedding, weight file = {}'.format(args.weight))
  238. step1_start_time = time.time()
  239. ds, img_tot, all_labels = get_dataloader(args.test_img_predix, args.test_img_list,
  240. args.test_batch_size, img_transforms)
  241. args.logger.info('INFO, dataset total test img:{}, total test batch:{}'.format(img_tot, ds.get_dataset_size()))
  242. test_embedding_tot_np = np.zeros((img_tot, args.emb_size))
  243. test_img_labels = all_labels
  244. data_loader = ds.create_dict_iterator(output_numpy=True, num_epochs=1)
  245. for i, data in enumerate(data_loader):
  246. img, idxs = data["image"], data["index"]
  247. out = net(Tensor(img)).asnumpy().astype(np.float32)
  248. embeddings = l2normalize(out)
  249. for batch in range(embeddings.shape[0]):
  250. test_embedding_tot_np[idxs[batch]] = embeddings[batch]
  251. try:
  252. check_minmax(args, np.linalg.norm(test_embedding_tot_np, ord=2, axis=1))
  253. except ValueError:
  254. return 0
  255. test_embedding_tot = {}
  256. for idx, label in enumerate(test_img_labels):
  257. test_embedding_tot[label] = test_embedding_tot_np[idx]
  258. step2_start_time = time.time()
  259. step1_time_used = step2_start_time - step1_start_time
  260. args.logger.info('INFO, step1 finished, time used:{:.2f}s, start step2, calculate dis img embedding'.
  261. format(step1_time_used))
  262. # for dis images
  263. ds_dis, img_tot, _ = get_dataloader(args.dis_img_predix, args.dis_img_list, args.dis_batch_size, img_transforms)
  264. dis_embedding_tot_np = np.zeros((img_tot, args.emb_size))
  265. total_batch = ds_dis.get_dataset_size()
  266. args.logger.info('INFO, dataloader total dis img:{}, total dis batch:{}'.format(img_tot, total_batch))
  267. start_time = time.time()
  268. img_per_gpu = int(math.ceil(1.0 * img_tot / args.world_size))
  269. delta_num = img_per_gpu * args.world_size - img_tot
  270. start_idx = img_per_gpu * args.local_rank - max(0, args.local_rank - (args.world_size - delta_num))
  271. data_loader = ds_dis.create_dict_iterator(output_numpy=True, num_epochs=1)
  272. for idx, data in enumerate(data_loader):
  273. img = data["image"]
  274. out = net(Tensor(img)).asnumpy().astype(np.float32)
  275. embeddings = l2normalize(out)
  276. dis_embedding_tot_np[start_idx:(start_idx + embeddings.shape[0])] = embeddings
  277. start_idx += embeddings.shape[0]
  278. if args.local_rank % 8 == 0 and idx % args.log_interval == 0 and idx > 0:
  279. speed = 1.0 * (args.dis_batch_size * args.log_interval * args.world_size) / (time.time() - start_time)
  280. time_left = (total_batch - idx - 1) * args.dis_batch_size *args.world_size / speed
  281. args.logger.info('INFO, processed [{}/{}], speed: {:.2f} img/s, left:{:.2f}s'.
  282. format(idx, total_batch, speed, time_left))
  283. start_time = time.time()
  284. try:
  285. check_minmax(args, np.linalg.norm(dis_embedding_tot_np, ord=2, axis=1))
  286. except ValueError:
  287. return 0
  288. step3_start_time = time.time()
  289. step2_time_used = step3_start_time - step2_start_time
  290. args.logger.info('INFO, step2 finished, time used:{:.2f}s, start step3, calculate top1 acc'.format(step2_time_used))
  291. # clear npu memory
  292. img = None
  293. net = None
  294. dis_embedding_tot_np = np.transpose(dis_embedding_tot_np, (1, 0))
  295. args.logger.info('INFO, calculate top1 acc dis_embedding_tot_np shape:{}'.format(dis_embedding_tot_np.shape))
  296. # find best match
  297. assert len(args.test_img_list) % 2 == 0
  298. task_num = int(len(args.test_img_list) / 2)
  299. correct = np.array([0] * (2 * task_num))
  300. tot = np.array([0] * task_num)
  301. for i in range(int(len(args.test_img_list) / 2)):
  302. jk_list = args.test_img_list[2 * i]
  303. zj_list = args.test_img_list[2 * i + 1]
  304. zj2jk_pairs = sorted(generate_test_pair(jk_list, zj_list))
  305. sampler = DistributedSampler(zj2jk_pairs)
  306. args.logger.info('INFO, calculate top1 acc sampler len:{}'.format(len(sampler)))
  307. for idx in sampler:
  308. out1, out2 = cal_topk(args, idx, zj2jk_pairs, test_embedding_tot, dis_embedding_tot_np)
  309. correct[2 * i] += out1[0]
  310. correct[2 * i + 1] += out1[1]
  311. tot[i] += out2[0]
  312. args.logger.info('local_rank={},tot={},correct={}'.format(args.local_rank, tot, correct))
  313. step3_time_used = time.time() - step3_start_time
  314. args.logger.info('INFO, step3 finished, time used:{:.2f}s'.format(step3_time_used))
  315. args.logger.info('weight:{}'.format(args.weight))
  316. for i in range(int(len(args.test_img_list) / 2)):
  317. test_set_name = 'test_dataset'
  318. zj2jk_acc = correct[2 * i] / tot[i]
  319. jk2zj_acc = correct[2 * i + 1] / tot[i]
  320. avg_acc = (zj2jk_acc + jk2zj_acc) / 2
  321. results = '[{}]: zj2jk={:.4f}, jk2zj={:.4f}, avg={:.4f}'.format(test_set_name, zj2jk_acc, jk2zj_acc, avg_acc)
  322. args.logger.info(results)
  323. args.logger.info('INFO, tot time used: {:.2f}s'.format(time.time() - all_start_time))
  324. return 0
  325. if __name__ == '__main__':
  326. config.test_img_predix = [os.path.join(config.test_dir, 'test_dataset/'),
  327. os.path.join(config.test_dir, 'test_dataset/')]
  328. config.test_img_list = [os.path.join(config.test_dir, 'lists/jk_list.txt'),
  329. os.path.join(config.test_dir, 'lists/zj_list.txt')]
  330. config.dis_img_predix = [os.path.join(config.test_dir, 'dis_dataset/'),]
  331. config.dis_img_list = [os.path.join(config.test_dir, 'lists/dis_list.txt'),]
  332. log_path = os.path.join(config.ckpt_path, 'logs')
  333. config.logger = get_logger(log_path, config.local_rank)
  334. config.logger.info('Config %s', pformat(config))
  335. run_eval(config)