You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

eval.py 4.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """eval standalone script"""
  16. import os
  17. import re
  18. import argparse
  19. from mindspore import context
  20. from mindspore.train.serialization import load_checkpoint, load_param_into_net
  21. from src.dataset import create_dataset
  22. from src.config import eval_cfg, student_net_cfg, task_cfg
  23. from src.tinybert_model import BertModelCLS
  24. DATA_NAME = 'eval.tf_record'
  25. def parse_args():
  26. """
  27. parse args
  28. """
  29. parser = argparse.ArgumentParser(description='ternarybert evaluation')
  30. parser.add_argument('--device_target', type=str, default='GPU', choices=['Ascend', 'GPU'],
  31. help='Device where the code will be implemented. (Default: GPU)')
  32. parser.add_argument('--device_id', type=int, default=0, help='Device id. (Default: 0)')
  33. parser.add_argument('--model_dir', type=str, default='', help='The checkpoint directory of model.')
  34. parser.add_argument('--data_dir', type=str, default='', help='Data directory.')
  35. parser.add_argument('--task_name', type=str, default='sts-b', choices=['sts-b', 'qnli', 'mnli'],
  36. help='The name of the task to train. (Default: sts-b)')
  37. parser.add_argument('--dataset_type', type=str, default='tfrecord', choices=['tfrecord', 'mindrecord'],
  38. help='The name of the task to train. (Default: tfrecord)')
  39. parser.add_argument('--batch_size', type=int, default=32, help='Batch size for evaluating')
  40. return parser.parse_args()
  41. def get_ckpt(ckpt_file):
  42. lists = os.listdir(ckpt_file)
  43. lists.sort(key=lambda fn: os.path.getmtime(ckpt_file + '/' + fn))
  44. return os.path.join(ckpt_file, lists[-1])
  45. def do_eval_standalone(args_opt):
  46. """
  47. do eval standalone
  48. """
  49. ckpt_file = os.path.join(args_opt.model_dir, args_opt.task_name)
  50. ckpt_file = get_ckpt(ckpt_file)
  51. print('ckpt file:', ckpt_file)
  52. task = task_cfg[args_opt.task_name]
  53. student_net_cfg.seq_length = task.seq_length
  54. eval_cfg.batch_size = args_opt.batch_size
  55. eval_data_dir = os.path.join(args_opt.data_dir, args_opt.task_name, DATA_NAME)
  56. context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args.device_id)
  57. eval_dataset = create_dataset(batch_size=eval_cfg.batch_size,
  58. device_num=1,
  59. rank=0,
  60. do_shuffle='false',
  61. data_dir=eval_data_dir,
  62. data_type=args_opt.dataset_type,
  63. seq_length=task.seq_length,
  64. task_type=task.task_type,
  65. drop_remainder=False)
  66. print('eval dataset size:', eval_dataset.get_dataset_size())
  67. print('eval dataset batch size:', eval_dataset.get_batch_size())
  68. eval_model = BertModelCLS(student_net_cfg, False, task.num_labels, 0.0, phase_type='student')
  69. param_dict = load_checkpoint(ckpt_file)
  70. new_param_dict = {}
  71. for key, value in param_dict.items():
  72. new_key = re.sub('tinybert_', 'bert_', key)
  73. new_key = re.sub('^bert.', '', new_key)
  74. new_param_dict[new_key] = value
  75. load_param_into_net(eval_model, new_param_dict)
  76. eval_model.set_train(False)
  77. columns_list = ["input_ids", "input_mask", "segment_ids", "label_ids"]
  78. callback = task.metrics()
  79. for step, data in enumerate(eval_dataset.create_dict_iterator()):
  80. input_data = []
  81. for i in columns_list:
  82. input_data.append(data[i])
  83. input_ids, input_mask, token_type_id, label_ids = input_data
  84. _, _, logits, _ = eval_model(input_ids, token_type_id, input_mask)
  85. callback.update(logits, label_ids)
  86. print('eval step: {}, {}: {}'.format(step, callback.name, callback.get_metrics()))
  87. metrics = callback.get_metrics()
  88. print('The best {}: {}'.format(callback.name, metrics))
  89. if __name__ == '__main__':
  90. args = parse_args()
  91. do_eval_standalone(args)