diff --git a/model_zoo/research/nlp/tprr/README.md b/model_zoo/research/nlp/tprr/README.md
index 25ffac30a8..d2a02a295b 100644
--- a/model_zoo/research/nlp/tprr/README.md
+++ b/model_zoo/research/nlp/tprr/README.md
@@ -38,6 +38,9 @@ Wikipedia data: the 2017 English Wikipedia dump version with bidirectional hyper
dev data: HotPotQA full wiki setting dev data with 7398 question-answer pairs.
dev tf-idf data: the candidates for each question in dev data which is originated from top-500 retrieved from 5M paragraphs of Wikipedia
through TF-IDF.
+The dataset of re-ranker consists of two parts:
+Wikipedia data: the 2017 English Wikipedia dump version.
+dev data: HotPotQA full wiki setting dev data with 7398 question-answer pairs.
# [Features](#contents)
@@ -64,6 +67,7 @@ After installing MindSpore via the official website and Dataset is correctly gen
```python
# run evaluation example with HotPotQA dev dataset
sh run_eval_ascend.sh
+ sh run_eval_ascend_reranker_reader.sh
```
# [Script Description](#contents)
@@ -75,25 +79,39 @@ After installing MindSpore via the official website and Dataset is correctly gen
└─tprr
├─README.md
├─scripts
- | ├─run_eval_ascend.sh # Launch evaluation in ascend
+ | ├─run_eval_ascend.sh # Launch retriever evaluation in ascend
+ | └─run_eval_ascend_reranker_reader # Launch re-ranker and reader evaluation in ascend
|
├─src
- | ├─config.py # Evaluation configurations
- | ├─onehop.py # Onehop model
- | ├─onehop_bert.py # Onehop bert model
- | ├─process_data.py # Data preprocessing
- | ├─twohop.py # Twohop model
- | ├─twohop_bert.py # Twohop bert model
- | └─utils.py # Utils for evaluation
+ | ├─build_reranker_data.py # build data for re-ranker from result of retriever
+ | ├─config.py # Evaluation configurations for retriever
+ | ├─hotpot_evaluate_v1.py # Hotpotqa evaluation script
+ | ├─onehop.py # Onehop model of retriever
+ | ├─onehop_bert.py # Onehop bert model of retriever
+ | ├─process_data.py # Data preprocessing for retriever
+ | ├─reader.py # Reader model
+ | ├─reader_albert_xxlarge.py # Albert-xxlarge module of reader model
+ | ├─reader_downstream.py # Downstream module of reader model
+ | ├─reader_eval.py # Reader evaluation script
+ | ├─rerank_albert_xxlarge.py # Albert-xxlarge module of re-ranker model
+ | ├─rerank_and_reader_data_generator.py # Data generator for re-ranker and reader
+ | ├─rerank_and_reader_utils.py # Utils for re-ranker and reader
+ | ├─rerank_downstream.py # Downstream module of re-ranker model
+ | ├─reranker.py # Re-ranker model
+ | ├─reranker_eval.py # Re-ranker evaluation script
+ | ├─twohop.py # Twohop model of retriever
+ | ├─twohop_bert.py # Twohop bert model of retriever
+ | └─utils.py # Utils for retriever
|
- └─retriever_eval.py # Evaluation net for retriever
+ ├─retriever_eval.py # Evaluation net for retriever
+ └─reranker_and_reader_eval.py # Evaluation net for re-ranker and reader
```
## [Script Parameters](#contents)
-Parameters for evaluation can be set in config.py.
+Parameters for retriever evaluation can be set in config.py.
-- config for TPRR retriever dataset
+- config for TPRR retriever
```python
"q_len": 64, # Max query length
@@ -108,17 +126,30 @@ Parameters for evaluation can be set in config.py.
config.py for more configuration.
+Parameters for re-ranker and reader evaluation can be passed directly at execution time.
+
+- parameters for TPRR re-ranker and reader
+
+ ```python
+ "seq_len": 512, # sequence length
+ "rerank_batch_size": 32, # batch size for re-ranker evaluation
+ "reader_batch_size": 448, # batch size for reader evaluation
+ "sp_threshold": 8 # threshold for picking supporting sentence
+ ```
+
+ config.py for more configuration.
+
## [Evaluation Process](#contents)
### Evaluation
-- Evaluation on Ascend
+- Retriever evaluation on Ascend
```python
sh run_eval_ascend.sh
```
- Evaluation result will be stored in the scripts path, whose folder name begins with "eval". You can find the result like the
+ Evaluation result will be stored in the scripts path, whose folder name begins with "eval_tr". You can find the result like the
followings in log.
```python
@@ -138,6 +169,35 @@ Parameters for evaluation can be set in config.py.
evaluation time (h): 20.155506462653477
```
+- Re-ranker and reader evaluation on Ascend
+
+ Use the output of retriever as input of re-ranker
+
+ ```python
+ sh run_eval_ascend_reranker_reader.sh
+ ```
+
+ Evaluation result will be stored in the scripts path, whose folder name begins with "eval". You can find the result like the
+ followings in log.
+
+ ```python
+ total top1 pem: 0.8803511141120864
+
+ ...
+ em: 0.67440918298447
+ f1: 0.8025625656569652
+ prec: 0.8292800393689271
+ recall: 0.8136908451841731
+ sp_em: 0.6009453072248481
+ sp_f1: 0.844555664157302
+ sp_prec: 0.8640844345841021
+ sp_recall: 0.8446123918845106
+ joint_em: 0.4537474679270763
+ joint_f1: 0.715119580346802
+ joint_prec: 0.7540052057184267
+ joint_recall: 0.7250240424067661
+ ```
+
# [Model Description](#contents)
## [Performance](#contents)
@@ -154,6 +214,8 @@ Parameters for evaluation can be set in config.py.
| Batch_size | 1 |
| Output | inference path |
| PEM | 0.9188 |
+| total top1 pem | 0.88 |
+| joint_f1 | 0.7151 |
# [Description of random situation](#contents)
diff --git a/model_zoo/research/nlp/tprr/reranker_and_reader_eval.py b/model_zoo/research/nlp/tprr/reranker_and_reader_eval.py
new file mode 100644
index 0000000000..9724a52fa4
--- /dev/null
+++ b/model_zoo/research/nlp/tprr/reranker_and_reader_eval.py
@@ -0,0 +1,55 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""main file"""
+
+from mindspore import context
+from src.rerank_and_reader_utils import get_parse, cal_reranker_metrics, select_reader_dev_data
+from src.reranker_eval import rerank
+from src.reader_eval import read
+from src.hotpot_evaluate_v1 import hotpotqa_eval
+from src.build_reranker_data import get_rerank_data
+
+
+def rerank_and_retriever_eval():
+ """main function"""
+ context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
+ parser = get_parse()
+ args = parser.parse_args()
+
+ if args.get_reranker_data:
+ get_rerank_data(args)
+
+ if args.run_reranker:
+ rerank(args)
+
+ if args.cal_reranker_metrics:
+ total_top1_pem, _, _ = \
+ cal_reranker_metrics(dev_gold_file=args.dev_gold_file, rerank_result_file=args.rerank_result_file)
+ print(f"total top1 pem: {total_top1_pem}")
+
+ if args.select_reader_data:
+ select_reader_dev_data(args)
+
+ if args.run_reader:
+ read(args)
+
+ if args.cal_reader_metrics:
+ metrics = hotpotqa_eval(args.reader_result_file, args.dev_gold_file)
+ for k in metrics:
+ print(f"{k}: {metrics[k]}")
+
+
+if __name__ == "__main__":
+ rerank_and_retriever_eval()
diff --git a/model_zoo/research/nlp/tprr/scripts/run_eval_ascend.sh b/model_zoo/research/nlp/tprr/scripts/run_eval_ascend.sh
index 4d1f2e5c1f..7d5d46b062 100644
--- a/model_zoo/research/nlp/tprr/scripts/run_eval_ascend.sh
+++ b/model_zoo/research/nlp/tprr/scripts/run_eval_ascend.sh
@@ -21,16 +21,16 @@ export DEVICE_NUM=1
export RANK_SIZE=$DEVICE_NUM
export RANK_ID=0
-if [ -d "eval" ];
+if [ -d "eval_tr" ];
then
- rm -rf ./eval
+ rm -rf ./eval_tr
fi
-mkdir ./eval
+mkdir ./eval_tr
-cp ../*.py ./eval
-cp *.sh ./eval
-cp -r ../src ./eval
-cd ./eval || exit
+cp ../*.py ./eval_tr
+cp *.sh ./eval_tr
+cp -r ../src ./eval_tr
+cd ./eval_tr || exit
env > env.log
echo "start evaluation"
diff --git a/model_zoo/research/nlp/tprr/scripts/run_eval_ascend_reranker_reader.sh b/model_zoo/research/nlp/tprr/scripts/run_eval_ascend_reranker_reader.sh
new file mode 100644
index 0000000000..02435d5389
--- /dev/null
+++ b/model_zoo/research/nlp/tprr/scripts/run_eval_ascend_reranker_reader.sh
@@ -0,0 +1,39 @@
+#!/bin/bash
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+# eval script
+
+ulimit -u unlimited
+export DEVICE_NUM=1
+export RANK_SIZE=$DEVICE_NUM
+export RANK_ID=0
+
+if [ -d "eval" ];
+then
+ rm -rf ./eval
+fi
+mkdir ./eval
+
+cp ../*.py ./eval
+cp *.sh ./eval
+cp -r ../src ./eval
+cd ./eval || exit
+env > env.log
+echo "start evaluation"
+
+python reranker_and_reader_eval.py --get_reranker_data --run_reranker --cal_reranker_metrics --select_reader_data --run_reader --cal_reader_metrics > log_reranker_and_reader.txt 2>&1 &
+
+cd ..
diff --git a/model_zoo/research/nlp/tprr/src/build_reranker_data.py b/model_zoo/research/nlp/tprr/src/build_reranker_data.py
new file mode 100644
index 0000000000..47bc2ed6be
--- /dev/null
+++ b/model_zoo/research/nlp/tprr/src/build_reranker_data.py
@@ -0,0 +1,430 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""build reranker data from retriever result"""
+import pickle
+import gzip
+from tqdm import tqdm
+
+from src.rerank_and_reader_utils import read_json, make_wiki_id, convert_text_to_tokens, normalize_title, \
+ whitespace_tokenize, DocDB, _largest_valid_index, generate_mapping, InputFeatures, Example
+
+from transformers import AutoTokenizer
+
+
+def judge_para(data):
+ """judge whether is valid para"""
+ for _, para_tokens in data["context"].items():
+ if len(para_tokens) == 1:
+ return False
+ return True
+
+
+def judge_sp(data, sent_name2id, para2id):
+ """judge whether is valid sp"""
+ for sp in data['sp']:
+ title = normalize_title(sp[0])
+ name = normalize_title(sp[0]) + '_{}'.format(sp[1])
+ if title in para2id and name not in sent_name2id:
+ return False
+ return True
+
+
+def judge(path, path_set, reverse=False, golds=None, mode='or'):
+ """judge function"""
+ if path[0] == path[-1]:
+ return False
+ if path in path_set:
+ return False
+ if reverse and path[::-1] in path_set:
+ return False
+ if not golds:
+ return True
+ if mode == 'or':
+ return any(gold not in path for gold in golds)
+ if mode == 'and':
+ return all(gold not in path for gold in golds)
+ return False
+
+
+def get_context_and_sents(path, doc_db):
+ """get context ans sentences"""
+ context = {}
+ sents = {}
+ for title in path:
+ para_info = doc_db.get_doc_info(title)
+ if title.endswith('_0'):
+ title = title[:-2]
+ context[title] = pickle.loads(para_info[1])
+ sents[title] = pickle.loads(para_info[2])
+ return context, sents
+
+
+def gen_dev_data(dev_file, db_path, topk_file):
+ """generate dev data"""
+ # ----------------------------------------db info-----------------------------------------------
+ topk_data = read_json(topk_file) # path
+ doc_db = DocDB(db_path) # db get offset
+ print('load db successfully!')
+
+ # ---------------------------------------------supervision ------------------------------------------
+ dev_data = read_json(dev_file)
+ qid2sp = {}
+ qid2ans = {}
+ qid2type = {}
+ qid2path = {}
+ for _, data in enumerate(dev_data):
+ sp_facts = data['supporting_facts'] if 'supporting_facts' in data else None
+ qid2sp[data['_id']] = sp_facts
+ qid2ans[data['_id']] = data['answer'] if 'answer' in data else None
+ qid2type[data['_id']] = data['type'] if 'type' in data else None
+ qid2path[data['_id']] = list(set(list(zip(*sp_facts))[0])) if sp_facts else None
+
+ new_dev_data = []
+
+ for _, data in enumerate(tqdm(topk_data)):
+ qid = data['q_id']
+ question = data['question']
+ topk_titles = data['topk_titles']
+ gold_path = list(map(normalize_title, qid2path[qid])) if qid2path[qid] else None
+
+ all_titles = []
+ for titles in topk_titles:
+ titles = list(map(normalize_title, titles))
+ if len(titles) == 1:
+ continue
+ path = titles[:2]
+ if judge(path, all_titles):
+ all_titles.append(titles[:2])
+ if len(titles) == 3:
+ path = titles[1:]
+ if judge(path, all_titles):
+ all_titles.append(titles[1:])
+
+ # --------------------------------------------------process query-----------------------------------
+
+ question = " ".join(whitespace_tokenize(question))
+ question = question.strip()
+ q_tokens, _ = convert_text_to_tokens(question)
+
+ gold_path = list(map(lambda x: make_wiki_id(x, 0), gold_path)) if gold_path else None
+ for path in all_titles:
+ context, sents = get_context_and_sents(path, doc_db)
+ ans_label = int(gold_path[0] in path and gold_path[1] in path) if gold_path else None
+
+ new_dev_data.append({
+ 'qid': qid,
+ 'type': qid2type[qid],
+ 'question': question,
+ 'q_tokens': q_tokens,
+ 'context': context,
+ 'sents': sents,
+ 'answer': qid2ans[qid],
+ 'sp': qid2sp[qid],
+ 'ans_para': None,
+ 'is_impossible': not ans_label == 1
+ })
+
+ return new_dev_data
+
+
+def read_hotpot_examples(path_data):
+ """reader examples"""
+ examples = []
+ max_sent_cnt = 0
+ failed = 0
+
+ for _, data in enumerate(path_data):
+ if not judge_para(data):
+ failed += 1
+ continue
+ question = data['question']
+ question = " ".join(whitespace_tokenize(question))
+ question = question.strip()
+ path = list(map(normalize_title, data["context"].keys()))
+ qid = data['qid']
+ q_tokens = data['q_tokens']
+
+ # -------------------------------------add para------------------------------------------------------------
+ doc_tokens = []
+
+ para_start_end_position = []
+ title_start_end_position = []
+ sent_start_end_position = []
+
+ sent_names = []
+ sent_name2id = {}
+ para2id = {}
+
+ for para, para_tokens in data["context"].items():
+ sents = data["sents"][para]
+
+ para = normalize_title(para)
+ title_tokens = convert_text_to_tokens(para)[0]
+ para_node_id = len(para_start_end_position)
+ para2id[para] = para_node_id
+
+ doc_offset = len(doc_tokens)
+ doc_tokens += title_tokens
+ doc_tokens += para_tokens
+
+ title_start_end_position.append((doc_offset, doc_offset + len(title_tokens) - 1))
+
+ doc_offset += len(title_tokens)
+ para_start_end_position.append((doc_offset, doc_offset + len(para_tokens) - 1, para))
+
+ for idx, sent in enumerate(sents):
+ if sent[0] == -1 and sent[1] == -1:
+ continue
+ sent_names.append([para, idx]) # local name
+ sent_node_id = len(sent_start_end_position)
+ sent_name2id[normalize_title(para) + '_{}'.format(str(idx))] = sent_node_id
+ sent_start_end_position.append((doc_offset + sent[0],
+ doc_offset + sent[1]))
+
+ # add sp and ans
+ sp_facts = []
+ sup_fact_id = []
+ for sp in sp_facts:
+ name = normalize_title(sp[0]) + '_{}'.format(sp[1])
+ if name in sent_name2id:
+ sup_fact_id.append(sent_name2id[name])
+
+ sup_para_id = set() # use set
+ if sp_facts:
+ for para in list(zip(*sp_facts))[0]:
+ para = normalize_title(para)
+ if para in para2id:
+ sup_para_id.add(para2id[para])
+ sup_para_id = list(sup_para_id)
+
+ example = Example(
+ qas_id=qid,
+ path=path,
+ unique_id=qid + '_' + '_'.join(path),
+ question_tokens=q_tokens,
+ doc_tokens=doc_tokens, # multi-para tokens w/o query
+ sent_names=sent_names,
+ sup_fact_id=sup_fact_id, # global sent id
+ sup_para_id=sup_para_id, # global para id
+ para_start_end_position=para_start_end_position,
+ sent_start_end_position=sent_start_end_position,
+ title_start_end_position=title_start_end_position,
+ question_text=question)
+
+ examples.append(example)
+ max_sent_cnt = max(max_sent_cnt, len(sent_start_end_position))
+
+ print(f"Maximum sentence cnt: {max_sent_cnt}")
+ print(f'failed examples: {failed}')
+ print(f'convert {len(examples)} examples successfully!')
+
+ return examples
+
+
+def add_sub_token(sub_tokens, idx, tok_to_orig_index, all_query_tokens):
+ """add sub tokens"""
+ for sub_token in sub_tokens:
+ tok_to_orig_index.append(idx)
+ all_query_tokens.append(sub_token)
+ return tok_to_orig_index, all_query_tokens
+
+
+def get_sent_spans(example, orig_to_tok_index, orig_to_tok_back_index):
+ """get sentences' spans"""
+ sentence_spans = []
+ for sent_span in example.sent_start_end_position:
+ sent_start_position = orig_to_tok_index[sent_span[0]]
+ sent_end_position = orig_to_tok_back_index[sent_span[1]]
+ sentence_spans.append((sent_start_position, sent_end_position + 1))
+ return sentence_spans
+
+
+def get_para_spans(example, orig_to_tok_index, orig_to_tok_back_index, all_doc_tokens, marker):
+ """get paragraphs' spans"""
+ para_spans = []
+ for title_span, para_span in zip(example.title_start_end_position, example.para_start_end_position):
+ para_start_position = orig_to_tok_index[title_span[0]]
+ para_end_position = orig_to_tok_back_index[para_span[1]]
+ if para_end_position + 1 < len(all_doc_tokens) and all_doc_tokens[para_end_position + 1] == \
+ marker['sent'][0]:
+ para_spans.append((para_start_position - 1, para_end_position + 1, para_span[2]))
+ else:
+ para_spans.append((para_start_position - 1, para_end_position, para_span[2]))
+ return para_spans
+
+
+def build_feature(example, all_doc_tokens, doc_input_ids, doc_input_mask, doc_segment_ids, all_query_tokens,
+ query_input_ids, query_input_mask, query_segment_ids, para_spans, sentence_spans, tok_to_orig_index):
+ """build a input feature"""
+ feature = InputFeatures(
+ qas_id=example.qas_id,
+ path=example.path,
+ unique_id=example.qas_id + '_' + '_'.join(example.path),
+ sent_names=example.sent_names,
+ doc_tokens=all_doc_tokens,
+ doc_input_ids=doc_input_ids,
+ doc_input_mask=doc_input_mask,
+ doc_segment_ids=doc_segment_ids,
+ query_tokens=all_query_tokens,
+ query_input_ids=query_input_ids,
+ query_input_mask=query_input_mask,
+ query_segment_ids=query_segment_ids,
+ para_spans=para_spans,
+ sent_spans=sentence_spans,
+ token_to_orig_map=tok_to_orig_index)
+ return feature
+
+
+def convert_example_to_features(tokenizer, args, examples):
+ """convert examples to features"""
+ features = []
+ failed = 0
+ marker = {'q': ['[q]', '[/q]'], 'para': ['', ''], 'sent': ['[s]']}
+
+ for (_, example) in enumerate(tqdm(examples)):
+
+ all_query_tokens = [tokenizer.cls_token, marker['q'][0]]
+ tok_to_orig_index = [-1, -1] # orig: query + doc tokens
+ ques_orig_to_tok_index = [] # start position
+ ques_orig_to_tok_back_index = [] # end position
+ q_spans = []
+
+ # -------------------------------------------for query---------------------------------------------
+ for (idx, token) in enumerate(example.question_tokens):
+ sub_tokens = tokenizer.tokenize(token)
+
+ ques_orig_to_tok_index.append(len(all_query_tokens))
+ tok_to_orig_index, all_query_tokens = add_sub_token(sub_tokens, idx, tok_to_orig_index, all_query_tokens)
+ ques_orig_to_tok_back_index.append(len(all_query_tokens) - 1)
+
+ all_query_tokens = all_query_tokens[:63]
+ tok_to_orig_index = tok_to_orig_index[:63]
+ all_query_tokens.append(marker['q'][-1])
+ tok_to_orig_index.append(-1)
+ q_spans.append((1, len(all_query_tokens) - 1))
+
+ # ---------------------------------------add doc tokens------------------------------------------------
+ all_doc_tokens = []
+ orig_to_tok_index = [] # orig: token in doc
+ orig_to_tok_back_index = []
+ title_start_mapping, title_end_mapping = generate_mapping(len(example.doc_tokens),
+ example.title_start_end_position)
+ _, sent_end_mapping = generate_mapping(len(example.doc_tokens),
+ example.sent_start_end_position)
+ all_doc_tokens += all_query_tokens
+
+ for (idx, token) in enumerate(example.doc_tokens):
+ sub_tokens = tokenizer.tokenize(token)
+
+ if title_start_mapping[idx] == 1:
+ all_doc_tokens.append(marker['para'][0])
+ tok_to_orig_index.append(-1)
+
+ # orig: position in doc tokens tok: global tokenized tokens (start)
+ orig_to_tok_index.append(len(all_doc_tokens))
+ tok_to_orig_index, all_doc_tokens = add_sub_token(sub_tokens, idx + len(example.question_tokens),
+ tok_to_orig_index, all_doc_tokens)
+ orig_to_tok_back_index.append(len(all_doc_tokens) - 1)
+
+ if title_end_mapping[idx] == 1:
+ all_doc_tokens.append(marker['para'][1])
+ tok_to_orig_index.append(-1)
+
+ if sent_end_mapping[idx] == 1:
+ all_doc_tokens.append(marker['sent'][0])
+ tok_to_orig_index.append(-1)
+
+ # -----------------------------------for sentence-------------------------------------------------
+ sentence_spans = get_sent_spans(example, orig_to_tok_index, orig_to_tok_back_index)
+
+ # -----------------------------------for para-------------------------------------------------------
+ para_spans = get_para_spans(example, orig_to_tok_index, orig_to_tok_back_index, all_doc_tokens, marker)
+
+ # -----------------------------------remove sent > max seq length-----------------------------------------
+ sent_max_index = _largest_valid_index(sentence_spans, args.seq_len)
+ max_sent_cnt = len(sentence_spans)
+
+ if sent_max_index != len(sentence_spans):
+ if sent_max_index == 0:
+ failed += 0
+ continue
+ sentence_spans = sentence_spans[:sent_max_index]
+ max_tok_length = sentence_spans[-1][1] # max_tok_length [s]
+
+ # max end index: max_tok_length
+ para_max_index = _largest_valid_index(para_spans, max_tok_length + 1)
+ if para_max_index == 0: # only one para
+ failed += 0
+ continue
+ if orig_to_tok_back_index[example.title_start_end_position[1][1]] + 1 >= max_tok_length:
+ failed += 0
+ continue
+ max_para_span = para_spans[para_max_index]
+ para_spans = para_spans[:para_max_index]
+ para_spans.append((max_para_span[0], max_tok_length, max_para_span[2]))
+
+ all_doc_tokens = all_doc_tokens[:max_tok_length + 1]
+
+ sentence_spans = sentence_spans[:min(max_sent_cnt, args.max_sent_num)]
+
+ # ----------------------------------------Padding Document-----------------------------------------------------
+ if len(all_doc_tokens) > args.seq_len:
+ st, _, title = para_spans[-1]
+ para_spans[-1] = (st, args.seq_len - 1, title)
+ all_doc_tokens = all_doc_tokens[:args.seq_len - 1] + [marker['sent'][0]]
+
+ doc_input_ids = tokenizer.convert_tokens_to_ids(all_doc_tokens)
+ query_input_ids = tokenizer.convert_tokens_to_ids(all_query_tokens)
+
+ doc_input_mask = [1] * len(doc_input_ids)
+ doc_segment_ids = [0] * len(query_input_ids) + [1] * (len(doc_input_ids) - len(query_input_ids))
+
+ doc_pad_length = args.seq_len - len(doc_input_ids)
+ doc_input_ids += [0] * doc_pad_length
+ doc_input_mask += [0] * doc_pad_length
+ doc_segment_ids += [0] * doc_pad_length
+
+ # Padding Question
+ query_input_mask = [1] * len(query_input_ids)
+ query_segment_ids = [0] * len(query_input_ids)
+
+ query_pad_length = 64 - len(query_input_ids)
+ query_input_ids += [0] * query_pad_length
+ query_input_mask += [0] * query_pad_length
+ query_segment_ids += [0] * query_pad_length
+
+ feature = build_feature(example, all_doc_tokens, doc_input_ids, doc_input_mask, doc_segment_ids,
+ all_query_tokens, query_input_ids, query_input_mask, query_segment_ids, para_spans,
+ sentence_spans, tok_to_orig_index)
+ features.append(feature)
+ return features
+
+
+def get_rerank_data(args):
+ """function for generating reranker's data"""
+ new_dev_data = gen_dev_data(dev_file=args.dev_gold_file,
+ db_path=args.wiki_db_file,
+ topk_file=args.retriever_result_file)
+ tokenizer = AutoTokenizer.from_pretrained(args.albert_model_path)
+ new_tokens = ['[q]', '[/q]', '', '', '[s]']
+ tokenizer.add_tokens(new_tokens)
+
+ examples = read_hotpot_examples(new_dev_data)
+ features = convert_example_to_features(tokenizer=tokenizer, args=args, examples=examples)
+
+ with gzip.open(args.rerank_example_file, "wb") as f:
+ pickle.dump(examples, f)
+ with gzip.open(args.rerank_feature_file, "wb") as f:
+ pickle.dump(features, f)
diff --git a/model_zoo/research/nlp/tprr/src/config.py b/model_zoo/research/nlp/tprr/src/config.py
index 9af4e5d46f..edbe493968 100644
--- a/model_zoo/research/nlp/tprr/src/config.py
+++ b/model_zoo/research/nlp/tprr/src/config.py
@@ -33,14 +33,14 @@ def ThinkRetrieverConfig():
parser.add_argument("--batch_size", type=int, default=1, help="batch size")
parser.add_argument("--device_id", type=int, default=0, help="device id")
parser.add_argument("--save_name", type=str, default='doc_path', help='name of output')
- parser.add_argument("--save_path", type=str, default='./', help='path of output')
- parser.add_argument("--vocab_path", type=str, default='./scripts/vocab.txt', help="vocab path")
- parser.add_argument("--wiki_path", type=str, default='./scripts/db_docs_bidirection_new.pkl', help="wiki path")
- parser.add_argument("--dev_path", type=str, default='./scripts/hotpot_dev_fullwiki_v1_for_retriever.json',
+ parser.add_argument("--save_path", type=str, default='../', help='path of output')
+ parser.add_argument("--vocab_path", type=str, default='../vocab.txt', help="vocab path")
+ parser.add_argument("--wiki_path", type=str, default='../db_docs_bidirection_new.pkl', help="wiki path")
+ parser.add_argument("--dev_path", type=str, default='../hotpot_dev_fullwiki_v1_for_retriever.json',
help="dev path")
- parser.add_argument("--dev_data_path", type=str, default='./scripts/dev_tf_idf_data_raw.pkl', help="dev data path")
- parser.add_argument("--onehop_bert_path", type=str, default='./scripts/onehop.ckpt', help="onehop bert ckpt path")
- parser.add_argument("--onehop_mlp_path", type=str, default='./scripts/onehop_mlp.ckpt', help="onehop mlp ckpt path")
- parser.add_argument("--twohop_bert_path", type=str, default='./scripts/twohop.ckpt', help="twohop bert ckpt path")
- parser.add_argument("--twohop_mlp_path", type=str, default='./scripts/twohop_mlp.ckpt', help="twohop mlp ckpt path")
+ parser.add_argument("--dev_data_path", type=str, default='../dev_tf_idf_data_raw.pkl', help="dev data path")
+ parser.add_argument("--onehop_bert_path", type=str, default='../onehop.ckpt', help="onehop bert ckpt path")
+ parser.add_argument("--onehop_mlp_path", type=str, default='../onehop_mlp.ckpt', help="onehop mlp ckpt path")
+ parser.add_argument("--twohop_bert_path", type=str, default='../twohop.ckpt', help="twohop bert ckpt path")
+ parser.add_argument("--twohop_mlp_path", type=str, default='../twohop_mlp.ckpt', help="twohop mlp ckpt path")
return parser.parse_args()
diff --git a/model_zoo/research/nlp/tprr/src/hotpot_evaluate_v1.py b/model_zoo/research/nlp/tprr/src/hotpot_evaluate_v1.py
new file mode 100644
index 0000000000..ecb25a0250
--- /dev/null
+++ b/model_zoo/research/nlp/tprr/src/hotpot_evaluate_v1.py
@@ -0,0 +1,153 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""hotpotqa evaluate script"""
+
+import re
+import string
+from collections import Counter
+import ujson as json
+
+
+def normalize_answer(s):
+ """normalize answer"""
+ def remove_articles(text):
+ """remove articles"""
+ return re.sub(r'\b(a|an|the)\b', ' ', text)
+
+ def white_space_fix(text):
+ """fix whitespace"""
+ return ' '.join(text.split())
+
+ def remove_punc(text):
+ """remove punctuation from text"""
+ exclude = set(string.punctuation)
+ return ''.join(ch for ch in text if ch not in exclude)
+
+ def lower(text):
+ """lower text"""
+ return text.lower()
+
+ return white_space_fix(remove_articles(remove_punc(lower(s))))
+
+
+def f1_score(prediction, ground_truth):
+ """calculate f1 score"""
+ normalized_prediction = normalize_answer(prediction)
+ normalized_ground_truth = normalize_answer(ground_truth)
+
+ ZERO_METRIC = (0, 0, 0)
+
+ if normalized_prediction in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth:
+ return ZERO_METRIC
+ if normalized_ground_truth in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth:
+ return ZERO_METRIC
+
+ prediction_tokens = normalized_prediction.split()
+ ground_truth_tokens = normalized_ground_truth.split()
+ common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
+ num_same = sum(common.values())
+ if num_same == 0:
+ return ZERO_METRIC
+ precision = 1.0 * num_same / len(prediction_tokens)
+ recall = 1.0 * num_same / len(ground_truth_tokens)
+ f1 = (2 * precision * recall) / (precision + recall)
+ return f1, precision, recall
+
+
+def exact_match_score(prediction, ground_truth):
+ """calculate exact match score"""
+ return normalize_answer(prediction) == normalize_answer(ground_truth)
+
+
+def update_answer(metrics, prediction, gold):
+ """update answer"""
+ em = exact_match_score(prediction, gold)
+ f1, prec, recall = f1_score(prediction, gold)
+ metrics['em'] += float(em)
+ metrics['f1'] += f1
+ metrics['prec'] += prec
+ metrics['recall'] += recall
+ return em, prec, recall
+
+
+def update_sp(metrics, prediction, gold):
+ """update supporting sentences"""
+ cur_sp_pred = set(map(tuple, prediction))
+ gold_sp_pred = set(map(tuple, gold))
+ tp, fp, fn = 0, 0, 0
+ for e in cur_sp_pred:
+ if e in gold_sp_pred:
+ tp += 1
+ else:
+ fp += 1
+ for e in gold_sp_pred:
+ if e not in cur_sp_pred:
+ fn += 1
+ prec = 1.0 * tp / (tp + fp) if tp + fp > 0 else 0.0
+ recall = 1.0 * tp / (tp + fn) if tp + fn > 0 else 0.0
+ f1 = 2 * prec * recall / (prec + recall) if prec + recall > 0 else 0.0
+ em = 1.0 if fp + fn == 0 else 0.0
+ metrics['sp_em'] += em
+ metrics['sp_f1'] += f1
+ metrics['sp_prec'] += prec
+ metrics['sp_recall'] += recall
+ return em, prec, recall
+
+
+def hotpotqa_eval(prediction_file, gold_file):
+ """hotpotqa evaluate function"""
+ with open(prediction_file) as f:
+ prediction = json.load(f)
+ with open(gold_file) as f:
+ gold = json.load(f)
+
+ metrics = {'em': 0, 'f1': 0, 'prec': 0, 'recall': 0,
+ 'sp_em': 0, 'sp_f1': 0, 'sp_prec': 0, 'sp_recall': 0,
+ 'joint_em': 0, 'joint_f1': 0, 'joint_prec': 0, 'joint_recall': 0}
+ for dp in gold:
+ cur_id = dp['_id']
+ can_eval_joint = True
+ if cur_id not in prediction['answer']:
+ print('missing answer {}'.format(cur_id))
+ can_eval_joint = False
+ else:
+ em, prec, recall = update_answer(
+ metrics, prediction['answer'][cur_id], dp['answer'])
+ if cur_id not in prediction['sp']:
+ print('missing sp fact {}'.format(cur_id))
+ can_eval_joint = False
+ else:
+ sp_em, sp_prec, sp_recall = update_sp(
+ metrics, prediction['sp'][cur_id], dp['supporting_facts'])
+
+ if can_eval_joint:
+ joint_prec = prec * sp_prec
+ joint_recall = recall * sp_recall
+ if joint_prec + joint_recall > 0:
+ joint_f1 = 2 * joint_prec * joint_recall / (joint_prec + joint_recall)
+ else:
+ joint_f1 = 0.
+ joint_em = em * sp_em
+
+ metrics['joint_em'] += joint_em
+ metrics['joint_f1'] += joint_f1
+ metrics['joint_prec'] += joint_prec
+ metrics['joint_recall'] += joint_recall
+
+ num = len(gold)
+ for k in metrics:
+ metrics[k] /= num
+
+ return metrics
diff --git a/model_zoo/research/nlp/tprr/src/reader.py b/model_zoo/research/nlp/tprr/src/reader.py
new file mode 100644
index 0000000000..73d6fe8ac7
--- /dev/null
+++ b/model_zoo/research/nlp/tprr/src/reader.py
@@ -0,0 +1,73 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Reader model"""
+
+import mindspore.nn as nn
+from mindspore import load_checkpoint, load_param_into_net
+from mindspore.ops import BatchMatMul
+from mindspore import ops
+from mindspore import dtype as mstype
+from src.reader_albert_xxlarge import Reader_Albert
+from src.reader_downstream import Reader_Downstream
+
+
+dst_type = mstype.float16
+dst_type2 = mstype.float32
+
+
+class Reader(nn.Cell):
+ """Reader model"""
+ def __init__(self, batch_size, encoder_ck_file, downstream_ck_file):
+ """init function"""
+ super(Reader, self).__init__(auto_prefix=False)
+
+ self.encoder = Reader_Albert(batch_size)
+ param_dict = load_checkpoint(encoder_ck_file)
+ not_load_params = load_param_into_net(self.encoder, param_dict)
+ print(f"not loaded: {not_load_params}")
+
+ self.downstream = Reader_Downstream()
+ param_dict = load_checkpoint(downstream_ck_file)
+ not_load_params = load_param_into_net(self.downstream, param_dict)
+ print(f"not loaded: {not_load_params}")
+
+ self.bmm = BatchMatMul()
+
+ def construct(self, input_ids, attn_mask, token_type_ids,
+ context_mask, square_mask, packing_mask, cache_mask,
+ para_start_mapping, sent_end_mapping):
+ """construct function"""
+ state = self.encoder(attn_mask, input_ids, token_type_ids)
+
+ para_state = self.bmm(ops.Cast()(para_start_mapping, dst_type), ops.Cast()(state, dst_type)) # [B, 2, D]
+ sent_state = self.bmm(ops.Cast()(sent_end_mapping, dst_type), ops.Cast()(state, dst_type)) # [B, max_sent, D]
+
+ q_type, start, end, para_logit, sent_logit = self.downstream(ops.Cast()(para_state, dst_type2),
+ ops.Cast()(sent_state, dst_type2),
+ state,
+ context_mask)
+
+ outer = start[:, :, None] + end[:, None]
+
+ outer_mask = cache_mask
+ outer_mask = square_mask * outer_mask[None]
+ outer = outer - 1e30 * (1 - outer_mask)
+ outer = outer - 1e30 * packing_mask[:, :, None]
+ max_row = ops.ReduceMax()(outer, 2)
+ y1 = ops.Argmax()(max_row)
+ max_col = ops.ReduceMax()(outer, 1)
+ y2 = ops.Argmax()(max_col)
+
+ return start, end, q_type, para_logit, sent_logit, y1, y2
diff --git a/model_zoo/research/nlp/tprr/src/reader_albert_xxlarge.py b/model_zoo/research/nlp/tprr/src/reader_albert_xxlarge.py
new file mode 100644
index 0000000000..8eb7f75850
--- /dev/null
+++ b/model_zoo/research/nlp/tprr/src/reader_albert_xxlarge.py
@@ -0,0 +1,263 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""albert-xxlarge Model for reader"""
+
+import numpy as np
+from mindspore import nn, ops
+from mindspore import Tensor, Parameter
+from mindspore.ops import operations as P
+from mindspore import dtype as mstype
+
+dst_type = mstype.float16
+dst_type2 = mstype.float32
+
+
+class LayerNorm(nn.Cell):
+ """LayerNorm layer"""
+ def __init__(self, mul_7_w_shape, add_8_bias_shape):
+ """init function"""
+ super(LayerNorm, self).__init__()
+ self.reducemean_0 = P.ReduceMean(keep_dims=True)
+ self.sub_1 = P.Sub()
+ self.pow_2 = P.Pow()
+ self.pow_2_input_weight = 2.0
+ self.reducemean_3 = P.ReduceMean(keep_dims=True)
+ self.add_4 = P.Add()
+ self.add_4_bias = 9.999999960041972e-13
+ self.sqrt_5 = P.Sqrt()
+ self.div_6 = P.Div()
+ self.mul_7 = P.Mul()
+ self.mul_7_w = Parameter(Tensor(np.random.uniform(0, 1, mul_7_w_shape).astype(np.float32)), name=None)
+ self.add_8 = P.Add()
+ self.add_8_bias = Parameter(Tensor(np.random.uniform(0, 1, add_8_bias_shape).astype(np.float32)), name=None)
+
+ def construct(self, x):
+ """construct function"""
+ opt_reducemean_0 = self.reducemean_0(x, -1)
+ opt_sub_1 = self.sub_1(x, opt_reducemean_0)
+ opt_pow_2 = self.pow_2(opt_sub_1, self.pow_2_input_weight)
+ opt_reducemean_3 = self.reducemean_3(opt_pow_2, -1)
+ opt_add_4 = self.add_4(opt_reducemean_3, self.add_4_bias)
+ opt_sqrt_5 = self.sqrt_5(opt_add_4)
+ opt_div_6 = self.div_6(opt_sub_1, opt_sqrt_5)
+ opt_mul_7 = self.mul_7(opt_div_6, self.mul_7_w)
+ opt_add_8 = self.add_8(opt_mul_7, self.add_8_bias)
+ return opt_add_8
+
+
+class Linear(nn.Cell):
+ """Linear layer"""
+ def __init__(self, matmul_0_weight_shape, add_1_bias_shape):
+ """init function"""
+ super(Linear, self).__init__()
+ self.matmul_0 = nn.MatMul()
+ self.matmul_0_w = Parameter(Tensor(np.random.uniform(0, 1, matmul_0_weight_shape).astype(np.float32)),
+ name=None)
+ self.add_1 = P.Add()
+ self.add_1_bias = Parameter(Tensor(np.random.uniform(0, 1, add_1_bias_shape).astype(np.float32)), name=None)
+
+ def construct(self, x):
+ """construct function"""
+ opt_matmul_0 = self.matmul_0(ops.Cast()(x, dst_type), ops.Cast()(self.matmul_0_w, dst_type))
+ opt_add_1 = self.add_1(ops.Cast()(opt_matmul_0, dst_type2), self.add_1_bias)
+ return opt_add_1
+
+
+class MultiHeadAttn(nn.Cell):
+ """Multi-head attention layer"""
+ def __init__(self, batch_size):
+ """init function"""
+ super(MultiHeadAttn, self).__init__()
+ self.batch_size = batch_size
+ self.matmul_0 = nn.MatMul()
+ self.matmul_0_w = Parameter(Tensor(np.random.uniform(0, 1, (4096, 4096)).astype(np.float32)), name=None)
+ self.matmul_1 = nn.MatMul()
+ self.matmul_1_w = Parameter(Tensor(np.random.uniform(0, 1, (4096, 4096)).astype(np.float32)), name=None)
+ self.matmul_2 = nn.MatMul()
+ self.matmul_2_w = Parameter(Tensor(np.random.uniform(0, 1, (4096, 4096)).astype(np.float32)), name=None)
+ self.add_3 = P.Add()
+ self.add_3_bias = Parameter(Tensor(np.random.uniform(0, 1, (4096,)).astype(np.float32)), name=None)
+ self.add_4 = P.Add()
+ self.add_4_bias = Parameter(Tensor(np.random.uniform(0, 1, (4096,)).astype(np.float32)), name=None)
+ self.add_5 = P.Add()
+ self.add_5_bias = Parameter(Tensor(np.random.uniform(0, 1, (4096,)).astype(np.float32)), name=None)
+ self.reshape_6 = P.Reshape()
+ self.reshape_6_shape = tuple([batch_size, 512, 64, 64])
+ self.reshape_7 = P.Reshape()
+ self.reshape_7_shape = tuple([batch_size, 512, 64, 64])
+ self.reshape_8 = P.Reshape()
+ self.reshape_8_shape = tuple([batch_size, 512, 64, 64])
+ self.transpose_9 = P.Transpose()
+ self.transpose_10 = P.Transpose()
+ self.transpose_11 = P.Transpose()
+ self.matmul_12 = nn.MatMul()
+ self.div_13 = P.Div()
+ self.div_13_w = 8.0
+ self.add_14 = P.Add()
+ self.softmax_15 = nn.Softmax(axis=3)
+ self.matmul_16 = nn.MatMul()
+ self.transpose_17 = P.Transpose()
+ self.matmul_18 = P.MatMul()
+ self.matmul_18_weight = Parameter(Tensor(np.random.uniform(0, 1, (64, 64, 4096)).astype(np.float32)), name=None)
+ self.add_19 = P.Add()
+ self.add_19_bias = Parameter(Tensor(np.random.uniform(0, 1, (4096,)).astype(np.float32)), name=None)
+
+ def construct(self, x, x0):
+ """construct function"""
+ opt_matmul_0 = self.matmul_0(ops.Cast()(x, dst_type), ops.Cast()(self.matmul_0_w, dst_type))
+ opt_matmul_1 = self.matmul_1(ops.Cast()(x, dst_type), ops.Cast()(self.matmul_1_w, dst_type))
+ opt_matmul_2 = self.matmul_2(ops.Cast()(x, dst_type), ops.Cast()(self.matmul_2_w, dst_type))
+ opt_add_3 = self.add_3(ops.Cast()(opt_matmul_0, dst_type2), self.add_3_bias)
+ opt_add_4 = self.add_4(ops.Cast()(opt_matmul_1, dst_type2), self.add_4_bias)
+ opt_add_5 = self.add_5(ops.Cast()(opt_matmul_2, dst_type2), self.add_5_bias)
+ opt_reshape_6 = self.reshape_6(opt_add_3, self.reshape_6_shape)
+ opt_reshape_7 = self.reshape_7(opt_add_4, self.reshape_7_shape)
+ opt_reshape_8 = self.reshape_8(opt_add_5, self.reshape_8_shape)
+ opt_transpose_9 = self.transpose_9(opt_reshape_6, (0, 2, 1, 3))
+ opt_transpose_10 = self.transpose_10(opt_reshape_7, (0, 2, 3, 1))
+ opt_transpose_11 = self.transpose_11(opt_reshape_8, (0, 2, 1, 3))
+ opt_matmul_12 = self.matmul_12(ops.Cast()(opt_transpose_9, dst_type), ops.Cast()(opt_transpose_10, dst_type))
+ opt_div_13 = self.div_13(ops.Cast()(opt_matmul_12, dst_type2), ops.Cast()(self.div_13_w, dst_type2))
+ opt_add_14 = self.add_14(opt_div_13, x0)
+ opt_softmax_15 = self.softmax_15(opt_add_14)
+ opt_matmul_16 = self.matmul_16(ops.Cast()(opt_softmax_15, dst_type), ops.Cast()(opt_transpose_11, dst_type))
+ opt_transpose_17 = self.transpose_17(ops.Cast()(opt_matmul_16, dst_type2), (0, 2, 1, 3))
+ opt_matmul_18 = self.matmul_18(ops.Cast()(opt_transpose_17, dst_type).view(self.batch_size * 512, -1),
+ ops.Cast()(self.matmul_18_weight, dst_type).view(-1, 4096))\
+ .view(self.batch_size, 512, 4096)
+
+ opt_add_19 = self.add_19(ops.Cast()(opt_matmul_18, dst_type2), self.add_19_bias)
+ return opt_add_19
+
+
+class NewGeLU(nn.Cell):
+ """new gelu layer"""
+ def __init__(self):
+ """init function"""
+ super(NewGeLU, self).__init__()
+ self.mul_0 = P.Mul()
+ self.mul_0_w = 0.5
+ self.pow_1 = P.Pow()
+ self.pow_1_input_weight = 3.0
+ self.mul_2 = P.Mul()
+ self.mul_2_w = 0.044714998453855515
+ self.add_3 = P.Add()
+ self.mul_4 = P.Mul()
+ self.mul_4_w = 0.7978845834732056
+ self.tanh_5 = nn.Tanh()
+ self.add_6 = P.Add()
+ self.add_6_bias = 1.0
+ self.mul_7 = P.Mul()
+
+ def construct(self, x):
+ """construct function"""
+ opt_mul_0 = self.mul_0(x, self.mul_0_w)
+ opt_pow_1 = self.pow_1(x, self.pow_1_input_weight)
+ opt_mul_2 = self.mul_2(opt_pow_1, self.mul_2_w)
+ opt_add_3 = self.add_3(x, opt_mul_2)
+ opt_mul_4 = self.mul_4(opt_add_3, self.mul_4_w)
+ opt_tanh_5 = self.tanh_5(opt_mul_4)
+ opt_add_6 = self.add_6(opt_tanh_5, self.add_6_bias)
+ opt_mul_7 = self.mul_7(opt_mul_0, opt_add_6)
+ return opt_mul_7
+
+
+class TransformerLayer(nn.Cell):
+ """Transformer layer"""
+ def __init__(self, batch_size, layernorm1_0_mul_7_w_shape, layernorm1_0_add_8_bias_shape,
+ linear3_0_matmul_0_weight_shape, linear3_0_add_1_bias_shape, linear3_1_matmul_0_weight_shape,
+ linear3_1_add_1_bias_shape):
+ """init function"""
+ super(TransformerLayer, self).__init__()
+ self.multiheadattn_0 = MultiHeadAttn(batch_size)
+ self.add_0 = P.Add()
+ self.layernorm1_0 = LayerNorm(mul_7_w_shape=layernorm1_0_mul_7_w_shape,
+ add_8_bias_shape=layernorm1_0_add_8_bias_shape)
+ self.linear3_0 = Linear(matmul_0_weight_shape=linear3_0_matmul_0_weight_shape,
+ add_1_bias_shape=linear3_0_add_1_bias_shape)
+ self.newgelu2_0 = NewGeLU()
+ self.linear3_1 = Linear(matmul_0_weight_shape=linear3_1_matmul_0_weight_shape,
+ add_1_bias_shape=linear3_1_add_1_bias_shape)
+ self.add_1 = P.Add()
+
+ def construct(self, x, x0):
+ """construct function"""
+ multiheadattn_0_opt = self.multiheadattn_0(x, x0)
+ opt_add_0 = self.add_0(x, multiheadattn_0_opt)
+ layernorm1_0_opt = self.layernorm1_0(opt_add_0)
+ linear3_0_opt = self.linear3_0(layernorm1_0_opt)
+ newgelu2_0_opt = self.newgelu2_0(linear3_0_opt)
+ linear3_1_opt = self.linear3_1(newgelu2_0_opt)
+ opt_add_1 = self.add_1(linear3_1_opt, layernorm1_0_opt)
+ return opt_add_1
+
+
+class Reader_Albert(nn.Cell):
+ """Albert model for reader"""
+ def __init__(self, batch_size):
+ """init function"""
+ super(Reader_Albert, self).__init__()
+ self.expanddims_0 = P.ExpandDims()
+ self.expanddims_0_axis = 1
+ self.expanddims_3 = P.ExpandDims()
+ self.expanddims_3_axis = 2
+ self.cast_5 = P.Cast()
+ self.cast_5_to = mstype.float32
+ self.sub_7 = P.Sub()
+ self.sub_7_bias = 1.0
+ self.mul_9 = P.Mul()
+ self.mul_9_w = -10000.0
+ self.gather_1_input_weight = Parameter(Tensor(np.random.uniform(0, 1, (30005, 128)).astype(np.float32)),
+ name=None)
+ self.gather_1_axis = 0
+ self.gather_1 = P.Gather()
+ self.gather_2_input_weight = Parameter(Tensor(np.random.uniform(0, 1, (2, 128)).astype(np.float32)), name=None)
+ self.gather_2_axis = 0
+ self.gather_2 = P.Gather()
+ self.add_4 = P.Add()
+ self.add_6 = P.Add()
+ self.add_6_bias = Parameter(Tensor(np.random.uniform(0, 1, (1, 512, 128)).astype(np.float32)), name=None)
+ self.layernorm1_0 = LayerNorm(mul_7_w_shape=(128,), add_8_bias_shape=(128,))
+ self.linear3_0 = Linear(matmul_0_weight_shape=(128, 4096), add_1_bias_shape=(4096,))
+ self.module34_0 = TransformerLayer(batch_size,
+ layernorm1_0_mul_7_w_shape=(4096,),
+ layernorm1_0_add_8_bias_shape=(4096,),
+ linear3_0_matmul_0_weight_shape=(4096, 16384),
+ linear3_0_add_1_bias_shape=(16384,),
+ linear3_1_matmul_0_weight_shape=(16384, 4096),
+ linear3_1_add_1_bias_shape=(4096,))
+ self.layernorm1_1 = LayerNorm(mul_7_w_shape=(4096,), add_8_bias_shape=(4096,))
+
+ def construct(self, x, x0, x1):
+ """construct function"""
+ opt_expanddims_0 = self.expanddims_0(x, self.expanddims_0_axis)
+ opt_expanddims_3 = self.expanddims_3(opt_expanddims_0, self.expanddims_3_axis)
+ opt_cast_5 = self.cast_5(opt_expanddims_3, self.cast_5_to)
+ opt_sub_7 = self.sub_7(self.sub_7_bias, opt_cast_5)
+ opt_mul_9 = self.mul_9(opt_sub_7, self.mul_9_w)
+ opt_gather_1_axis = self.gather_1_axis
+ opt_gather_1 = self.gather_1(self.gather_1_input_weight, x0, opt_gather_1_axis)
+ opt_gather_2_axis = self.gather_2_axis
+ opt_gather_2 = self.gather_2(self.gather_2_input_weight, x1, opt_gather_2_axis)
+ opt_add_4 = self.add_4(opt_gather_1, opt_gather_2)
+ opt_add_6 = self.add_6(opt_add_4, self.add_6_bias)
+ layernorm1_0_opt = self.layernorm1_0(opt_add_6)
+ linear3_0_opt = self.linear3_0(layernorm1_0_opt)
+ module34_0_opt = self.module34_0(linear3_0_opt, opt_mul_9)
+ out = self.layernorm1_1(module34_0_opt)
+ for _ in range(11):
+ out = self.module34_0(out, opt_mul_9)
+ out = self.layernorm1_1(out)
+ return out
diff --git a/model_zoo/research/nlp/tprr/src/reader_downstream.py b/model_zoo/research/nlp/tprr/src/reader_downstream.py
new file mode 100644
index 0000000000..f971b0f522
--- /dev/null
+++ b/model_zoo/research/nlp/tprr/src/reader_downstream.py
@@ -0,0 +1,213 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""downstream Model for reader"""
+
+import numpy as np
+from mindspore import nn, ops
+from mindspore import Tensor, Parameter
+from mindspore.ops import operations as P
+from mindspore import dtype as mstype
+
+
+dst_type = mstype.float16
+dst_type2 = mstype.float32
+
+
+class Module15(nn.Cell):
+ """module of reader downstream"""
+ def __init__(self, matmul_0_weight_shape, add_1_bias_shape):
+ """init function"""
+ super(Module15, self).__init__()
+ self.matmul_0 = nn.MatMul()
+ self.matmul_0_w = Parameter(Tensor(np.random.uniform(0, 1, matmul_0_weight_shape).astype(np.float32)),
+ name=None)
+ self.add_1 = P.Add()
+ self.add_1_bias = Parameter(Tensor(np.random.uniform(0, 1, add_1_bias_shape).astype(np.float32)), name=None)
+ self.relu_2 = nn.ReLU()
+
+ def construct(self, x):
+ """construct function"""
+ opt_matmul_0 = self.matmul_0(ops.Cast()(x, dst_type), ops.Cast()(self.matmul_0_w, dst_type))
+ opt_add_1 = self.add_1(ops.Cast()(opt_matmul_0, dst_type2), self.add_1_bias)
+ opt_relu_2 = self.relu_2(opt_add_1)
+ return opt_relu_2
+
+
+class NormModule(nn.Cell):
+ """Normalization module of reader downstream"""
+ def __init__(self, mul_8_w_shape, add_9_bias_shape):
+ """init function"""
+ super(NormModule, self).__init__()
+ self.reducemean_0 = P.ReduceMean(keep_dims=True)
+ self.sub_1 = P.Sub()
+ self.sub_2 = P.Sub()
+ self.pow_3 = P.Pow()
+ self.pow_3_input_weight = 2.0
+ self.reducemean_4 = P.ReduceMean(keep_dims=True)
+ self.add_5 = P.Add()
+ self.add_5_bias = 9.999999960041972e-13
+ self.sqrt_6 = P.Sqrt()
+ self.div_7 = P.Div()
+ self.mul_8 = P.Mul()
+ self.mul_8_w = Parameter(Tensor(np.random.uniform(0, 1, mul_8_w_shape).astype(np.float32)), name=None)
+ self.add_9 = P.Add()
+ self.add_9_bias = Parameter(Tensor(np.random.uniform(0, 1, add_9_bias_shape).astype(np.float32)), name=None)
+
+ def construct(self, x):
+ """construct function"""
+ opt_reducemean_0 = self.reducemean_0(x, -1)
+ opt_sub_1 = self.sub_1(x, opt_reducemean_0)
+ opt_sub_2 = self.sub_2(x, opt_reducemean_0)
+ opt_pow_3 = self.pow_3(opt_sub_1, self.pow_3_input_weight)
+ opt_reducemean_4 = self.reducemean_4(opt_pow_3, -1)
+ opt_add_5 = self.add_5(opt_reducemean_4, self.add_5_bias)
+ opt_sqrt_6 = self.sqrt_6(opt_add_5)
+ opt_div_7 = self.div_7(opt_sub_2, opt_sqrt_6)
+ opt_mul_8 = self.mul_8(self.mul_8_w, opt_div_7)
+ opt_add_9 = self.add_9(opt_mul_8, self.add_9_bias)
+ return opt_add_9
+
+
+class Module16(nn.Cell):
+ """module of reader downstream"""
+ def __init__(self, module15_0_matmul_0_weight_shape, module15_0_add_1_bias_shape, normmodule_0_mul_8_w_shape,
+ normmodule_0_add_9_bias_shape):
+ """init function"""
+ super(Module16, self).__init__()
+ self.module15_0 = Module15(matmul_0_weight_shape=module15_0_matmul_0_weight_shape,
+ add_1_bias_shape=module15_0_add_1_bias_shape)
+ self.normmodule_0 = NormModule(mul_8_w_shape=normmodule_0_mul_8_w_shape,
+ add_9_bias_shape=normmodule_0_add_9_bias_shape)
+ self.matmul_0 = nn.MatMul()
+ self.matmul_0_w = Parameter(Tensor(np.random.uniform(0, 1, (8192, 1)).astype(np.float32)), name=None)
+
+ def construct(self, x):
+ """construct function"""
+ module15_0_opt = self.module15_0(x)
+ normmodule_0_opt = self.normmodule_0(module15_0_opt)
+ opt_matmul_0 = self.matmul_0(ops.Cast()(normmodule_0_opt, dst_type), ops.Cast()(self.matmul_0_w, dst_type))
+ return ops.Cast()(opt_matmul_0, dst_type2)
+
+
+class Module17(nn.Cell):
+ """module of reader downstream"""
+ def __init__(self, module15_0_matmul_0_weight_shape, module15_0_add_1_bias_shape, normmodule_0_mul_8_w_shape,
+ normmodule_0_add_9_bias_shape):
+ """init function"""
+ super(Module17, self).__init__()
+ self.module15_0 = Module15(matmul_0_weight_shape=module15_0_matmul_0_weight_shape,
+ add_1_bias_shape=module15_0_add_1_bias_shape)
+ self.normmodule_0 = NormModule(mul_8_w_shape=normmodule_0_mul_8_w_shape,
+ add_9_bias_shape=normmodule_0_add_9_bias_shape)
+ self.matmul_0 = nn.MatMul()
+ self.matmul_0_w = Parameter(Tensor(np.random.uniform(0, 1, (4096, 1)).astype(np.float32)), name=None)
+ self.add_1 = P.Add()
+ self.add_1_bias = Parameter(Tensor(np.random.uniform(0, 1, (1,)).astype(np.float32)), name=None)
+
+ def construct(self, x):
+ """construct function"""
+ module15_0_opt = self.module15_0(x)
+ normmodule_0_opt = self.normmodule_0(module15_0_opt)
+ opt_matmul_0 = self.matmul_0(ops.Cast()(normmodule_0_opt, dst_type), ops.Cast()(self.matmul_0_w, dst_type))
+ opt_add_1 = self.add_1(ops.Cast()(opt_matmul_0, dst_type2), self.add_1_bias)
+ return opt_add_1
+
+
+class Module5(nn.Cell):
+ """module of reader downstream"""
+ def __init__(self):
+ """init function"""
+ super(Module5, self).__init__()
+ self.sub_0 = P.Sub()
+ self.sub_0_bias = 1.0
+ self.mul_1 = P.Mul()
+ self.mul_1_w = 1.0000000150474662e+30
+
+ def construct(self, x):
+ """construct function"""
+ opt_sub_0 = self.sub_0(self.sub_0_bias, x)
+ opt_mul_1 = self.mul_1(opt_sub_0, self.mul_1_w)
+ return opt_mul_1
+
+
+class Module10(nn.Cell):
+ """module of reader downstream"""
+ def __init__(self):
+ """init function"""
+ super(Module10, self).__init__()
+ self.squeeze_0 = P.Squeeze(2)
+ self.module5_0 = Module5()
+ self.sub_1 = P.Sub()
+
+ def construct(self, x, x0):
+ """construct function"""
+ opt_squeeze_0 = self.squeeze_0(x)
+ module5_0_opt = self.module5_0(x0)
+ opt_sub_1 = self.sub_1(opt_squeeze_0, module5_0_opt)
+ return opt_sub_1
+
+
+class Reader_Downstream(nn.Cell):
+ """Downstream model for reader"""
+ def __init__(self):
+ """init function"""
+ super(Reader_Downstream, self).__init__()
+ self.module16_0 = Module16(module15_0_matmul_0_weight_shape=(4096, 8192),
+ module15_0_add_1_bias_shape=(8192,),
+ normmodule_0_mul_8_w_shape=(8192,),
+ normmodule_0_add_9_bias_shape=(8192,))
+ self.add_74 = P.Add()
+ self.add_74_bias = Parameter(Tensor(np.random.uniform(0, 1, (1,)).astype(np.float32)), name=None)
+ self.module16_1 = Module16(module15_0_matmul_0_weight_shape=(4096, 8192),
+ module15_0_add_1_bias_shape=(8192,),
+ normmodule_0_mul_8_w_shape=(8192,),
+ normmodule_0_add_9_bias_shape=(8192,))
+ self.add_75 = P.Add()
+ self.add_75_bias = Parameter(Tensor(np.random.uniform(0, 1, (1,)).astype(np.float32)), name=None)
+ self.module17_0 = Module17(module15_0_matmul_0_weight_shape=(4096, 4096),
+ module15_0_add_1_bias_shape=(4096,),
+ normmodule_0_mul_8_w_shape=(4096,),
+ normmodule_0_add_9_bias_shape=(4096,))
+ self.module10_0 = Module10()
+ self.module17_1 = Module17(module15_0_matmul_0_weight_shape=(4096, 4096),
+ module15_0_add_1_bias_shape=(4096,),
+ normmodule_0_mul_8_w_shape=(4096,),
+ normmodule_0_add_9_bias_shape=(4096,))
+ self.module10_1 = Module10()
+ self.gather_6_input_weight = Tensor(np.array(0))
+ self.gather_6_axis = 1
+ self.gather_6 = P.Gather()
+ self.dense_13 = nn.Dense(in_channels=4096, out_channels=4096, has_bias=True)
+ self.relu_18 = nn.ReLU()
+ self.normmodule_0 = NormModule(mul_8_w_shape=(4096,), add_9_bias_shape=(4096,))
+ self.dense_73 = nn.Dense(in_channels=4096, out_channels=3, has_bias=True)
+
+ def construct(self, x, x0, x1, x2):
+ """construct function"""
+ module16_0_opt = self.module16_0(x)
+ opt_add_74 = self.add_74(module16_0_opt, self.add_74_bias)
+ module16_1_opt = self.module16_1(x0)
+ opt_add_75 = self.add_75(module16_1_opt, self.add_75_bias)
+ module17_0_opt = self.module17_0(x1)
+ opt_module10_0 = self.module10_0(module17_0_opt, x2)
+ module17_1_opt = self.module17_1(x1)
+ opt_module10_1 = self.module10_1(module17_1_opt, x2)
+ opt_gather_6_axis = self.gather_6_axis
+ opt_gather_6 = self.gather_6(x1, self.gather_6_input_weight, opt_gather_6_axis)
+ opt_dense_13 = self.dense_13(opt_gather_6)
+ opt_relu_18 = self.relu_18(opt_dense_13)
+ normmodule_0_opt = self.normmodule_0(opt_relu_18)
+ opt_dense_73 = self.dense_73(normmodule_0_opt)
+ return opt_dense_73, opt_module10_0, opt_module10_1, opt_add_74, opt_add_75
diff --git a/model_zoo/research/nlp/tprr/src/reader_eval.py b/model_zoo/research/nlp/tprr/src/reader_eval.py
new file mode 100644
index 0000000000..e39a9055fb
--- /dev/null
+++ b/model_zoo/research/nlp/tprr/src/reader_eval.py
@@ -0,0 +1,142 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""execute reader"""
+
+from collections import defaultdict
+import random
+from time import time
+import json
+from tqdm import tqdm
+import numpy as np
+
+from transformers import AlbertTokenizer
+
+from mindspore import Tensor, ops
+from mindspore import dtype as mstype
+
+from src.rerank_and_reader_data_generator import DataGenerator
+from src.rerank_and_reader_utils import convert_to_tokens, make_wiki_id, DocDB
+from src.reader import Reader
+
+
+def read(args):
+ """reader function"""
+ db_file = args.wiki_db_file
+ reader_feature_file = args.reader_feature_file
+ reader_example_file = args.reader_example_file
+ encoder_ck_file = args.reader_encoder_ck_file
+ downstream_ck_file = args.reader_downstream_ck_file
+ albert_model_path = args.albert_model_path
+ reader_result_file = args.reader_result_file
+ seed = args.seed
+ sp_threshold = args.sp_threshold
+ seq_len = args.seq_len
+ batch_size = args.reader_batch_size
+ para_limit = args.max_para_num
+ sent_limit = args.max_sent_num
+
+ random.seed(seed)
+ np.random.seed(seed)
+
+ t1 = time()
+
+ doc_db = DocDB(db_file)
+
+ generator = DataGenerator(feature_file_path=reader_feature_file,
+ example_file_path=reader_example_file,
+ batch_size=batch_size, seq_len=seq_len,
+ para_limit=para_limit, sent_limit=sent_limit,
+ task_type="reader")
+ example_dict = generator.example_dict
+ feature_dict = generator.feature_dict
+ answer_dict = defaultdict(lambda: defaultdict(list))
+ new_answer_dict = {}
+ total_sp_dict = defaultdict(list)
+ new_total_sp_dict = defaultdict(list)
+
+ tokenizer = AlbertTokenizer.from_pretrained(albert_model_path)
+ new_tokens = ['[q]', '[/q]', '', '', '[s]']
+ tokenizer.add_tokens(new_tokens)
+
+ reader = Reader(batch_size=batch_size,
+ encoder_ck_file=encoder_ck_file,
+ downstream_ck_file=downstream_ck_file)
+
+ print("start reading ...")
+
+ for _, batch in tqdm(enumerate(generator)):
+ input_ids = Tensor(batch["context_idxs"], mstype.int32)
+ attn_mask = Tensor(batch["context_mask"], mstype.int32)
+ token_type_ids = Tensor(batch["segment_idxs"], mstype.int32)
+ context_mask = Tensor(batch["context_mask"], mstype.float32)
+ square_mask = Tensor(batch["square_mask"], mstype.float32)
+ packing_mask = Tensor(batch["query_mapping"], mstype.float32)
+ para_start_mapping = Tensor(batch["para_start_mapping"], mstype.float32)
+ sent_end_mapping = Tensor(batch["sent_end_mapping"], mstype.float32)
+ unique_ids = batch["unique_ids"]
+ sent_names = batch["sent_names"]
+ cache_mask = Tensor(np.tril(np.triu(np.ones((seq_len, seq_len)), 0), 30), mstype.float32)
+
+ _, _, q_type, _, sent_logit, y1, y2 = reader(input_ids, attn_mask, token_type_ids,
+ context_mask, square_mask, packing_mask, cache_mask,
+ para_start_mapping, sent_end_mapping)
+
+ type_prob = ops.Softmax()(q_type).asnumpy()
+
+ answer_dict_ = convert_to_tokens(example_dict,
+ feature_dict,
+ batch['ids'],
+ y1.asnumpy().tolist(),
+ y2.asnumpy().tolist(),
+ type_prob,
+ tokenizer,
+ sent_logit.asnumpy(),
+ sent_names,
+ unique_ids)
+ for q_id in answer_dict_:
+ answer_dict[q_id] = answer_dict_[q_id]
+
+ for q_id in answer_dict:
+ res = answer_dict[q_id]
+ answer_text_ = res[0]
+ sent_ = res[1]
+ sent_names_ = res[2]
+ new_answer_dict[q_id] = answer_text_
+
+ predict_support_np = ops.Sigmoid()(Tensor(sent_, mstype.float32)).asnumpy()
+
+ for j in range(predict_support_np.shape[0]):
+ if j >= len(sent_names_):
+ break
+ if predict_support_np[j] > sp_threshold:
+ total_sp_dict[q_id].append(sent_names_[j])
+
+ for _id in total_sp_dict:
+ _sent_names = total_sp_dict[_id]
+ for para in _sent_names:
+ title = make_wiki_id(para[0], 0)
+ para_original_title = doc_db.get_doc_info(title)[-1]
+ para[0] = para_original_title
+ new_total_sp_dict[_id].append(para)
+
+ prediction = {'answer': new_answer_dict,
+ 'sp': new_total_sp_dict}
+
+ with open(reader_result_file, 'w') as f:
+ json.dump(prediction, f, indent=4)
+
+ t2 = time()
+
+ print(f"reader cost time: {t2-t1} s")
diff --git a/model_zoo/research/nlp/tprr/src/rerank_albert_xxlarge.py b/model_zoo/research/nlp/tprr/src/rerank_albert_xxlarge.py
new file mode 100644
index 0000000000..965b9b352a
--- /dev/null
+++ b/model_zoo/research/nlp/tprr/src/rerank_albert_xxlarge.py
@@ -0,0 +1,276 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""albert-xxlarge Model for reranker"""
+
+import numpy as np
+from mindspore import nn, ops
+from mindspore import Tensor, Parameter
+from mindspore.ops import operations as P
+from mindspore import dtype as mstype
+
+
+dst_type = mstype.float16
+dst_type2 = mstype.float32
+
+
+class LayerNorm(nn.Cell):
+ """LayerNorm layer"""
+ def __init__(self, passthrough_w_0, passthrough_w_1):
+ """init function"""
+ super(LayerNorm, self).__init__()
+ self.reducemean_0 = P.ReduceMean(keep_dims=True)
+ self.sub_1 = P.Sub()
+ self.pow_2 = P.Pow()
+ self.pow_2_input_weight = 2.0
+ self.reducemean_3 = P.ReduceMean(keep_dims=True)
+ self.add_4 = P.Add()
+ self.add_4_bias = 9.999999960041972e-13
+ self.sqrt_5 = P.Sqrt()
+ self.div_6 = P.Div()
+ self.mul_7 = P.Mul()
+ self.mul_7_w = passthrough_w_0
+ self.add_8 = P.Add()
+ self.add_8_bias = passthrough_w_1
+
+ def construct(self, x):
+ """construct function"""
+ opt_reducemean_0 = self.reducemean_0(x, -1)
+ opt_sub_1 = self.sub_1(x, opt_reducemean_0)
+ opt_pow_2 = self.pow_2(opt_sub_1, self.pow_2_input_weight)
+ opt_reducemean_3 = self.reducemean_3(opt_pow_2, -1)
+ opt_add_4 = self.add_4(opt_reducemean_3, self.add_4_bias)
+ opt_sqrt_5 = self.sqrt_5(opt_add_4)
+ opt_div_6 = self.div_6(opt_sub_1, opt_sqrt_5)
+ opt_mul_7 = self.mul_7(opt_div_6, self.mul_7_w)
+ opt_add_8 = self.add_8(opt_mul_7, self.add_8_bias)
+ return opt_add_8
+
+
+class Linear(nn.Cell):
+ """Linear layer"""
+ def __init__(self, matmul_0_w_shape, passthrough_w_0):
+ """init function"""
+ super(Linear, self).__init__()
+ self.matmul_0 = nn.MatMul()
+ self.matmul_0_w = Parameter(Tensor(np.random.uniform(0, 1, matmul_0_w_shape).astype(np.float32)), name=None)
+ self.add_1 = P.Add()
+ self.add_1_bias = passthrough_w_0
+
+ def construct(self, x):
+ """construct function"""
+ opt_matmul_0 = self.matmul_0(ops.Cast()(x, dst_type), ops.Cast()(self.matmul_0_w, dst_type))
+ opt_add_1 = self.add_1(ops.Cast()(opt_matmul_0, dst_type2), self.add_1_bias)
+ return opt_add_1
+
+
+class MultiHeadAttn(nn.Cell):
+ """Multi-head attention layer"""
+ def __init__(self, batch_size, passthrough_w_0, passthrough_w_1, passthrough_w_2):
+ """init function"""
+ super(MultiHeadAttn, self).__init__()
+ self.batch_size = batch_size
+ self.matmul_0 = nn.MatMul()
+ self.matmul_0_w = Parameter(Tensor(np.random.uniform(0, 1, (4096, 4096)).astype(np.float32)), name=None)
+ self.matmul_1 = nn.MatMul()
+ self.matmul_1_w = Parameter(Tensor(np.random.uniform(0, 1, (4096, 4096)).astype(np.float32)), name=None)
+ self.matmul_2 = nn.MatMul()
+ self.matmul_2_w = Parameter(Tensor(np.random.uniform(0, 1, (4096, 4096)).astype(np.float32)), name=None)
+ self.add_3 = P.Add()
+ self.add_3_bias = passthrough_w_0
+ self.add_4 = P.Add()
+ self.add_4_bias = passthrough_w_1
+ self.add_5 = P.Add()
+ self.add_5_bias = passthrough_w_2
+ self.reshape_6 = P.Reshape()
+ self.reshape_6_shape = tuple([batch_size, 512, 64, 64])
+ self.reshape_7 = P.Reshape()
+ self.reshape_7_shape = tuple([batch_size, 512, 64, 64])
+ self.reshape_8 = P.Reshape()
+ self.reshape_8_shape = tuple([batch_size, 512, 64, 64])
+ self.transpose_9 = P.Transpose()
+ self.transpose_10 = P.Transpose()
+ self.transpose_11 = P.Transpose()
+ self.matmul_12 = nn.MatMul()
+ self.div_13 = P.Div()
+ self.div_13_w = 8.0
+ self.add_14 = P.Add()
+ self.softmax_15 = nn.Softmax(axis=3)
+ self.matmul_16 = nn.MatMul()
+ self.transpose_17 = P.Transpose()
+ self.matmul_18 = P.MatMul()
+ self.matmul_18_weight = Parameter(Tensor(np.random.uniform(0, 1, (64, 64, 4096)).astype(np.float32)), name=None)
+ self.add_19 = P.Add()
+ self.add_19_bias = Parameter(Tensor(np.random.uniform(0, 1, (4096,)).astype(np.float32)), name=None)
+
+ def construct(self, x, x0):
+ """construct function"""
+ opt_matmul_0 = self.matmul_0(ops.Cast()(x, dst_type), ops.Cast()(self.matmul_0_w, dst_type))
+ opt_matmul_1 = self.matmul_1(ops.Cast()(x, dst_type), ops.Cast()(self.matmul_1_w, dst_type))
+ opt_matmul_2 = self.matmul_2(ops.Cast()(x, dst_type), ops.Cast()(self.matmul_2_w, dst_type))
+ opt_add_3 = self.add_3(ops.Cast()(opt_matmul_0, dst_type2), self.add_3_bias)
+ opt_add_4 = self.add_4(ops.Cast()(opt_matmul_1, dst_type2), self.add_4_bias)
+ opt_add_5 = self.add_5(ops.Cast()(opt_matmul_2, dst_type2), self.add_5_bias)
+ opt_reshape_6 = self.reshape_6(opt_add_3, self.reshape_6_shape)
+ opt_reshape_7 = self.reshape_7(opt_add_4, self.reshape_7_shape)
+ opt_reshape_8 = self.reshape_8(opt_add_5, self.reshape_8_shape)
+ opt_transpose_9 = self.transpose_9(opt_reshape_6, (0, 2, 1, 3))
+ opt_transpose_10 = self.transpose_10(opt_reshape_7, (0, 2, 3, 1))
+ opt_transpose_11 = self.transpose_11(opt_reshape_8, (0, 2, 1, 3))
+ opt_matmul_12 = self.matmul_12(ops.Cast()(opt_transpose_9, dst_type), ops.Cast()(opt_transpose_10, dst_type))
+ opt_div_13 = self.div_13(ops.Cast()(opt_matmul_12, dst_type2), ops.Cast()(self.div_13_w, dst_type2))
+ opt_add_14 = self.add_14(opt_div_13, x0)
+ opt_softmax_15 = self.softmax_15(opt_add_14)
+ opt_matmul_16 = self.matmul_16(ops.Cast()(opt_softmax_15, dst_type), ops.Cast()(opt_transpose_11, dst_type))
+ opt_transpose_17 = self.transpose_17(ops.Cast()(opt_matmul_16, dst_type2), (0, 2, 1, 3))
+ opt_matmul_18 = self.matmul_18(ops.Cast()(opt_transpose_17, dst_type).view(self.batch_size * 512, -1),
+ ops.Cast()(self.matmul_18_weight, dst_type).view(-1, 4096))\
+ .view(self.batch_size, 512, 4096)
+ opt_add_19 = self.add_19(ops.Cast()(opt_matmul_18, dst_type2), self.add_19_bias)
+ return opt_add_19
+
+
+class NewGeLU(nn.Cell):
+ """Gelu layer"""
+ def __init__(self):
+ """init function"""
+ super(NewGeLU, self).__init__()
+ self.mul_0 = P.Mul()
+ self.mul_0_w = 0.5
+ self.pow_1 = P.Pow()
+ self.pow_1_input_weight = 3.0
+ self.mul_2 = P.Mul()
+ self.mul_2_w = 0.044714998453855515
+ self.add_3 = P.Add()
+ self.mul_4 = P.Mul()
+ self.mul_4_w = 0.7978845834732056
+ self.tanh_5 = nn.Tanh()
+ self.add_6 = P.Add()
+ self.add_6_bias = 1.0
+ self.mul_7 = P.Mul()
+
+ def construct(self, x):
+ """construct function"""
+ opt_mul_0 = self.mul_0(x, self.mul_0_w)
+ opt_pow_1 = self.pow_1(x, self.pow_1_input_weight)
+ opt_mul_2 = self.mul_2(opt_pow_1, self.mul_2_w)
+ opt_add_3 = self.add_3(x, opt_mul_2)
+ opt_mul_4 = self.mul_4(opt_add_3, self.mul_4_w)
+ opt_tanh_5 = self.tanh_5(opt_mul_4)
+ opt_add_6 = self.add_6(opt_tanh_5, self.add_6_bias)
+ opt_mul_7 = self.mul_7(opt_mul_0, opt_add_6)
+ return opt_mul_7
+
+
+class TransformerLayerWithLayerNorm(nn.Cell):
+ """Transformer layer with LayerNOrm"""
+ def __init__(self, batch_size, linear3_0_matmul_0_w_shape, linear3_1_matmul_0_w_shape, passthrough_w_0,
+ passthrough_w_1, passthrough_w_2, passthrough_w_3, passthrough_w_4, passthrough_w_5, passthrough_w_6):
+ """init function"""
+ super(TransformerLayerWithLayerNorm, self).__init__()
+ self.multiheadattn_0 = MultiHeadAttn(batch_size=batch_size,
+ passthrough_w_0=passthrough_w_0,
+ passthrough_w_1=passthrough_w_1,
+ passthrough_w_2=passthrough_w_2)
+ self.add_0 = P.Add()
+ self.layernorm1_0 = LayerNorm(passthrough_w_0=passthrough_w_3, passthrough_w_1=passthrough_w_4)
+ self.linear3_0 = Linear(matmul_0_w_shape=linear3_0_matmul_0_w_shape, passthrough_w_0=passthrough_w_5)
+ self.newgelu2_0 = NewGeLU()
+ self.linear3_1 = Linear(matmul_0_w_shape=linear3_1_matmul_0_w_shape, passthrough_w_0=passthrough_w_6)
+ self.add_1 = P.Add()
+
+ def construct(self, x, x0):
+ """construct function"""
+ multiheadattn_0_opt = self.multiheadattn_0(x, x0)
+ opt_add_0 = self.add_0(x, multiheadattn_0_opt)
+ layernorm1_0_opt = self.layernorm1_0(opt_add_0)
+ linear3_0_opt = self.linear3_0(layernorm1_0_opt)
+ newgelu2_0_opt = self.newgelu2_0(linear3_0_opt)
+ linear3_1_opt = self.linear3_1(newgelu2_0_opt)
+ opt_add_1 = self.add_1(linear3_1_opt, layernorm1_0_opt)
+ return opt_add_1
+
+
+class Rerank_Albert(nn.Cell):
+ """Albert model for rerank"""
+ def __init__(self, batch_size):
+ """init function"""
+ super(Rerank_Albert, self).__init__()
+ self.passthrough_w_0 = Parameter(Tensor(np.random.uniform(0, 1, (128,)).astype(np.float32)), name=None)
+ self.passthrough_w_1 = Parameter(Tensor(np.random.uniform(0, 1, (128,)).astype(np.float32)), name=None)
+ self.passthrough_w_2 = Parameter(Tensor(np.random.uniform(0, 1, (4096,)).astype(np.float32)), name=None)
+ self.passthrough_w_3 = Parameter(Tensor(np.random.uniform(0, 1, (4096,)).astype(np.float32)), name=None)
+ self.passthrough_w_4 = Parameter(Tensor(np.random.uniform(0, 1, (4096,)).astype(np.float32)), name=None)
+ self.passthrough_w_5 = Parameter(Tensor(np.random.uniform(0, 1, (4096,)).astype(np.float32)), name=None)
+ self.passthrough_w_6 = Parameter(Tensor(np.random.uniform(0, 1, (4096,)).astype(np.float32)), name=None)
+ self.passthrough_w_7 = Parameter(Tensor(np.random.uniform(0, 1, (4096,)).astype(np.float32)), name=None)
+ self.passthrough_w_8 = Parameter(Tensor(np.random.uniform(0, 1, (16384,)).astype(np.float32)), name=None)
+ self.passthrough_w_9 = Parameter(Tensor(np.random.uniform(0, 1, (4096,)).astype(np.float32)), name=None)
+ self.passthrough_w_10 = Parameter(Tensor(np.random.uniform(0, 1, (4096,)).astype(np.float32)), name=None)
+ self.passthrough_w_11 = Parameter(Tensor(np.random.uniform(0, 1, (4096,)).astype(np.float32)), name=None)
+ self.expanddims_0 = P.ExpandDims()
+ self.expanddims_0_axis = 1
+ self.expanddims_3 = P.ExpandDims()
+ self.expanddims_3_axis = 2
+ self.cast_5 = P.Cast()
+ self.cast_5_to = mstype.float32
+ self.sub_7 = P.Sub()
+ self.sub_7_bias = 1.0
+ self.mul_9 = P.Mul()
+ self.mul_9_w = -10000.0
+ self.gather_1_input_weight = Parameter(Tensor(np.random.uniform(0, 1, (30005, 128)).astype(np.float32)),
+ name=None)
+ self.gather_1_axis = 0
+ self.gather_1 = P.Gather()
+ self.gather_2_input_weight = Parameter(Tensor(np.random.uniform(0, 1, (2, 128)).astype(np.float32)), name=None)
+ self.gather_2_axis = 0
+ self.gather_2 = P.Gather()
+ self.add_4 = P.Add()
+ self.add_4_bias = Parameter(Tensor(np.random.uniform(0, 1, (1, 512, 128)).astype(np.float32)), name=None)
+ self.add_6 = P.Add()
+ self.layernorm1_0 = LayerNorm(passthrough_w_0=self.passthrough_w_0, passthrough_w_1=self.passthrough_w_1)
+ self.linear3_0 = Linear(matmul_0_w_shape=(128, 4096), passthrough_w_0=self.passthrough_w_2)
+ self.module34_0 = TransformerLayerWithLayerNorm(batch_size=batch_size,
+ linear3_0_matmul_0_w_shape=(4096, 16384),
+ linear3_1_matmul_0_w_shape=(16384, 4096),
+ passthrough_w_0=self.passthrough_w_3,
+ passthrough_w_1=self.passthrough_w_4,
+ passthrough_w_2=self.passthrough_w_5,
+ passthrough_w_3=self.passthrough_w_6,
+ passthrough_w_4=self.passthrough_w_7,
+ passthrough_w_5=self.passthrough_w_8,
+ passthrough_w_6=self.passthrough_w_9)
+ self.layernorm1_1 = LayerNorm(passthrough_w_0=self.passthrough_w_10, passthrough_w_1=self.passthrough_w_11)
+
+ def construct(self, input_ids, attention_mask, token_type_ids):
+ """construct function"""
+ opt_expanddims_0 = self.expanddims_0(attention_mask, self.expanddims_0_axis)
+ opt_expanddims_3 = self.expanddims_3(opt_expanddims_0, self.expanddims_3_axis)
+ opt_cast_5 = self.cast_5(opt_expanddims_3, self.cast_5_to)
+ opt_sub_7 = self.sub_7(self.sub_7_bias, opt_cast_5)
+ opt_mul_9 = self.mul_9(opt_sub_7, self.mul_9_w)
+ opt_gather_1_axis = self.gather_1_axis
+ opt_gather_1 = self.gather_1(self.gather_1_input_weight, input_ids, opt_gather_1_axis)
+ opt_gather_2_axis = self.gather_2_axis
+ opt_gather_2 = self.gather_2(self.gather_2_input_weight, token_type_ids, opt_gather_2_axis)
+ opt_add_4 = self.add_4(opt_gather_1, self.add_4_bias)
+ opt_add_6 = self.add_6(opt_add_4, opt_gather_2)
+ layernorm1_0_opt = self.layernorm1_0(opt_add_6)
+ linear3_0_opt = self.linear3_0(layernorm1_0_opt)
+ opt = self.module34_0(linear3_0_opt, opt_mul_9)
+ opt = self.layernorm1_1(opt)
+ for _ in range(11):
+ opt = self.module34_0(opt, opt_mul_9)
+ opt = self.layernorm1_1(opt)
+ return opt
diff --git a/model_zoo/research/nlp/tprr/src/rerank_and_reader_data_generator.py b/model_zoo/research/nlp/tprr/src/rerank_and_reader_data_generator.py
new file mode 100644
index 0000000000..fd9d83921e
--- /dev/null
+++ b/model_zoo/research/nlp/tprr/src/rerank_and_reader_data_generator.py
@@ -0,0 +1,183 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""define a data generator"""
+
+import gzip
+import pickle
+import random
+import numpy as np
+
+
+random.seed(42)
+np.random.seed(42)
+
+
+class DataGenerator:
+ """data generator for reranker and reader"""
+ def __init__(self, feature_file_path, example_file_path, batch_size, seq_len,
+ para_limit=None, sent_limit=None, task_type=None):
+ """init function"""
+ self.example_ptr = 0
+ self.bsz = batch_size
+ self.seq_length = seq_len
+ self.para_limit = para_limit
+ self.sent_limit = sent_limit
+ self.task_type = task_type
+
+ self.feature_file_path = feature_file_path
+ self.example_file_path = example_file_path
+ self.features = self.load_features()
+ self.examples = self.load_examples()
+ self.feature_dict = self.get_feature_dict()
+ self.example_dict = self.get_example_dict()
+
+ self.features = self.padding_feature(self.features, self.bsz)
+
+ def load_features(self):
+ """load features from feature file"""
+ with gzip.open(self.feature_file_path, 'rb') as fin:
+ features = pickle.load(fin)
+ print("load features successful !!!")
+ return features
+
+ def padding_feature(self, features, bsz):
+ """padding features as multiples of batch size"""
+ padding_num = ((len(features) // bsz + 1) * bsz - len(features))
+ print(f"features padding num is {padding_num}")
+ new_features = features + features[:padding_num]
+ return new_features
+
+ def load_examples(self):
+ """laod examples from file"""
+ if self.example_file_path:
+ with gzip.open(self.example_file_path, 'rb') as fin:
+ examples = pickle.load(fin)
+ print("load examples successful !!!")
+ return examples
+ return {}
+
+ def get_feature_dict(self):
+ """build a feature dict"""
+ return {f.unique_id: f for f in self.features}
+
+ def get_example_dict(self):
+ """build a example dict"""
+ if self.example_file_path:
+ return {e.unique_id: e for e in self.examples}
+ return {}
+
+ def common_process_single_case(self, i, case, context_idxs, context_mask, segment_idxs, ids, path, unique_ids):
+ """common process for a single case"""
+ context_idxs[i] = np.array(case.doc_input_ids)
+ context_mask[i] = np.array(case.doc_input_mask)
+ segment_idxs[i] = np.array(case.doc_segment_ids)
+
+ ids.append(case.qas_id)
+ path.append(case.path)
+ unique_ids.append(case.unique_id)
+
+ return context_idxs, context_mask, segment_idxs, ids, path, unique_ids
+
+ def reader_process_single_case(self, i, case, sent_names, square_mask, query_mapping, ques_start_mapping,
+ para_start_mapping, sent_end_mapping):
+ """process for a single case about reader"""
+ sent_names.append(case.sent_names)
+ prev_position = None
+ for cur_position, token_id in enumerate(case.doc_input_ids):
+ if token_id >= 30000:
+ if prev_position:
+ square_mask[i, prev_position + 1: cur_position, prev_position + 1: cur_position] = 1.0
+ prev_position = cur_position
+ if case.sent_spans:
+ for j in range(case.sent_spans[0][0] - 1):
+ query_mapping[i, j] = 1
+ ques_start_mapping[i, 0, 1] = 1
+ for j, para_span in enumerate(case.para_spans[:self.para_limit]):
+ start, end, _ = para_span
+ if start <= end:
+ para_start_mapping[i, j, start] = 1
+ for j, sent_span in enumerate(case.sent_spans[:self.sent_limit]):
+ start, end = sent_span
+ if start <= end:
+ end = min(end, self.seq_length - 1)
+ sent_end_mapping[i, j, end] = 1
+ return sent_names, square_mask, query_mapping, ques_start_mapping, para_start_mapping, sent_end_mapping
+
+ def __iter__(self):
+ """iteration function"""
+ while True:
+ if self.example_ptr >= len(self.features):
+ break
+ start_id = self.example_ptr
+ cur_bsz = min(self.bsz, len(self.features) - start_id)
+ cur_batch = self.features[start_id: start_id + cur_bsz]
+ # BERT input
+ context_idxs = np.zeros((cur_bsz, self.seq_length))
+ context_mask = np.zeros((cur_bsz, self.seq_length))
+ segment_idxs = np.zeros((cur_bsz, self.seq_length))
+
+ # others
+ ids = []
+ path = []
+ unique_ids = []
+
+ if self.task_type == "reader":
+ # Mappings
+ ques_start_mapping = np.zeros((cur_bsz, 1, self.seq_length))
+ query_mapping = np.zeros((cur_bsz, self.seq_length))
+ para_start_mapping = np.zeros((cur_bsz, self.para_limit, self.seq_length))
+ sent_end_mapping = np.zeros((cur_bsz, self.sent_limit, self.seq_length))
+ square_mask = np.zeros((cur_bsz, self.seq_length, self.seq_length))
+ sent_names = []
+
+ for i, case in enumerate(cur_batch):
+ context_idxs, context_mask, segment_idxs, ids, path, unique_ids = \
+ self.common_process_single_case(i, case, context_idxs, context_mask, segment_idxs, ids, path,
+ unique_ids)
+ if self.task_type == "reader":
+ sent_names, square_mask, query_mapping, ques_start_mapping, para_start_mapping, sent_end_mapping = \
+ self.reader_process_single_case(i, case, sent_names, square_mask, query_mapping,
+ ques_start_mapping, para_start_mapping, sent_end_mapping)
+
+ self.example_ptr += cur_bsz
+
+ if self.task_type == "reranker":
+ yield {
+ "context_idxs": context_idxs,
+ "context_mask": context_mask,
+ "segment_idxs": segment_idxs,
+
+ "ids": ids,
+ "unique_ids": unique_ids,
+ "path": path
+ }
+ elif self.task_type == "reader":
+ yield {
+ "context_idxs": context_idxs,
+ "context_mask": context_mask,
+ "segment_idxs": segment_idxs,
+ "query_mapping": query_mapping,
+ "para_start_mapping": para_start_mapping,
+ "sent_end_mapping": sent_end_mapping,
+ "square_mask": square_mask,
+ "ques_start_mapping": ques_start_mapping,
+
+ "ids": ids,
+ "unique_ids": unique_ids,
+ "sent_names": sent_names,
+ "path": path
+ }
+ else:
+ print(f"data generator received a error type: {self.task_type} !!!")
diff --git a/model_zoo/research/nlp/tprr/src/rerank_and_reader_utils.py b/model_zoo/research/nlp/tprr/src/rerank_and_reader_utils.py
new file mode 100644
index 0000000000..bd4dba3762
--- /dev/null
+++ b/model_zoo/research/nlp/tprr/src/rerank_and_reader_utils.py
@@ -0,0 +1,656 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""utils"""
+
+import re
+import argparse
+from urllib.parse import unquote
+from collections import defaultdict
+import collections
+import logging
+import unicodedata
+import json
+import gzip
+import string
+import pickle
+import sqlite3
+from tqdm import tqdm
+
+import numpy as np
+from transformers import BasicTokenizer
+
+
+logger = logging.getLogger(__name__)
+
+
+class Example:
+ """A single example of data"""
+ def __init__(self,
+ qas_id,
+ path,
+ unique_id,
+ question_tokens,
+ doc_tokens,
+ sent_names,
+ sup_fact_id,
+ sup_para_id,
+ para_start_end_position,
+ sent_start_end_position,
+ question_text,
+ title_start_end_position=None):
+ """init function"""
+ self.qas_id = qas_id
+ self.path = path
+ self.unique_id = unique_id
+ self.question_tokens = question_tokens
+ self.doc_tokens = doc_tokens
+ self.question_text = question_text
+ self.sent_names = sent_names
+ self.sup_fact_id = sup_fact_id
+ self.sup_para_id = sup_para_id
+ self.para_start_end_position = para_start_end_position
+ self.sent_start_end_position = sent_start_end_position
+ self.title_start_end_position = title_start_end_position
+
+
+class InputFeatures:
+ """A single set of features of data."""
+
+ def __init__(self,
+ unique_id,
+ qas_id,
+ path,
+ sent_names,
+ doc_tokens,
+ doc_input_ids,
+ doc_input_mask,
+ doc_segment_ids,
+ query_tokens,
+ query_input_ids,
+ query_input_mask,
+ query_segment_ids,
+ para_spans,
+ sent_spans,
+ token_to_orig_map):
+ """init function"""
+ self.qas_id = qas_id
+ self.doc_tokens = doc_tokens
+ self.doc_input_ids = doc_input_ids
+ self.doc_input_mask = doc_input_mask
+ self.doc_segment_ids = doc_segment_ids
+ self.path = path
+ self.unique_id = unique_id
+ self.sent_names = sent_names
+
+ self.query_tokens = query_tokens
+ self.query_input_ids = query_input_ids
+ self.query_input_mask = query_input_mask
+ self.query_segment_ids = query_segment_ids
+
+ self.para_spans = para_spans
+ self.sent_spans = sent_spans
+
+ self.token_to_orig_map = token_to_orig_map
+
+
+class DocDB:
+ """
+ Sqlite backed document storage.
+ Implements get_doc_text(doc_id).
+ """
+
+ def __init__(self, db_path):
+ """init function"""
+ self.path = db_path
+ self.connection = sqlite3.connect(self.path, check_same_thread=False)
+
+ def __enter__(self):
+ """enter function"""
+ return self
+
+ def __exit__(self, *args):
+ """exit function"""
+ self.close()
+
+ def close(self):
+ """Close the connection to the database."""
+ self.connection.close()
+
+ def get_doc_ids(self):
+ """Fetch all ids of docs stored in the db."""
+ cursor = self.connection.cursor()
+ cursor.execute("SELECT id FROM documents")
+ results = [r[0] for r in cursor.fetchall()]
+ cursor.close()
+ return results
+
+ def get_doc_info(self, doc_id):
+ """get docment information"""
+ if not doc_id.endswith('_0'):
+ doc_id += '_0'
+ cursor = self.connection.cursor()
+ cursor.execute(
+ "SELECT * FROM documents WHERE id = ?",
+ (normalize_title(doc_id),)
+ )
+ result = cursor.fetchall()
+ cursor.close()
+ return result if result is None else result[0]
+
+
+def get_parse():
+ """get parse function"""
+ parser = argparse.ArgumentParser()
+
+ # Environment
+ parser.add_argument('--seed', type=int, default=42,
+ help="random seed for initialization")
+ parser.add_argument('--seq_len', type=int, default=512,
+ help="max sentence length")
+ parser.add_argument("--get_reranker_data",
+ action='store_true',
+ help="Set this flag if you want to get reranker data from retrieved result")
+ parser.add_argument("--run_reranker",
+ action='store_true',
+ help="Set this flag if you want to run reranker")
+ parser.add_argument("--cal_reranker_metrics",
+ action='store_true',
+ help="Set this flag if you want to calculate rerank metrics")
+ parser.add_argument("--select_reader_data",
+ action='store_true',
+ help="Set this flag if you want to select reader data")
+ parser.add_argument("--run_reader",
+ action='store_true',
+ help="Set this flag if you want to run reader")
+ parser.add_argument("--cal_reader_metrics",
+ action='store_true',
+ help="Set this flag if you want to calculate reader metrics")
+ parser.add_argument('--dev_gold_file',
+ type=str,
+ default="../hotpot_dev_fullwiki_v1.json",
+ help='file of dev ground truth')
+ parser.add_argument('--wiki_db_file',
+ type=str,
+ default="../enwiki_offset.db",
+ help='wiki_database_file')
+ parser.add_argument('--albert_model_path',
+ type=str,
+ default="../albert-xxlarge/",
+ help='model path of huggingface albert-xxlarge')
+
+ # Retriever
+ parser.add_argument('--retriever_result_file',
+ type=str,
+ default="../doc_path",
+ help='file of retriever result')
+
+ # Rerank
+ parser.add_argument('--rerank_batch_size', type=int, default=32,
+ help="rerank batchsize for evaluating")
+ parser.add_argument('--rerank_feature_file',
+ type=str,
+ default="../reranker_feature_file.pkl.gz",
+ help='file of rerank feature')
+ parser.add_argument('--rerank_example_file',
+ type=str,
+ default="../reranker_example_file.pkl.gz",
+ help='file of rerank example')
+ parser.add_argument('--rerank_result_file',
+ type=str,
+ default="../rerank_result.json",
+ help='file of rerank result')
+ parser.add_argument('--rerank_encoder_ck_file',
+ type=str,
+ default="../rerank_albert_12.ckpt",
+ help='checkpoint of rerank albert-xxlarge')
+ parser.add_argument('--rerank_downstream_ck_file',
+ type=str,
+ default="../rerank_downstream.ckpt",
+ help='checkpoint of rerank downstream')
+
+ # Reader
+ parser.add_argument('--reader_batch_size', type=int, default=32,
+ help="reader batchsize for evaluating")
+ parser.add_argument('--reader_feature_file',
+ type=str,
+ default="../reader_feature_file.pkl.gz",
+ help='file of reader feature')
+ parser.add_argument('--reader_example_file',
+ type=str,
+ default="../reader_example_file.pkl.gz",
+ help='file of reader example')
+ parser.add_argument('--reader_encoder_ck_file',
+ type=str,
+ default="../albert_12_layer.ckpt",
+ help='checkpoint of reader albert-xxlarge')
+ parser.add_argument('--reader_downstream_ck_file',
+ type=str,
+ default="../reader_downstream.ckpt",
+ help='checkpoint of reader downstream')
+ parser.add_argument('--reader_result_file',
+ type=str,
+ default="../reader_result_file.json",
+ help='file of reader result')
+ parser.add_argument('--sp_threshold', type=float, default=0.65,
+ help="threshold for selecting supporting sentences")
+ parser.add_argument("--max_para_num", default=2, type=int)
+ parser.add_argument("--max_sent_num", default=40, type=int)
+
+ return parser
+
+
+def select_reader_dev_data(args):
+ """select reader dev data from result of retriever based on result of reranker"""
+ rerank_result_file = args.rerank_result_file
+ rerank_feature_file = args.rerank_feature_file
+ rerank_example_file = args.rerank_example_file
+ reader_feature_file = args.reader_feature_file
+ reader_example_file = args.reader_example_file
+
+ with gzip.open(rerank_example_file, "rb") as f:
+ dev_examples = pickle.load(f)
+ with gzip.open(rerank_feature_file, "rb") as f:
+ dev_features = pickle.load(f)
+ with open(rerank_result_file, "r") as f:
+ rerank_result = json.load(f)
+
+ new_dev_examples = []
+ new_dev_features = []
+
+ rerank_unique_ids = defaultdict(int)
+ feature_unique_ids = defaultdict(int)
+
+ for _, res in tqdm(rerank_result.items(), desc="get rerank unique ids"):
+ rerank_unique_ids[res[0]] = True
+ print(f"rerank result num is {len(rerank_unique_ids)}")
+
+ for feature in tqdm(dev_features, desc="select rerank top1 feature"):
+ if feature.unique_id in rerank_unique_ids:
+ feature_unique_ids[feature.unique_id] = True
+ new_dev_features.append(feature)
+ print(f"new feature num is {len(new_dev_features)}")
+
+ for example in tqdm(dev_examples, desc="select rerank top1 example"):
+ if example.unique_id in rerank_unique_ids and example.unique_id in feature_unique_ids:
+ new_dev_examples.append(example)
+ print(f"new examples num is {len(new_dev_examples)}")
+
+ print("start save new examples ......")
+ with gzip.open(reader_example_file, "wb") as f:
+ pickle.dump(new_dev_examples, f)
+
+ print("start save new features ......")
+ with gzip.open(reader_feature_file, "wb") as f:
+ pickle.dump(new_dev_features, f)
+ print("finish selecting reader data !!!")
+
+
+def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False):
+ """Project the tokenized prediction back to the original text."""
+ def _strip_spaces(text):
+ ns_chars = []
+ ns_to_s_map = collections.OrderedDict()
+ for (i, c) in enumerate(text):
+ if c == " ":
+ continue
+ ns_to_s_map[len(ns_chars)] = i
+ ns_chars.append(c)
+ ns_text = "".join(ns_chars)
+ return (ns_text, ns_to_s_map)
+
+ tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
+
+ tok_text = " ".join(tokenizer.tokenize(orig_text))
+
+ start_position = tok_text.find(pred_text)
+ if start_position == -1:
+ if verbose_logging:
+ print("Unable to find text: '%s' in '%s'" % (pred_text, orig_text))
+ return orig_text
+ end_position = start_position + len(pred_text) - 1
+
+ (orig_ns_text, orig_ns_to_s_map) = _strip_spaces(orig_text)
+ (tok_ns_text, tok_ns_to_s_map) = _strip_spaces(tok_text)
+
+ if len(orig_ns_text) != len(tok_ns_text):
+ if verbose_logging:
+ logger.info("Length not equal after stripping spaces: '%s' vs '%s'",
+ orig_ns_text, tok_ns_text)
+ return orig_text
+
+ # We then project the characters in `pred_text` back to `orig_text` using
+ # the character-to-character alignment.
+ tok_s_to_ns_map = {}
+ for (i, tok_index) in tok_ns_to_s_map.items():
+ tok_s_to_ns_map[tok_index] = i
+
+ orig_start_position = None
+ if start_position in tok_s_to_ns_map:
+ ns_start_position = tok_s_to_ns_map[start_position]
+ if ns_start_position in orig_ns_to_s_map:
+ orig_start_position = orig_ns_to_s_map[ns_start_position]
+
+ if orig_start_position is None:
+ if verbose_logging:
+ print("Couldn't map start position")
+ return orig_text
+
+ orig_end_position = None
+ if end_position in tok_s_to_ns_map:
+ ns_end_position = tok_s_to_ns_map[end_position]
+ if ns_end_position in orig_ns_to_s_map:
+ orig_end_position = orig_ns_to_s_map[ns_end_position]
+
+ if orig_end_position is None:
+ if verbose_logging:
+ print("Couldn't map end position")
+ return orig_text
+
+ output_text = orig_text[orig_start_position:(orig_end_position + 1)]
+ return output_text
+
+
+def get_ans_from_pos(tokenizer, examples, features, y1, y2, unique_id):
+ """get answer text from predicted position"""
+ feature = features[unique_id]
+ example = examples[unique_id]
+ tok_to_orig_map = feature.token_to_orig_map
+ orig_all_tokens = example.question_tokens + example.doc_tokens
+
+ final_text = " "
+ if y1 < len(tok_to_orig_map) and y2 < len(tok_to_orig_map):
+ orig_tok_start = tok_to_orig_map[y1]
+ orig_tok_end = tok_to_orig_map[y2]
+ # -----------------orig all tokens-----------------------------------
+ orig_tokens = orig_all_tokens[orig_tok_start: (orig_tok_end + 1)]
+ tok_tokens = feature.doc_tokens[y1: y2 + 1]
+ tok_text = tokenizer.convert_tokens_to_string(tok_tokens)
+ # Clean whitespace
+ tok_text = tok_text.strip()
+ tok_text = " ".join(tok_text.split())
+ orig_text = " ".join(orig_tokens)
+ final_text = get_final_text(tok_text, orig_text, True, False)
+ # print("final_text: " + final_text)
+ return final_text
+
+
+def convert_to_tokens(examples, features, ids, y1, y2, q_type_prob, tokenizer, sent, sent_names,
+ unique_ids):
+ """get raw answer text and supporting sentences"""
+ answer_dict = defaultdict(list)
+
+ q_type = np.argmax(q_type_prob, 1)
+
+ for i, qid in enumerate(ids):
+ unique_id = unique_ids[i]
+
+ if q_type[i] == 0:
+ answer_text = 'yes'
+ elif q_type[i] == 1:
+ answer_text = 'no'
+ elif q_type[i] == 2:
+ answer_text = get_ans_from_pos(tokenizer, examples, features, y1[i], y2[i], unique_id)
+ else:
+ raise ValueError("question type error")
+
+ answer_dict[qid].append(answer_text)
+ answer_dict[qid].append(sent[i])
+ answer_dict[qid].append(sent_names[i])
+
+ return answer_dict
+
+
+def normalize_title(text):
+ """Resolve different type of unicode encodings / capitarization in HotpotQA data."""
+ text = unicodedata.normalize('NFD', text)
+ return text[0].capitalize() + text[1:]
+
+
+def make_wiki_id(title, para_index):
+ """make wiki id"""
+ title_id = "{0}_{1}".format(normalize_title(title), para_index)
+ return title_id
+
+
+def cal_reranker_metrics(dev_gold_file, rerank_result_file):
+ """function for calculating reranker's metrics"""
+ with open(dev_gold_file, 'rb') as f:
+ gt = json.load(f)
+ with open(rerank_result_file, 'rb') as f:
+ rerank_result = json.load(f)
+
+ cnt = 0
+ all_ = len(gt)
+
+ cnt_c = 0
+ cnt_b = 0
+ all_c = 0
+ all_b = 0
+
+ for item in tqdm(gt, desc="get com and bridge "):
+ q_type = item["type"]
+ if q_type == "comparison":
+ all_c += 1
+ elif q_type == "bridge":
+ all_b += 1
+ else:
+ print(f"{q_type} is a error question type!!!")
+
+ for item in tqdm(gt, desc="cal pem"):
+ _id = item["_id"]
+
+ if _id in rerank_result:
+ pred = rerank_result[_id][1]
+ sps = item["supporting_facts"]
+ q_type = item["type"]
+ gold = []
+ for t in sps:
+ gold.append(normalize_title(t[0]))
+ gold = set(gold)
+ flag = True
+ for t in gold:
+ if t not in pred:
+ flag = False
+ break
+ if flag:
+ cnt += 1
+ if q_type == "comparison":
+ cnt_c += 1
+ elif q_type == "bridge":
+ cnt_b += 1
+ else:
+ print(f"{q_type} is a error question type!!!")
+
+ return cnt/all_, cnt_c/all_c, cnt_b/all_b
+
+
+def whitespace_tokenize(text):
+ """Runs basic whitespace cleaning and splitting on a piece of text."""
+ text = text.strip()
+ if not text:
+ return []
+ tokens = text.split()
+ return tokens
+
+
+def find_hyper_linked_titles(text_w_links):
+ """find hyperlinked titles"""
+ titles = re.findall(r'href=[\'"]?([^\'" >]+)', text_w_links)
+ titles = [unquote(title) for title in titles]
+ titles = [title[0].capitalize() + title[1:] for title in titles]
+ return titles
+
+
+def normalize_text(text):
+ """Resolve different type of unicode encodings / capitarization in HotpotQA data."""
+ text = unicodedata.normalize('NFD', text)
+ return text
+
+
+def convert_char_to_token_offset(orig_text, start_offset, end_offset, char_to_word_offset, doc_tokens):
+ """build characters' offset"""
+ length = len(orig_text)
+ assert start_offset + length == end_offset
+ assert end_offset <= len(char_to_word_offset)
+
+ start_position = char_to_word_offset[start_offset]
+ end_position = char_to_word_offset[start_offset + length - 1]
+
+ actual_text = " ".join(
+ doc_tokens[start_position:(end_position + 1)])
+
+ assert actual_text.lower().find(orig_text.lower()) != -1
+ return start_position, end_position
+
+
+def _is_whitespace(c):
+ """check whitespace"""
+ if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F:
+ return True
+ return False
+
+
+def convert_text_to_tokens(context_text, return_word_start=False):
+ """convert text to tokens"""
+ doc_tokens = []
+ char_to_word_offset = []
+ words_start_idx = []
+ prev_is_whitespace = True
+
+ for idx, c in enumerate(context_text):
+ if _is_whitespace(c):
+ prev_is_whitespace = True
+ else:
+ if prev_is_whitespace:
+ doc_tokens.append(c)
+ words_start_idx.append(idx)
+ else:
+ doc_tokens[-1] += c
+ prev_is_whitespace = False
+ char_to_word_offset.append(len(doc_tokens) - 1)
+
+ if not return_word_start:
+ return doc_tokens, char_to_word_offset
+ return doc_tokens, char_to_word_offset, words_start_idx
+
+
+def read_json(eval_file_name):
+ """reader json files"""
+ print("loading examples from {0}".format(eval_file_name))
+ with open(eval_file_name) as reader:
+ lines = json.load(reader)
+ return lines
+
+
+def write_json(data, out_file_name):
+ """write json files"""
+ print("writing {0} examples to {1}".format(len(data), out_file_name))
+ with open(out_file_name, 'w') as writer:
+ json.dump(data, writer, indent=4)
+
+
+def get_edges(sentence):
+ """get edges"""
+ EDGE_XY = re.compile(r'(.*?)<\/a>')
+ ret = EDGE_XY.findall(sentence)
+ return [(unquote(x), y) for x, y in ret]
+
+
+def relocate_tok_span(orig_to_tok_index, orig_to_tok_back_index, word_tokens, subword_tokens,
+ orig_start_position, orig_end_position, orig_text, tokenizer, tok_to_orig_index=None):
+ """relocate tokens' span"""
+ if orig_start_position is None:
+ return 0, 0
+
+ tok_start_position = orig_to_tok_index[orig_start_position]
+ if tok_start_position >= len(subword_tokens):
+ return 0, 0
+
+ if orig_end_position < len(word_tokens) - 1:
+ tok_end_position = orig_to_tok_back_index[orig_end_position]
+ if tok_to_orig_index and tok_to_orig_index[tok_end_position + 1] == -1:
+ assert tok_end_position <= orig_to_tok_index[orig_end_position + 1] - 2
+ else:
+ assert tok_end_position == orig_to_tok_index[orig_end_position + 1] - 1
+ else:
+ tok_end_position = orig_to_tok_back_index[orig_end_position]
+ return _improve_answer_span(
+ subword_tokens, tok_start_position, tok_end_position, tokenizer, orig_text)
+
+
+def generate_mapping(length, positions):
+ """generate mapping"""
+ start_mapping = [0] * length
+ end_mapping = [0] * length
+ for _, (start, end) in enumerate(positions):
+ start_mapping[start] = 1
+ end_mapping[end] = 1
+ return start_mapping, end_mapping
+
+
+def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer, orig_answer_text):
+ """Returns tokenized answer spans that better match the annotated answer."""
+ tok_answer_text = " ".join(tokenizer.tokenize(orig_answer_text, add_prefix_space=True))
+
+ for new_start in range(input_start, input_end + 1):
+ for new_end in range(input_end, new_start - 1, -1):
+ text_span = " ".join(doc_tokens[new_start: (new_end + 1)])
+ if text_span == tok_answer_text:
+ return new_start, new_end
+
+ return input_start, input_end
+
+
+def _largest_valid_index(spans, limit):
+ """return largest valid index"""
+ for idx, _ in enumerate(spans):
+ if spans[idx][1] >= limit:
+ return idx
+ return len(spans)
+
+
+def remove_punc(text):
+ """remove punctuation"""
+ if text == " ":
+ return ''
+ exclude = set(string.punctuation)
+ return ''.join(ch for ch in text if ch not in exclude)
+
+
+def check_text_include_ans(ans, text):
+ """check whether text include answer"""
+ if normalize_answer(ans) in normalize_answer(text):
+ return True
+ return False
+
+
+def remove_articles(text):
+ """remove articles"""
+ return re.sub(r'\b(a|an|the)\b', ' ', text)
+
+
+def white_space_fix(text):
+ """fix whitespace"""
+ return ' '.join(text.split())
+
+
+def lower(text):
+ """lower text"""
+ return text.lower()
+
+
+def normalize_answer(s):
+ """Lower text and remove punctuation, articles and extra whitespace."""
+ return white_space_fix(remove_articles(remove_punc(lower(s))))
diff --git a/model_zoo/research/nlp/tprr/src/rerank_downstream.py b/model_zoo/research/nlp/tprr/src/rerank_downstream.py
new file mode 100644
index 0000000000..29bc41a7a5
--- /dev/null
+++ b/model_zoo/research/nlp/tprr/src/rerank_downstream.py
@@ -0,0 +1,61 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""downstream Model for reranker"""
+
+import numpy as np
+from mindspore import nn
+from mindspore import Tensor, Parameter
+from mindspore.ops import operations as P
+
+
+class Rerank_Downstream(nn.Cell):
+ """Downstream model for rerank"""
+ def __init__(self):
+ """init function"""
+ super(Rerank_Downstream, self).__init__()
+ self.dense_0 = nn.Dense(in_channels=4096, out_channels=8192, has_bias=True)
+ self.relu_1 = nn.ReLU()
+ self.reducemean_2 = P.ReduceMean(keep_dims=True)
+ self.sub_3 = P.Sub()
+ self.sub_4 = P.Sub()
+ self.pow_5 = P.Pow()
+ self.pow_5_input_weight = 2.0
+ self.reducemean_6 = P.ReduceMean(keep_dims=True)
+ self.add_7 = P.Add()
+ self.add_7_bias = 9.999999960041972e-13
+ self.sqrt_8 = P.Sqrt()
+ self.div_9 = P.Div()
+ self.mul_10 = P.Mul()
+ self.mul_10_w = Parameter(Tensor(np.random.uniform(0, 1, (8192,)).astype(np.float32)), name=None)
+ self.add_11 = P.Add()
+ self.add_11_bias = Parameter(Tensor(np.random.uniform(0, 1, (8192,)).astype(np.float32)), name=None)
+ self.dense_12 = nn.Dense(in_channels=8192, out_channels=2, has_bias=True)
+
+ def construct(self, x):
+ """construct function"""
+ opt_dense_0 = self.dense_0(x)
+ opt_relu_1 = self.relu_1(opt_dense_0)
+ opt_reducemean_2 = self.reducemean_2(opt_relu_1, -1)
+ opt_sub_3 = self.sub_3(opt_relu_1, opt_reducemean_2)
+ opt_sub_4 = self.sub_4(opt_relu_1, opt_reducemean_2)
+ opt_pow_5 = self.pow_5(opt_sub_3, self.pow_5_input_weight)
+ opt_reducemean_6 = self.reducemean_6(opt_pow_5, -1)
+ opt_add_7 = self.add_7(opt_reducemean_6, self.add_7_bias)
+ opt_sqrt_8 = self.sqrt_8(opt_add_7)
+ opt_div_9 = self.div_9(opt_sub_4, opt_sqrt_8)
+ opt_mul_10 = self.mul_10(self.mul_10_w, opt_div_9)
+ opt_add_11 = self.add_11(opt_mul_10, self.add_11_bias)
+ opt_dense_12 = self.dense_12(opt_add_11)
+ return opt_dense_12
diff --git a/model_zoo/research/nlp/tprr/src/reranker.py b/model_zoo/research/nlp/tprr/src/reranker.py
new file mode 100644
index 0000000000..ef732b696d
--- /dev/null
+++ b/model_zoo/research/nlp/tprr/src/reranker.py
@@ -0,0 +1,45 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Reranker Model"""
+
+import mindspore.nn as nn
+from mindspore import load_checkpoint, load_param_into_net
+from src.rerank_albert_xxlarge import Rerank_Albert
+from src.rerank_downstream import Rerank_Downstream
+
+
+class Reranker(nn.Cell):
+ """Reranker model"""
+ def __init__(self, batch_size, encoder_ck_file, downstream_ck_file):
+ """init function"""
+ super(Reranker, self).__init__(auto_prefix=False)
+
+ self.encoder = Rerank_Albert(batch_size)
+ param_dict = load_checkpoint(encoder_ck_file)
+ not_load_params_1 = load_param_into_net(self.encoder, param_dict)
+ print(f"not loaded albert: {not_load_params_1}")
+
+ self.no_answer_mlp = Rerank_Downstream()
+ param_dict = load_checkpoint(downstream_ck_file)
+ not_load_params_2 = load_param_into_net(self.no_answer_mlp, param_dict)
+ print(f"not loaded downstream: {not_load_params_2}")
+
+ def construct(self, input_ids, attn_mask, token_type_ids):
+ """construct function"""
+ state = self.encoder(input_ids, attn_mask, token_type_ids)
+ state = state[:, 0, :]
+
+ no_answer = self.no_answer_mlp(state)
+ return no_answer
diff --git a/model_zoo/research/nlp/tprr/src/reranker_eval.py b/model_zoo/research/nlp/tprr/src/reranker_eval.py
new file mode 100644
index 0000000000..4c2d89451d
--- /dev/null
+++ b/model_zoo/research/nlp/tprr/src/reranker_eval.py
@@ -0,0 +1,85 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""execute reranker"""
+
+import json
+import random
+from collections import defaultdict
+from time import time
+from tqdm import tqdm
+import numpy as np
+
+from mindspore import Tensor, ops
+from mindspore import dtype as mstype
+
+from src.rerank_and_reader_data_generator import DataGenerator
+from src.reranker import Reranker
+
+
+def rerank(args):
+ """rerank function"""
+ rerank_feature_file = args.rerank_feature_file
+ rerank_result_file = args.rerank_result_file
+ encoder_ck_file = args.rerank_encoder_ck_file
+ downstream_ck_file = args.rerank_downstream_ck_file
+ seed = args.seed
+ seq_len = args.seq_len
+ batch_size = args.rerank_batch_size
+
+ random.seed(seed)
+ np.random.seed(seed)
+
+ t1 = time()
+
+ generator = DataGenerator(feature_file_path=rerank_feature_file,
+ example_file_path=None,
+ batch_size=batch_size, seq_len=seq_len,
+ task_type="reranker")
+ gather_dict = defaultdict(lambda: defaultdict(list))
+
+ reranker = Reranker(batch_size=batch_size,
+ encoder_ck_file=encoder_ck_file,
+ downstream_ck_file=downstream_ck_file)
+
+ print("start re-ranking ...")
+
+ for _, batch in tqdm(enumerate(generator)):
+ input_ids = Tensor(batch["context_idxs"], mstype.int32)
+ attn_mask = Tensor(batch["context_mask"], mstype.int32)
+ token_type_ids = Tensor(batch["segment_idxs"], mstype.int32)
+
+ no_answer = reranker(input_ids, attn_mask, token_type_ids)
+
+ no_answer_prob = ops.Softmax()(no_answer).asnumpy()
+ no_answer_prob = no_answer_prob[:, 0]
+
+ for i in range(len(batch['ids'])):
+ qas_id = batch['ids'][i]
+ gather_dict[qas_id][no_answer_prob[i]].append(batch['unique_ids'][i])
+ gather_dict[qas_id][no_answer_prob[i]].append(batch['path'][i])
+
+ rerank_result = {}
+ for qas_id in tqdm(gather_dict, desc="get top1 path from re-rank result"):
+ all_paths = gather_dict[qas_id]
+ all_paths = sorted(all_paths.items(), key=lambda item: item[0])
+ assert qas_id not in rerank_result
+ rerank_result[qas_id] = all_paths[0][1]
+
+ with open(rerank_result_file, 'w') as f:
+ json.dump(rerank_result, f)
+
+ t2 = time()
+
+ print(f"re-rank cost time: {t2-t1} s")