You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

run_ReadComprehension.py 16 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """
  16. GPT-2 finetune and evaluation script for Reading Comprehension task.
  17. """
  18. import argparse
  19. import time
  20. from mindspore import context
  21. from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
  22. from mindspore.nn import AdamWeightDecay, Lamb, Momentum
  23. from mindspore.train.model import Model
  24. from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, TimeMonitor, LossMonitor
  25. from mindspore.train.serialization import load_checkpoint, load_param_into_net
  26. from src.gpt2_for_finetune import GPT2FinetuneCell, GPT2CoQA
  27. from src.GPT2ForReadComprehension import GPT2CoQAModel
  28. from src.utils.metric_method import F1
  29. from src.finetune_eval_config import cfg, gpt2_net_cfg
  30. from src.dataset import create_language_model_dataset
  31. from src.utils.lr_schedule import GPT2LearningRate
  32. from src.utils.tokenization import Tokenizer
  33. from src.GPT2_generation import GenerateForReadComprehension
  34. from src.utils.get_config_setting import get_train_setting, get_model_setting
  35. def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoint_path="", epoch_num=1):
  36. """
  37. Do train
  38. Args:
  39. dataset: the train dataset.
  40. network: the network with loss
  41. load_checkpoint_path: the file path which saved pretrained model checkpoint.
  42. save_checkpoint_path: the file path which will save finetuned model checkpoint.
  43. epoch_num: the number of epoch.
  44. """
  45. if load_checkpoint_path == "":
  46. raise ValueError("Pretrain model missed, finetune task must load pretrain model!")
  47. steps_per_epoch = dataset.get_dataset_size()
  48. # optimizer
  49. if cfg.optimizer == 'AdamWeightDecay':
  50. lr_schedule = GPT2LearningRate(learning_rate=cfg.AdamWeightDecay.learning_rate,
  51. end_learning_rate=cfg.AdamWeightDecay.end_learning_rate,
  52. warmup_steps=int(steps_per_epoch * epoch_num * 0.1),
  53. decay_steps=steps_per_epoch * epoch_num,
  54. power=cfg.AdamWeightDecay.power)
  55. params = network.trainable_params()
  56. decay_params = list(filter(cfg.AdamWeightDecay.decay_filter, params))
  57. other_params = list(filter(lambda x: not cfg.AdamWeightDecay.decay_filter(x), params))
  58. group_params = [{'params': decay_params, 'weight_decay': cfg.AdamWeightDecay.weight_decay},
  59. {'params': other_params, 'weight_decay': 0.0}]
  60. optimizer = AdamWeightDecay(group_params, lr_schedule, eps=cfg.AdamWeightDecay.eps)
  61. elif cfg.optimizer == 'Lamb':
  62. lr_schedule = GPT2LearningRate(learning_rate=cfg.Lamb.learning_rate,
  63. end_learning_rate=cfg.Lamb.end_learning_rate,
  64. warmup_steps=int(steps_per_epoch * epoch_num * 0.1),
  65. decay_steps=steps_per_epoch * epoch_num,
  66. power=cfg.Lamb.power)
  67. optimizer = Lamb(network.trainable_params(), lr_schedule)
  68. elif cfg.optimizer == 'Momentum':
  69. optimizer = Momentum(network.trainable_params(), cfg.Momentum.learning_rate, cfg.Momentum.momentum)
  70. else:
  71. raise Exception("Optimizer not supported. support: [AdamWeightDecay, Lamb, Momentum]")
  72. # load checkpoint into network
  73. ckpt_config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=1)
  74. prefix_name = "gpt2_rc_" + str(cfg.gpt2_network) + "_" + str(cfg.optimizer) + "_" \
  75. + str(epoch_num) + "_bs" + str(gpt2_net_cfg.batch_size)
  76. ckpoint_cb = ModelCheckpoint(prefix=prefix_name,
  77. directory=None if save_checkpoint_path == "" else save_checkpoint_path,
  78. config=ckpt_config)
  79. param_dict = load_checkpoint(load_checkpoint_path)
  80. final_param_dict = {}
  81. for name, _ in param_dict.items():
  82. final_param_dict['gpt2.gpt2.' + name] = param_dict[name]
  83. final_param_dict['gpt2.dense1.weight'] = param_dict['gpt2_embedding_lookup.embedding_table']
  84. load_param_into_net(network, final_param_dict)
  85. print("Load the pretrained parameter successfully! \n")
  86. update_cell = DynamicLossScaleUpdateCell(loss_scale_value=2 ** 32, scale_factor=2, scale_window=1000)
  87. netwithgrads = GPT2FinetuneCell(network, optimizer=optimizer, scale_update_cell=update_cell)
  88. netwithgrads.set_train(True)
  89. loss_cb = LossMonitor(per_print_times=1)
  90. model = Model(netwithgrads)
  91. callbacks = [TimeMonitor(dataset.get_dataset_size()), loss_cb, ckpoint_cb]
  92. print("=================== Starting Training For Translation Task ====================")
  93. model.train(epoch_num, dataset, callbacks=callbacks, dataset_sink_mode=False)
  94. print("=================== Translation Training Success ====================")
  95. def do_eval(dataset=None, network=None, metric=None, load_checkpoint_path="", eval_type=None, tokenizer_file_path="",
  96. generate_length=1, top_k=1, top_p=1.0, temperature=1.0):
  97. """
  98. Do evaluation on Translation
  99. Args:
  100. dataset: the eval dataset.
  101. network: the network with loss.
  102. metric: the evaluation method.
  103. load_checkpoint_path: the file path which saved finetune model checkpoint.
  104. """
  105. if load_checkpoint_path == "":
  106. raise ValueError("Finetune model missed, evaluation task must load finetune model!")
  107. if metric.lower() == "f1":
  108. print("Prepare to calculate the F1 score ...")
  109. gpt2_rc = network(config=gpt2_net_cfg,
  110. is_training=False,
  111. use_one_hot_embeddings=False)
  112. gpt2_rc.set_train(False)
  113. param_dict = load_checkpoint(load_checkpoint_path)
  114. if eval_type == "zero-shot":
  115. final_param_dict = {}
  116. for name, _ in param_dict.items():
  117. final_param_dict['gpt2.' + name] = param_dict[name]
  118. final_param_dict['dense1.weight'] = param_dict['gpt2_embedding_lookup.embedding_table']
  119. load_param_into_net(gpt2_rc, final_param_dict)
  120. print("load pretrained parameter successfully!\n")
  121. elif eval_type == "finetuned":
  122. load_param_into_net(gpt2_rc, param_dict)
  123. print("load finetuned parameter successfully!\n")
  124. else:
  125. raise ValueError("Evaluation type missed, eval_type should be [zero-shot, finetuned]")
  126. model = Model(gpt2_rc)
  127. tokenizer = Tokenizer(vocab_file=tokenizer_file_path + 'gpt2-vocab.json',
  128. merge_file=tokenizer_file_path + 'gpt2-merges.txt')
  129. callback = F1()
  130. rc_generator = GenerateForReadComprehension(decoder=model,
  131. config=gpt2_net_cfg,
  132. tokenizer=tokenizer,
  133. generate_length=generate_length,
  134. topk_num=top_k,
  135. topp_prob=float(top_p),
  136. temperature=float(temperature)
  137. )
  138. columns_list = ["input_ids", "input_mask", "label_ids"]
  139. print("==================== [F1] Testing ====================")
  140. num_data = 0
  141. for data in dataset.create_dict_iterator():
  142. input_data = []
  143. for i in columns_list:
  144. input_data.append(data[i])
  145. input_ids, _, label_ids = input_data
  146. print("input_ids shape: {}".format(input_ids.shape))
  147. print("label_ids shape: {}".format(label_ids.shape))
  148. passage, pred_answer, gold_answer = rc_generator.generate_for_read_comprehension(input_ids)
  149. for batch_id in range(gpt2_net_cfg.batch_size):
  150. print("============== [F1] {} ================".format(num_data + 1))
  151. print(" | Passage:{}".format(passage[batch_id]))
  152. print(" | Gold_answer:{}".format(gold_answer[batch_id]))
  153. print(" | Pred_answer:{}".format(pred_answer[batch_id]))
  154. pred = callback.get_normalize_answer_token(pred_answer[batch_id])
  155. gold = callback.get_normalize_answer_token(gold_answer[batch_id])
  156. callback.update(pred, gold)
  157. num_data += 1
  158. average_f1_score = callback.f1_score / num_data
  159. print("============== Evaluation =================")
  160. print("| Avg F1 Score:{:.8f}".format(average_f1_score))
  161. print("=============================================\n\n")
  162. print("********************** Testing Finished **********************")
  163. else:
  164. raise ValueError("metric method not supported in Reading Comprehension task, support: [F1]")
  165. def run_Readcomprehension():
  166. '''
  167. run Readcomprehension task
  168. '''
  169. parser = argparse.ArgumentParser(description="Finetune and Evaluate translation")
  170. parser.add_argument("--device_target", type=str, default="Ascend",
  171. help="Device type. Default: Ascend.")
  172. parser.add_argument("--device_id", type=int, default=0,
  173. help="ID of target device. ")
  174. parser.add_argument("--metric_method", type=str, default="F1",
  175. help="The eval method including [F1]. Default: F1.")
  176. parser.add_argument("--do_train", type=str, default="false",
  177. help="Enable train. Default: false.")
  178. parser.add_argument("--do_eval", type=str, default="true",
  179. help="Enable evaluation. Default: false.")
  180. parser.add_argument("--eval_type", type=str, default="zero-shot",
  181. help="The type of evaluation including [zero-shot, finetuned]. Default: zero-shot.")
  182. parser.add_argument("--epoch_num", type=int, default=1,
  183. help="Epoch number. Default: 1.")
  184. parser.add_argument("--train_data_shuffle", type=str, default="true",
  185. help="Enable train data shuffle. Default: true.")
  186. parser.add_argument("--eval_data_shuffle", type=str, default="false",
  187. help="Enable eval data shuffle. Default: false.")
  188. parser.add_argument("--save_finetune_ckpt_path", type=str, default="",
  189. help="Save the checkpoint path.")
  190. parser.add_argument("--load_pretrain_ckpt_path", type=str, default="",
  191. help="Load the checkpoint file path.")
  192. parser.add_argument("--load_finetune_ckpt_path", type=str, default="",
  193. help="Load the checkpoint file path.")
  194. parser.add_argument("--train_data_file_path", type=str, default="",
  195. help="Data path, it is better to use absolute path")
  196. parser.add_argument("--eval_data_file_path", type=str, default="",
  197. help="Data path, it is better to use absolute path")
  198. parser.add_argument("--tokenizer_file_path", type=str, default="",
  199. help="pretrained vocab and merge file path.")
  200. parser.add_argument("--generate_length", type=int, default=55,
  201. help="The generation length of translation sentence.")
  202. parser.add_argument("--top_k", type=int, default=1,
  203. help="Parameter for Top-K sampling.")
  204. parser.add_argument("--top_p", type=str, default="1.0",
  205. help="parameter for Top-P sampling.")
  206. parser.add_argument("--temperature", type=str, default="1.0",
  207. help="Parameter for generation, greater if generation more diverse. ")
  208. args_opt = parser.parse_args()
  209. epoch_num = args_opt.epoch_num
  210. metric = args_opt.metric_method
  211. save_finetune_ckpt_path = args_opt.save_finetune_ckpt_path
  212. load_finetune_ckpt_path = args_opt.load_finetune_ckpt_path
  213. load_pretrain_ckpt_path = args_opt.load_pretrain_ckpt_path
  214. if args_opt.do_train.lower() == "false" and args_opt.do_eval.lower() == "false":
  215. raise ValueError("At least one of 'do_train' or 'do_eval' must be true")
  216. if args_opt.do_train.lower() == "true" and args_opt.train_data_file_path == "":
  217. raise ValueError("'train_data_file_path' must be set when do finetune task")
  218. if args_opt.do_eval.lower() == "true" and args_opt.eval_data_file_path == "":
  219. raise ValueError("'eval_data_file_path' must be set when do evaluation task")
  220. device_target = args_opt.device_target
  221. if device_target == "Ascend":
  222. context.set_context(mode=context.GRAPH_MODE,
  223. device_target=device_target,
  224. device_id=args_opt.device_id,
  225. max_call_depth=3000)
  226. context.set_auto_parallel_context(parallel_mode="stand_alone")
  227. print(" | Device: {} | Device id: {}".format(device_target, args_opt.device_id))
  228. else:
  229. raise Exception("Device target error, Ascend is supported.")
  230. gpt2_loss = GPT2CoQA(config=gpt2_net_cfg,
  231. is_training=True,
  232. use_one_hot_embeddings=False)
  233. if args_opt.do_train.lower() == "true":
  234. get_train_setting(cfg)
  235. get_model_setting(cfg, gpt2_net_cfg)
  236. print("============== Start Loading Translation Train Dataset ==============")
  237. print(" | Train Dataset: {}".format(args_opt.train_data_file_path))
  238. print(" | Checkpoint: {}".format(args_opt.load_pretrain_ckpt_path))
  239. train_dataset = create_language_model_dataset(do_shuffle=(args_opt.train_data_shuffle.lower() == "true"),
  240. dataset_path=args_opt.train_data_file_path)
  241. do_train(train_dataset, gpt2_loss, load_pretrain_ckpt_path, save_finetune_ckpt_path, epoch_num)
  242. if args_opt.do_eval.lower() == "true":
  243. get_model_setting(cfg, gpt2_net_cfg)
  244. print("============ Start Loading Translation Evaluation Dataset ============")
  245. print(" | Eval Dataset: {}".format(args_opt.eval_data_file_path))
  246. print(" | Checkpoint: {}".format(args_opt.load_finetune_ckpt_path))
  247. eval_dataset = create_language_model_dataset(do_shuffle=(args_opt.eval_data_shuffle.lower() == "true"),
  248. dataset_path=args_opt.eval_data_file_path)
  249. do_eval(eval_dataset, GPT2CoQAModel, metric, load_finetune_ckpt_path, args_opt.eval_type,
  250. args_opt.tokenizer_file_path, args_opt.generate_length, args_opt.top_k, args_opt.top_p,
  251. args_opt.temperature)
  252. if __name__ == "__main__":
  253. print("Start Time: \n", time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
  254. run_Readcomprehension()
  255. print("End Time: \n", time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))