You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

cloud_eval.py 3.9 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. import argparse
  16. import os
  17. import sys
  18. from time import time
  19. from mindspore import context
  20. from mindspore.train.serialization import load_checkpoint
  21. from src.config import eval_cfg, server_net_cfg
  22. from src.dataset import load_datasets
  23. from src.utils import restore_params
  24. from src.model import AlbertModelCLS
  25. from src.tokenization import CustomizedTextTokenizer
  26. from src.assessment_method import Accuracy
  27. def parse_args():
  28. """
  29. parse args
  30. """
  31. parser = argparse.ArgumentParser(description='server eval task')
  32. parser.add_argument('--device_target', type=str, default='GPU', choices=['Ascend', 'GPU', 'CPU'])
  33. parser.add_argument('--device_id', type=str, default='0')
  34. parser.add_argument('--tokenizer_dir', type=str, default='../model_save/init/')
  35. parser.add_argument('--eval_data_dir', type=str, default='../datasets/eval/')
  36. parser.add_argument('--model_path', type=str, default='../model_save/train_server/0.ckpt')
  37. parser.add_argument('--vocab_map_ids_path', type=str, default='../model_save/init/vocab_map_ids.txt')
  38. return parser.parse_args()
  39. def server_eval(args):
  40. start = time()
  41. # some parameters
  42. os.environ['CUDA_VISIBLE_DEVICES'] = args.device_id
  43. tokenizer_dir = args.tokenizer_dir
  44. eval_data_dir = args.eval_data_dir
  45. model_path = args.model_path
  46. vocab_map_ids_path = args.vocab_map_ids_path
  47. # mindspore context
  48. context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
  49. print('Context setting is done! Time cost: {}'.format(time() - start))
  50. sys.stdout.flush()
  51. start = time()
  52. # data process
  53. tokenizer = CustomizedTextTokenizer.from_pretrained(tokenizer_dir, vocab_map_ids_path=vocab_map_ids_path)
  54. datasets_list, _ = load_datasets(
  55. eval_data_dir, server_net_cfg.seq_length, tokenizer, eval_cfg.batch_size,
  56. label_list=None,
  57. do_shuffle=False,
  58. drop_remainder=False,
  59. output_dir=None)
  60. print('Data process is done! Time cost: {}'.format(time() - start))
  61. sys.stdout.flush()
  62. start = time()
  63. # main model
  64. albert_model_cls = AlbertModelCLS(server_net_cfg)
  65. albert_model_cls.set_train(False)
  66. param_dict = load_checkpoint(model_path)
  67. restore_params(albert_model_cls, param_dict)
  68. print('Model construction is done! Time cost: {}'.format(time() - start))
  69. sys.stdout.flush()
  70. start = time()
  71. # eval
  72. callback = Accuracy()
  73. global_step = 0
  74. for datasets in datasets_list:
  75. for batch in datasets.create_tuple_iterator():
  76. input_ids, attention_mask, token_type_ids, label_ids, _ = batch
  77. logits = albert_model_cls(input_ids, attention_mask, token_type_ids)
  78. callback.update(logits, label_ids)
  79. print('eval step: {}, {}: {}'.format(global_step, callback.name, callback.get_metrics()))
  80. sys.stdout.flush()
  81. global_step += 1
  82. metrics = callback.get_metrics()
  83. print('Final {}: {}'.format(callback.name, metrics))
  84. sys.stdout.flush()
  85. print('Evaluating process is done! Time cost: {}'.format(time() - start))
  86. sys.stdout.flush()
  87. if __name__ == '__main__':
  88. args_opt = parse_args()
  89. server_eval(args_opt)