You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 9.1 kB

5 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """
  16. Functional Cells used in Bert finetune and evaluation.
  17. """
  18. import os
  19. import math
  20. import collections
  21. import numpy as np
  22. import mindspore.nn as nn
  23. from mindspore import log as logger
  24. from mindspore.ops import operations as P
  25. from mindspore.common.tensor import Tensor
  26. from mindspore.common import dtype as mstype
  27. from mindspore.train.callback import Callback
  28. from mindspore.nn.learning_rate_schedule import LearningRateSchedule, PolynomialDecayLR, WarmUpLR
  29. class CrossEntropyCalculation(nn.Cell):
  30. """
  31. Cross Entropy loss
  32. """
  33. def __init__(self, is_training=True):
  34. super(CrossEntropyCalculation, self).__init__()
  35. self.onehot = P.OneHot()
  36. self.on_value = Tensor(1.0, mstype.float32)
  37. self.off_value = Tensor(0.0, mstype.float32)
  38. self.reduce_sum = P.ReduceSum()
  39. self.reduce_mean = P.ReduceMean()
  40. self.reshape = P.Reshape()
  41. self.last_idx = (-1,)
  42. self.neg = P.Neg()
  43. self.cast = P.Cast()
  44. self.is_training = is_training
  45. def construct(self, logits, label_ids, num_labels):
  46. if self.is_training:
  47. label_ids = self.reshape(label_ids, self.last_idx)
  48. one_hot_labels = self.onehot(label_ids, num_labels, self.on_value, self.off_value)
  49. per_example_loss = self.neg(self.reduce_sum(one_hot_labels * logits, self.last_idx))
  50. loss = self.reduce_mean(per_example_loss, self.last_idx)
  51. return_value = self.cast(loss, mstype.float32)
  52. else:
  53. return_value = logits * 1.0
  54. return return_value
  55. def make_directory(path: str):
  56. """Make directory."""
  57. if path is None or not isinstance(path, str) or path.strip() == "":
  58. logger.error("The path(%r) is invalid type.", path)
  59. raise TypeError("Input path is invalid type")
  60. # convert the relative paths
  61. path = os.path.realpath(path)
  62. logger.debug("The abs path is %r", path)
  63. # check the path is exist and write permissions?
  64. if os.path.exists(path):
  65. real_path = path
  66. else:
  67. # All exceptions need to be caught because create directory maybe have some limit(permissions)
  68. logger.debug("The directory(%s) doesn't exist, will create it", path)
  69. try:
  70. os.makedirs(path, exist_ok=True)
  71. real_path = path
  72. except PermissionError as e:
  73. logger.error("No write permission on the directory(%r), error = %r", path, e)
  74. raise TypeError("No write permission on the directory.")
  75. return real_path
  76. class LossCallBack(Callback):
  77. """
  78. Monitor the loss in training.
  79. If the loss in NAN or INF terminating training.
  80. Note:
  81. if per_print_times is 0 do not print loss.
  82. Args:
  83. per_print_times (int): Print loss every times. Default: 1.
  84. """
  85. def __init__(self, dataset_size=-1):
  86. super(LossCallBack, self).__init__()
  87. self._dataset_size = dataset_size
  88. def step_end(self, run_context):
  89. """
  90. Print loss after each step
  91. """
  92. cb_params = run_context.original_args()
  93. if self._dataset_size > 0:
  94. percent, epoch_num = math.modf(cb_params.cur_step_num / self._dataset_size)
  95. if percent == 0:
  96. percent = 1
  97. epoch_num -= 1
  98. print("epoch: {}, current epoch percent: {}, step: {}, outputs are {}"
  99. .format(int(epoch_num), "%.3f" % percent, cb_params.cur_step_num, str(cb_params.net_outputs)),
  100. flush=True)
  101. else:
  102. print("epoch: {}, step: {}, outputs are {}".format(cb_params.cur_epoch_num, cb_params.cur_step_num,
  103. str(cb_params.net_outputs)), flush=True)
  104. def LoadNewestCkpt(load_finetune_checkpoint_dir, steps_per_epoch, epoch_num, prefix):
  105. """
  106. Find the ckpt finetune generated and load it into eval network.
  107. """
  108. files = os.listdir(load_finetune_checkpoint_dir)
  109. pre_len = len(prefix)
  110. max_num = 0
  111. for filename in files:
  112. name_ext = os.path.splitext(filename)
  113. if name_ext[-1] != ".ckpt":
  114. continue
  115. if filename.find(prefix) == 0 and not filename[pre_len].isalpha():
  116. index = filename[pre_len:].find("-")
  117. if index == 0 and max_num == 0:
  118. load_finetune_checkpoint_path = os.path.join(load_finetune_checkpoint_dir, filename)
  119. elif index not in (0, -1):
  120. name_split = name_ext[-2].split('_')
  121. if (steps_per_epoch != int(name_split[len(name_split)-1])) \
  122. or (epoch_num != int(filename[pre_len + index + 1:pre_len + index + 2])):
  123. continue
  124. num = filename[pre_len + 1:pre_len + index]
  125. if int(num) > max_num:
  126. max_num = int(num)
  127. load_finetune_checkpoint_path = os.path.join(load_finetune_checkpoint_dir, filename)
  128. return load_finetune_checkpoint_path
  129. class BertLearningRate(LearningRateSchedule):
  130. """
  131. Warmup-decay learning rate for Bert network.
  132. """
  133. def __init__(self, learning_rate, end_learning_rate, warmup_steps, decay_steps, power):
  134. super(BertLearningRate, self).__init__()
  135. self.warmup_flag = False
  136. if warmup_steps > 0:
  137. self.warmup_flag = True
  138. self.warmup_lr = WarmUpLR(learning_rate, warmup_steps)
  139. self.decay_lr = PolynomialDecayLR(learning_rate, end_learning_rate, decay_steps, power)
  140. self.warmup_steps = Tensor(np.array([warmup_steps]).astype(np.float32))
  141. self.greater = P.Greater()
  142. self.one = Tensor(np.array([1.0]).astype(np.float32))
  143. self.cast = P.Cast()
  144. def construct(self, global_step):
  145. decay_lr = self.decay_lr(global_step)
  146. if self.warmup_flag:
  147. is_warmup = self.cast(self.greater(self.warmup_steps, global_step), mstype.float32)
  148. warmup_lr = self.warmup_lr(global_step)
  149. lr = (self.one - is_warmup) * decay_lr + is_warmup * warmup_lr
  150. else:
  151. lr = decay_lr
  152. return lr
  153. def convert_labels_to_index(label_list):
  154. """
  155. Convert label_list to indices for NER task.
  156. """
  157. label2id = collections.OrderedDict()
  158. label2id["O"] = 0
  159. prefix = ["S_", "B_", "M_", "E_"]
  160. index = 0
  161. for label in label_list:
  162. for pre in prefix:
  163. index += 1
  164. sub_label = pre + label
  165. label2id[sub_label] = index
  166. return label2id
  167. def _get_poly_lr(global_step, lr_init, lr_end, lr_max, warmup_steps, total_steps, poly_power):
  168. """
  169. generate learning rate array
  170. Args:
  171. global_step(int): current step
  172. lr_init(float): init learning rate
  173. lr_end(float): end learning rate
  174. lr_max(float): max learning rate
  175. warmup_steps(int): number of warmup epochs
  176. total_steps(int): total epoch of training
  177. poly_power(int): poly learning rate power
  178. Returns:
  179. np.array, learning rate array
  180. """
  181. lr_each_step = []
  182. if warmup_steps != 0:
  183. inc_each_step = (float(lr_max) - float(lr_init)) / float(warmup_steps)
  184. else:
  185. inc_each_step = 0
  186. for i in range(total_steps):
  187. if i < warmup_steps:
  188. lr = float(lr_init) + inc_each_step * float(i)
  189. else:
  190. base = (1.0 - (float(i) - float(warmup_steps)) / (float(total_steps) - float(warmup_steps)))
  191. lr = float(lr_max - lr_end) * (base ** poly_power)
  192. lr = lr + lr_end
  193. if lr < 0.0:
  194. lr = 0.0
  195. lr_each_step.append(lr)
  196. learning_rate = np.array(lr_each_step).astype(np.float32)
  197. current_step = global_step
  198. learning_rate = learning_rate[current_step:]
  199. return learning_rate
  200. def get_bert_thor_lr(lr_max=0.0034, lr_min=3.244e-05, lr_power=1.0, lr_total_steps=30000):
  201. learning_rate = _get_poly_lr(global_step=0, lr_init=0.0, lr_end=lr_min, lr_max=lr_max, warmup_steps=0,
  202. total_steps=lr_total_steps, poly_power=lr_power)
  203. return Tensor(learning_rate)
  204. def get_bert_thor_damping(damping_max=5e-2, damping_min=1e-6, damping_power=1.0, damping_total_steps=30000):
  205. damping = _get_poly_lr(global_step=0, lr_init=0.0, lr_end=damping_min, lr_max=damping_max, warmup_steps=0,
  206. total_steps=damping_total_steps, poly_power=damping_power)
  207. return Tensor(damping)