Merge pull request !2955 from linqingke/masstags/v0.6.0-beta
| @@ -40,7 +40,7 @@ parser.add_argument("--checkpoint_path", type=str, required=True, help="Checkpoi | |||||
| parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") | parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") | ||||
| args_opt = parser.parse_args() | args_opt = parser.parse_args() | ||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True, device_id=args_opt.device_id) | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id) | |||||
| def FasterRcnn_eval(dataset_path, ckpt_path, ann_file): | def FasterRcnn_eval(dataset_path, ckpt_path, ann_file): | ||||
| """FasterRcnn evaluation.""" | """FasterRcnn evaluation.""" | ||||
| @@ -22,7 +22,7 @@ from mindspore.common.tensor import Tensor | |||||
| from mindspore.common import dtype as mstype | from mindspore.common import dtype as mstype | ||||
| from mindspore.common.initializer import initializer | from mindspore.common.initializer import initializer | ||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||||
| def bias_init_zeros(shape): | def bias_init_zeros(shape): | ||||
| """Bias init method.""" | """Bias init method.""" | ||||
| @@ -22,7 +22,7 @@ from mindspore import Tensor | |||||
| from mindspore import context | from mindspore import context | ||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||||
| class Proposal(nn.Cell): | class Proposal(nn.Cell): | ||||
| @@ -22,7 +22,7 @@ from mindspore.ops import functional as F | |||||
| from mindspore import context | from mindspore import context | ||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||||
| def weight_init_ones(shape): | def weight_init_ones(shape): | ||||
| @@ -52,7 +52,7 @@ parser.add_argument("--device_num", type=int, default=1, help="Use device nums, | |||||
| parser.add_argument("--rank_id", type=int, default=0, help="Rank id, default is 0.") | parser.add_argument("--rank_id", type=int, default=0, help="Rank id, default is 0.") | ||||
| args_opt = parser.parse_args() | args_opt = parser.parse_args() | ||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True, device_id=args_opt.device_id) | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id) | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| if not args_opt.do_eval and args_opt.run_distribute: | if not args_opt.do_eval and args_opt.run_distribute: | ||||
| @@ -15,15 +15,13 @@ | |||||
| """Evaluation api.""" | """Evaluation api.""" | ||||
| import argparse | import argparse | ||||
| import pickle | import pickle | ||||
| import numpy as np | |||||
| from mindspore.common import dtype as mstype | from mindspore.common import dtype as mstype | ||||
| from config import TransformerConfig | from config import TransformerConfig | ||||
| from src.transformer import infer | |||||
| from src.utils import ngram_ppl | |||||
| from src.transformer import infer, infer_ppl | |||||
| from src.utils import Dictionary | from src.utils import Dictionary | ||||
| from src.utils import rouge | |||||
| from src.utils import get_score | |||||
| parser = argparse.ArgumentParser(description='Evaluation MASS.') | parser = argparse.ArgumentParser(description='Evaluation MASS.') | ||||
| parser.add_argument("--config", type=str, required=True, | parser.add_argument("--config", type=str, required=True, | ||||
| @@ -32,6 +30,8 @@ parser.add_argument("--vocab", type=str, required=True, | |||||
| help="Vocabulary to use.") | help="Vocabulary to use.") | ||||
| parser.add_argument("--output", type=str, required=True, | parser.add_argument("--output", type=str, required=True, | ||||
| help="Result file path.") | help="Result file path.") | ||||
| parser.add_argument("--metric", type=str, default='rouge', | |||||
| help='Set eval method.') | |||||
| def get_config(config): | def get_config(config): | ||||
| @@ -45,31 +45,15 @@ if __name__ == '__main__': | |||||
| args, _ = parser.parse_known_args() | args, _ = parser.parse_known_args() | ||||
| vocab = Dictionary.load_from_persisted_dict(args.vocab) | vocab = Dictionary.load_from_persisted_dict(args.vocab) | ||||
| _config = get_config(args.config) | _config = get_config(args.config) | ||||
| result = infer(_config) | |||||
| if args.metric == 'rouge': | |||||
| result = infer(_config) | |||||
| else: | |||||
| result = infer_ppl(_config) | |||||
| with open(args.output, "wb") as f: | with open(args.output, "wb") as f: | ||||
| pickle.dump(result, f, 1) | pickle.dump(result, f, 1) | ||||
| ppl_score = 0. | |||||
| preds = [] | |||||
| tgts = [] | |||||
| _count = 0 | |||||
| for sample in result: | |||||
| sentence_prob = np.array(sample['prediction_prob'], dtype=np.float32) | |||||
| sentence_prob = sentence_prob[:, 1:] | |||||
| _ppl = [] | |||||
| for path in sentence_prob: | |||||
| _ppl.append(ngram_ppl(path, log_softmax=True)) | |||||
| ppl = np.min(_ppl) | |||||
| preds.append(' '.join([vocab[t] for t in sample['prediction']])) | |||||
| tgts.append(' '.join([vocab[t] for t in sample['target']])) | |||||
| print(f" | source: {' '.join([vocab[t] for t in sample['source']])}") | |||||
| print(f" | target: {tgts[-1]}") | |||||
| print(f" | prediction: {preds[-1]}") | |||||
| print(f" | ppl: {ppl}.") | |||||
| if np.isinf(ppl): | |||||
| continue | |||||
| ppl_score += ppl | |||||
| _count += 1 | |||||
| print(f" | PPL={ppl_score / _count}.") | |||||
| rouge(preds, tgts) | |||||
| # get score by given metric | |||||
| score = get_score(result, vocab, metric=args.metric) | |||||
| print(score) | |||||
| @@ -18,7 +18,7 @@ export DEVICE_ID=0 | |||||
| export RANK_ID=0 | export RANK_ID=0 | ||||
| export RANK_SIZE=1 | export RANK_SIZE=1 | ||||
| options=`getopt -u -o ht:n:i:j:c:o:v: -l help,task:,device_num:,device_id:,hccl_json:,config:,output:,vocab: -- "$@"` | |||||
| options=`getopt -u -o ht:n:i:j:c:o:v:m: -l help,task:,device_num:,device_id:,hccl_json:,config:,output:,vocab:,metric: -- "$@"` | |||||
| eval set -- "$options" | eval set -- "$options" | ||||
| echo $options | echo $options | ||||
| @@ -35,6 +35,7 @@ echo_help() | |||||
| echo " -c --config set the configuration file" | echo " -c --config set the configuration file" | ||||
| echo " -o --output set the output file of inference" | echo " -o --output set the output file of inference" | ||||
| echo " -v --vocab set the vocabulary" | echo " -v --vocab set the vocabulary" | ||||
| echo " -m --metric set the metric" | |||||
| } | } | ||||
| set_hccl_json() | set_hccl_json() | ||||
| @@ -43,8 +44,8 @@ set_hccl_json() | |||||
| do | do | ||||
| if [[ "$1" == "-j" || "$1" == "--hccl_json" ]] | if [[ "$1" == "-j" || "$1" == "--hccl_json" ]] | ||||
| then | then | ||||
| export MINDSPORE_HCCL_CONFIG_PATH=$2 #/data/wsc/hccl_2p_01.json | |||||
| export RANK_TABLE_FILE=$2 #/data/wsc/hccl_2p_01.json | |||||
| export MINDSPORE_HCCL_CONFIG_PATH=$2 | |||||
| export RANK_TABLE_FILE=$2 | |||||
| break | break | ||||
| fi | fi | ||||
| shift | shift | ||||
| @@ -119,6 +120,11 @@ do | |||||
| vocab=$2 | vocab=$2 | ||||
| shift 2 | shift 2 | ||||
| ;; | ;; | ||||
| -m|--metric) | |||||
| echo "metric"; | |||||
| metric=$2 | |||||
| shift 2 | |||||
| ;; | |||||
| --) | --) | ||||
| shift | shift | ||||
| break | break | ||||
| @@ -163,7 +169,7 @@ do | |||||
| python train.py --config ${configurations##*/} >>log.log 2>&1 & | python train.py --config ${configurations##*/} >>log.log 2>&1 & | ||||
| elif [ "$task" == "infer" ] | elif [ "$task" == "infer" ] | ||||
| then | then | ||||
| python eval.py --config ${configurations##*/} --output ${output} --vocab ${vocab##*/} >>log_infer.log 2>&1 & | |||||
| python eval.py --config ${configurations##*/} --output ${output} --vocab ${vocab##*/} --metric ${metric} >>log_infer.log 2>&1 & | |||||
| fi | fi | ||||
| cd ../ | cd ../ | ||||
| done | done | ||||
| @@ -19,10 +19,11 @@ from .decoder import TransformerDecoder | |||||
| from .beam_search import BeamSearchDecoder | from .beam_search import BeamSearchDecoder | ||||
| from .transformer_for_train import TransformerTraining, LabelSmoothedCrossEntropyCriterion, \ | from .transformer_for_train import TransformerTraining, LabelSmoothedCrossEntropyCriterion, \ | ||||
| TransformerNetworkWithLoss, TransformerTrainOneStepWithLossScaleCell | TransformerNetworkWithLoss, TransformerTrainOneStepWithLossScaleCell | ||||
| from .infer_mass import infer | |||||
| from .infer_mass import infer, infer_ppl | |||||
| __all__ = [ | __all__ = [ | ||||
| "infer", | "infer", | ||||
| "infer_ppl", | |||||
| "TransformerTraining", | "TransformerTraining", | ||||
| "LabelSmoothedCrossEntropyCriterion", | "LabelSmoothedCrossEntropyCriterion", | ||||
| "TransformerTrainOneStepWithLossScaleCell", | "TransformerTrainOneStepWithLossScaleCell", | ||||
| @@ -41,7 +41,7 @@ class EmbeddingLookup(nn.Cell): | |||||
| self.vocab_size = vocab_size | self.vocab_size = vocab_size | ||||
| self.use_one_hot_embeddings = use_one_hot_embeddings | self.use_one_hot_embeddings = use_one_hot_embeddings | ||||
| init_weight = np.random.normal(0, embed_dim ** -0.5, size=[vocab_size, embed_dim]) | |||||
| init_weight = np.random.normal(0, embed_dim ** -0.5, size=[vocab_size, embed_dim]).astype(np.float32) | |||||
| # 0 is Padding index, thus init it as 0. | # 0 is Padding index, thus init it as 0. | ||||
| init_weight[0, :] = 0 | init_weight[0, :] = 0 | ||||
| self.embedding_table = Parameter(Tensor(init_weight), | self.embedding_table = Parameter(Tensor(init_weight), | ||||
| @@ -17,13 +17,16 @@ import time | |||||
| import mindspore.nn as nn | import mindspore.nn as nn | ||||
| import mindspore.common.dtype as mstype | import mindspore.common.dtype as mstype | ||||
| from mindspore.ops import operations as P | |||||
| from mindspore.common.tensor import Tensor | from mindspore.common.tensor import Tensor | ||||
| from mindspore.train.model import Model | from mindspore.train.model import Model | ||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||||
| from mindspore import context | from mindspore import context | ||||
| from src.dataset import load_dataset | from src.dataset import load_dataset | ||||
| from .transformer_for_infer import TransformerInferModel | from .transformer_for_infer import TransformerInferModel | ||||
| from .transformer_for_train import TransformerTraining | |||||
| from ..utils.load_weights import load_infer_weights | from ..utils.load_weights import load_infer_weights | ||||
| context.set_context( | context.set_context( | ||||
| @@ -156,3 +159,129 @@ def infer(config): | |||||
| shuffle=False) if config.test_dataset else None | shuffle=False) if config.test_dataset else None | ||||
| prediction = transformer_infer(config, eval_dataset) | prediction = transformer_infer(config, eval_dataset) | ||||
| return prediction | return prediction | ||||
| class TransformerInferPPLCell(nn.Cell): | |||||
| """ | |||||
| Encapsulation class of transformer network infer for PPL. | |||||
| Args: | |||||
| config(TransformerConfig): Config. | |||||
| Returns: | |||||
| Tuple[Tensor, Tensor], predicted log prob and label lengths. | |||||
| """ | |||||
| def __init__(self, config): | |||||
| super(TransformerInferPPLCell, self).__init__() | |||||
| self.transformer = TransformerTraining(config, is_training=False, use_one_hot_embeddings=False) | |||||
| self.batch_size = config.batch_size | |||||
| self.vocab_size = config.vocab_size | |||||
| self.one_hot = P.OneHot() | |||||
| self.on_value = Tensor(float(1), mstype.float32) | |||||
| self.off_value = Tensor(float(0), mstype.float32) | |||||
| self.reduce_sum = P.ReduceSum() | |||||
| self.reshape = P.Reshape() | |||||
| self.cast = P.Cast() | |||||
| self.flat_shape = (config.batch_size * config.seq_length,) | |||||
| self.batch_shape = (config.batch_size, config.seq_length) | |||||
| self.last_idx = (-1,) | |||||
| def construct(self, | |||||
| source_ids, | |||||
| source_mask, | |||||
| target_ids, | |||||
| target_mask, | |||||
| label_ids, | |||||
| label_mask): | |||||
| """Defines the computation performed.""" | |||||
| predicted_log_probs = self.transformer(source_ids, source_mask, target_ids, target_mask) | |||||
| label_ids = self.reshape(label_ids, self.flat_shape) | |||||
| label_mask = self.cast(label_mask, mstype.float32) | |||||
| one_hot_labels = self.one_hot(label_ids, self.vocab_size, self.on_value, self.off_value) | |||||
| label_log_probs = self.reduce_sum(predicted_log_probs * one_hot_labels, self.last_idx) | |||||
| label_log_probs = self.reshape(label_log_probs, self.batch_shape) | |||||
| log_probs = label_log_probs * label_mask | |||||
| lengths = self.reduce_sum(label_mask, self.last_idx) | |||||
| return log_probs, lengths | |||||
| def transformer_infer_ppl(config, dataset): | |||||
| """ | |||||
| Run infer with Transformer for PPL. | |||||
| Args: | |||||
| config (TransformerConfig): Config. | |||||
| dataset (Dataset): Dataset. | |||||
| Returns: | |||||
| List[Dict], prediction, each example has 4 keys, "source", | |||||
| "target", "log_prob" and "length". | |||||
| """ | |||||
| tfm_infer = TransformerInferPPLCell(config=config) | |||||
| tfm_infer.init_parameters_data() | |||||
| parameter_dict = load_checkpoint(config.existed_ckpt) | |||||
| load_param_into_net(tfm_infer, parameter_dict) | |||||
| model = Model(tfm_infer) | |||||
| log_probs = [] | |||||
| lengths = [] | |||||
| source_sentences = [] | |||||
| target_sentences = [] | |||||
| for batch in dataset.create_dict_iterator(): | |||||
| source_sentences.append(batch["source_eos_ids"]) | |||||
| target_sentences.append(batch["target_eos_ids"]) | |||||
| source_ids = Tensor(batch["source_eos_ids"], mstype.int32) | |||||
| source_mask = Tensor(batch["source_eos_mask"], mstype.int32) | |||||
| target_ids = Tensor(batch["target_sos_ids"], mstype.int32) | |||||
| target_mask = Tensor(batch["target_sos_mask"], mstype.int32) | |||||
| label_ids = Tensor(batch["target_eos_ids"], mstype.int32) | |||||
| label_mask = Tensor(batch["target_eos_mask"], mstype.int32) | |||||
| start_time = time.time() | |||||
| log_prob, length = model.predict(source_ids, source_mask, target_ids, target_mask, label_ids, label_mask) | |||||
| print(f" | Batch size: {config.batch_size}, " | |||||
| f"Time cost: {time.time() - start_time}.") | |||||
| log_probs.append(log_prob.asnumpy()) | |||||
| lengths.append(length.asnumpy()) | |||||
| output = [] | |||||
| for inputs, ref, log_prob, length in zip(source_sentences, | |||||
| target_sentences, | |||||
| log_probs, | |||||
| lengths): | |||||
| for i in range(config.batch_size): | |||||
| example = { | |||||
| "source": inputs[i].tolist(), | |||||
| "target": ref[i].tolist(), | |||||
| "log_prob": log_prob[i].tolist(), | |||||
| "length": length[i] | |||||
| } | |||||
| output.append(example) | |||||
| return output | |||||
| def infer_ppl(config): | |||||
| """ | |||||
| Transformer infer PPL api. | |||||
| Args: | |||||
| config (TransformerConfig): Config. | |||||
| Returns: | |||||
| list, result with | |||||
| """ | |||||
| eval_dataset = load_dataset(data_files=config.test_dataset, | |||||
| batch_size=config.batch_size, | |||||
| epoch_count=1, | |||||
| sink_mode=config.dataset_sink_mode, | |||||
| shuffle=False) if config.test_dataset else None | |||||
| prediction = transformer_infer_ppl(config, eval_dataset) | |||||
| return prediction | |||||
| @@ -20,6 +20,7 @@ from .loss_monitor import LossCallBack | |||||
| from .byte_pair_encoding import bpe_encode | from .byte_pair_encoding import bpe_encode | ||||
| from .initializer import zero_weight, one_weight, normal_weight, weight_variable | from .initializer import zero_weight, one_weight, normal_weight, weight_variable | ||||
| from .rouge_score import rouge | from .rouge_score import rouge | ||||
| from .eval_score import get_score | |||||
| __all__ = [ | __all__ = [ | ||||
| "Dictionary", | "Dictionary", | ||||
| @@ -31,5 +32,6 @@ __all__ = [ | |||||
| "one_weight", | "one_weight", | ||||
| "zero_weight", | "zero_weight", | ||||
| "normal_weight", | "normal_weight", | ||||
| "weight_variable" | |||||
| "weight_variable", | |||||
| "get_score" | |||||
| ] | ] | ||||
| @@ -0,0 +1,92 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Get score by given metric.""" | |||||
| from .ppl_score import ngram_ppl | |||||
| from .rouge_score import rouge | |||||
| def get_ppl_score(result): | |||||
| """ | |||||
| Calculate Perplexity(PPL) score. | |||||
| Args: | |||||
| List[Dict], prediction, each example has 4 keys, "source", | |||||
| "target", "log_prob" and "length". | |||||
| Returns: | |||||
| Float, ppl score. | |||||
| """ | |||||
| log_probs = [] | |||||
| total_length = 0 | |||||
| for sample in result: | |||||
| log_prob = sample['log_prob'] | |||||
| length = sample['length'] | |||||
| log_probs.extend(log_prob) | |||||
| total_length += length | |||||
| print(f" | log_prob:{log_prob}") | |||||
| print(f" | length:{length}") | |||||
| ppl = ngram_ppl(log_probs, total_length, log_softmax=True) | |||||
| print(f" | final PPL={ppl}.") | |||||
| return ppl | |||||
| def get_rouge_score(result, vocab): | |||||
| """ | |||||
| Calculate ROUGE score. | |||||
| Args: | |||||
| List[Dict], prediction, each example has 4 keys, "source", | |||||
| "target", "prediction" and "prediction_prob". | |||||
| Dictionary, dict instance. | |||||
| retur: | |||||
| Str, rouge score. | |||||
| """ | |||||
| predictions = [] | |||||
| targets = [] | |||||
| for sample in result: | |||||
| predictions.append(' '.join([vocab[t] for t in sample['prediction']])) | |||||
| targets.append(' '.join([vocab[t] for t in sample['target']])) | |||||
| print(f" | source: {' '.join([vocab[t] for t in sample['source']])}") | |||||
| print(f" | target: {targets[-1]}") | |||||
| return rouge(predictions, targets) | |||||
| def get_score(result, vocab=None, metric='rouge'): | |||||
| """ | |||||
| Get eval score. | |||||
| Args: | |||||
| List[Dict], prediction. | |||||
| Dictionary, dict instance. | |||||
| Str, metric function, default is rouge. | |||||
| Return: | |||||
| Str, Score. | |||||
| """ | |||||
| score = None | |||||
| if metric == 'rouge': | |||||
| score = get_rouge_score(result, vocab) | |||||
| elif metric == 'ppl': | |||||
| score = get_ppl_score(result) | |||||
| else: | |||||
| print(f" |metric not in (rouge, ppl)") | |||||
| return score | |||||
| @@ -17,10 +17,7 @@ from typing import Union | |||||
| import numpy as np | import numpy as np | ||||
| NINF = -1.0 * 1e9 | |||||
| def ngram_ppl(prob: Union[np.ndarray, list], log_softmax=False, index: float = np.e): | |||||
| def ngram_ppl(prob: Union[np.ndarray, list], length: int, log_softmax=False, index: float = np.e): | |||||
| """ | """ | ||||
| Calculate Perplexity(PPL) score under N-gram language model. | Calculate Perplexity(PPL) score under N-gram language model. | ||||
| @@ -39,7 +36,8 @@ def ngram_ppl(prob: Union[np.ndarray, list], log_softmax=False, index: float = n | |||||
| Returns: | Returns: | ||||
| float, ppl score. | float, ppl score. | ||||
| """ | """ | ||||
| eps = 1e-8 | |||||
| if not length: | |||||
| return np.inf | |||||
| if not isinstance(prob, (np.ndarray, list)): | if not isinstance(prob, (np.ndarray, list)): | ||||
| raise TypeError("`prob` must be type of list or np.ndarray.") | raise TypeError("`prob` must be type of list or np.ndarray.") | ||||
| if not isinstance(prob, np.ndarray): | if not isinstance(prob, np.ndarray): | ||||
| @@ -47,18 +45,17 @@ def ngram_ppl(prob: Union[np.ndarray, list], log_softmax=False, index: float = n | |||||
| if prob.shape[0] == 0: | if prob.shape[0] == 0: | ||||
| raise ValueError("`prob` length must greater than 0.") | raise ValueError("`prob` length must greater than 0.") | ||||
| p = 1.0 | |||||
| sen_len = 0 | |||||
| for t in range(prob.shape[0]): | |||||
| s = prob[t] | |||||
| if s <= NINF: | |||||
| break | |||||
| if log_softmax: | |||||
| s = np.power(index, s) | |||||
| p *= (1 / (s + eps)) | |||||
| sen_len += 1 | |||||
| print(f'length:{length}, log_prob:{prob}') | |||||
| if sen_len == 0: | |||||
| return np.inf | |||||
| if log_softmax: | |||||
| prob = np.sum(prob) / length | |||||
| ppl = 1. / np.power(index, prob) | |||||
| print(f'avg log prob:{prob}') | |||||
| else: | |||||
| p = 1. | |||||
| for i in range(prob.shape[0]): | |||||
| p *= (1. / prob[i]) | |||||
| ppl = pow(p, 1 / length) | |||||
| return pow(p, 1 / sen_len) | |||||
| print(f'ppl val:{ppl}') | |||||
| return ppl | |||||