From 676c219bc4a2dba085c7dba086c134f589ce6562 Mon Sep 17 00:00:00 2001 From: huenrui Date: Mon, 22 Mar 2021 20:52:19 +0800 Subject: [PATCH] add tprr 8p version --- model_zoo/research/nlp/tprr/retriever_eval.py | 60 ++++++++++++------- model_zoo/research/nlp/tprr/src/config.py | 3 +- .../research/nlp/tprr/src/process_data.py | 3 + model_zoo/research/nlp/tprr/src/utils.py | 9 +-- 4 files changed, 47 insertions(+), 28 deletions(-) diff --git a/model_zoo/research/nlp/tprr/retriever_eval.py b/model_zoo/research/nlp/tprr/retriever_eval.py index b0b7a12504..fd9dd58bdd 100644 --- a/model_zoo/research/nlp/tprr/retriever_eval.py +++ b/model_zoo/research/nlp/tprr/retriever_eval.py @@ -19,6 +19,7 @@ Retriever Evaluation. import time import json +from multiprocessing import Pool import numpy as np from mindspore import Tensor @@ -69,16 +70,20 @@ def eval_output(out_2, last_out, path_raw, gold_path, val, true_count): return val, true_count, topk_titles -def evaluation(): +def evaluation(d_id): """evaluation""" + context.set_context(mode=context.GRAPH_MODE, + device_target='Ascend', + device_id=d_id, + save_graphs=False) print('********************** loading corpus ********************** ') s_lc = time.time() data_generator = DataGen(config) - queries = read_query(config) + queries = read_query(config, d_id) print("loading corpus time (h):", (time.time() - s_lc) / 3600) print('********************** loading model ********************** ') - s_lm = time.time() + s_lm = time.time() model_onehop_bert = ModelOneHop() param_dict = load_checkpoint(config.onehop_bert_path) load_param_into_net(model_onehop_bert, param_dict) @@ -90,10 +95,10 @@ def evaluation(): print("loading model time (h):", (time.time() - s_lm) / 3600) print('********************** evaluation ********************** ') - s_tr = time.time() f_dev = open(config.dev_path, 'rb') dev_data = json.load(f_dev) + f_dev.close() q_gold = {} q_2id = {} for onedata in dev_data: @@ -101,10 +106,10 @@ def evaluation(): q_gold[onedata["question"]] = [get_new_title(get_raw_title(item)) for item in onedata['path']] q_2id[onedata["question"]] = onedata['_id'] val, true_count, count, step = 0, 0, 0, 0 - batch_queries = split_queries(config, queries)[:-1] + batch_queries = split_queries(config, queries) output_path = [] for _, batch in enumerate(batch_queries): - print("###step###: ", step) + print("###step###: ", str(step) + "_" + str(d_id)) query = batch[0] temp_dict = {} temp_dict['q_id'] = q_2id[query] @@ -158,23 +163,36 @@ def evaluation(): val, true_count, topk_titles = eval_output(out_2, last_out, path_raw, gold_path, val, true_count) temp_dict['topk_titles'] = topk_titles output_path.append(temp_dict) - count += 1 - print("val:", val) - print("count:", count) - print("true count:", true_count) - if count: - print("PEM:", val / count) - if true_count: - print("true top8 PEM:", val / true_count) step += 1 - save_json(output_path, config.save_path, config.save_name) - print("evaluation time (h):", (time.time() - s_tr) / 3600) + count += 1 + return {'val': val, 'count': count, 'true_count': true_count, 'path': output_path} if __name__ == "__main__": + t_s = time.time() config = ThinkRetrieverConfig() - context.set_context(mode=context.GRAPH_MODE, - device_target='Ascend', - device_id=config.device_id, - save_graphs=False) - evaluation() + pool = Pool(processes=config.device_num) + results = [] + for device_id in range(config.device_num): + results.append(pool.apply_async(evaluation, (device_id,))) + + print("Waiting for all subprocess done...") + + pool.close() + pool.join() + + val_all, true_count_all, count_all = 0, 0, 0 + output_path_all = [] + for res in results: + output = res.get() + val_all += output['val'] + count_all += output['count'] + true_count_all += output['true_count'] + output_path_all += output['path'] + print("val:", val_all) + print("count:", count_all) + print("true count:", true_count_all) + print("PEM:", val_all / count_all) + print("true top8 PEM:", val_all / true_count_all) + save_json(output_path_all, config.save_path, config.save_name) + print("evaluation time (h):", (time.time() - t_s) / 3600) diff --git a/model_zoo/research/nlp/tprr/src/config.py b/model_zoo/research/nlp/tprr/src/config.py index edbe493968..2b8519e7f7 100644 --- a/model_zoo/research/nlp/tprr/src/config.py +++ b/model_zoo/research/nlp/tprr/src/config.py @@ -31,7 +31,7 @@ def ThinkRetrieverConfig(): parser.add_argument("--topk", type=int, default=8, help="top num") parser.add_argument("--onehop_num", type=int, default=8, help="onehop num") parser.add_argument("--batch_size", type=int, default=1, help="batch size") - parser.add_argument("--device_id", type=int, default=0, help="device id") + parser.add_argument("--device_num", type=int, default=8, help="device num") parser.add_argument("--save_name", type=str, default='doc_path', help='name of output') parser.add_argument("--save_path", type=str, default='../', help='path of output') parser.add_argument("--vocab_path", type=str, default='../vocab.txt', help="vocab path") @@ -43,4 +43,5 @@ def ThinkRetrieverConfig(): parser.add_argument("--onehop_mlp_path", type=str, default='../onehop_mlp.ckpt', help="onehop mlp ckpt path") parser.add_argument("--twohop_bert_path", type=str, default='../twohop.ckpt', help="twohop bert ckpt path") parser.add_argument("--twohop_mlp_path", type=str, default='../twohop_mlp.ckpt', help="twohop mlp ckpt path") + parser.add_argument("--q_path", type=str, default="../queries", help="queries data path") return parser.parse_args() diff --git a/model_zoo/research/nlp/tprr/src/process_data.py b/model_zoo/research/nlp/tprr/src/process_data.py index e8028d1036..65a46a08da 100644 --- a/model_zoo/research/nlp/tprr/src/process_data.py +++ b/model_zoo/research/nlp/tprr/src/process_data.py @@ -53,6 +53,9 @@ class DataGen: data_db = pkl.load(f_wiki, encoding="gbk") dev_data = json.load(f_train) q_doc_text = pkl.load(f_doc, encoding='gbk') + f_wiki.close() + f_train.close() + f_doc.close() return data_db, dev_data, q_doc_text def process_data(self): diff --git a/model_zoo/research/nlp/tprr/src/utils.py b/model_zoo/research/nlp/tprr/src/utils.py index cb3f94746e..4838866f2e 100644 --- a/model_zoo/research/nlp/tprr/src/utils.py +++ b/model_zoo/research/nlp/tprr/src/utils.py @@ -28,13 +28,10 @@ def normalize(text): return text[0].capitalize() + text[1:] -def read_query(config): +def read_query(config, device_id): """get query data""" - with open(config.dev_data_path, 'rb') as f: - temp_dic = pkl.load(f, encoding='gbk') - queries = [] - for item in temp_dic: - queries.append(temp_dic[item]["query"]) + with open(config.q_path + str(device_id), 'rb') as f: + queries = pkl.load(f, encoding='gbk') return queries