You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

run_squad.py 12 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. '''
  16. Bert finetune and evaluation script.
  17. '''
  18. import os
  19. import argparse
  20. import collections
  21. from src.bert_for_finetune import BertSquadCell, BertSquad
  22. from src.finetune_eval_config import optimizer_cfg, bert_net_cfg
  23. from src.dataset import create_squad_dataset
  24. from src.utils import make_directory, LossCallBack, LoadNewestCkpt, BertLearningRate
  25. import mindspore.common.dtype as mstype
  26. from mindspore import context
  27. from mindspore import log as logger
  28. from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
  29. from mindspore.nn.optim import AdamWeightDecay, Lamb, Momentum
  30. from mindspore.common.tensor import Tensor
  31. from mindspore.train.model import Model
  32. from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, TimeMonitor
  33. from mindspore.train.serialization import load_checkpoint, load_param_into_net
  34. _cur_dir = os.getcwd()
  35. def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoint_path="", epoch_num=1):
  36. """ do train """
  37. if load_checkpoint_path == "":
  38. raise ValueError("Pretrain model missed, finetune task must load pretrain model!")
  39. steps_per_epoch = dataset.get_dataset_size()
  40. # optimizer
  41. if optimizer_cfg.optimizer == 'AdamWeightDecay':
  42. lr_schedule = BertLearningRate(learning_rate=optimizer_cfg.AdamWeightDecay.learning_rate,
  43. end_learning_rate=optimizer_cfg.AdamWeightDecay.end_learning_rate,
  44. warmup_steps=int(steps_per_epoch * epoch_num * 0.1),
  45. decay_steps=steps_per_epoch * epoch_num,
  46. power=optimizer_cfg.AdamWeightDecay.power)
  47. params = network.trainable_params()
  48. decay_params = list(filter(optimizer_cfg.AdamWeightDecay.decay_filter, params))
  49. other_params = list(filter(lambda x: not optimizer_cfg.AdamWeightDecay.decay_filter(x), params))
  50. group_params = [{'params': decay_params, 'weight_decay': optimizer_cfg.AdamWeightDecay.weight_decay},
  51. {'params': other_params, 'weight_decay': 0.0}]
  52. optimizer = AdamWeightDecay(group_params, lr_schedule, eps=optimizer_cfg.AdamWeightDecay.eps)
  53. elif optimizer_cfg.optimizer == 'Lamb':
  54. lr_schedule = BertLearningRate(learning_rate=optimizer_cfg.Lamb.learning_rate,
  55. end_learning_rate=optimizer_cfg.Lamb.end_learning_rate,
  56. warmup_steps=int(steps_per_epoch * epoch_num * 0.1),
  57. decay_steps=steps_per_epoch * epoch_num,
  58. power=optimizer_cfg.Lamb.power)
  59. optimizer = Lamb(network.trainable_params(), learning_rate=lr_schedule)
  60. elif optimizer_cfg.optimizer == 'Momentum':
  61. optimizer = Momentum(network.trainable_params(), learning_rate=optimizer_cfg.Momentum.learning_rate,
  62. momentum=optimizer_cfg.Momentum.momentum)
  63. else:
  64. raise Exception("Optimizer not supported. support: [AdamWeightDecay, Lamb, Momentum]")
  65. # load checkpoint into network
  66. ckpt_config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=1)
  67. ckpoint_cb = ModelCheckpoint(prefix="squad",
  68. directory=None if save_checkpoint_path == "" else save_checkpoint_path,
  69. config=ckpt_config)
  70. param_dict = load_checkpoint(load_checkpoint_path)
  71. load_param_into_net(network, param_dict)
  72. update_cell = DynamicLossScaleUpdateCell(loss_scale_value=2**32, scale_factor=2, scale_window=1000)
  73. netwithgrads = BertSquadCell(network, optimizer=optimizer, scale_update_cell=update_cell)
  74. model = Model(netwithgrads)
  75. callbacks = [TimeMonitor(dataset.get_dataset_size()), LossCallBack(dataset.get_dataset_size()), ckpoint_cb]
  76. model.train(epoch_num, dataset, callbacks=callbacks)
  77. def do_eval(dataset=None, load_checkpoint_path="", eval_batch_size=1):
  78. """ do eval """
  79. if load_checkpoint_path == "":
  80. raise ValueError("Finetune model missed, evaluation task must load finetune model!")
  81. net = BertSquad(bert_net_cfg, False, 2)
  82. net.set_train(False)
  83. param_dict = load_checkpoint(load_checkpoint_path)
  84. load_param_into_net(net, param_dict)
  85. model = Model(net)
  86. output = []
  87. RawResult = collections.namedtuple("RawResult", ["unique_id", "start_logits", "end_logits"])
  88. columns_list = ["input_ids", "input_mask", "segment_ids", "unique_ids"]
  89. for data in dataset.create_dict_iterator(num_epochs=1):
  90. input_data = []
  91. for i in columns_list:
  92. input_data.append(data[i])
  93. input_ids, input_mask, segment_ids, unique_ids = input_data
  94. start_positions = Tensor([1], mstype.float32)
  95. end_positions = Tensor([1], mstype.float32)
  96. is_impossible = Tensor([1], mstype.float32)
  97. logits = model.predict(input_ids, input_mask, segment_ids, start_positions,
  98. end_positions, unique_ids, is_impossible)
  99. ids = logits[0].asnumpy()
  100. start = logits[1].asnumpy()
  101. end = logits[2].asnumpy()
  102. for i in range(eval_batch_size):
  103. unique_id = int(ids[i])
  104. start_logits = [float(x) for x in start[i].flat]
  105. end_logits = [float(x) for x in end[i].flat]
  106. output.append(RawResult(
  107. unique_id=unique_id,
  108. start_logits=start_logits,
  109. end_logits=end_logits))
  110. return output
  111. def run_squad():
  112. """run squad task"""
  113. parser = argparse.ArgumentParser(description="run squad")
  114. parser.add_argument("--device_target", type=str, default="Ascend", choices=["Ascend", "GPU"],
  115. help="Device type, default is Ascend")
  116. parser.add_argument("--do_train", type=str, default="false", choices=["true", "false"],
  117. help="Eable train, default is false")
  118. parser.add_argument("--do_eval", type=str, default="false", choices=["true", "false"],
  119. help="Eable eval, default is false")
  120. parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
  121. parser.add_argument("--epoch_num", type=int, default=3, help="Epoch number, default is 1.")
  122. parser.add_argument("--num_class", type=int, default=2, help="The number of class, default is 2.")
  123. parser.add_argument("--train_data_shuffle", type=str, default="true", choices=["true", "false"],
  124. help="Enable train data shuffle, default is true")
  125. parser.add_argument("--eval_data_shuffle", type=str, default="false", choices=["true", "false"],
  126. help="Enable eval data shuffle, default is false")
  127. parser.add_argument("--train_batch_size", type=int, default=32, help="Train batch size, default is 32")
  128. parser.add_argument("--eval_batch_size", type=int, default=1, help="Eval batch size, default is 1")
  129. parser.add_argument("--vocab_file_path", type=str, default="", help="Vocab file path")
  130. parser.add_argument("--eval_json_path", type=str, default="", help="Evaluation json file path, can be eval.json")
  131. parser.add_argument("--save_finetune_checkpoint_path", type=str, default="", help="Save checkpoint path")
  132. parser.add_argument("--load_pretrain_checkpoint_path", type=str, default="", help="Load checkpoint file path")
  133. parser.add_argument("--load_finetune_checkpoint_path", type=str, default="", help="Load checkpoint file path")
  134. parser.add_argument("--train_data_file_path", type=str, default="",
  135. help="Data path, it is better to use absolute path")
  136. parser.add_argument("--schema_file_path", type=str, default="",
  137. help="Schema path, it is better to use absolute path")
  138. args_opt = parser.parse_args()
  139. epoch_num = args_opt.epoch_num
  140. load_pretrain_checkpoint_path = args_opt.load_pretrain_checkpoint_path
  141. save_finetune_checkpoint_path = args_opt.save_finetune_checkpoint_path
  142. load_finetune_checkpoint_path = args_opt.load_finetune_checkpoint_path
  143. if args_opt.do_train.lower() == "false" and args_opt.do_eval.lower() == "false":
  144. raise ValueError("At least one of 'do_train' or 'do_eval' must be true")
  145. if args_opt.do_train.lower() == "true" and args_opt.train_data_file_path == "":
  146. raise ValueError("'train_data_file_path' must be set when do finetune task")
  147. if args_opt.do_eval.lower() == "true":
  148. if args_opt.vocab_file_path == "":
  149. raise ValueError("'vocab_file_path' must be set when do evaluation task")
  150. if args_opt.eval_json_path == "":
  151. raise ValueError("'tokenization_file_path' must be set when do evaluation task")
  152. target = args_opt.device_target
  153. if target == "Ascend":
  154. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id)
  155. elif target == "GPU":
  156. context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
  157. if bert_net_cfg.compute_type != mstype.float32:
  158. logger.warning('GPU only support fp32 temporarily, run with fp32.')
  159. bert_net_cfg.compute_type = mstype.float32
  160. else:
  161. raise Exception("Target error, GPU or Ascend is supported.")
  162. netwithloss = BertSquad(bert_net_cfg, True, 2, dropout_prob=0.1)
  163. if args_opt.do_train.lower() == "true":
  164. ds = create_squad_dataset(batch_size=args_opt.train_batch_size, repeat_count=1,
  165. data_file_path=args_opt.train_data_file_path,
  166. schema_file_path=args_opt.schema_file_path,
  167. do_shuffle=(args_opt.train_data_shuffle.lower() == "true"))
  168. do_train(ds, netwithloss, load_pretrain_checkpoint_path, save_finetune_checkpoint_path, epoch_num)
  169. if args_opt.do_eval.lower() == "true":
  170. if save_finetune_checkpoint_path == "":
  171. load_finetune_checkpoint_dir = _cur_dir
  172. else:
  173. load_finetune_checkpoint_dir = make_directory(save_finetune_checkpoint_path)
  174. load_finetune_checkpoint_path = LoadNewestCkpt(load_finetune_checkpoint_dir,
  175. ds.get_dataset_size(), epoch_num, "squad")
  176. if args_opt.do_eval.lower() == "true":
  177. from src import tokenization
  178. from src.create_squad_data import read_squad_examples, convert_examples_to_features
  179. from src.squad_get_predictions import write_predictions
  180. from src.squad_postprocess import SQuad_postprocess
  181. tokenizer = tokenization.FullTokenizer(vocab_file=args_opt.vocab_file_path, do_lower_case=True)
  182. eval_examples = read_squad_examples(args_opt.eval_json_path, False)
  183. eval_features = convert_examples_to_features(
  184. examples=eval_examples,
  185. tokenizer=tokenizer,
  186. max_seq_length=bert_net_cfg.seq_length,
  187. doc_stride=128,
  188. max_query_length=64,
  189. is_training=False,
  190. output_fn=None,
  191. vocab_file=args_opt.vocab_file_path)
  192. ds = create_squad_dataset(batch_size=args_opt.eval_batch_size, repeat_count=1,
  193. data_file_path=eval_features,
  194. schema_file_path=args_opt.schema_file_path, is_training=False,
  195. do_shuffle=(args_opt.eval_data_shuffle.lower() == "true"))
  196. outputs = do_eval(ds, load_finetune_checkpoint_path, args_opt.eval_batch_size)
  197. all_predictions = write_predictions(eval_examples, eval_features, outputs, 20, 30, True)
  198. SQuad_postprocess(args_opt.eval_json_path, all_predictions, output_metrics="output.json")
  199. if __name__ == "__main__":
  200. run_squad()