# Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """ Bert evaluation script. """ import os import argparse import numpy as np import mindspore.common.dtype as mstype from mindspore import context from mindspore import log as logger from mindspore.common.tensor import Tensor import mindspore.dataset as de import mindspore.dataset.transforms.c_transforms as C from mindspore.train.model import Model from mindspore.train.serialization import load_checkpoint, load_param_into_net from src.evaluation_config import cfg, bert_net_cfg from src.utils import BertNER, BertCLS from src.CRF import postprocess from src.cluener_evaluation import submit from src.finetune_config import tag_to_index class Accuracy(): ''' calculate accuracy ''' def __init__(self): self.acc_num = 0 self.total_num = 0 def update(self, logits, labels): labels = labels.asnumpy() labels = np.reshape(labels, -1) logits = logits.asnumpy() logit_id = np.argmax(logits, axis=-1) self.acc_num += np.sum(labels == logit_id) self.total_num += len(labels) print("=========================accuracy is ", self.acc_num / self.total_num) class F1(): ''' calculate F1 score ''' def __init__(self): self.TP = 0 self.FP = 0 self.FN = 0 def update(self, logits, labels): ''' update F1 score ''' labels = labels.asnumpy() labels = np.reshape(labels, -1) if cfg.use_crf: backpointers, best_tag_id = logits best_path = postprocess(backpointers, best_tag_id) logit_id = [] for ele in best_path: logit_id.extend(ele) else: logits = logits.asnumpy() logit_id = np.argmax(logits, axis=-1) logit_id = np.reshape(logit_id, -1) pos_eva = np.isin(logit_id, [i for i in range(1, cfg.num_labels)]) pos_label = np.isin(labels, [i for i in range(1, cfg.num_labels)]) self.TP += np.sum(pos_eva&pos_label) self.FP += np.sum(pos_eva&(~pos_label)) self.FN += np.sum((~pos_eva)&pos_label) def get_dataset(batch_size=1, repeat_count=1, distribute_file=''): ''' get dataset ''' _ = distribute_file ds = de.TFRecordDataset([cfg.data_file], cfg.schema_file, columns_list=["input_ids", "input_mask", "segment_ids", "label_ids"]) type_cast_op = C.TypeCast(mstype.int32) ds = ds.map(input_columns="segment_ids", operations=type_cast_op) ds = ds.map(input_columns="input_mask", operations=type_cast_op) ds = ds.map(input_columns="input_ids", operations=type_cast_op) ds = ds.map(input_columns="label_ids", operations=type_cast_op) ds = ds.repeat(repeat_count) # apply shuffle operation buffer_size = 960 ds = ds.shuffle(buffer_size=buffer_size) # apply batch operations ds = ds.batch(batch_size, drop_remainder=True) return ds def bert_predict(Evaluation): ''' prediction function ''' target = args_opt.device_target if target == "Ascend": devid = int(os.getenv('DEVICE_ID')) context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=devid) elif target == "GPU": context.set_context(mode=context.GRAPH_MODE, device_target="GPU") if bert_net_cfg.compute_type != mstype.float32: logger.warning('GPU only support fp32 temporarily, run with fp32.') bert_net_cfg.compute_type = mstype.float32 else: raise Exception("Target error, GPU or Ascend is supported.") dataset = get_dataset(bert_net_cfg.batch_size, 1) if cfg.use_crf: net_for_pretraining = Evaluation(bert_net_cfg, False, num_labels=len(tag_to_index), use_crf=True, tag_to_index=tag_to_index, dropout_prob=0.0) else: net_for_pretraining = Evaluation(bert_net_cfg, False, num_labels) net_for_pretraining.set_train(False) param_dict = load_checkpoint(cfg.finetune_ckpt) load_param_into_net(net_for_pretraining, param_dict) model = Model(net_for_pretraining) return model, dataset def test_eval(): ''' evaluation function ''' task_type = BertNER if cfg.task == "NER" else BertCLS model, dataset = bert_predict(task_type) if cfg.clue_benchmark: submit(model, cfg.data_file, bert_net_cfg.seq_length) else: callback = F1() if cfg.task == "NER" else Accuracy() columns_list = ["input_ids", "input_mask", "segment_ids", "label_ids"] for data in dataset.create_dict_iterator(): input_data = [] for i in columns_list: input_data.append(Tensor(data[i])) input_ids, input_mask, token_type_id, label_ids = input_data logits = model.predict(input_ids, input_mask, token_type_id, label_ids) callback.update(logits, label_ids) print("==============================================================") if cfg.task == "NER": print("Precision {:.6f} ".format(callback.TP / (callback.TP + callback.FP))) print("Recall {:.6f} ".format(callback.TP / (callback.TP + callback.FN))) print("F1 {:.6f} ".format(2*callback.TP / (2*callback.TP + callback.FP + callback.FN))) else: print("acc_num {} , total_num {}, accuracy {:.6f}".format(callback.acc_num, callback.total_num, callback.acc_num / callback.total_num)) print("==============================================================") parser = argparse.ArgumentParser(description='Bert eval') parser.add_argument('--device_target', type=str, default='Ascend', help='Device target') args_opt = parser.parse_args() if __name__ == "__main__": num_labels = cfg.num_labels test_eval()