| @@ -109,7 +109,7 @@ def generate_test_pair(jk_list, zj_list): | |||||
| zj2jk_pairs.append([zj_file, jk_file_list]) | zj2jk_pairs.append([zj_file, jk_file_list]) | ||||
| return zj2jk_pairs | return zj2jk_pairs | ||||
| def check_minmax(data, min_value=0.99, max_value=1.01): | |||||
| def check_minmax(args, data, min_value=0.99, max_value=1.01): | |||||
| min_data = data.min() | min_data = data.min() | ||||
| max_data = data.max() | max_data = data.max() | ||||
| if np.isnan(min_data) or np.isnan(max_data): | if np.isnan(min_data) or np.isnan(max_data): | ||||
| @@ -162,7 +162,7 @@ def topk(matrix, k, axis=1): | |||||
| topk_index_sort = topk_index[:, 0:k][column_index, topk_index_sort] | topk_index_sort = topk_index[:, 0:k][column_index, topk_index_sort] | ||||
| return topk_data_sort, topk_index_sort | return topk_data_sort, topk_index_sort | ||||
| def cal_topk(idx, zj2jk_pairs, test_embedding_tot, dis_embedding_tot): | |||||
| def cal_topk(args, idx, zj2jk_pairs, test_embedding_tot, dis_embedding_tot): | |||||
| '''cal_topk''' | '''cal_topk''' | ||||
| args.logger.info('start idx:{} subprocess...'.format(idx)) | args.logger.info('start idx:{} subprocess...'.format(idx)) | ||||
| correct = np.array([0] * 2) | correct = np.array([0] * 2) | ||||
| @@ -230,7 +230,7 @@ def main(args): | |||||
| for batch in range(embeddings.shape[0]): | for batch in range(embeddings.shape[0]): | ||||
| test_embedding_tot_np[idxs[batch]] = embeddings[batch] | test_embedding_tot_np[idxs[batch]] = embeddings[batch] | ||||
| try: | try: | ||||
| check_minmax(np.linalg.norm(test_embedding_tot_np, ord=2, axis=1)) | |||||
| check_minmax(args, np.linalg.norm(test_embedding_tot_np, ord=2, axis=1)) | |||||
| except ValueError: | except ValueError: | ||||
| return 0 | return 0 | ||||
| @@ -266,7 +266,7 @@ def main(args): | |||||
| format(idx, total_batch, speed, time_left)) | format(idx, total_batch, speed, time_left)) | ||||
| start_time = time.time() | start_time = time.time() | ||||
| try: | try: | ||||
| check_minmax(np.linalg.norm(dis_embedding_tot_np, ord=2, axis=1)) | |||||
| check_minmax(args, np.linalg.norm(dis_embedding_tot_np, ord=2, axis=1)) | |||||
| except ValueError: | except ValueError: | ||||
| return 0 | return 0 | ||||
| @@ -295,7 +295,7 @@ def main(args): | |||||
| sampler = DistributedSampler(zj2jk_pairs) | sampler = DistributedSampler(zj2jk_pairs) | ||||
| args.logger.info('INFO, calculate top1 acc sampler len:{}'.format(len(sampler))) | args.logger.info('INFO, calculate top1 acc sampler len:{}'.format(len(sampler))) | ||||
| for idx in sampler: | for idx in sampler: | ||||
| out1, out2 = cal_topk(idx, zj2jk_pairs, test_embedding_tot, dis_embedding_tot_np) | |||||
| out1, out2 = cal_topk(args, idx, zj2jk_pairs, test_embedding_tot, dis_embedding_tot_np) | |||||
| correct[2 * i] += out1[0] | correct[2 * i] += out1[0] | ||||
| correct[2 * i + 1] += out1[1] | correct[2 * i + 1] += out1[1] | ||||
| tot[i] += out2[0] | tot[i] += out2[0] | ||||