You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

eval.py 14 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Face Recognition eval."""
  16. import os
  17. import time
  18. import math
  19. from pprint import pformat
  20. import numpy as np
  21. import cv2
  22. import mindspore.dataset.transforms.py_transforms as transforms
  23. import mindspore.dataset.vision.py_transforms as vision
  24. import mindspore.dataset as de
  25. from mindspore import Tensor, context
  26. from mindspore.train.serialization import load_checkpoint, load_param_into_net
  27. from src.config import config_inference
  28. from src.backbone.resnet import get_backbone
  29. from src.my_logging import get_logger
  30. devid = int(os.getenv('DEVICE_ID'))
  31. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=devid)
  32. class TxtDataset():
  33. '''TxtDataset'''
  34. def __init__(self, root_all, filenames):
  35. super(TxtDataset, self).__init__()
  36. self.imgs = []
  37. self.labels = []
  38. for root, filename in zip(root_all, filenames):
  39. fin = open(filename, "r")
  40. for line in fin:
  41. self.imgs.append(os.path.join(root, line.strip().split(" ")[0]))
  42. self.labels.append(line.strip())
  43. fin.close()
  44. def __getitem__(self, index):
  45. try:
  46. img = cv2.cvtColor(cv2.imread(self.imgs[index]), cv2.COLOR_BGR2RGB)
  47. except:
  48. print(self.imgs[index])
  49. raise
  50. return img, index
  51. def __len__(self):
  52. return len(self.imgs)
  53. def get_all_labels(self):
  54. return self.labels
  55. class DistributedSampler():
  56. '''DistributedSampler'''
  57. def __init__(self, dataset):
  58. self.dataset = dataset
  59. self.num_replicas = 1
  60. self.rank = 0
  61. self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
  62. def __iter__(self):
  63. indices = list(range(len(self.dataset)))
  64. indices = indices[self.rank::self.num_replicas]
  65. return iter(indices)
  66. def __len__(self):
  67. return self.num_samples
  68. def get_dataloader(img_predix_all, img_list_all, batch_size, img_transforms):
  69. dataset = TxtDataset(img_predix_all, img_list_all)
  70. sampler = DistributedSampler(dataset)
  71. dataset_column_names = ["image", "index"]
  72. ds = de.GeneratorDataset(dataset, column_names=dataset_column_names, sampler=sampler)
  73. ds = ds.map(input_columns=["image"], operations=img_transforms)
  74. ds = ds.batch(batch_size, num_parallel_workers=8, drop_remainder=False)
  75. ds = ds.repeat(1)
  76. return ds, len(dataset), dataset.get_all_labels()
  77. def generate_test_pair(jk_list, zj_list):
  78. '''generate_test_pair'''
  79. file_paths = [jk_list, zj_list]
  80. jk_dict = {}
  81. zj_dict = {}
  82. jk_zj_dict_list = [jk_dict, zj_dict]
  83. for path, x_dict in zip(file_paths, jk_zj_dict_list):
  84. with open(path, 'r') as fr:
  85. for line in fr:
  86. label = line.strip().split(' ')[1]
  87. tmp = x_dict.get(label, [])
  88. tmp.append(line.strip())
  89. x_dict[label] = tmp
  90. zj2jk_pairs = []
  91. for key in jk_dict:
  92. jk_file_list = jk_dict[key]
  93. zj_file_list = zj_dict[key]
  94. for zj_file in zj_file_list:
  95. zj2jk_pairs.append([zj_file, jk_file_list])
  96. return zj2jk_pairs
  97. def check_minmax(args, data, min_value=0.99, max_value=1.01):
  98. min_data = data.min()
  99. max_data = data.max()
  100. if np.isnan(min_data) or np.isnan(max_data):
  101. args.logger.info('ERROR, nan happened, please check if used fp16 or other error')
  102. raise Exception
  103. if min_data < min_value or max_data > max_value:
  104. args.logger.info('ERROR, min or max is out if range, range=[{}, {}], minmax=[{}, {}]'.format(
  105. min_value, max_value, min_data, max_data))
  106. raise Exception
  107. def get_model(args):
  108. '''get_model'''
  109. net = get_backbone(args)
  110. if args.fp16:
  111. net.add_flags_recursive(fp16=True)
  112. if args.weight.endswith('.ckpt'):
  113. param_dict = load_checkpoint(args.weight)
  114. param_dict_new = {}
  115. for key, value in param_dict.items():
  116. if key.startswith('moments.'):
  117. continue
  118. elif key.startswith('network.'):
  119. param_dict_new[key[8:]] = value
  120. else:
  121. param_dict_new[key] = value
  122. load_param_into_net(net, param_dict_new)
  123. args.logger.info('INFO, ------------- load model success--------------')
  124. else:
  125. args.logger.info('ERROR, not supprot file:{}, please check weight in config.py'.format(args.weight))
  126. return 0
  127. net.set_train(False)
  128. return net
  129. def topk(matrix, k, axis=1):
  130. '''topk'''
  131. if axis == 0:
  132. row_index = np.arange(matrix.shape[1 - axis])
  133. topk_index = np.argpartition(-matrix, k, axis=axis)[0:k, :]
  134. topk_data = matrix[topk_index, row_index]
  135. topk_index_sort = np.argsort(-topk_data, axis=axis)
  136. topk_data_sort = topk_data[topk_index_sort, row_index]
  137. topk_index_sort = topk_index[0:k, :][topk_index_sort, row_index]
  138. else:
  139. column_index = np.arange(matrix.shape[1 - axis])[:, None]
  140. topk_index = np.argpartition(-matrix, k, axis=axis)[:, 0:k]
  141. topk_data = matrix[column_index, topk_index]
  142. topk_index_sort = np.argsort(-topk_data, axis=axis)
  143. topk_data_sort = topk_data[column_index, topk_index_sort]
  144. topk_index_sort = topk_index[:, 0:k][column_index, topk_index_sort]
  145. return topk_data_sort, topk_index_sort
  146. def cal_topk(args, idx, zj2jk_pairs, test_embedding_tot, dis_embedding_tot):
  147. '''cal_topk'''
  148. args.logger.info('start idx:{} subprocess...'.format(idx))
  149. correct = np.array([0] * 2)
  150. tot = np.array([0])
  151. zj, jk_all = zj2jk_pairs[idx]
  152. zj_embedding = test_embedding_tot[zj]
  153. jk_all_embedding = np.concatenate([np.expand_dims(test_embedding_tot[jk], axis=0) for jk in jk_all], axis=0)
  154. args.logger.info('INFO, calculate top1 acc index:{}, zj_embedding shape:{}'.format(idx, zj_embedding.shape))
  155. args.logger.info('INFO, calculate top1 acc index:{}, jk_all_embedding shape:{}'.format(idx, jk_all_embedding.shape))
  156. test_time = time.time()
  157. mm = np.matmul(np.expand_dims(zj_embedding, axis=0), dis_embedding_tot)
  158. top100_jk2zj = np.squeeze(topk(mm, 100)[0], axis=0)
  159. top100_zj2jk = topk(np.matmul(jk_all_embedding, dis_embedding_tot), 100)[0]
  160. test_time_used = time.time() - test_time
  161. args.logger.info('INFO, calculate top1 acc index:{}, np.matmul().top(100) time used:{:.2f}s'.format(
  162. idx, test_time_used))
  163. tot[0] = len(jk_all)
  164. for i, jk in enumerate(jk_all):
  165. jk_embedding = test_embedding_tot[jk]
  166. similarity = np.dot(jk_embedding, zj_embedding)
  167. if similarity > top100_jk2zj[0]:
  168. correct[0] += 1
  169. if similarity > top100_zj2jk[i, 0]:
  170. correct[1] += 1
  171. return correct, tot
  172. def l2normalize(features):
  173. epsilon = 1e-12
  174. l2norm = np.sum(np.abs(features) ** 2, axis=1, keepdims=True) ** (1./2)
  175. l2norm[np.logical_and(l2norm < 0, l2norm > -epsilon)] = -epsilon
  176. l2norm[np.logical_and(l2norm >= 0, l2norm < epsilon)] = epsilon
  177. return features/l2norm
  178. def main(args):
  179. if not os.path.exists(args.test_dir):
  180. args.logger.info('ERROR, test_dir is not exists, please set test_dir in config.py.')
  181. return 0
  182. all_start_time = time.time()
  183. net = get_model(args)
  184. compile_time_used = time.time() - all_start_time
  185. args.logger.info('INFO, graph compile finished, time used:{:.2f}s, start calculate img embedding'.
  186. format(compile_time_used))
  187. img_transforms = transforms.Compose([vision.ToTensor(), vision.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
  188. #for test images
  189. args.logger.info('INFO, start step1, calculate test img embedding, weight file = {}'.format(args.weight))
  190. step1_start_time = time.time()
  191. ds, img_tot, all_labels = get_dataloader(args.test_img_predix, args.test_img_list,
  192. args.test_batch_size, img_transforms)
  193. args.logger.info('INFO, dataset total test img:{}, total test batch:{}'.format(img_tot, ds.get_dataset_size()))
  194. test_embedding_tot_np = np.zeros((img_tot, args.emb_size))
  195. test_img_labels = all_labels
  196. data_loader = ds.create_dict_iterator(output_numpy=True, num_epochs=1)
  197. for i, data in enumerate(data_loader):
  198. img, idxs = data["image"], data["index"]
  199. out = net(Tensor(img)).asnumpy().astype(np.float32)
  200. embeddings = l2normalize(out)
  201. for batch in range(embeddings.shape[0]):
  202. test_embedding_tot_np[idxs[batch]] = embeddings[batch]
  203. try:
  204. check_minmax(args, np.linalg.norm(test_embedding_tot_np, ord=2, axis=1))
  205. except ValueError:
  206. return 0
  207. test_embedding_tot = {}
  208. for idx, label in enumerate(test_img_labels):
  209. test_embedding_tot[label] = test_embedding_tot_np[idx]
  210. step2_start_time = time.time()
  211. step1_time_used = step2_start_time - step1_start_time
  212. args.logger.info('INFO, step1 finished, time used:{:.2f}s, start step2, calculate dis img embedding'.
  213. format(step1_time_used))
  214. # for dis images
  215. ds_dis, img_tot, _ = get_dataloader(args.dis_img_predix, args.dis_img_list, args.dis_batch_size, img_transforms)
  216. dis_embedding_tot_np = np.zeros((img_tot, args.emb_size))
  217. total_batch = ds_dis.get_dataset_size()
  218. args.logger.info('INFO, dataloader total dis img:{}, total dis batch:{}'.format(img_tot, total_batch))
  219. start_time = time.time()
  220. img_per_gpu = int(math.ceil(1.0 * img_tot / args.world_size))
  221. delta_num = img_per_gpu * args.world_size - img_tot
  222. start_idx = img_per_gpu * args.local_rank - max(0, args.local_rank - (args.world_size - delta_num))
  223. data_loader = ds_dis.create_dict_iterator(output_numpy=True, num_epochs=1)
  224. for idx, data in enumerate(data_loader):
  225. img = data["image"]
  226. out = net(Tensor(img)).asnumpy().astype(np.float32)
  227. embeddings = l2normalize(out)
  228. dis_embedding_tot_np[start_idx:(start_idx + embeddings.shape[0])] = embeddings
  229. start_idx += embeddings.shape[0]
  230. if args.local_rank % 8 == 0 and idx % args.log_interval == 0 and idx > 0:
  231. speed = 1.0 * (args.dis_batch_size * args.log_interval * args.world_size) / (time.time() - start_time)
  232. time_left = (total_batch - idx - 1) * args.dis_batch_size *args.world_size / speed
  233. args.logger.info('INFO, processed [{}/{}], speed: {:.2f} img/s, left:{:.2f}s'.
  234. format(idx, total_batch, speed, time_left))
  235. start_time = time.time()
  236. try:
  237. check_minmax(args, np.linalg.norm(dis_embedding_tot_np, ord=2, axis=1))
  238. except ValueError:
  239. return 0
  240. step3_start_time = time.time()
  241. step2_time_used = step3_start_time - step2_start_time
  242. args.logger.info('INFO, step2 finished, time used:{:.2f}s, start step3, calculate top1 acc'.format(step2_time_used))
  243. # clear npu memory
  244. img = None
  245. net = None
  246. dis_embedding_tot_np = np.transpose(dis_embedding_tot_np, (1, 0))
  247. args.logger.info('INFO, calculate top1 acc dis_embedding_tot_np shape:{}'.format(dis_embedding_tot_np.shape))
  248. # find best match
  249. assert len(args.test_img_list) % 2 == 0
  250. task_num = int(len(args.test_img_list) / 2)
  251. correct = np.array([0] * (2 * task_num))
  252. tot = np.array([0] * task_num)
  253. for i in range(int(len(args.test_img_list) / 2)):
  254. jk_list = args.test_img_list[2 * i]
  255. zj_list = args.test_img_list[2 * i + 1]
  256. zj2jk_pairs = sorted(generate_test_pair(jk_list, zj_list))
  257. sampler = DistributedSampler(zj2jk_pairs)
  258. args.logger.info('INFO, calculate top1 acc sampler len:{}'.format(len(sampler)))
  259. for idx in sampler:
  260. out1, out2 = cal_topk(args, idx, zj2jk_pairs, test_embedding_tot, dis_embedding_tot_np)
  261. correct[2 * i] += out1[0]
  262. correct[2 * i + 1] += out1[1]
  263. tot[i] += out2[0]
  264. args.logger.info('local_rank={},tot={},correct={}'.format(args.local_rank, tot, correct))
  265. step3_time_used = time.time() - step3_start_time
  266. args.logger.info('INFO, step3 finished, time used:{:.2f}s'.format(step3_time_used))
  267. args.logger.info('weight:{}'.format(args.weight))
  268. for i in range(int(len(args.test_img_list) / 2)):
  269. test_set_name = 'test_dataset'
  270. zj2jk_acc = correct[2 * i] / tot[i]
  271. jk2zj_acc = correct[2 * i + 1] / tot[i]
  272. avg_acc = (zj2jk_acc + jk2zj_acc) / 2
  273. results = '[{}]: zj2jk={:.4f}, jk2zj={:.4f}, avg={:.4f}'.format(test_set_name, zj2jk_acc, jk2zj_acc, avg_acc)
  274. args.logger.info(results)
  275. args.logger.info('INFO, tot time used: {:.2f}s'.format(time.time() - all_start_time))
  276. return 0
  277. if __name__ == '__main__':
  278. arg = config_inference
  279. arg.test_img_predix = [arg.test_dir, arg.test_dir]
  280. arg.test_img_list = [os.path.join(arg.test_dir, 'lists/jk_list.txt'),
  281. os.path.join(arg.test_dir, 'lists/zj_list.txt')]
  282. arg.dis_img_predix = [arg.test_dir,]
  283. arg.dis_img_list = [os.path.join(arg.test_dir, 'lists/dis_list.txt'),]
  284. log_path = os.path.join(arg.ckpt_path, 'logs')
  285. arg.logger = get_logger(log_path, arg.local_rank)
  286. arg.logger.info('Config\n\n%s\n' % pformat(arg))
  287. main(arg)