You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

run_lambada.py 17 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """
  16. GPT-2 finetune and evaluation script for LAMBADA task.
  17. """
  18. import argparse
  19. import math
  20. import time
  21. from mindspore import context
  22. from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
  23. from mindspore.nn import AdamWeightDecay, Lamb, Momentum
  24. from mindspore.train.model import Model
  25. from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, TimeMonitor, LossMonitor
  26. from mindspore.train.serialization import load_checkpoint, load_param_into_net
  27. from src.gpt2_for_finetune import GPT2FinetuneCell, GPT2Lambada
  28. from src.finetune_eval_config import cfg, gpt2_net_cfg
  29. from src.utils.metric_method import LastWordAccuracy
  30. from src.dataset import create_language_model_dataset, create_lambada_control_dataset
  31. from src.utils.lr_schedule import GPT2LearningRate
  32. from src.utils.task_utils import get_final_word_label
  33. from src.utils.tokenization import Tokenizer
  34. from src.GPT2_generation import GenerateForLambada
  35. from src.utils.CrossEntropy import CrossEntropyCalculationWithMask
  36. from src.utils.get_config_setting import get_train_setting, get_model_setting
  37. from src.utils.task_utils import calculate_final_word_loss
  38. def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoint_path="", epoch_num=1):
  39. """
  40. Do train
  41. Args:
  42. dataset: the train dataset.
  43. network: the network with loss
  44. load_checkpoint_path: the file path which saved pretrain model checkpoint.
  45. save_checkpoint_path: the file path which will save finetune model checkpoint.
  46. epoch_num: the number of epoch
  47. """
  48. if load_checkpoint_path == "":
  49. raise ValueError("Pretrain model missed, finetune task must load pretrain model!")
  50. steps_per_epoch = dataset.get_dataset_size()
  51. # optimizer
  52. if cfg.optimizer == 'AdamWeightDecay':
  53. lr_schedule = GPT2LearningRate(learning_rate=cfg.AdamWeightDecay.learning_rate,
  54. end_learning_rate=cfg.AdamWeightDecay.end_learning_rate,
  55. warmup_steps=int(steps_per_epoch * epoch_num * 0.1),
  56. decay_steps=steps_per_epoch * epoch_num,
  57. power=cfg.AdamWeightDecay.power)
  58. params = network.trainable_params()
  59. decay_params = list(filter(cfg.AdamWeightDecay.decay_filter, params))
  60. other_params = list(filter(lambda x: not cfg.AdamWeightDecay.decay_filter(x), params))
  61. group_params = [{'params': decay_params, 'weight_decay': cfg.AdamWeightDecay.weight_decay},
  62. {'params': other_params, 'weight_decay': 0.0}]
  63. optimizer = AdamWeightDecay(group_params, lr_schedule, eps=cfg.AdamWeightDecay.eps)
  64. elif cfg.optimizer == 'Lamb':
  65. lr_schedule = GPT2LearningRate(learning_rate=cfg.Lamb.learning_rate,
  66. end_learning_rate=cfg.Lamb.end_learning_rate,
  67. warmup_steps=int(steps_per_epoch * epoch_num * 0.1),
  68. decay_steps=steps_per_epoch * epoch_num,
  69. power=cfg.Lamb.power)
  70. optimizer = Lamb(network.trainable_params(), lr_schedule)
  71. elif cfg.optimizer == 'Momentum':
  72. optimizer = Momentum(network.trainable_params(), cfg.Momentum.learning_rate, cfg.Momentum.momentum)
  73. else:
  74. raise Exception("Optimizer not supported. support: [AdamWeightDecay, Lamb, Momentum]")
  75. # load checkpoint into network
  76. ckpt_config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=1)
  77. prefix_name = "gpt2_" + "lambada_" + str(cfg.gpt2_network) + "_" + str(cfg.optimizer) + "_" \
  78. + str(epoch_num) + "_bs" + str(gpt2_net_cfg.batch_size)
  79. ckpoint_cb = ModelCheckpoint(prefix=prefix_name,
  80. directory=None if save_checkpoint_path == "" else save_checkpoint_path,
  81. config=ckpt_config)
  82. param_dict = load_checkpoint(load_checkpoint_path)
  83. final_param_dict = {}
  84. for name, _ in param_dict.items():
  85. final_param_dict['gpt2.gpt2.' + name] = param_dict[name]
  86. final_param_dict['gpt2.dense1.weight'] = param_dict['gpt2_embedding_lookup.embedding_table']
  87. load_param_into_net(network, final_param_dict)
  88. print("Load pretrained parameter successfully!\n")
  89. update_cell = DynamicLossScaleUpdateCell(loss_scale_value=2 ** 32, scale_factor=2, scale_window=1000)
  90. netwithgrads = GPT2FinetuneCell(network, optimizer=optimizer, scale_update_cell=update_cell)
  91. netwithgrads.set_train(True)
  92. loss_cb = LossMonitor(per_print_times=1)
  93. model = Model(netwithgrads)
  94. callbacks = [TimeMonitor(dataset.get_dataset_size()), loss_cb, ckpoint_cb]
  95. print("==================== Starting Finetuning ====================")
  96. model.train(epoch_num, dataset, callbacks=callbacks, dataset_sink_mode=False)
  97. print("==================== Finetuning Success ====================")
  98. def eval_result_print(metric="accuracy", callback=None):
  99. """
  100. Print eval result.
  101. """
  102. if metric.lower() == "accuracy":
  103. print("acc_num {}, total_num {}, accuracy {:.6f}".format(callback.acc_num, callback.total_num,
  104. callback.acc_num / callback.total_num))
  105. else:
  106. raise ValueError("metric method not supported, support: [accuracy]")
  107. def do_eval(dataset=None, network=None, metric=None, load_checkpoint_path="", eval_type=None, stop_word_file="",
  108. generate_length_dynamic=True, tokenizer_file_path=""):
  109. """
  110. Do eval
  111. Args:
  112. dataset: the eval dataset.
  113. network: the network with loss.
  114. metric: the evaluation method.
  115. load_checkpoint_path: the file path which saved finetune model checkpoint.
  116. eval_type: the eval type, i.e. zero-shot, finetuned.
  117. generate_length_dynamic (bool): True for the generate length is dynamic, False for fixed. Default: True.
  118. tokenizer_file_path: the tokenizer file path for vocab file and merge file.
  119. stop_word_file: stop word file for calculating Accuracy.
  120. """
  121. if load_checkpoint_path == "":
  122. raise ValueError("Finetune model missed, evaluation task must load finetune model!")
  123. tokenizer = Tokenizer(vocab_file=tokenizer_file_path + 'gpt2-vocab.json',
  124. merge_file=tokenizer_file_path + 'gpt2-merges.txt')
  125. gpt2_lambada = network(config=gpt2_net_cfg,
  126. is_training=False,
  127. use_one_hot_embeddings=False)
  128. gpt2_lambada.set_train(False)
  129. param_dict = load_checkpoint(load_checkpoint_path)
  130. if eval_type == "zero-shot":
  131. final_param_dict = {}
  132. for name, _ in param_dict.items():
  133. final_param_dict['gpt2.gpt2.' + name] = param_dict[name]
  134. final_param_dict['gpt2.dense1.weight'] = param_dict['gpt2_embedding_lookup.embedding_table']
  135. load_param_into_net(gpt2_lambada, final_param_dict)
  136. print("load pretrained parameter successfully!\n")
  137. elif eval_type == "finetuned":
  138. load_param_into_net(gpt2_lambada, param_dict)
  139. print("load finetuned parameter successfully!\n")
  140. model = Model(gpt2_lambada)
  141. if metric.lower() == "accuracy":
  142. print("Prepare to calculate the accuracy score ...")
  143. callback = LastWordAccuracy()
  144. columns_list = ["input_ids", "input_mask", "input_length"]
  145. print("==================== [ACC] Testing ====================")
  146. lambada_generator = GenerateForLambada(decoder=model,
  147. config=gpt2_net_cfg,
  148. tokenizer=tokenizer,
  149. generate_length_dynamic=generate_length_dynamic,
  150. max_iterations=200,
  151. stop_word_file=stop_word_file)
  152. num_data = 1
  153. for data in dataset.create_dict_iterator():
  154. input_data = []
  155. for i in columns_list:
  156. input_data.append(data[i])
  157. input_ids, input_mask, input_length = input_data
  158. print("| [ACC] number : {} / {} ".format(num_data, dataset.get_dataset_size()))
  159. logits = model.predict(input_ids, input_mask)
  160. predict_str = lambada_generator.generate_for_lambada(input_ids=input_ids,
  161. logits=logits,
  162. input_length=input_length)
  163. label_str = get_final_word_label(input_ids=input_ids, input_length=input_length, tokenizer=tokenizer)
  164. callback.update(predict_str, label_str)
  165. eval_result_print(metric, callback)
  166. num_data += 1
  167. print("\n\n")
  168. print("**********************************************************")
  169. eval_result_print(metric, callback)
  170. print("******************** Testing Finished ********************")
  171. elif metric.lower() == "ppl":
  172. print("Prepare to calculate the ppl score ...")
  173. cross_entropy = CrossEntropyCalculationWithMask(is_training=True,
  174. num_labels=gpt2_net_cfg.vocab_size,
  175. config=gpt2_net_cfg)
  176. columns_list = ["input_ids", "input_mask", "input_length"]
  177. num_data = 1
  178. total_loss = 0.0
  179. print("==================== [PPL] Testing ====================")
  180. for data in dataset.create_dict_iterator():
  181. input_data = []
  182. for i in columns_list:
  183. input_data.append(data[i])
  184. input_ids, input_mask, input_length = input_data
  185. print("| [PPL] number : {} / {} ".format(num_data, dataset.get_dataset_size()))
  186. logits = model.predict(input_ids, input_mask) # (batch_size, seq_len, vocab_size)
  187. avg_batch_loss = calculate_final_word_loss(logits,
  188. gpt2_net_cfg.batch_size,
  189. input_ids,
  190. input_length,
  191. cross_entropy)
  192. total_loss += avg_batch_loss
  193. avg_total_loss = total_loss / num_data
  194. print(" | Current AVG loss:", avg_total_loss)
  195. print(" | Current AVG ppl:", math.exp(avg_total_loss))
  196. num_data += 1
  197. print("\n\n")
  198. print("**********************************************************")
  199. print("Average PPL: {:.6f}".format(math.exp(avg_total_loss)))
  200. print("******************** Testing Finished ********************")
  201. else:
  202. raise ValueError("metric method not supported, support: [accuracy, ppl]")
  203. def run_lambada():
  204. """
  205. Run Lambada task.
  206. """
  207. parser = argparse.ArgumentParser(description="Finetune and Evaluate languagemodel")
  208. parser.add_argument("--device_target", type=str, default="Ascend",
  209. help="Device type. Default: Ascend.")
  210. parser.add_argument("--device_id", type=int, default=2,
  211. help="ID of target device.")
  212. parser.add_argument("--metric_method", type=str, default="PPL",
  213. help="The eval method including [Accuracy, PPL]. Default: Accuracy.")
  214. parser.add_argument("--do_train", type=str, default="false",
  215. help="Enable train. Default: false.")
  216. parser.add_argument("--do_eval", type=str, default="true",
  217. help="Enable evaluation. Default: false.")
  218. parser.add_argument("--eval_type", type=str, default="finetuned",
  219. help="The type of evaluation including [zero-shot, finetuned]. Default: zero-shot.")
  220. parser.add_argument("--epoch_num", type=int, default=3,
  221. help="Epoch number. Default: 1.")
  222. parser.add_argument("--train_data_shuffle", type=str, default="false",
  223. help="Enable train data shuffle. Default: true.")
  224. parser.add_argument("--eval_data_shuffle", type=str, default="false",
  225. help="Enable eval data shuffle. Default: false.")
  226. parser.add_argument("--generate_length_dynamically", type=str, default="true",
  227. help="Enable generate_length_Dynamically. Default: true.")
  228. parser.add_argument("--save_finetune_ckpt_path", type=str, default="",
  229. help="Save the checkpoint path.")
  230. parser.add_argument("--load_pretrain_ckpt_path", type=str, default="",
  231. help="Load the checkpoint file path.")
  232. parser.add_argument("--load_finetune_ckpt_path", type=str, default="",
  233. help="Load the checkpoint file path.")
  234. parser.add_argument("--train_data_file_path", type=str, default="",
  235. help="Data path, it is better to use absolute path.")
  236. parser.add_argument("--eval_data_file_path", type=str, default="",
  237. help="Data path, it is better to use absolute path.")
  238. parser.add_argument("--tokenizer_file_path", type=str, default="",
  239. help="pretrained vocab and merge file path.")
  240. parser.add_argument("--stop_word_file_path", type=str, default="",
  241. help="The stop word file path.")
  242. args_opt = parser.parse_args()
  243. epoch_num = args_opt.epoch_num
  244. metric = args_opt.metric_method
  245. save_finetune_ckpt_path = args_opt.save_finetune_ckpt_path
  246. load_finetune_ckpt_path = args_opt.load_finetune_ckpt_path
  247. load_pretrain_ckpt_path = args_opt.load_pretrain_ckpt_path
  248. if args_opt.do_train.lower() == "false" and args_opt.do_eval.lower() == "false":
  249. raise ValueError("At least one of 'do_train' or 'do_eval' must be true")
  250. if args_opt.do_train.lower() == "true" and args_opt.train_data_file_path == "":
  251. raise ValueError("'train_data_file_path' must be set when do finetune task")
  252. if args_opt.do_eval.lower() == "true" and args_opt.eval_data_file_path == "":
  253. raise ValueError("'eval_data_file_path' must be set when do evaluation task")
  254. device = args_opt.device_target
  255. if device == "Ascend":
  256. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id)
  257. context.set_auto_parallel_context(parallel_mode="stand_alone")
  258. print(" | Device: {} | Device id: {}".format(device, args_opt.device_id))
  259. else:
  260. raise Exception("Device target error, Ascend is supported.")
  261. gpt2_loss = GPT2Lambada(config=gpt2_net_cfg,
  262. is_training=True,
  263. use_one_hot_embeddings=False)
  264. if args_opt.do_train.lower() == "true":
  265. get_train_setting(cfg)
  266. get_model_setting(cfg, gpt2_net_cfg)
  267. print("============== Start Loading Train Dataset ============")
  268. print(" | Train Dataset: {}".format(args_opt.train_data_file_path))
  269. print(" | Checkpoint: {}".format(args_opt.load_pretrain_ckpt_path))
  270. train_dataset = create_language_model_dataset(do_shuffle=(args_opt.train_data_shuffle.lower() == "true"),
  271. dataset_path=args_opt.train_data_file_path)
  272. do_train(train_dataset, gpt2_loss, load_pretrain_ckpt_path, save_finetune_ckpt_path, epoch_num)
  273. if args_opt.do_eval.lower() == "true":
  274. get_model_setting(cfg, gpt2_net_cfg)
  275. print("============== Start Loading Evaluation Dataset ============")
  276. print(" | Eval Dataset: {}".format(args_opt.eval_data_file_path))
  277. print(" | Checkpoint: {}".format(args_opt.load_finetune_ckpt_path))
  278. eval_dataset = create_lambada_control_dataset(do_shuffle=(args_opt.eval_data_shuffle.lower() == "true"),
  279. dataset_path=args_opt.eval_data_file_path)
  280. do_eval(eval_dataset, GPT2Lambada, metric, load_finetune_ckpt_path, args_opt.eval_type,
  281. args_opt.stop_word_file_path, args_opt.generate_length_dynamically, args_opt.tokenizer_file_path)
  282. if __name__ == "__main__":
  283. print("Start Time: ", time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
  284. run_lambada()
  285. print("End Time: ", time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))