Browse Source

!2955 Mass eval mertric update.

Merge pull request !2955 from linqingke/mass
tags/v0.6.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
a4d7ad7f7d
13 changed files with 270 additions and 59 deletions
  1. +1
    -1
      model_zoo/faster_rcnn/eval.py
  2. +1
    -1
      model_zoo/faster_rcnn/src/FasterRcnn/fpn_neck.py
  3. +1
    -1
      model_zoo/faster_rcnn/src/FasterRcnn/proposal_generator.py
  4. +1
    -1
      model_zoo/faster_rcnn/src/FasterRcnn/resnet50.py
  5. +1
    -1
      model_zoo/faster_rcnn/train.py
  6. +13
    -29
      model_zoo/mass/eval.py
  7. +10
    -4
      model_zoo/mass/scripts/run.sh
  8. +2
    -1
      model_zoo/mass/src/transformer/__init__.py
  9. +1
    -1
      model_zoo/mass/src/transformer/embedding.py
  10. +129
    -0
      model_zoo/mass/src/transformer/infer_mass.py
  11. +3
    -1
      model_zoo/mass/src/utils/__init__.py
  12. +92
    -0
      model_zoo/mass/src/utils/eval_score.py
  13. +15
    -18
      model_zoo/mass/src/utils/ppl_score.py

+ 1
- 1
model_zoo/faster_rcnn/eval.py View File

@@ -40,7 +40,7 @@ parser.add_argument("--checkpoint_path", type=str, required=True, help="Checkpoi
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
args_opt = parser.parse_args()

context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True, device_id=args_opt.device_id)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id)

def FasterRcnn_eval(dataset_path, ckpt_path, ann_file):
"""FasterRcnn evaluation."""


+ 1
- 1
model_zoo/faster_rcnn/src/FasterRcnn/fpn_neck.py View File

@@ -22,7 +22,7 @@ from mindspore.common.tensor import Tensor
from mindspore.common import dtype as mstype
from mindspore.common.initializer import initializer

context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")

def bias_init_zeros(shape):
"""Bias init method."""


+ 1
- 1
model_zoo/faster_rcnn/src/FasterRcnn/proposal_generator.py View File

@@ -22,7 +22,7 @@ from mindspore import Tensor
from mindspore import context


context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")


class Proposal(nn.Cell):


+ 1
- 1
model_zoo/faster_rcnn/src/FasterRcnn/resnet50.py View File

@@ -22,7 +22,7 @@ from mindspore.ops import functional as F
from mindspore import context
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
def weight_init_ones(shape):


+ 1
- 1
model_zoo/faster_rcnn/train.py View File

@@ -52,7 +52,7 @@ parser.add_argument("--device_num", type=int, default=1, help="Use device nums,
parser.add_argument("--rank_id", type=int, default=0, help="Rank id, default is 0.")
args_opt = parser.parse_args()

context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True, device_id=args_opt.device_id)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id)

if __name__ == '__main__':
if not args_opt.do_eval and args_opt.run_distribute:


+ 13
- 29
model_zoo/mass/eval.py View File

@@ -15,15 +15,13 @@
"""Evaluation api."""
import argparse
import pickle
import numpy as np

from mindspore.common import dtype as mstype

from config import TransformerConfig
from src.transformer import infer
from src.utils import ngram_ppl
from src.transformer import infer, infer_ppl
from src.utils import Dictionary
from src.utils import rouge
from src.utils import get_score

parser = argparse.ArgumentParser(description='Evaluation MASS.')
parser.add_argument("--config", type=str, required=True,
@@ -32,6 +30,8 @@ parser.add_argument("--vocab", type=str, required=True,
help="Vocabulary to use.")
parser.add_argument("--output", type=str, required=True,
help="Result file path.")
parser.add_argument("--metric", type=str, default='rouge',
help='Set eval method.')


def get_config(config):
@@ -45,31 +45,15 @@ if __name__ == '__main__':
args, _ = parser.parse_known_args()
vocab = Dictionary.load_from_persisted_dict(args.vocab)
_config = get_config(args.config)
result = infer(_config)

if args.metric == 'rouge':
result = infer(_config)
else:
result = infer_ppl(_config)

with open(args.output, "wb") as f:
pickle.dump(result, f, 1)

ppl_score = 0.
preds = []
tgts = []
_count = 0
for sample in result:
sentence_prob = np.array(sample['prediction_prob'], dtype=np.float32)
sentence_prob = sentence_prob[:, 1:]
_ppl = []
for path in sentence_prob:
_ppl.append(ngram_ppl(path, log_softmax=True))
ppl = np.min(_ppl)
preds.append(' '.join([vocab[t] for t in sample['prediction']]))
tgts.append(' '.join([vocab[t] for t in sample['target']]))
print(f" | source: {' '.join([vocab[t] for t in sample['source']])}")
print(f" | target: {tgts[-1]}")
print(f" | prediction: {preds[-1]}")
print(f" | ppl: {ppl}.")
if np.isinf(ppl):
continue
ppl_score += ppl
_count += 1

print(f" | PPL={ppl_score / _count}.")
rouge(preds, tgts)
# get score by given metric
score = get_score(result, vocab, metric=args.metric)
print(score)

+ 10
- 4
model_zoo/mass/scripts/run.sh View File

@@ -18,7 +18,7 @@ export DEVICE_ID=0
export RANK_ID=0
export RANK_SIZE=1

options=`getopt -u -o ht:n:i:j:c:o:v: -l help,task:,device_num:,device_id:,hccl_json:,config:,output:,vocab: -- "$@"`
options=`getopt -u -o ht:n:i:j:c:o:v:m: -l help,task:,device_num:,device_id:,hccl_json:,config:,output:,vocab:,metric: -- "$@"`
eval set -- "$options"
echo $options

@@ -35,6 +35,7 @@ echo_help()
echo " -c --config set the configuration file"
echo " -o --output set the output file of inference"
echo " -v --vocab set the vocabulary"
echo " -m --metric set the metric"
}

set_hccl_json()
@@ -43,8 +44,8 @@ set_hccl_json()
do
if [[ "$1" == "-j" || "$1" == "--hccl_json" ]]
then
export MINDSPORE_HCCL_CONFIG_PATH=$2 #/data/wsc/hccl_2p_01.json
export RANK_TABLE_FILE=$2 #/data/wsc/hccl_2p_01.json
export MINDSPORE_HCCL_CONFIG_PATH=$2
export RANK_TABLE_FILE=$2
break
fi
shift
@@ -119,6 +120,11 @@ do
vocab=$2
shift 2
;;
-m|--metric)
echo "metric";
metric=$2
shift 2
;;
--)
shift
break
@@ -163,7 +169,7 @@ do
python train.py --config ${configurations##*/} >>log.log 2>&1 &
elif [ "$task" == "infer" ]
then
python eval.py --config ${configurations##*/} --output ${output} --vocab ${vocab##*/} >>log_infer.log 2>&1 &
python eval.py --config ${configurations##*/} --output ${output} --vocab ${vocab##*/} --metric ${metric} >>log_infer.log 2>&1 &
fi
cd ../
done

+ 2
- 1
model_zoo/mass/src/transformer/__init__.py View File

@@ -19,10 +19,11 @@ from .decoder import TransformerDecoder
from .beam_search import BeamSearchDecoder
from .transformer_for_train import TransformerTraining, LabelSmoothedCrossEntropyCriterion, \
TransformerNetworkWithLoss, TransformerTrainOneStepWithLossScaleCell
from .infer_mass import infer
from .infer_mass import infer, infer_ppl

__all__ = [
"infer",
"infer_ppl",
"TransformerTraining",
"LabelSmoothedCrossEntropyCriterion",
"TransformerTrainOneStepWithLossScaleCell",


+ 1
- 1
model_zoo/mass/src/transformer/embedding.py View File

@@ -41,7 +41,7 @@ class EmbeddingLookup(nn.Cell):
self.vocab_size = vocab_size
self.use_one_hot_embeddings = use_one_hot_embeddings

init_weight = np.random.normal(0, embed_dim ** -0.5, size=[vocab_size, embed_dim])
init_weight = np.random.normal(0, embed_dim ** -0.5, size=[vocab_size, embed_dim]).astype(np.float32)
# 0 is Padding index, thus init it as 0.
init_weight[0, :] = 0
self.embedding_table = Parameter(Tensor(init_weight),


+ 129
- 0
model_zoo/mass/src/transformer/infer_mass.py View File

@@ -17,13 +17,16 @@ import time

import mindspore.nn as nn
import mindspore.common.dtype as mstype
from mindspore.ops import operations as P
from mindspore.common.tensor import Tensor
from mindspore.train.model import Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net

from mindspore import context

from src.dataset import load_dataset
from .transformer_for_infer import TransformerInferModel
from .transformer_for_train import TransformerTraining
from ..utils.load_weights import load_infer_weights

context.set_context(
@@ -156,3 +159,129 @@ def infer(config):
shuffle=False) if config.test_dataset else None
prediction = transformer_infer(config, eval_dataset)
return prediction


class TransformerInferPPLCell(nn.Cell):
"""
Encapsulation class of transformer network infer for PPL.

Args:
config(TransformerConfig): Config.

Returns:
Tuple[Tensor, Tensor], predicted log prob and label lengths.
"""
def __init__(self, config):
super(TransformerInferPPLCell, self).__init__()
self.transformer = TransformerTraining(config, is_training=False, use_one_hot_embeddings=False)
self.batch_size = config.batch_size
self.vocab_size = config.vocab_size
self.one_hot = P.OneHot()
self.on_value = Tensor(float(1), mstype.float32)
self.off_value = Tensor(float(0), mstype.float32)
self.reduce_sum = P.ReduceSum()
self.reshape = P.Reshape()
self.cast = P.Cast()
self.flat_shape = (config.batch_size * config.seq_length,)
self.batch_shape = (config.batch_size, config.seq_length)
self.last_idx = (-1,)

def construct(self,
source_ids,
source_mask,
target_ids,
target_mask,
label_ids,
label_mask):
"""Defines the computation performed."""

predicted_log_probs = self.transformer(source_ids, source_mask, target_ids, target_mask)
label_ids = self.reshape(label_ids, self.flat_shape)
label_mask = self.cast(label_mask, mstype.float32)
one_hot_labels = self.one_hot(label_ids, self.vocab_size, self.on_value, self.off_value)

label_log_probs = self.reduce_sum(predicted_log_probs * one_hot_labels, self.last_idx)
label_log_probs = self.reshape(label_log_probs, self.batch_shape)
log_probs = label_log_probs * label_mask
lengths = self.reduce_sum(label_mask, self.last_idx)

return log_probs, lengths


def transformer_infer_ppl(config, dataset):
"""
Run infer with Transformer for PPL.

Args:
config (TransformerConfig): Config.
dataset (Dataset): Dataset.

Returns:
List[Dict], prediction, each example has 4 keys, "source",
"target", "log_prob" and "length".
"""
tfm_infer = TransformerInferPPLCell(config=config)
tfm_infer.init_parameters_data()

parameter_dict = load_checkpoint(config.existed_ckpt)
load_param_into_net(tfm_infer, parameter_dict)

model = Model(tfm_infer)

log_probs = []
lengths = []
source_sentences = []
target_sentences = []
for batch in dataset.create_dict_iterator():
source_sentences.append(batch["source_eos_ids"])
target_sentences.append(batch["target_eos_ids"])

source_ids = Tensor(batch["source_eos_ids"], mstype.int32)
source_mask = Tensor(batch["source_eos_mask"], mstype.int32)
target_ids = Tensor(batch["target_sos_ids"], mstype.int32)
target_mask = Tensor(batch["target_sos_mask"], mstype.int32)
label_ids = Tensor(batch["target_eos_ids"], mstype.int32)
label_mask = Tensor(batch["target_eos_mask"], mstype.int32)

start_time = time.time()
log_prob, length = model.predict(source_ids, source_mask, target_ids, target_mask, label_ids, label_mask)
print(f" | Batch size: {config.batch_size}, "
f"Time cost: {time.time() - start_time}.")

log_probs.append(log_prob.asnumpy())
lengths.append(length.asnumpy())

output = []
for inputs, ref, log_prob, length in zip(source_sentences,
target_sentences,
log_probs,
lengths):
for i in range(config.batch_size):
example = {
"source": inputs[i].tolist(),
"target": ref[i].tolist(),
"log_prob": log_prob[i].tolist(),
"length": length[i]
}
output.append(example)

return output


def infer_ppl(config):
"""
Transformer infer PPL api.

Args:
config (TransformerConfig): Config.

Returns:
list, result with
"""
eval_dataset = load_dataset(data_files=config.test_dataset,
batch_size=config.batch_size,
epoch_count=1,
sink_mode=config.dataset_sink_mode,
shuffle=False) if config.test_dataset else None
prediction = transformer_infer_ppl(config, eval_dataset)
return prediction

+ 3
- 1
model_zoo/mass/src/utils/__init__.py View File

@@ -20,6 +20,7 @@ from .loss_monitor import LossCallBack
from .byte_pair_encoding import bpe_encode
from .initializer import zero_weight, one_weight, normal_weight, weight_variable
from .rouge_score import rouge
from .eval_score import get_score

__all__ = [
"Dictionary",
@@ -31,5 +32,6 @@ __all__ = [
"one_weight",
"zero_weight",
"normal_weight",
"weight_variable"
"weight_variable",
"get_score"
]

+ 92
- 0
model_zoo/mass/src/utils/eval_score.py View File

@@ -0,0 +1,92 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Get score by given metric."""
from .ppl_score import ngram_ppl
from .rouge_score import rouge


def get_ppl_score(result):
"""
Calculate Perplexity(PPL) score.

Args:
List[Dict], prediction, each example has 4 keys, "source",
"target", "log_prob" and "length".

Returns:
Float, ppl score.
"""
log_probs = []
total_length = 0

for sample in result:
log_prob = sample['log_prob']
length = sample['length']
log_probs.extend(log_prob)
total_length += length

print(f" | log_prob:{log_prob}")
print(f" | length:{length}")

ppl = ngram_ppl(log_probs, total_length, log_softmax=True)
print(f" | final PPL={ppl}.")
return ppl


def get_rouge_score(result, vocab):
"""
Calculate ROUGE score.

Args:
List[Dict], prediction, each example has 4 keys, "source",
"target", "prediction" and "prediction_prob".
Dictionary, dict instance.

retur:
Str, rouge score.
"""

predictions = []
targets = []
for sample in result:
predictions.append(' '.join([vocab[t] for t in sample['prediction']]))
targets.append(' '.join([vocab[t] for t in sample['target']]))
print(f" | source: {' '.join([vocab[t] for t in sample['source']])}")
print(f" | target: {targets[-1]}")

return rouge(predictions, targets)


def get_score(result, vocab=None, metric='rouge'):
"""
Get eval score.

Args:
List[Dict], prediction.
Dictionary, dict instance.
Str, metric function, default is rouge.

Return:
Str, Score.
"""
score = None
if metric == 'rouge':
score = get_rouge_score(result, vocab)
elif metric == 'ppl':
score = get_ppl_score(result)
else:
print(f" |metric not in (rouge, ppl)")

return score

+ 15
- 18
model_zoo/mass/src/utils/ppl_score.py View File

@@ -17,10 +17,7 @@ from typing import Union

import numpy as np

NINF = -1.0 * 1e9


def ngram_ppl(prob: Union[np.ndarray, list], log_softmax=False, index: float = np.e):
def ngram_ppl(prob: Union[np.ndarray, list], length: int, log_softmax=False, index: float = np.e):
"""
Calculate Perplexity(PPL) score under N-gram language model.

@@ -39,7 +36,8 @@ def ngram_ppl(prob: Union[np.ndarray, list], log_softmax=False, index: float = n
Returns:
float, ppl score.
"""
eps = 1e-8
if not length:
return np.inf
if not isinstance(prob, (np.ndarray, list)):
raise TypeError("`prob` must be type of list or np.ndarray.")
if not isinstance(prob, np.ndarray):
@@ -47,18 +45,17 @@ def ngram_ppl(prob: Union[np.ndarray, list], log_softmax=False, index: float = n
if prob.shape[0] == 0:
raise ValueError("`prob` length must greater than 0.")

p = 1.0
sen_len = 0
for t in range(prob.shape[0]):
s = prob[t]
if s <= NINF:
break
if log_softmax:
s = np.power(index, s)
p *= (1 / (s + eps))
sen_len += 1
print(f'length:{length}, log_prob:{prob}')

if sen_len == 0:
return np.inf
if log_softmax:
prob = np.sum(prob) / length
ppl = 1. / np.power(index, prob)
print(f'avg log prob:{prob}')
else:
p = 1.
for i in range(prob.shape[0]):
p *= (1. / prob[i])
ppl = pow(p, 1 / length)

return pow(p, 1 / sen_len)
print(f'ppl val:{ppl}')
return ppl

Loading…
Cancel
Save