From c4d8c8aec09b60594c739148c6ea53b5bf22bc67 Mon Sep 17 00:00:00 2001 From: linqingke Date: Sun, 28 Jun 2020 15:47:49 +0800 Subject: [PATCH] Mass text summarization fix bug. --- model_zoo/mass/scripts/run.sh | 4 ++-- model_zoo/mass/src/dataset/load_dataset.py | 7 ------- model_zoo/mass/src/utils/rouge_score.py | 2 +- model_zoo/mass/train.py | 1 + 4 files changed, 4 insertions(+), 10 deletions(-) diff --git a/model_zoo/mass/scripts/run.sh b/model_zoo/mass/scripts/run.sh index fc9606fcbd..91bed510ea 100644 --- a/model_zoo/mass/scripts/run.sh +++ b/model_zoo/mass/scripts/run.sh @@ -18,7 +18,7 @@ export DEVICE_ID=0 export RANK_ID=0 export RANK_SIZE=1 -options=`getopt -u -o ht:n:i:j:c:o:v: -l help,task:,device_num:,device_id:,hccl_json:,config:,output:,vocab -- "$@"` +options=`getopt -u -o ht:n:i:j:c:o:v: -l help,task:,device_num:,device_id:,hccl_json:,config:,output:,vocab: -- "$@"` eval set -- "$options" echo $options @@ -129,6 +129,7 @@ do esac done +file_path=$(cd "$(dirname $0)" || exit; pwd) for((i=0; i < $RANK_SIZE; i++)) do if [ $RANK_SIZE -gt 1 ] @@ -139,7 +140,6 @@ do fi echo "Working on device $i" - file_path=$(cd "$(dirname $0)" || exit; pwd) cd $file_path || exit cd ../ || exit diff --git a/model_zoo/mass/src/dataset/load_dataset.py b/model_zoo/mass/src/dataset/load_dataset.py index 9d9d558cb6..53ad5c7491 100644 --- a/model_zoo/mass/src/dataset/load_dataset.py +++ b/model_zoo/mass/src/dataset/load_dataset.py @@ -13,7 +13,6 @@ # limitations under the License. # ============================================================================ """Dataset loader to feed into model.""" -import os import mindspore.common.dtype as mstype import mindspore.dataset.engine as de import mindspore.dataset.transforms.c_transforms as deC @@ -40,12 +39,6 @@ def _load_dataset(input_files, batch_size, epoch_count=1, if not input_files: raise FileNotFoundError("Require at least one dataset.") - if not (schema_file and - os.path.exists(schema_file) - and os.path.isfile(schema_file) - and os.path.basename(schema_file).endswith(".json")): - raise FileNotFoundError("`dataset_schema` must be a existed json file.") - if not isinstance(sink_mode, bool): raise ValueError("`sink` must be type of bool.") diff --git a/model_zoo/mass/src/utils/rouge_score.py b/model_zoo/mass/src/utils/rouge_score.py index f453b5d2e1..665cda3433 100644 --- a/model_zoo/mass/src/utils/rouge_score.py +++ b/model_zoo/mass/src/utils/rouge_score.py @@ -47,7 +47,7 @@ def rouge(hypothesis: List[str], target: List[str]): edited_ref.append(r + "\n") _rouge = Rouge() - scores = _rouge.get_scores(edited_hyp, target, avg=True) + scores = _rouge.get_scores(edited_hyp, edited_ref, avg=True) print(" | ROUGE Score:") print(f" | RG-1(F): {scores['rouge-1']['f'] * 100:8.2f}") print(f" | RG-2(F): {scores['rouge-2']['f'] * 100:8.2f}") diff --git a/model_zoo/mass/train.py b/model_zoo/mass/train.py index 05b96ddae3..b58075ba4e 100644 --- a/model_zoo/mass/train.py +++ b/model_zoo/mass/train.py @@ -120,6 +120,7 @@ def _build_training_pipeline(config: TransformerConfig, test_dataset (Dataset): Test dataset. """ net_with_loss = TransformerNetworkWithLoss(config, is_training=True) + net_with_loss.init_parameters_data() if config.existed_ckpt: if config.existed_ckpt.endswith(".npz"):