|
|
|
@@ -18,9 +18,11 @@ Bert evaluation script. |
|
|
|
""" |
|
|
|
|
|
|
|
import os |
|
|
|
import argparse |
|
|
|
import numpy as np |
|
|
|
import mindspore.common.dtype as mstype |
|
|
|
from mindspore import context |
|
|
|
from mindspore import log as logger |
|
|
|
from mindspore.common.tensor import Tensor |
|
|
|
import mindspore.dataset as de |
|
|
|
import mindspore.dataset.transforms.c_transforms as C |
|
|
|
@@ -105,8 +107,17 @@ def bert_predict(Evaluation): |
|
|
|
''' |
|
|
|
prediction function |
|
|
|
''' |
|
|
|
devid = int(os.getenv('DEVICE_ID')) |
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=devid) |
|
|
|
target = args_opt.device_target |
|
|
|
if target == "Ascend": |
|
|
|
devid = int(os.getenv('DEVICE_ID')) |
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=devid) |
|
|
|
elif target == "GPU": |
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target="GPU") |
|
|
|
if bert_net_cfg.compute_type != mstype.float32: |
|
|
|
logger.warning('GPU only support fp32 temporarily, run with fp32.') |
|
|
|
bert_net_cfg.compute_type = mstype.float32 |
|
|
|
else: |
|
|
|
raise Exception("Target error, GPU or Ascend is supported.") |
|
|
|
dataset = get_dataset(bert_net_cfg.batch_size, 1) |
|
|
|
if cfg.use_crf: |
|
|
|
net_for_pretraining = Evaluation(bert_net_cfg, False, num_labels=len(tag_to_index), use_crf=True, |
|
|
|
@@ -147,6 +158,9 @@ def test_eval(): |
|
|
|
callback.acc_num / callback.total_num)) |
|
|
|
print("==============================================================") |
|
|
|
|
|
|
|
parser = argparse.ArgumentParser(description='Bert eval') |
|
|
|
parser.add_argument('--device_target', type=str, default='Ascend', help='Device target') |
|
|
|
args_opt = parser.parse_args() |
|
|
|
if __name__ == "__main__": |
|
|
|
num_labels = cfg.num_labels |
|
|
|
test_eval() |