From: @zhan_ke Reviewed-by: @oacjiewen,@c_34 Signed-off-by: @c_34pull/14922/MERGE
| @@ -66,6 +66,7 @@ After installing MindSpore via the official website and Dataset is correctly gen | |||||
| ```python | ```python | ||||
| # run evaluation example with HotPotQA dev dataset | # run evaluation example with HotPotQA dev dataset | ||||
| pip install transformers | |||||
| sh run_eval_ascend.sh | sh run_eval_ascend.sh | ||||
| sh run_eval_ascend_reranker_reader.sh | sh run_eval_ascend_reranker_reader.sh | ||||
| ``` | ``` | ||||
| @@ -85,22 +86,20 @@ After installing MindSpore via the official website and Dataset is correctly gen | |||||
| ├─src | ├─src | ||||
| | ├─build_reranker_data.py # build data for re-ranker from result of retriever | | ├─build_reranker_data.py # build data for re-ranker from result of retriever | ||||
| | ├─config.py # Evaluation configurations for retriever | | ├─config.py # Evaluation configurations for retriever | ||||
| | ├─converted_bert.py # Bert model for tprr | |||||
| | ├─hotpot_evaluate_v1.py # Hotpotqa evaluation script | | ├─hotpot_evaluate_v1.py # Hotpotqa evaluation script | ||||
| | ├─onehop.py # Onehop model of retriever | | ├─onehop.py # Onehop model of retriever | ||||
| | ├─onehop_bert.py # Onehop bert model of retriever | |||||
| | ├─process_data.py # Data preprocessing for retriever | | ├─process_data.py # Data preprocessing for retriever | ||||
| | ├─reader.py # Reader model | | ├─reader.py # Reader model | ||||
| | ├─reader_albert_xxlarge.py # Albert-xxlarge module of reader model | |||||
| | ├─albert.py # Albert-xxlarge model | |||||
| | ├─reader_downstream.py # Downstream module of reader model | | ├─reader_downstream.py # Downstream module of reader model | ||||
| | ├─reader_eval.py # Reader evaluation script | | ├─reader_eval.py # Reader evaluation script | ||||
| | ├─rerank_albert_xxlarge.py # Albert-xxlarge module of re-ranker model | |||||
| | ├─rerank_and_reader_data_generator.py # Data generator for re-ranker and reader | | ├─rerank_and_reader_data_generator.py # Data generator for re-ranker and reader | ||||
| | ├─rerank_and_reader_utils.py # Utils for re-ranker and reader | | ├─rerank_and_reader_utils.py # Utils for re-ranker and reader | ||||
| | ├─rerank_downstream.py # Downstream module of re-ranker model | | ├─rerank_downstream.py # Downstream module of re-ranker model | ||||
| | ├─reranker.py # Re-ranker model | | ├─reranker.py # Re-ranker model | ||||
| | ├─reranker_eval.py # Re-ranker evaluation script | | ├─reranker_eval.py # Re-ranker evaluation script | ||||
| | ├─twohop.py # Twohop model of retriever | | ├─twohop.py # Twohop model of retriever | ||||
| | ├─twohop_bert.py # Twohop bert model of retriever | |||||
| | └─utils.py # Utils for retriever | | └─utils.py # Utils for retriever | ||||
| | | | | ||||
| ├─retriever_eval.py # Evaluation net for retriever | ├─retriever_eval.py # Evaluation net for retriever | ||||
| @@ -14,6 +14,8 @@ | |||||
| # ============================================================================ | # ============================================================================ | ||||
| """main file""" | """main file""" | ||||
| import os | |||||
| from time import time | |||||
| from mindspore import context | from mindspore import context | ||||
| from src.rerank_and_reader_utils import get_parse, cal_reranker_metrics, select_reader_dev_data | from src.rerank_and_reader_utils import get_parse, cal_reranker_metrics, select_reader_dev_data | ||||
| from src.reranker_eval import rerank | from src.reranker_eval import rerank | ||||
| @@ -27,6 +29,13 @@ def rerank_and_retriever_eval(): | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | ||||
| parser = get_parse() | parser = get_parse() | ||||
| args = parser.parse_args() | args = parser.parse_args() | ||||
| args.dev_gold_path = os.path.join(args.data_path, args.dev_gold_file) | |||||
| args.wiki_db_path = os.path.join(args.data_path, args.wiki_db_file) | |||||
| args.albert_model_path = os.path.join(args.ckpt_path, args.albert_model) | |||||
| args.rerank_encoder_ck_path = os.path.join(args.ckpt_path, args.rerank_encoder_ck_file) | |||||
| args.rerank_downstream_ck_path = os.path.join(args.ckpt_path, args.rerank_downstream_ck_file) | |||||
| args.reader_encoder_ck_path = os.path.join(args.ckpt_path, args.reader_encoder_ck_file) | |||||
| args.reader_downstream_ck_path = os.path.join(args.ckpt_path, args.reader_downstream_ck_file) | |||||
| if args.get_reranker_data: | if args.get_reranker_data: | ||||
| get_rerank_data(args) | get_rerank_data(args) | ||||
| @@ -36,8 +45,7 @@ def rerank_and_retriever_eval(): | |||||
| if args.cal_reranker_metrics: | if args.cal_reranker_metrics: | ||||
| total_top1_pem, _, _ = \ | total_top1_pem, _, _ = \ | ||||
| cal_reranker_metrics(dev_gold_file=args.dev_gold_file, rerank_result_file=args.rerank_result_file) | |||||
| print(f"total top1 pem: {total_top1_pem}") | |||||
| cal_reranker_metrics(dev_gold_file=args.dev_gold_path, rerank_result_file=args.rerank_result_file) | |||||
| if args.select_reader_data: | if args.select_reader_data: | ||||
| select_reader_dev_data(args) | select_reader_dev_data(args) | ||||
| @@ -46,10 +54,18 @@ def rerank_and_retriever_eval(): | |||||
| read(args) | read(args) | ||||
| if args.cal_reader_metrics: | if args.cal_reader_metrics: | ||||
| metrics = hotpotqa_eval(args.reader_result_file, args.dev_gold_file) | |||||
| metrics = hotpotqa_eval(args.reader_result_file, args.dev_gold_path) | |||||
| if args.cal_reranker_metrics: | |||||
| print(f"total top1 pem: {total_top1_pem}") | |||||
| if args.cal_reader_metrics: | |||||
| for k in metrics: | for k in metrics: | ||||
| print(f"{k}: {metrics[k]}") | print(f"{k}: {metrics[k]}") | ||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||
| t1 = time() | |||||
| rerank_and_retriever_eval() | rerank_and_retriever_eval() | ||||
| t2 = time() | |||||
| print(f"eval reranker and reader cost {(t2 - t1) / 3600} h") | |||||
| @@ -31,8 +31,7 @@ from mindspore import load_checkpoint, load_param_into_net | |||||
| from src.onehop import OneHopBert | from src.onehop import OneHopBert | ||||
| from src.twohop import TwoHopBert | from src.twohop import TwoHopBert | ||||
| from src.process_data import DataGen | from src.process_data import DataGen | ||||
| from src.onehop_bert import ModelOneHop | |||||
| from src.twohop_bert import ModelTwoHop | |||||
| from src.converted_bert import ModelOneHop | |||||
| from src.config import ThinkRetrieverConfig | from src.config import ThinkRetrieverConfig | ||||
| from src.utils import read_query, split_queries, get_new_title, get_raw_title, save_json | from src.utils import read_query, split_queries, get_new_title, get_raw_title, save_json | ||||
| @@ -84,10 +83,10 @@ def evaluation(d_id): | |||||
| print('********************** loading model ********************** ') | print('********************** loading model ********************** ') | ||||
| s_lm = time.time() | s_lm = time.time() | ||||
| model_onehop_bert = ModelOneHop() | |||||
| model_onehop_bert = ModelOneHop(256) | |||||
| param_dict = load_checkpoint(config.onehop_bert_path) | param_dict = load_checkpoint(config.onehop_bert_path) | ||||
| load_param_into_net(model_onehop_bert, param_dict) | load_param_into_net(model_onehop_bert, param_dict) | ||||
| model_twohop_bert = ModelTwoHop() | |||||
| model_twohop_bert = ModelOneHop(448) | |||||
| param_dict2 = load_checkpoint(config.twohop_bert_path) | param_dict2 = load_checkpoint(config.twohop_bert_path) | ||||
| load_param_into_net(model_twohop_bert, param_dict2) | load_param_into_net(model_twohop_bert, param_dict2) | ||||
| onehop = OneHopBert(config, model_onehop_bert) | onehop = OneHopBert(config, model_onehop_bert) | ||||
| @@ -15,11 +15,8 @@ | |||||
| # ============================================================================ | # ============================================================================ | ||||
| # eval script | # eval script | ||||
| ulimit -u unlimited | |||||
| export DEVICE_NUM=1 | |||||
| export RANK_SIZE=$DEVICE_NUM | |||||
| export RANK_ID=0 | |||||
| DATAPATH="../data" | |||||
| CKPTPATH="../ckpt" | |||||
| if [ -d "eval_tr" ]; | if [ -d "eval_tr" ]; | ||||
| then | then | ||||
| @@ -34,6 +31,6 @@ cd ./eval_tr || exit | |||||
| env > env.log | env > env.log | ||||
| echo "start evaluation" | echo "start evaluation" | ||||
| python retriever_eval.py > log.txt 2>&1 & | |||||
| python -u retriever_eval.py --vocab_path=$DATAPATH/vocab.txt --wiki_path=$DATAPATH/db_docs_bidirection_new.pkl --dev_path=$DATAPATH/hotpot_dev_fullwiki_v1_for_retriever.json --dev_data_path=$DATAPATH/dev_tf_idf_data_raw.pkl --q_path=$DATAPATH/queries --onehop_bert_path=$CKPTPATH/onehop_new.ckpt --onehop_mlp_path=$CKPTPATH/onehop_mlp.ckpt --twohop_bert_path=$CKPTPATH/twohop_new.ckpt --twohop_mlp_path=$CKPTPATH/twohop_mlp.ckpt > log.txt 2>&1 & | |||||
| cd .. | cd .. | ||||
| @@ -16,6 +16,8 @@ | |||||
| # eval script | # eval script | ||||
| DATAPATH="../data" | |||||
| CKPTPATH="../ckpt" | |||||
| ulimit -u unlimited | ulimit -u unlimited | ||||
| export DEVICE_NUM=1 | export DEVICE_NUM=1 | ||||
| export RANK_SIZE=$DEVICE_NUM | export RANK_SIZE=$DEVICE_NUM | ||||
| @@ -34,6 +36,6 @@ cd ./eval || exit | |||||
| env > env.log | env > env.log | ||||
| echo "start evaluation" | echo "start evaluation" | ||||
| python reranker_and_reader_eval.py --get_reranker_data --run_reranker --cal_reranker_metrics --select_reader_data --run_reader --cal_reader_metrics > log_reranker_and_reader.txt 2>&1 & | |||||
| python reranker_and_reader_eval.py --get_reranker_data --run_reranker --cal_reranker_metrics --select_reader_data --run_reader --cal_reader_metrics --data_path $DATAPATH --ckpt_path $CKPTPATH > log_reranker_and_reader.txt 2>&1 & | |||||
| cd .. | cd .. | ||||
| @@ -0,0 +1,251 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """albert-xxlarge Model for reranker""" | |||||
| import numpy as np | |||||
| from mindspore import nn, ops | |||||
| from mindspore import Tensor, Parameter | |||||
| from mindspore.ops import operations as P | |||||
| from mindspore import dtype as mstype | |||||
| dst_type = mstype.float16 | |||||
| dst_type2 = mstype.float32 | |||||
| class LayerNorm(nn.Cell): | |||||
| """LayerNorm layer""" | |||||
| def __init__(self, layer_norm_weight, layer_norm_bias): | |||||
| """init function""" | |||||
| super(LayerNorm, self).__init__() | |||||
| self.reducemean = P.ReduceMean(keep_dims=True) | |||||
| self.sub = P.Sub() | |||||
| self.pow = P.Pow() | |||||
| self.add = P.Add() | |||||
| self.sqrt = P.Sqrt() | |||||
| self.div = P.Div() | |||||
| self.mul = P.Mul() | |||||
| self.layer_norm_weight = layer_norm_weight | |||||
| self.layer_norm_bias = layer_norm_bias | |||||
| def construct(self, x): | |||||
| """construct function""" | |||||
| diff_ex = self.sub(x, self.reducemean(x, -1)) | |||||
| var_x = self.reducemean(self.pow(diff_ex, 2.0), -1) | |||||
| output = self.div(diff_ex, self.sqrt(self.add(var_x, 1e-12))) | |||||
| output = self.add(self.mul(output, self.layer_norm_weight), self.layer_norm_bias) | |||||
| return output | |||||
| class Linear(nn.Cell): | |||||
| """Linear layer""" | |||||
| def __init__(self, linear_weight_shape, linear_bias): | |||||
| """init function""" | |||||
| super(Linear, self).__init__() | |||||
| self.matmul = nn.MatMul() | |||||
| self.add = P.Add() | |||||
| self.weight = Parameter(Tensor(np.random.uniform(0, 1, linear_weight_shape).astype(np.float32)), name=None) | |||||
| self.bias = linear_bias | |||||
| def construct(self, input_x): | |||||
| """construct function""" | |||||
| output = self.matmul(ops.Cast()(input_x, dst_type), ops.Cast()(self.weight, dst_type)) | |||||
| output = self.add(ops.Cast()(output, dst_type2), self.bias) | |||||
| return output | |||||
| class MultiHeadAttn(nn.Cell): | |||||
| """Multi-head attention layer""" | |||||
| def __init__(self, batch_size, query_linear_bias, key_linear_bias, value_linear_bias): | |||||
| """init function""" | |||||
| super(MultiHeadAttn, self).__init__() | |||||
| self.batch_size = batch_size | |||||
| self.matmul = nn.MatMul() | |||||
| self.add = P.Add() | |||||
| self.reshape = P.Reshape() | |||||
| self.transpose = P.Transpose() | |||||
| self.div = P.Div() | |||||
| self.softmax = nn.Softmax(axis=3) | |||||
| self.query_linear_weight = Parameter(Tensor(np.random.uniform(0, 1, (4096, 4096)).astype(np.float32)), | |||||
| name=None) | |||||
| self.query_linear_bias = query_linear_bias | |||||
| self.key_linear_weight = Parameter(Tensor(np.random.uniform(0, 1, (4096, 4096)).astype(np.float32)), | |||||
| name=None) | |||||
| self.key_linear_bias = key_linear_bias | |||||
| self.value_linear_weight = Parameter(Tensor(np.random.uniform(0, 1, (4096, 4096)).astype(np.float32)), | |||||
| name=None) | |||||
| self.value_linear_bias = value_linear_bias | |||||
| self.reshape_shape = tuple([batch_size, 512, 64, 64]) | |||||
| self.w = Parameter(Tensor(np.random.uniform(0, 1, (64, 64, 4096)).astype(np.float32)), name=None) | |||||
| self.b = Parameter(Tensor(np.random.uniform(0, 1, (4096,)).astype(np.float32)), name=None) | |||||
| def construct(self, hidden_states, extended_attention_mask): | |||||
| """construct function""" | |||||
| mixed_query_layer = self.matmul(ops.Cast()(hidden_states, dst_type), | |||||
| ops.Cast()(self.query_linear_weight, dst_type)) | |||||
| mixed_query_layer = self.add(ops.Cast()(mixed_query_layer, dst_type2), self.query_linear_bias) | |||||
| mixed_key_layer = self.matmul(ops.Cast()(hidden_states, dst_type), | |||||
| ops.Cast()(self.key_linear_weight, dst_type)) | |||||
| mixed_key_layer = self.add(ops.Cast()(mixed_key_layer, dst_type2), self.key_linear_bias) | |||||
| mixed_value_layer = self.matmul(ops.Cast()(hidden_states, dst_type), | |||||
| ops.Cast()(self.value_linear_weight, dst_type)) | |||||
| mixed_value_layer = self.add(ops.Cast()(mixed_value_layer, dst_type2), self.value_linear_bias) | |||||
| query_layer = self.reshape(mixed_query_layer, self.reshape_shape) | |||||
| key_layer = self.reshape(mixed_key_layer, self.reshape_shape) | |||||
| value_layer = self.reshape(mixed_value_layer, self.reshape_shape) | |||||
| query_layer = self.transpose(query_layer, (0, 2, 1, 3)) | |||||
| key_layer = self.transpose(key_layer, (0, 2, 3, 1)) | |||||
| value_layer = self.transpose(value_layer, (0, 2, 1, 3)) | |||||
| attention_scores = self.matmul(ops.Cast()(query_layer, dst_type), ops.Cast()(key_layer, dst_type)) | |||||
| attention_scores = self.div(ops.Cast()(attention_scores, dst_type2), ops.Cast()(8.0, dst_type2)) | |||||
| attention_scores = self.add(attention_scores, extended_attention_mask) | |||||
| attention_probs = self.softmax(attention_scores) | |||||
| context_layer = self.matmul(ops.Cast()(attention_probs, dst_type), ops.Cast()(value_layer, dst_type)) | |||||
| context_layer = self.transpose(ops.Cast()(context_layer, dst_type2), (0, 2, 1, 3)) | |||||
| projected_context_layer = self.matmul(ops.Cast()(context_layer, dst_type).view(self.batch_size * 512, -1), | |||||
| ops.Cast()(self.w, dst_type).view(-1, 4096))\ | |||||
| .view(self.batch_size, 512, 4096) | |||||
| projected_context_layer = self.add(ops.Cast()(projected_context_layer, dst_type2), self.b) | |||||
| return projected_context_layer | |||||
| class NewGeLU(nn.Cell): | |||||
| """Gelu layer""" | |||||
| def __init__(self): | |||||
| """init function""" | |||||
| super(NewGeLU, self).__init__() | |||||
| self.mul = P.Mul() | |||||
| self.pow = P.Pow() | |||||
| self.mul = P.Mul() | |||||
| self.add = P.Add() | |||||
| self.tanh = nn.Tanh() | |||||
| def construct(self, x): | |||||
| """construct function""" | |||||
| output = self.mul(self.add(x, self.mul(self.pow(x, 3.0), 0.044714998453855515)), 0.7978845834732056) | |||||
| output = self.tanh(output) | |||||
| output = self.mul(self.mul(x, 0.5), self.add(output, 1.0)) | |||||
| return output | |||||
| class AlbertTransformer(nn.Cell): | |||||
| """Transformer layer with LayerNOrm""" | |||||
| def __init__(self, batch_size, ffn_weight_shape, ffn_output_weight_shape, query_linear_bias, | |||||
| key_linear_bias, value_linear_bias, layernorm_weight, layernorm_bias, ffn_bias, ffn_output_bias): | |||||
| """init function""" | |||||
| super(AlbertTransformer, self).__init__() | |||||
| self.multiheadattn = MultiHeadAttn(batch_size=batch_size, | |||||
| query_linear_bias=query_linear_bias, | |||||
| key_linear_bias=key_linear_bias, | |||||
| value_linear_bias=value_linear_bias) | |||||
| self.add = P.Add() | |||||
| self.layernorm = LayerNorm(layer_norm_weight=layernorm_weight, layer_norm_bias=layernorm_bias) | |||||
| self.ffn = Linear(linear_weight_shape=ffn_weight_shape, linear_bias=ffn_bias) | |||||
| self.newgelu = NewGeLU() | |||||
| self.ffn_output = Linear(linear_weight_shape=ffn_output_weight_shape, linear_bias=ffn_output_bias) | |||||
| self.add_1 = P.Add() | |||||
| def construct(self, hidden_states, extended_attention_mask): | |||||
| """construct function""" | |||||
| attention_output = self.multiheadattn(hidden_states, extended_attention_mask) | |||||
| hidden_states = self.add(hidden_states, attention_output) | |||||
| hidden_states = self.layernorm(hidden_states) | |||||
| ffn_output = self.ffn(hidden_states) | |||||
| ffn_output = self.newgelu(ffn_output) | |||||
| ffn_output = self.ffn_output(ffn_output) | |||||
| hidden_states = self.add_1(ffn_output, hidden_states) | |||||
| return hidden_states | |||||
| class Albert(nn.Cell): | |||||
| """Albert model for rerank""" | |||||
| def __init__(self, batch_size): | |||||
| """init function""" | |||||
| super(Albert, self).__init__() | |||||
| self.expanddims = P.ExpandDims() | |||||
| self.cast = P.Cast() | |||||
| self.sub = P.Sub() | |||||
| self.mul = P.Mul() | |||||
| self.gather = P.Gather() | |||||
| self.add = P.Add() | |||||
| self.layernorm_1_weight = Parameter(Tensor(np.random.uniform(0, 1, (128,)).astype(np.float32)), name=None) | |||||
| self.layernorm_1_bias = Parameter(Tensor(np.random.uniform(0, 1, (128,)).astype(np.float32)), name=None) | |||||
| self.embedding_hidden_mapping_in_bias = Parameter(Tensor(np.random.uniform(0, 1, (4096,)).astype(np.float32)), | |||||
| name=None) | |||||
| self.query_linear_bias = Parameter(Tensor(np.random.uniform(0, 1, (4096,)).astype(np.float32)), name=None) | |||||
| self.key_linear_bias = Parameter(Tensor(np.random.uniform(0, 1, (4096,)).astype(np.float32)), name=None) | |||||
| self.value_linear_bias = Parameter(Tensor(np.random.uniform(0, 1, (4096,)).astype(np.float32)), name=None) | |||||
| self.albert_transformer_layernorm_w = Parameter(Tensor(np.random.uniform(0, 1, (4096,)).astype(np.float32)), | |||||
| name=None) | |||||
| self.albert_transformer_layernorm_b = Parameter(Tensor(np.random.uniform(0, 1, (4096,)).astype(np.float32)), | |||||
| name=None) | |||||
| self.ffn_bias = Parameter(Tensor(np.random.uniform(0, 1, (16384,)).astype(np.float32)), name=None) | |||||
| self.ffn_output_bias = Parameter(Tensor(np.random.uniform(0, 1, (4096,)).astype(np.float32)), name=None) | |||||
| self.layernorm_2_weight = Parameter(Tensor(np.random.uniform(0, 1, (4096,)).astype(np.float32)), name=None) | |||||
| self.layernorm_2_bias = Parameter(Tensor(np.random.uniform(0, 1, (4096,)).astype(np.float32)), name=None) | |||||
| self.word_embeddings = Parameter(Tensor(np.random.uniform(0, 1, (30005, 128)).astype(np.float32)), name=None) | |||||
| self.token_type_embeddings = Parameter(Tensor(np.random.uniform(0, 1, (2, 128)).astype(np.float32)), name=None) | |||||
| self.position_embeddings = Parameter(Tensor(np.random.uniform(0, 1, (1, 512, 128)).astype(np.float32)), | |||||
| name=None) | |||||
| self.layernorm_1 = LayerNorm(layer_norm_weight=self.layernorm_1_weight, layer_norm_bias=self.layernorm_1_bias) | |||||
| self.embedding_hidden_mapping_in = Linear(linear_weight_shape=(128, 4096), | |||||
| linear_bias=self.embedding_hidden_mapping_in_bias) | |||||
| self.albert_transformer = AlbertTransformer(batch_size=batch_size, | |||||
| ffn_weight_shape=(4096, 16384), | |||||
| ffn_output_weight_shape=(16384, 4096), | |||||
| query_linear_bias=self.query_linear_bias, | |||||
| key_linear_bias=self.key_linear_bias, | |||||
| value_linear_bias=self.value_linear_bias, | |||||
| layernorm_weight=self.albert_transformer_layernorm_w, | |||||
| layernorm_bias=self.albert_transformer_layernorm_b, | |||||
| ffn_bias=self.ffn_bias, | |||||
| ffn_output_bias=self.ffn_output_bias) | |||||
| self.layernorm_2 = LayerNorm(layer_norm_weight=self.layernorm_2_weight, layer_norm_bias=self.layernorm_2_bias) | |||||
| def construct(self, input_ids, attention_mask, token_type_ids): | |||||
| """construct function""" | |||||
| extended_attention_mask = self.expanddims(attention_mask, 1) | |||||
| extended_attention_mask = self.expanddims(extended_attention_mask, 2) | |||||
| extended_attention_mask = self.cast(extended_attention_mask, mstype.float32) | |||||
| extended_attention_mask = self.mul(self.sub(1.0, extended_attention_mask), -10000.0) | |||||
| inputs_embeds = self.gather(self.word_embeddings, input_ids, 0) | |||||
| token_type_embeddings = self.gather(self.token_type_embeddings, token_type_ids, 0) | |||||
| embeddings = self.add(self.add(inputs_embeds, self.position_embeddings), token_type_embeddings) | |||||
| embeddings = self.layernorm_1(embeddings) | |||||
| hidden_states = self.embedding_hidden_mapping_in(embeddings) | |||||
| for _ in range(12): | |||||
| hidden_states = self.albert_transformer(hidden_states, extended_attention_mask) | |||||
| hidden_states = self.layernorm_2(hidden_states) | |||||
| return hidden_states | |||||
| @@ -414,8 +414,8 @@ def convert_example_to_features(tokenizer, args, examples): | |||||
| def get_rerank_data(args): | def get_rerank_data(args): | ||||
| """function for generating reranker's data""" | """function for generating reranker's data""" | ||||
| new_dev_data = gen_dev_data(dev_file=args.dev_gold_file, | |||||
| db_path=args.wiki_db_file, | |||||
| new_dev_data = gen_dev_data(dev_file=args.dev_gold_path, | |||||
| db_path=args.wiki_db_path, | |||||
| topk_file=args.retriever_result_file) | topk_file=args.retriever_result_file) | ||||
| tokenizer = AutoTokenizer.from_pretrained(args.albert_model_path) | tokenizer = AutoTokenizer.from_pretrained(args.albert_model_path) | ||||
| new_tokens = ['[q]', '[/q]', '<t>', '</t>', '[s]'] | new_tokens = ['[q]', '[/q]', '<t>', '</t>', '[s]'] | ||||
| @@ -39,9 +39,9 @@ def ThinkRetrieverConfig(): | |||||
| parser.add_argument("--dev_path", type=str, default='../hotpot_dev_fullwiki_v1_for_retriever.json', | parser.add_argument("--dev_path", type=str, default='../hotpot_dev_fullwiki_v1_for_retriever.json', | ||||
| help="dev path") | help="dev path") | ||||
| parser.add_argument("--dev_data_path", type=str, default='../dev_tf_idf_data_raw.pkl', help="dev data path") | parser.add_argument("--dev_data_path", type=str, default='../dev_tf_idf_data_raw.pkl', help="dev data path") | ||||
| parser.add_argument("--onehop_bert_path", type=str, default='../onehop.ckpt', help="onehop bert ckpt path") | |||||
| parser.add_argument("--onehop_bert_path", type=str, default='../onehop_new.ckpt', help="onehop bert ckpt path") | |||||
| parser.add_argument("--onehop_mlp_path", type=str, default='../onehop_mlp.ckpt', help="onehop mlp ckpt path") | parser.add_argument("--onehop_mlp_path", type=str, default='../onehop_mlp.ckpt', help="onehop mlp ckpt path") | ||||
| parser.add_argument("--twohop_bert_path", type=str, default='../twohop.ckpt', help="twohop bert ckpt path") | |||||
| parser.add_argument("--twohop_bert_path", type=str, default='../twohop_new.ckpt', help="twohop bert ckpt path") | |||||
| parser.add_argument("--twohop_mlp_path", type=str, default='../twohop_mlp.ckpt', help="twohop mlp ckpt path") | parser.add_argument("--twohop_mlp_path", type=str, default='../twohop_mlp.ckpt', help="twohop mlp ckpt path") | ||||
| parser.add_argument("--q_path", type=str, default="../queries", help="queries data path") | parser.add_argument("--q_path", type=str, default="../queries", help="queries data path") | ||||
| return parser.parse_args() | return parser.parse_args() | ||||
| @@ -0,0 +1,278 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ | |||||
| One Hop BERT. | |||||
| """ | |||||
| import numpy as np | |||||
| from mindspore import nn | |||||
| from mindspore import Tensor, Parameter | |||||
| import mindspore.common.dtype as mstype | |||||
| from mindspore.ops import operations as P | |||||
| BATCH_SIZE = -1 | |||||
| class LayerNorm(nn.Cell): | |||||
| """layer norm""" | |||||
| def __init__(self): | |||||
| super(LayerNorm, self).__init__() | |||||
| self.reducemean = P.ReduceMean(keep_dims=True) | |||||
| self.sub = P.Sub() | |||||
| self.cast = P.Cast() | |||||
| self.cast_to = mstype.float32 | |||||
| self.pow = P.Pow() | |||||
| self.pow_weight = 2.0 | |||||
| self.add = P.Add() | |||||
| self.add_bias_0 = 9.999999960041972e-13 | |||||
| self.sqrt = P.Sqrt() | |||||
| self.div = P.Div() | |||||
| self.mul = P.Mul() | |||||
| self.mul_weight = Parameter(Tensor(np.random.uniform(0, 1, (768,)).astype(np.float32)), name=None) | |||||
| self.add_bias_1 = Parameter(Tensor(np.random.uniform(0, 1, (768,)).astype(np.float32)), name=None) | |||||
| def construct(self, x): | |||||
| """construct function""" | |||||
| x_mean = self.reducemean(x, -1) | |||||
| x_sub = self.sub(x, x_mean) | |||||
| x_sub = self.cast(x_sub, self.cast_to) | |||||
| x_pow = self.pow(x_sub, self.pow_weight) | |||||
| out_mean = self.reducemean(x_pow, -1) | |||||
| out_add = self.add(out_mean, self.add_bias_0) | |||||
| out_sqrt = self.sqrt(out_add) | |||||
| out_div = self.div(x_sub, out_sqrt) | |||||
| out_mul = self.mul(out_div, self.mul_weight) | |||||
| output = self.add(out_mul, self.add_bias_1) | |||||
| return output | |||||
| class MultiHeadAttn(nn.Cell): | |||||
| """multi head attention layer""" | |||||
| def __init__(self, seq_len): | |||||
| super(MultiHeadAttn, self).__init__() | |||||
| self.matmul = nn.MatMul() | |||||
| self.matmul.to_float(mstype.float16) | |||||
| self.query = Parameter(Tensor(np.random.uniform(0, 1, (768, 768)).astype(np.float32)), name=None) | |||||
| self.key = Parameter(Tensor(np.random.uniform(0, 1, (768, 768)).astype(np.float32)), name=None) | |||||
| self.value = Parameter(Tensor(np.random.uniform(0, 1, (768, 768)).astype(np.float32)), name=None) | |||||
| self.add = P.Add() | |||||
| self.query_bias = Parameter(Tensor(np.random.uniform(0, 1, (768,)).astype(np.float32)), name=None) | |||||
| self.key_bias = Parameter(Tensor(np.random.uniform(0, 1, (768,)).astype(np.float32)), name=None) | |||||
| self.value_bias = Parameter(Tensor(np.random.uniform(0, 1, (768,)).astype(np.float32)), name=None) | |||||
| self.reshape = P.Reshape() | |||||
| self.to_shape_0 = tuple([BATCH_SIZE, seq_len, 12, 64]) | |||||
| self.transpose = P.Transpose() | |||||
| self.div = P.Div() | |||||
| self.div_w = 8.0 | |||||
| self.softmax = nn.Softmax(axis=3) | |||||
| self.to_shape_1 = tuple([BATCH_SIZE, seq_len, 768]) | |||||
| self.context_weight = Parameter(Tensor(np.random.uniform(0, 1, (768, 768)).astype(np.float32)), name=None) | |||||
| self.context_bias = Parameter(Tensor(np.random.uniform(0, 1, (768,)).astype(np.float32)), name=None) | |||||
| def construct(self, input_tensor, attention_mask): | |||||
| """construct function""" | |||||
| query_output = self.matmul(input_tensor, self.query) | |||||
| key_output = self.matmul(input_tensor, self.key) | |||||
| value_output = self.matmul(input_tensor, self.value) | |||||
| query_output = P.Cast()(query_output, mstype.float32) | |||||
| key_output = P.Cast()(key_output, mstype.float32) | |||||
| value_output = P.Cast()(value_output, mstype.float32) | |||||
| query_output = self.add(query_output, self.query_bias) | |||||
| key_output = self.add(key_output, self.key_bias) | |||||
| value_output = self.add(value_output, self.value_bias) | |||||
| query_layer = self.reshape(query_output, self.to_shape_0) | |||||
| key_layer = self.reshape(key_output, self.to_shape_0) | |||||
| value_layer = self.reshape(value_output, self.to_shape_0) | |||||
| query_layer = self.transpose(query_layer, (0, 2, 1, 3)) | |||||
| key_layer = self.transpose(key_layer, (0, 2, 3, 1)) | |||||
| value_layer = self.transpose(value_layer, (0, 2, 1, 3)) | |||||
| attention_scores = self.matmul(query_layer, key_layer) | |||||
| attention_scores = P.Cast()(attention_scores, mstype.float32) | |||||
| attention_scores = self.div(attention_scores, self.div_w) | |||||
| attention_scores = self.add(attention_scores, attention_mask) | |||||
| attention_scores = P.Cast()(attention_scores, mstype.float32) | |||||
| attention_probs = self.softmax(attention_scores) | |||||
| context_layer = self.matmul(attention_probs, value_layer) | |||||
| context_layer = P.Cast()(context_layer, mstype.float32) | |||||
| context_layer = self.transpose(context_layer, (0, 2, 1, 3)) | |||||
| context_layer = self.reshape(context_layer, self.to_shape_1) | |||||
| context_layer = self.matmul(context_layer, self.context_weight) | |||||
| context_layer = P.Cast()(context_layer, mstype.float32) | |||||
| context_layer = self.add(context_layer, self.context_bias) | |||||
| return context_layer | |||||
| class Linear(nn.Cell): | |||||
| """linear layer""" | |||||
| def __init__(self, w_shape, b_shape): | |||||
| super(Linear, self).__init__() | |||||
| self.matmul = nn.MatMul() | |||||
| self.matmul.to_float(mstype.float16) | |||||
| self.w = Parameter(Tensor(np.random.uniform(0, 1, w_shape).astype(np.float32)), | |||||
| name=None) | |||||
| self.add = P.Add() | |||||
| self.b = Parameter(Tensor(np.random.uniform(0, 1, b_shape).astype(np.float32)), name=None) | |||||
| def construct(self, x): | |||||
| """construct function""" | |||||
| output = self.matmul(x, self.w) | |||||
| output = P.Cast()(output, mstype.float32) | |||||
| output = self.add(output, self.b) | |||||
| return output | |||||
| class GeLU(nn.Cell): | |||||
| """gelu layer""" | |||||
| def __init__(self): | |||||
| super(GeLU, self).__init__() | |||||
| self.div = P.Div() | |||||
| self.div_w = 1.4142135381698608 | |||||
| self.erf = P.Erf() | |||||
| self.add = P.Add() | |||||
| self.add_bias = 1.0 | |||||
| self.mul = P.Mul() | |||||
| self.mul_w = 0.5 | |||||
| def construct(self, x): | |||||
| """construct function""" | |||||
| output = self.div(x, self.div_w) | |||||
| output = self.erf(output) | |||||
| output = self.add(output, self.add_bias) | |||||
| output = self.mul(x, output) | |||||
| output = self.mul(output, self.mul_w) | |||||
| return output | |||||
| class TransformerLayer(nn.Cell): | |||||
| """transformer layer""" | |||||
| def __init__(self, seq_len, intermediate_size, intermediate_bias, output_size, output_bias): | |||||
| super(TransformerLayer, self).__init__() | |||||
| self.attention = MultiHeadAttn(seq_len) | |||||
| self.add = P.Add() | |||||
| self.layernorm1 = LayerNorm() | |||||
| self.intermediate = Linear(w_shape=intermediate_size, | |||||
| b_shape=intermediate_bias) | |||||
| self.gelu = GeLU() | |||||
| self.output = Linear(w_shape=output_size, | |||||
| b_shape=output_bias) | |||||
| self.layernorm2 = LayerNorm() | |||||
| def construct(self, hidden_states, attention_mask): | |||||
| """construct function""" | |||||
| attention_output = self.attention(hidden_states, attention_mask) | |||||
| attention_output = self.add(attention_output, hidden_states) | |||||
| attention_output = self.layernorm1(attention_output) | |||||
| intermediate_output = self.intermediate(attention_output) | |||||
| intermediate_output = self.gelu(intermediate_output) | |||||
| output = self.output(intermediate_output) | |||||
| output = self.add(output, attention_output) | |||||
| output = self.layernorm2(output) | |||||
| return output | |||||
| class BertEncoder(nn.Cell): | |||||
| """encoder layer""" | |||||
| def __init__(self, seq_len): | |||||
| super(BertEncoder, self).__init__() | |||||
| self.layer1 = TransformerLayer(seq_len, | |||||
| intermediate_size=(768, 3072), | |||||
| intermediate_bias=(3072,), | |||||
| output_size=(3072, 768), | |||||
| output_bias=(768,)) | |||||
| self.layer2 = TransformerLayer(seq_len, | |||||
| intermediate_size=(768, 3072), | |||||
| intermediate_bias=(3072,), | |||||
| output_size=(3072, 768), | |||||
| output_bias=(768,)) | |||||
| self.layer3 = TransformerLayer(seq_len, | |||||
| intermediate_size=(768, 3072), | |||||
| intermediate_bias=(3072,), | |||||
| output_size=(3072, 768), | |||||
| output_bias=(768,)) | |||||
| self.layer4 = TransformerLayer(seq_len, | |||||
| intermediate_size=(768, 3072), | |||||
| intermediate_bias=(3072,), | |||||
| output_size=(3072, 768), | |||||
| output_bias=(768,)) | |||||
| def construct(self, input_tensor, attention_mask): | |||||
| """construct function""" | |||||
| layer1_output = self.layer1(input_tensor, attention_mask) | |||||
| layer2_output = self.layer2(layer1_output, attention_mask) | |||||
| layer3_output = self.layer3(layer2_output, attention_mask) | |||||
| layer4_output = self.layer4(layer3_output, attention_mask) | |||||
| return layer4_output | |||||
| class ModelOneHop(nn.Cell): | |||||
| """one hop layer""" | |||||
| def __init__(self, seq_len): | |||||
| super(ModelOneHop, self).__init__() | |||||
| self.expanddims = P.ExpandDims() | |||||
| self.expanddims_axis_0 = 1 | |||||
| self.expanddims_axis_1 = 2 | |||||
| self.cast = P.Cast() | |||||
| self.cast_to = mstype.float32 | |||||
| self.sub = P.Sub() | |||||
| self.sub_bias = 1.0 | |||||
| self.mul = P.Mul() | |||||
| self.mul_w = -10000.0 | |||||
| self.input_weight_0 = Parameter(Tensor(np.random.uniform(0, 1, (30522, 768)).astype(np.float32)), | |||||
| name=None) | |||||
| self.gather_axis_0 = 0 | |||||
| self.gather = P.Gather() | |||||
| self.input_weight_1 = Parameter(Tensor(np.random.uniform(0, 1, (2, 768)).astype(np.float32)), name=None) | |||||
| self.add = P.Add() | |||||
| self.add_bias = Parameter(Tensor(np.random.uniform(0, 1, (1, seq_len, 768)).astype(np.float32)), name=None) | |||||
| self.layernorm = LayerNorm() | |||||
| self.encoder_layer_1_4 = BertEncoder(seq_len) | |||||
| self.encoder_layer_5_8 = BertEncoder(seq_len) | |||||
| self.encoder_layer_9_12 = BertEncoder(seq_len) | |||||
| self.cls_ids = Tensor(np.array(0)) | |||||
| self.gather_axis_1 = 1 | |||||
| self.dense = nn.Dense(in_channels=768, out_channels=768, has_bias=True) | |||||
| self.tanh = nn.Tanh() | |||||
| def construct(self, input_ids, token_type_ids, attention_mask): | |||||
| """construct function""" | |||||
| input_ids = self.cast(input_ids, mstype.int32) | |||||
| token_type_ids = self.cast(token_type_ids, mstype.int32) | |||||
| attention_mask = self.cast(attention_mask, mstype.int32) | |||||
| attention_mask = self.expanddims(attention_mask, self.expanddims_axis_0) | |||||
| attention_mask = self.expanddims(attention_mask, self.expanddims_axis_1) | |||||
| attention_mask = self.cast(attention_mask, self.cast_to) | |||||
| attention_mask = self.sub(self.sub_bias, attention_mask) | |||||
| attention_mask_matrix = self.mul(attention_mask, self.mul_w) | |||||
| word_embeddings = self.gather(self.input_weight_0, input_ids, self.gather_axis_0) | |||||
| token_type_embeddings = self.gather(self.input_weight_1, token_type_ids, self.gather_axis_0) | |||||
| word_embeddings = self.add(word_embeddings, self.add_bias) | |||||
| embedding_output = self.add(word_embeddings, token_type_embeddings) | |||||
| embedding_output = self.layernorm(embedding_output) | |||||
| encoder_output = self.encoder_layer_1_4(embedding_output, attention_mask_matrix) | |||||
| encoder_output = self.encoder_layer_5_8(encoder_output, attention_mask_matrix) | |||||
| encoder_output = self.encoder_layer_9_12(encoder_output, attention_mask_matrix) | |||||
| cls_output = self.gather(encoder_output, self.cls_ids, self.gather_axis_1) | |||||
| pooled_output = self.dense(cls_output) | |||||
| pooled_output = self.tanh(pooled_output) | |||||
| return pooled_output | |||||
| @@ -120,13 +120,11 @@ def hotpotqa_eval(prediction_file, gold_file): | |||||
| cur_id = dp['_id'] | cur_id = dp['_id'] | ||||
| can_eval_joint = True | can_eval_joint = True | ||||
| if cur_id not in prediction['answer']: | if cur_id not in prediction['answer']: | ||||
| print('missing answer {}'.format(cur_id)) | |||||
| can_eval_joint = False | can_eval_joint = False | ||||
| else: | else: | ||||
| em, prec, recall = update_answer( | em, prec, recall = update_answer( | ||||
| metrics, prediction['answer'][cur_id], dp['answer']) | metrics, prediction['answer'][cur_id], dp['answer']) | ||||
| if cur_id not in prediction['sp']: | if cur_id not in prediction['sp']: | ||||
| print('missing sp fact {}'.format(cur_id)) | |||||
| can_eval_joint = False | can_eval_joint = False | ||||
| else: | else: | ||||
| sp_em, sp_prec, sp_recall = update_sp( | sp_em, sp_prec, sp_recall = update_sp( | ||||
| @@ -1,302 +0,0 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ | |||||
| One Hop BERT. | |||||
| """ | |||||
| import numpy as np | |||||
| from mindspore import nn | |||||
| from mindspore import Tensor, Parameter | |||||
| import mindspore.common.dtype as mstype | |||||
| from mindspore.ops import operations as P | |||||
| BATCH_SIZE = -1 | |||||
| class LayerNorm(nn.Cell): | |||||
| """layer norm""" | |||||
| def __init__(self): | |||||
| super(LayerNorm, self).__init__() | |||||
| self.reducemean_0 = P.ReduceMean(keep_dims=True) | |||||
| self.sub_1 = P.Sub() | |||||
| self.cast_2 = P.Cast() | |||||
| self.cast_2_to = mstype.float32 | |||||
| self.pow_3 = P.Pow() | |||||
| self.pow_3_input_weight = 2.0 | |||||
| self.reducemean_4 = P.ReduceMean(keep_dims=True) | |||||
| self.add_5 = P.Add() | |||||
| self.add_5_bias = 9.999999960041972e-13 | |||||
| self.sqrt_6 = P.Sqrt() | |||||
| self.div_7 = P.Div() | |||||
| self.mul_8 = P.Mul() | |||||
| self.mul_8_w = Parameter(Tensor(np.random.uniform(0, 1, (768,)).astype(np.float32)), name=None) | |||||
| self.add_9 = P.Add() | |||||
| self.add_9_bias = Parameter(Tensor(np.random.uniform(0, 1, (768,)).astype(np.float32)), name=None) | |||||
| def construct(self, x): | |||||
| """construct function""" | |||||
| opt_reducemean_0 = self.reducemean_0(x, -1) | |||||
| opt_sub_1 = self.sub_1(x, opt_reducemean_0) | |||||
| opt_cast_2 = self.cast_2(opt_sub_1, self.cast_2_to) | |||||
| opt_pow_3 = self.pow_3(opt_cast_2, self.pow_3_input_weight) | |||||
| opt_reducemean_4 = self.reducemean_4(opt_pow_3, -1) | |||||
| opt_add_5 = self.add_5(opt_reducemean_4, self.add_5_bias) | |||||
| opt_sqrt_6 = self.sqrt_6(opt_add_5) | |||||
| opt_div_7 = self.div_7(opt_sub_1, opt_sqrt_6) | |||||
| opt_mul_8 = self.mul_8(opt_div_7, self.mul_8_w) | |||||
| opt_add_9 = self.add_9(opt_mul_8, self.add_9_bias) | |||||
| return opt_add_9 | |||||
| class MultiHeadAttn(nn.Cell): | |||||
| """multi head attention layer""" | |||||
| def __init__(self): | |||||
| super(MultiHeadAttn, self).__init__() | |||||
| self.matmul_0 = nn.MatMul() | |||||
| self.matmul_0.to_float(mstype.float16) | |||||
| self.matmul_0_w = Parameter(Tensor(np.random.uniform(0, 1, (768, 768)).astype(np.float32)), name=None) | |||||
| self.matmul_1 = nn.MatMul() | |||||
| self.matmul_1.to_float(mstype.float16) | |||||
| self.matmul_1_w = Parameter(Tensor(np.random.uniform(0, 1, (768, 768)).astype(np.float32)), name=None) | |||||
| self.matmul_2 = nn.MatMul() | |||||
| self.matmul_2.to_float(mstype.float16) | |||||
| self.matmul_2_w = Parameter(Tensor(np.random.uniform(0, 1, (768, 768)).astype(np.float32)), name=None) | |||||
| self.add_3 = P.Add() | |||||
| self.add_3_bias = Parameter(Tensor(np.random.uniform(0, 1, (768,)).astype(np.float32)), name=None) | |||||
| self.add_4 = P.Add() | |||||
| self.add_4_bias = Parameter(Tensor(np.random.uniform(0, 1, (768,)).astype(np.float32)), name=None) | |||||
| self.add_5 = P.Add() | |||||
| self.add_5_bias = Parameter(Tensor(np.random.uniform(0, 1, (768,)).astype(np.float32)), name=None) | |||||
| self.reshape_6 = P.Reshape() | |||||
| self.reshape_6_shape = tuple([BATCH_SIZE, 256, 12, 64]) | |||||
| self.reshape_7 = P.Reshape() | |||||
| self.reshape_7_shape = tuple([BATCH_SIZE, 256, 12, 64]) | |||||
| self.reshape_8 = P.Reshape() | |||||
| self.reshape_8_shape = tuple([BATCH_SIZE, 256, 12, 64]) | |||||
| self.transpose_9 = P.Transpose() | |||||
| self.transpose_10 = P.Transpose() | |||||
| self.transpose_11 = P.Transpose() | |||||
| self.matmul_12 = nn.MatMul() | |||||
| self.matmul_12.to_float(mstype.float16) | |||||
| self.div_13 = P.Div() | |||||
| self.div_13_w = 8.0 | |||||
| self.add_14 = P.Add() | |||||
| self.softmax_15 = nn.Softmax(axis=3) | |||||
| self.matmul_16 = nn.MatMul() | |||||
| self.matmul_16.to_float(mstype.float16) | |||||
| self.transpose_17 = P.Transpose() | |||||
| self.reshape_18 = P.Reshape() | |||||
| self.reshape_18_shape = tuple([BATCH_SIZE, 256, 768]) | |||||
| self.matmul_19 = nn.MatMul() | |||||
| self.matmul_19.to_float(mstype.float16) | |||||
| self.matmul_19_w = Parameter(Tensor(np.random.uniform(0, 1, (768, 768)).astype(np.float32)), name=None) | |||||
| self.add_20 = P.Add() | |||||
| self.add_20_bias = Parameter(Tensor(np.random.uniform(0, 1, (768,)).astype(np.float32)), name=None) | |||||
| def construct(self, x, x0): | |||||
| """construct function""" | |||||
| opt_matmul_0 = self.matmul_0(x, self.matmul_0_w) | |||||
| opt_matmul_1 = self.matmul_1(x, self.matmul_1_w) | |||||
| opt_matmul_2 = self.matmul_2(x, self.matmul_2_w) | |||||
| opt_matmul_0 = P.Cast()(opt_matmul_0, mstype.float32) | |||||
| opt_matmul_1 = P.Cast()(opt_matmul_1, mstype.float32) | |||||
| opt_matmul_2 = P.Cast()(opt_matmul_2, mstype.float32) | |||||
| opt_add_3 = self.add_3(opt_matmul_0, self.add_3_bias) | |||||
| opt_add_4 = self.add_4(opt_matmul_1, self.add_4_bias) | |||||
| opt_add_5 = self.add_5(opt_matmul_2, self.add_5_bias) | |||||
| opt_reshape_6 = self.reshape_6(opt_add_3, self.reshape_6_shape) | |||||
| opt_reshape_7 = self.reshape_7(opt_add_4, self.reshape_7_shape) | |||||
| opt_reshape_8 = self.reshape_8(opt_add_5, self.reshape_8_shape) | |||||
| opt_transpose_9 = self.transpose_9(opt_reshape_6, (0, 2, 1, 3)) | |||||
| opt_transpose_10 = self.transpose_10(opt_reshape_7, (0, 2, 3, 1)) | |||||
| opt_transpose_11 = self.transpose_11(opt_reshape_8, (0, 2, 1, 3)) | |||||
| opt_matmul_12 = self.matmul_12(opt_transpose_9, opt_transpose_10) | |||||
| opt_matmul_12 = P.Cast()(opt_matmul_12, mstype.float32) | |||||
| opt_div_13 = self.div_13(opt_matmul_12, self.div_13_w) | |||||
| opt_add_14 = self.add_14(opt_div_13, x0) | |||||
| opt_add_14 = P.Cast()(opt_add_14, mstype.float32) | |||||
| opt_softmax_15 = self.softmax_15(opt_add_14) | |||||
| opt_matmul_16 = self.matmul_16(opt_softmax_15, opt_transpose_11) | |||||
| opt_matmul_16 = P.Cast()(opt_matmul_16, mstype.float32) | |||||
| opt_transpose_17 = self.transpose_17(opt_matmul_16, (0, 2, 1, 3)) | |||||
| opt_reshape_18 = self.reshape_18(opt_transpose_17, self.reshape_18_shape) | |||||
| opt_matmul_19 = self.matmul_19(opt_reshape_18, self.matmul_19_w) | |||||
| opt_matmul_19 = P.Cast()(opt_matmul_19, mstype.float32) | |||||
| opt_add_20 = self.add_20(opt_matmul_19, self.add_20_bias) | |||||
| return opt_add_20 | |||||
| class Linear(nn.Cell): | |||||
| """linear layer""" | |||||
| def __init__(self, matmul_0_weight_shape, add_1_bias_shape): | |||||
| super(Linear, self).__init__() | |||||
| self.matmul_0 = nn.MatMul() | |||||
| self.matmul_0.to_float(mstype.float16) | |||||
| self.matmul_0_w = Parameter(Tensor(np.random.uniform(0, 1, matmul_0_weight_shape).astype(np.float32)), | |||||
| name=None) | |||||
| self.add_1 = P.Add() | |||||
| self.add_1_bias = Parameter(Tensor(np.random.uniform(0, 1, add_1_bias_shape).astype(np.float32)), name=None) | |||||
| def construct(self, x): | |||||
| """construct function""" | |||||
| opt_matmul_0 = self.matmul_0(x, self.matmul_0_w) | |||||
| opt_matmul_0 = P.Cast()(opt_matmul_0, mstype.float32) | |||||
| opt_add_1 = self.add_1(opt_matmul_0, self.add_1_bias) | |||||
| return opt_add_1 | |||||
| class GeLU(nn.Cell): | |||||
| """gelu layer""" | |||||
| def __init__(self): | |||||
| super(GeLU, self).__init__() | |||||
| self.div_0 = P.Div() | |||||
| self.div_0_w = 1.4142135381698608 | |||||
| self.erf_1 = P.Erf() | |||||
| self.add_2 = P.Add() | |||||
| self.add_2_bias = 1.0 | |||||
| self.mul_3 = P.Mul() | |||||
| self.mul_4 = P.Mul() | |||||
| self.mul_4_w = 0.5 | |||||
| def construct(self, x): | |||||
| """construct function""" | |||||
| opt_div_0 = self.div_0(x, self.div_0_w) | |||||
| opt_erf_1 = self.erf_1(opt_div_0) | |||||
| opt_add_2 = self.add_2(opt_erf_1, self.add_2_bias) | |||||
| opt_mul_3 = self.mul_3(x, opt_add_2) | |||||
| opt_mul_4 = self.mul_4(opt_mul_3, self.mul_4_w) | |||||
| return opt_mul_4 | |||||
| class TransformerLayer(nn.Cell): | |||||
| """transformer layer""" | |||||
| def __init__(self, linear3_0_matmul_0_weight_shape, linear3_0_add_1_bias_shape, linear3_1_matmul_0_weight_shape, | |||||
| linear3_1_add_1_bias_shape): | |||||
| super(TransformerLayer, self).__init__() | |||||
| self.multiheadattn_0 = MultiHeadAttn() | |||||
| self.add_0 = P.Add() | |||||
| self.layernorm1_0 = LayerNorm() | |||||
| self.linear3_0 = Linear(matmul_0_weight_shape=linear3_0_matmul_0_weight_shape, | |||||
| add_1_bias_shape=linear3_0_add_1_bias_shape) | |||||
| self.gelu1_0 = GeLU() | |||||
| self.linear3_1 = Linear(matmul_0_weight_shape=linear3_1_matmul_0_weight_shape, | |||||
| add_1_bias_shape=linear3_1_add_1_bias_shape) | |||||
| self.add_1 = P.Add() | |||||
| self.layernorm1_1 = LayerNorm() | |||||
| def construct(self, x, x0): | |||||
| """construct function""" | |||||
| multiheadattn_0_opt = self.multiheadattn_0(x, x0) | |||||
| opt_add_0 = self.add_0(multiheadattn_0_opt, x) | |||||
| layernorm1_0_opt = self.layernorm1_0(opt_add_0) | |||||
| linear3_0_opt = self.linear3_0(layernorm1_0_opt) | |||||
| gelu1_0_opt = self.gelu1_0(linear3_0_opt) | |||||
| linear3_1_opt = self.linear3_1(gelu1_0_opt) | |||||
| opt_add_1 = self.add_1(linear3_1_opt, layernorm1_0_opt) | |||||
| layernorm1_1_opt = self.layernorm1_1(opt_add_1) | |||||
| return layernorm1_1_opt | |||||
| class Encoder1_4(nn.Cell): | |||||
| """encoder layer""" | |||||
| def __init__(self): | |||||
| super(Encoder1_4, self).__init__() | |||||
| self.module47_0 = TransformerLayer(linear3_0_matmul_0_weight_shape=(768, 3072), | |||||
| linear3_0_add_1_bias_shape=(3072,), | |||||
| linear3_1_matmul_0_weight_shape=(3072, 768), | |||||
| linear3_1_add_1_bias_shape=(768,)) | |||||
| self.module47_1 = TransformerLayer(linear3_0_matmul_0_weight_shape=(768, 3072), | |||||
| linear3_0_add_1_bias_shape=(3072,), | |||||
| linear3_1_matmul_0_weight_shape=(3072, 768), | |||||
| linear3_1_add_1_bias_shape=(768,)) | |||||
| self.module47_2 = TransformerLayer(linear3_0_matmul_0_weight_shape=(768, 3072), | |||||
| linear3_0_add_1_bias_shape=(3072,), | |||||
| linear3_1_matmul_0_weight_shape=(3072, 768), | |||||
| linear3_1_add_1_bias_shape=(768,)) | |||||
| self.module47_3 = TransformerLayer(linear3_0_matmul_0_weight_shape=(768, 3072), | |||||
| linear3_0_add_1_bias_shape=(3072,), | |||||
| linear3_1_matmul_0_weight_shape=(3072, 768), | |||||
| linear3_1_add_1_bias_shape=(768,)) | |||||
| def construct(self, x, x0): | |||||
| """construct function""" | |||||
| module47_0_opt = self.module47_0(x, x0) | |||||
| module47_1_opt = self.module47_1(module47_0_opt, x0) | |||||
| module47_2_opt = self.module47_2(module47_1_opt, x0) | |||||
| module47_3_opt = self.module47_3(module47_2_opt, x0) | |||||
| return module47_3_opt | |||||
| class ModelOneHop(nn.Cell): | |||||
| """one hop layer""" | |||||
| def __init__(self): | |||||
| super(ModelOneHop, self).__init__() | |||||
| self.expanddims_0 = P.ExpandDims() | |||||
| self.expanddims_0_axis = 1 | |||||
| self.expanddims_3 = P.ExpandDims() | |||||
| self.expanddims_3_axis = 2 | |||||
| self.cast_5 = P.Cast() | |||||
| self.cast_5_to = mstype.float32 | |||||
| self.sub_7 = P.Sub() | |||||
| self.sub_7_bias = 1.0 | |||||
| self.mul_9 = P.Mul() | |||||
| self.mul_9_w = -10000.0 | |||||
| self.gather_1_input_weight = Parameter(Tensor(np.random.uniform(0, 1, (30522, 768)).astype(np.float32)), | |||||
| name=None) | |||||
| self.gather_1_axis = 0 | |||||
| self.gather_1 = P.Gather() | |||||
| self.gather_2_input_weight = Parameter(Tensor(np.random.uniform(0, 1, (2, 768)).astype(np.float32)), name=None) | |||||
| self.gather_2_axis = 0 | |||||
| self.gather_2 = P.Gather() | |||||
| self.add_4 = P.Add() | |||||
| self.add_4_bias = Parameter(Tensor(np.random.uniform(0, 1, (1, 256, 768)).astype(np.float32)), name=None) | |||||
| self.add_6 = P.Add() | |||||
| self.layernorm1_0 = LayerNorm() | |||||
| self.module51_0 = Encoder1_4() | |||||
| self.module51_1 = Encoder1_4() | |||||
| self.module51_2 = Encoder1_4() | |||||
| self.gather_643_input_weight = Tensor(np.array(0)) | |||||
| self.gather_643_axis = 1 | |||||
| self.gather_643 = P.Gather() | |||||
| self.dense_644 = nn.Dense(in_channels=768, out_channels=768, has_bias=True) | |||||
| self.tanh_645 = nn.Tanh() | |||||
| def construct(self, input_ids, token_type_ids, attention_mask): | |||||
| """construct function""" | |||||
| input_ids = self.cast_5(input_ids, mstype.int32) | |||||
| token_type_ids = self.cast_5(token_type_ids, mstype.int32) | |||||
| attention_mask = self.cast_5(attention_mask, mstype.int32) | |||||
| opt_expanddims_0 = self.expanddims_0(attention_mask, self.expanddims_0_axis) | |||||
| opt_expanddims_3 = self.expanddims_3(opt_expanddims_0, self.expanddims_3_axis) | |||||
| opt_cast_5 = self.cast_5(opt_expanddims_3, self.cast_5_to) | |||||
| opt_sub_7 = self.sub_7(self.sub_7_bias, opt_cast_5) | |||||
| opt_mul_9 = self.mul_9(opt_sub_7, self.mul_9_w) | |||||
| opt_gather_1_axis = self.gather_1_axis | |||||
| opt_gather_1 = self.gather_1(self.gather_1_input_weight, input_ids, opt_gather_1_axis) | |||||
| opt_gather_2_axis = self.gather_2_axis | |||||
| opt_gather_2 = self.gather_2(self.gather_2_input_weight, token_type_ids, opt_gather_2_axis) | |||||
| opt_add_4 = self.add_4(opt_gather_1, self.add_4_bias) | |||||
| opt_add_6 = self.add_6(opt_add_4, opt_gather_2) | |||||
| layernorm1_0_opt = self.layernorm1_0(opt_add_6) | |||||
| module51_0_opt = self.module51_0(layernorm1_0_opt, opt_mul_9) | |||||
| module51_1_opt = self.module51_1(module51_0_opt, opt_mul_9) | |||||
| module51_2_opt = self.module51_2(module51_1_opt, opt_mul_9) | |||||
| opt_gather_643_axis = self.gather_643_axis | |||||
| opt_gather_643 = self.gather_643(module51_2_opt, self.gather_643_input_weight, opt_gather_643_axis) | |||||
| opt_dense_644 = self.dense_644(opt_gather_643) | |||||
| opt_tanh_645 = self.tanh_645(opt_dense_644) | |||||
| return opt_tanh_645 | |||||
| @@ -19,7 +19,7 @@ from mindspore import load_checkpoint, load_param_into_net | |||||
| from mindspore.ops import BatchMatMul | from mindspore.ops import BatchMatMul | ||||
| from mindspore import ops | from mindspore import ops | ||||
| from mindspore import dtype as mstype | from mindspore import dtype as mstype | ||||
| from src.reader_albert_xxlarge import Reader_Albert | |||||
| from src.albert import Albert | |||||
| from src.reader_downstream import Reader_Downstream | from src.reader_downstream import Reader_Downstream | ||||
| @@ -33,15 +33,15 @@ class Reader(nn.Cell): | |||||
| """init function""" | """init function""" | ||||
| super(Reader, self).__init__(auto_prefix=False) | super(Reader, self).__init__(auto_prefix=False) | ||||
| self.encoder = Reader_Albert(batch_size) | |||||
| self.encoder = Albert(batch_size) | |||||
| param_dict = load_checkpoint(encoder_ck_file) | param_dict = load_checkpoint(encoder_ck_file) | ||||
| not_load_params = load_param_into_net(self.encoder, param_dict) | not_load_params = load_param_into_net(self.encoder, param_dict) | ||||
| print(f"not loaded: {not_load_params}") | |||||
| print(f"reader albert not loaded params: {not_load_params}") | |||||
| self.downstream = Reader_Downstream() | self.downstream = Reader_Downstream() | ||||
| param_dict = load_checkpoint(downstream_ck_file) | param_dict = load_checkpoint(downstream_ck_file) | ||||
| not_load_params = load_param_into_net(self.downstream, param_dict) | not_load_params = load_param_into_net(self.downstream, param_dict) | ||||
| print(f"not loaded: {not_load_params}") | |||||
| print(f"reader downstream not loaded params: {not_load_params}") | |||||
| self.bmm = BatchMatMul() | self.bmm = BatchMatMul() | ||||
| @@ -49,7 +49,7 @@ class Reader(nn.Cell): | |||||
| context_mask, square_mask, packing_mask, cache_mask, | context_mask, square_mask, packing_mask, cache_mask, | ||||
| para_start_mapping, sent_end_mapping): | para_start_mapping, sent_end_mapping): | ||||
| """construct function""" | """construct function""" | ||||
| state = self.encoder(attn_mask, input_ids, token_type_ids) | |||||
| state = self.encoder(input_ids, attn_mask, token_type_ids) | |||||
| para_state = self.bmm(ops.Cast()(para_start_mapping, dst_type), ops.Cast()(state, dst_type)) # [B, 2, D] | para_state = self.bmm(ops.Cast()(para_start_mapping, dst_type), ops.Cast()(state, dst_type)) # [B, 2, D] | ||||
| sent_state = self.bmm(ops.Cast()(sent_end_mapping, dst_type), ops.Cast()(state, dst_type)) # [B, max_sent, D] | sent_state = self.bmm(ops.Cast()(sent_end_mapping, dst_type), ops.Cast()(state, dst_type)) # [B, max_sent, D] | ||||
| @@ -1,263 +0,0 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """albert-xxlarge Model for reader""" | |||||
| import numpy as np | |||||
| from mindspore import nn, ops | |||||
| from mindspore import Tensor, Parameter | |||||
| from mindspore.ops import operations as P | |||||
| from mindspore import dtype as mstype | |||||
| dst_type = mstype.float16 | |||||
| dst_type2 = mstype.float32 | |||||
| class LayerNorm(nn.Cell): | |||||
| """LayerNorm layer""" | |||||
| def __init__(self, mul_7_w_shape, add_8_bias_shape): | |||||
| """init function""" | |||||
| super(LayerNorm, self).__init__() | |||||
| self.reducemean_0 = P.ReduceMean(keep_dims=True) | |||||
| self.sub_1 = P.Sub() | |||||
| self.pow_2 = P.Pow() | |||||
| self.pow_2_input_weight = 2.0 | |||||
| self.reducemean_3 = P.ReduceMean(keep_dims=True) | |||||
| self.add_4 = P.Add() | |||||
| self.add_4_bias = 9.999999960041972e-13 | |||||
| self.sqrt_5 = P.Sqrt() | |||||
| self.div_6 = P.Div() | |||||
| self.mul_7 = P.Mul() | |||||
| self.mul_7_w = Parameter(Tensor(np.random.uniform(0, 1, mul_7_w_shape).astype(np.float32)), name=None) | |||||
| self.add_8 = P.Add() | |||||
| self.add_8_bias = Parameter(Tensor(np.random.uniform(0, 1, add_8_bias_shape).astype(np.float32)), name=None) | |||||
| def construct(self, x): | |||||
| """construct function""" | |||||
| opt_reducemean_0 = self.reducemean_0(x, -1) | |||||
| opt_sub_1 = self.sub_1(x, opt_reducemean_0) | |||||
| opt_pow_2 = self.pow_2(opt_sub_1, self.pow_2_input_weight) | |||||
| opt_reducemean_3 = self.reducemean_3(opt_pow_2, -1) | |||||
| opt_add_4 = self.add_4(opt_reducemean_3, self.add_4_bias) | |||||
| opt_sqrt_5 = self.sqrt_5(opt_add_4) | |||||
| opt_div_6 = self.div_6(opt_sub_1, opt_sqrt_5) | |||||
| opt_mul_7 = self.mul_7(opt_div_6, self.mul_7_w) | |||||
| opt_add_8 = self.add_8(opt_mul_7, self.add_8_bias) | |||||
| return opt_add_8 | |||||
| class Linear(nn.Cell): | |||||
| """Linear layer""" | |||||
| def __init__(self, matmul_0_weight_shape, add_1_bias_shape): | |||||
| """init function""" | |||||
| super(Linear, self).__init__() | |||||
| self.matmul_0 = nn.MatMul() | |||||
| self.matmul_0_w = Parameter(Tensor(np.random.uniform(0, 1, matmul_0_weight_shape).astype(np.float32)), | |||||
| name=None) | |||||
| self.add_1 = P.Add() | |||||
| self.add_1_bias = Parameter(Tensor(np.random.uniform(0, 1, add_1_bias_shape).astype(np.float32)), name=None) | |||||
| def construct(self, x): | |||||
| """construct function""" | |||||
| opt_matmul_0 = self.matmul_0(ops.Cast()(x, dst_type), ops.Cast()(self.matmul_0_w, dst_type)) | |||||
| opt_add_1 = self.add_1(ops.Cast()(opt_matmul_0, dst_type2), self.add_1_bias) | |||||
| return opt_add_1 | |||||
| class MultiHeadAttn(nn.Cell): | |||||
| """Multi-head attention layer""" | |||||
| def __init__(self, batch_size): | |||||
| """init function""" | |||||
| super(MultiHeadAttn, self).__init__() | |||||
| self.batch_size = batch_size | |||||
| self.matmul_0 = nn.MatMul() | |||||
| self.matmul_0_w = Parameter(Tensor(np.random.uniform(0, 1, (4096, 4096)).astype(np.float32)), name=None) | |||||
| self.matmul_1 = nn.MatMul() | |||||
| self.matmul_1_w = Parameter(Tensor(np.random.uniform(0, 1, (4096, 4096)).astype(np.float32)), name=None) | |||||
| self.matmul_2 = nn.MatMul() | |||||
| self.matmul_2_w = Parameter(Tensor(np.random.uniform(0, 1, (4096, 4096)).astype(np.float32)), name=None) | |||||
| self.add_3 = P.Add() | |||||
| self.add_3_bias = Parameter(Tensor(np.random.uniform(0, 1, (4096,)).astype(np.float32)), name=None) | |||||
| self.add_4 = P.Add() | |||||
| self.add_4_bias = Parameter(Tensor(np.random.uniform(0, 1, (4096,)).astype(np.float32)), name=None) | |||||
| self.add_5 = P.Add() | |||||
| self.add_5_bias = Parameter(Tensor(np.random.uniform(0, 1, (4096,)).astype(np.float32)), name=None) | |||||
| self.reshape_6 = P.Reshape() | |||||
| self.reshape_6_shape = tuple([batch_size, 512, 64, 64]) | |||||
| self.reshape_7 = P.Reshape() | |||||
| self.reshape_7_shape = tuple([batch_size, 512, 64, 64]) | |||||
| self.reshape_8 = P.Reshape() | |||||
| self.reshape_8_shape = tuple([batch_size, 512, 64, 64]) | |||||
| self.transpose_9 = P.Transpose() | |||||
| self.transpose_10 = P.Transpose() | |||||
| self.transpose_11 = P.Transpose() | |||||
| self.matmul_12 = nn.MatMul() | |||||
| self.div_13 = P.Div() | |||||
| self.div_13_w = 8.0 | |||||
| self.add_14 = P.Add() | |||||
| self.softmax_15 = nn.Softmax(axis=3) | |||||
| self.matmul_16 = nn.MatMul() | |||||
| self.transpose_17 = P.Transpose() | |||||
| self.matmul_18 = P.MatMul() | |||||
| self.matmul_18_weight = Parameter(Tensor(np.random.uniform(0, 1, (64, 64, 4096)).astype(np.float32)), name=None) | |||||
| self.add_19 = P.Add() | |||||
| self.add_19_bias = Parameter(Tensor(np.random.uniform(0, 1, (4096,)).astype(np.float32)), name=None) | |||||
| def construct(self, x, x0): | |||||
| """construct function""" | |||||
| opt_matmul_0 = self.matmul_0(ops.Cast()(x, dst_type), ops.Cast()(self.matmul_0_w, dst_type)) | |||||
| opt_matmul_1 = self.matmul_1(ops.Cast()(x, dst_type), ops.Cast()(self.matmul_1_w, dst_type)) | |||||
| opt_matmul_2 = self.matmul_2(ops.Cast()(x, dst_type), ops.Cast()(self.matmul_2_w, dst_type)) | |||||
| opt_add_3 = self.add_3(ops.Cast()(opt_matmul_0, dst_type2), self.add_3_bias) | |||||
| opt_add_4 = self.add_4(ops.Cast()(opt_matmul_1, dst_type2), self.add_4_bias) | |||||
| opt_add_5 = self.add_5(ops.Cast()(opt_matmul_2, dst_type2), self.add_5_bias) | |||||
| opt_reshape_6 = self.reshape_6(opt_add_3, self.reshape_6_shape) | |||||
| opt_reshape_7 = self.reshape_7(opt_add_4, self.reshape_7_shape) | |||||
| opt_reshape_8 = self.reshape_8(opt_add_5, self.reshape_8_shape) | |||||
| opt_transpose_9 = self.transpose_9(opt_reshape_6, (0, 2, 1, 3)) | |||||
| opt_transpose_10 = self.transpose_10(opt_reshape_7, (0, 2, 3, 1)) | |||||
| opt_transpose_11 = self.transpose_11(opt_reshape_8, (0, 2, 1, 3)) | |||||
| opt_matmul_12 = self.matmul_12(ops.Cast()(opt_transpose_9, dst_type), ops.Cast()(opt_transpose_10, dst_type)) | |||||
| opt_div_13 = self.div_13(ops.Cast()(opt_matmul_12, dst_type2), ops.Cast()(self.div_13_w, dst_type2)) | |||||
| opt_add_14 = self.add_14(opt_div_13, x0) | |||||
| opt_softmax_15 = self.softmax_15(opt_add_14) | |||||
| opt_matmul_16 = self.matmul_16(ops.Cast()(opt_softmax_15, dst_type), ops.Cast()(opt_transpose_11, dst_type)) | |||||
| opt_transpose_17 = self.transpose_17(ops.Cast()(opt_matmul_16, dst_type2), (0, 2, 1, 3)) | |||||
| opt_matmul_18 = self.matmul_18(ops.Cast()(opt_transpose_17, dst_type).view(self.batch_size * 512, -1), | |||||
| ops.Cast()(self.matmul_18_weight, dst_type).view(-1, 4096))\ | |||||
| .view(self.batch_size, 512, 4096) | |||||
| opt_add_19 = self.add_19(ops.Cast()(opt_matmul_18, dst_type2), self.add_19_bias) | |||||
| return opt_add_19 | |||||
| class NewGeLU(nn.Cell): | |||||
| """new gelu layer""" | |||||
| def __init__(self): | |||||
| """init function""" | |||||
| super(NewGeLU, self).__init__() | |||||
| self.mul_0 = P.Mul() | |||||
| self.mul_0_w = 0.5 | |||||
| self.pow_1 = P.Pow() | |||||
| self.pow_1_input_weight = 3.0 | |||||
| self.mul_2 = P.Mul() | |||||
| self.mul_2_w = 0.044714998453855515 | |||||
| self.add_3 = P.Add() | |||||
| self.mul_4 = P.Mul() | |||||
| self.mul_4_w = 0.7978845834732056 | |||||
| self.tanh_5 = nn.Tanh() | |||||
| self.add_6 = P.Add() | |||||
| self.add_6_bias = 1.0 | |||||
| self.mul_7 = P.Mul() | |||||
| def construct(self, x): | |||||
| """construct function""" | |||||
| opt_mul_0 = self.mul_0(x, self.mul_0_w) | |||||
| opt_pow_1 = self.pow_1(x, self.pow_1_input_weight) | |||||
| opt_mul_2 = self.mul_2(opt_pow_1, self.mul_2_w) | |||||
| opt_add_3 = self.add_3(x, opt_mul_2) | |||||
| opt_mul_4 = self.mul_4(opt_add_3, self.mul_4_w) | |||||
| opt_tanh_5 = self.tanh_5(opt_mul_4) | |||||
| opt_add_6 = self.add_6(opt_tanh_5, self.add_6_bias) | |||||
| opt_mul_7 = self.mul_7(opt_mul_0, opt_add_6) | |||||
| return opt_mul_7 | |||||
| class TransformerLayer(nn.Cell): | |||||
| """Transformer layer""" | |||||
| def __init__(self, batch_size, layernorm1_0_mul_7_w_shape, layernorm1_0_add_8_bias_shape, | |||||
| linear3_0_matmul_0_weight_shape, linear3_0_add_1_bias_shape, linear3_1_matmul_0_weight_shape, | |||||
| linear3_1_add_1_bias_shape): | |||||
| """init function""" | |||||
| super(TransformerLayer, self).__init__() | |||||
| self.multiheadattn_0 = MultiHeadAttn(batch_size) | |||||
| self.add_0 = P.Add() | |||||
| self.layernorm1_0 = LayerNorm(mul_7_w_shape=layernorm1_0_mul_7_w_shape, | |||||
| add_8_bias_shape=layernorm1_0_add_8_bias_shape) | |||||
| self.linear3_0 = Linear(matmul_0_weight_shape=linear3_0_matmul_0_weight_shape, | |||||
| add_1_bias_shape=linear3_0_add_1_bias_shape) | |||||
| self.newgelu2_0 = NewGeLU() | |||||
| self.linear3_1 = Linear(matmul_0_weight_shape=linear3_1_matmul_0_weight_shape, | |||||
| add_1_bias_shape=linear3_1_add_1_bias_shape) | |||||
| self.add_1 = P.Add() | |||||
| def construct(self, x, x0): | |||||
| """construct function""" | |||||
| multiheadattn_0_opt = self.multiheadattn_0(x, x0) | |||||
| opt_add_0 = self.add_0(x, multiheadattn_0_opt) | |||||
| layernorm1_0_opt = self.layernorm1_0(opt_add_0) | |||||
| linear3_0_opt = self.linear3_0(layernorm1_0_opt) | |||||
| newgelu2_0_opt = self.newgelu2_0(linear3_0_opt) | |||||
| linear3_1_opt = self.linear3_1(newgelu2_0_opt) | |||||
| opt_add_1 = self.add_1(linear3_1_opt, layernorm1_0_opt) | |||||
| return opt_add_1 | |||||
| class Reader_Albert(nn.Cell): | |||||
| """Albert model for reader""" | |||||
| def __init__(self, batch_size): | |||||
| """init function""" | |||||
| super(Reader_Albert, self).__init__() | |||||
| self.expanddims_0 = P.ExpandDims() | |||||
| self.expanddims_0_axis = 1 | |||||
| self.expanddims_3 = P.ExpandDims() | |||||
| self.expanddims_3_axis = 2 | |||||
| self.cast_5 = P.Cast() | |||||
| self.cast_5_to = mstype.float32 | |||||
| self.sub_7 = P.Sub() | |||||
| self.sub_7_bias = 1.0 | |||||
| self.mul_9 = P.Mul() | |||||
| self.mul_9_w = -10000.0 | |||||
| self.gather_1_input_weight = Parameter(Tensor(np.random.uniform(0, 1, (30005, 128)).astype(np.float32)), | |||||
| name=None) | |||||
| self.gather_1_axis = 0 | |||||
| self.gather_1 = P.Gather() | |||||
| self.gather_2_input_weight = Parameter(Tensor(np.random.uniform(0, 1, (2, 128)).astype(np.float32)), name=None) | |||||
| self.gather_2_axis = 0 | |||||
| self.gather_2 = P.Gather() | |||||
| self.add_4 = P.Add() | |||||
| self.add_6 = P.Add() | |||||
| self.add_6_bias = Parameter(Tensor(np.random.uniform(0, 1, (1, 512, 128)).astype(np.float32)), name=None) | |||||
| self.layernorm1_0 = LayerNorm(mul_7_w_shape=(128,), add_8_bias_shape=(128,)) | |||||
| self.linear3_0 = Linear(matmul_0_weight_shape=(128, 4096), add_1_bias_shape=(4096,)) | |||||
| self.module34_0 = TransformerLayer(batch_size, | |||||
| layernorm1_0_mul_7_w_shape=(4096,), | |||||
| layernorm1_0_add_8_bias_shape=(4096,), | |||||
| linear3_0_matmul_0_weight_shape=(4096, 16384), | |||||
| linear3_0_add_1_bias_shape=(16384,), | |||||
| linear3_1_matmul_0_weight_shape=(16384, 4096), | |||||
| linear3_1_add_1_bias_shape=(4096,)) | |||||
| self.layernorm1_1 = LayerNorm(mul_7_w_shape=(4096,), add_8_bias_shape=(4096,)) | |||||
| def construct(self, x, x0, x1): | |||||
| """construct function""" | |||||
| opt_expanddims_0 = self.expanddims_0(x, self.expanddims_0_axis) | |||||
| opt_expanddims_3 = self.expanddims_3(opt_expanddims_0, self.expanddims_3_axis) | |||||
| opt_cast_5 = self.cast_5(opt_expanddims_3, self.cast_5_to) | |||||
| opt_sub_7 = self.sub_7(self.sub_7_bias, opt_cast_5) | |||||
| opt_mul_9 = self.mul_9(opt_sub_7, self.mul_9_w) | |||||
| opt_gather_1_axis = self.gather_1_axis | |||||
| opt_gather_1 = self.gather_1(self.gather_1_input_weight, x0, opt_gather_1_axis) | |||||
| opt_gather_2_axis = self.gather_2_axis | |||||
| opt_gather_2 = self.gather_2(self.gather_2_input_weight, x1, opt_gather_2_axis) | |||||
| opt_add_4 = self.add_4(opt_gather_1, opt_gather_2) | |||||
| opt_add_6 = self.add_6(opt_add_4, self.add_6_bias) | |||||
| layernorm1_0_opt = self.layernorm1_0(opt_add_6) | |||||
| linear3_0_opt = self.linear3_0(layernorm1_0_opt) | |||||
| module34_0_opt = self.module34_0(linear3_0_opt, opt_mul_9) | |||||
| out = self.layernorm1_1(module34_0_opt) | |||||
| for _ in range(11): | |||||
| out = self.module34_0(out, opt_mul_9) | |||||
| out = self.layernorm1_1(out) | |||||
| return out | |||||
| @@ -25,138 +25,114 @@ dst_type = mstype.float16 | |||||
| dst_type2 = mstype.float32 | dst_type2 = mstype.float32 | ||||
| class Module15(nn.Cell): | |||||
| class Linear(nn.Cell): | |||||
| """module of reader downstream""" | """module of reader downstream""" | ||||
| def __init__(self, matmul_0_weight_shape, add_1_bias_shape): | |||||
| def __init__(self, linear_weight_shape, linear_bias_shape): | |||||
| """init function""" | """init function""" | ||||
| super(Module15, self).__init__() | |||||
| self.matmul_0 = nn.MatMul() | |||||
| self.matmul_0_w = Parameter(Tensor(np.random.uniform(0, 1, matmul_0_weight_shape).astype(np.float32)), | |||||
| name=None) | |||||
| self.add_1 = P.Add() | |||||
| self.add_1_bias = Parameter(Tensor(np.random.uniform(0, 1, add_1_bias_shape).astype(np.float32)), name=None) | |||||
| self.relu_2 = nn.ReLU() | |||||
| def construct(self, x): | |||||
| super(Linear, self).__init__() | |||||
| self.matmul = nn.MatMul() | |||||
| self.matmul_w = Parameter(Tensor(np.random.uniform(0, 1, linear_weight_shape).astype(np.float32)), | |||||
| name=None) | |||||
| self.add = P.Add() | |||||
| self.add_bias = Parameter(Tensor(np.random.uniform(0, 1, linear_bias_shape).astype(np.float32)), name=None) | |||||
| self.relu = nn.ReLU() | |||||
| def construct(self, hidden_state): | |||||
| """construct function""" | """construct function""" | ||||
| opt_matmul_0 = self.matmul_0(ops.Cast()(x, dst_type), ops.Cast()(self.matmul_0_w, dst_type)) | |||||
| opt_add_1 = self.add_1(ops.Cast()(opt_matmul_0, dst_type2), self.add_1_bias) | |||||
| opt_relu_2 = self.relu_2(opt_add_1) | |||||
| return opt_relu_2 | |||||
| output = self.matmul(ops.Cast()(hidden_state, dst_type), ops.Cast()(self.matmul_w, dst_type)) | |||||
| output = self.add(ops.Cast()(output, dst_type2), self.add_bias) | |||||
| output = self.relu(output) | |||||
| return output | |||||
| class NormModule(nn.Cell): | |||||
| class BertLayerNorm(nn.Cell): | |||||
| """Normalization module of reader downstream""" | """Normalization module of reader downstream""" | ||||
| def __init__(self, mul_8_w_shape, add_9_bias_shape): | |||||
| def __init__(self, bert_layer_norm_weight_shape, bert_layer_norm_bias_shape, eps=1e-12): | |||||
| """init function""" | """init function""" | ||||
| super(NormModule, self).__init__() | |||||
| self.reducemean_0 = P.ReduceMean(keep_dims=True) | |||||
| self.sub_1 = P.Sub() | |||||
| self.sub_2 = P.Sub() | |||||
| self.pow_3 = P.Pow() | |||||
| self.pow_3_input_weight = 2.0 | |||||
| self.reducemean_4 = P.ReduceMean(keep_dims=True) | |||||
| self.add_5 = P.Add() | |||||
| self.add_5_bias = 9.999999960041972e-13 | |||||
| self.sqrt_6 = P.Sqrt() | |||||
| self.div_7 = P.Div() | |||||
| self.mul_8 = P.Mul() | |||||
| self.mul_8_w = Parameter(Tensor(np.random.uniform(0, 1, mul_8_w_shape).astype(np.float32)), name=None) | |||||
| self.add_9 = P.Add() | |||||
| self.add_9_bias = Parameter(Tensor(np.random.uniform(0, 1, add_9_bias_shape).astype(np.float32)), name=None) | |||||
| super(BertLayerNorm, self).__init__() | |||||
| self.reducemean = P.ReduceMean(keep_dims=True) | |||||
| self.sub = P.Sub() | |||||
| self.pow = P.Pow() | |||||
| self.add = P.Add() | |||||
| self.sqrt = P.Sqrt() | |||||
| self.div = P.Div() | |||||
| self.mul = P.Mul() | |||||
| self.variance_epsilon = eps | |||||
| self.bert_layer_norm_weight = Parameter(Tensor(np.random.uniform(0, 1, bert_layer_norm_weight_shape) | |||||
| .astype(np.float32)), name=None) | |||||
| self.bert_layer_norm_bias = Parameter(Tensor(np.random.uniform(0, 1, bert_layer_norm_bias_shape) | |||||
| .astype(np.float32)), name=None) | |||||
| def construct(self, x): | def construct(self, x): | ||||
| """construct function""" | """construct function""" | ||||
| opt_reducemean_0 = self.reducemean_0(x, -1) | |||||
| opt_sub_1 = self.sub_1(x, opt_reducemean_0) | |||||
| opt_sub_2 = self.sub_2(x, opt_reducemean_0) | |||||
| opt_pow_3 = self.pow_3(opt_sub_1, self.pow_3_input_weight) | |||||
| opt_reducemean_4 = self.reducemean_4(opt_pow_3, -1) | |||||
| opt_add_5 = self.add_5(opt_reducemean_4, self.add_5_bias) | |||||
| opt_sqrt_6 = self.sqrt_6(opt_add_5) | |||||
| opt_div_7 = self.div_7(opt_sub_2, opt_sqrt_6) | |||||
| opt_mul_8 = self.mul_8(self.mul_8_w, opt_div_7) | |||||
| opt_add_9 = self.add_9(opt_mul_8, self.add_9_bias) | |||||
| return opt_add_9 | |||||
| class Module16(nn.Cell): | |||||
| """module of reader downstream""" | |||||
| def __init__(self, module15_0_matmul_0_weight_shape, module15_0_add_1_bias_shape, normmodule_0_mul_8_w_shape, | |||||
| normmodule_0_add_9_bias_shape): | |||||
| """init function""" | |||||
| super(Module16, self).__init__() | |||||
| self.module15_0 = Module15(matmul_0_weight_shape=module15_0_matmul_0_weight_shape, | |||||
| add_1_bias_shape=module15_0_add_1_bias_shape) | |||||
| self.normmodule_0 = NormModule(mul_8_w_shape=normmodule_0_mul_8_w_shape, | |||||
| add_9_bias_shape=normmodule_0_add_9_bias_shape) | |||||
| self.matmul_0 = nn.MatMul() | |||||
| self.matmul_0_w = Parameter(Tensor(np.random.uniform(0, 1, (8192, 1)).astype(np.float32)), name=None) | |||||
| u = self.reducemean(x, -1) | |||||
| s = self.reducemean(self.pow(self.sub(x, u), 2), -1) | |||||
| x = self.div(self.sub(x, u), self.sqrt(self.add(s, self.variance_epsilon))) | |||||
| output = self.mul(self.bert_layer_norm_weight, x) | |||||
| output = self.add(output, self.bert_layer_norm_bias) | |||||
| return output | |||||
| def construct(self, x): | |||||
| """construct function""" | |||||
| module15_0_opt = self.module15_0(x) | |||||
| normmodule_0_opt = self.normmodule_0(module15_0_opt) | |||||
| opt_matmul_0 = self.matmul_0(ops.Cast()(normmodule_0_opt, dst_type), ops.Cast()(self.matmul_0_w, dst_type)) | |||||
| return ops.Cast()(opt_matmul_0, dst_type2) | |||||
| class Module17(nn.Cell): | |||||
| class SupportingOutputLayer(nn.Cell): | |||||
| """module of reader downstream""" | """module of reader downstream""" | ||||
| def __init__(self, module15_0_matmul_0_weight_shape, module15_0_add_1_bias_shape, normmodule_0_mul_8_w_shape, | |||||
| normmodule_0_add_9_bias_shape): | |||||
| def __init__(self, linear_1_weight_shape, linear_1_bias_shape, bert_layer_norm_weight_shape, | |||||
| bert_layer_norm_bias_shape): | |||||
| """init function""" | """init function""" | ||||
| super(Module17, self).__init__() | |||||
| self.module15_0 = Module15(matmul_0_weight_shape=module15_0_matmul_0_weight_shape, | |||||
| add_1_bias_shape=module15_0_add_1_bias_shape) | |||||
| self.normmodule_0 = NormModule(mul_8_w_shape=normmodule_0_mul_8_w_shape, | |||||
| add_9_bias_shape=normmodule_0_add_9_bias_shape) | |||||
| self.matmul_0 = nn.MatMul() | |||||
| self.matmul_0_w = Parameter(Tensor(np.random.uniform(0, 1, (4096, 1)).astype(np.float32)), name=None) | |||||
| self.add_1 = P.Add() | |||||
| self.add_1_bias = Parameter(Tensor(np.random.uniform(0, 1, (1,)).astype(np.float32)), name=None) | |||||
| super(SupportingOutputLayer, self).__init__() | |||||
| self.linear_1 = Linear(linear_weight_shape=linear_1_weight_shape, | |||||
| linear_bias_shape=linear_1_bias_shape) | |||||
| self.bert_layer_norm = BertLayerNorm(bert_layer_norm_weight_shape=bert_layer_norm_weight_shape, | |||||
| bert_layer_norm_bias_shape=bert_layer_norm_bias_shape) | |||||
| self.matmul = nn.MatMul() | |||||
| self.matmul_w = Parameter(Tensor(np.random.uniform(0, 1, (8192, 1)).astype(np.float32)), name=None) | |||||
| def construct(self, x): | def construct(self, x): | ||||
| """construct function""" | """construct function""" | ||||
| module15_0_opt = self.module15_0(x) | |||||
| normmodule_0_opt = self.normmodule_0(module15_0_opt) | |||||
| opt_matmul_0 = self.matmul_0(ops.Cast()(normmodule_0_opt, dst_type), ops.Cast()(self.matmul_0_w, dst_type)) | |||||
| opt_add_1 = self.add_1(ops.Cast()(opt_matmul_0, dst_type2), self.add_1_bias) | |||||
| return opt_add_1 | |||||
| output = self.linear_1(x) | |||||
| output = self.bert_layer_norm(output) | |||||
| output = self.matmul(ops.Cast()(output, dst_type), ops.Cast()(self.matmul_w, dst_type)) | |||||
| return ops.Cast()(output, dst_type2) | |||||
| class Module5(nn.Cell): | |||||
| class PosOutputLayer(nn.Cell): | |||||
| """module of reader downstream""" | """module of reader downstream""" | ||||
| def __init__(self): | |||||
| def __init__(self, linear_weight_shape, linear_bias_shape, bert_layer_norm_weight_shape, | |||||
| bert_layer_norm_bias_shape): | |||||
| """init function""" | """init function""" | ||||
| super(Module5, self).__init__() | |||||
| self.sub_0 = P.Sub() | |||||
| self.sub_0_bias = 1.0 | |||||
| self.mul_1 = P.Mul() | |||||
| self.mul_1_w = 1.0000000150474662e+30 | |||||
| def construct(self, x): | |||||
| super(PosOutputLayer, self).__init__() | |||||
| self.linear_1 = Linear(linear_weight_shape=linear_weight_shape, | |||||
| linear_bias_shape=linear_bias_shape) | |||||
| self.bert_layer_norm = BertLayerNorm(bert_layer_norm_weight_shape=bert_layer_norm_weight_shape, | |||||
| bert_layer_norm_bias_shape=bert_layer_norm_bias_shape) | |||||
| self.matmul = nn.MatMul() | |||||
| self.linear_2_weight = Parameter(Tensor(np.random.uniform(0, 1, (4096, 1)).astype(np.float32)), name=None) | |||||
| self.add = P.Add() | |||||
| self.linear_2_bias = Parameter(Tensor(np.random.uniform(0, 1, (1,)).astype(np.float32)), name=None) | |||||
| def construct(self, state): | |||||
| """construct function""" | """construct function""" | ||||
| opt_sub_0 = self.sub_0(self.sub_0_bias, x) | |||||
| opt_mul_1 = self.mul_1(opt_sub_0, self.mul_1_w) | |||||
| return opt_mul_1 | |||||
| output = self.linear_1(state) | |||||
| output = self.bert_layer_norm(output) | |||||
| output = self.matmul(ops.Cast()(output, dst_type), ops.Cast()(self.linear_2_weight, dst_type)) | |||||
| output = self.add(ops.Cast()(output, dst_type2), self.linear_2_bias) | |||||
| return output | |||||
| class Module10(nn.Cell): | |||||
| class MaskInvalidPos(nn.Cell): | |||||
| """module of reader downstream""" | """module of reader downstream""" | ||||
| def __init__(self): | def __init__(self): | ||||
| """init function""" | """init function""" | ||||
| super(Module10, self).__init__() | |||||
| self.squeeze_0 = P.Squeeze(2) | |||||
| self.module5_0 = Module5() | |||||
| self.sub_1 = P.Sub() | |||||
| super(MaskInvalidPos, self).__init__() | |||||
| self.squeeze = P.Squeeze(2) | |||||
| self.sub = P.Sub() | |||||
| self.mul = P.Mul() | |||||
| def construct(self, x, x0): | |||||
| def construct(self, pos_pred, context_mask): | |||||
| """construct function""" | """construct function""" | ||||
| opt_squeeze_0 = self.squeeze_0(x) | |||||
| module5_0_opt = self.module5_0(x0) | |||||
| opt_sub_1 = self.sub_1(opt_squeeze_0, module5_0_opt) | |||||
| return opt_sub_1 | |||||
| output = self.squeeze(pos_pred) | |||||
| invalid_pos_mask = self.mul(self.sub(1.0, context_mask), 1e30) | |||||
| output = self.sub(output, invalid_pos_mask) | |||||
| return output | |||||
| class Reader_Downstream(nn.Cell): | class Reader_Downstream(nn.Cell): | ||||
| @@ -164,50 +140,52 @@ class Reader_Downstream(nn.Cell): | |||||
| def __init__(self): | def __init__(self): | ||||
| """init function""" | """init function""" | ||||
| super(Reader_Downstream, self).__init__() | super(Reader_Downstream, self).__init__() | ||||
| self.module16_0 = Module16(module15_0_matmul_0_weight_shape=(4096, 8192), | |||||
| module15_0_add_1_bias_shape=(8192,), | |||||
| normmodule_0_mul_8_w_shape=(8192,), | |||||
| normmodule_0_add_9_bias_shape=(8192,)) | |||||
| self.add_74 = P.Add() | |||||
| self.add_74_bias = Parameter(Tensor(np.random.uniform(0, 1, (1,)).astype(np.float32)), name=None) | |||||
| self.module16_1 = Module16(module15_0_matmul_0_weight_shape=(4096, 8192), | |||||
| module15_0_add_1_bias_shape=(8192,), | |||||
| normmodule_0_mul_8_w_shape=(8192,), | |||||
| normmodule_0_add_9_bias_shape=(8192,)) | |||||
| self.add_75 = P.Add() | |||||
| self.add_75_bias = Parameter(Tensor(np.random.uniform(0, 1, (1,)).astype(np.float32)), name=None) | |||||
| self.module17_0 = Module17(module15_0_matmul_0_weight_shape=(4096, 4096), | |||||
| module15_0_add_1_bias_shape=(4096,), | |||||
| normmodule_0_mul_8_w_shape=(4096,), | |||||
| normmodule_0_add_9_bias_shape=(4096,)) | |||||
| self.module10_0 = Module10() | |||||
| self.module17_1 = Module17(module15_0_matmul_0_weight_shape=(4096, 4096), | |||||
| module15_0_add_1_bias_shape=(4096,), | |||||
| normmodule_0_mul_8_w_shape=(4096,), | |||||
| normmodule_0_add_9_bias_shape=(4096,)) | |||||
| self.module10_1 = Module10() | |||||
| self.gather_6_input_weight = Tensor(np.array(0)) | |||||
| self.gather_6_axis = 1 | |||||
| self.gather_6 = P.Gather() | |||||
| self.dense_13 = nn.Dense(in_channels=4096, out_channels=4096, has_bias=True) | |||||
| self.relu_18 = nn.ReLU() | |||||
| self.normmodule_0 = NormModule(mul_8_w_shape=(4096,), add_9_bias_shape=(4096,)) | |||||
| self.dense_73 = nn.Dense(in_channels=4096, out_channels=3, has_bias=True) | |||||
| def construct(self, x, x0, x1, x2): | |||||
| self.add = P.Add() | |||||
| self.para_bias = Parameter(Tensor(np.random.uniform(0, 1, (1,)).astype(np.float32)), name=None) | |||||
| self.para_output_layer = SupportingOutputLayer(linear_1_weight_shape=(4096, 8192), | |||||
| linear_1_bias_shape=(8192,), | |||||
| bert_layer_norm_weight_shape=(8192,), | |||||
| bert_layer_norm_bias_shape=(8192,)) | |||||
| self.sent_bias = Parameter(Tensor(np.random.uniform(0, 1, (1,)).astype(np.float32)), name=None) | |||||
| self.sent_output_layer = SupportingOutputLayer(linear_1_weight_shape=(4096, 8192), | |||||
| linear_1_bias_shape=(8192,), | |||||
| bert_layer_norm_weight_shape=(8192,), | |||||
| bert_layer_norm_bias_shape=(8192,)) | |||||
| self.start_output_layer = PosOutputLayer(linear_weight_shape=(4096, 4096), | |||||
| linear_bias_shape=(4096,), | |||||
| bert_layer_norm_weight_shape=(4096,), | |||||
| bert_layer_norm_bias_shape=(4096,)) | |||||
| self.end_output_layer = PosOutputLayer(linear_weight_shape=(4096, 4096), | |||||
| linear_bias_shape=(4096,), | |||||
| bert_layer_norm_weight_shape=(4096,), | |||||
| bert_layer_norm_bias_shape=(4096,)) | |||||
| self.mask_invalid_pos = MaskInvalidPos() | |||||
| self.gather_input_weight = Tensor(np.array(0)) | |||||
| self.gather = P.Gather() | |||||
| self.type_linear_1 = nn.Dense(in_channels=4096, out_channels=4096, has_bias=True) | |||||
| self.relu = nn.ReLU() | |||||
| self.bert_layer_norm = BertLayerNorm(bert_layer_norm_weight_shape=(4096,), bert_layer_norm_bias_shape=(4096,)) | |||||
| self.type_linear_2 = nn.Dense(in_channels=4096, out_channels=3, has_bias=True) | |||||
| def construct(self, para_state, sent_state, state, context_mask): | |||||
| """construct function""" | """construct function""" | ||||
| module16_0_opt = self.module16_0(x) | |||||
| opt_add_74 = self.add_74(module16_0_opt, self.add_74_bias) | |||||
| module16_1_opt = self.module16_1(x0) | |||||
| opt_add_75 = self.add_75(module16_1_opt, self.add_75_bias) | |||||
| module17_0_opt = self.module17_0(x1) | |||||
| opt_module10_0 = self.module10_0(module17_0_opt, x2) | |||||
| module17_1_opt = self.module17_1(x1) | |||||
| opt_module10_1 = self.module10_1(module17_1_opt, x2) | |||||
| opt_gather_6_axis = self.gather_6_axis | |||||
| opt_gather_6 = self.gather_6(x1, self.gather_6_input_weight, opt_gather_6_axis) | |||||
| opt_dense_13 = self.dense_13(opt_gather_6) | |||||
| opt_relu_18 = self.relu_18(opt_dense_13) | |||||
| normmodule_0_opt = self.normmodule_0(opt_relu_18) | |||||
| opt_dense_73 = self.dense_73(normmodule_0_opt) | |||||
| return opt_dense_73, opt_module10_0, opt_module10_1, opt_add_74, opt_add_75 | |||||
| para_logit = self.para_output_layer(para_state) | |||||
| para_logit = self.add(para_logit, self.para_bias) | |||||
| sent_logit = self.sent_output_layer(sent_state) | |||||
| sent_logit = self.add(sent_logit, self.sent_bias) | |||||
| start = self.start_output_layer(state) | |||||
| start = self.mask_invalid_pos(start, context_mask) | |||||
| end = self.end_output_layer(state) | |||||
| end = self.mask_invalid_pos(end, context_mask) | |||||
| cls_emb = self.gather(state, self.gather_input_weight, 1) | |||||
| q_type = self.type_linear_1(cls_emb) | |||||
| q_type = self.relu(q_type) | |||||
| q_type = self.bert_layer_norm(q_type) | |||||
| q_type = self.type_linear_2(q_type) | |||||
| return q_type, start, end, para_logit, sent_logit | |||||
| @@ -33,11 +33,11 @@ from src.reader import Reader | |||||
| def read(args): | def read(args): | ||||
| """reader function""" | """reader function""" | ||||
| db_file = args.wiki_db_file | |||||
| db_file = args.wiki_db_path | |||||
| reader_feature_file = args.reader_feature_file | reader_feature_file = args.reader_feature_file | ||||
| reader_example_file = args.reader_example_file | reader_example_file = args.reader_example_file | ||||
| encoder_ck_file = args.reader_encoder_ck_file | |||||
| downstream_ck_file = args.reader_downstream_ck_file | |||||
| encoder_ck_file = args.reader_encoder_ck_path | |||||
| downstream_ck_file = args.reader_downstream_ck_path | |||||
| albert_model_path = args.albert_model_path | albert_model_path = args.albert_model_path | ||||
| reader_result_file = args.reader_result_file | reader_result_file = args.reader_result_file | ||||
| seed = args.seed | seed = args.seed | ||||
| @@ -1,276 +0,0 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """albert-xxlarge Model for reranker""" | |||||
| import numpy as np | |||||
| from mindspore import nn, ops | |||||
| from mindspore import Tensor, Parameter | |||||
| from mindspore.ops import operations as P | |||||
| from mindspore import dtype as mstype | |||||
| dst_type = mstype.float16 | |||||
| dst_type2 = mstype.float32 | |||||
| class LayerNorm(nn.Cell): | |||||
| """LayerNorm layer""" | |||||
| def __init__(self, passthrough_w_0, passthrough_w_1): | |||||
| """init function""" | |||||
| super(LayerNorm, self).__init__() | |||||
| self.reducemean_0 = P.ReduceMean(keep_dims=True) | |||||
| self.sub_1 = P.Sub() | |||||
| self.pow_2 = P.Pow() | |||||
| self.pow_2_input_weight = 2.0 | |||||
| self.reducemean_3 = P.ReduceMean(keep_dims=True) | |||||
| self.add_4 = P.Add() | |||||
| self.add_4_bias = 9.999999960041972e-13 | |||||
| self.sqrt_5 = P.Sqrt() | |||||
| self.div_6 = P.Div() | |||||
| self.mul_7 = P.Mul() | |||||
| self.mul_7_w = passthrough_w_0 | |||||
| self.add_8 = P.Add() | |||||
| self.add_8_bias = passthrough_w_1 | |||||
| def construct(self, x): | |||||
| """construct function""" | |||||
| opt_reducemean_0 = self.reducemean_0(x, -1) | |||||
| opt_sub_1 = self.sub_1(x, opt_reducemean_0) | |||||
| opt_pow_2 = self.pow_2(opt_sub_1, self.pow_2_input_weight) | |||||
| opt_reducemean_3 = self.reducemean_3(opt_pow_2, -1) | |||||
| opt_add_4 = self.add_4(opt_reducemean_3, self.add_4_bias) | |||||
| opt_sqrt_5 = self.sqrt_5(opt_add_4) | |||||
| opt_div_6 = self.div_6(opt_sub_1, opt_sqrt_5) | |||||
| opt_mul_7 = self.mul_7(opt_div_6, self.mul_7_w) | |||||
| opt_add_8 = self.add_8(opt_mul_7, self.add_8_bias) | |||||
| return opt_add_8 | |||||
| class Linear(nn.Cell): | |||||
| """Linear layer""" | |||||
| def __init__(self, matmul_0_w_shape, passthrough_w_0): | |||||
| """init function""" | |||||
| super(Linear, self).__init__() | |||||
| self.matmul_0 = nn.MatMul() | |||||
| self.matmul_0_w = Parameter(Tensor(np.random.uniform(0, 1, matmul_0_w_shape).astype(np.float32)), name=None) | |||||
| self.add_1 = P.Add() | |||||
| self.add_1_bias = passthrough_w_0 | |||||
| def construct(self, x): | |||||
| """construct function""" | |||||
| opt_matmul_0 = self.matmul_0(ops.Cast()(x, dst_type), ops.Cast()(self.matmul_0_w, dst_type)) | |||||
| opt_add_1 = self.add_1(ops.Cast()(opt_matmul_0, dst_type2), self.add_1_bias) | |||||
| return opt_add_1 | |||||
| class MultiHeadAttn(nn.Cell): | |||||
| """Multi-head attention layer""" | |||||
| def __init__(self, batch_size, passthrough_w_0, passthrough_w_1, passthrough_w_2): | |||||
| """init function""" | |||||
| super(MultiHeadAttn, self).__init__() | |||||
| self.batch_size = batch_size | |||||
| self.matmul_0 = nn.MatMul() | |||||
| self.matmul_0_w = Parameter(Tensor(np.random.uniform(0, 1, (4096, 4096)).astype(np.float32)), name=None) | |||||
| self.matmul_1 = nn.MatMul() | |||||
| self.matmul_1_w = Parameter(Tensor(np.random.uniform(0, 1, (4096, 4096)).astype(np.float32)), name=None) | |||||
| self.matmul_2 = nn.MatMul() | |||||
| self.matmul_2_w = Parameter(Tensor(np.random.uniform(0, 1, (4096, 4096)).astype(np.float32)), name=None) | |||||
| self.add_3 = P.Add() | |||||
| self.add_3_bias = passthrough_w_0 | |||||
| self.add_4 = P.Add() | |||||
| self.add_4_bias = passthrough_w_1 | |||||
| self.add_5 = P.Add() | |||||
| self.add_5_bias = passthrough_w_2 | |||||
| self.reshape_6 = P.Reshape() | |||||
| self.reshape_6_shape = tuple([batch_size, 512, 64, 64]) | |||||
| self.reshape_7 = P.Reshape() | |||||
| self.reshape_7_shape = tuple([batch_size, 512, 64, 64]) | |||||
| self.reshape_8 = P.Reshape() | |||||
| self.reshape_8_shape = tuple([batch_size, 512, 64, 64]) | |||||
| self.transpose_9 = P.Transpose() | |||||
| self.transpose_10 = P.Transpose() | |||||
| self.transpose_11 = P.Transpose() | |||||
| self.matmul_12 = nn.MatMul() | |||||
| self.div_13 = P.Div() | |||||
| self.div_13_w = 8.0 | |||||
| self.add_14 = P.Add() | |||||
| self.softmax_15 = nn.Softmax(axis=3) | |||||
| self.matmul_16 = nn.MatMul() | |||||
| self.transpose_17 = P.Transpose() | |||||
| self.matmul_18 = P.MatMul() | |||||
| self.matmul_18_weight = Parameter(Tensor(np.random.uniform(0, 1, (64, 64, 4096)).astype(np.float32)), name=None) | |||||
| self.add_19 = P.Add() | |||||
| self.add_19_bias = Parameter(Tensor(np.random.uniform(0, 1, (4096,)).astype(np.float32)), name=None) | |||||
| def construct(self, x, x0): | |||||
| """construct function""" | |||||
| opt_matmul_0 = self.matmul_0(ops.Cast()(x, dst_type), ops.Cast()(self.matmul_0_w, dst_type)) | |||||
| opt_matmul_1 = self.matmul_1(ops.Cast()(x, dst_type), ops.Cast()(self.matmul_1_w, dst_type)) | |||||
| opt_matmul_2 = self.matmul_2(ops.Cast()(x, dst_type), ops.Cast()(self.matmul_2_w, dst_type)) | |||||
| opt_add_3 = self.add_3(ops.Cast()(opt_matmul_0, dst_type2), self.add_3_bias) | |||||
| opt_add_4 = self.add_4(ops.Cast()(opt_matmul_1, dst_type2), self.add_4_bias) | |||||
| opt_add_5 = self.add_5(ops.Cast()(opt_matmul_2, dst_type2), self.add_5_bias) | |||||
| opt_reshape_6 = self.reshape_6(opt_add_3, self.reshape_6_shape) | |||||
| opt_reshape_7 = self.reshape_7(opt_add_4, self.reshape_7_shape) | |||||
| opt_reshape_8 = self.reshape_8(opt_add_5, self.reshape_8_shape) | |||||
| opt_transpose_9 = self.transpose_9(opt_reshape_6, (0, 2, 1, 3)) | |||||
| opt_transpose_10 = self.transpose_10(opt_reshape_7, (0, 2, 3, 1)) | |||||
| opt_transpose_11 = self.transpose_11(opt_reshape_8, (0, 2, 1, 3)) | |||||
| opt_matmul_12 = self.matmul_12(ops.Cast()(opt_transpose_9, dst_type), ops.Cast()(opt_transpose_10, dst_type)) | |||||
| opt_div_13 = self.div_13(ops.Cast()(opt_matmul_12, dst_type2), ops.Cast()(self.div_13_w, dst_type2)) | |||||
| opt_add_14 = self.add_14(opt_div_13, x0) | |||||
| opt_softmax_15 = self.softmax_15(opt_add_14) | |||||
| opt_matmul_16 = self.matmul_16(ops.Cast()(opt_softmax_15, dst_type), ops.Cast()(opt_transpose_11, dst_type)) | |||||
| opt_transpose_17 = self.transpose_17(ops.Cast()(opt_matmul_16, dst_type2), (0, 2, 1, 3)) | |||||
| opt_matmul_18 = self.matmul_18(ops.Cast()(opt_transpose_17, dst_type).view(self.batch_size * 512, -1), | |||||
| ops.Cast()(self.matmul_18_weight, dst_type).view(-1, 4096))\ | |||||
| .view(self.batch_size, 512, 4096) | |||||
| opt_add_19 = self.add_19(ops.Cast()(opt_matmul_18, dst_type2), self.add_19_bias) | |||||
| return opt_add_19 | |||||
| class NewGeLU(nn.Cell): | |||||
| """Gelu layer""" | |||||
| def __init__(self): | |||||
| """init function""" | |||||
| super(NewGeLU, self).__init__() | |||||
| self.mul_0 = P.Mul() | |||||
| self.mul_0_w = 0.5 | |||||
| self.pow_1 = P.Pow() | |||||
| self.pow_1_input_weight = 3.0 | |||||
| self.mul_2 = P.Mul() | |||||
| self.mul_2_w = 0.044714998453855515 | |||||
| self.add_3 = P.Add() | |||||
| self.mul_4 = P.Mul() | |||||
| self.mul_4_w = 0.7978845834732056 | |||||
| self.tanh_5 = nn.Tanh() | |||||
| self.add_6 = P.Add() | |||||
| self.add_6_bias = 1.0 | |||||
| self.mul_7 = P.Mul() | |||||
| def construct(self, x): | |||||
| """construct function""" | |||||
| opt_mul_0 = self.mul_0(x, self.mul_0_w) | |||||
| opt_pow_1 = self.pow_1(x, self.pow_1_input_weight) | |||||
| opt_mul_2 = self.mul_2(opt_pow_1, self.mul_2_w) | |||||
| opt_add_3 = self.add_3(x, opt_mul_2) | |||||
| opt_mul_4 = self.mul_4(opt_add_3, self.mul_4_w) | |||||
| opt_tanh_5 = self.tanh_5(opt_mul_4) | |||||
| opt_add_6 = self.add_6(opt_tanh_5, self.add_6_bias) | |||||
| opt_mul_7 = self.mul_7(opt_mul_0, opt_add_6) | |||||
| return opt_mul_7 | |||||
| class TransformerLayerWithLayerNorm(nn.Cell): | |||||
| """Transformer layer with LayerNOrm""" | |||||
| def __init__(self, batch_size, linear3_0_matmul_0_w_shape, linear3_1_matmul_0_w_shape, passthrough_w_0, | |||||
| passthrough_w_1, passthrough_w_2, passthrough_w_3, passthrough_w_4, passthrough_w_5, passthrough_w_6): | |||||
| """init function""" | |||||
| super(TransformerLayerWithLayerNorm, self).__init__() | |||||
| self.multiheadattn_0 = MultiHeadAttn(batch_size=batch_size, | |||||
| passthrough_w_0=passthrough_w_0, | |||||
| passthrough_w_1=passthrough_w_1, | |||||
| passthrough_w_2=passthrough_w_2) | |||||
| self.add_0 = P.Add() | |||||
| self.layernorm1_0 = LayerNorm(passthrough_w_0=passthrough_w_3, passthrough_w_1=passthrough_w_4) | |||||
| self.linear3_0 = Linear(matmul_0_w_shape=linear3_0_matmul_0_w_shape, passthrough_w_0=passthrough_w_5) | |||||
| self.newgelu2_0 = NewGeLU() | |||||
| self.linear3_1 = Linear(matmul_0_w_shape=linear3_1_matmul_0_w_shape, passthrough_w_0=passthrough_w_6) | |||||
| self.add_1 = P.Add() | |||||
| def construct(self, x, x0): | |||||
| """construct function""" | |||||
| multiheadattn_0_opt = self.multiheadattn_0(x, x0) | |||||
| opt_add_0 = self.add_0(x, multiheadattn_0_opt) | |||||
| layernorm1_0_opt = self.layernorm1_0(opt_add_0) | |||||
| linear3_0_opt = self.linear3_0(layernorm1_0_opt) | |||||
| newgelu2_0_opt = self.newgelu2_0(linear3_0_opt) | |||||
| linear3_1_opt = self.linear3_1(newgelu2_0_opt) | |||||
| opt_add_1 = self.add_1(linear3_1_opt, layernorm1_0_opt) | |||||
| return opt_add_1 | |||||
| class Rerank_Albert(nn.Cell): | |||||
| """Albert model for rerank""" | |||||
| def __init__(self, batch_size): | |||||
| """init function""" | |||||
| super(Rerank_Albert, self).__init__() | |||||
| self.passthrough_w_0 = Parameter(Tensor(np.random.uniform(0, 1, (128,)).astype(np.float32)), name=None) | |||||
| self.passthrough_w_1 = Parameter(Tensor(np.random.uniform(0, 1, (128,)).astype(np.float32)), name=None) | |||||
| self.passthrough_w_2 = Parameter(Tensor(np.random.uniform(0, 1, (4096,)).astype(np.float32)), name=None) | |||||
| self.passthrough_w_3 = Parameter(Tensor(np.random.uniform(0, 1, (4096,)).astype(np.float32)), name=None) | |||||
| self.passthrough_w_4 = Parameter(Tensor(np.random.uniform(0, 1, (4096,)).astype(np.float32)), name=None) | |||||
| self.passthrough_w_5 = Parameter(Tensor(np.random.uniform(0, 1, (4096,)).astype(np.float32)), name=None) | |||||
| self.passthrough_w_6 = Parameter(Tensor(np.random.uniform(0, 1, (4096,)).astype(np.float32)), name=None) | |||||
| self.passthrough_w_7 = Parameter(Tensor(np.random.uniform(0, 1, (4096,)).astype(np.float32)), name=None) | |||||
| self.passthrough_w_8 = Parameter(Tensor(np.random.uniform(0, 1, (16384,)).astype(np.float32)), name=None) | |||||
| self.passthrough_w_9 = Parameter(Tensor(np.random.uniform(0, 1, (4096,)).astype(np.float32)), name=None) | |||||
| self.passthrough_w_10 = Parameter(Tensor(np.random.uniform(0, 1, (4096,)).astype(np.float32)), name=None) | |||||
| self.passthrough_w_11 = Parameter(Tensor(np.random.uniform(0, 1, (4096,)).astype(np.float32)), name=None) | |||||
| self.expanddims_0 = P.ExpandDims() | |||||
| self.expanddims_0_axis = 1 | |||||
| self.expanddims_3 = P.ExpandDims() | |||||
| self.expanddims_3_axis = 2 | |||||
| self.cast_5 = P.Cast() | |||||
| self.cast_5_to = mstype.float32 | |||||
| self.sub_7 = P.Sub() | |||||
| self.sub_7_bias = 1.0 | |||||
| self.mul_9 = P.Mul() | |||||
| self.mul_9_w = -10000.0 | |||||
| self.gather_1_input_weight = Parameter(Tensor(np.random.uniform(0, 1, (30005, 128)).astype(np.float32)), | |||||
| name=None) | |||||
| self.gather_1_axis = 0 | |||||
| self.gather_1 = P.Gather() | |||||
| self.gather_2_input_weight = Parameter(Tensor(np.random.uniform(0, 1, (2, 128)).astype(np.float32)), name=None) | |||||
| self.gather_2_axis = 0 | |||||
| self.gather_2 = P.Gather() | |||||
| self.add_4 = P.Add() | |||||
| self.add_4_bias = Parameter(Tensor(np.random.uniform(0, 1, (1, 512, 128)).astype(np.float32)), name=None) | |||||
| self.add_6 = P.Add() | |||||
| self.layernorm1_0 = LayerNorm(passthrough_w_0=self.passthrough_w_0, passthrough_w_1=self.passthrough_w_1) | |||||
| self.linear3_0 = Linear(matmul_0_w_shape=(128, 4096), passthrough_w_0=self.passthrough_w_2) | |||||
| self.module34_0 = TransformerLayerWithLayerNorm(batch_size=batch_size, | |||||
| linear3_0_matmul_0_w_shape=(4096, 16384), | |||||
| linear3_1_matmul_0_w_shape=(16384, 4096), | |||||
| passthrough_w_0=self.passthrough_w_3, | |||||
| passthrough_w_1=self.passthrough_w_4, | |||||
| passthrough_w_2=self.passthrough_w_5, | |||||
| passthrough_w_3=self.passthrough_w_6, | |||||
| passthrough_w_4=self.passthrough_w_7, | |||||
| passthrough_w_5=self.passthrough_w_8, | |||||
| passthrough_w_6=self.passthrough_w_9) | |||||
| self.layernorm1_1 = LayerNorm(passthrough_w_0=self.passthrough_w_10, passthrough_w_1=self.passthrough_w_11) | |||||
| def construct(self, input_ids, attention_mask, token_type_ids): | |||||
| """construct function""" | |||||
| opt_expanddims_0 = self.expanddims_0(attention_mask, self.expanddims_0_axis) | |||||
| opt_expanddims_3 = self.expanddims_3(opt_expanddims_0, self.expanddims_3_axis) | |||||
| opt_cast_5 = self.cast_5(opt_expanddims_3, self.cast_5_to) | |||||
| opt_sub_7 = self.sub_7(self.sub_7_bias, opt_cast_5) | |||||
| opt_mul_9 = self.mul_9(opt_sub_7, self.mul_9_w) | |||||
| opt_gather_1_axis = self.gather_1_axis | |||||
| opt_gather_1 = self.gather_1(self.gather_1_input_weight, input_ids, opt_gather_1_axis) | |||||
| opt_gather_2_axis = self.gather_2_axis | |||||
| opt_gather_2 = self.gather_2(self.gather_2_input_weight, token_type_ids, opt_gather_2_axis) | |||||
| opt_add_4 = self.add_4(opt_gather_1, self.add_4_bias) | |||||
| opt_add_6 = self.add_6(opt_add_4, opt_gather_2) | |||||
| layernorm1_0_opt = self.layernorm1_0(opt_add_6) | |||||
| linear3_0_opt = self.linear3_0(layernorm1_0_opt) | |||||
| opt = self.module34_0(linear3_0_opt, opt_mul_9) | |||||
| opt = self.layernorm1_1(opt) | |||||
| for _ in range(11): | |||||
| opt = self.module34_0(opt, opt_mul_9) | |||||
| opt = self.layernorm1_1(opt) | |||||
| return opt | |||||
| @@ -155,6 +155,14 @@ def get_parse(): | |||||
| parser = argparse.ArgumentParser() | parser = argparse.ArgumentParser() | ||||
| # Environment | # Environment | ||||
| parser.add_argument('--data_path', | |||||
| type=str, | |||||
| default="", | |||||
| help='data path') | |||||
| parser.add_argument('--ckpt_path', | |||||
| type=str, | |||||
| default="", | |||||
| help='ckpt path') | |||||
| parser.add_argument('--seed', type=int, default=42, | parser.add_argument('--seed', type=int, default=42, | ||||
| help="random seed for initialization") | help="random seed for initialization") | ||||
| parser.add_argument('--seq_len', type=int, default=512, | parser.add_argument('--seq_len', type=int, default=512, | ||||
| @@ -179,15 +187,15 @@ def get_parse(): | |||||
| help="Set this flag if you want to calculate reader metrics") | help="Set this flag if you want to calculate reader metrics") | ||||
| parser.add_argument('--dev_gold_file', | parser.add_argument('--dev_gold_file', | ||||
| type=str, | type=str, | ||||
| default="../hotpot_dev_fullwiki_v1.json", | |||||
| default="hotpot_dev_fullwiki_v1.json", | |||||
| help='file of dev ground truth') | help='file of dev ground truth') | ||||
| parser.add_argument('--wiki_db_file', | parser.add_argument('--wiki_db_file', | ||||
| type=str, | type=str, | ||||
| default="../enwiki_offset.db", | |||||
| default="enwiki_offset.db", | |||||
| help='wiki_database_file') | help='wiki_database_file') | ||||
| parser.add_argument('--albert_model_path', | |||||
| parser.add_argument('--albert_model', | |||||
| type=str, | type=str, | ||||
| default="../albert-xxlarge/", | |||||
| default="albert-xxlarge", | |||||
| help='model path of huggingface albert-xxlarge') | help='model path of huggingface albert-xxlarge') | ||||
| # Retriever | # Retriever | ||||
| @@ -213,11 +221,11 @@ def get_parse(): | |||||
| help='file of rerank result') | help='file of rerank result') | ||||
| parser.add_argument('--rerank_encoder_ck_file', | parser.add_argument('--rerank_encoder_ck_file', | ||||
| type=str, | type=str, | ||||
| default="../rerank_albert_12.ckpt", | |||||
| default="rerank_albert.ckpt", | |||||
| help='checkpoint of rerank albert-xxlarge') | help='checkpoint of rerank albert-xxlarge') | ||||
| parser.add_argument('--rerank_downstream_ck_file', | parser.add_argument('--rerank_downstream_ck_file', | ||||
| type=str, | type=str, | ||||
| default="../rerank_downstream.ckpt", | |||||
| default="rerank_downstream.ckpt", | |||||
| help='checkpoint of rerank downstream') | help='checkpoint of rerank downstream') | ||||
| # Reader | # Reader | ||||
| @@ -233,11 +241,11 @@ def get_parse(): | |||||
| help='file of reader example') | help='file of reader example') | ||||
| parser.add_argument('--reader_encoder_ck_file', | parser.add_argument('--reader_encoder_ck_file', | ||||
| type=str, | type=str, | ||||
| default="../albert_12_layer.ckpt", | |||||
| default="reader_albert.ckpt", | |||||
| help='checkpoint of reader albert-xxlarge') | help='checkpoint of reader albert-xxlarge') | ||||
| parser.add_argument('--reader_downstream_ck_file', | parser.add_argument('--reader_downstream_ck_file', | ||||
| type=str, | type=str, | ||||
| default="../reader_downstream.ckpt", | |||||
| default="reader_downstream.ckpt", | |||||
| help='checkpoint of reader downstream') | help='checkpoint of reader downstream') | ||||
| parser.add_argument('--reader_result_file', | parser.add_argument('--reader_result_file', | ||||
| type=str, | type=str, | ||||
| @@ -274,24 +282,19 @@ def select_reader_dev_data(args): | |||||
| for _, res in tqdm(rerank_result.items(), desc="get rerank unique ids"): | for _, res in tqdm(rerank_result.items(), desc="get rerank unique ids"): | ||||
| rerank_unique_ids[res[0]] = True | rerank_unique_ids[res[0]] = True | ||||
| print(f"rerank result num is {len(rerank_unique_ids)}") | |||||
| for feature in tqdm(dev_features, desc="select rerank top1 feature"): | for feature in tqdm(dev_features, desc="select rerank top1 feature"): | ||||
| if feature.unique_id in rerank_unique_ids: | if feature.unique_id in rerank_unique_ids: | ||||
| feature_unique_ids[feature.unique_id] = True | feature_unique_ids[feature.unique_id] = True | ||||
| new_dev_features.append(feature) | new_dev_features.append(feature) | ||||
| print(f"new feature num is {len(new_dev_features)}") | |||||
| for example in tqdm(dev_examples, desc="select rerank top1 example"): | for example in tqdm(dev_examples, desc="select rerank top1 example"): | ||||
| if example.unique_id in rerank_unique_ids and example.unique_id in feature_unique_ids: | if example.unique_id in rerank_unique_ids and example.unique_id in feature_unique_ids: | ||||
| new_dev_examples.append(example) | new_dev_examples.append(example) | ||||
| print(f"new examples num is {len(new_dev_examples)}") | |||||
| print("start save new examples ......") | |||||
| with gzip.open(reader_example_file, "wb") as f: | with gzip.open(reader_example_file, "wb") as f: | ||||
| pickle.dump(new_dev_examples, f) | pickle.dump(new_dev_examples, f) | ||||
| print("start save new features ......") | |||||
| with gzip.open(reader_feature_file, "wb") as f: | with gzip.open(reader_feature_file, "wb") as f: | ||||
| pickle.dump(new_dev_features, f) | pickle.dump(new_dev_features, f) | ||||
| print("finish selecting reader data !!!") | print("finish selecting reader data !!!") | ||||
| @@ -449,7 +452,6 @@ def cal_reranker_metrics(dev_gold_file, rerank_result_file): | |||||
| for item in tqdm(gt, desc="cal pem"): | for item in tqdm(gt, desc="cal pem"): | ||||
| _id = item["_id"] | _id = item["_id"] | ||||
| if _id in rerank_result: | if _id in rerank_result: | ||||
| pred = rerank_result[_id][1] | pred = rerank_result[_id][1] | ||||
| sps = item["supporting_facts"] | sps = item["supporting_facts"] | ||||
| @@ -20,42 +20,48 @@ from mindspore import Tensor, Parameter | |||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| class BertLayerNorm(nn.Cell): | |||||
| """Layer norm for Bert""" | |||||
| def __init__(self, bln_weight=None, bln_bias=None, eps=1e-12): | |||||
| """init function""" | |||||
| super(BertLayerNorm, self).__init__() | |||||
| self.weight = bln_weight | |||||
| self.bias = bln_bias | |||||
| self.variance_epsilon = eps | |||||
| self.reduce_mean = P.ReduceMean(keep_dims=True) | |||||
| self.sub = P.Sub() | |||||
| self.pow = P.Pow() | |||||
| self.sqrt = P.Sqrt() | |||||
| self.div = P.Div() | |||||
| self.add = P.Add() | |||||
| self.mul = P.Mul() | |||||
| def construct(self, x): | |||||
| u = self.reduce_mean(x, -1) | |||||
| s = self.reduce_mean(self.pow(self.sub(x, u), 2.0), -1) | |||||
| x = self.div(self.sub(x, u), self.sqrt(self.add(s, self.variance_epsilon))) | |||||
| output = self.mul(self.weight, x) | |||||
| output = self.add(output, self.bias) | |||||
| return output | |||||
| class Rerank_Downstream(nn.Cell): | class Rerank_Downstream(nn.Cell): | ||||
| """Downstream model for rerank""" | """Downstream model for rerank""" | ||||
| def __init__(self): | def __init__(self): | ||||
| """init function""" | """init function""" | ||||
| super(Rerank_Downstream, self).__init__() | super(Rerank_Downstream, self).__init__() | ||||
| self.dense_0 = nn.Dense(in_channels=4096, out_channels=8192, has_bias=True) | |||||
| self.relu_1 = nn.ReLU() | |||||
| self.reducemean_2 = P.ReduceMean(keep_dims=True) | |||||
| self.sub_3 = P.Sub() | |||||
| self.sub_4 = P.Sub() | |||||
| self.pow_5 = P.Pow() | |||||
| self.pow_5_input_weight = 2.0 | |||||
| self.reducemean_6 = P.ReduceMean(keep_dims=True) | |||||
| self.add_7 = P.Add() | |||||
| self.add_7_bias = 9.999999960041972e-13 | |||||
| self.sqrt_8 = P.Sqrt() | |||||
| self.div_9 = P.Div() | |||||
| self.mul_10 = P.Mul() | |||||
| self.mul_10_w = Parameter(Tensor(np.random.uniform(0, 1, (8192,)).astype(np.float32)), name=None) | |||||
| self.add_11 = P.Add() | |||||
| self.add_11_bias = Parameter(Tensor(np.random.uniform(0, 1, (8192,)).astype(np.float32)), name=None) | |||||
| self.dense_12 = nn.Dense(in_channels=8192, out_channels=2, has_bias=True) | |||||
| self.relu = nn.ReLU() | |||||
| self.linear_1 = nn.Dense(in_channels=4096, out_channels=8192, has_bias=True) | |||||
| self.linear_2 = nn.Dense(in_channels=8192, out_channels=2, has_bias=True) | |||||
| self.bln_1_weight = Parameter(Tensor(np.random.uniform(0, 1, (8192,)).astype(np.float32)), name=None) | |||||
| self.bln_1_bias = Parameter(Tensor(np.random.uniform(0, 1, (8192,)).astype(np.float32)), name=None) | |||||
| self.bln_1 = BertLayerNorm(bln_weight=self.bln_1_weight, bln_bias=self.bln_1_bias) | |||||
| def construct(self, x): | |||||
| def construct(self, cls_emd): | |||||
| """construct function""" | """construct function""" | ||||
| opt_dense_0 = self.dense_0(x) | |||||
| opt_relu_1 = self.relu_1(opt_dense_0) | |||||
| opt_reducemean_2 = self.reducemean_2(opt_relu_1, -1) | |||||
| opt_sub_3 = self.sub_3(opt_relu_1, opt_reducemean_2) | |||||
| opt_sub_4 = self.sub_4(opt_relu_1, opt_reducemean_2) | |||||
| opt_pow_5 = self.pow_5(opt_sub_3, self.pow_5_input_weight) | |||||
| opt_reducemean_6 = self.reducemean_6(opt_pow_5, -1) | |||||
| opt_add_7 = self.add_7(opt_reducemean_6, self.add_7_bias) | |||||
| opt_sqrt_8 = self.sqrt_8(opt_add_7) | |||||
| opt_div_9 = self.div_9(opt_sub_4, opt_sqrt_8) | |||||
| opt_mul_10 = self.mul_10(self.mul_10_w, opt_div_9) | |||||
| opt_add_11 = self.add_11(opt_mul_10, self.add_11_bias) | |||||
| opt_dense_12 = self.dense_12(opt_add_11) | |||||
| return opt_dense_12 | |||||
| output = self.linear_1(cls_emd) | |||||
| output = self.relu(output) | |||||
| output = self.bln_1(output) | |||||
| output = self.linear_2(output) | |||||
| return output | |||||
| @@ -16,7 +16,7 @@ | |||||
| import mindspore.nn as nn | import mindspore.nn as nn | ||||
| from mindspore import load_checkpoint, load_param_into_net | from mindspore import load_checkpoint, load_param_into_net | ||||
| from src.rerank_albert_xxlarge import Rerank_Albert | |||||
| from src.albert import Albert | |||||
| from src.rerank_downstream import Rerank_Downstream | from src.rerank_downstream import Rerank_Downstream | ||||
| @@ -26,15 +26,15 @@ class Reranker(nn.Cell): | |||||
| """init function""" | """init function""" | ||||
| super(Reranker, self).__init__(auto_prefix=False) | super(Reranker, self).__init__(auto_prefix=False) | ||||
| self.encoder = Rerank_Albert(batch_size) | |||||
| self.encoder = Albert(batch_size) | |||||
| param_dict = load_checkpoint(encoder_ck_file) | param_dict = load_checkpoint(encoder_ck_file) | ||||
| not_load_params_1 = load_param_into_net(self.encoder, param_dict) | not_load_params_1 = load_param_into_net(self.encoder, param_dict) | ||||
| print(f"not loaded albert: {not_load_params_1}") | |||||
| print(f"re-ranker albert not loaded params: {not_load_params_1}") | |||||
| self.no_answer_mlp = Rerank_Downstream() | self.no_answer_mlp = Rerank_Downstream() | ||||
| param_dict = load_checkpoint(downstream_ck_file) | param_dict = load_checkpoint(downstream_ck_file) | ||||
| not_load_params_2 = load_param_into_net(self.no_answer_mlp, param_dict) | not_load_params_2 = load_param_into_net(self.no_answer_mlp, param_dict) | ||||
| print(f"not loaded downstream: {not_load_params_2}") | |||||
| print(f"re-ranker downstream not loaded params: {not_load_params_2}") | |||||
| def construct(self, input_ids, attn_mask, token_type_ids): | def construct(self, input_ids, attn_mask, token_type_ids): | ||||
| """construct function""" | """construct function""" | ||||
| @@ -32,8 +32,8 @@ def rerank(args): | |||||
| """rerank function""" | """rerank function""" | ||||
| rerank_feature_file = args.rerank_feature_file | rerank_feature_file = args.rerank_feature_file | ||||
| rerank_result_file = args.rerank_result_file | rerank_result_file = args.rerank_result_file | ||||
| encoder_ck_file = args.rerank_encoder_ck_file | |||||
| downstream_ck_file = args.rerank_downstream_ck_file | |||||
| encoder_ck_file = args.rerank_encoder_ck_path | |||||
| downstream_ck_file = args.rerank_downstream_ck_path | |||||
| seed = args.seed | seed = args.seed | ||||
| seq_len = args.seq_len | seq_len = args.seq_len | ||||
| batch_size = args.rerank_batch_size | batch_size = args.rerank_batch_size | ||||
| @@ -1,302 +0,0 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ | |||||
| Two Hop BERT. | |||||
| """ | |||||
| import numpy as np | |||||
| from mindspore import nn | |||||
| from mindspore import Tensor, Parameter | |||||
| import mindspore.common.dtype as mstype | |||||
| from mindspore.ops import operations as P | |||||
| BATCH_SIZE = -1 | |||||
| class LayerNorm(nn.Cell): | |||||
| """layer norm""" | |||||
| def __init__(self): | |||||
| super(LayerNorm, self).__init__() | |||||
| self.reducemean_0 = P.ReduceMean(keep_dims=True) | |||||
| self.sub_1 = P.Sub() | |||||
| self.cast_2 = P.Cast() | |||||
| self.cast_2_to = mstype.float32 | |||||
| self.pow_3 = P.Pow() | |||||
| self.pow_3_input_weight = 2.0 | |||||
| self.reducemean_4 = P.ReduceMean(keep_dims=True) | |||||
| self.add_5 = P.Add() | |||||
| self.add_5_bias = 9.999999960041972e-13 | |||||
| self.sqrt_6 = P.Sqrt() | |||||
| self.div_7 = P.Div() | |||||
| self.mul_8 = P.Mul() | |||||
| self.mul_8_w = Parameter(Tensor(np.random.uniform(0, 1, (768,)).astype(np.float32)), name=None) | |||||
| self.add_9 = P.Add() | |||||
| self.add_9_bias = Parameter(Tensor(np.random.uniform(0, 1, (768,)).astype(np.float32)), name=None) | |||||
| def construct(self, x): | |||||
| """construct function""" | |||||
| opt_reducemean_0 = self.reducemean_0(x, -1) | |||||
| opt_sub_1 = self.sub_1(x, opt_reducemean_0) | |||||
| opt_cast_2 = self.cast_2(opt_sub_1, self.cast_2_to) | |||||
| opt_pow_3 = self.pow_3(opt_cast_2, self.pow_3_input_weight) | |||||
| opt_reducemean_4 = self.reducemean_4(opt_pow_3, -1) | |||||
| opt_add_5 = self.add_5(opt_reducemean_4, self.add_5_bias) | |||||
| opt_sqrt_6 = self.sqrt_6(opt_add_5) | |||||
| opt_div_7 = self.div_7(opt_sub_1, opt_sqrt_6) | |||||
| opt_mul_8 = self.mul_8(opt_div_7, self.mul_8_w) | |||||
| opt_add_9 = self.add_9(opt_mul_8, self.add_9_bias) | |||||
| return opt_add_9 | |||||
| class MultiHeadAttn(nn.Cell): | |||||
| """multi head attention layer""" | |||||
| def __init__(self): | |||||
| super(MultiHeadAttn, self).__init__() | |||||
| self.matmul_0 = nn.MatMul() | |||||
| self.matmul_0.to_float(mstype.float16) | |||||
| self.matmul_0_w = Parameter(Tensor(np.random.uniform(0, 1, (768, 768)).astype(np.float32)), name=None) | |||||
| self.matmul_1 = nn.MatMul() | |||||
| self.matmul_1.to_float(mstype.float16) | |||||
| self.matmul_1_w = Parameter(Tensor(np.random.uniform(0, 1, (768, 768)).astype(np.float32)), name=None) | |||||
| self.matmul_2 = nn.MatMul() | |||||
| self.matmul_2.to_float(mstype.float16) | |||||
| self.matmul_2_w = Parameter(Tensor(np.random.uniform(0, 1, (768, 768)).astype(np.float32)), name=None) | |||||
| self.add_3 = P.Add() | |||||
| self.add_3_bias = Parameter(Tensor(np.random.uniform(0, 1, (768,)).astype(np.float32)), name=None) | |||||
| self.add_4 = P.Add() | |||||
| self.add_4_bias = Parameter(Tensor(np.random.uniform(0, 1, (768,)).astype(np.float32)), name=None) | |||||
| self.add_5 = P.Add() | |||||
| self.add_5_bias = Parameter(Tensor(np.random.uniform(0, 1, (768,)).astype(np.float32)), name=None) | |||||
| self.reshape_6 = P.Reshape() | |||||
| self.reshape_6_shape = tuple([BATCH_SIZE, 448, 12, 64]) | |||||
| self.reshape_7 = P.Reshape() | |||||
| self.reshape_7_shape = tuple([BATCH_SIZE, 448, 12, 64]) | |||||
| self.reshape_8 = P.Reshape() | |||||
| self.reshape_8_shape = tuple([BATCH_SIZE, 448, 12, 64]) | |||||
| self.transpose_9 = P.Transpose() | |||||
| self.transpose_10 = P.Transpose() | |||||
| self.transpose_11 = P.Transpose() | |||||
| self.matmul_12 = nn.MatMul() | |||||
| self.matmul_12.to_float(mstype.float16) | |||||
| self.div_13 = P.Div() | |||||
| self.div_13_w = 8.0 | |||||
| self.add_14 = P.Add() | |||||
| self.softmax_15 = nn.Softmax(axis=3) | |||||
| self.matmul_16 = nn.MatMul() | |||||
| self.matmul_16.to_float(mstype.float16) | |||||
| self.transpose_17 = P.Transpose() | |||||
| self.reshape_18 = P.Reshape() | |||||
| self.reshape_18_shape = tuple([BATCH_SIZE, 448, 768]) | |||||
| self.matmul_19 = nn.MatMul() | |||||
| self.matmul_19.to_float(mstype.float16) | |||||
| self.matmul_19_w = Parameter(Tensor(np.random.uniform(0, 1, (768, 768)).astype(np.float32)), name=None) | |||||
| self.add_20 = P.Add() | |||||
| self.add_20_bias = Parameter(Tensor(np.random.uniform(0, 1, (768,)).astype(np.float32)), name=None) | |||||
| def construct(self, x, x0): | |||||
| """construct function""" | |||||
| opt_matmul_0 = self.matmul_0(x, self.matmul_0_w) | |||||
| opt_matmul_1 = self.matmul_1(x, self.matmul_1_w) | |||||
| opt_matmul_2 = self.matmul_2(x, self.matmul_2_w) | |||||
| opt_matmul_0 = P.Cast()(opt_matmul_0, mstype.float32) | |||||
| opt_matmul_1 = P.Cast()(opt_matmul_1, mstype.float32) | |||||
| opt_matmul_2 = P.Cast()(opt_matmul_2, mstype.float32) | |||||
| opt_add_3 = self.add_3(opt_matmul_0, self.add_3_bias) | |||||
| opt_add_4 = self.add_4(opt_matmul_1, self.add_4_bias) | |||||
| opt_add_5 = self.add_5(opt_matmul_2, self.add_5_bias) | |||||
| opt_reshape_6 = self.reshape_6(opt_add_3, self.reshape_6_shape) | |||||
| opt_reshape_7 = self.reshape_7(opt_add_4, self.reshape_7_shape) | |||||
| opt_reshape_8 = self.reshape_8(opt_add_5, self.reshape_8_shape) | |||||
| opt_transpose_9 = self.transpose_9(opt_reshape_6, (0, 2, 1, 3)) | |||||
| opt_transpose_10 = self.transpose_10(opt_reshape_7, (0, 2, 3, 1)) | |||||
| opt_transpose_11 = self.transpose_11(opt_reshape_8, (0, 2, 1, 3)) | |||||
| opt_matmul_12 = self.matmul_12(opt_transpose_9, opt_transpose_10) | |||||
| opt_matmul_12 = P.Cast()(opt_matmul_12, mstype.float32) | |||||
| opt_div_13 = self.div_13(opt_matmul_12, self.div_13_w) | |||||
| opt_add_14 = self.add_14(opt_div_13, x0) | |||||
| opt_add_14 = P.Cast()(opt_add_14, mstype.float32) | |||||
| opt_softmax_15 = self.softmax_15(opt_add_14) | |||||
| opt_matmul_16 = self.matmul_16(opt_softmax_15, opt_transpose_11) | |||||
| opt_matmul_16 = P.Cast()(opt_matmul_16, mstype.float32) | |||||
| opt_transpose_17 = self.transpose_17(opt_matmul_16, (0, 2, 1, 3)) | |||||
| opt_reshape_18 = self.reshape_18(opt_transpose_17, self.reshape_18_shape) | |||||
| opt_matmul_19 = self.matmul_19(opt_reshape_18, self.matmul_19_w) | |||||
| opt_matmul_19 = P.Cast()(opt_matmul_19, mstype.float32) | |||||
| opt_add_20 = self.add_20(opt_matmul_19, self.add_20_bias) | |||||
| return opt_add_20 | |||||
| class Linear(nn.Cell): | |||||
| """linear layer""" | |||||
| def __init__(self, matmul_0_weight_shape, add_1_bias_shape): | |||||
| super(Linear, self).__init__() | |||||
| self.matmul_0 = nn.MatMul() | |||||
| self.matmul_0.to_float(mstype.float16) | |||||
| self.matmul_0_w = Parameter(Tensor(np.random.uniform(0, 1, matmul_0_weight_shape).astype(np.float32)), | |||||
| name=None) | |||||
| self.add_1 = P.Add() | |||||
| self.add_1_bias = Parameter(Tensor(np.random.uniform(0, 1, add_1_bias_shape).astype(np.float32)), name=None) | |||||
| def construct(self, x): | |||||
| """construct function""" | |||||
| opt_matmul_0 = self.matmul_0(x, self.matmul_0_w) | |||||
| opt_matmul_0 = P.Cast()(opt_matmul_0, mstype.float32) | |||||
| opt_add_1 = self.add_1(opt_matmul_0, self.add_1_bias) | |||||
| return opt_add_1 | |||||
| class GeLU(nn.Cell): | |||||
| """gelu layer""" | |||||
| def __init__(self): | |||||
| super(GeLU, self).__init__() | |||||
| self.div_0 = P.Div() | |||||
| self.div_0_w = 1.4142135381698608 | |||||
| self.erf_1 = P.Erf() | |||||
| self.add_2 = P.Add() | |||||
| self.add_2_bias = 1.0 | |||||
| self.mul_3 = P.Mul() | |||||
| self.mul_4 = P.Mul() | |||||
| self.mul_4_w = 0.5 | |||||
| def construct(self, x): | |||||
| """construct function""" | |||||
| opt_div_0 = self.div_0(x, self.div_0_w) | |||||
| opt_erf_1 = self.erf_1(opt_div_0) | |||||
| opt_add_2 = self.add_2(opt_erf_1, self.add_2_bias) | |||||
| opt_mul_3 = self.mul_3(x, opt_add_2) | |||||
| opt_mul_4 = self.mul_4(opt_mul_3, self.mul_4_w) | |||||
| return opt_mul_4 | |||||
| class TransformerLayer(nn.Cell): | |||||
| """transformer layer""" | |||||
| def __init__(self, linear3_0_matmul_0_weight_shape, linear3_0_add_1_bias_shape, linear3_1_matmul_0_weight_shape, | |||||
| linear3_1_add_1_bias_shape): | |||||
| super(TransformerLayer, self).__init__() | |||||
| self.multiheadattn_0 = MultiHeadAttn() | |||||
| self.add_0 = P.Add() | |||||
| self.layernorm1_0 = LayerNorm() | |||||
| self.linear3_0 = Linear(matmul_0_weight_shape=linear3_0_matmul_0_weight_shape, | |||||
| add_1_bias_shape=linear3_0_add_1_bias_shape) | |||||
| self.gelu1_0 = GeLU() | |||||
| self.linear3_1 = Linear(matmul_0_weight_shape=linear3_1_matmul_0_weight_shape, | |||||
| add_1_bias_shape=linear3_1_add_1_bias_shape) | |||||
| self.add_1 = P.Add() | |||||
| self.layernorm1_1 = LayerNorm() | |||||
| def construct(self, x, x0): | |||||
| """construct function""" | |||||
| multiheadattn_0_opt = self.multiheadattn_0(x, x0) | |||||
| opt_add_0 = self.add_0(multiheadattn_0_opt, x) | |||||
| layernorm1_0_opt = self.layernorm1_0(opt_add_0) | |||||
| linear3_0_opt = self.linear3_0(layernorm1_0_opt) | |||||
| gelu1_0_opt = self.gelu1_0(linear3_0_opt) | |||||
| linear3_1_opt = self.linear3_1(gelu1_0_opt) | |||||
| opt_add_1 = self.add_1(linear3_1_opt, layernorm1_0_opt) | |||||
| layernorm1_1_opt = self.layernorm1_1(opt_add_1) | |||||
| return layernorm1_1_opt | |||||
| class Encoder1_4(nn.Cell): | |||||
| """encoder layer""" | |||||
| def __init__(self): | |||||
| super(Encoder1_4, self).__init__() | |||||
| self.module46_0 = TransformerLayer(linear3_0_matmul_0_weight_shape=(768, 3072), | |||||
| linear3_0_add_1_bias_shape=(3072,), | |||||
| linear3_1_matmul_0_weight_shape=(3072, 768), | |||||
| linear3_1_add_1_bias_shape=(768,)) | |||||
| self.module46_1 = TransformerLayer(linear3_0_matmul_0_weight_shape=(768, 3072), | |||||
| linear3_0_add_1_bias_shape=(3072,), | |||||
| linear3_1_matmul_0_weight_shape=(3072, 768), | |||||
| linear3_1_add_1_bias_shape=(768,)) | |||||
| self.module46_2 = TransformerLayer(linear3_0_matmul_0_weight_shape=(768, 3072), | |||||
| linear3_0_add_1_bias_shape=(3072,), | |||||
| linear3_1_matmul_0_weight_shape=(3072, 768), | |||||
| linear3_1_add_1_bias_shape=(768,)) | |||||
| self.module46_3 = TransformerLayer(linear3_0_matmul_0_weight_shape=(768, 3072), | |||||
| linear3_0_add_1_bias_shape=(3072,), | |||||
| linear3_1_matmul_0_weight_shape=(3072, 768), | |||||
| linear3_1_add_1_bias_shape=(768,)) | |||||
| def construct(self, x, x0): | |||||
| """construct function""" | |||||
| module46_0_opt = self.module46_0(x, x0) | |||||
| module46_1_opt = self.module46_1(module46_0_opt, x0) | |||||
| module46_2_opt = self.module46_2(module46_1_opt, x0) | |||||
| module46_3_opt = self.module46_3(module46_2_opt, x0) | |||||
| return module46_3_opt | |||||
| class ModelTwoHop(nn.Cell): | |||||
| """two hop layer""" | |||||
| def __init__(self): | |||||
| super(ModelTwoHop, self).__init__() | |||||
| self.expanddims_0 = P.ExpandDims() | |||||
| self.expanddims_0_axis = 1 | |||||
| self.expanddims_3 = P.ExpandDims() | |||||
| self.expanddims_3_axis = 2 | |||||
| self.cast_5 = P.Cast() | |||||
| self.cast_5_to = mstype.float32 | |||||
| self.sub_7 = P.Sub() | |||||
| self.sub_7_bias = 1.0 | |||||
| self.mul_9 = P.Mul() | |||||
| self.mul_9_w = -10000.0 | |||||
| self.gather_1_input_weight = Parameter(Tensor(np.random.uniform(0, 1, (30522, 768)).astype(np.float32)), | |||||
| name=None) | |||||
| self.gather_1_axis = 0 | |||||
| self.gather_1 = P.Gather() | |||||
| self.gather_2_input_weight = Parameter(Tensor(np.random.uniform(0, 1, (2, 768)).astype(np.float32)), name=None) | |||||
| self.gather_2_axis = 0 | |||||
| self.gather_2 = P.Gather() | |||||
| self.add_4 = P.Add() | |||||
| self.add_6 = P.Add() | |||||
| self.add_6_bias = Parameter(Tensor(np.random.uniform(0, 1, (1, 448, 768)).astype(np.float32)), name=None) | |||||
| self.layernorm1_0 = LayerNorm() | |||||
| self.module50_0 = Encoder1_4() | |||||
| self.module50_1 = Encoder1_4() | |||||
| self.module50_2 = Encoder1_4() | |||||
| self.gather_643_input_weight = Tensor(np.array(0)) | |||||
| self.gather_643_axis = 1 | |||||
| self.gather_643 = P.Gather() | |||||
| self.dense_644 = nn.Dense(in_channels=768, out_channels=768, has_bias=True) | |||||
| self.tanh_645 = nn.Tanh() | |||||
| def construct(self, input_ids, token_type_ids, attention_mask): | |||||
| """construct function""" | |||||
| input_ids = P.Cast()(input_ids, mstype.int32) | |||||
| token_type_ids = P.Cast()(token_type_ids, mstype.int32) | |||||
| attention_mask = P.Cast()(attention_mask, mstype.int32) | |||||
| opt_expanddims_0 = self.expanddims_0(attention_mask, self.expanddims_0_axis) | |||||
| opt_expanddims_3 = self.expanddims_3(opt_expanddims_0, self.expanddims_3_axis) | |||||
| opt_cast_5 = self.cast_5(opt_expanddims_3, self.cast_5_to) | |||||
| opt_sub_7 = self.sub_7(self.sub_7_bias, opt_cast_5) | |||||
| opt_mul_9 = self.mul_9(opt_sub_7, self.mul_9_w) | |||||
| opt_gather_1_axis = self.gather_1_axis | |||||
| opt_gather_1 = self.gather_1(self.gather_1_input_weight, input_ids, opt_gather_1_axis) | |||||
| opt_gather_2_axis = self.gather_2_axis | |||||
| opt_gather_2 = self.gather_2(self.gather_2_input_weight, token_type_ids, opt_gather_2_axis) | |||||
| opt_add_4 = self.add_4(opt_gather_1, opt_gather_2) | |||||
| opt_add_6 = self.add_6(opt_add_4, self.add_6_bias) | |||||
| layernorm1_0_opt = self.layernorm1_0(opt_add_6) | |||||
| module50_0_opt = self.module50_0(layernorm1_0_opt, opt_mul_9) | |||||
| module50_1_opt = self.module50_1(module50_0_opt, opt_mul_9) | |||||
| module50_2_opt = self.module50_2(module50_1_opt, opt_mul_9) | |||||
| opt_gather_643_axis = self.gather_643_axis | |||||
| opt_gather_643 = self.gather_643(module50_2_opt, self.gather_643_input_weight, opt_gather_643_axis) | |||||
| opt_dense_644 = self.dense_644(opt_gather_643) | |||||
| opt_tanh_645 = self.tanh_645(opt_dense_644) | |||||
| return opt_tanh_645 | |||||