Browse Source

bert ner for adaption of MSRA dataset

tags/v1.2.0-rc1
shibeiji 4 years ago
parent
commit
cc18b206c9
8 changed files with 116 additions and 132 deletions
  1. +3
    -5
      model_zoo/official/nlp/bert/README.md
  2. +3
    -5
      model_zoo/official/nlp/bert/README_CN.md
  3. +20
    -8
      model_zoo/official/nlp/bert/run_ner.py
  4. +2
    -2
      model_zoo/official/nlp/bert/scripts/run_ner.sh
  5. +24
    -71
      model_zoo/official/nlp/bert/src/assessment_method.py
  6. +2
    -2
      model_zoo/official/nlp/bert/src/dataset.py
  7. +59
    -37
      model_zoo/official/nlp/bert/src/finetune_data_preprocess.py
  8. +3
    -2
      model_zoo/official/nlp/bert/src/utils.py

+ 3
- 5
model_zoo/official/nlp/bert/README.md View File

@@ -551,7 +551,7 @@ F1 0.920507
For preprocess, you can first convert the original txt format of MSRA dataset into mindrecord by run the command as below: For preprocess, you can first convert the original txt format of MSRA dataset into mindrecord by run the command as below:


```python ```python
python src/finetune_data_preprocess.py ----data_dir=/path/msra_dataset.txt --vocab_file=/path/vacab_file --save_path=/path/msra_dataset.mindrecord --label2id=/path/label2id_file --max_seq_len=seq_len
python src/finetune_data_preprocess.py ----data_dir=/path/msra_dataset.txt --vocab_file=/path/vacab_file --save_path=/path/msra_dataset.mindrecord --label2id=/path/label2id_file --max_seq_len=seq_len --class_filter="NAMEX" --split_begin=0.0 --split_end=1.0
``` ```


For finetune and evaluation, just do For finetune and evaluation, just do
@@ -562,12 +562,10 @@ bash scripts/ner.sh


The command above will run in the background, you can view training logs in ner_log.txt. The command above will run in the background, you can view training logs in ner_log.txt.


If you choose SpanF1 as assessment method and mode use_crf is set to be "true", the result will be as follows if evaluation is done after finetuning 10 epoches:
If you choose MF1(F1 score with multi-labels) as assessment method, the result will be as follows if evaluation is done after finetuning 10 epoches:


```text ```text
Precision 0.953826
Recall 0.957749
F1 0.955784
F1 0.931243
``` ```


#### evaluation on squad v1.1 dataset when running on Ascend #### evaluation on squad v1.1 dataset when running on Ascend


+ 3
- 5
model_zoo/official/nlp/bert/README_CN.md View File

@@ -515,7 +515,7 @@ F1 0.920507
您可以采用如下方式,先将MSRA数据集的原始格式在预处理流程中转换为mindrecord格式以提升性能: 您可以采用如下方式,先将MSRA数据集的原始格式在预处理流程中转换为mindrecord格式以提升性能:


```python ```python
python src/finetune_data_preprocess.py ----data_dir=/path/msra_dataset.txt --vocab_file=/path/vacab_file --save_path=/path/msra_dataset.mindrecord --label2id=/path/label2id_file --max_seq_len=seq_len
python src/finetune_data_preprocess.py ----data_dir=/path/msra_dataset.txt --vocab_file=/path/vacab_file --save_path=/path/msra_dataset.mindrecord --label2id=/path/label2id_file --max_seq_len=seq_len --class_filter="NAMEX" --split_begin=0.0 --split_end=1.0
``` ```


此后,您可以进行微调再训练和推理流程, 此后,您可以进行微调再训练和推理流程,
@@ -525,12 +525,10 @@ bash scripts/ner.sh
``` ```


以上命令后台运行,您可以在ner_log.txt中查看训练日志。 以上命令后台运行,您可以在ner_log.txt中查看训练日志。
如您选择SpanF1作为评估方法并且模型结构中配置CRF模式,在微调训练10个epoch之后进行推理,可得到如下结果:
如您选择MF1(多标签的F1得分)作为评估方法,在微调训练10个epoch之后进行推理,可得到如下结果:


```text ```text
Precision 0.953826
Recall 0.957749
F1 0.955784
F1 0.931243
``` ```


#### Ascend处理器上运行后评估squad v1.1数据集 #### Ascend处理器上运行后评估squad v1.1数据集


+ 20
- 8
model_zoo/official/nlp/bert/run_ner.py View File

@@ -19,11 +19,12 @@ Bert finetune and evaluation script.


import os import os
import argparse import argparse
import time
from src.bert_for_finetune import BertFinetuneCell, BertNER from src.bert_for_finetune import BertFinetuneCell, BertNER
from src.finetune_eval_config import optimizer_cfg, bert_net_cfg from src.finetune_eval_config import optimizer_cfg, bert_net_cfg
from src.dataset import create_ner_dataset from src.dataset import create_ner_dataset
from src.utils import make_directory, LossCallBack, LoadNewestCkpt, BertLearningRate, convert_labels_to_index from src.utils import make_directory, LossCallBack, LoadNewestCkpt, BertLearningRate, convert_labels_to_index
from src.assessment_method import Accuracy, F1, MCC, Spearman_Correlation, SpanF1
from src.assessment_method import Accuracy, F1, MCC, Spearman_Correlation
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
from mindspore import context from mindspore import context
from mindspore import log as logger from mindspore import log as logger
@@ -79,17 +80,22 @@ def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoin
netwithgrads = BertFinetuneCell(network, optimizer=optimizer, scale_update_cell=update_cell) netwithgrads = BertFinetuneCell(network, optimizer=optimizer, scale_update_cell=update_cell)
model = Model(netwithgrads) model = Model(netwithgrads)
callbacks = [TimeMonitor(dataset.get_dataset_size()), LossCallBack(dataset.get_dataset_size()), ckpoint_cb] callbacks = [TimeMonitor(dataset.get_dataset_size()), LossCallBack(dataset.get_dataset_size()), ckpoint_cb]
train_begin = time.time()
model.train(epoch_num, dataset, callbacks=callbacks) model.train(epoch_num, dataset, callbacks=callbacks)
train_end = time.time()
print("latency: {:.6f} s".format(train_end - train_begin))


def eval_result_print(assessment_method="accuracy", callback=None): def eval_result_print(assessment_method="accuracy", callback=None):
"""print eval result""" """print eval result"""
if assessment_method == "accuracy": if assessment_method == "accuracy":
print("acc_num {} , total_num {}, accuracy {:.6f}".format(callback.acc_num, callback.total_num, print("acc_num {} , total_num {}, accuracy {:.6f}".format(callback.acc_num, callback.total_num,
callback.acc_num / callback.total_num)) callback.acc_num / callback.total_num))
elif assessment_method in ("f1", "spanf1"):
elif assessment_method == "bf1":
print("Precision {:.6f} ".format(callback.TP / (callback.TP + callback.FP))) print("Precision {:.6f} ".format(callback.TP / (callback.TP + callback.FP)))
print("Recall {:.6f} ".format(callback.TP / (callback.TP + callback.FN))) print("Recall {:.6f} ".format(callback.TP / (callback.TP + callback.FN)))
print("F1 {:.6f} ".format(2 * callback.TP / (2 * callback.TP + callback.FP + callback.FN))) print("F1 {:.6f} ".format(2 * callback.TP / (2 * callback.TP + callback.FP + callback.FN)))
elif assessment_method == "mf1":
print("F1 {:.6f} ".format(callback.eval()[0]))
elif assessment_method == "mcc": elif assessment_method == "mcc":
print("MCC {:.6f} ".format(callback.cal())) print("MCC {:.6f} ".format(callback.cal()))
elif assessment_method == "spearman_correlation": elif assessment_method == "spearman_correlation":
@@ -116,10 +122,10 @@ def do_eval(dataset=None, network=None, use_crf="", num_class=41, assessment_met
else: else:
if assessment_method == "accuracy": if assessment_method == "accuracy":
callback = Accuracy() callback = Accuracy()
elif assessment_method == "f1":
elif assessment_method == "bf1":
callback = F1((use_crf.lower() == "true"), num_class) callback = F1((use_crf.lower() == "true"), num_class)
elif assessment_method == "spanf1":
callback = SpanF1((use_crf.lower() == "true"), tag_to_index)
elif assessment_method == "mf1":
callback = F1((use_crf.lower() == "true"), num_labels=num_class, mode="MultiLabel")
elif assessment_method == "mcc": elif assessment_method == "mcc":
callback = MCC() callback = MCC()
elif assessment_method == "spearman_correlation": elif assessment_method == "spearman_correlation":
@@ -145,8 +151,8 @@ def parse_args():
parser = argparse.ArgumentParser(description="run ner") parser = argparse.ArgumentParser(description="run ner")
parser.add_argument("--device_target", type=str, default="Ascend", choices=["Ascend", "GPU"], parser.add_argument("--device_target", type=str, default="Ascend", choices=["Ascend", "GPU"],
help="Device type, default is Ascend") help="Device type, default is Ascend")
parser.add_argument("--assessment_method", type=str, default="F1", choices=["F1", "clue_benchmark", "SpanF1"],
help="assessment_method include: [F1, clue_benchmark, SpanF1], default is F1")
parser.add_argument("--assessment_method", type=str, default="BF1", choices=["BF1", "clue_benchmark", "MF1"],
help="assessment_method include: [BF1, clue_benchmark, MF1], default is BF1")
parser.add_argument("--do_train", type=str, default="false", choices=["true", "false"], parser.add_argument("--do_train", type=str, default="false", choices=["true", "false"],
help="Eable train, default is false") help="Eable train, default is false")
parser.add_argument("--do_eval", type=str, default="false", choices=["true", "false"], parser.add_argument("--do_eval", type=str, default="false", choices=["true", "false"],
@@ -231,6 +237,12 @@ def run_ner():
assessment_method=assessment_method, data_file_path=args_opt.train_data_file_path, assessment_method=assessment_method, data_file_path=args_opt.train_data_file_path,
schema_file_path=args_opt.schema_file_path, dataset_format=args_opt.dataset_format, schema_file_path=args_opt.schema_file_path, dataset_format=args_opt.dataset_format,
do_shuffle=(args_opt.train_data_shuffle.lower() == "true")) do_shuffle=(args_opt.train_data_shuffle.lower() == "true"))
print("==============================================================")
print("processor_name: {}".format(args_opt.device_target))
print("test_name: BERT Finetune Training")
print("model_name: {}".format("BERT+MLP+CRF" if args_opt.use_crf.lower() == "true" else "BERT + MLP"))
print("batch_size: {}".format(args_opt.train_batch_size))

do_train(ds, netwithloss, load_pretrain_checkpoint_path, save_finetune_checkpoint_path, epoch_num) do_train(ds, netwithloss, load_pretrain_checkpoint_path, save_finetune_checkpoint_path, epoch_num)


if args_opt.do_eval.lower() == "true": if args_opt.do_eval.lower() == "true":
@@ -245,7 +257,7 @@ def run_ner():
ds = create_ner_dataset(batch_size=args_opt.eval_batch_size, repeat_count=1, ds = create_ner_dataset(batch_size=args_opt.eval_batch_size, repeat_count=1,
assessment_method=assessment_method, data_file_path=args_opt.eval_data_file_path, assessment_method=assessment_method, data_file_path=args_opt.eval_data_file_path,
schema_file_path=args_opt.schema_file_path, dataset_format=args_opt.dataset_format, schema_file_path=args_opt.schema_file_path, dataset_format=args_opt.dataset_format,
do_shuffle=(args_opt.eval_data_shuffle.lower() == "true"))
do_shuffle=(args_opt.eval_data_shuffle.lower() == "true"), drop_remainder=False)
do_eval(ds, BertNER, args_opt.use_crf, number_labels, assessment_method, do_eval(ds, BertNER, args_opt.use_crf, number_labels, assessment_method,
args_opt.eval_data_file_path, load_finetune_checkpoint_path, args_opt.vocab_file_path, args_opt.eval_data_file_path, load_finetune_checkpoint_path, args_opt.vocab_file_path,
args_opt.label_file_path, tag_to_index, args_opt.eval_batch_size) args_opt.label_file_path, tag_to_index, args_opt.eval_batch_size)


+ 2
- 2
model_zoo/official/nlp/bert/scripts/run_ner.sh View File

@@ -18,7 +18,7 @@ echo "==========================================================================
echo "Please run the script as: " echo "Please run the script as: "
echo "bash scripts/run_ner.sh" echo "bash scripts/run_ner.sh"
echo "for example: bash scripts/run_ner.sh" echo "for example: bash scripts/run_ner.sh"
echo "assessment_method include: [F1, SpanF1, clue_benchmark]"
echo "assessment_method include: [BF1, MF1, clue_benchmark]"
echo "==============================================================================================================" echo "=============================================================================================================="


mkdir -p ms_log mkdir -p ms_log
@@ -30,7 +30,7 @@ python ${PROJECT_DIR}/../run_ner.py \
--device_target="Ascend" \ --device_target="Ascend" \
--do_train="true" \ --do_train="true" \
--do_eval="false" \ --do_eval="false" \
--assessment_method="F1" \
--assessment_method="BF1" \
--use_crf="false" \ --use_crf="false" \
--device_id=0 \ --device_id=0 \
--epoch_num=5 \ --epoch_num=5 \


+ 24
- 71
model_zoo/official/nlp/bert/src/assessment_method.py View File

@@ -18,6 +18,7 @@ Bert evaluation assessment method script.
''' '''
import math import math
import numpy as np import numpy as np
from mindspore.nn.metrics import ConfusionMatrixMetric
from .CRF import postprocess from .CRF import postprocess


class Accuracy(): class Accuracy():
@@ -39,12 +40,18 @@ class F1():
''' '''
calculate F1 score calculate F1 score
''' '''
def __init__(self, use_crf=False, num_labels=2):
def __init__(self, use_crf=False, num_labels=2, mode="Binary"):
self.TP = 0 self.TP = 0
self.FP = 0 self.FP = 0
self.FN = 0 self.FN = 0
self.use_crf = use_crf self.use_crf = use_crf
self.num_labels = num_labels self.num_labels = num_labels
self.mode = mode
if self.mode.lower() not in ("binary", "multilabel"):
raise ValueError("Assessment mode not supported, support: [Binary, MultiLabel]")
if self.mode.lower() != "binary":
self.metric = ConfusionMatrixMetric(skip_channel=False, metric_name=("f1 score"),
calculation_method=False, decrease="mean")


def update(self, logits, labels): def update(self, logits, labels):
''' '''
@@ -62,78 +69,24 @@ class F1():
logits = logits.asnumpy() logits = logits.asnumpy()
logit_id = np.argmax(logits, axis=-1) logit_id = np.argmax(logits, axis=-1)
logit_id = np.reshape(logit_id, -1) logit_id = np.reshape(logit_id, -1)
pos_eva = np.isin(logit_id, [i for i in range(1, self.num_labels)])
pos_label = np.isin(labels, [i for i in range(1, self.num_labels)])
self.TP += np.sum(pos_eva&pos_label)
self.FP += np.sum(pos_eva&(~pos_label))
self.FN += np.sum((~pos_eva)&pos_label)



class SpanF1():
'''
calculate F1、precision and recall score in span manner for NER
'''
def __init__(self, use_crf=False, label2id=None):
self.TP = 0
self.FP = 0
self.FN = 0
self.use_crf = use_crf
self.label2id = label2id
if label2id is None:
raise ValueError("label2id info should not be empty")
self.id2label = {}
for key, value in label2id.items():
self.id2label[value] = key

def tag2span(self, ids):
'''
conbert ids list to span mode
'''
labels = np.array([self.id2label[id] for id in ids])
spans = []
prev_label = None
for idx, tag in enumerate(labels):
tag = tag.lower()
cur_label, label = tag[:1], tag[2:]
if cur_label in ('b', 's'):
spans.append((label, [idx, idx]))
elif cur_label in ('m', 'e') and prev_label in ('b', 'm') and label == spans[-1][0]:
spans[-1][1][1] = idx
elif cur_label == 'o':
pass
else:
spans.append((label, [idx, idx]))
prev_label = cur_label
return [(span[0], (span[1][0], span[1][1] + 1)) for span in spans]


def update(self, logits, labels):
'''
update span F1 score
'''
labels = labels.asnumpy()
labels = np.reshape(labels, -1)
if self.use_crf:
backpointers, best_tag_id = logits
best_path = postprocess(backpointers, best_tag_id)
logit_id = []
for ele in best_path:
logit_id.extend(ele)
if self.mode.lower() == "binary":
pos_eva = np.isin(logit_id, [i for i in range(1, self.num_labels)])
pos_label = np.isin(labels, [i for i in range(1, self.num_labels)])
self.TP += np.sum(pos_eva&pos_label)
self.FP += np.sum(pos_eva&(~pos_label))
self.FN += np.sum((~pos_eva)&pos_label)
else: else:
logits = logits.asnumpy()
logit_id = np.argmax(logits, axis=-1)
logit_id = np.reshape(logit_id, -1)

label_spans = self.tag2span(labels)
pred_spans = self.tag2span(logit_id)
for span in pred_spans:
if span in label_spans:
self.TP += 1
label_spans.remove(span)
else:
self.FP += 1
for span in label_spans:
self.FN += 1
target = np.zeros((len(labels), self.num_labels), dtype=np.int)
pred = np.zeros((len(logit_id), self.num_labels), dtype=np.int)
for i, label in enumerate(labels):
target[i][label] = 1
for i, label in enumerate(logit_id):
pred[i][label] = 1
self.metric.update(pred, target)

def eval(self):
return self.metric.eval()




class MCC(): class MCC():


+ 2
- 2
model_zoo/official/nlp/bert/src/dataset.py View File

@@ -53,7 +53,7 @@ def create_bert_dataset(device_num=1, rank=0, do_shuffle="true", data_dir=None,




def create_ner_dataset(batch_size=1, repeat_count=1, assessment_method="accuracy", data_file_path=None, def create_ner_dataset(batch_size=1, repeat_count=1, assessment_method="accuracy", data_file_path=None,
dataset_format="mindrecord", schema_file_path=None, do_shuffle=True):
dataset_format="mindrecord", schema_file_path=None, do_shuffle=True, drop_remainder=True):
"""create finetune or evaluation dataset""" """create finetune or evaluation dataset"""
type_cast_op = C.TypeCast(mstype.int32) type_cast_op = C.TypeCast(mstype.int32)
if dataset_format == "mindrecord": if dataset_format == "mindrecord":
@@ -74,7 +74,7 @@ def create_ner_dataset(batch_size=1, repeat_count=1, assessment_method="accuracy
dataset = dataset.map(operations=type_cast_op, input_columns="input_ids") dataset = dataset.map(operations=type_cast_op, input_columns="input_ids")
dataset = dataset.repeat(repeat_count) dataset = dataset.repeat(repeat_count)
# apply batch operations # apply batch operations
dataset = dataset.batch(batch_size, drop_remainder=True)
dataset = dataset.batch(batch_size, drop_remainder=drop_remainder)
return dataset return dataset






+ 59
- 37
model_zoo/official/nlp/bert/src/finetune_data_preprocess.py View File

@@ -20,6 +20,7 @@ sample script of processing CLUE classification dataset using mindspore.dataset.
import os import os
import argparse import argparse
import numpy as np import numpy as np
from lxml import etree
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
import mindspore.dataset as ds import mindspore.dataset as ds
@@ -139,46 +140,60 @@ def process_cmnli_clue_dataset(data_dir, label_list, bert_vocab_path, data_usage
return dataset return dataset
def process_cluener_msra(data_file):
"""process MSRA dataset for CLUE"""
content = []
labels = []
for line in open(data_file):
line = line.strip()
if line:
word = line.split("\t")[0]
if len(line.split("\t")) == 1:
label = "O"
def process_cluener_msra(data_file, class_filter=None, split_begin=None, split_end=None):
"""
Data pre-process for MSRA dataset
Args:
data_file (path): The original dataset file path.
class_filter (list of str): Only tags within the class_filter will be counted unless the list is None.
split_begin (float): Only data after split_begin part will be counted. Used for split dataset
into training and evaluation subsets if needed.
split_end (float): Only data before split_end part will be counted. Used for split dataset
into training and evaluation subsets if needed.
"""
tree = etree.parse(data_file)
root = tree.getroot()
print("original dataset length: ", len(root))
dataset_size = len(root)
beg = 0 if split_begin is None or not 0 <= split_begin <= 1.0 else int(dataset_size * split_begin)
end = dataset_size if split_end is None or not 0 <= split_end <= 1.0 else int(dataset_size * split_end)
print("preporcessed dataset_size: ", end - beg)
for i in range(beg, end):
sentence = root[i]
tags = []
content = ""
for phrases in sentence:
labeled_words = [word for word in phrases]
if labeled_words:
for words in phrases:
name = words.tag
label = words.get("TYPE")
words = words.text
if not words:
continue
content += words
if class_filter and name not in class_filter:
tags += ["O" for _ in words]
else:
length = len(words)
labels = ["S_"] if length == 1 else ["B_"] + ["M_" for i in range(length - 2)] + ["E_"]
tags += [ele + label for ele in labels]
else: else:
label = line.split("\t")[1].split("\n")[0]
if label[0] != "O":
label = label[0] + "_" + label[2:]
if label[0] == "I":
label = "M" + label[1:]
content.append(word)
labels.append(label)
else:
for i in range(1, len(labels) - 1):
if labels[i][0] == "B" and labels[i+1][0] != "M":
labels[i] = "S" + labels[i][1:]
elif labels[i][0] == "M" and labels[i+1][0] != labels[i][0]:
labels[i] = "E" + labels[i][1:]
last = len(labels) - 1
if labels[last][0] == "B":
labels[last] = "S" + labels[last][1:]
elif labels[last][0] == "M":
labels[last] = "E" + labels[last][1:]
phrases = phrases.text
if phrases:
content += phrases
tags += ["O" for ele in phrases]
if len(content) != len(tags):
raise ValueError("Mismathc length of content: ", len(content), " and label: ", len(tags))
yield (np.array("".join(content)), np.array(list(tags)))
yield (np.array("".join(content)), np.array(list(labels)))
content.clear()
labels.clear()
continue
def process_msra_clue_dataset(data_dir, label_list, bert_vocab_path, max_seq_len=128):
def process_msra_clue_dataset(data_dir, label_list, bert_vocab_path, max_seq_len=128, class_filter=None,
split_begin=None, split_end=None):
"""Process MSRA dataset""" """Process MSRA dataset"""
### Loading MSRA from CLUEDataset ### Loading MSRA from CLUEDataset
dataset = ds.GeneratorDataset(process_cluener_msra(data_dir), column_names=['text', 'label'])
dataset = ds.GeneratorDataset(process_cluener_msra(data_dir, class_filter, split_begin, split_end),
column_names=['text', 'label'])
### Processing label ### Processing label
label_vocab = text.Vocab.from_list(label_list) label_vocab = text.Vocab.from_list(label_list)
@@ -213,10 +228,16 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser(description="create mindrecord") parser = argparse.ArgumentParser(description="create mindrecord")
parser.add_argument("--data_dir", type=str, default="", help="dataset path") parser.add_argument("--data_dir", type=str, default="", help="dataset path")
parser.add_argument("--vocab_file", type=str, default="", help="Vocab file path") parser.add_argument("--vocab_file", type=str, default="", help="Vocab file path")
parser.add_argument("--max_seq_len", type=int, default=128, help="Sequence length")
parser.add_argument("--save_path", type=str, default="./my.mindrecord", help="Path to save mindrecord") parser.add_argument("--save_path", type=str, default="./my.mindrecord", help="Path to save mindrecord")
parser.add_argument("--label2id", type=str, default="", parser.add_argument("--label2id", type=str, default="",
help="Label2id file path, must be set for cluener2020 task") help="Label2id file path, must be set for cluener2020 task")
parser.add_argument("--max_seq_len", type=int, default=128, help="Sequence length")
parser.add_argument("--class_filter", nargs='*', help="Specified classes will be counted, if empty all in counted")
parser.add_argument("--split_begin", type=float, default=None, help="Specified subsets of date will be counted,"
"if not None, the data will counted begin from split_begin")
parser.add_argument("--split_end", type=float, default=None, help="Specified subsets of date will be counted,"
"if not None, the data will counted before split_before")
args_opt = parser.parse_args() args_opt = parser.parse_args()
if args_opt.label2id == "": if args_opt.label2id == "":
raise ValueError("label2id should not be empty") raise ValueError("label2id should not be empty")
@@ -225,5 +246,6 @@ if __name__ == "__main__":
for tag in f: for tag in f:
labels_list.append(tag.strip()) labels_list.append(tag.strip())
tag_to_index = list(convert_labels_to_index(labels_list).keys()) tag_to_index = list(convert_labels_to_index(labels_list).keys())
ds = process_msra_clue_dataset(args_opt.data_dir, tag_to_index, args_opt.vocab_file, args_opt.max_seq_len)
ds = process_msra_clue_dataset(args_opt.data_dir, tag_to_index, args_opt.vocab_file, args_opt.max_seq_len,
args_opt.class_filter, args_opt.split_begin, args_opt.split_end)
ds.save(args_opt.save_path) ds.save(args_opt.save_path)

+ 3
- 2
model_zoo/official/nlp/bert/src/utils.py View File

@@ -106,10 +106,11 @@ class LossCallBack(Callback):
percent = 1 percent = 1
epoch_num -= 1 epoch_num -= 1
print("epoch: {}, current epoch percent: {}, step: {}, outputs are {}" print("epoch: {}, current epoch percent: {}, step: {}, outputs are {}"
.format(int(epoch_num), "%.3f" % percent, cb_params.cur_step_num, str(cb_params.net_outputs)))
.format(int(epoch_num), "%.3f" % percent, cb_params.cur_step_num, str(cb_params.net_outputs),
flush=True))
else: else:
print("epoch: {}, step: {}, outputs are {}".format(cb_params.cur_epoch_num, cb_params.cur_step_num, print("epoch: {}, step: {}, outputs are {}".format(cb_params.cur_epoch_num, cb_params.cur_step_num,
str(cb_params.net_outputs)))
str(cb_params.net_outputs), flush=True))


def LoadNewestCkpt(load_finetune_checkpoint_dir, steps_per_epoch, epoch_num, prefix): def LoadNewestCkpt(load_finetune_checkpoint_dir, steps_per_epoch, epoch_num, prefix):
""" """


Loading…
Cancel
Save