You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

finetune.py 8.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. '''
  16. Bert finetune script.
  17. '''
  18. import os
  19. import argparse
  20. from src.utils import BertFinetuneCell, BertCLS, BertNER, BertSquad, BertSquadCell
  21. from src.finetune_config import cfg, bert_net_cfg, tag_to_index
  22. import mindspore.common.dtype as mstype
  23. from mindspore import context
  24. import mindspore.dataset as de
  25. import mindspore.dataset.transforms.c_transforms as C
  26. from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
  27. from mindspore.nn.optim import AdamWeightDecayDynamicLR, Lamb, Momentum
  28. from mindspore.train.model import Model
  29. from mindspore.train.callback import Callback
  30. from mindspore.train.callback import CheckpointConfig, ModelCheckpoint
  31. from mindspore.train.serialization import load_checkpoint, load_param_into_net
  32. class LossCallBack(Callback):
  33. '''
  34. Monitor the loss in training.
  35. If the loss is NAN or INF, terminate training.
  36. Note:
  37. If per_print_times is 0, do not print loss.
  38. Args:
  39. per_print_times (int): Print loss every times. Default: 1.
  40. '''
  41. def __init__(self, per_print_times=1):
  42. super(LossCallBack, self).__init__()
  43. if not isinstance(per_print_times, int) or per_print_times < 0:
  44. raise ValueError("print_step must be in and >= 0.")
  45. self._per_print_times = per_print_times
  46. def step_end(self, run_context):
  47. cb_params = run_context.original_args()
  48. with open("./loss.log", "a+") as f:
  49. f.write("epoch: {}, step: {}, outputs are {}".format(cb_params.cur_epoch_num, cb_params.cur_step_num,
  50. str(cb_params.net_outputs)))
  51. f.write("\n")
  52. def get_dataset(batch_size=1, repeat_count=1, distribute_file=''):
  53. '''
  54. get dataset
  55. '''
  56. ds = de.TFRecordDataset([cfg.data_file], cfg.schema_file, columns_list=["input_ids", "input_mask",
  57. "segment_ids", "label_ids"])
  58. type_cast_op = C.TypeCast(mstype.int32)
  59. ds = ds.map(input_columns="segment_ids", operations=type_cast_op)
  60. ds = ds.map(input_columns="input_mask", operations=type_cast_op)
  61. ds = ds.map(input_columns="input_ids", operations=type_cast_op)
  62. ds = ds.map(input_columns="label_ids", operations=type_cast_op)
  63. ds = ds.repeat(repeat_count)
  64. # apply shuffle operation
  65. buffer_size = 960
  66. ds = ds.shuffle(buffer_size=buffer_size)
  67. # apply batch operations
  68. ds = ds.batch(batch_size, drop_remainder=True)
  69. return ds
  70. def get_squad_dataset(batch_size=1, repeat_count=1, distribute_file=''):
  71. '''
  72. get SQuAD dataset
  73. '''
  74. ds = de.TFRecordDataset([cfg.data_file], cfg.schema_file, columns_list=["input_ids", "input_mask", "segment_ids",
  75. "start_positions", "end_positions",
  76. "unique_ids", "is_impossible"])
  77. type_cast_op = C.TypeCast(mstype.int32)
  78. ds = ds.map(input_columns="segment_ids", operations=type_cast_op)
  79. ds = ds.map(input_columns="input_ids", operations=type_cast_op)
  80. ds = ds.map(input_columns="input_mask", operations=type_cast_op)
  81. ds = ds.map(input_columns="start_positions", operations=type_cast_op)
  82. ds = ds.map(input_columns="end_positions", operations=type_cast_op)
  83. ds = ds.repeat(repeat_count)
  84. buffer_size = 960
  85. ds = ds.shuffle(buffer_size=buffer_size)
  86. ds = ds.batch(batch_size, drop_remainder=True)
  87. return ds
  88. def test_train():
  89. '''
  90. finetune function
  91. '''
  92. target = args_opt.device_target
  93. if target == "Ascend":
  94. devid = int(os.getenv('DEVICE_ID'))
  95. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=devid)
  96. elif target == "GPU":
  97. context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
  98. else:
  99. raise Exception("Target error, GPU or Ascend is supported.")
  100. #BertCLSTrain for classification
  101. #BertNERTrain for sequence labeling
  102. if cfg.task == 'NER':
  103. if cfg.use_crf:
  104. netwithloss = BertNER(bert_net_cfg, True, num_labels=len(tag_to_index), use_crf=True,
  105. tag_to_index=tag_to_index, dropout_prob=0.1)
  106. else:
  107. netwithloss = BertNER(bert_net_cfg, True, num_labels=cfg.num_labels, dropout_prob=0.1)
  108. elif cfg.task == 'SQUAD':
  109. netwithloss = BertSquad(bert_net_cfg, True, 2, dropout_prob=0.1)
  110. else:
  111. netwithloss = BertCLS(bert_net_cfg, True, num_labels=cfg.num_labels, dropout_prob=0.1)
  112. if cfg.task == 'SQUAD':
  113. dataset = get_squad_dataset(bert_net_cfg.batch_size, cfg.epoch_num)
  114. else:
  115. dataset = get_dataset(bert_net_cfg.batch_size, cfg.epoch_num)
  116. # optimizer
  117. steps_per_epoch = dataset.get_dataset_size()
  118. if cfg.optimizer == 'AdamWeightDecayDynamicLR':
  119. optimizer = AdamWeightDecayDynamicLR(netwithloss.trainable_params(),
  120. decay_steps=steps_per_epoch * cfg.epoch_num,
  121. learning_rate=cfg.AdamWeightDecayDynamicLR.learning_rate,
  122. end_learning_rate=cfg.AdamWeightDecayDynamicLR.end_learning_rate,
  123. power=cfg.AdamWeightDecayDynamicLR.power,
  124. warmup_steps=int(steps_per_epoch * cfg.epoch_num * 0.1),
  125. weight_decay=cfg.AdamWeightDecayDynamicLR.weight_decay,
  126. eps=cfg.AdamWeightDecayDynamicLR.eps)
  127. elif cfg.optimizer == 'Lamb':
  128. optimizer = Lamb(netwithloss.trainable_params(), decay_steps=steps_per_epoch * cfg.epoch_num,
  129. start_learning_rate=cfg.Lamb.start_learning_rate, end_learning_rate=cfg.Lamb.end_learning_rate,
  130. power=cfg.Lamb.power, weight_decay=cfg.Lamb.weight_decay,
  131. warmup_steps=int(steps_per_epoch * cfg.epoch_num * 0.1), decay_filter=cfg.Lamb.decay_filter)
  132. elif cfg.optimizer == 'Momentum':
  133. optimizer = Momentum(netwithloss.trainable_params(), learning_rate=cfg.Momentum.learning_rate,
  134. momentum=cfg.Momentum.momentum)
  135. else:
  136. raise Exception("Optimizer not supported.")
  137. # load checkpoint into network
  138. ckpt_config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=1)
  139. ckpoint_cb = ModelCheckpoint(prefix=cfg.ckpt_prefix, directory=cfg.ckpt_dir, config=ckpt_config)
  140. param_dict = load_checkpoint(cfg.pre_training_ckpt)
  141. load_param_into_net(netwithloss, param_dict)
  142. update_cell = DynamicLossScaleUpdateCell(loss_scale_value=2**32, scale_factor=2, scale_window=1000)
  143. if cfg.task == 'SQUAD':
  144. netwithgrads = BertSquadCell(netwithloss, optimizer=optimizer, scale_update_cell=update_cell)
  145. else:
  146. netwithgrads = BertFinetuneCell(netwithloss, optimizer=optimizer, scale_update_cell=update_cell)
  147. model = Model(netwithgrads)
  148. model.train(cfg.epoch_num, dataset, callbacks=[LossCallBack(), ckpoint_cb])
  149. parser = argparse.ArgumentParser(description='Bert finetune')
  150. parser.add_argument('--device_target', type=str, default='Ascend', help='Device target')
  151. args_opt = parser.parse_args()
  152. if __name__ == "__main__":
  153. test_train()