You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

squadeval.py 4.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Evaluation script for SQuAD task"""
  16. import os
  17. import collections
  18. import mindspore.dataset as de
  19. import mindspore.dataset.transforms.c_transforms as C
  20. import mindspore.common.dtype as mstype
  21. from mindspore import context
  22. from mindspore.common.tensor import Tensor
  23. from mindspore.train.model import Model
  24. from mindspore.train.serialization import load_checkpoint, load_param_into_net
  25. from src import tokenization
  26. from src.evaluation_config import cfg, bert_net_cfg
  27. from src.utils import BertSquad
  28. from src.create_squad_data import read_squad_examples, convert_examples_to_features
  29. from src.run_squad import write_predictions
  30. def get_squad_dataset(batch_size=1, repeat_count=1, distribute_file=''):
  31. """get SQuAD dataset from tfrecord"""
  32. ds = de.TFRecordDataset([cfg.data_file], cfg.schema_file, columns_list=["input_ids", "input_mask",
  33. "segment_ids", "unique_ids"],
  34. shuffle=False)
  35. type_cast_op = C.TypeCast(mstype.int32)
  36. ds = ds.map(input_columns="segment_ids", operations=type_cast_op)
  37. ds = ds.map(input_columns="input_ids", operations=type_cast_op)
  38. ds = ds.map(input_columns="input_mask", operations=type_cast_op)
  39. ds = ds.repeat(repeat_count)
  40. ds = ds.batch(batch_size, drop_remainder=True)
  41. return ds
  42. def test_eval():
  43. """Evaluation function for SQuAD task"""
  44. tokenizer = tokenization.FullTokenizer(vocab_file="./vocab.txt", do_lower_case=True)
  45. input_file = "dataset/v1.1/dev-v1.1.json"
  46. eval_examples = read_squad_examples(input_file, False)
  47. eval_features = convert_examples_to_features(
  48. examples=eval_examples,
  49. tokenizer=tokenizer,
  50. max_seq_length=384,
  51. doc_stride=128,
  52. max_query_length=64,
  53. is_training=False,
  54. output_fn=None,
  55. verbose_logging=False)
  56. device_id = int(os.getenv('DEVICE_ID'))
  57. context.set_context(mode=context.GRAPH_MODE, device_target='Ascend', device_id=device_id)
  58. dataset = get_squad_dataset(bert_net_cfg.batch_size, 1)
  59. net = BertSquad(bert_net_cfg, False, 2)
  60. net.set_train(False)
  61. param_dict = load_checkpoint(cfg.finetune_ckpt)
  62. load_param_into_net(net, param_dict)
  63. model = Model(net)
  64. output = []
  65. RawResult = collections.namedtuple("RawResult", ["unique_id", "start_logits", "end_logits"])
  66. columns_list = ["input_ids", "input_mask", "segment_ids", "unique_ids"]
  67. for data in dataset.create_dict_iterator():
  68. input_data = []
  69. for i in columns_list:
  70. input_data.append(Tensor(data[i]))
  71. input_ids, input_mask, segment_ids, unique_ids = input_data
  72. start_positions = Tensor([1], mstype.float32)
  73. end_positions = Tensor([1], mstype.float32)
  74. is_impossible = Tensor([1], mstype.float32)
  75. logits = model.predict(input_ids, input_mask, segment_ids, start_positions,
  76. end_positions, unique_ids, is_impossible)
  77. ids = logits[0].asnumpy()
  78. start = logits[1].asnumpy()
  79. end = logits[2].asnumpy()
  80. for i in range(bert_net_cfg.batch_size):
  81. unique_id = int(ids[i])
  82. start_logits = [float(x) for x in start[i].flat]
  83. end_logits = [float(x) for x in end[i].flat]
  84. output.append(RawResult(
  85. unique_id=unique_id,
  86. start_logits=start_logits,
  87. end_logits=end_logits))
  88. write_predictions(eval_examples, eval_features, output, 20, 30, True, "./predictions.json",
  89. None, None, False, False)
  90. if __name__ == "__main__":
  91. test_eval()