Merge pull request !2955 from linqingke/masstags/v0.6.0-beta
| @@ -40,7 +40,7 @@ parser.add_argument("--checkpoint_path", type=str, required=True, help="Checkpoi | |||
| parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") | |||
| args_opt = parser.parse_args() | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True, device_id=args_opt.device_id) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id) | |||
| def FasterRcnn_eval(dataset_path, ckpt_path, ann_file): | |||
| """FasterRcnn evaluation.""" | |||
| @@ -22,7 +22,7 @@ from mindspore.common.tensor import Tensor | |||
| from mindspore.common import dtype as mstype | |||
| from mindspore.common.initializer import initializer | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
| def bias_init_zeros(shape): | |||
| """Bias init method.""" | |||
| @@ -22,7 +22,7 @@ from mindspore import Tensor | |||
| from mindspore import context | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
| class Proposal(nn.Cell): | |||
| @@ -22,7 +22,7 @@ from mindspore.ops import functional as F | |||
| from mindspore import context | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
| def weight_init_ones(shape): | |||
| @@ -52,7 +52,7 @@ parser.add_argument("--device_num", type=int, default=1, help="Use device nums, | |||
| parser.add_argument("--rank_id", type=int, default=0, help="Rank id, default is 0.") | |||
| args_opt = parser.parse_args() | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True, device_id=args_opt.device_id) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id) | |||
| if __name__ == '__main__': | |||
| if not args_opt.do_eval and args_opt.run_distribute: | |||
| @@ -15,15 +15,13 @@ | |||
| """Evaluation api.""" | |||
| import argparse | |||
| import pickle | |||
| import numpy as np | |||
| from mindspore.common import dtype as mstype | |||
| from config import TransformerConfig | |||
| from src.transformer import infer | |||
| from src.utils import ngram_ppl | |||
| from src.transformer import infer, infer_ppl | |||
| from src.utils import Dictionary | |||
| from src.utils import rouge | |||
| from src.utils import get_score | |||
| parser = argparse.ArgumentParser(description='Evaluation MASS.') | |||
| parser.add_argument("--config", type=str, required=True, | |||
| @@ -32,6 +30,8 @@ parser.add_argument("--vocab", type=str, required=True, | |||
| help="Vocabulary to use.") | |||
| parser.add_argument("--output", type=str, required=True, | |||
| help="Result file path.") | |||
| parser.add_argument("--metric", type=str, default='rouge', | |||
| help='Set eval method.') | |||
| def get_config(config): | |||
| @@ -45,31 +45,15 @@ if __name__ == '__main__': | |||
| args, _ = parser.parse_known_args() | |||
| vocab = Dictionary.load_from_persisted_dict(args.vocab) | |||
| _config = get_config(args.config) | |||
| result = infer(_config) | |||
| if args.metric == 'rouge': | |||
| result = infer(_config) | |||
| else: | |||
| result = infer_ppl(_config) | |||
| with open(args.output, "wb") as f: | |||
| pickle.dump(result, f, 1) | |||
| ppl_score = 0. | |||
| preds = [] | |||
| tgts = [] | |||
| _count = 0 | |||
| for sample in result: | |||
| sentence_prob = np.array(sample['prediction_prob'], dtype=np.float32) | |||
| sentence_prob = sentence_prob[:, 1:] | |||
| _ppl = [] | |||
| for path in sentence_prob: | |||
| _ppl.append(ngram_ppl(path, log_softmax=True)) | |||
| ppl = np.min(_ppl) | |||
| preds.append(' '.join([vocab[t] for t in sample['prediction']])) | |||
| tgts.append(' '.join([vocab[t] for t in sample['target']])) | |||
| print(f" | source: {' '.join([vocab[t] for t in sample['source']])}") | |||
| print(f" | target: {tgts[-1]}") | |||
| print(f" | prediction: {preds[-1]}") | |||
| print(f" | ppl: {ppl}.") | |||
| if np.isinf(ppl): | |||
| continue | |||
| ppl_score += ppl | |||
| _count += 1 | |||
| print(f" | PPL={ppl_score / _count}.") | |||
| rouge(preds, tgts) | |||
| # get score by given metric | |||
| score = get_score(result, vocab, metric=args.metric) | |||
| print(score) | |||
| @@ -18,7 +18,7 @@ export DEVICE_ID=0 | |||
| export RANK_ID=0 | |||
| export RANK_SIZE=1 | |||
| options=`getopt -u -o ht:n:i:j:c:o:v: -l help,task:,device_num:,device_id:,hccl_json:,config:,output:,vocab: -- "$@"` | |||
| options=`getopt -u -o ht:n:i:j:c:o:v:m: -l help,task:,device_num:,device_id:,hccl_json:,config:,output:,vocab:,metric: -- "$@"` | |||
| eval set -- "$options" | |||
| echo $options | |||
| @@ -35,6 +35,7 @@ echo_help() | |||
| echo " -c --config set the configuration file" | |||
| echo " -o --output set the output file of inference" | |||
| echo " -v --vocab set the vocabulary" | |||
| echo " -m --metric set the metric" | |||
| } | |||
| set_hccl_json() | |||
| @@ -43,8 +44,8 @@ set_hccl_json() | |||
| do | |||
| if [[ "$1" == "-j" || "$1" == "--hccl_json" ]] | |||
| then | |||
| export MINDSPORE_HCCL_CONFIG_PATH=$2 #/data/wsc/hccl_2p_01.json | |||
| export RANK_TABLE_FILE=$2 #/data/wsc/hccl_2p_01.json | |||
| export MINDSPORE_HCCL_CONFIG_PATH=$2 | |||
| export RANK_TABLE_FILE=$2 | |||
| break | |||
| fi | |||
| shift | |||
| @@ -119,6 +120,11 @@ do | |||
| vocab=$2 | |||
| shift 2 | |||
| ;; | |||
| -m|--metric) | |||
| echo "metric"; | |||
| metric=$2 | |||
| shift 2 | |||
| ;; | |||
| --) | |||
| shift | |||
| break | |||
| @@ -163,7 +169,7 @@ do | |||
| python train.py --config ${configurations##*/} >>log.log 2>&1 & | |||
| elif [ "$task" == "infer" ] | |||
| then | |||
| python eval.py --config ${configurations##*/} --output ${output} --vocab ${vocab##*/} >>log_infer.log 2>&1 & | |||
| python eval.py --config ${configurations##*/} --output ${output} --vocab ${vocab##*/} --metric ${metric} >>log_infer.log 2>&1 & | |||
| fi | |||
| cd ../ | |||
| done | |||
| @@ -19,10 +19,11 @@ from .decoder import TransformerDecoder | |||
| from .beam_search import BeamSearchDecoder | |||
| from .transformer_for_train import TransformerTraining, LabelSmoothedCrossEntropyCriterion, \ | |||
| TransformerNetworkWithLoss, TransformerTrainOneStepWithLossScaleCell | |||
| from .infer_mass import infer | |||
| from .infer_mass import infer, infer_ppl | |||
| __all__ = [ | |||
| "infer", | |||
| "infer_ppl", | |||
| "TransformerTraining", | |||
| "LabelSmoothedCrossEntropyCriterion", | |||
| "TransformerTrainOneStepWithLossScaleCell", | |||
| @@ -41,7 +41,7 @@ class EmbeddingLookup(nn.Cell): | |||
| self.vocab_size = vocab_size | |||
| self.use_one_hot_embeddings = use_one_hot_embeddings | |||
| init_weight = np.random.normal(0, embed_dim ** -0.5, size=[vocab_size, embed_dim]) | |||
| init_weight = np.random.normal(0, embed_dim ** -0.5, size=[vocab_size, embed_dim]).astype(np.float32) | |||
| # 0 is Padding index, thus init it as 0. | |||
| init_weight[0, :] = 0 | |||
| self.embedding_table = Parameter(Tensor(init_weight), | |||
| @@ -17,13 +17,16 @@ import time | |||
| import mindspore.nn as nn | |||
| import mindspore.common.dtype as mstype | |||
| from mindspore.ops import operations as P | |||
| from mindspore.common.tensor import Tensor | |||
| from mindspore.train.model import Model | |||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| from mindspore import context | |||
| from src.dataset import load_dataset | |||
| from .transformer_for_infer import TransformerInferModel | |||
| from .transformer_for_train import TransformerTraining | |||
| from ..utils.load_weights import load_infer_weights | |||
| context.set_context( | |||
| @@ -156,3 +159,129 @@ def infer(config): | |||
| shuffle=False) if config.test_dataset else None | |||
| prediction = transformer_infer(config, eval_dataset) | |||
| return prediction | |||
| class TransformerInferPPLCell(nn.Cell): | |||
| """ | |||
| Encapsulation class of transformer network infer for PPL. | |||
| Args: | |||
| config(TransformerConfig): Config. | |||
| Returns: | |||
| Tuple[Tensor, Tensor], predicted log prob and label lengths. | |||
| """ | |||
| def __init__(self, config): | |||
| super(TransformerInferPPLCell, self).__init__() | |||
| self.transformer = TransformerTraining(config, is_training=False, use_one_hot_embeddings=False) | |||
| self.batch_size = config.batch_size | |||
| self.vocab_size = config.vocab_size | |||
| self.one_hot = P.OneHot() | |||
| self.on_value = Tensor(float(1), mstype.float32) | |||
| self.off_value = Tensor(float(0), mstype.float32) | |||
| self.reduce_sum = P.ReduceSum() | |||
| self.reshape = P.Reshape() | |||
| self.cast = P.Cast() | |||
| self.flat_shape = (config.batch_size * config.seq_length,) | |||
| self.batch_shape = (config.batch_size, config.seq_length) | |||
| self.last_idx = (-1,) | |||
| def construct(self, | |||
| source_ids, | |||
| source_mask, | |||
| target_ids, | |||
| target_mask, | |||
| label_ids, | |||
| label_mask): | |||
| """Defines the computation performed.""" | |||
| predicted_log_probs = self.transformer(source_ids, source_mask, target_ids, target_mask) | |||
| label_ids = self.reshape(label_ids, self.flat_shape) | |||
| label_mask = self.cast(label_mask, mstype.float32) | |||
| one_hot_labels = self.one_hot(label_ids, self.vocab_size, self.on_value, self.off_value) | |||
| label_log_probs = self.reduce_sum(predicted_log_probs * one_hot_labels, self.last_idx) | |||
| label_log_probs = self.reshape(label_log_probs, self.batch_shape) | |||
| log_probs = label_log_probs * label_mask | |||
| lengths = self.reduce_sum(label_mask, self.last_idx) | |||
| return log_probs, lengths | |||
| def transformer_infer_ppl(config, dataset): | |||
| """ | |||
| Run infer with Transformer for PPL. | |||
| Args: | |||
| config (TransformerConfig): Config. | |||
| dataset (Dataset): Dataset. | |||
| Returns: | |||
| List[Dict], prediction, each example has 4 keys, "source", | |||
| "target", "log_prob" and "length". | |||
| """ | |||
| tfm_infer = TransformerInferPPLCell(config=config) | |||
| tfm_infer.init_parameters_data() | |||
| parameter_dict = load_checkpoint(config.existed_ckpt) | |||
| load_param_into_net(tfm_infer, parameter_dict) | |||
| model = Model(tfm_infer) | |||
| log_probs = [] | |||
| lengths = [] | |||
| source_sentences = [] | |||
| target_sentences = [] | |||
| for batch in dataset.create_dict_iterator(): | |||
| source_sentences.append(batch["source_eos_ids"]) | |||
| target_sentences.append(batch["target_eos_ids"]) | |||
| source_ids = Tensor(batch["source_eos_ids"], mstype.int32) | |||
| source_mask = Tensor(batch["source_eos_mask"], mstype.int32) | |||
| target_ids = Tensor(batch["target_sos_ids"], mstype.int32) | |||
| target_mask = Tensor(batch["target_sos_mask"], mstype.int32) | |||
| label_ids = Tensor(batch["target_eos_ids"], mstype.int32) | |||
| label_mask = Tensor(batch["target_eos_mask"], mstype.int32) | |||
| start_time = time.time() | |||
| log_prob, length = model.predict(source_ids, source_mask, target_ids, target_mask, label_ids, label_mask) | |||
| print(f" | Batch size: {config.batch_size}, " | |||
| f"Time cost: {time.time() - start_time}.") | |||
| log_probs.append(log_prob.asnumpy()) | |||
| lengths.append(length.asnumpy()) | |||
| output = [] | |||
| for inputs, ref, log_prob, length in zip(source_sentences, | |||
| target_sentences, | |||
| log_probs, | |||
| lengths): | |||
| for i in range(config.batch_size): | |||
| example = { | |||
| "source": inputs[i].tolist(), | |||
| "target": ref[i].tolist(), | |||
| "log_prob": log_prob[i].tolist(), | |||
| "length": length[i] | |||
| } | |||
| output.append(example) | |||
| return output | |||
| def infer_ppl(config): | |||
| """ | |||
| Transformer infer PPL api. | |||
| Args: | |||
| config (TransformerConfig): Config. | |||
| Returns: | |||
| list, result with | |||
| """ | |||
| eval_dataset = load_dataset(data_files=config.test_dataset, | |||
| batch_size=config.batch_size, | |||
| epoch_count=1, | |||
| sink_mode=config.dataset_sink_mode, | |||
| shuffle=False) if config.test_dataset else None | |||
| prediction = transformer_infer_ppl(config, eval_dataset) | |||
| return prediction | |||
| @@ -20,6 +20,7 @@ from .loss_monitor import LossCallBack | |||
| from .byte_pair_encoding import bpe_encode | |||
| from .initializer import zero_weight, one_weight, normal_weight, weight_variable | |||
| from .rouge_score import rouge | |||
| from .eval_score import get_score | |||
| __all__ = [ | |||
| "Dictionary", | |||
| @@ -31,5 +32,6 @@ __all__ = [ | |||
| "one_weight", | |||
| "zero_weight", | |||
| "normal_weight", | |||
| "weight_variable" | |||
| "weight_variable", | |||
| "get_score" | |||
| ] | |||
| @@ -0,0 +1,92 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Get score by given metric.""" | |||
| from .ppl_score import ngram_ppl | |||
| from .rouge_score import rouge | |||
| def get_ppl_score(result): | |||
| """ | |||
| Calculate Perplexity(PPL) score. | |||
| Args: | |||
| List[Dict], prediction, each example has 4 keys, "source", | |||
| "target", "log_prob" and "length". | |||
| Returns: | |||
| Float, ppl score. | |||
| """ | |||
| log_probs = [] | |||
| total_length = 0 | |||
| for sample in result: | |||
| log_prob = sample['log_prob'] | |||
| length = sample['length'] | |||
| log_probs.extend(log_prob) | |||
| total_length += length | |||
| print(f" | log_prob:{log_prob}") | |||
| print(f" | length:{length}") | |||
| ppl = ngram_ppl(log_probs, total_length, log_softmax=True) | |||
| print(f" | final PPL={ppl}.") | |||
| return ppl | |||
| def get_rouge_score(result, vocab): | |||
| """ | |||
| Calculate ROUGE score. | |||
| Args: | |||
| List[Dict], prediction, each example has 4 keys, "source", | |||
| "target", "prediction" and "prediction_prob". | |||
| Dictionary, dict instance. | |||
| retur: | |||
| Str, rouge score. | |||
| """ | |||
| predictions = [] | |||
| targets = [] | |||
| for sample in result: | |||
| predictions.append(' '.join([vocab[t] for t in sample['prediction']])) | |||
| targets.append(' '.join([vocab[t] for t in sample['target']])) | |||
| print(f" | source: {' '.join([vocab[t] for t in sample['source']])}") | |||
| print(f" | target: {targets[-1]}") | |||
| return rouge(predictions, targets) | |||
| def get_score(result, vocab=None, metric='rouge'): | |||
| """ | |||
| Get eval score. | |||
| Args: | |||
| List[Dict], prediction. | |||
| Dictionary, dict instance. | |||
| Str, metric function, default is rouge. | |||
| Return: | |||
| Str, Score. | |||
| """ | |||
| score = None | |||
| if metric == 'rouge': | |||
| score = get_rouge_score(result, vocab) | |||
| elif metric == 'ppl': | |||
| score = get_ppl_score(result) | |||
| else: | |||
| print(f" |metric not in (rouge, ppl)") | |||
| return score | |||
| @@ -17,10 +17,7 @@ from typing import Union | |||
| import numpy as np | |||
| NINF = -1.0 * 1e9 | |||
| def ngram_ppl(prob: Union[np.ndarray, list], log_softmax=False, index: float = np.e): | |||
| def ngram_ppl(prob: Union[np.ndarray, list], length: int, log_softmax=False, index: float = np.e): | |||
| """ | |||
| Calculate Perplexity(PPL) score under N-gram language model. | |||
| @@ -39,7 +36,8 @@ def ngram_ppl(prob: Union[np.ndarray, list], log_softmax=False, index: float = n | |||
| Returns: | |||
| float, ppl score. | |||
| """ | |||
| eps = 1e-8 | |||
| if not length: | |||
| return np.inf | |||
| if not isinstance(prob, (np.ndarray, list)): | |||
| raise TypeError("`prob` must be type of list or np.ndarray.") | |||
| if not isinstance(prob, np.ndarray): | |||
| @@ -47,18 +45,17 @@ def ngram_ppl(prob: Union[np.ndarray, list], log_softmax=False, index: float = n | |||
| if prob.shape[0] == 0: | |||
| raise ValueError("`prob` length must greater than 0.") | |||
| p = 1.0 | |||
| sen_len = 0 | |||
| for t in range(prob.shape[0]): | |||
| s = prob[t] | |||
| if s <= NINF: | |||
| break | |||
| if log_softmax: | |||
| s = np.power(index, s) | |||
| p *= (1 / (s + eps)) | |||
| sen_len += 1 | |||
| print(f'length:{length}, log_prob:{prob}') | |||
| if sen_len == 0: | |||
| return np.inf | |||
| if log_softmax: | |||
| prob = np.sum(prob) / length | |||
| ppl = 1. / np.power(index, prob) | |||
| print(f'avg log prob:{prob}') | |||
| else: | |||
| p = 1. | |||
| for i in range(prob.shape[0]): | |||
| p *= (1. / prob[i]) | |||
| ppl = pow(p, 1 / length) | |||
| return pow(p, 1 / sen_len) | |||
| print(f'ppl val:{ppl}') | |||
| return ppl | |||