Browse Source

add bert support for glue

tags/v0.6.0-beta
yoonlee666 5 years ago
parent
commit
7a8ee4725b
4 changed files with 201 additions and 40 deletions
  1. +2
    -2
      model_zoo/bert/README.md
  2. +123
    -17
      model_zoo/bert/evaluation.py
  3. +22
    -14
      model_zoo/bert/finetune.py
  4. +54
    -7
      model_zoo/bert/src/utils.py

+ 2
- 2
model_zoo/bert/README.md View File

@@ -89,7 +89,7 @@ config.py:
optimizer optimizer used in the network: AdamWerigtDecayDynamicLR | Lamb | Momentum, default is "Lamb" optimizer optimizer used in the network: AdamWerigtDecayDynamicLR | Lamb | Momentum, default is "Lamb"


finetune_config.py: finetune_config.py:
task task type: NER | SQUAD | OTHERS
task task type: SeqLabeling | Regression | Classification | COLA | SQUAD
num_labels number of labels to do classification num_labels number of labels to do classification
data_file dataset file to load: PATH, default is "/your/path/train.tfrecord" data_file dataset file to load: PATH, default is "/your/path/train.tfrecord"
schema_file dataset schema file to load: PATH, default is "/your/path/schema.json" schema_file dataset schema file to load: PATH, default is "/your/path/schema.json"
@@ -101,7 +101,7 @@ finetune_config.py:
optimizer optimizer used in fine-tune network: AdamWeigtDecayDynamicLR | Lamb | Momentum, default is "Lamb" optimizer optimizer used in fine-tune network: AdamWeigtDecayDynamicLR | Lamb | Momentum, default is "Lamb"


evaluation_config.py: evaluation_config.py:
task task type: NER | SQUAD | OTHERS
task task type: SeqLabeling | Regression | Classification | COLA
num_labels number of labels to do classsification num_labels number of labels to do classsification
data_file dataset file to load: PATH, default is "/your/path/evaluation.tfrecord" data_file dataset file to load: PATH, default is "/your/path/evaluation.tfrecord"
schema_file dataset schema file to load: PATH, default is "/your/path/schema.json" schema_file dataset schema file to load: PATH, default is "/your/path/schema.json"


+ 123
- 17
model_zoo/bert/evaluation.py View File

@@ -19,6 +19,7 @@ Bert evaluation script.


import os import os
import argparse import argparse
import math
import numpy as np import numpy as np
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
from mindspore import context from mindspore import context
@@ -29,19 +30,24 @@ import mindspore.dataset.transforms.c_transforms as C
from mindspore.train.model import Model from mindspore.train.model import Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.evaluation_config import cfg, bert_net_cfg from src.evaluation_config import cfg, bert_net_cfg
from src.utils import BertNER, BertCLS
from src.utils import BertNER, BertCLS, BertReg
from src.CRF import postprocess from src.CRF import postprocess
from src.cluener_evaluation import submit from src.cluener_evaluation import submit
from src.finetune_config import tag_to_index from src.finetune_config import tag_to_index



class Accuracy(): class Accuracy():
'''
"""
calculate accuracy calculate accuracy
'''
"""
def __init__(self): def __init__(self):
self.acc_num = 0 self.acc_num = 0
self.total_num = 0 self.total_num = 0

def update(self, logits, labels): def update(self, logits, labels):
"""
Update accuracy
"""
labels = labels.asnumpy() labels = labels.asnumpy()
labels = np.reshape(labels, -1) labels = np.reshape(labels, -1)
logits = logits.asnumpy() logits = logits.asnumpy()
@@ -50,18 +56,20 @@ class Accuracy():
self.total_num += len(labels) self.total_num += len(labels)
print("=========================accuracy is ", self.acc_num / self.total_num) print("=========================accuracy is ", self.acc_num / self.total_num)



class F1(): class F1():
'''
"""
calculate F1 score calculate F1 score
'''
"""
def __init__(self): def __init__(self):
self.TP = 0 self.TP = 0
self.FP = 0 self.FP = 0
self.FN = 0 self.FN = 0

def update(self, logits, labels): def update(self, logits, labels):
'''
"""
update F1 score update F1 score
'''
"""
labels = labels.asnumpy() labels = labels.asnumpy()
labels = np.reshape(labels, -1) labels = np.reshape(labels, -1)
if cfg.use_crf: if cfg.use_crf:
@@ -80,10 +88,76 @@ class F1():
self.FP += np.sum(pos_eva&(~pos_label)) self.FP += np.sum(pos_eva&(~pos_label))
self.FN += np.sum((~pos_eva)&pos_label) self.FN += np.sum((~pos_eva)&pos_label)



class MCC():
"""
Calculate Matthews Correlation Coefficient.
"""
def __init__(self):
self.TP = 0
self.FP = 0
self.FN = 0
self.TN = 0

def update(self, logits, labels):
"""
Update MCC score
"""
labels = labels.asnumpy()
labels = np.reshape(labels, -1)
labels = labels.astype(np.bool)
logits = logits.asnumpy()
logit_id = np.argmax(logits, axis=-1)
logit_id = np.reshape(logit_id, -1)
logit_id = logit_id.astype(np.bool)
ornot = logit_id ^ labels

self.TP += (~ornot & labels).sum()
self.FP += (ornot & ~labels).sum()
self.FN += (ornot & labels).sum()
self.TN += (~ornot & ~labels).sum()


class Spearman_Correlation():
"""
calculate Spearman Correlation coefficient
"""
def __init__(self):
self.label = []
self.logit = []

def update(self, logits, labels):
"""
Update Spearman Correlation
"""
labels = labels.asnumpy()
labels = np.reshape(labels, -1)
logits = logits.asnumpy()
logits = np.reshape(logits, -1)
self.label.append(labels)
self.logit.append(logits)

def cal(self):
"""
Calculate Spearman Correlation
"""
label = np.concatenate(self.label)
logit = np.concatenate(self.logit)
sort_label = label.argsort()[::-1]
sort_logit = logit.argsort()[::-1]
n = len(label)
d_acc = 0
for i in range(n):
d = np.where(sort_label == i)[0] - np.where(sort_logit == i)[0]
d_acc += d**2
ps = 1 - 6*d_acc/n/(n**2-1)
return ps


def get_dataset(batch_size=1, repeat_count=1, distribute_file=''): def get_dataset(batch_size=1, repeat_count=1, distribute_file=''):
'''
"""
get dataset get dataset
'''
"""
_ = distribute_file _ = distribute_file


ds = de.TFRecordDataset([cfg.data_file], cfg.schema_file, columns_list=["input_ids", "input_mask", ds = de.TFRecordDataset([cfg.data_file], cfg.schema_file, columns_list=["input_ids", "input_mask",
@@ -92,7 +166,11 @@ def get_dataset(batch_size=1, repeat_count=1, distribute_file=''):
ds = ds.map(input_columns="segment_ids", operations=type_cast_op) ds = ds.map(input_columns="segment_ids", operations=type_cast_op)
ds = ds.map(input_columns="input_mask", operations=type_cast_op) ds = ds.map(input_columns="input_mask", operations=type_cast_op)
ds = ds.map(input_columns="input_ids", operations=type_cast_op) ds = ds.map(input_columns="input_ids", operations=type_cast_op)
ds = ds.map(input_columns="label_ids", operations=type_cast_op)
if cfg.task == "Regression":
type_cast_op_float = C.TypeCast(mstype.float32)
ds = ds.map(input_columns="label_ids", operations=type_cast_op_float)
else:
ds = ds.map(input_columns="label_ids", operations=type_cast_op)
ds = ds.repeat(repeat_count) ds = ds.repeat(repeat_count)


# apply shuffle operation # apply shuffle operation
@@ -103,10 +181,11 @@ def get_dataset(batch_size=1, repeat_count=1, distribute_file=''):
ds = ds.batch(batch_size, drop_remainder=True) ds = ds.batch(batch_size, drop_remainder=True)
return ds return ds



def bert_predict(Evaluation): def bert_predict(Evaluation):
'''
"""
prediction function prediction function
'''
"""
target = args_opt.device_target target = args_opt.device_target
if target == "Ascend": if target == "Ascend":
devid = int(os.getenv('DEVICE_ID')) devid = int(os.getenv('DEVICE_ID'))
@@ -131,15 +210,33 @@ def bert_predict(Evaluation):
return model, dataset return model, dataset


def test_eval(): def test_eval():
'''
"""
evaluation function evaluation function
'''
task_type = BertNER if cfg.task == "NER" else BertCLS
"""
if cfg.task == "SeqLabeling":
task_type = BertNER
elif cfg.task == "Regression":
task_type = BertReg
elif cfg.task == "Classification":
task_type = BertCLS
elif cfg.task == "COLA":
task_type = BertCLS
else:
raise ValueError("Task not supported.")
model, dataset = bert_predict(task_type) model, dataset = bert_predict(task_type)

if cfg.clue_benchmark: if cfg.clue_benchmark:
submit(model, cfg.data_file, bert_net_cfg.seq_length) submit(model, cfg.data_file, bert_net_cfg.seq_length)
else: else:
callback = F1() if cfg.task == "NER" else Accuracy()
if cfg.task == "SeqLabeling":
callback = F1()
elif cfg.task == "COLA":
callback = MCC()
elif cfg.task == "Regression":
callback = Spearman_Correlation()
else:
callback = Accuracy()

columns_list = ["input_ids", "input_mask", "segment_ids", "label_ids"] columns_list = ["input_ids", "input_mask", "segment_ids", "label_ids"]
for data in dataset.create_dict_iterator(): for data in dataset.create_dict_iterator():
input_data = [] input_data = []
@@ -149,10 +246,19 @@ def test_eval():
logits = model.predict(input_ids, input_mask, token_type_id, label_ids) logits = model.predict(input_ids, input_mask, token_type_id, label_ids)
callback.update(logits, label_ids) callback.update(logits, label_ids)
print("==============================================================") print("==============================================================")
if cfg.task == "NER":
if cfg.task == "SeqLabeling":
print("Precision {:.6f} ".format(callback.TP / (callback.TP + callback.FP))) print("Precision {:.6f} ".format(callback.TP / (callback.TP + callback.FP)))
print("Recall {:.6f} ".format(callback.TP / (callback.TP + callback.FN))) print("Recall {:.6f} ".format(callback.TP / (callback.TP + callback.FN)))
print("F1 {:.6f} ".format(2*callback.TP / (2*callback.TP + callback.FP + callback.FN))) print("F1 {:.6f} ".format(2*callback.TP / (2*callback.TP + callback.FP + callback.FN)))
elif cfg.task == "COLA":
TP = callback.TP
TN = callback.TN
FP = callback.FP
FN = callback.FN
mcc = (TP*TN-FP*FN)/math.sqrt((TP+FP)*(TP+FN)*(TN+FP)*(TN+FN))
print("MCC: {:.6f}".format(mcc))
elif cfg.task == "Regression":
print("Spearman Correlation is {:.6f}".format(callback.cal()[0]))
else: else:
print("acc_num {} , total_num {}, accuracy {:.6f}".format(callback.acc_num, callback.total_num, print("acc_num {} , total_num {}, accuracy {:.6f}".format(callback.acc_num, callback.total_num,
callback.acc_num / callback.total_num)) callback.acc_num / callback.total_num))


+ 22
- 14
model_zoo/bert/finetune.py View File

@@ -13,13 +13,13 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================


'''
"""
Bert finetune script. Bert finetune script.
'''
"""


import os import os
import argparse import argparse
from src.utils import BertFinetuneCell, BertCLS, BertNER, BertSquad, BertSquadCell
from src.utils import BertFinetuneCell, BertCLS, BertNER, BertSquad, BertSquadCell, BertReg
from src.finetune_config import cfg, bert_net_cfg, tag_to_index from src.finetune_config import cfg, bert_net_cfg, tag_to_index
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
from mindspore import context from mindspore import context
@@ -34,14 +34,14 @@ from mindspore.train.callback import CheckpointConfig, ModelCheckpoint
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.serialization import load_checkpoint, load_param_into_net


class LossCallBack(Callback): class LossCallBack(Callback):
'''
"""
Monitor the loss in training. Monitor the loss in training.
If the loss is NAN or INF, terminate training. If the loss is NAN or INF, terminate training.
Note: Note:
If per_print_times is 0, do not print loss. If per_print_times is 0, do not print loss.
Args: Args:
per_print_times (int): Print loss every times. Default: 1. per_print_times (int): Print loss every times. Default: 1.
'''
"""
def __init__(self, per_print_times=1): def __init__(self, per_print_times=1):
super(LossCallBack, self).__init__() super(LossCallBack, self).__init__()
if not isinstance(per_print_times, int) or per_print_times < 0: if not isinstance(per_print_times, int) or per_print_times < 0:
@@ -56,16 +56,20 @@ class LossCallBack(Callback):
f.write("\n") f.write("\n")


def get_dataset(batch_size=1, repeat_count=1, distribute_file=''): def get_dataset(batch_size=1, repeat_count=1, distribute_file=''):
'''
"""
get dataset get dataset
'''
"""
ds = de.TFRecordDataset([cfg.data_file], cfg.schema_file, columns_list=["input_ids", "input_mask", ds = de.TFRecordDataset([cfg.data_file], cfg.schema_file, columns_list=["input_ids", "input_mask",
"segment_ids", "label_ids"]) "segment_ids", "label_ids"])
type_cast_op = C.TypeCast(mstype.int32) type_cast_op = C.TypeCast(mstype.int32)
ds = ds.map(input_columns="segment_ids", operations=type_cast_op) ds = ds.map(input_columns="segment_ids", operations=type_cast_op)
ds = ds.map(input_columns="input_mask", operations=type_cast_op) ds = ds.map(input_columns="input_mask", operations=type_cast_op)
ds = ds.map(input_columns="input_ids", operations=type_cast_op) ds = ds.map(input_columns="input_ids", operations=type_cast_op)
ds = ds.map(input_columns="label_ids", operations=type_cast_op)
if cfg.task == "Regression":
type_cast_op_float = C.TypeCast(mstype.float32)
ds = ds.map(input_columns="label_ids", operations=type_cast_op_float)
else:
ds = ds.map(input_columns="label_ids", operations=type_cast_op)
ds = ds.repeat(repeat_count) ds = ds.repeat(repeat_count)


# apply shuffle operation # apply shuffle operation
@@ -77,9 +81,9 @@ def get_dataset(batch_size=1, repeat_count=1, distribute_file=''):
return ds return ds


def get_squad_dataset(batch_size=1, repeat_count=1, distribute_file=''): def get_squad_dataset(batch_size=1, repeat_count=1, distribute_file=''):
'''
"""
get SQuAD dataset get SQuAD dataset
'''
"""
ds = de.TFRecordDataset([cfg.data_file], cfg.schema_file, columns_list=["input_ids", "input_mask", "segment_ids", ds = de.TFRecordDataset([cfg.data_file], cfg.schema_file, columns_list=["input_ids", "input_mask", "segment_ids",
"start_positions", "end_positions", "start_positions", "end_positions",
"unique_ids", "is_impossible"]) "unique_ids", "is_impossible"])
@@ -97,9 +101,9 @@ def get_squad_dataset(batch_size=1, repeat_count=1, distribute_file=''):
return ds return ds


def test_train(): def test_train():
'''
"""
finetune function finetune function
'''
"""
target = args_opt.device_target target = args_opt.device_target
if target == "Ascend": if target == "Ascend":
devid = int(os.getenv('DEVICE_ID')) devid = int(os.getenv('DEVICE_ID'))
@@ -113,7 +117,7 @@ def test_train():
raise Exception("Target error, GPU or Ascend is supported.") raise Exception("Target error, GPU or Ascend is supported.")
#BertCLSTrain for classification #BertCLSTrain for classification
#BertNERTrain for sequence labeling #BertNERTrain for sequence labeling
if cfg.task == 'NER':
if cfg.task == 'SeqLabeling':
if cfg.use_crf: if cfg.use_crf:
netwithloss = BertNER(bert_net_cfg, True, num_labels=len(tag_to_index), use_crf=True, netwithloss = BertNER(bert_net_cfg, True, num_labels=len(tag_to_index), use_crf=True,
tag_to_index=tag_to_index, dropout_prob=0.1) tag_to_index=tag_to_index, dropout_prob=0.1)
@@ -121,8 +125,12 @@ def test_train():
netwithloss = BertNER(bert_net_cfg, True, num_labels=cfg.num_labels, dropout_prob=0.1) netwithloss = BertNER(bert_net_cfg, True, num_labels=cfg.num_labels, dropout_prob=0.1)
elif cfg.task == 'SQUAD': elif cfg.task == 'SQUAD':
netwithloss = BertSquad(bert_net_cfg, True, 2, dropout_prob=0.1) netwithloss = BertSquad(bert_net_cfg, True, 2, dropout_prob=0.1)
else:
elif cfg.task == 'Regression':
netwithloss = BertReg(bert_net_cfg, True, num_labels=cfg.num_labels, dropout_prob=0.1)
elif cfg.task == 'Classification':
netwithloss = BertCLS(bert_net_cfg, True, num_labels=cfg.num_labels, dropout_prob=0.1) netwithloss = BertCLS(bert_net_cfg, True, num_labels=cfg.num_labels, dropout_prob=0.1)
else:
raise Exception("Target error, GPU or Ascend is supported.")
if cfg.task == 'SQUAD': if cfg.task == 'SQUAD':
dataset = get_squad_dataset(bert_net_cfg.batch_size, cfg.epoch_num) dataset = get_squad_dataset(bert_net_cfg.batch_size, cfg.epoch_num)
else: else:


+ 54
- 7
model_zoo/bert/src/utils.py View File

@@ -13,9 +13,9 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================


'''
"""
Functional Cells used in Bert finetune and evaluation. Functional Cells used in Bert finetune and evaluation.
'''
"""


import mindspore.nn as nn import mindspore.nn as nn
from mindspore.common.initializer import TruncatedNormal from mindspore.common.initializer import TruncatedNormal
@@ -245,6 +245,32 @@ class BertSquadCell(nn.Cell):
ret = (loss, cond) ret = (loss, cond)
return F.depend(ret, succ) return F.depend(ret, succ)



class BertRegressionModel(nn.Cell):
"""
Bert finetune model for regression task
"""
def __init__(self, config, is_training, num_labels=2, dropout_prob=0.0, use_one_hot_embeddings=False):
super(BertRegressionModel, self).__init__()
self.bert = BertModel(config, is_training, use_one_hot_embeddings)
self.cast = P.Cast()
self.weight_init = TruncatedNormal(config.initializer_range)
self.log_softmax = P.LogSoftmax(axis=-1)
self.dtype = config.dtype
self.num_labels = num_labels
self.dropout = nn.Dropout(1 - dropout_prob)
self.dense_1 = nn.Dense(config.hidden_size, 1, weight_init=self.weight_init,
has_bias=True).to_float(mstype.float16)

def construct(self, input_ids, input_mask, token_type_id):
_, pooled_output, _ = self.bert(input_ids, token_type_id, input_mask)
cls = self.cast(pooled_output, self.dtype)
cls = self.dropout(cls)
logits = self.dense_1(cls)
logits = self.cast(logits, self.dtype)
return logits


class BertCLSModel(nn.Cell): class BertCLSModel(nn.Cell):
""" """
This class is responsible for classification task evaluation, i.e. XNLI(num_labels=3), This class is responsible for classification task evaluation, i.e. XNLI(num_labels=3),
@@ -274,9 +300,9 @@ class BertCLSModel(nn.Cell):
return log_probs return log_probs


class BertSquadModel(nn.Cell): class BertSquadModel(nn.Cell):
'''
This class is responsible for SQuAD
'''
"""
Bert finetune model for SQuAD v1.1 task
"""
def __init__(self, config, is_training, num_labels=2, dropout_prob=0.0, use_one_hot_embeddings=False): def __init__(self, config, is_training, num_labels=2, dropout_prob=0.0, use_one_hot_embeddings=False):
super(BertSquadModel, self).__init__() super(BertSquadModel, self).__init__()
self.bert = BertModel(config, is_training, use_one_hot_embeddings) self.bert = BertModel(config, is_training, use_one_hot_embeddings)
@@ -401,9 +427,9 @@ class BertNER(nn.Cell):
return loss return loss


class BertSquad(nn.Cell): class BertSquad(nn.Cell):
'''
"""
Train interface for SQuAD finetuning task. Train interface for SQuAD finetuning task.
'''
"""
def __init__(self, config, is_training, num_labels=2, dropout_prob=0.0, use_one_hot_embeddings=False): def __init__(self, config, is_training, num_labels=2, dropout_prob=0.0, use_one_hot_embeddings=False):
super(BertSquad, self).__init__() super(BertSquad, self).__init__()
self.bert = BertSquadModel(config, is_training, num_labels, dropout_prob, use_one_hot_embeddings) self.bert = BertSquadModel(config, is_training, num_labels, dropout_prob, use_one_hot_embeddings)
@@ -432,3 +458,24 @@ class BertSquad(nn.Cell):
end_logits = self.squeeze(logits[:, :, 1:2]) end_logits = self.squeeze(logits[:, :, 1:2])
total_loss = (unique_id, start_logits, end_logits) total_loss = (unique_id, start_logits, end_logits)
return total_loss return total_loss


class BertReg(nn.Cell):
"""
Bert finetune model with loss for regression task
"""
def __init__(self, config, is_training, num_labels=2, dropout_prob=0.0, use_one_hot_embeddings=False):
super(BertReg, self).__init__()
self.bert = BertRegressionModel(config, is_training, num_labels, dropout_prob, use_one_hot_embeddings)
self.loss = nn.MSELoss()
self.is_training = is_training
self.sigmoid = P.Sigmoid()
self.cast = P.Cast()
self.mul = P.Mul()
def construct(self, input_ids, input_mask, token_type_id, labels):
logits = self.bert(input_ids, input_mask, token_type_id)
if self.is_training:
loss = self.loss(logits, labels)
else:
loss = logits
return loss

Loading…
Cancel
Save