Browse Source

!2955 Mass eval mertric update.

Merge pull request !2955 from linqingke/mass
tags/v0.6.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
a4d7ad7f7d
13 changed files with 270 additions and 59 deletions
  1. +1
    -1
      model_zoo/faster_rcnn/eval.py
  2. +1
    -1
      model_zoo/faster_rcnn/src/FasterRcnn/fpn_neck.py
  3. +1
    -1
      model_zoo/faster_rcnn/src/FasterRcnn/proposal_generator.py
  4. +1
    -1
      model_zoo/faster_rcnn/src/FasterRcnn/resnet50.py
  5. +1
    -1
      model_zoo/faster_rcnn/train.py
  6. +13
    -29
      model_zoo/mass/eval.py
  7. +10
    -4
      model_zoo/mass/scripts/run.sh
  8. +2
    -1
      model_zoo/mass/src/transformer/__init__.py
  9. +1
    -1
      model_zoo/mass/src/transformer/embedding.py
  10. +129
    -0
      model_zoo/mass/src/transformer/infer_mass.py
  11. +3
    -1
      model_zoo/mass/src/utils/__init__.py
  12. +92
    -0
      model_zoo/mass/src/utils/eval_score.py
  13. +15
    -18
      model_zoo/mass/src/utils/ppl_score.py

+ 1
- 1
model_zoo/faster_rcnn/eval.py View File

@@ -40,7 +40,7 @@ parser.add_argument("--checkpoint_path", type=str, required=True, help="Checkpoi
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
args_opt = parser.parse_args() args_opt = parser.parse_args()


context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True, device_id=args_opt.device_id)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id)


def FasterRcnn_eval(dataset_path, ckpt_path, ann_file): def FasterRcnn_eval(dataset_path, ckpt_path, ann_file):
"""FasterRcnn evaluation.""" """FasterRcnn evaluation."""


+ 1
- 1
model_zoo/faster_rcnn/src/FasterRcnn/fpn_neck.py View File

@@ -22,7 +22,7 @@ from mindspore.common.tensor import Tensor
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore.common.initializer import initializer from mindspore.common.initializer import initializer


context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")


def bias_init_zeros(shape): def bias_init_zeros(shape):
"""Bias init method.""" """Bias init method."""


+ 1
- 1
model_zoo/faster_rcnn/src/FasterRcnn/proposal_generator.py View File

@@ -22,7 +22,7 @@ from mindspore import Tensor
from mindspore import context from mindspore import context




context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")




class Proposal(nn.Cell): class Proposal(nn.Cell):


+ 1
- 1
model_zoo/faster_rcnn/src/FasterRcnn/resnet50.py View File

@@ -22,7 +22,7 @@ from mindspore.ops import functional as F
from mindspore import context from mindspore import context
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
def weight_init_ones(shape): def weight_init_ones(shape):


+ 1
- 1
model_zoo/faster_rcnn/train.py View File

@@ -52,7 +52,7 @@ parser.add_argument("--device_num", type=int, default=1, help="Use device nums,
parser.add_argument("--rank_id", type=int, default=0, help="Rank id, default is 0.") parser.add_argument("--rank_id", type=int, default=0, help="Rank id, default is 0.")
args_opt = parser.parse_args() args_opt = parser.parse_args()


context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True, device_id=args_opt.device_id)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id)


if __name__ == '__main__': if __name__ == '__main__':
if not args_opt.do_eval and args_opt.run_distribute: if not args_opt.do_eval and args_opt.run_distribute:


+ 13
- 29
model_zoo/mass/eval.py View File

@@ -15,15 +15,13 @@
"""Evaluation api.""" """Evaluation api."""
import argparse import argparse
import pickle import pickle
import numpy as np


from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype


from config import TransformerConfig from config import TransformerConfig
from src.transformer import infer
from src.utils import ngram_ppl
from src.transformer import infer, infer_ppl
from src.utils import Dictionary from src.utils import Dictionary
from src.utils import rouge
from src.utils import get_score


parser = argparse.ArgumentParser(description='Evaluation MASS.') parser = argparse.ArgumentParser(description='Evaluation MASS.')
parser.add_argument("--config", type=str, required=True, parser.add_argument("--config", type=str, required=True,
@@ -32,6 +30,8 @@ parser.add_argument("--vocab", type=str, required=True,
help="Vocabulary to use.") help="Vocabulary to use.")
parser.add_argument("--output", type=str, required=True, parser.add_argument("--output", type=str, required=True,
help="Result file path.") help="Result file path.")
parser.add_argument("--metric", type=str, default='rouge',
help='Set eval method.')




def get_config(config): def get_config(config):
@@ -45,31 +45,15 @@ if __name__ == '__main__':
args, _ = parser.parse_known_args() args, _ = parser.parse_known_args()
vocab = Dictionary.load_from_persisted_dict(args.vocab) vocab = Dictionary.load_from_persisted_dict(args.vocab)
_config = get_config(args.config) _config = get_config(args.config)
result = infer(_config)

if args.metric == 'rouge':
result = infer(_config)
else:
result = infer_ppl(_config)

with open(args.output, "wb") as f: with open(args.output, "wb") as f:
pickle.dump(result, f, 1) pickle.dump(result, f, 1)


ppl_score = 0.
preds = []
tgts = []
_count = 0
for sample in result:
sentence_prob = np.array(sample['prediction_prob'], dtype=np.float32)
sentence_prob = sentence_prob[:, 1:]
_ppl = []
for path in sentence_prob:
_ppl.append(ngram_ppl(path, log_softmax=True))
ppl = np.min(_ppl)
preds.append(' '.join([vocab[t] for t in sample['prediction']]))
tgts.append(' '.join([vocab[t] for t in sample['target']]))
print(f" | source: {' '.join([vocab[t] for t in sample['source']])}")
print(f" | target: {tgts[-1]}")
print(f" | prediction: {preds[-1]}")
print(f" | ppl: {ppl}.")
if np.isinf(ppl):
continue
ppl_score += ppl
_count += 1

print(f" | PPL={ppl_score / _count}.")
rouge(preds, tgts)
# get score by given metric
score = get_score(result, vocab, metric=args.metric)
print(score)

+ 10
- 4
model_zoo/mass/scripts/run.sh View File

@@ -18,7 +18,7 @@ export DEVICE_ID=0
export RANK_ID=0 export RANK_ID=0
export RANK_SIZE=1 export RANK_SIZE=1


options=`getopt -u -o ht:n:i:j:c:o:v: -l help,task:,device_num:,device_id:,hccl_json:,config:,output:,vocab: -- "$@"`
options=`getopt -u -o ht:n:i:j:c:o:v:m: -l help,task:,device_num:,device_id:,hccl_json:,config:,output:,vocab:,metric: -- "$@"`
eval set -- "$options" eval set -- "$options"
echo $options echo $options


@@ -35,6 +35,7 @@ echo_help()
echo " -c --config set the configuration file" echo " -c --config set the configuration file"
echo " -o --output set the output file of inference" echo " -o --output set the output file of inference"
echo " -v --vocab set the vocabulary" echo " -v --vocab set the vocabulary"
echo " -m --metric set the metric"
} }


set_hccl_json() set_hccl_json()
@@ -43,8 +44,8 @@ set_hccl_json()
do do
if [[ "$1" == "-j" || "$1" == "--hccl_json" ]] if [[ "$1" == "-j" || "$1" == "--hccl_json" ]]
then then
export MINDSPORE_HCCL_CONFIG_PATH=$2 #/data/wsc/hccl_2p_01.json
export RANK_TABLE_FILE=$2 #/data/wsc/hccl_2p_01.json
export MINDSPORE_HCCL_CONFIG_PATH=$2
export RANK_TABLE_FILE=$2
break break
fi fi
shift shift
@@ -119,6 +120,11 @@ do
vocab=$2 vocab=$2
shift 2 shift 2
;; ;;
-m|--metric)
echo "metric";
metric=$2
shift 2
;;
--) --)
shift shift
break break
@@ -163,7 +169,7 @@ do
python train.py --config ${configurations##*/} >>log.log 2>&1 & python train.py --config ${configurations##*/} >>log.log 2>&1 &
elif [ "$task" == "infer" ] elif [ "$task" == "infer" ]
then then
python eval.py --config ${configurations##*/} --output ${output} --vocab ${vocab##*/} >>log_infer.log 2>&1 &
python eval.py --config ${configurations##*/} --output ${output} --vocab ${vocab##*/} --metric ${metric} >>log_infer.log 2>&1 &
fi fi
cd ../ cd ../
done done

+ 2
- 1
model_zoo/mass/src/transformer/__init__.py View File

@@ -19,10 +19,11 @@ from .decoder import TransformerDecoder
from .beam_search import BeamSearchDecoder from .beam_search import BeamSearchDecoder
from .transformer_for_train import TransformerTraining, LabelSmoothedCrossEntropyCriterion, \ from .transformer_for_train import TransformerTraining, LabelSmoothedCrossEntropyCriterion, \
TransformerNetworkWithLoss, TransformerTrainOneStepWithLossScaleCell TransformerNetworkWithLoss, TransformerTrainOneStepWithLossScaleCell
from .infer_mass import infer
from .infer_mass import infer, infer_ppl


__all__ = [ __all__ = [
"infer", "infer",
"infer_ppl",
"TransformerTraining", "TransformerTraining",
"LabelSmoothedCrossEntropyCriterion", "LabelSmoothedCrossEntropyCriterion",
"TransformerTrainOneStepWithLossScaleCell", "TransformerTrainOneStepWithLossScaleCell",


+ 1
- 1
model_zoo/mass/src/transformer/embedding.py View File

@@ -41,7 +41,7 @@ class EmbeddingLookup(nn.Cell):
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.use_one_hot_embeddings = use_one_hot_embeddings self.use_one_hot_embeddings = use_one_hot_embeddings


init_weight = np.random.normal(0, embed_dim ** -0.5, size=[vocab_size, embed_dim])
init_weight = np.random.normal(0, embed_dim ** -0.5, size=[vocab_size, embed_dim]).astype(np.float32)
# 0 is Padding index, thus init it as 0. # 0 is Padding index, thus init it as 0.
init_weight[0, :] = 0 init_weight[0, :] = 0
self.embedding_table = Parameter(Tensor(init_weight), self.embedding_table = Parameter(Tensor(init_weight),


+ 129
- 0
model_zoo/mass/src/transformer/infer_mass.py View File

@@ -17,13 +17,16 @@ import time


import mindspore.nn as nn import mindspore.nn as nn
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
from mindspore.ops import operations as P
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
from mindspore.train.model import Model from mindspore.train.model import Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net


from mindspore import context from mindspore import context


from src.dataset import load_dataset from src.dataset import load_dataset
from .transformer_for_infer import TransformerInferModel from .transformer_for_infer import TransformerInferModel
from .transformer_for_train import TransformerTraining
from ..utils.load_weights import load_infer_weights from ..utils.load_weights import load_infer_weights


context.set_context( context.set_context(
@@ -156,3 +159,129 @@ def infer(config):
shuffle=False) if config.test_dataset else None shuffle=False) if config.test_dataset else None
prediction = transformer_infer(config, eval_dataset) prediction = transformer_infer(config, eval_dataset)
return prediction return prediction


class TransformerInferPPLCell(nn.Cell):
"""
Encapsulation class of transformer network infer for PPL.

Args:
config(TransformerConfig): Config.

Returns:
Tuple[Tensor, Tensor], predicted log prob and label lengths.
"""
def __init__(self, config):
super(TransformerInferPPLCell, self).__init__()
self.transformer = TransformerTraining(config, is_training=False, use_one_hot_embeddings=False)
self.batch_size = config.batch_size
self.vocab_size = config.vocab_size
self.one_hot = P.OneHot()
self.on_value = Tensor(float(1), mstype.float32)
self.off_value = Tensor(float(0), mstype.float32)
self.reduce_sum = P.ReduceSum()
self.reshape = P.Reshape()
self.cast = P.Cast()
self.flat_shape = (config.batch_size * config.seq_length,)
self.batch_shape = (config.batch_size, config.seq_length)
self.last_idx = (-1,)

def construct(self,
source_ids,
source_mask,
target_ids,
target_mask,
label_ids,
label_mask):
"""Defines the computation performed."""

predicted_log_probs = self.transformer(source_ids, source_mask, target_ids, target_mask)
label_ids = self.reshape(label_ids, self.flat_shape)
label_mask = self.cast(label_mask, mstype.float32)
one_hot_labels = self.one_hot(label_ids, self.vocab_size, self.on_value, self.off_value)

label_log_probs = self.reduce_sum(predicted_log_probs * one_hot_labels, self.last_idx)
label_log_probs = self.reshape(label_log_probs, self.batch_shape)
log_probs = label_log_probs * label_mask
lengths = self.reduce_sum(label_mask, self.last_idx)

return log_probs, lengths


def transformer_infer_ppl(config, dataset):
"""
Run infer with Transformer for PPL.

Args:
config (TransformerConfig): Config.
dataset (Dataset): Dataset.

Returns:
List[Dict], prediction, each example has 4 keys, "source",
"target", "log_prob" and "length".
"""
tfm_infer = TransformerInferPPLCell(config=config)
tfm_infer.init_parameters_data()

parameter_dict = load_checkpoint(config.existed_ckpt)
load_param_into_net(tfm_infer, parameter_dict)

model = Model(tfm_infer)

log_probs = []
lengths = []
source_sentences = []
target_sentences = []
for batch in dataset.create_dict_iterator():
source_sentences.append(batch["source_eos_ids"])
target_sentences.append(batch["target_eos_ids"])

source_ids = Tensor(batch["source_eos_ids"], mstype.int32)
source_mask = Tensor(batch["source_eos_mask"], mstype.int32)
target_ids = Tensor(batch["target_sos_ids"], mstype.int32)
target_mask = Tensor(batch["target_sos_mask"], mstype.int32)
label_ids = Tensor(batch["target_eos_ids"], mstype.int32)
label_mask = Tensor(batch["target_eos_mask"], mstype.int32)

start_time = time.time()
log_prob, length = model.predict(source_ids, source_mask, target_ids, target_mask, label_ids, label_mask)
print(f" | Batch size: {config.batch_size}, "
f"Time cost: {time.time() - start_time}.")

log_probs.append(log_prob.asnumpy())
lengths.append(length.asnumpy())

output = []
for inputs, ref, log_prob, length in zip(source_sentences,
target_sentences,
log_probs,
lengths):
for i in range(config.batch_size):
example = {
"source": inputs[i].tolist(),
"target": ref[i].tolist(),
"log_prob": log_prob[i].tolist(),
"length": length[i]
}
output.append(example)

return output


def infer_ppl(config):
"""
Transformer infer PPL api.

Args:
config (TransformerConfig): Config.

Returns:
list, result with
"""
eval_dataset = load_dataset(data_files=config.test_dataset,
batch_size=config.batch_size,
epoch_count=1,
sink_mode=config.dataset_sink_mode,
shuffle=False) if config.test_dataset else None
prediction = transformer_infer_ppl(config, eval_dataset)
return prediction

+ 3
- 1
model_zoo/mass/src/utils/__init__.py View File

@@ -20,6 +20,7 @@ from .loss_monitor import LossCallBack
from .byte_pair_encoding import bpe_encode from .byte_pair_encoding import bpe_encode
from .initializer import zero_weight, one_weight, normal_weight, weight_variable from .initializer import zero_weight, one_weight, normal_weight, weight_variable
from .rouge_score import rouge from .rouge_score import rouge
from .eval_score import get_score


__all__ = [ __all__ = [
"Dictionary", "Dictionary",
@@ -31,5 +32,6 @@ __all__ = [
"one_weight", "one_weight",
"zero_weight", "zero_weight",
"normal_weight", "normal_weight",
"weight_variable"
"weight_variable",
"get_score"
] ]

+ 92
- 0
model_zoo/mass/src/utils/eval_score.py View File

@@ -0,0 +1,92 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Get score by given metric."""
from .ppl_score import ngram_ppl
from .rouge_score import rouge


def get_ppl_score(result):
"""
Calculate Perplexity(PPL) score.

Args:
List[Dict], prediction, each example has 4 keys, "source",
"target", "log_prob" and "length".

Returns:
Float, ppl score.
"""
log_probs = []
total_length = 0

for sample in result:
log_prob = sample['log_prob']
length = sample['length']
log_probs.extend(log_prob)
total_length += length

print(f" | log_prob:{log_prob}")
print(f" | length:{length}")

ppl = ngram_ppl(log_probs, total_length, log_softmax=True)
print(f" | final PPL={ppl}.")
return ppl


def get_rouge_score(result, vocab):
"""
Calculate ROUGE score.

Args:
List[Dict], prediction, each example has 4 keys, "source",
"target", "prediction" and "prediction_prob".
Dictionary, dict instance.

retur:
Str, rouge score.
"""

predictions = []
targets = []
for sample in result:
predictions.append(' '.join([vocab[t] for t in sample['prediction']]))
targets.append(' '.join([vocab[t] for t in sample['target']]))
print(f" | source: {' '.join([vocab[t] for t in sample['source']])}")
print(f" | target: {targets[-1]}")

return rouge(predictions, targets)


def get_score(result, vocab=None, metric='rouge'):
"""
Get eval score.

Args:
List[Dict], prediction.
Dictionary, dict instance.
Str, metric function, default is rouge.

Return:
Str, Score.
"""
score = None
if metric == 'rouge':
score = get_rouge_score(result, vocab)
elif metric == 'ppl':
score = get_ppl_score(result)
else:
print(f" |metric not in (rouge, ppl)")

return score

+ 15
- 18
model_zoo/mass/src/utils/ppl_score.py View File

@@ -17,10 +17,7 @@ from typing import Union


import numpy as np import numpy as np


NINF = -1.0 * 1e9


def ngram_ppl(prob: Union[np.ndarray, list], log_softmax=False, index: float = np.e):
def ngram_ppl(prob: Union[np.ndarray, list], length: int, log_softmax=False, index: float = np.e):
""" """
Calculate Perplexity(PPL) score under N-gram language model. Calculate Perplexity(PPL) score under N-gram language model.


@@ -39,7 +36,8 @@ def ngram_ppl(prob: Union[np.ndarray, list], log_softmax=False, index: float = n
Returns: Returns:
float, ppl score. float, ppl score.
""" """
eps = 1e-8
if not length:
return np.inf
if not isinstance(prob, (np.ndarray, list)): if not isinstance(prob, (np.ndarray, list)):
raise TypeError("`prob` must be type of list or np.ndarray.") raise TypeError("`prob` must be type of list or np.ndarray.")
if not isinstance(prob, np.ndarray): if not isinstance(prob, np.ndarray):
@@ -47,18 +45,17 @@ def ngram_ppl(prob: Union[np.ndarray, list], log_softmax=False, index: float = n
if prob.shape[0] == 0: if prob.shape[0] == 0:
raise ValueError("`prob` length must greater than 0.") raise ValueError("`prob` length must greater than 0.")


p = 1.0
sen_len = 0
for t in range(prob.shape[0]):
s = prob[t]
if s <= NINF:
break
if log_softmax:
s = np.power(index, s)
p *= (1 / (s + eps))
sen_len += 1
print(f'length:{length}, log_prob:{prob}')


if sen_len == 0:
return np.inf
if log_softmax:
prob = np.sum(prob) / length
ppl = 1. / np.power(index, prob)
print(f'avg log prob:{prob}')
else:
p = 1.
for i in range(prob.shape[0]):
p *= (1. / prob[i])
ppl = pow(p, 1 / length)


return pow(p, 1 / sen_len)
print(f'ppl val:{ppl}')
return ppl

Loading…
Cancel
Save