You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

run_summarization.py 16 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300
  1. # -*- coding: utf-8 -*-
  2. # Copyright 2020 Huawei Technologies Co., Ltd
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. # ============================================================================
  16. """
  17. GPT-2 finetune and evaluation script for Summarization task.
  18. """
  19. import time
  20. import argparse
  21. from mindspore import context
  22. from mindspore.nn import AdamWeightDecay, Lamb, Momentum
  23. from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
  24. from mindspore.train.model import Model
  25. from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, TimeMonitor, LossMonitor
  26. from mindspore.train.serialization import load_checkpoint, load_param_into_net
  27. from src.GPT2ForSummarization import GPT2SummarizationModel
  28. from src.gpt2_for_finetune import GPT2Summarization, GPT2FinetuneCell
  29. from src.finetune_eval_config import cfg, gpt2_net_cfg
  30. from src.utils.metric_method import Rouge
  31. from src.dataset import create_language_model_dataset
  32. from src.utils.lr_schedule import GPT2LearningRate
  33. from src.utils.tokenization import Tokenizer
  34. from src.utils.task_utils import clean_hypo, modify_paramdict
  35. from src.GPT2_generation import GenerateForSummarization
  36. from src.utils.get_config_setting import get_train_setting, get_model_setting
  37. def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoint_path="", epoch_num=1):
  38. """
  39. Do train
  40. Args:
  41. dataset: the train dataset.
  42. network: the network with loss
  43. load_checkpoint_path: the file path which saved pretrain model checkpoint.
  44. save_checkpoint_path: the file path which will save finetune model checkpoint.
  45. epoch_num: the number of epoch
  46. """
  47. if load_checkpoint_path == "":
  48. raise ValueError("Pretrain model missed, finetune task must load pretrain model!")
  49. steps_per_epoch = dataset.get_dataset_size()
  50. # optimizer
  51. if cfg.optimizer == 'AdamWeightDecay':
  52. lr_schedule = GPT2LearningRate(learning_rate=cfg.AdamWeightDecay.learning_rate,
  53. end_learning_rate=cfg.AdamWeightDecay.end_learning_rate,
  54. warmup_steps=int(steps_per_epoch * epoch_num * 0.1),
  55. decay_steps=steps_per_epoch * epoch_num,
  56. power=cfg.AdamWeightDecay.power)
  57. params = network.trainable_params()
  58. decay_params = list(filter(cfg.AdamWeightDecay.decay_filter, params))
  59. other_params = list(
  60. filter(lambda x: not cfg.AdamWeightDecay.decay_filter(x), params))
  61. group_params = [{'params': decay_params, 'weight_decay': cfg.AdamWeightDecay.weight_decay},
  62. {'params': other_params, 'weight_decay': 0.0}]
  63. optimizer = AdamWeightDecay(group_params, lr_schedule, eps=cfg.AdamWeightDecay.eps)
  64. elif cfg.optimizer == 'Lamb':
  65. lr_schedule = GPT2LearningRate(learning_rate=cfg.Lamb.learning_rate,
  66. end_learning_rate=cfg.Lamb.end_learning_rate,
  67. warmup_steps=int(steps_per_epoch * epoch_num * 0.1),
  68. decay_steps=steps_per_epoch * epoch_num,
  69. power=cfg.Lamb.power)
  70. optimizer = Lamb(network.trainable_params(), lr_schedule)
  71. elif cfg.optimizer == 'Momentum':
  72. optimizer = Momentum(network.trainable_params(), cfg.Momentum.learning_rate, cfg.Momentum.momentum)
  73. else:
  74. raise Exception("Optimizer not supported. support: [AdamWeightDecay, Lamb, Momentum]")
  75. # load checkpoint into network
  76. ckpt_config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=1)
  77. prefix_name = "gpt2_summarization_" + str(cfg.gpt2_network) + "_" + str(cfg.optimizer) + "_" \
  78. + str(epoch_num) + "_bs" + str(gpt2_net_cfg.batch_size)
  79. ckpoint_cb = ModelCheckpoint(prefix=prefix_name,
  80. directory=None if save_checkpoint_path == "" else save_checkpoint_path,
  81. config=ckpt_config)
  82. param_dict = load_checkpoint(load_checkpoint_path)
  83. final_param_dict = {}
  84. for name, _ in param_dict.items():
  85. final_param_dict['gpt2.gpt2.' + name] = param_dict[name]
  86. final_param_dict['gpt2.lm_head.weight'] = param_dict['gpt2_embedding_lookup.embedding_table']
  87. load_param_into_net(network, final_param_dict)
  88. print("Load pretrained parameter successfully!\n")
  89. update_cell = DynamicLossScaleUpdateCell(loss_scale_value=2 ** 32, scale_factor=2, scale_window=1000)
  90. netwithgrads = GPT2FinetuneCell(network, optimizer=optimizer, scale_update_cell=update_cell)
  91. netwithgrads.set_train(True)
  92. loss_cb = LossMonitor(per_print_times=1)
  93. model = Model(netwithgrads)
  94. callbacks = [TimeMonitor(dataset.get_dataset_size()), loss_cb, ckpoint_cb]
  95. print("============== Starting Finetuning ==============")
  96. model.train(epoch_num, dataset, callbacks=callbacks, dataset_sink_mode=False)
  97. print("============== Finetuning Success ==============")
  98. def eval_result_print(metric="Rouge", callback=None):
  99. """
  100. print eval result
  101. """
  102. if metric == "Rouge":
  103. print("Rouge-1 {:.8f}, Rouge-2 {:.8f}, Rouge-L {:.8f}, Rouge-AVG{:.8f}".
  104. format(callback.Rouge1 / callback.total_num,
  105. callback.Rouge2 / callback.total_num,
  106. callback.RougeL / callback.total_num,
  107. (callback.Rouge1 + callback.Rouge2 + callback.RougeL) / (3.0 * callback.total_num)))
  108. else:
  109. raise ValueError("metric method '{}' not supported, support: [Rouge]. ".format(str(metric)))
  110. def do_eval(dataset=None, network=None, metric=None, load_checkpoint_path="", eval_type=None, tokenizer_file="",
  111. top_k=None, top_p=None, temperature=None, generate_length=None):
  112. """
  113. Do evaluation on summarization
  114. """
  115. if load_checkpoint_path == "":
  116. raise ValueError("Finetune model missed, evaluation task must load finetune model!")
  117. if metric.lower() == "rouge":
  118. print("Prepare to calculate the Rouge score ...")
  119. callback = Rouge()
  120. gpt2_loss = network(config=gpt2_net_cfg,
  121. is_training=False,
  122. use_one_hot_embeddings=False)
  123. gpt2_loss.set_train(False)
  124. param_dict = load_checkpoint(load_checkpoint_path)
  125. reorganized_param_dict = modify_paramdict(param_dict, mode=eval_type, model_prefix="gpt2.")
  126. load_param_into_net(gpt2_loss, reorganized_param_dict)
  127. # load nn.Cell into Model and initiate tokenizer and Sample
  128. model = Model(gpt2_loss)
  129. tokenizer = Tokenizer(vocab_file=tokenizer_file + 'gpt2-vocab.json',
  130. merge_file=tokenizer_file + 'gpt2-merges.txt')
  131. # load data and process text generation
  132. columns_list = ["input_ids", "input_mask", "label_ids"]
  133. summarization_generator = GenerateForSummarization(model,
  134. config=gpt2_net_cfg,
  135. tokenizer=tokenizer,
  136. select_sentence=3,
  137. eval_type=eval_type,
  138. topk=top_k,
  139. topp=float(top_p),
  140. temperature=float(temperature),
  141. generate_length=generate_length)
  142. num_data = 1
  143. print("==================== [Summrization] Testing ====================")
  144. for data in dataset.create_dict_iterator():
  145. input_data = []
  146. for value in columns_list:
  147. input_data.append(data[value])
  148. input_ids, _, label_ids = input_data
  149. print(" | [ROUGE] number : {} / {} ".format(num_data, dataset.get_dataset_size()))
  150. print("input_ids shape: {}".format(input_ids.shape))
  151. print("label_ids shape: {}".format(label_ids.shape))
  152. hypothesis, ref = summarization_generator.generate_for_summarization(input_ids)
  153. if ref[0] == '' or ref[0] is None:
  154. print("Sorry ref_list is None, skip it!")
  155. continue
  156. print("REF str:\n ", ref, "\nHYPO str:\n", hypothesis, "\n")
  157. for batch_idx in range(gpt2_net_cfg.batch_size):
  158. hypothesis[batch_idx] = clean_hypo(hypothesis[batch_idx])
  159. for batch_idx in range(gpt2_net_cfg.batch_size):
  160. hypothesis[batch_idx] = hypothesis[batch_idx].lower()
  161. ref[batch_idx] = ref[batch_idx].lower()
  162. callback.update(hypothesis, ref)
  163. num_data += 1
  164. print("\n\n")
  165. print("**********************************************************")
  166. eval_result_print(metric, callback)
  167. print("******************** Testing Finished ********************")
  168. else:
  169. raise ValueError("metric method not supported in summarization, support: [Rouge]")
  170. def run_summarization():
  171. """
  172. Run Summarization task.
  173. """
  174. # set argument parser
  175. parser = argparse.ArgumentParser(description="Finetune and Evaluate Summrization")
  176. # context and task settings
  177. parser.add_argument("--device_target", type=str, default="Ascend",
  178. help="Device type. Default: Ascend.")
  179. parser.add_argument("--device_id", type=int, default=4,
  180. help="ID of target device.")
  181. parser.add_argument("--do_train", type=str, default="false",
  182. help="Enable train. Default: false.")
  183. parser.add_argument("--do_eval", type=str, default="true",
  184. help="Enable evaluation. Default: false.")
  185. parser.add_argument("--eval_type", type=str, default="finetuned",
  186. help="The type of evaluation including [zero-shot, finetuned]. Default: zero-shot.")
  187. parser.add_argument("--metric_method", type=str, default="Rouge",
  188. help="The eval method including [Rouge(Rouge1,Rouge2,RougeL,Rouge Avg)]. Default: Rouge.")
  189. parser.add_argument("--epoch_num", type=int, default=2,
  190. help="Epoch number. Default: 2.")
  191. # dataset and params_dict file settings
  192. parser.add_argument("--train_data_shuffle", type=str, default="true",
  193. help="Enable train data shuffle. Default: true.")
  194. parser.add_argument("--eval_data_shuffle", type=str, default="false",
  195. help="Enable eval data shuffle. Default: false.")
  196. parser.add_argument("--save_finetune_ckpt_path", type=str, default="",
  197. help="Save the checkpoint path.")
  198. parser.add_argument("--load_pretrain_ckpt_path", type=str, default="",
  199. help="Load the checkpoint file path.")
  200. parser.add_argument("--load_finetune_ckpt_path", type=str, default="",
  201. help="Load the checkpoint file path.")
  202. parser.add_argument("--train_data_file_path", type=str, default="",
  203. help="Data path, it is better to use absolute path")
  204. parser.add_argument("--eval_data_file_path", type=str, default="",
  205. help="Data path, it is better to use absolute path")
  206. # sampling settings
  207. parser.add_argument("--top_k", type=int, default=2,
  208. help="top k tokens chosen for sampling")
  209. parser.add_argument("--top_p", type=str, default="1.0",
  210. help="top p accumulated probability threshold for logit to be counted")
  211. parser.add_argument("--generate_length", type=int, default=100,
  212. help="the number of generated tokens.")
  213. parser.add_argument("--temperature", type=str, default="1.0",
  214. help="temperature on logits for sampling")
  215. parser.add_argument("--tokenizer_file_path", type=str, default="",
  216. help="vocab & merge file path")
  217. args_opt = parser.parse_args()
  218. epoch_num = args_opt.epoch_num
  219. metric = args_opt.metric_method
  220. save_finetune_ckpt_path = args_opt.save_finetune_ckpt_path
  221. load_finetune_ckpt_path = args_opt.load_finetune_ckpt_path
  222. load_pretrain_ckpt_path = args_opt.load_pretrain_ckpt_path
  223. eval_type = args_opt.eval_type
  224. tokenizer_file = args_opt.tokenizer_file_path
  225. if args_opt.do_train.lower() == "false" and args_opt.do_eval.lower() == "false":
  226. raise ValueError("At least one of 'do_train' or 'do_eval' must be true")
  227. if args_opt.do_train.lower() == "true" and args_opt.train_data_file_path == "":
  228. raise ValueError("'train_data_file_path' must be set when do finetune task")
  229. if args_opt.do_eval.lower() == "true" and args_opt.eval_data_file_path == "":
  230. raise ValueError("'eval_data_file_path' must be set when do evaluation task")
  231. device = args_opt.device_target
  232. if device == "Ascend":
  233. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id)
  234. context.set_auto_parallel_context(parallel_mode="stand_alone")
  235. print(" | Device: {} | Device id: {}".format(device, args_opt.device_id))
  236. else:
  237. raise Exception("Device target error, Ascend is supported.")
  238. if args_opt.do_train.lower() == "true":
  239. get_train_setting(cfg)
  240. get_model_setting(cfg, gpt2_net_cfg)
  241. train_data_file_path = args_opt.train_data_file_path
  242. gpt2_loss = GPT2Summarization(config=gpt2_net_cfg,
  243. is_training=True,
  244. use_one_hot_embeddings=False)
  245. print("============== Start Loading Train Dataset ============")
  246. train_dataset = create_language_model_dataset(do_shuffle=(args_opt.train_data_shuffle.lower() == "true"),
  247. dataset_path=train_data_file_path)
  248. do_train(train_dataset, gpt2_loss, load_pretrain_ckpt_path, save_finetune_ckpt_path, epoch_num)
  249. if args_opt.do_eval.lower() == "true":
  250. get_model_setting(cfg, gpt2_net_cfg)
  251. eval_dataset_file_path = args_opt.eval_data_file_path
  252. print("============== Start Loading Evaluation Dataset ============")
  253. eval_dataset = create_language_model_dataset(do_shuffle=(args_opt.train_data_shuffle.lower() == "true"),
  254. dataset_path=eval_dataset_file_path)
  255. do_eval(eval_dataset, GPT2SummarizationModel, metric, load_finetune_ckpt_path, eval_type, tokenizer_file,
  256. args_opt.top_k, args_opt.top_p, args_opt.temperature, args_opt.generate_length)
  257. if __name__ == "__main__":
  258. print("Start Time: ", time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
  259. run_summarization()
  260. print("End Time: ", time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))