| @@ -38,6 +38,9 @@ Wikipedia data: the 2017 English Wikipedia dump version with bidirectional hyper | |||||
| dev data: HotPotQA full wiki setting dev data with 7398 question-answer pairs. | dev data: HotPotQA full wiki setting dev data with 7398 question-answer pairs. | ||||
| dev tf-idf data: the candidates for each question in dev data which is originated from top-500 retrieved from 5M paragraphs of Wikipedia | dev tf-idf data: the candidates for each question in dev data which is originated from top-500 retrieved from 5M paragraphs of Wikipedia | ||||
| through TF-IDF. | through TF-IDF. | ||||
| The dataset of re-ranker consists of two parts: | |||||
| Wikipedia data: the 2017 English Wikipedia dump version. | |||||
| dev data: HotPotQA full wiki setting dev data with 7398 question-answer pairs. | |||||
| # [Features](#contents) | # [Features](#contents) | ||||
| @@ -64,6 +67,7 @@ After installing MindSpore via the official website and Dataset is correctly gen | |||||
| ```python | ```python | ||||
| # run evaluation example with HotPotQA dev dataset | # run evaluation example with HotPotQA dev dataset | ||||
| sh run_eval_ascend.sh | sh run_eval_ascend.sh | ||||
| sh run_eval_ascend_reranker_reader.sh | |||||
| ``` | ``` | ||||
| # [Script Description](#contents) | # [Script Description](#contents) | ||||
| @@ -75,25 +79,39 @@ After installing MindSpore via the official website and Dataset is correctly gen | |||||
| └─tprr | └─tprr | ||||
| ├─README.md | ├─README.md | ||||
| ├─scripts | ├─scripts | ||||
| | ├─run_eval_ascend.sh # Launch evaluation in ascend | |||||
| | ├─run_eval_ascend.sh # Launch retriever evaluation in ascend | |||||
| | └─run_eval_ascend_reranker_reader # Launch re-ranker and reader evaluation in ascend | |||||
| | | | | ||||
| ├─src | ├─src | ||||
| | ├─config.py # Evaluation configurations | |||||
| | ├─onehop.py # Onehop model | |||||
| | ├─onehop_bert.py # Onehop bert model | |||||
| | ├─process_data.py # Data preprocessing | |||||
| | ├─twohop.py # Twohop model | |||||
| | ├─twohop_bert.py # Twohop bert model | |||||
| | └─utils.py # Utils for evaluation | |||||
| | ├─build_reranker_data.py # build data for re-ranker from result of retriever | |||||
| | ├─config.py # Evaluation configurations for retriever | |||||
| | ├─hotpot_evaluate_v1.py # Hotpotqa evaluation script | |||||
| | ├─onehop.py # Onehop model of retriever | |||||
| | ├─onehop_bert.py # Onehop bert model of retriever | |||||
| | ├─process_data.py # Data preprocessing for retriever | |||||
| | ├─reader.py # Reader model | |||||
| | ├─reader_albert_xxlarge.py # Albert-xxlarge module of reader model | |||||
| | ├─reader_downstream.py # Downstream module of reader model | |||||
| | ├─reader_eval.py # Reader evaluation script | |||||
| | ├─rerank_albert_xxlarge.py # Albert-xxlarge module of re-ranker model | |||||
| | ├─rerank_and_reader_data_generator.py # Data generator for re-ranker and reader | |||||
| | ├─rerank_and_reader_utils.py # Utils for re-ranker and reader | |||||
| | ├─rerank_downstream.py # Downstream module of re-ranker model | |||||
| | ├─reranker.py # Re-ranker model | |||||
| | ├─reranker_eval.py # Re-ranker evaluation script | |||||
| | ├─twohop.py # Twohop model of retriever | |||||
| | ├─twohop_bert.py # Twohop bert model of retriever | |||||
| | └─utils.py # Utils for retriever | |||||
| | | | | ||||
| └─retriever_eval.py # Evaluation net for retriever | |||||
| ├─retriever_eval.py # Evaluation net for retriever | |||||
| └─reranker_and_reader_eval.py # Evaluation net for re-ranker and reader | |||||
| ``` | ``` | ||||
| ## [Script Parameters](#contents) | ## [Script Parameters](#contents) | ||||
| Parameters for evaluation can be set in config.py. | |||||
| Parameters for retriever evaluation can be set in config.py. | |||||
| - config for TPRR retriever dataset | |||||
| - config for TPRR retriever | |||||
| ```python | ```python | ||||
| "q_len": 64, # Max query length | "q_len": 64, # Max query length | ||||
| @@ -108,17 +126,30 @@ Parameters for evaluation can be set in config.py. | |||||
| config.py for more configuration. | config.py for more configuration. | ||||
| Parameters for re-ranker and reader evaluation can be passed directly at execution time. | |||||
| - parameters for TPRR re-ranker and reader | |||||
| ```python | |||||
| "seq_len": 512, # sequence length | |||||
| "rerank_batch_size": 32, # batch size for re-ranker evaluation | |||||
| "reader_batch_size": 448, # batch size for reader evaluation | |||||
| "sp_threshold": 8 # threshold for picking supporting sentence | |||||
| ``` | |||||
| config.py for more configuration. | |||||
| ## [Evaluation Process](#contents) | ## [Evaluation Process](#contents) | ||||
| ### Evaluation | ### Evaluation | ||||
| - Evaluation on Ascend | |||||
| - Retriever evaluation on Ascend | |||||
| ```python | ```python | ||||
| sh run_eval_ascend.sh | sh run_eval_ascend.sh | ||||
| ``` | ``` | ||||
| Evaluation result will be stored in the scripts path, whose folder name begins with "eval". You can find the result like the | |||||
| Evaluation result will be stored in the scripts path, whose folder name begins with "eval_tr". You can find the result like the | |||||
| followings in log. | followings in log. | ||||
| ```python | ```python | ||||
| @@ -138,6 +169,35 @@ Parameters for evaluation can be set in config.py. | |||||
| evaluation time (h): 20.155506462653477 | evaluation time (h): 20.155506462653477 | ||||
| ``` | ``` | ||||
| - Re-ranker and reader evaluation on Ascend | |||||
| Use the output of retriever as input of re-ranker | |||||
| ```python | |||||
| sh run_eval_ascend_reranker_reader.sh | |||||
| ``` | |||||
| Evaluation result will be stored in the scripts path, whose folder name begins with "eval". You can find the result like the | |||||
| followings in log. | |||||
| ```python | |||||
| total top1 pem: 0.8803511141120864 | |||||
| ... | |||||
| em: 0.67440918298447 | |||||
| f1: 0.8025625656569652 | |||||
| prec: 0.8292800393689271 | |||||
| recall: 0.8136908451841731 | |||||
| sp_em: 0.6009453072248481 | |||||
| sp_f1: 0.844555664157302 | |||||
| sp_prec: 0.8640844345841021 | |||||
| sp_recall: 0.8446123918845106 | |||||
| joint_em: 0.4537474679270763 | |||||
| joint_f1: 0.715119580346802 | |||||
| joint_prec: 0.7540052057184267 | |||||
| joint_recall: 0.7250240424067661 | |||||
| ``` | |||||
| # [Model Description](#contents) | # [Model Description](#contents) | ||||
| ## [Performance](#contents) | ## [Performance](#contents) | ||||
| @@ -154,6 +214,8 @@ Parameters for evaluation can be set in config.py. | |||||
| | Batch_size | 1 | | | Batch_size | 1 | | ||||
| | Output | inference path | | | Output | inference path | | ||||
| | PEM | 0.9188 | | | PEM | 0.9188 | | ||||
| | total top1 pem | 0.88 | | |||||
| | joint_f1 | 0.7151 | | |||||
| # [Description of random situation](#contents) | # [Description of random situation](#contents) | ||||
| @@ -0,0 +1,55 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """main file""" | |||||
| from mindspore import context | |||||
| from src.rerank_and_reader_utils import get_parse, cal_reranker_metrics, select_reader_dev_data | |||||
| from src.reranker_eval import rerank | |||||
| from src.reader_eval import read | |||||
| from src.hotpot_evaluate_v1 import hotpotqa_eval | |||||
| from src.build_reranker_data import get_rerank_data | |||||
| def rerank_and_retriever_eval(): | |||||
| """main function""" | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||||
| parser = get_parse() | |||||
| args = parser.parse_args() | |||||
| if args.get_reranker_data: | |||||
| get_rerank_data(args) | |||||
| if args.run_reranker: | |||||
| rerank(args) | |||||
| if args.cal_reranker_metrics: | |||||
| total_top1_pem, _, _ = \ | |||||
| cal_reranker_metrics(dev_gold_file=args.dev_gold_file, rerank_result_file=args.rerank_result_file) | |||||
| print(f"total top1 pem: {total_top1_pem}") | |||||
| if args.select_reader_data: | |||||
| select_reader_dev_data(args) | |||||
| if args.run_reader: | |||||
| read(args) | |||||
| if args.cal_reader_metrics: | |||||
| metrics = hotpotqa_eval(args.reader_result_file, args.dev_gold_file) | |||||
| for k in metrics: | |||||
| print(f"{k}: {metrics[k]}") | |||||
| if __name__ == "__main__": | |||||
| rerank_and_retriever_eval() | |||||
| @@ -21,16 +21,16 @@ export DEVICE_NUM=1 | |||||
| export RANK_SIZE=$DEVICE_NUM | export RANK_SIZE=$DEVICE_NUM | ||||
| export RANK_ID=0 | export RANK_ID=0 | ||||
| if [ -d "eval" ]; | |||||
| if [ -d "eval_tr" ]; | |||||
| then | then | ||||
| rm -rf ./eval | |||||
| rm -rf ./eval_tr | |||||
| fi | fi | ||||
| mkdir ./eval | |||||
| mkdir ./eval_tr | |||||
| cp ../*.py ./eval | |||||
| cp *.sh ./eval | |||||
| cp -r ../src ./eval | |||||
| cd ./eval || exit | |||||
| cp ../*.py ./eval_tr | |||||
| cp *.sh ./eval_tr | |||||
| cp -r ../src ./eval_tr | |||||
| cd ./eval_tr || exit | |||||
| env > env.log | env > env.log | ||||
| echo "start evaluation" | echo "start evaluation" | ||||
| @@ -0,0 +1,39 @@ | |||||
| #!/bin/bash | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| # eval script | |||||
| ulimit -u unlimited | |||||
| export DEVICE_NUM=1 | |||||
| export RANK_SIZE=$DEVICE_NUM | |||||
| export RANK_ID=0 | |||||
| if [ -d "eval" ]; | |||||
| then | |||||
| rm -rf ./eval | |||||
| fi | |||||
| mkdir ./eval | |||||
| cp ../*.py ./eval | |||||
| cp *.sh ./eval | |||||
| cp -r ../src ./eval | |||||
| cd ./eval || exit | |||||
| env > env.log | |||||
| echo "start evaluation" | |||||
| python reranker_and_reader_eval.py --get_reranker_data --run_reranker --cal_reranker_metrics --select_reader_data --run_reader --cal_reader_metrics > log_reranker_and_reader.txt 2>&1 & | |||||
| cd .. | |||||
| @@ -0,0 +1,430 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """build reranker data from retriever result""" | |||||
| import pickle | |||||
| import gzip | |||||
| from tqdm import tqdm | |||||
| from src.rerank_and_reader_utils import read_json, make_wiki_id, convert_text_to_tokens, normalize_title, \ | |||||
| whitespace_tokenize, DocDB, _largest_valid_index, generate_mapping, InputFeatures, Example | |||||
| from transformers import AutoTokenizer | |||||
| def judge_para(data): | |||||
| """judge whether is valid para""" | |||||
| for _, para_tokens in data["context"].items(): | |||||
| if len(para_tokens) == 1: | |||||
| return False | |||||
| return True | |||||
| def judge_sp(data, sent_name2id, para2id): | |||||
| """judge whether is valid sp""" | |||||
| for sp in data['sp']: | |||||
| title = normalize_title(sp[0]) | |||||
| name = normalize_title(sp[0]) + '_{}'.format(sp[1]) | |||||
| if title in para2id and name not in sent_name2id: | |||||
| return False | |||||
| return True | |||||
| def judge(path, path_set, reverse=False, golds=None, mode='or'): | |||||
| """judge function""" | |||||
| if path[0] == path[-1]: | |||||
| return False | |||||
| if path in path_set: | |||||
| return False | |||||
| if reverse and path[::-1] in path_set: | |||||
| return False | |||||
| if not golds: | |||||
| return True | |||||
| if mode == 'or': | |||||
| return any(gold not in path for gold in golds) | |||||
| if mode == 'and': | |||||
| return all(gold not in path for gold in golds) | |||||
| return False | |||||
| def get_context_and_sents(path, doc_db): | |||||
| """get context ans sentences""" | |||||
| context = {} | |||||
| sents = {} | |||||
| for title in path: | |||||
| para_info = doc_db.get_doc_info(title) | |||||
| if title.endswith('_0'): | |||||
| title = title[:-2] | |||||
| context[title] = pickle.loads(para_info[1]) | |||||
| sents[title] = pickle.loads(para_info[2]) | |||||
| return context, sents | |||||
| def gen_dev_data(dev_file, db_path, topk_file): | |||||
| """generate dev data""" | |||||
| # ----------------------------------------db info----------------------------------------------- | |||||
| topk_data = read_json(topk_file) # path | |||||
| doc_db = DocDB(db_path) # db get offset | |||||
| print('load db successfully!') | |||||
| # ---------------------------------------------supervision ------------------------------------------ | |||||
| dev_data = read_json(dev_file) | |||||
| qid2sp = {} | |||||
| qid2ans = {} | |||||
| qid2type = {} | |||||
| qid2path = {} | |||||
| for _, data in enumerate(dev_data): | |||||
| sp_facts = data['supporting_facts'] if 'supporting_facts' in data else None | |||||
| qid2sp[data['_id']] = sp_facts | |||||
| qid2ans[data['_id']] = data['answer'] if 'answer' in data else None | |||||
| qid2type[data['_id']] = data['type'] if 'type' in data else None | |||||
| qid2path[data['_id']] = list(set(list(zip(*sp_facts))[0])) if sp_facts else None | |||||
| new_dev_data = [] | |||||
| for _, data in enumerate(tqdm(topk_data)): | |||||
| qid = data['q_id'] | |||||
| question = data['question'] | |||||
| topk_titles = data['topk_titles'] | |||||
| gold_path = list(map(normalize_title, qid2path[qid])) if qid2path[qid] else None | |||||
| all_titles = [] | |||||
| for titles in topk_titles: | |||||
| titles = list(map(normalize_title, titles)) | |||||
| if len(titles) == 1: | |||||
| continue | |||||
| path = titles[:2] | |||||
| if judge(path, all_titles): | |||||
| all_titles.append(titles[:2]) | |||||
| if len(titles) == 3: | |||||
| path = titles[1:] | |||||
| if judge(path, all_titles): | |||||
| all_titles.append(titles[1:]) | |||||
| # --------------------------------------------------process query----------------------------------- | |||||
| question = " ".join(whitespace_tokenize(question)) | |||||
| question = question.strip() | |||||
| q_tokens, _ = convert_text_to_tokens(question) | |||||
| gold_path = list(map(lambda x: make_wiki_id(x, 0), gold_path)) if gold_path else None | |||||
| for path in all_titles: | |||||
| context, sents = get_context_and_sents(path, doc_db) | |||||
| ans_label = int(gold_path[0] in path and gold_path[1] in path) if gold_path else None | |||||
| new_dev_data.append({ | |||||
| 'qid': qid, | |||||
| 'type': qid2type[qid], | |||||
| 'question': question, | |||||
| 'q_tokens': q_tokens, | |||||
| 'context': context, | |||||
| 'sents': sents, | |||||
| 'answer': qid2ans[qid], | |||||
| 'sp': qid2sp[qid], | |||||
| 'ans_para': None, | |||||
| 'is_impossible': not ans_label == 1 | |||||
| }) | |||||
| return new_dev_data | |||||
| def read_hotpot_examples(path_data): | |||||
| """reader examples""" | |||||
| examples = [] | |||||
| max_sent_cnt = 0 | |||||
| failed = 0 | |||||
| for _, data in enumerate(path_data): | |||||
| if not judge_para(data): | |||||
| failed += 1 | |||||
| continue | |||||
| question = data['question'] | |||||
| question = " ".join(whitespace_tokenize(question)) | |||||
| question = question.strip() | |||||
| path = list(map(normalize_title, data["context"].keys())) | |||||
| qid = data['qid'] | |||||
| q_tokens = data['q_tokens'] | |||||
| # -------------------------------------add para------------------------------------------------------------ | |||||
| doc_tokens = [] | |||||
| para_start_end_position = [] | |||||
| title_start_end_position = [] | |||||
| sent_start_end_position = [] | |||||
| sent_names = [] | |||||
| sent_name2id = {} | |||||
| para2id = {} | |||||
| for para, para_tokens in data["context"].items(): | |||||
| sents = data["sents"][para] | |||||
| para = normalize_title(para) | |||||
| title_tokens = convert_text_to_tokens(para)[0] | |||||
| para_node_id = len(para_start_end_position) | |||||
| para2id[para] = para_node_id | |||||
| doc_offset = len(doc_tokens) | |||||
| doc_tokens += title_tokens | |||||
| doc_tokens += para_tokens | |||||
| title_start_end_position.append((doc_offset, doc_offset + len(title_tokens) - 1)) | |||||
| doc_offset += len(title_tokens) | |||||
| para_start_end_position.append((doc_offset, doc_offset + len(para_tokens) - 1, para)) | |||||
| for idx, sent in enumerate(sents): | |||||
| if sent[0] == -1 and sent[1] == -1: | |||||
| continue | |||||
| sent_names.append([para, idx]) # local name | |||||
| sent_node_id = len(sent_start_end_position) | |||||
| sent_name2id[normalize_title(para) + '_{}'.format(str(idx))] = sent_node_id | |||||
| sent_start_end_position.append((doc_offset + sent[0], | |||||
| doc_offset + sent[1])) | |||||
| # add sp and ans | |||||
| sp_facts = [] | |||||
| sup_fact_id = [] | |||||
| for sp in sp_facts: | |||||
| name = normalize_title(sp[0]) + '_{}'.format(sp[1]) | |||||
| if name in sent_name2id: | |||||
| sup_fact_id.append(sent_name2id[name]) | |||||
| sup_para_id = set() # use set | |||||
| if sp_facts: | |||||
| for para in list(zip(*sp_facts))[0]: | |||||
| para = normalize_title(para) | |||||
| if para in para2id: | |||||
| sup_para_id.add(para2id[para]) | |||||
| sup_para_id = list(sup_para_id) | |||||
| example = Example( | |||||
| qas_id=qid, | |||||
| path=path, | |||||
| unique_id=qid + '_' + '_'.join(path), | |||||
| question_tokens=q_tokens, | |||||
| doc_tokens=doc_tokens, # multi-para tokens w/o query | |||||
| sent_names=sent_names, | |||||
| sup_fact_id=sup_fact_id, # global sent id | |||||
| sup_para_id=sup_para_id, # global para id | |||||
| para_start_end_position=para_start_end_position, | |||||
| sent_start_end_position=sent_start_end_position, | |||||
| title_start_end_position=title_start_end_position, | |||||
| question_text=question) | |||||
| examples.append(example) | |||||
| max_sent_cnt = max(max_sent_cnt, len(sent_start_end_position)) | |||||
| print(f"Maximum sentence cnt: {max_sent_cnt}") | |||||
| print(f'failed examples: {failed}') | |||||
| print(f'convert {len(examples)} examples successfully!') | |||||
| return examples | |||||
| def add_sub_token(sub_tokens, idx, tok_to_orig_index, all_query_tokens): | |||||
| """add sub tokens""" | |||||
| for sub_token in sub_tokens: | |||||
| tok_to_orig_index.append(idx) | |||||
| all_query_tokens.append(sub_token) | |||||
| return tok_to_orig_index, all_query_tokens | |||||
| def get_sent_spans(example, orig_to_tok_index, orig_to_tok_back_index): | |||||
| """get sentences' spans""" | |||||
| sentence_spans = [] | |||||
| for sent_span in example.sent_start_end_position: | |||||
| sent_start_position = orig_to_tok_index[sent_span[0]] | |||||
| sent_end_position = orig_to_tok_back_index[sent_span[1]] | |||||
| sentence_spans.append((sent_start_position, sent_end_position + 1)) | |||||
| return sentence_spans | |||||
| def get_para_spans(example, orig_to_tok_index, orig_to_tok_back_index, all_doc_tokens, marker): | |||||
| """get paragraphs' spans""" | |||||
| para_spans = [] | |||||
| for title_span, para_span in zip(example.title_start_end_position, example.para_start_end_position): | |||||
| para_start_position = orig_to_tok_index[title_span[0]] | |||||
| para_end_position = orig_to_tok_back_index[para_span[1]] | |||||
| if para_end_position + 1 < len(all_doc_tokens) and all_doc_tokens[para_end_position + 1] == \ | |||||
| marker['sent'][0]: | |||||
| para_spans.append((para_start_position - 1, para_end_position + 1, para_span[2])) | |||||
| else: | |||||
| para_spans.append((para_start_position - 1, para_end_position, para_span[2])) | |||||
| return para_spans | |||||
| def build_feature(example, all_doc_tokens, doc_input_ids, doc_input_mask, doc_segment_ids, all_query_tokens, | |||||
| query_input_ids, query_input_mask, query_segment_ids, para_spans, sentence_spans, tok_to_orig_index): | |||||
| """build a input feature""" | |||||
| feature = InputFeatures( | |||||
| qas_id=example.qas_id, | |||||
| path=example.path, | |||||
| unique_id=example.qas_id + '_' + '_'.join(example.path), | |||||
| sent_names=example.sent_names, | |||||
| doc_tokens=all_doc_tokens, | |||||
| doc_input_ids=doc_input_ids, | |||||
| doc_input_mask=doc_input_mask, | |||||
| doc_segment_ids=doc_segment_ids, | |||||
| query_tokens=all_query_tokens, | |||||
| query_input_ids=query_input_ids, | |||||
| query_input_mask=query_input_mask, | |||||
| query_segment_ids=query_segment_ids, | |||||
| para_spans=para_spans, | |||||
| sent_spans=sentence_spans, | |||||
| token_to_orig_map=tok_to_orig_index) | |||||
| return feature | |||||
| def convert_example_to_features(tokenizer, args, examples): | |||||
| """convert examples to features""" | |||||
| features = [] | |||||
| failed = 0 | |||||
| marker = {'q': ['[q]', '[/q]'], 'para': ['<t>', '</t>'], 'sent': ['[s]']} | |||||
| for (_, example) in enumerate(tqdm(examples)): | |||||
| all_query_tokens = [tokenizer.cls_token, marker['q'][0]] | |||||
| tok_to_orig_index = [-1, -1] # orig: query + doc tokens | |||||
| ques_orig_to_tok_index = [] # start position | |||||
| ques_orig_to_tok_back_index = [] # end position | |||||
| q_spans = [] | |||||
| # -------------------------------------------for query--------------------------------------------- | |||||
| for (idx, token) in enumerate(example.question_tokens): | |||||
| sub_tokens = tokenizer.tokenize(token) | |||||
| ques_orig_to_tok_index.append(len(all_query_tokens)) | |||||
| tok_to_orig_index, all_query_tokens = add_sub_token(sub_tokens, idx, tok_to_orig_index, all_query_tokens) | |||||
| ques_orig_to_tok_back_index.append(len(all_query_tokens) - 1) | |||||
| all_query_tokens = all_query_tokens[:63] | |||||
| tok_to_orig_index = tok_to_orig_index[:63] | |||||
| all_query_tokens.append(marker['q'][-1]) | |||||
| tok_to_orig_index.append(-1) | |||||
| q_spans.append((1, len(all_query_tokens) - 1)) | |||||
| # ---------------------------------------add doc tokens------------------------------------------------ | |||||
| all_doc_tokens = [] | |||||
| orig_to_tok_index = [] # orig: token in doc | |||||
| orig_to_tok_back_index = [] | |||||
| title_start_mapping, title_end_mapping = generate_mapping(len(example.doc_tokens), | |||||
| example.title_start_end_position) | |||||
| _, sent_end_mapping = generate_mapping(len(example.doc_tokens), | |||||
| example.sent_start_end_position) | |||||
| all_doc_tokens += all_query_tokens | |||||
| for (idx, token) in enumerate(example.doc_tokens): | |||||
| sub_tokens = tokenizer.tokenize(token) | |||||
| if title_start_mapping[idx] == 1: | |||||
| all_doc_tokens.append(marker['para'][0]) | |||||
| tok_to_orig_index.append(-1) | |||||
| # orig: position in doc tokens tok: global tokenized tokens (start) | |||||
| orig_to_tok_index.append(len(all_doc_tokens)) | |||||
| tok_to_orig_index, all_doc_tokens = add_sub_token(sub_tokens, idx + len(example.question_tokens), | |||||
| tok_to_orig_index, all_doc_tokens) | |||||
| orig_to_tok_back_index.append(len(all_doc_tokens) - 1) | |||||
| if title_end_mapping[idx] == 1: | |||||
| all_doc_tokens.append(marker['para'][1]) | |||||
| tok_to_orig_index.append(-1) | |||||
| if sent_end_mapping[idx] == 1: | |||||
| all_doc_tokens.append(marker['sent'][0]) | |||||
| tok_to_orig_index.append(-1) | |||||
| # -----------------------------------for sentence------------------------------------------------- | |||||
| sentence_spans = get_sent_spans(example, orig_to_tok_index, orig_to_tok_back_index) | |||||
| # -----------------------------------for para------------------------------------------------------- | |||||
| para_spans = get_para_spans(example, orig_to_tok_index, orig_to_tok_back_index, all_doc_tokens, marker) | |||||
| # -----------------------------------remove sent > max seq length----------------------------------------- | |||||
| sent_max_index = _largest_valid_index(sentence_spans, args.seq_len) | |||||
| max_sent_cnt = len(sentence_spans) | |||||
| if sent_max_index != len(sentence_spans): | |||||
| if sent_max_index == 0: | |||||
| failed += 0 | |||||
| continue | |||||
| sentence_spans = sentence_spans[:sent_max_index] | |||||
| max_tok_length = sentence_spans[-1][1] # max_tok_length [s] | |||||
| # max end index: max_tok_length | |||||
| para_max_index = _largest_valid_index(para_spans, max_tok_length + 1) | |||||
| if para_max_index == 0: # only one para | |||||
| failed += 0 | |||||
| continue | |||||
| if orig_to_tok_back_index[example.title_start_end_position[1][1]] + 1 >= max_tok_length: | |||||
| failed += 0 | |||||
| continue | |||||
| max_para_span = para_spans[para_max_index] | |||||
| para_spans = para_spans[:para_max_index] | |||||
| para_spans.append((max_para_span[0], max_tok_length, max_para_span[2])) | |||||
| all_doc_tokens = all_doc_tokens[:max_tok_length + 1] | |||||
| sentence_spans = sentence_spans[:min(max_sent_cnt, args.max_sent_num)] | |||||
| # ----------------------------------------Padding Document----------------------------------------------------- | |||||
| if len(all_doc_tokens) > args.seq_len: | |||||
| st, _, title = para_spans[-1] | |||||
| para_spans[-1] = (st, args.seq_len - 1, title) | |||||
| all_doc_tokens = all_doc_tokens[:args.seq_len - 1] + [marker['sent'][0]] | |||||
| doc_input_ids = tokenizer.convert_tokens_to_ids(all_doc_tokens) | |||||
| query_input_ids = tokenizer.convert_tokens_to_ids(all_query_tokens) | |||||
| doc_input_mask = [1] * len(doc_input_ids) | |||||
| doc_segment_ids = [0] * len(query_input_ids) + [1] * (len(doc_input_ids) - len(query_input_ids)) | |||||
| doc_pad_length = args.seq_len - len(doc_input_ids) | |||||
| doc_input_ids += [0] * doc_pad_length | |||||
| doc_input_mask += [0] * doc_pad_length | |||||
| doc_segment_ids += [0] * doc_pad_length | |||||
| # Padding Question | |||||
| query_input_mask = [1] * len(query_input_ids) | |||||
| query_segment_ids = [0] * len(query_input_ids) | |||||
| query_pad_length = 64 - len(query_input_ids) | |||||
| query_input_ids += [0] * query_pad_length | |||||
| query_input_mask += [0] * query_pad_length | |||||
| query_segment_ids += [0] * query_pad_length | |||||
| feature = build_feature(example, all_doc_tokens, doc_input_ids, doc_input_mask, doc_segment_ids, | |||||
| all_query_tokens, query_input_ids, query_input_mask, query_segment_ids, para_spans, | |||||
| sentence_spans, tok_to_orig_index) | |||||
| features.append(feature) | |||||
| return features | |||||
| def get_rerank_data(args): | |||||
| """function for generating reranker's data""" | |||||
| new_dev_data = gen_dev_data(dev_file=args.dev_gold_file, | |||||
| db_path=args.wiki_db_file, | |||||
| topk_file=args.retriever_result_file) | |||||
| tokenizer = AutoTokenizer.from_pretrained(args.albert_model_path) | |||||
| new_tokens = ['[q]', '[/q]', '<t>', '</t>', '[s]'] | |||||
| tokenizer.add_tokens(new_tokens) | |||||
| examples = read_hotpot_examples(new_dev_data) | |||||
| features = convert_example_to_features(tokenizer=tokenizer, args=args, examples=examples) | |||||
| with gzip.open(args.rerank_example_file, "wb") as f: | |||||
| pickle.dump(examples, f) | |||||
| with gzip.open(args.rerank_feature_file, "wb") as f: | |||||
| pickle.dump(features, f) | |||||
| @@ -33,14 +33,14 @@ def ThinkRetrieverConfig(): | |||||
| parser.add_argument("--batch_size", type=int, default=1, help="batch size") | parser.add_argument("--batch_size", type=int, default=1, help="batch size") | ||||
| parser.add_argument("--device_id", type=int, default=0, help="device id") | parser.add_argument("--device_id", type=int, default=0, help="device id") | ||||
| parser.add_argument("--save_name", type=str, default='doc_path', help='name of output') | parser.add_argument("--save_name", type=str, default='doc_path', help='name of output') | ||||
| parser.add_argument("--save_path", type=str, default='./', help='path of output') | |||||
| parser.add_argument("--vocab_path", type=str, default='./scripts/vocab.txt', help="vocab path") | |||||
| parser.add_argument("--wiki_path", type=str, default='./scripts/db_docs_bidirection_new.pkl', help="wiki path") | |||||
| parser.add_argument("--dev_path", type=str, default='./scripts/hotpot_dev_fullwiki_v1_for_retriever.json', | |||||
| parser.add_argument("--save_path", type=str, default='../', help='path of output') | |||||
| parser.add_argument("--vocab_path", type=str, default='../vocab.txt', help="vocab path") | |||||
| parser.add_argument("--wiki_path", type=str, default='../db_docs_bidirection_new.pkl', help="wiki path") | |||||
| parser.add_argument("--dev_path", type=str, default='../hotpot_dev_fullwiki_v1_for_retriever.json', | |||||
| help="dev path") | help="dev path") | ||||
| parser.add_argument("--dev_data_path", type=str, default='./scripts/dev_tf_idf_data_raw.pkl', help="dev data path") | |||||
| parser.add_argument("--onehop_bert_path", type=str, default='./scripts/onehop.ckpt', help="onehop bert ckpt path") | |||||
| parser.add_argument("--onehop_mlp_path", type=str, default='./scripts/onehop_mlp.ckpt', help="onehop mlp ckpt path") | |||||
| parser.add_argument("--twohop_bert_path", type=str, default='./scripts/twohop.ckpt', help="twohop bert ckpt path") | |||||
| parser.add_argument("--twohop_mlp_path", type=str, default='./scripts/twohop_mlp.ckpt', help="twohop mlp ckpt path") | |||||
| parser.add_argument("--dev_data_path", type=str, default='../dev_tf_idf_data_raw.pkl', help="dev data path") | |||||
| parser.add_argument("--onehop_bert_path", type=str, default='../onehop.ckpt', help="onehop bert ckpt path") | |||||
| parser.add_argument("--onehop_mlp_path", type=str, default='../onehop_mlp.ckpt', help="onehop mlp ckpt path") | |||||
| parser.add_argument("--twohop_bert_path", type=str, default='../twohop.ckpt', help="twohop bert ckpt path") | |||||
| parser.add_argument("--twohop_mlp_path", type=str, default='../twohop_mlp.ckpt', help="twohop mlp ckpt path") | |||||
| return parser.parse_args() | return parser.parse_args() | ||||
| @@ -0,0 +1,153 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """hotpotqa evaluate script""" | |||||
| import re | |||||
| import string | |||||
| from collections import Counter | |||||
| import ujson as json | |||||
| def normalize_answer(s): | |||||
| """normalize answer""" | |||||
| def remove_articles(text): | |||||
| """remove articles""" | |||||
| return re.sub(r'\b(a|an|the)\b', ' ', text) | |||||
| def white_space_fix(text): | |||||
| """fix whitespace""" | |||||
| return ' '.join(text.split()) | |||||
| def remove_punc(text): | |||||
| """remove punctuation from text""" | |||||
| exclude = set(string.punctuation) | |||||
| return ''.join(ch for ch in text if ch not in exclude) | |||||
| def lower(text): | |||||
| """lower text""" | |||||
| return text.lower() | |||||
| return white_space_fix(remove_articles(remove_punc(lower(s)))) | |||||
| def f1_score(prediction, ground_truth): | |||||
| """calculate f1 score""" | |||||
| normalized_prediction = normalize_answer(prediction) | |||||
| normalized_ground_truth = normalize_answer(ground_truth) | |||||
| ZERO_METRIC = (0, 0, 0) | |||||
| if normalized_prediction in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth: | |||||
| return ZERO_METRIC | |||||
| if normalized_ground_truth in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth: | |||||
| return ZERO_METRIC | |||||
| prediction_tokens = normalized_prediction.split() | |||||
| ground_truth_tokens = normalized_ground_truth.split() | |||||
| common = Counter(prediction_tokens) & Counter(ground_truth_tokens) | |||||
| num_same = sum(common.values()) | |||||
| if num_same == 0: | |||||
| return ZERO_METRIC | |||||
| precision = 1.0 * num_same / len(prediction_tokens) | |||||
| recall = 1.0 * num_same / len(ground_truth_tokens) | |||||
| f1 = (2 * precision * recall) / (precision + recall) | |||||
| return f1, precision, recall | |||||
| def exact_match_score(prediction, ground_truth): | |||||
| """calculate exact match score""" | |||||
| return normalize_answer(prediction) == normalize_answer(ground_truth) | |||||
| def update_answer(metrics, prediction, gold): | |||||
| """update answer""" | |||||
| em = exact_match_score(prediction, gold) | |||||
| f1, prec, recall = f1_score(prediction, gold) | |||||
| metrics['em'] += float(em) | |||||
| metrics['f1'] += f1 | |||||
| metrics['prec'] += prec | |||||
| metrics['recall'] += recall | |||||
| return em, prec, recall | |||||
| def update_sp(metrics, prediction, gold): | |||||
| """update supporting sentences""" | |||||
| cur_sp_pred = set(map(tuple, prediction)) | |||||
| gold_sp_pred = set(map(tuple, gold)) | |||||
| tp, fp, fn = 0, 0, 0 | |||||
| for e in cur_sp_pred: | |||||
| if e in gold_sp_pred: | |||||
| tp += 1 | |||||
| else: | |||||
| fp += 1 | |||||
| for e in gold_sp_pred: | |||||
| if e not in cur_sp_pred: | |||||
| fn += 1 | |||||
| prec = 1.0 * tp / (tp + fp) if tp + fp > 0 else 0.0 | |||||
| recall = 1.0 * tp / (tp + fn) if tp + fn > 0 else 0.0 | |||||
| f1 = 2 * prec * recall / (prec + recall) if prec + recall > 0 else 0.0 | |||||
| em = 1.0 if fp + fn == 0 else 0.0 | |||||
| metrics['sp_em'] += em | |||||
| metrics['sp_f1'] += f1 | |||||
| metrics['sp_prec'] += prec | |||||
| metrics['sp_recall'] += recall | |||||
| return em, prec, recall | |||||
| def hotpotqa_eval(prediction_file, gold_file): | |||||
| """hotpotqa evaluate function""" | |||||
| with open(prediction_file) as f: | |||||
| prediction = json.load(f) | |||||
| with open(gold_file) as f: | |||||
| gold = json.load(f) | |||||
| metrics = {'em': 0, 'f1': 0, 'prec': 0, 'recall': 0, | |||||
| 'sp_em': 0, 'sp_f1': 0, 'sp_prec': 0, 'sp_recall': 0, | |||||
| 'joint_em': 0, 'joint_f1': 0, 'joint_prec': 0, 'joint_recall': 0} | |||||
| for dp in gold: | |||||
| cur_id = dp['_id'] | |||||
| can_eval_joint = True | |||||
| if cur_id not in prediction['answer']: | |||||
| print('missing answer {}'.format(cur_id)) | |||||
| can_eval_joint = False | |||||
| else: | |||||
| em, prec, recall = update_answer( | |||||
| metrics, prediction['answer'][cur_id], dp['answer']) | |||||
| if cur_id not in prediction['sp']: | |||||
| print('missing sp fact {}'.format(cur_id)) | |||||
| can_eval_joint = False | |||||
| else: | |||||
| sp_em, sp_prec, sp_recall = update_sp( | |||||
| metrics, prediction['sp'][cur_id], dp['supporting_facts']) | |||||
| if can_eval_joint: | |||||
| joint_prec = prec * sp_prec | |||||
| joint_recall = recall * sp_recall | |||||
| if joint_prec + joint_recall > 0: | |||||
| joint_f1 = 2 * joint_prec * joint_recall / (joint_prec + joint_recall) | |||||
| else: | |||||
| joint_f1 = 0. | |||||
| joint_em = em * sp_em | |||||
| metrics['joint_em'] += joint_em | |||||
| metrics['joint_f1'] += joint_f1 | |||||
| metrics['joint_prec'] += joint_prec | |||||
| metrics['joint_recall'] += joint_recall | |||||
| num = len(gold) | |||||
| for k in metrics: | |||||
| metrics[k] /= num | |||||
| return metrics | |||||
| @@ -0,0 +1,73 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Reader model""" | |||||
| import mindspore.nn as nn | |||||
| from mindspore import load_checkpoint, load_param_into_net | |||||
| from mindspore.ops import BatchMatMul | |||||
| from mindspore import ops | |||||
| from mindspore import dtype as mstype | |||||
| from src.reader_albert_xxlarge import Reader_Albert | |||||
| from src.reader_downstream import Reader_Downstream | |||||
| dst_type = mstype.float16 | |||||
| dst_type2 = mstype.float32 | |||||
| class Reader(nn.Cell): | |||||
| """Reader model""" | |||||
| def __init__(self, batch_size, encoder_ck_file, downstream_ck_file): | |||||
| """init function""" | |||||
| super(Reader, self).__init__(auto_prefix=False) | |||||
| self.encoder = Reader_Albert(batch_size) | |||||
| param_dict = load_checkpoint(encoder_ck_file) | |||||
| not_load_params = load_param_into_net(self.encoder, param_dict) | |||||
| print(f"not loaded: {not_load_params}") | |||||
| self.downstream = Reader_Downstream() | |||||
| param_dict = load_checkpoint(downstream_ck_file) | |||||
| not_load_params = load_param_into_net(self.downstream, param_dict) | |||||
| print(f"not loaded: {not_load_params}") | |||||
| self.bmm = BatchMatMul() | |||||
| def construct(self, input_ids, attn_mask, token_type_ids, | |||||
| context_mask, square_mask, packing_mask, cache_mask, | |||||
| para_start_mapping, sent_end_mapping): | |||||
| """construct function""" | |||||
| state = self.encoder(attn_mask, input_ids, token_type_ids) | |||||
| para_state = self.bmm(ops.Cast()(para_start_mapping, dst_type), ops.Cast()(state, dst_type)) # [B, 2, D] | |||||
| sent_state = self.bmm(ops.Cast()(sent_end_mapping, dst_type), ops.Cast()(state, dst_type)) # [B, max_sent, D] | |||||
| q_type, start, end, para_logit, sent_logit = self.downstream(ops.Cast()(para_state, dst_type2), | |||||
| ops.Cast()(sent_state, dst_type2), | |||||
| state, | |||||
| context_mask) | |||||
| outer = start[:, :, None] + end[:, None] | |||||
| outer_mask = cache_mask | |||||
| outer_mask = square_mask * outer_mask[None] | |||||
| outer = outer - 1e30 * (1 - outer_mask) | |||||
| outer = outer - 1e30 * packing_mask[:, :, None] | |||||
| max_row = ops.ReduceMax()(outer, 2) | |||||
| y1 = ops.Argmax()(max_row) | |||||
| max_col = ops.ReduceMax()(outer, 1) | |||||
| y2 = ops.Argmax()(max_col) | |||||
| return start, end, q_type, para_logit, sent_logit, y1, y2 | |||||
| @@ -0,0 +1,263 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """albert-xxlarge Model for reader""" | |||||
| import numpy as np | |||||
| from mindspore import nn, ops | |||||
| from mindspore import Tensor, Parameter | |||||
| from mindspore.ops import operations as P | |||||
| from mindspore import dtype as mstype | |||||
| dst_type = mstype.float16 | |||||
| dst_type2 = mstype.float32 | |||||
| class LayerNorm(nn.Cell): | |||||
| """LayerNorm layer""" | |||||
| def __init__(self, mul_7_w_shape, add_8_bias_shape): | |||||
| """init function""" | |||||
| super(LayerNorm, self).__init__() | |||||
| self.reducemean_0 = P.ReduceMean(keep_dims=True) | |||||
| self.sub_1 = P.Sub() | |||||
| self.pow_2 = P.Pow() | |||||
| self.pow_2_input_weight = 2.0 | |||||
| self.reducemean_3 = P.ReduceMean(keep_dims=True) | |||||
| self.add_4 = P.Add() | |||||
| self.add_4_bias = 9.999999960041972e-13 | |||||
| self.sqrt_5 = P.Sqrt() | |||||
| self.div_6 = P.Div() | |||||
| self.mul_7 = P.Mul() | |||||
| self.mul_7_w = Parameter(Tensor(np.random.uniform(0, 1, mul_7_w_shape).astype(np.float32)), name=None) | |||||
| self.add_8 = P.Add() | |||||
| self.add_8_bias = Parameter(Tensor(np.random.uniform(0, 1, add_8_bias_shape).astype(np.float32)), name=None) | |||||
| def construct(self, x): | |||||
| """construct function""" | |||||
| opt_reducemean_0 = self.reducemean_0(x, -1) | |||||
| opt_sub_1 = self.sub_1(x, opt_reducemean_0) | |||||
| opt_pow_2 = self.pow_2(opt_sub_1, self.pow_2_input_weight) | |||||
| opt_reducemean_3 = self.reducemean_3(opt_pow_2, -1) | |||||
| opt_add_4 = self.add_4(opt_reducemean_3, self.add_4_bias) | |||||
| opt_sqrt_5 = self.sqrt_5(opt_add_4) | |||||
| opt_div_6 = self.div_6(opt_sub_1, opt_sqrt_5) | |||||
| opt_mul_7 = self.mul_7(opt_div_6, self.mul_7_w) | |||||
| opt_add_8 = self.add_8(opt_mul_7, self.add_8_bias) | |||||
| return opt_add_8 | |||||
| class Linear(nn.Cell): | |||||
| """Linear layer""" | |||||
| def __init__(self, matmul_0_weight_shape, add_1_bias_shape): | |||||
| """init function""" | |||||
| super(Linear, self).__init__() | |||||
| self.matmul_0 = nn.MatMul() | |||||
| self.matmul_0_w = Parameter(Tensor(np.random.uniform(0, 1, matmul_0_weight_shape).astype(np.float32)), | |||||
| name=None) | |||||
| self.add_1 = P.Add() | |||||
| self.add_1_bias = Parameter(Tensor(np.random.uniform(0, 1, add_1_bias_shape).astype(np.float32)), name=None) | |||||
| def construct(self, x): | |||||
| """construct function""" | |||||
| opt_matmul_0 = self.matmul_0(ops.Cast()(x, dst_type), ops.Cast()(self.matmul_0_w, dst_type)) | |||||
| opt_add_1 = self.add_1(ops.Cast()(opt_matmul_0, dst_type2), self.add_1_bias) | |||||
| return opt_add_1 | |||||
| class MultiHeadAttn(nn.Cell): | |||||
| """Multi-head attention layer""" | |||||
| def __init__(self, batch_size): | |||||
| """init function""" | |||||
| super(MultiHeadAttn, self).__init__() | |||||
| self.batch_size = batch_size | |||||
| self.matmul_0 = nn.MatMul() | |||||
| self.matmul_0_w = Parameter(Tensor(np.random.uniform(0, 1, (4096, 4096)).astype(np.float32)), name=None) | |||||
| self.matmul_1 = nn.MatMul() | |||||
| self.matmul_1_w = Parameter(Tensor(np.random.uniform(0, 1, (4096, 4096)).astype(np.float32)), name=None) | |||||
| self.matmul_2 = nn.MatMul() | |||||
| self.matmul_2_w = Parameter(Tensor(np.random.uniform(0, 1, (4096, 4096)).astype(np.float32)), name=None) | |||||
| self.add_3 = P.Add() | |||||
| self.add_3_bias = Parameter(Tensor(np.random.uniform(0, 1, (4096,)).astype(np.float32)), name=None) | |||||
| self.add_4 = P.Add() | |||||
| self.add_4_bias = Parameter(Tensor(np.random.uniform(0, 1, (4096,)).astype(np.float32)), name=None) | |||||
| self.add_5 = P.Add() | |||||
| self.add_5_bias = Parameter(Tensor(np.random.uniform(0, 1, (4096,)).astype(np.float32)), name=None) | |||||
| self.reshape_6 = P.Reshape() | |||||
| self.reshape_6_shape = tuple([batch_size, 512, 64, 64]) | |||||
| self.reshape_7 = P.Reshape() | |||||
| self.reshape_7_shape = tuple([batch_size, 512, 64, 64]) | |||||
| self.reshape_8 = P.Reshape() | |||||
| self.reshape_8_shape = tuple([batch_size, 512, 64, 64]) | |||||
| self.transpose_9 = P.Transpose() | |||||
| self.transpose_10 = P.Transpose() | |||||
| self.transpose_11 = P.Transpose() | |||||
| self.matmul_12 = nn.MatMul() | |||||
| self.div_13 = P.Div() | |||||
| self.div_13_w = 8.0 | |||||
| self.add_14 = P.Add() | |||||
| self.softmax_15 = nn.Softmax(axis=3) | |||||
| self.matmul_16 = nn.MatMul() | |||||
| self.transpose_17 = P.Transpose() | |||||
| self.matmul_18 = P.MatMul() | |||||
| self.matmul_18_weight = Parameter(Tensor(np.random.uniform(0, 1, (64, 64, 4096)).astype(np.float32)), name=None) | |||||
| self.add_19 = P.Add() | |||||
| self.add_19_bias = Parameter(Tensor(np.random.uniform(0, 1, (4096,)).astype(np.float32)), name=None) | |||||
| def construct(self, x, x0): | |||||
| """construct function""" | |||||
| opt_matmul_0 = self.matmul_0(ops.Cast()(x, dst_type), ops.Cast()(self.matmul_0_w, dst_type)) | |||||
| opt_matmul_1 = self.matmul_1(ops.Cast()(x, dst_type), ops.Cast()(self.matmul_1_w, dst_type)) | |||||
| opt_matmul_2 = self.matmul_2(ops.Cast()(x, dst_type), ops.Cast()(self.matmul_2_w, dst_type)) | |||||
| opt_add_3 = self.add_3(ops.Cast()(opt_matmul_0, dst_type2), self.add_3_bias) | |||||
| opt_add_4 = self.add_4(ops.Cast()(opt_matmul_1, dst_type2), self.add_4_bias) | |||||
| opt_add_5 = self.add_5(ops.Cast()(opt_matmul_2, dst_type2), self.add_5_bias) | |||||
| opt_reshape_6 = self.reshape_6(opt_add_3, self.reshape_6_shape) | |||||
| opt_reshape_7 = self.reshape_7(opt_add_4, self.reshape_7_shape) | |||||
| opt_reshape_8 = self.reshape_8(opt_add_5, self.reshape_8_shape) | |||||
| opt_transpose_9 = self.transpose_9(opt_reshape_6, (0, 2, 1, 3)) | |||||
| opt_transpose_10 = self.transpose_10(opt_reshape_7, (0, 2, 3, 1)) | |||||
| opt_transpose_11 = self.transpose_11(opt_reshape_8, (0, 2, 1, 3)) | |||||
| opt_matmul_12 = self.matmul_12(ops.Cast()(opt_transpose_9, dst_type), ops.Cast()(opt_transpose_10, dst_type)) | |||||
| opt_div_13 = self.div_13(ops.Cast()(opt_matmul_12, dst_type2), ops.Cast()(self.div_13_w, dst_type2)) | |||||
| opt_add_14 = self.add_14(opt_div_13, x0) | |||||
| opt_softmax_15 = self.softmax_15(opt_add_14) | |||||
| opt_matmul_16 = self.matmul_16(ops.Cast()(opt_softmax_15, dst_type), ops.Cast()(opt_transpose_11, dst_type)) | |||||
| opt_transpose_17 = self.transpose_17(ops.Cast()(opt_matmul_16, dst_type2), (0, 2, 1, 3)) | |||||
| opt_matmul_18 = self.matmul_18(ops.Cast()(opt_transpose_17, dst_type).view(self.batch_size * 512, -1), | |||||
| ops.Cast()(self.matmul_18_weight, dst_type).view(-1, 4096))\ | |||||
| .view(self.batch_size, 512, 4096) | |||||
| opt_add_19 = self.add_19(ops.Cast()(opt_matmul_18, dst_type2), self.add_19_bias) | |||||
| return opt_add_19 | |||||
| class NewGeLU(nn.Cell): | |||||
| """new gelu layer""" | |||||
| def __init__(self): | |||||
| """init function""" | |||||
| super(NewGeLU, self).__init__() | |||||
| self.mul_0 = P.Mul() | |||||
| self.mul_0_w = 0.5 | |||||
| self.pow_1 = P.Pow() | |||||
| self.pow_1_input_weight = 3.0 | |||||
| self.mul_2 = P.Mul() | |||||
| self.mul_2_w = 0.044714998453855515 | |||||
| self.add_3 = P.Add() | |||||
| self.mul_4 = P.Mul() | |||||
| self.mul_4_w = 0.7978845834732056 | |||||
| self.tanh_5 = nn.Tanh() | |||||
| self.add_6 = P.Add() | |||||
| self.add_6_bias = 1.0 | |||||
| self.mul_7 = P.Mul() | |||||
| def construct(self, x): | |||||
| """construct function""" | |||||
| opt_mul_0 = self.mul_0(x, self.mul_0_w) | |||||
| opt_pow_1 = self.pow_1(x, self.pow_1_input_weight) | |||||
| opt_mul_2 = self.mul_2(opt_pow_1, self.mul_2_w) | |||||
| opt_add_3 = self.add_3(x, opt_mul_2) | |||||
| opt_mul_4 = self.mul_4(opt_add_3, self.mul_4_w) | |||||
| opt_tanh_5 = self.tanh_5(opt_mul_4) | |||||
| opt_add_6 = self.add_6(opt_tanh_5, self.add_6_bias) | |||||
| opt_mul_7 = self.mul_7(opt_mul_0, opt_add_6) | |||||
| return opt_mul_7 | |||||
| class TransformerLayer(nn.Cell): | |||||
| """Transformer layer""" | |||||
| def __init__(self, batch_size, layernorm1_0_mul_7_w_shape, layernorm1_0_add_8_bias_shape, | |||||
| linear3_0_matmul_0_weight_shape, linear3_0_add_1_bias_shape, linear3_1_matmul_0_weight_shape, | |||||
| linear3_1_add_1_bias_shape): | |||||
| """init function""" | |||||
| super(TransformerLayer, self).__init__() | |||||
| self.multiheadattn_0 = MultiHeadAttn(batch_size) | |||||
| self.add_0 = P.Add() | |||||
| self.layernorm1_0 = LayerNorm(mul_7_w_shape=layernorm1_0_mul_7_w_shape, | |||||
| add_8_bias_shape=layernorm1_0_add_8_bias_shape) | |||||
| self.linear3_0 = Linear(matmul_0_weight_shape=linear3_0_matmul_0_weight_shape, | |||||
| add_1_bias_shape=linear3_0_add_1_bias_shape) | |||||
| self.newgelu2_0 = NewGeLU() | |||||
| self.linear3_1 = Linear(matmul_0_weight_shape=linear3_1_matmul_0_weight_shape, | |||||
| add_1_bias_shape=linear3_1_add_1_bias_shape) | |||||
| self.add_1 = P.Add() | |||||
| def construct(self, x, x0): | |||||
| """construct function""" | |||||
| multiheadattn_0_opt = self.multiheadattn_0(x, x0) | |||||
| opt_add_0 = self.add_0(x, multiheadattn_0_opt) | |||||
| layernorm1_0_opt = self.layernorm1_0(opt_add_0) | |||||
| linear3_0_opt = self.linear3_0(layernorm1_0_opt) | |||||
| newgelu2_0_opt = self.newgelu2_0(linear3_0_opt) | |||||
| linear3_1_opt = self.linear3_1(newgelu2_0_opt) | |||||
| opt_add_1 = self.add_1(linear3_1_opt, layernorm1_0_opt) | |||||
| return opt_add_1 | |||||
| class Reader_Albert(nn.Cell): | |||||
| """Albert model for reader""" | |||||
| def __init__(self, batch_size): | |||||
| """init function""" | |||||
| super(Reader_Albert, self).__init__() | |||||
| self.expanddims_0 = P.ExpandDims() | |||||
| self.expanddims_0_axis = 1 | |||||
| self.expanddims_3 = P.ExpandDims() | |||||
| self.expanddims_3_axis = 2 | |||||
| self.cast_5 = P.Cast() | |||||
| self.cast_5_to = mstype.float32 | |||||
| self.sub_7 = P.Sub() | |||||
| self.sub_7_bias = 1.0 | |||||
| self.mul_9 = P.Mul() | |||||
| self.mul_9_w = -10000.0 | |||||
| self.gather_1_input_weight = Parameter(Tensor(np.random.uniform(0, 1, (30005, 128)).astype(np.float32)), | |||||
| name=None) | |||||
| self.gather_1_axis = 0 | |||||
| self.gather_1 = P.Gather() | |||||
| self.gather_2_input_weight = Parameter(Tensor(np.random.uniform(0, 1, (2, 128)).astype(np.float32)), name=None) | |||||
| self.gather_2_axis = 0 | |||||
| self.gather_2 = P.Gather() | |||||
| self.add_4 = P.Add() | |||||
| self.add_6 = P.Add() | |||||
| self.add_6_bias = Parameter(Tensor(np.random.uniform(0, 1, (1, 512, 128)).astype(np.float32)), name=None) | |||||
| self.layernorm1_0 = LayerNorm(mul_7_w_shape=(128,), add_8_bias_shape=(128,)) | |||||
| self.linear3_0 = Linear(matmul_0_weight_shape=(128, 4096), add_1_bias_shape=(4096,)) | |||||
| self.module34_0 = TransformerLayer(batch_size, | |||||
| layernorm1_0_mul_7_w_shape=(4096,), | |||||
| layernorm1_0_add_8_bias_shape=(4096,), | |||||
| linear3_0_matmul_0_weight_shape=(4096, 16384), | |||||
| linear3_0_add_1_bias_shape=(16384,), | |||||
| linear3_1_matmul_0_weight_shape=(16384, 4096), | |||||
| linear3_1_add_1_bias_shape=(4096,)) | |||||
| self.layernorm1_1 = LayerNorm(mul_7_w_shape=(4096,), add_8_bias_shape=(4096,)) | |||||
| def construct(self, x, x0, x1): | |||||
| """construct function""" | |||||
| opt_expanddims_0 = self.expanddims_0(x, self.expanddims_0_axis) | |||||
| opt_expanddims_3 = self.expanddims_3(opt_expanddims_0, self.expanddims_3_axis) | |||||
| opt_cast_5 = self.cast_5(opt_expanddims_3, self.cast_5_to) | |||||
| opt_sub_7 = self.sub_7(self.sub_7_bias, opt_cast_5) | |||||
| opt_mul_9 = self.mul_9(opt_sub_7, self.mul_9_w) | |||||
| opt_gather_1_axis = self.gather_1_axis | |||||
| opt_gather_1 = self.gather_1(self.gather_1_input_weight, x0, opt_gather_1_axis) | |||||
| opt_gather_2_axis = self.gather_2_axis | |||||
| opt_gather_2 = self.gather_2(self.gather_2_input_weight, x1, opt_gather_2_axis) | |||||
| opt_add_4 = self.add_4(opt_gather_1, opt_gather_2) | |||||
| opt_add_6 = self.add_6(opt_add_4, self.add_6_bias) | |||||
| layernorm1_0_opt = self.layernorm1_0(opt_add_6) | |||||
| linear3_0_opt = self.linear3_0(layernorm1_0_opt) | |||||
| module34_0_opt = self.module34_0(linear3_0_opt, opt_mul_9) | |||||
| out = self.layernorm1_1(module34_0_opt) | |||||
| for _ in range(11): | |||||
| out = self.module34_0(out, opt_mul_9) | |||||
| out = self.layernorm1_1(out) | |||||
| return out | |||||
| @@ -0,0 +1,213 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """downstream Model for reader""" | |||||
| import numpy as np | |||||
| from mindspore import nn, ops | |||||
| from mindspore import Tensor, Parameter | |||||
| from mindspore.ops import operations as P | |||||
| from mindspore import dtype as mstype | |||||
| dst_type = mstype.float16 | |||||
| dst_type2 = mstype.float32 | |||||
| class Module15(nn.Cell): | |||||
| """module of reader downstream""" | |||||
| def __init__(self, matmul_0_weight_shape, add_1_bias_shape): | |||||
| """init function""" | |||||
| super(Module15, self).__init__() | |||||
| self.matmul_0 = nn.MatMul() | |||||
| self.matmul_0_w = Parameter(Tensor(np.random.uniform(0, 1, matmul_0_weight_shape).astype(np.float32)), | |||||
| name=None) | |||||
| self.add_1 = P.Add() | |||||
| self.add_1_bias = Parameter(Tensor(np.random.uniform(0, 1, add_1_bias_shape).astype(np.float32)), name=None) | |||||
| self.relu_2 = nn.ReLU() | |||||
| def construct(self, x): | |||||
| """construct function""" | |||||
| opt_matmul_0 = self.matmul_0(ops.Cast()(x, dst_type), ops.Cast()(self.matmul_0_w, dst_type)) | |||||
| opt_add_1 = self.add_1(ops.Cast()(opt_matmul_0, dst_type2), self.add_1_bias) | |||||
| opt_relu_2 = self.relu_2(opt_add_1) | |||||
| return opt_relu_2 | |||||
| class NormModule(nn.Cell): | |||||
| """Normalization module of reader downstream""" | |||||
| def __init__(self, mul_8_w_shape, add_9_bias_shape): | |||||
| """init function""" | |||||
| super(NormModule, self).__init__() | |||||
| self.reducemean_0 = P.ReduceMean(keep_dims=True) | |||||
| self.sub_1 = P.Sub() | |||||
| self.sub_2 = P.Sub() | |||||
| self.pow_3 = P.Pow() | |||||
| self.pow_3_input_weight = 2.0 | |||||
| self.reducemean_4 = P.ReduceMean(keep_dims=True) | |||||
| self.add_5 = P.Add() | |||||
| self.add_5_bias = 9.999999960041972e-13 | |||||
| self.sqrt_6 = P.Sqrt() | |||||
| self.div_7 = P.Div() | |||||
| self.mul_8 = P.Mul() | |||||
| self.mul_8_w = Parameter(Tensor(np.random.uniform(0, 1, mul_8_w_shape).astype(np.float32)), name=None) | |||||
| self.add_9 = P.Add() | |||||
| self.add_9_bias = Parameter(Tensor(np.random.uniform(0, 1, add_9_bias_shape).astype(np.float32)), name=None) | |||||
| def construct(self, x): | |||||
| """construct function""" | |||||
| opt_reducemean_0 = self.reducemean_0(x, -1) | |||||
| opt_sub_1 = self.sub_1(x, opt_reducemean_0) | |||||
| opt_sub_2 = self.sub_2(x, opt_reducemean_0) | |||||
| opt_pow_3 = self.pow_3(opt_sub_1, self.pow_3_input_weight) | |||||
| opt_reducemean_4 = self.reducemean_4(opt_pow_3, -1) | |||||
| opt_add_5 = self.add_5(opt_reducemean_4, self.add_5_bias) | |||||
| opt_sqrt_6 = self.sqrt_6(opt_add_5) | |||||
| opt_div_7 = self.div_7(opt_sub_2, opt_sqrt_6) | |||||
| opt_mul_8 = self.mul_8(self.mul_8_w, opt_div_7) | |||||
| opt_add_9 = self.add_9(opt_mul_8, self.add_9_bias) | |||||
| return opt_add_9 | |||||
| class Module16(nn.Cell): | |||||
| """module of reader downstream""" | |||||
| def __init__(self, module15_0_matmul_0_weight_shape, module15_0_add_1_bias_shape, normmodule_0_mul_8_w_shape, | |||||
| normmodule_0_add_9_bias_shape): | |||||
| """init function""" | |||||
| super(Module16, self).__init__() | |||||
| self.module15_0 = Module15(matmul_0_weight_shape=module15_0_matmul_0_weight_shape, | |||||
| add_1_bias_shape=module15_0_add_1_bias_shape) | |||||
| self.normmodule_0 = NormModule(mul_8_w_shape=normmodule_0_mul_8_w_shape, | |||||
| add_9_bias_shape=normmodule_0_add_9_bias_shape) | |||||
| self.matmul_0 = nn.MatMul() | |||||
| self.matmul_0_w = Parameter(Tensor(np.random.uniform(0, 1, (8192, 1)).astype(np.float32)), name=None) | |||||
| def construct(self, x): | |||||
| """construct function""" | |||||
| module15_0_opt = self.module15_0(x) | |||||
| normmodule_0_opt = self.normmodule_0(module15_0_opt) | |||||
| opt_matmul_0 = self.matmul_0(ops.Cast()(normmodule_0_opt, dst_type), ops.Cast()(self.matmul_0_w, dst_type)) | |||||
| return ops.Cast()(opt_matmul_0, dst_type2) | |||||
| class Module17(nn.Cell): | |||||
| """module of reader downstream""" | |||||
| def __init__(self, module15_0_matmul_0_weight_shape, module15_0_add_1_bias_shape, normmodule_0_mul_8_w_shape, | |||||
| normmodule_0_add_9_bias_shape): | |||||
| """init function""" | |||||
| super(Module17, self).__init__() | |||||
| self.module15_0 = Module15(matmul_0_weight_shape=module15_0_matmul_0_weight_shape, | |||||
| add_1_bias_shape=module15_0_add_1_bias_shape) | |||||
| self.normmodule_0 = NormModule(mul_8_w_shape=normmodule_0_mul_8_w_shape, | |||||
| add_9_bias_shape=normmodule_0_add_9_bias_shape) | |||||
| self.matmul_0 = nn.MatMul() | |||||
| self.matmul_0_w = Parameter(Tensor(np.random.uniform(0, 1, (4096, 1)).astype(np.float32)), name=None) | |||||
| self.add_1 = P.Add() | |||||
| self.add_1_bias = Parameter(Tensor(np.random.uniform(0, 1, (1,)).astype(np.float32)), name=None) | |||||
| def construct(self, x): | |||||
| """construct function""" | |||||
| module15_0_opt = self.module15_0(x) | |||||
| normmodule_0_opt = self.normmodule_0(module15_0_opt) | |||||
| opt_matmul_0 = self.matmul_0(ops.Cast()(normmodule_0_opt, dst_type), ops.Cast()(self.matmul_0_w, dst_type)) | |||||
| opt_add_1 = self.add_1(ops.Cast()(opt_matmul_0, dst_type2), self.add_1_bias) | |||||
| return opt_add_1 | |||||
| class Module5(nn.Cell): | |||||
| """module of reader downstream""" | |||||
| def __init__(self): | |||||
| """init function""" | |||||
| super(Module5, self).__init__() | |||||
| self.sub_0 = P.Sub() | |||||
| self.sub_0_bias = 1.0 | |||||
| self.mul_1 = P.Mul() | |||||
| self.mul_1_w = 1.0000000150474662e+30 | |||||
| def construct(self, x): | |||||
| """construct function""" | |||||
| opt_sub_0 = self.sub_0(self.sub_0_bias, x) | |||||
| opt_mul_1 = self.mul_1(opt_sub_0, self.mul_1_w) | |||||
| return opt_mul_1 | |||||
| class Module10(nn.Cell): | |||||
| """module of reader downstream""" | |||||
| def __init__(self): | |||||
| """init function""" | |||||
| super(Module10, self).__init__() | |||||
| self.squeeze_0 = P.Squeeze(2) | |||||
| self.module5_0 = Module5() | |||||
| self.sub_1 = P.Sub() | |||||
| def construct(self, x, x0): | |||||
| """construct function""" | |||||
| opt_squeeze_0 = self.squeeze_0(x) | |||||
| module5_0_opt = self.module5_0(x0) | |||||
| opt_sub_1 = self.sub_1(opt_squeeze_0, module5_0_opt) | |||||
| return opt_sub_1 | |||||
| class Reader_Downstream(nn.Cell): | |||||
| """Downstream model for reader""" | |||||
| def __init__(self): | |||||
| """init function""" | |||||
| super(Reader_Downstream, self).__init__() | |||||
| self.module16_0 = Module16(module15_0_matmul_0_weight_shape=(4096, 8192), | |||||
| module15_0_add_1_bias_shape=(8192,), | |||||
| normmodule_0_mul_8_w_shape=(8192,), | |||||
| normmodule_0_add_9_bias_shape=(8192,)) | |||||
| self.add_74 = P.Add() | |||||
| self.add_74_bias = Parameter(Tensor(np.random.uniform(0, 1, (1,)).astype(np.float32)), name=None) | |||||
| self.module16_1 = Module16(module15_0_matmul_0_weight_shape=(4096, 8192), | |||||
| module15_0_add_1_bias_shape=(8192,), | |||||
| normmodule_0_mul_8_w_shape=(8192,), | |||||
| normmodule_0_add_9_bias_shape=(8192,)) | |||||
| self.add_75 = P.Add() | |||||
| self.add_75_bias = Parameter(Tensor(np.random.uniform(0, 1, (1,)).astype(np.float32)), name=None) | |||||
| self.module17_0 = Module17(module15_0_matmul_0_weight_shape=(4096, 4096), | |||||
| module15_0_add_1_bias_shape=(4096,), | |||||
| normmodule_0_mul_8_w_shape=(4096,), | |||||
| normmodule_0_add_9_bias_shape=(4096,)) | |||||
| self.module10_0 = Module10() | |||||
| self.module17_1 = Module17(module15_0_matmul_0_weight_shape=(4096, 4096), | |||||
| module15_0_add_1_bias_shape=(4096,), | |||||
| normmodule_0_mul_8_w_shape=(4096,), | |||||
| normmodule_0_add_9_bias_shape=(4096,)) | |||||
| self.module10_1 = Module10() | |||||
| self.gather_6_input_weight = Tensor(np.array(0)) | |||||
| self.gather_6_axis = 1 | |||||
| self.gather_6 = P.Gather() | |||||
| self.dense_13 = nn.Dense(in_channels=4096, out_channels=4096, has_bias=True) | |||||
| self.relu_18 = nn.ReLU() | |||||
| self.normmodule_0 = NormModule(mul_8_w_shape=(4096,), add_9_bias_shape=(4096,)) | |||||
| self.dense_73 = nn.Dense(in_channels=4096, out_channels=3, has_bias=True) | |||||
| def construct(self, x, x0, x1, x2): | |||||
| """construct function""" | |||||
| module16_0_opt = self.module16_0(x) | |||||
| opt_add_74 = self.add_74(module16_0_opt, self.add_74_bias) | |||||
| module16_1_opt = self.module16_1(x0) | |||||
| opt_add_75 = self.add_75(module16_1_opt, self.add_75_bias) | |||||
| module17_0_opt = self.module17_0(x1) | |||||
| opt_module10_0 = self.module10_0(module17_0_opt, x2) | |||||
| module17_1_opt = self.module17_1(x1) | |||||
| opt_module10_1 = self.module10_1(module17_1_opt, x2) | |||||
| opt_gather_6_axis = self.gather_6_axis | |||||
| opt_gather_6 = self.gather_6(x1, self.gather_6_input_weight, opt_gather_6_axis) | |||||
| opt_dense_13 = self.dense_13(opt_gather_6) | |||||
| opt_relu_18 = self.relu_18(opt_dense_13) | |||||
| normmodule_0_opt = self.normmodule_0(opt_relu_18) | |||||
| opt_dense_73 = self.dense_73(normmodule_0_opt) | |||||
| return opt_dense_73, opt_module10_0, opt_module10_1, opt_add_74, opt_add_75 | |||||
| @@ -0,0 +1,142 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """execute reader""" | |||||
| from collections import defaultdict | |||||
| import random | |||||
| from time import time | |||||
| import json | |||||
| from tqdm import tqdm | |||||
| import numpy as np | |||||
| from transformers import AlbertTokenizer | |||||
| from mindspore import Tensor, ops | |||||
| from mindspore import dtype as mstype | |||||
| from src.rerank_and_reader_data_generator import DataGenerator | |||||
| from src.rerank_and_reader_utils import convert_to_tokens, make_wiki_id, DocDB | |||||
| from src.reader import Reader | |||||
| def read(args): | |||||
| """reader function""" | |||||
| db_file = args.wiki_db_file | |||||
| reader_feature_file = args.reader_feature_file | |||||
| reader_example_file = args.reader_example_file | |||||
| encoder_ck_file = args.reader_encoder_ck_file | |||||
| downstream_ck_file = args.reader_downstream_ck_file | |||||
| albert_model_path = args.albert_model_path | |||||
| reader_result_file = args.reader_result_file | |||||
| seed = args.seed | |||||
| sp_threshold = args.sp_threshold | |||||
| seq_len = args.seq_len | |||||
| batch_size = args.reader_batch_size | |||||
| para_limit = args.max_para_num | |||||
| sent_limit = args.max_sent_num | |||||
| random.seed(seed) | |||||
| np.random.seed(seed) | |||||
| t1 = time() | |||||
| doc_db = DocDB(db_file) | |||||
| generator = DataGenerator(feature_file_path=reader_feature_file, | |||||
| example_file_path=reader_example_file, | |||||
| batch_size=batch_size, seq_len=seq_len, | |||||
| para_limit=para_limit, sent_limit=sent_limit, | |||||
| task_type="reader") | |||||
| example_dict = generator.example_dict | |||||
| feature_dict = generator.feature_dict | |||||
| answer_dict = defaultdict(lambda: defaultdict(list)) | |||||
| new_answer_dict = {} | |||||
| total_sp_dict = defaultdict(list) | |||||
| new_total_sp_dict = defaultdict(list) | |||||
| tokenizer = AlbertTokenizer.from_pretrained(albert_model_path) | |||||
| new_tokens = ['[q]', '[/q]', '<t>', '</t>', '[s]'] | |||||
| tokenizer.add_tokens(new_tokens) | |||||
| reader = Reader(batch_size=batch_size, | |||||
| encoder_ck_file=encoder_ck_file, | |||||
| downstream_ck_file=downstream_ck_file) | |||||
| print("start reading ...") | |||||
| for _, batch in tqdm(enumerate(generator)): | |||||
| input_ids = Tensor(batch["context_idxs"], mstype.int32) | |||||
| attn_mask = Tensor(batch["context_mask"], mstype.int32) | |||||
| token_type_ids = Tensor(batch["segment_idxs"], mstype.int32) | |||||
| context_mask = Tensor(batch["context_mask"], mstype.float32) | |||||
| square_mask = Tensor(batch["square_mask"], mstype.float32) | |||||
| packing_mask = Tensor(batch["query_mapping"], mstype.float32) | |||||
| para_start_mapping = Tensor(batch["para_start_mapping"], mstype.float32) | |||||
| sent_end_mapping = Tensor(batch["sent_end_mapping"], mstype.float32) | |||||
| unique_ids = batch["unique_ids"] | |||||
| sent_names = batch["sent_names"] | |||||
| cache_mask = Tensor(np.tril(np.triu(np.ones((seq_len, seq_len)), 0), 30), mstype.float32) | |||||
| _, _, q_type, _, sent_logit, y1, y2 = reader(input_ids, attn_mask, token_type_ids, | |||||
| context_mask, square_mask, packing_mask, cache_mask, | |||||
| para_start_mapping, sent_end_mapping) | |||||
| type_prob = ops.Softmax()(q_type).asnumpy() | |||||
| answer_dict_ = convert_to_tokens(example_dict, | |||||
| feature_dict, | |||||
| batch['ids'], | |||||
| y1.asnumpy().tolist(), | |||||
| y2.asnumpy().tolist(), | |||||
| type_prob, | |||||
| tokenizer, | |||||
| sent_logit.asnumpy(), | |||||
| sent_names, | |||||
| unique_ids) | |||||
| for q_id in answer_dict_: | |||||
| answer_dict[q_id] = answer_dict_[q_id] | |||||
| for q_id in answer_dict: | |||||
| res = answer_dict[q_id] | |||||
| answer_text_ = res[0] | |||||
| sent_ = res[1] | |||||
| sent_names_ = res[2] | |||||
| new_answer_dict[q_id] = answer_text_ | |||||
| predict_support_np = ops.Sigmoid()(Tensor(sent_, mstype.float32)).asnumpy() | |||||
| for j in range(predict_support_np.shape[0]): | |||||
| if j >= len(sent_names_): | |||||
| break | |||||
| if predict_support_np[j] > sp_threshold: | |||||
| total_sp_dict[q_id].append(sent_names_[j]) | |||||
| for _id in total_sp_dict: | |||||
| _sent_names = total_sp_dict[_id] | |||||
| for para in _sent_names: | |||||
| title = make_wiki_id(para[0], 0) | |||||
| para_original_title = doc_db.get_doc_info(title)[-1] | |||||
| para[0] = para_original_title | |||||
| new_total_sp_dict[_id].append(para) | |||||
| prediction = {'answer': new_answer_dict, | |||||
| 'sp': new_total_sp_dict} | |||||
| with open(reader_result_file, 'w') as f: | |||||
| json.dump(prediction, f, indent=4) | |||||
| t2 = time() | |||||
| print(f"reader cost time: {t2-t1} s") | |||||
| @@ -0,0 +1,276 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """albert-xxlarge Model for reranker""" | |||||
| import numpy as np | |||||
| from mindspore import nn, ops | |||||
| from mindspore import Tensor, Parameter | |||||
| from mindspore.ops import operations as P | |||||
| from mindspore import dtype as mstype | |||||
| dst_type = mstype.float16 | |||||
| dst_type2 = mstype.float32 | |||||
| class LayerNorm(nn.Cell): | |||||
| """LayerNorm layer""" | |||||
| def __init__(self, passthrough_w_0, passthrough_w_1): | |||||
| """init function""" | |||||
| super(LayerNorm, self).__init__() | |||||
| self.reducemean_0 = P.ReduceMean(keep_dims=True) | |||||
| self.sub_1 = P.Sub() | |||||
| self.pow_2 = P.Pow() | |||||
| self.pow_2_input_weight = 2.0 | |||||
| self.reducemean_3 = P.ReduceMean(keep_dims=True) | |||||
| self.add_4 = P.Add() | |||||
| self.add_4_bias = 9.999999960041972e-13 | |||||
| self.sqrt_5 = P.Sqrt() | |||||
| self.div_6 = P.Div() | |||||
| self.mul_7 = P.Mul() | |||||
| self.mul_7_w = passthrough_w_0 | |||||
| self.add_8 = P.Add() | |||||
| self.add_8_bias = passthrough_w_1 | |||||
| def construct(self, x): | |||||
| """construct function""" | |||||
| opt_reducemean_0 = self.reducemean_0(x, -1) | |||||
| opt_sub_1 = self.sub_1(x, opt_reducemean_0) | |||||
| opt_pow_2 = self.pow_2(opt_sub_1, self.pow_2_input_weight) | |||||
| opt_reducemean_3 = self.reducemean_3(opt_pow_2, -1) | |||||
| opt_add_4 = self.add_4(opt_reducemean_3, self.add_4_bias) | |||||
| opt_sqrt_5 = self.sqrt_5(opt_add_4) | |||||
| opt_div_6 = self.div_6(opt_sub_1, opt_sqrt_5) | |||||
| opt_mul_7 = self.mul_7(opt_div_6, self.mul_7_w) | |||||
| opt_add_8 = self.add_8(opt_mul_7, self.add_8_bias) | |||||
| return opt_add_8 | |||||
| class Linear(nn.Cell): | |||||
| """Linear layer""" | |||||
| def __init__(self, matmul_0_w_shape, passthrough_w_0): | |||||
| """init function""" | |||||
| super(Linear, self).__init__() | |||||
| self.matmul_0 = nn.MatMul() | |||||
| self.matmul_0_w = Parameter(Tensor(np.random.uniform(0, 1, matmul_0_w_shape).astype(np.float32)), name=None) | |||||
| self.add_1 = P.Add() | |||||
| self.add_1_bias = passthrough_w_0 | |||||
| def construct(self, x): | |||||
| """construct function""" | |||||
| opt_matmul_0 = self.matmul_0(ops.Cast()(x, dst_type), ops.Cast()(self.matmul_0_w, dst_type)) | |||||
| opt_add_1 = self.add_1(ops.Cast()(opt_matmul_0, dst_type2), self.add_1_bias) | |||||
| return opt_add_1 | |||||
| class MultiHeadAttn(nn.Cell): | |||||
| """Multi-head attention layer""" | |||||
| def __init__(self, batch_size, passthrough_w_0, passthrough_w_1, passthrough_w_2): | |||||
| """init function""" | |||||
| super(MultiHeadAttn, self).__init__() | |||||
| self.batch_size = batch_size | |||||
| self.matmul_0 = nn.MatMul() | |||||
| self.matmul_0_w = Parameter(Tensor(np.random.uniform(0, 1, (4096, 4096)).astype(np.float32)), name=None) | |||||
| self.matmul_1 = nn.MatMul() | |||||
| self.matmul_1_w = Parameter(Tensor(np.random.uniform(0, 1, (4096, 4096)).astype(np.float32)), name=None) | |||||
| self.matmul_2 = nn.MatMul() | |||||
| self.matmul_2_w = Parameter(Tensor(np.random.uniform(0, 1, (4096, 4096)).astype(np.float32)), name=None) | |||||
| self.add_3 = P.Add() | |||||
| self.add_3_bias = passthrough_w_0 | |||||
| self.add_4 = P.Add() | |||||
| self.add_4_bias = passthrough_w_1 | |||||
| self.add_5 = P.Add() | |||||
| self.add_5_bias = passthrough_w_2 | |||||
| self.reshape_6 = P.Reshape() | |||||
| self.reshape_6_shape = tuple([batch_size, 512, 64, 64]) | |||||
| self.reshape_7 = P.Reshape() | |||||
| self.reshape_7_shape = tuple([batch_size, 512, 64, 64]) | |||||
| self.reshape_8 = P.Reshape() | |||||
| self.reshape_8_shape = tuple([batch_size, 512, 64, 64]) | |||||
| self.transpose_9 = P.Transpose() | |||||
| self.transpose_10 = P.Transpose() | |||||
| self.transpose_11 = P.Transpose() | |||||
| self.matmul_12 = nn.MatMul() | |||||
| self.div_13 = P.Div() | |||||
| self.div_13_w = 8.0 | |||||
| self.add_14 = P.Add() | |||||
| self.softmax_15 = nn.Softmax(axis=3) | |||||
| self.matmul_16 = nn.MatMul() | |||||
| self.transpose_17 = P.Transpose() | |||||
| self.matmul_18 = P.MatMul() | |||||
| self.matmul_18_weight = Parameter(Tensor(np.random.uniform(0, 1, (64, 64, 4096)).astype(np.float32)), name=None) | |||||
| self.add_19 = P.Add() | |||||
| self.add_19_bias = Parameter(Tensor(np.random.uniform(0, 1, (4096,)).astype(np.float32)), name=None) | |||||
| def construct(self, x, x0): | |||||
| """construct function""" | |||||
| opt_matmul_0 = self.matmul_0(ops.Cast()(x, dst_type), ops.Cast()(self.matmul_0_w, dst_type)) | |||||
| opt_matmul_1 = self.matmul_1(ops.Cast()(x, dst_type), ops.Cast()(self.matmul_1_w, dst_type)) | |||||
| opt_matmul_2 = self.matmul_2(ops.Cast()(x, dst_type), ops.Cast()(self.matmul_2_w, dst_type)) | |||||
| opt_add_3 = self.add_3(ops.Cast()(opt_matmul_0, dst_type2), self.add_3_bias) | |||||
| opt_add_4 = self.add_4(ops.Cast()(opt_matmul_1, dst_type2), self.add_4_bias) | |||||
| opt_add_5 = self.add_5(ops.Cast()(opt_matmul_2, dst_type2), self.add_5_bias) | |||||
| opt_reshape_6 = self.reshape_6(opt_add_3, self.reshape_6_shape) | |||||
| opt_reshape_7 = self.reshape_7(opt_add_4, self.reshape_7_shape) | |||||
| opt_reshape_8 = self.reshape_8(opt_add_5, self.reshape_8_shape) | |||||
| opt_transpose_9 = self.transpose_9(opt_reshape_6, (0, 2, 1, 3)) | |||||
| opt_transpose_10 = self.transpose_10(opt_reshape_7, (0, 2, 3, 1)) | |||||
| opt_transpose_11 = self.transpose_11(opt_reshape_8, (0, 2, 1, 3)) | |||||
| opt_matmul_12 = self.matmul_12(ops.Cast()(opt_transpose_9, dst_type), ops.Cast()(opt_transpose_10, dst_type)) | |||||
| opt_div_13 = self.div_13(ops.Cast()(opt_matmul_12, dst_type2), ops.Cast()(self.div_13_w, dst_type2)) | |||||
| opt_add_14 = self.add_14(opt_div_13, x0) | |||||
| opt_softmax_15 = self.softmax_15(opt_add_14) | |||||
| opt_matmul_16 = self.matmul_16(ops.Cast()(opt_softmax_15, dst_type), ops.Cast()(opt_transpose_11, dst_type)) | |||||
| opt_transpose_17 = self.transpose_17(ops.Cast()(opt_matmul_16, dst_type2), (0, 2, 1, 3)) | |||||
| opt_matmul_18 = self.matmul_18(ops.Cast()(opt_transpose_17, dst_type).view(self.batch_size * 512, -1), | |||||
| ops.Cast()(self.matmul_18_weight, dst_type).view(-1, 4096))\ | |||||
| .view(self.batch_size, 512, 4096) | |||||
| opt_add_19 = self.add_19(ops.Cast()(opt_matmul_18, dst_type2), self.add_19_bias) | |||||
| return opt_add_19 | |||||
| class NewGeLU(nn.Cell): | |||||
| """Gelu layer""" | |||||
| def __init__(self): | |||||
| """init function""" | |||||
| super(NewGeLU, self).__init__() | |||||
| self.mul_0 = P.Mul() | |||||
| self.mul_0_w = 0.5 | |||||
| self.pow_1 = P.Pow() | |||||
| self.pow_1_input_weight = 3.0 | |||||
| self.mul_2 = P.Mul() | |||||
| self.mul_2_w = 0.044714998453855515 | |||||
| self.add_3 = P.Add() | |||||
| self.mul_4 = P.Mul() | |||||
| self.mul_4_w = 0.7978845834732056 | |||||
| self.tanh_5 = nn.Tanh() | |||||
| self.add_6 = P.Add() | |||||
| self.add_6_bias = 1.0 | |||||
| self.mul_7 = P.Mul() | |||||
| def construct(self, x): | |||||
| """construct function""" | |||||
| opt_mul_0 = self.mul_0(x, self.mul_0_w) | |||||
| opt_pow_1 = self.pow_1(x, self.pow_1_input_weight) | |||||
| opt_mul_2 = self.mul_2(opt_pow_1, self.mul_2_w) | |||||
| opt_add_3 = self.add_3(x, opt_mul_2) | |||||
| opt_mul_4 = self.mul_4(opt_add_3, self.mul_4_w) | |||||
| opt_tanh_5 = self.tanh_5(opt_mul_4) | |||||
| opt_add_6 = self.add_6(opt_tanh_5, self.add_6_bias) | |||||
| opt_mul_7 = self.mul_7(opt_mul_0, opt_add_6) | |||||
| return opt_mul_7 | |||||
| class TransformerLayerWithLayerNorm(nn.Cell): | |||||
| """Transformer layer with LayerNOrm""" | |||||
| def __init__(self, batch_size, linear3_0_matmul_0_w_shape, linear3_1_matmul_0_w_shape, passthrough_w_0, | |||||
| passthrough_w_1, passthrough_w_2, passthrough_w_3, passthrough_w_4, passthrough_w_5, passthrough_w_6): | |||||
| """init function""" | |||||
| super(TransformerLayerWithLayerNorm, self).__init__() | |||||
| self.multiheadattn_0 = MultiHeadAttn(batch_size=batch_size, | |||||
| passthrough_w_0=passthrough_w_0, | |||||
| passthrough_w_1=passthrough_w_1, | |||||
| passthrough_w_2=passthrough_w_2) | |||||
| self.add_0 = P.Add() | |||||
| self.layernorm1_0 = LayerNorm(passthrough_w_0=passthrough_w_3, passthrough_w_1=passthrough_w_4) | |||||
| self.linear3_0 = Linear(matmul_0_w_shape=linear3_0_matmul_0_w_shape, passthrough_w_0=passthrough_w_5) | |||||
| self.newgelu2_0 = NewGeLU() | |||||
| self.linear3_1 = Linear(matmul_0_w_shape=linear3_1_matmul_0_w_shape, passthrough_w_0=passthrough_w_6) | |||||
| self.add_1 = P.Add() | |||||
| def construct(self, x, x0): | |||||
| """construct function""" | |||||
| multiheadattn_0_opt = self.multiheadattn_0(x, x0) | |||||
| opt_add_0 = self.add_0(x, multiheadattn_0_opt) | |||||
| layernorm1_0_opt = self.layernorm1_0(opt_add_0) | |||||
| linear3_0_opt = self.linear3_0(layernorm1_0_opt) | |||||
| newgelu2_0_opt = self.newgelu2_0(linear3_0_opt) | |||||
| linear3_1_opt = self.linear3_1(newgelu2_0_opt) | |||||
| opt_add_1 = self.add_1(linear3_1_opt, layernorm1_0_opt) | |||||
| return opt_add_1 | |||||
| class Rerank_Albert(nn.Cell): | |||||
| """Albert model for rerank""" | |||||
| def __init__(self, batch_size): | |||||
| """init function""" | |||||
| super(Rerank_Albert, self).__init__() | |||||
| self.passthrough_w_0 = Parameter(Tensor(np.random.uniform(0, 1, (128,)).astype(np.float32)), name=None) | |||||
| self.passthrough_w_1 = Parameter(Tensor(np.random.uniform(0, 1, (128,)).astype(np.float32)), name=None) | |||||
| self.passthrough_w_2 = Parameter(Tensor(np.random.uniform(0, 1, (4096,)).astype(np.float32)), name=None) | |||||
| self.passthrough_w_3 = Parameter(Tensor(np.random.uniform(0, 1, (4096,)).astype(np.float32)), name=None) | |||||
| self.passthrough_w_4 = Parameter(Tensor(np.random.uniform(0, 1, (4096,)).astype(np.float32)), name=None) | |||||
| self.passthrough_w_5 = Parameter(Tensor(np.random.uniform(0, 1, (4096,)).astype(np.float32)), name=None) | |||||
| self.passthrough_w_6 = Parameter(Tensor(np.random.uniform(0, 1, (4096,)).astype(np.float32)), name=None) | |||||
| self.passthrough_w_7 = Parameter(Tensor(np.random.uniform(0, 1, (4096,)).astype(np.float32)), name=None) | |||||
| self.passthrough_w_8 = Parameter(Tensor(np.random.uniform(0, 1, (16384,)).astype(np.float32)), name=None) | |||||
| self.passthrough_w_9 = Parameter(Tensor(np.random.uniform(0, 1, (4096,)).astype(np.float32)), name=None) | |||||
| self.passthrough_w_10 = Parameter(Tensor(np.random.uniform(0, 1, (4096,)).astype(np.float32)), name=None) | |||||
| self.passthrough_w_11 = Parameter(Tensor(np.random.uniform(0, 1, (4096,)).astype(np.float32)), name=None) | |||||
| self.expanddims_0 = P.ExpandDims() | |||||
| self.expanddims_0_axis = 1 | |||||
| self.expanddims_3 = P.ExpandDims() | |||||
| self.expanddims_3_axis = 2 | |||||
| self.cast_5 = P.Cast() | |||||
| self.cast_5_to = mstype.float32 | |||||
| self.sub_7 = P.Sub() | |||||
| self.sub_7_bias = 1.0 | |||||
| self.mul_9 = P.Mul() | |||||
| self.mul_9_w = -10000.0 | |||||
| self.gather_1_input_weight = Parameter(Tensor(np.random.uniform(0, 1, (30005, 128)).astype(np.float32)), | |||||
| name=None) | |||||
| self.gather_1_axis = 0 | |||||
| self.gather_1 = P.Gather() | |||||
| self.gather_2_input_weight = Parameter(Tensor(np.random.uniform(0, 1, (2, 128)).astype(np.float32)), name=None) | |||||
| self.gather_2_axis = 0 | |||||
| self.gather_2 = P.Gather() | |||||
| self.add_4 = P.Add() | |||||
| self.add_4_bias = Parameter(Tensor(np.random.uniform(0, 1, (1, 512, 128)).astype(np.float32)), name=None) | |||||
| self.add_6 = P.Add() | |||||
| self.layernorm1_0 = LayerNorm(passthrough_w_0=self.passthrough_w_0, passthrough_w_1=self.passthrough_w_1) | |||||
| self.linear3_0 = Linear(matmul_0_w_shape=(128, 4096), passthrough_w_0=self.passthrough_w_2) | |||||
| self.module34_0 = TransformerLayerWithLayerNorm(batch_size=batch_size, | |||||
| linear3_0_matmul_0_w_shape=(4096, 16384), | |||||
| linear3_1_matmul_0_w_shape=(16384, 4096), | |||||
| passthrough_w_0=self.passthrough_w_3, | |||||
| passthrough_w_1=self.passthrough_w_4, | |||||
| passthrough_w_2=self.passthrough_w_5, | |||||
| passthrough_w_3=self.passthrough_w_6, | |||||
| passthrough_w_4=self.passthrough_w_7, | |||||
| passthrough_w_5=self.passthrough_w_8, | |||||
| passthrough_w_6=self.passthrough_w_9) | |||||
| self.layernorm1_1 = LayerNorm(passthrough_w_0=self.passthrough_w_10, passthrough_w_1=self.passthrough_w_11) | |||||
| def construct(self, input_ids, attention_mask, token_type_ids): | |||||
| """construct function""" | |||||
| opt_expanddims_0 = self.expanddims_0(attention_mask, self.expanddims_0_axis) | |||||
| opt_expanddims_3 = self.expanddims_3(opt_expanddims_0, self.expanddims_3_axis) | |||||
| opt_cast_5 = self.cast_5(opt_expanddims_3, self.cast_5_to) | |||||
| opt_sub_7 = self.sub_7(self.sub_7_bias, opt_cast_5) | |||||
| opt_mul_9 = self.mul_9(opt_sub_7, self.mul_9_w) | |||||
| opt_gather_1_axis = self.gather_1_axis | |||||
| opt_gather_1 = self.gather_1(self.gather_1_input_weight, input_ids, opt_gather_1_axis) | |||||
| opt_gather_2_axis = self.gather_2_axis | |||||
| opt_gather_2 = self.gather_2(self.gather_2_input_weight, token_type_ids, opt_gather_2_axis) | |||||
| opt_add_4 = self.add_4(opt_gather_1, self.add_4_bias) | |||||
| opt_add_6 = self.add_6(opt_add_4, opt_gather_2) | |||||
| layernorm1_0_opt = self.layernorm1_0(opt_add_6) | |||||
| linear3_0_opt = self.linear3_0(layernorm1_0_opt) | |||||
| opt = self.module34_0(linear3_0_opt, opt_mul_9) | |||||
| opt = self.layernorm1_1(opt) | |||||
| for _ in range(11): | |||||
| opt = self.module34_0(opt, opt_mul_9) | |||||
| opt = self.layernorm1_1(opt) | |||||
| return opt | |||||
| @@ -0,0 +1,183 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """define a data generator""" | |||||
| import gzip | |||||
| import pickle | |||||
| import random | |||||
| import numpy as np | |||||
| random.seed(42) | |||||
| np.random.seed(42) | |||||
| class DataGenerator: | |||||
| """data generator for reranker and reader""" | |||||
| def __init__(self, feature_file_path, example_file_path, batch_size, seq_len, | |||||
| para_limit=None, sent_limit=None, task_type=None): | |||||
| """init function""" | |||||
| self.example_ptr = 0 | |||||
| self.bsz = batch_size | |||||
| self.seq_length = seq_len | |||||
| self.para_limit = para_limit | |||||
| self.sent_limit = sent_limit | |||||
| self.task_type = task_type | |||||
| self.feature_file_path = feature_file_path | |||||
| self.example_file_path = example_file_path | |||||
| self.features = self.load_features() | |||||
| self.examples = self.load_examples() | |||||
| self.feature_dict = self.get_feature_dict() | |||||
| self.example_dict = self.get_example_dict() | |||||
| self.features = self.padding_feature(self.features, self.bsz) | |||||
| def load_features(self): | |||||
| """load features from feature file""" | |||||
| with gzip.open(self.feature_file_path, 'rb') as fin: | |||||
| features = pickle.load(fin) | |||||
| print("load features successful !!!") | |||||
| return features | |||||
| def padding_feature(self, features, bsz): | |||||
| """padding features as multiples of batch size""" | |||||
| padding_num = ((len(features) // bsz + 1) * bsz - len(features)) | |||||
| print(f"features padding num is {padding_num}") | |||||
| new_features = features + features[:padding_num] | |||||
| return new_features | |||||
| def load_examples(self): | |||||
| """laod examples from file""" | |||||
| if self.example_file_path: | |||||
| with gzip.open(self.example_file_path, 'rb') as fin: | |||||
| examples = pickle.load(fin) | |||||
| print("load examples successful !!!") | |||||
| return examples | |||||
| return {} | |||||
| def get_feature_dict(self): | |||||
| """build a feature dict""" | |||||
| return {f.unique_id: f for f in self.features} | |||||
| def get_example_dict(self): | |||||
| """build a example dict""" | |||||
| if self.example_file_path: | |||||
| return {e.unique_id: e for e in self.examples} | |||||
| return {} | |||||
| def common_process_single_case(self, i, case, context_idxs, context_mask, segment_idxs, ids, path, unique_ids): | |||||
| """common process for a single case""" | |||||
| context_idxs[i] = np.array(case.doc_input_ids) | |||||
| context_mask[i] = np.array(case.doc_input_mask) | |||||
| segment_idxs[i] = np.array(case.doc_segment_ids) | |||||
| ids.append(case.qas_id) | |||||
| path.append(case.path) | |||||
| unique_ids.append(case.unique_id) | |||||
| return context_idxs, context_mask, segment_idxs, ids, path, unique_ids | |||||
| def reader_process_single_case(self, i, case, sent_names, square_mask, query_mapping, ques_start_mapping, | |||||
| para_start_mapping, sent_end_mapping): | |||||
| """process for a single case about reader""" | |||||
| sent_names.append(case.sent_names) | |||||
| prev_position = None | |||||
| for cur_position, token_id in enumerate(case.doc_input_ids): | |||||
| if token_id >= 30000: | |||||
| if prev_position: | |||||
| square_mask[i, prev_position + 1: cur_position, prev_position + 1: cur_position] = 1.0 | |||||
| prev_position = cur_position | |||||
| if case.sent_spans: | |||||
| for j in range(case.sent_spans[0][0] - 1): | |||||
| query_mapping[i, j] = 1 | |||||
| ques_start_mapping[i, 0, 1] = 1 | |||||
| for j, para_span in enumerate(case.para_spans[:self.para_limit]): | |||||
| start, end, _ = para_span | |||||
| if start <= end: | |||||
| para_start_mapping[i, j, start] = 1 | |||||
| for j, sent_span in enumerate(case.sent_spans[:self.sent_limit]): | |||||
| start, end = sent_span | |||||
| if start <= end: | |||||
| end = min(end, self.seq_length - 1) | |||||
| sent_end_mapping[i, j, end] = 1 | |||||
| return sent_names, square_mask, query_mapping, ques_start_mapping, para_start_mapping, sent_end_mapping | |||||
| def __iter__(self): | |||||
| """iteration function""" | |||||
| while True: | |||||
| if self.example_ptr >= len(self.features): | |||||
| break | |||||
| start_id = self.example_ptr | |||||
| cur_bsz = min(self.bsz, len(self.features) - start_id) | |||||
| cur_batch = self.features[start_id: start_id + cur_bsz] | |||||
| # BERT input | |||||
| context_idxs = np.zeros((cur_bsz, self.seq_length)) | |||||
| context_mask = np.zeros((cur_bsz, self.seq_length)) | |||||
| segment_idxs = np.zeros((cur_bsz, self.seq_length)) | |||||
| # others | |||||
| ids = [] | |||||
| path = [] | |||||
| unique_ids = [] | |||||
| if self.task_type == "reader": | |||||
| # Mappings | |||||
| ques_start_mapping = np.zeros((cur_bsz, 1, self.seq_length)) | |||||
| query_mapping = np.zeros((cur_bsz, self.seq_length)) | |||||
| para_start_mapping = np.zeros((cur_bsz, self.para_limit, self.seq_length)) | |||||
| sent_end_mapping = np.zeros((cur_bsz, self.sent_limit, self.seq_length)) | |||||
| square_mask = np.zeros((cur_bsz, self.seq_length, self.seq_length)) | |||||
| sent_names = [] | |||||
| for i, case in enumerate(cur_batch): | |||||
| context_idxs, context_mask, segment_idxs, ids, path, unique_ids = \ | |||||
| self.common_process_single_case(i, case, context_idxs, context_mask, segment_idxs, ids, path, | |||||
| unique_ids) | |||||
| if self.task_type == "reader": | |||||
| sent_names, square_mask, query_mapping, ques_start_mapping, para_start_mapping, sent_end_mapping = \ | |||||
| self.reader_process_single_case(i, case, sent_names, square_mask, query_mapping, | |||||
| ques_start_mapping, para_start_mapping, sent_end_mapping) | |||||
| self.example_ptr += cur_bsz | |||||
| if self.task_type == "reranker": | |||||
| yield { | |||||
| "context_idxs": context_idxs, | |||||
| "context_mask": context_mask, | |||||
| "segment_idxs": segment_idxs, | |||||
| "ids": ids, | |||||
| "unique_ids": unique_ids, | |||||
| "path": path | |||||
| } | |||||
| elif self.task_type == "reader": | |||||
| yield { | |||||
| "context_idxs": context_idxs, | |||||
| "context_mask": context_mask, | |||||
| "segment_idxs": segment_idxs, | |||||
| "query_mapping": query_mapping, | |||||
| "para_start_mapping": para_start_mapping, | |||||
| "sent_end_mapping": sent_end_mapping, | |||||
| "square_mask": square_mask, | |||||
| "ques_start_mapping": ques_start_mapping, | |||||
| "ids": ids, | |||||
| "unique_ids": unique_ids, | |||||
| "sent_names": sent_names, | |||||
| "path": path | |||||
| } | |||||
| else: | |||||
| print(f"data generator received a error type: {self.task_type} !!!") | |||||
| @@ -0,0 +1,656 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """utils""" | |||||
| import re | |||||
| import argparse | |||||
| from urllib.parse import unquote | |||||
| from collections import defaultdict | |||||
| import collections | |||||
| import logging | |||||
| import unicodedata | |||||
| import json | |||||
| import gzip | |||||
| import string | |||||
| import pickle | |||||
| import sqlite3 | |||||
| from tqdm import tqdm | |||||
| import numpy as np | |||||
| from transformers import BasicTokenizer | |||||
| logger = logging.getLogger(__name__) | |||||
| class Example: | |||||
| """A single example of data""" | |||||
| def __init__(self, | |||||
| qas_id, | |||||
| path, | |||||
| unique_id, | |||||
| question_tokens, | |||||
| doc_tokens, | |||||
| sent_names, | |||||
| sup_fact_id, | |||||
| sup_para_id, | |||||
| para_start_end_position, | |||||
| sent_start_end_position, | |||||
| question_text, | |||||
| title_start_end_position=None): | |||||
| """init function""" | |||||
| self.qas_id = qas_id | |||||
| self.path = path | |||||
| self.unique_id = unique_id | |||||
| self.question_tokens = question_tokens | |||||
| self.doc_tokens = doc_tokens | |||||
| self.question_text = question_text | |||||
| self.sent_names = sent_names | |||||
| self.sup_fact_id = sup_fact_id | |||||
| self.sup_para_id = sup_para_id | |||||
| self.para_start_end_position = para_start_end_position | |||||
| self.sent_start_end_position = sent_start_end_position | |||||
| self.title_start_end_position = title_start_end_position | |||||
| class InputFeatures: | |||||
| """A single set of features of data.""" | |||||
| def __init__(self, | |||||
| unique_id, | |||||
| qas_id, | |||||
| path, | |||||
| sent_names, | |||||
| doc_tokens, | |||||
| doc_input_ids, | |||||
| doc_input_mask, | |||||
| doc_segment_ids, | |||||
| query_tokens, | |||||
| query_input_ids, | |||||
| query_input_mask, | |||||
| query_segment_ids, | |||||
| para_spans, | |||||
| sent_spans, | |||||
| token_to_orig_map): | |||||
| """init function""" | |||||
| self.qas_id = qas_id | |||||
| self.doc_tokens = doc_tokens | |||||
| self.doc_input_ids = doc_input_ids | |||||
| self.doc_input_mask = doc_input_mask | |||||
| self.doc_segment_ids = doc_segment_ids | |||||
| self.path = path | |||||
| self.unique_id = unique_id | |||||
| self.sent_names = sent_names | |||||
| self.query_tokens = query_tokens | |||||
| self.query_input_ids = query_input_ids | |||||
| self.query_input_mask = query_input_mask | |||||
| self.query_segment_ids = query_segment_ids | |||||
| self.para_spans = para_spans | |||||
| self.sent_spans = sent_spans | |||||
| self.token_to_orig_map = token_to_orig_map | |||||
| class DocDB: | |||||
| """ | |||||
| Sqlite backed document storage. | |||||
| Implements get_doc_text(doc_id). | |||||
| """ | |||||
| def __init__(self, db_path): | |||||
| """init function""" | |||||
| self.path = db_path | |||||
| self.connection = sqlite3.connect(self.path, check_same_thread=False) | |||||
| def __enter__(self): | |||||
| """enter function""" | |||||
| return self | |||||
| def __exit__(self, *args): | |||||
| """exit function""" | |||||
| self.close() | |||||
| def close(self): | |||||
| """Close the connection to the database.""" | |||||
| self.connection.close() | |||||
| def get_doc_ids(self): | |||||
| """Fetch all ids of docs stored in the db.""" | |||||
| cursor = self.connection.cursor() | |||||
| cursor.execute("SELECT id FROM documents") | |||||
| results = [r[0] for r in cursor.fetchall()] | |||||
| cursor.close() | |||||
| return results | |||||
| def get_doc_info(self, doc_id): | |||||
| """get docment information""" | |||||
| if not doc_id.endswith('_0'): | |||||
| doc_id += '_0' | |||||
| cursor = self.connection.cursor() | |||||
| cursor.execute( | |||||
| "SELECT * FROM documents WHERE id = ?", | |||||
| (normalize_title(doc_id),) | |||||
| ) | |||||
| result = cursor.fetchall() | |||||
| cursor.close() | |||||
| return result if result is None else result[0] | |||||
| def get_parse(): | |||||
| """get parse function""" | |||||
| parser = argparse.ArgumentParser() | |||||
| # Environment | |||||
| parser.add_argument('--seed', type=int, default=42, | |||||
| help="random seed for initialization") | |||||
| parser.add_argument('--seq_len', type=int, default=512, | |||||
| help="max sentence length") | |||||
| parser.add_argument("--get_reranker_data", | |||||
| action='store_true', | |||||
| help="Set this flag if you want to get reranker data from retrieved result") | |||||
| parser.add_argument("--run_reranker", | |||||
| action='store_true', | |||||
| help="Set this flag if you want to run reranker") | |||||
| parser.add_argument("--cal_reranker_metrics", | |||||
| action='store_true', | |||||
| help="Set this flag if you want to calculate rerank metrics") | |||||
| parser.add_argument("--select_reader_data", | |||||
| action='store_true', | |||||
| help="Set this flag if you want to select reader data") | |||||
| parser.add_argument("--run_reader", | |||||
| action='store_true', | |||||
| help="Set this flag if you want to run reader") | |||||
| parser.add_argument("--cal_reader_metrics", | |||||
| action='store_true', | |||||
| help="Set this flag if you want to calculate reader metrics") | |||||
| parser.add_argument('--dev_gold_file', | |||||
| type=str, | |||||
| default="../hotpot_dev_fullwiki_v1.json", | |||||
| help='file of dev ground truth') | |||||
| parser.add_argument('--wiki_db_file', | |||||
| type=str, | |||||
| default="../enwiki_offset.db", | |||||
| help='wiki_database_file') | |||||
| parser.add_argument('--albert_model_path', | |||||
| type=str, | |||||
| default="../albert-xxlarge/", | |||||
| help='model path of huggingface albert-xxlarge') | |||||
| # Retriever | |||||
| parser.add_argument('--retriever_result_file', | |||||
| type=str, | |||||
| default="../doc_path", | |||||
| help='file of retriever result') | |||||
| # Rerank | |||||
| parser.add_argument('--rerank_batch_size', type=int, default=32, | |||||
| help="rerank batchsize for evaluating") | |||||
| parser.add_argument('--rerank_feature_file', | |||||
| type=str, | |||||
| default="../reranker_feature_file.pkl.gz", | |||||
| help='file of rerank feature') | |||||
| parser.add_argument('--rerank_example_file', | |||||
| type=str, | |||||
| default="../reranker_example_file.pkl.gz", | |||||
| help='file of rerank example') | |||||
| parser.add_argument('--rerank_result_file', | |||||
| type=str, | |||||
| default="../rerank_result.json", | |||||
| help='file of rerank result') | |||||
| parser.add_argument('--rerank_encoder_ck_file', | |||||
| type=str, | |||||
| default="../rerank_albert_12.ckpt", | |||||
| help='checkpoint of rerank albert-xxlarge') | |||||
| parser.add_argument('--rerank_downstream_ck_file', | |||||
| type=str, | |||||
| default="../rerank_downstream.ckpt", | |||||
| help='checkpoint of rerank downstream') | |||||
| # Reader | |||||
| parser.add_argument('--reader_batch_size', type=int, default=32, | |||||
| help="reader batchsize for evaluating") | |||||
| parser.add_argument('--reader_feature_file', | |||||
| type=str, | |||||
| default="../reader_feature_file.pkl.gz", | |||||
| help='file of reader feature') | |||||
| parser.add_argument('--reader_example_file', | |||||
| type=str, | |||||
| default="../reader_example_file.pkl.gz", | |||||
| help='file of reader example') | |||||
| parser.add_argument('--reader_encoder_ck_file', | |||||
| type=str, | |||||
| default="../albert_12_layer.ckpt", | |||||
| help='checkpoint of reader albert-xxlarge') | |||||
| parser.add_argument('--reader_downstream_ck_file', | |||||
| type=str, | |||||
| default="../reader_downstream.ckpt", | |||||
| help='checkpoint of reader downstream') | |||||
| parser.add_argument('--reader_result_file', | |||||
| type=str, | |||||
| default="../reader_result_file.json", | |||||
| help='file of reader result') | |||||
| parser.add_argument('--sp_threshold', type=float, default=0.65, | |||||
| help="threshold for selecting supporting sentences") | |||||
| parser.add_argument("--max_para_num", default=2, type=int) | |||||
| parser.add_argument("--max_sent_num", default=40, type=int) | |||||
| return parser | |||||
| def select_reader_dev_data(args): | |||||
| """select reader dev data from result of retriever based on result of reranker""" | |||||
| rerank_result_file = args.rerank_result_file | |||||
| rerank_feature_file = args.rerank_feature_file | |||||
| rerank_example_file = args.rerank_example_file | |||||
| reader_feature_file = args.reader_feature_file | |||||
| reader_example_file = args.reader_example_file | |||||
| with gzip.open(rerank_example_file, "rb") as f: | |||||
| dev_examples = pickle.load(f) | |||||
| with gzip.open(rerank_feature_file, "rb") as f: | |||||
| dev_features = pickle.load(f) | |||||
| with open(rerank_result_file, "r") as f: | |||||
| rerank_result = json.load(f) | |||||
| new_dev_examples = [] | |||||
| new_dev_features = [] | |||||
| rerank_unique_ids = defaultdict(int) | |||||
| feature_unique_ids = defaultdict(int) | |||||
| for _, res in tqdm(rerank_result.items(), desc="get rerank unique ids"): | |||||
| rerank_unique_ids[res[0]] = True | |||||
| print(f"rerank result num is {len(rerank_unique_ids)}") | |||||
| for feature in tqdm(dev_features, desc="select rerank top1 feature"): | |||||
| if feature.unique_id in rerank_unique_ids: | |||||
| feature_unique_ids[feature.unique_id] = True | |||||
| new_dev_features.append(feature) | |||||
| print(f"new feature num is {len(new_dev_features)}") | |||||
| for example in tqdm(dev_examples, desc="select rerank top1 example"): | |||||
| if example.unique_id in rerank_unique_ids and example.unique_id in feature_unique_ids: | |||||
| new_dev_examples.append(example) | |||||
| print(f"new examples num is {len(new_dev_examples)}") | |||||
| print("start save new examples ......") | |||||
| with gzip.open(reader_example_file, "wb") as f: | |||||
| pickle.dump(new_dev_examples, f) | |||||
| print("start save new features ......") | |||||
| with gzip.open(reader_feature_file, "wb") as f: | |||||
| pickle.dump(new_dev_features, f) | |||||
| print("finish selecting reader data !!!") | |||||
| def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False): | |||||
| """Project the tokenized prediction back to the original text.""" | |||||
| def _strip_spaces(text): | |||||
| ns_chars = [] | |||||
| ns_to_s_map = collections.OrderedDict() | |||||
| for (i, c) in enumerate(text): | |||||
| if c == " ": | |||||
| continue | |||||
| ns_to_s_map[len(ns_chars)] = i | |||||
| ns_chars.append(c) | |||||
| ns_text = "".join(ns_chars) | |||||
| return (ns_text, ns_to_s_map) | |||||
| tokenizer = BasicTokenizer(do_lower_case=do_lower_case) | |||||
| tok_text = " ".join(tokenizer.tokenize(orig_text)) | |||||
| start_position = tok_text.find(pred_text) | |||||
| if start_position == -1: | |||||
| if verbose_logging: | |||||
| print("Unable to find text: '%s' in '%s'" % (pred_text, orig_text)) | |||||
| return orig_text | |||||
| end_position = start_position + len(pred_text) - 1 | |||||
| (orig_ns_text, orig_ns_to_s_map) = _strip_spaces(orig_text) | |||||
| (tok_ns_text, tok_ns_to_s_map) = _strip_spaces(tok_text) | |||||
| if len(orig_ns_text) != len(tok_ns_text): | |||||
| if verbose_logging: | |||||
| logger.info("Length not equal after stripping spaces: '%s' vs '%s'", | |||||
| orig_ns_text, tok_ns_text) | |||||
| return orig_text | |||||
| # We then project the characters in `pred_text` back to `orig_text` using | |||||
| # the character-to-character alignment. | |||||
| tok_s_to_ns_map = {} | |||||
| for (i, tok_index) in tok_ns_to_s_map.items(): | |||||
| tok_s_to_ns_map[tok_index] = i | |||||
| orig_start_position = None | |||||
| if start_position in tok_s_to_ns_map: | |||||
| ns_start_position = tok_s_to_ns_map[start_position] | |||||
| if ns_start_position in orig_ns_to_s_map: | |||||
| orig_start_position = orig_ns_to_s_map[ns_start_position] | |||||
| if orig_start_position is None: | |||||
| if verbose_logging: | |||||
| print("Couldn't map start position") | |||||
| return orig_text | |||||
| orig_end_position = None | |||||
| if end_position in tok_s_to_ns_map: | |||||
| ns_end_position = tok_s_to_ns_map[end_position] | |||||
| if ns_end_position in orig_ns_to_s_map: | |||||
| orig_end_position = orig_ns_to_s_map[ns_end_position] | |||||
| if orig_end_position is None: | |||||
| if verbose_logging: | |||||
| print("Couldn't map end position") | |||||
| return orig_text | |||||
| output_text = orig_text[orig_start_position:(orig_end_position + 1)] | |||||
| return output_text | |||||
| def get_ans_from_pos(tokenizer, examples, features, y1, y2, unique_id): | |||||
| """get answer text from predicted position""" | |||||
| feature = features[unique_id] | |||||
| example = examples[unique_id] | |||||
| tok_to_orig_map = feature.token_to_orig_map | |||||
| orig_all_tokens = example.question_tokens + example.doc_tokens | |||||
| final_text = " " | |||||
| if y1 < len(tok_to_orig_map) and y2 < len(tok_to_orig_map): | |||||
| orig_tok_start = tok_to_orig_map[y1] | |||||
| orig_tok_end = tok_to_orig_map[y2] | |||||
| # -----------------orig all tokens----------------------------------- | |||||
| orig_tokens = orig_all_tokens[orig_tok_start: (orig_tok_end + 1)] | |||||
| tok_tokens = feature.doc_tokens[y1: y2 + 1] | |||||
| tok_text = tokenizer.convert_tokens_to_string(tok_tokens) | |||||
| # Clean whitespace | |||||
| tok_text = tok_text.strip() | |||||
| tok_text = " ".join(tok_text.split()) | |||||
| orig_text = " ".join(orig_tokens) | |||||
| final_text = get_final_text(tok_text, orig_text, True, False) | |||||
| # print("final_text: " + final_text) | |||||
| return final_text | |||||
| def convert_to_tokens(examples, features, ids, y1, y2, q_type_prob, tokenizer, sent, sent_names, | |||||
| unique_ids): | |||||
| """get raw answer text and supporting sentences""" | |||||
| answer_dict = defaultdict(list) | |||||
| q_type = np.argmax(q_type_prob, 1) | |||||
| for i, qid in enumerate(ids): | |||||
| unique_id = unique_ids[i] | |||||
| if q_type[i] == 0: | |||||
| answer_text = 'yes' | |||||
| elif q_type[i] == 1: | |||||
| answer_text = 'no' | |||||
| elif q_type[i] == 2: | |||||
| answer_text = get_ans_from_pos(tokenizer, examples, features, y1[i], y2[i], unique_id) | |||||
| else: | |||||
| raise ValueError("question type error") | |||||
| answer_dict[qid].append(answer_text) | |||||
| answer_dict[qid].append(sent[i]) | |||||
| answer_dict[qid].append(sent_names[i]) | |||||
| return answer_dict | |||||
| def normalize_title(text): | |||||
| """Resolve different type of unicode encodings / capitarization in HotpotQA data.""" | |||||
| text = unicodedata.normalize('NFD', text) | |||||
| return text[0].capitalize() + text[1:] | |||||
| def make_wiki_id(title, para_index): | |||||
| """make wiki id""" | |||||
| title_id = "{0}_{1}".format(normalize_title(title), para_index) | |||||
| return title_id | |||||
| def cal_reranker_metrics(dev_gold_file, rerank_result_file): | |||||
| """function for calculating reranker's metrics""" | |||||
| with open(dev_gold_file, 'rb') as f: | |||||
| gt = json.load(f) | |||||
| with open(rerank_result_file, 'rb') as f: | |||||
| rerank_result = json.load(f) | |||||
| cnt = 0 | |||||
| all_ = len(gt) | |||||
| cnt_c = 0 | |||||
| cnt_b = 0 | |||||
| all_c = 0 | |||||
| all_b = 0 | |||||
| for item in tqdm(gt, desc="get com and bridge "): | |||||
| q_type = item["type"] | |||||
| if q_type == "comparison": | |||||
| all_c += 1 | |||||
| elif q_type == "bridge": | |||||
| all_b += 1 | |||||
| else: | |||||
| print(f"{q_type} is a error question type!!!") | |||||
| for item in tqdm(gt, desc="cal pem"): | |||||
| _id = item["_id"] | |||||
| if _id in rerank_result: | |||||
| pred = rerank_result[_id][1] | |||||
| sps = item["supporting_facts"] | |||||
| q_type = item["type"] | |||||
| gold = [] | |||||
| for t in sps: | |||||
| gold.append(normalize_title(t[0])) | |||||
| gold = set(gold) | |||||
| flag = True | |||||
| for t in gold: | |||||
| if t not in pred: | |||||
| flag = False | |||||
| break | |||||
| if flag: | |||||
| cnt += 1 | |||||
| if q_type == "comparison": | |||||
| cnt_c += 1 | |||||
| elif q_type == "bridge": | |||||
| cnt_b += 1 | |||||
| else: | |||||
| print(f"{q_type} is a error question type!!!") | |||||
| return cnt/all_, cnt_c/all_c, cnt_b/all_b | |||||
| def whitespace_tokenize(text): | |||||
| """Runs basic whitespace cleaning and splitting on a piece of text.""" | |||||
| text = text.strip() | |||||
| if not text: | |||||
| return [] | |||||
| tokens = text.split() | |||||
| return tokens | |||||
| def find_hyper_linked_titles(text_w_links): | |||||
| """find hyperlinked titles""" | |||||
| titles = re.findall(r'href=[\'"]?([^\'" >]+)', text_w_links) | |||||
| titles = [unquote(title) for title in titles] | |||||
| titles = [title[0].capitalize() + title[1:] for title in titles] | |||||
| return titles | |||||
| def normalize_text(text): | |||||
| """Resolve different type of unicode encodings / capitarization in HotpotQA data.""" | |||||
| text = unicodedata.normalize('NFD', text) | |||||
| return text | |||||
| def convert_char_to_token_offset(orig_text, start_offset, end_offset, char_to_word_offset, doc_tokens): | |||||
| """build characters' offset""" | |||||
| length = len(orig_text) | |||||
| assert start_offset + length == end_offset | |||||
| assert end_offset <= len(char_to_word_offset) | |||||
| start_position = char_to_word_offset[start_offset] | |||||
| end_position = char_to_word_offset[start_offset + length - 1] | |||||
| actual_text = " ".join( | |||||
| doc_tokens[start_position:(end_position + 1)]) | |||||
| assert actual_text.lower().find(orig_text.lower()) != -1 | |||||
| return start_position, end_position | |||||
| def _is_whitespace(c): | |||||
| """check whitespace""" | |||||
| if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F: | |||||
| return True | |||||
| return False | |||||
| def convert_text_to_tokens(context_text, return_word_start=False): | |||||
| """convert text to tokens""" | |||||
| doc_tokens = [] | |||||
| char_to_word_offset = [] | |||||
| words_start_idx = [] | |||||
| prev_is_whitespace = True | |||||
| for idx, c in enumerate(context_text): | |||||
| if _is_whitespace(c): | |||||
| prev_is_whitespace = True | |||||
| else: | |||||
| if prev_is_whitespace: | |||||
| doc_tokens.append(c) | |||||
| words_start_idx.append(idx) | |||||
| else: | |||||
| doc_tokens[-1] += c | |||||
| prev_is_whitespace = False | |||||
| char_to_word_offset.append(len(doc_tokens) - 1) | |||||
| if not return_word_start: | |||||
| return doc_tokens, char_to_word_offset | |||||
| return doc_tokens, char_to_word_offset, words_start_idx | |||||
| def read_json(eval_file_name): | |||||
| """reader json files""" | |||||
| print("loading examples from {0}".format(eval_file_name)) | |||||
| with open(eval_file_name) as reader: | |||||
| lines = json.load(reader) | |||||
| return lines | |||||
| def write_json(data, out_file_name): | |||||
| """write json files""" | |||||
| print("writing {0} examples to {1}".format(len(data), out_file_name)) | |||||
| with open(out_file_name, 'w') as writer: | |||||
| json.dump(data, writer, indent=4) | |||||
| def get_edges(sentence): | |||||
| """get edges""" | |||||
| EDGE_XY = re.compile(r'<a href="(?!http|<a)(.*?)">(.*?)<\/a>') | |||||
| ret = EDGE_XY.findall(sentence) | |||||
| return [(unquote(x), y) for x, y in ret] | |||||
| def relocate_tok_span(orig_to_tok_index, orig_to_tok_back_index, word_tokens, subword_tokens, | |||||
| orig_start_position, orig_end_position, orig_text, tokenizer, tok_to_orig_index=None): | |||||
| """relocate tokens' span""" | |||||
| if orig_start_position is None: | |||||
| return 0, 0 | |||||
| tok_start_position = orig_to_tok_index[orig_start_position] | |||||
| if tok_start_position >= len(subword_tokens): | |||||
| return 0, 0 | |||||
| if orig_end_position < len(word_tokens) - 1: | |||||
| tok_end_position = orig_to_tok_back_index[orig_end_position] | |||||
| if tok_to_orig_index and tok_to_orig_index[tok_end_position + 1] == -1: | |||||
| assert tok_end_position <= orig_to_tok_index[orig_end_position + 1] - 2 | |||||
| else: | |||||
| assert tok_end_position == orig_to_tok_index[orig_end_position + 1] - 1 | |||||
| else: | |||||
| tok_end_position = orig_to_tok_back_index[orig_end_position] | |||||
| return _improve_answer_span( | |||||
| subword_tokens, tok_start_position, tok_end_position, tokenizer, orig_text) | |||||
| def generate_mapping(length, positions): | |||||
| """generate mapping""" | |||||
| start_mapping = [0] * length | |||||
| end_mapping = [0] * length | |||||
| for _, (start, end) in enumerate(positions): | |||||
| start_mapping[start] = 1 | |||||
| end_mapping[end] = 1 | |||||
| return start_mapping, end_mapping | |||||
| def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer, orig_answer_text): | |||||
| """Returns tokenized answer spans that better match the annotated answer.""" | |||||
| tok_answer_text = " ".join(tokenizer.tokenize(orig_answer_text, add_prefix_space=True)) | |||||
| for new_start in range(input_start, input_end + 1): | |||||
| for new_end in range(input_end, new_start - 1, -1): | |||||
| text_span = " ".join(doc_tokens[new_start: (new_end + 1)]) | |||||
| if text_span == tok_answer_text: | |||||
| return new_start, new_end | |||||
| return input_start, input_end | |||||
| def _largest_valid_index(spans, limit): | |||||
| """return largest valid index""" | |||||
| for idx, _ in enumerate(spans): | |||||
| if spans[idx][1] >= limit: | |||||
| return idx | |||||
| return len(spans) | |||||
| def remove_punc(text): | |||||
| """remove punctuation""" | |||||
| if text == " ": | |||||
| return '' | |||||
| exclude = set(string.punctuation) | |||||
| return ''.join(ch for ch in text if ch not in exclude) | |||||
| def check_text_include_ans(ans, text): | |||||
| """check whether text include answer""" | |||||
| if normalize_answer(ans) in normalize_answer(text): | |||||
| return True | |||||
| return False | |||||
| def remove_articles(text): | |||||
| """remove articles""" | |||||
| return re.sub(r'\b(a|an|the)\b', ' ', text) | |||||
| def white_space_fix(text): | |||||
| """fix whitespace""" | |||||
| return ' '.join(text.split()) | |||||
| def lower(text): | |||||
| """lower text""" | |||||
| return text.lower() | |||||
| def normalize_answer(s): | |||||
| """Lower text and remove punctuation, articles and extra whitespace.""" | |||||
| return white_space_fix(remove_articles(remove_punc(lower(s)))) | |||||
| @@ -0,0 +1,61 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """downstream Model for reranker""" | |||||
| import numpy as np | |||||
| from mindspore import nn | |||||
| from mindspore import Tensor, Parameter | |||||
| from mindspore.ops import operations as P | |||||
| class Rerank_Downstream(nn.Cell): | |||||
| """Downstream model for rerank""" | |||||
| def __init__(self): | |||||
| """init function""" | |||||
| super(Rerank_Downstream, self).__init__() | |||||
| self.dense_0 = nn.Dense(in_channels=4096, out_channels=8192, has_bias=True) | |||||
| self.relu_1 = nn.ReLU() | |||||
| self.reducemean_2 = P.ReduceMean(keep_dims=True) | |||||
| self.sub_3 = P.Sub() | |||||
| self.sub_4 = P.Sub() | |||||
| self.pow_5 = P.Pow() | |||||
| self.pow_5_input_weight = 2.0 | |||||
| self.reducemean_6 = P.ReduceMean(keep_dims=True) | |||||
| self.add_7 = P.Add() | |||||
| self.add_7_bias = 9.999999960041972e-13 | |||||
| self.sqrt_8 = P.Sqrt() | |||||
| self.div_9 = P.Div() | |||||
| self.mul_10 = P.Mul() | |||||
| self.mul_10_w = Parameter(Tensor(np.random.uniform(0, 1, (8192,)).astype(np.float32)), name=None) | |||||
| self.add_11 = P.Add() | |||||
| self.add_11_bias = Parameter(Tensor(np.random.uniform(0, 1, (8192,)).astype(np.float32)), name=None) | |||||
| self.dense_12 = nn.Dense(in_channels=8192, out_channels=2, has_bias=True) | |||||
| def construct(self, x): | |||||
| """construct function""" | |||||
| opt_dense_0 = self.dense_0(x) | |||||
| opt_relu_1 = self.relu_1(opt_dense_0) | |||||
| opt_reducemean_2 = self.reducemean_2(opt_relu_1, -1) | |||||
| opt_sub_3 = self.sub_3(opt_relu_1, opt_reducemean_2) | |||||
| opt_sub_4 = self.sub_4(opt_relu_1, opt_reducemean_2) | |||||
| opt_pow_5 = self.pow_5(opt_sub_3, self.pow_5_input_weight) | |||||
| opt_reducemean_6 = self.reducemean_6(opt_pow_5, -1) | |||||
| opt_add_7 = self.add_7(opt_reducemean_6, self.add_7_bias) | |||||
| opt_sqrt_8 = self.sqrt_8(opt_add_7) | |||||
| opt_div_9 = self.div_9(opt_sub_4, opt_sqrt_8) | |||||
| opt_mul_10 = self.mul_10(self.mul_10_w, opt_div_9) | |||||
| opt_add_11 = self.add_11(opt_mul_10, self.add_11_bias) | |||||
| opt_dense_12 = self.dense_12(opt_add_11) | |||||
| return opt_dense_12 | |||||
| @@ -0,0 +1,45 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Reranker Model""" | |||||
| import mindspore.nn as nn | |||||
| from mindspore import load_checkpoint, load_param_into_net | |||||
| from src.rerank_albert_xxlarge import Rerank_Albert | |||||
| from src.rerank_downstream import Rerank_Downstream | |||||
| class Reranker(nn.Cell): | |||||
| """Reranker model""" | |||||
| def __init__(self, batch_size, encoder_ck_file, downstream_ck_file): | |||||
| """init function""" | |||||
| super(Reranker, self).__init__(auto_prefix=False) | |||||
| self.encoder = Rerank_Albert(batch_size) | |||||
| param_dict = load_checkpoint(encoder_ck_file) | |||||
| not_load_params_1 = load_param_into_net(self.encoder, param_dict) | |||||
| print(f"not loaded albert: {not_load_params_1}") | |||||
| self.no_answer_mlp = Rerank_Downstream() | |||||
| param_dict = load_checkpoint(downstream_ck_file) | |||||
| not_load_params_2 = load_param_into_net(self.no_answer_mlp, param_dict) | |||||
| print(f"not loaded downstream: {not_load_params_2}") | |||||
| def construct(self, input_ids, attn_mask, token_type_ids): | |||||
| """construct function""" | |||||
| state = self.encoder(input_ids, attn_mask, token_type_ids) | |||||
| state = state[:, 0, :] | |||||
| no_answer = self.no_answer_mlp(state) | |||||
| return no_answer | |||||
| @@ -0,0 +1,85 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """execute reranker""" | |||||
| import json | |||||
| import random | |||||
| from collections import defaultdict | |||||
| from time import time | |||||
| from tqdm import tqdm | |||||
| import numpy as np | |||||
| from mindspore import Tensor, ops | |||||
| from mindspore import dtype as mstype | |||||
| from src.rerank_and_reader_data_generator import DataGenerator | |||||
| from src.reranker import Reranker | |||||
| def rerank(args): | |||||
| """rerank function""" | |||||
| rerank_feature_file = args.rerank_feature_file | |||||
| rerank_result_file = args.rerank_result_file | |||||
| encoder_ck_file = args.rerank_encoder_ck_file | |||||
| downstream_ck_file = args.rerank_downstream_ck_file | |||||
| seed = args.seed | |||||
| seq_len = args.seq_len | |||||
| batch_size = args.rerank_batch_size | |||||
| random.seed(seed) | |||||
| np.random.seed(seed) | |||||
| t1 = time() | |||||
| generator = DataGenerator(feature_file_path=rerank_feature_file, | |||||
| example_file_path=None, | |||||
| batch_size=batch_size, seq_len=seq_len, | |||||
| task_type="reranker") | |||||
| gather_dict = defaultdict(lambda: defaultdict(list)) | |||||
| reranker = Reranker(batch_size=batch_size, | |||||
| encoder_ck_file=encoder_ck_file, | |||||
| downstream_ck_file=downstream_ck_file) | |||||
| print("start re-ranking ...") | |||||
| for _, batch in tqdm(enumerate(generator)): | |||||
| input_ids = Tensor(batch["context_idxs"], mstype.int32) | |||||
| attn_mask = Tensor(batch["context_mask"], mstype.int32) | |||||
| token_type_ids = Tensor(batch["segment_idxs"], mstype.int32) | |||||
| no_answer = reranker(input_ids, attn_mask, token_type_ids) | |||||
| no_answer_prob = ops.Softmax()(no_answer).asnumpy() | |||||
| no_answer_prob = no_answer_prob[:, 0] | |||||
| for i in range(len(batch['ids'])): | |||||
| qas_id = batch['ids'][i] | |||||
| gather_dict[qas_id][no_answer_prob[i]].append(batch['unique_ids'][i]) | |||||
| gather_dict[qas_id][no_answer_prob[i]].append(batch['path'][i]) | |||||
| rerank_result = {} | |||||
| for qas_id in tqdm(gather_dict, desc="get top1 path from re-rank result"): | |||||
| all_paths = gather_dict[qas_id] | |||||
| all_paths = sorted(all_paths.items(), key=lambda item: item[0]) | |||||
| assert qas_id not in rerank_result | |||||
| rerank_result[qas_id] = all_paths[0][1] | |||||
| with open(rerank_result_file, 'w') as f: | |||||
| json.dump(rerank_result, f) | |||||
| t2 = time() | |||||
| print(f"re-rank cost time: {t2-t1} s") | |||||