You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 8.0 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """
  16. Functional Cells used in Bert finetune and evaluation.
  17. """
  18. import os
  19. import math
  20. import numpy as np
  21. import mindspore.nn as nn
  22. from mindspore import log as logger
  23. from mindspore.ops import operations as P
  24. from mindspore.ops import functional as F
  25. from mindspore.ops import composite as C
  26. from mindspore.common.tensor import Tensor
  27. from mindspore.common import dtype as mstype
  28. from mindspore.train.callback import Callback
  29. from mindspore.nn.learning_rate_schedule import LearningRateSchedule, PolynomialDecayLR, WarmUpLR
  30. get_square_sum = C.MultitypeFuncGraph("get_square_sum")
  31. @get_square_sum.register("Tensor")
  32. def _get_square_sum(grad):
  33. norm = P.ReduceSum(False)(F.square(grad), ())
  34. norm = F.expand_dims(F.cast(norm, mstype.float32), 0)
  35. return norm
  36. apply_global_norm = C.MultitypeFuncGraph("apply_global_norm")
  37. @apply_global_norm.register("Tensor", "Tensor", "Tensor")
  38. def _apply_global_norm(clip_norm, global_norm, grad):
  39. grad = grad * clip_norm / global_norm
  40. return grad
  41. class GlobalNorm(nn.Cell):
  42. """
  43. Calculate the global norm value of given tensors
  44. """
  45. def __init__(self):
  46. super(GlobalNorm, self).__init__()
  47. self.norm = nn.Norm()
  48. self.hyper_map = C.HyperMap()
  49. def construct(self, grads):
  50. square_sum = self.hyper_map(get_square_sum, grads)
  51. global_norms = F.sqrt(F.addn(square_sum) / F.scalar_to_array(len(square_sum)))
  52. return global_norms
  53. class ClipByGlobalNorm(nn.Cell):
  54. """
  55. Clip grads by global norm
  56. """
  57. def __init__(self, clip_norm=1.0):
  58. super(ClipByGlobalNorm, self).__init__()
  59. self.global_norm = GlobalNorm()
  60. self.clip_norm = Tensor([clip_norm], mstype.float32)
  61. self.hyper_map = C.HyperMap()
  62. def construct(self, grads):
  63. global_norm = self.global_norm(grads)
  64. cond = P.GreaterEqual()(global_norm, self.clip_norm)
  65. global_norm = F.select(cond, global_norm, self.clip_norm)
  66. grads = self.hyper_map(F.partial(apply_global_norm, self.clip_norm, global_norm), grads)
  67. return grads
  68. class CrossEntropyCalculation(nn.Cell):
  69. """
  70. Cross Entropy loss
  71. """
  72. def __init__(self, is_training=True):
  73. super(CrossEntropyCalculation, self).__init__()
  74. self.onehot = P.OneHot()
  75. self.on_value = Tensor(1.0, mstype.float32)
  76. self.off_value = Tensor(0.0, mstype.float32)
  77. self.reduce_sum = P.ReduceSum()
  78. self.reduce_mean = P.ReduceMean()
  79. self.reshape = P.Reshape()
  80. self.last_idx = (-1,)
  81. self.neg = P.Neg()
  82. self.cast = P.Cast()
  83. self.is_training = is_training
  84. def construct(self, logits, label_ids, num_labels):
  85. if self.is_training:
  86. label_ids = self.reshape(label_ids, self.last_idx)
  87. one_hot_labels = self.onehot(label_ids, num_labels, self.on_value, self.off_value)
  88. per_example_loss = self.neg(self.reduce_sum(one_hot_labels * logits, self.last_idx))
  89. loss = self.reduce_mean(per_example_loss, self.last_idx)
  90. return_value = self.cast(loss, mstype.float32)
  91. else:
  92. return_value = logits * 1.0
  93. return return_value
  94. def make_directory(path: str):
  95. """Make directory."""
  96. if path is None or not isinstance(path, str) or path.strip() == "":
  97. logger.error("The path(%r) is invalid type.", path)
  98. raise TypeError("Input path is invaild type")
  99. # convert the relative paths
  100. path = os.path.realpath(path)
  101. logger.debug("The abs path is %r", path)
  102. # check the path is exist and write permissions?
  103. if os.path.exists(path):
  104. real_path = path
  105. else:
  106. # All exceptions need to be caught because create directory maybe have some limit(permissions)
  107. logger.debug("The directory(%s) doesn't exist, will create it", path)
  108. try:
  109. os.makedirs(path, exist_ok=True)
  110. real_path = path
  111. except PermissionError as e:
  112. logger.error("No write permission on the directory(%r), error = %r", path, e)
  113. raise TypeError("No write permission on the directory.")
  114. return real_path
  115. class LossCallBack(Callback):
  116. """
  117. Monitor the loss in training.
  118. If the loss in NAN or INF terminating training.
  119. Note:
  120. if per_print_times is 0 do not print loss.
  121. Args:
  122. per_print_times (int): Print loss every times. Default: 1.
  123. """
  124. def __init__(self, dataset_size=1):
  125. super(LossCallBack, self).__init__()
  126. self._dataset_size = dataset_size
  127. def step_end(self, run_context):
  128. cb_params = run_context.original_args()
  129. percent, epoch_num = math.modf(cb_params.cur_step_num / self._dataset_size)
  130. print("epoch: {}, current epoch percent: {}, step: {}, outputs are {}"
  131. .format(epoch_num, "%.3f" % percent, cb_params.cur_step_num, str(cb_params.net_outputs)))
  132. def LoadNewestCkpt(load_finetune_checkpoint_dir, steps_per_epoch, epoch_num, prefix):
  133. """
  134. Find the ckpt finetune generated and load it into eval network.
  135. """
  136. files = os.listdir(load_finetune_checkpoint_dir)
  137. pre_len = len(prefix)
  138. max_num = 0
  139. for filename in files:
  140. name_ext = os.path.splitext(filename)
  141. if name_ext[-1] != ".ckpt":
  142. continue
  143. #steps_per_epoch = ds.get_dataset_size()
  144. if filename.find(prefix) == 0 and not filename[pre_len].isalpha():
  145. index = filename[pre_len:].find("-")
  146. if index == 0 and max_num == 0:
  147. load_finetune_checkpoint_path = os.path.join(load_finetune_checkpoint_dir, filename)
  148. elif index not in (0, -1):
  149. name_split = name_ext[-2].split('_')
  150. if (steps_per_epoch != int(name_split[len(name_split)-1])) \
  151. or (epoch_num != int(filename[pre_len + index + 1:pre_len + index + 2])):
  152. continue
  153. num = filename[pre_len + 1:pre_len + index]
  154. if int(num) > max_num:
  155. max_num = int(num)
  156. load_finetune_checkpoint_path = os.path.join(load_finetune_checkpoint_dir, filename)
  157. return load_finetune_checkpoint_path
  158. class BertLearningRate(LearningRateSchedule):
  159. """
  160. Warmup-decay learning rate for Bert network.
  161. """
  162. def __init__(self, learning_rate, end_learning_rate, warmup_steps, decay_steps, power):
  163. super(BertLearningRate, self).__init__()
  164. self.warmup_flag = False
  165. if warmup_steps > 0:
  166. self.warmup_flag = True
  167. self.warmup_lr = WarmUpLR(learning_rate, warmup_steps)
  168. self.decay_lr = PolynomialDecayLR(learning_rate, end_learning_rate, decay_steps, power)
  169. self.warmup_steps = Tensor(np.array([warmup_steps]).astype(np.float32))
  170. self.greater = P.Greater()
  171. self.one = Tensor(np.array([1.0]).astype(np.float32))
  172. self.cast = P.Cast()
  173. def construct(self, global_step):
  174. decay_lr = self.decay_lr(global_step)
  175. if self.warmup_flag:
  176. is_warmup = self.cast(self.greater(self.warmup_steps, global_step), mstype.float32)
  177. warmup_lr = self.warmup_lr(global_step)
  178. lr = (self.one - is_warmup) * decay_lr + is_warmup * warmup_lr
  179. else:
  180. lr = decay_lr
  181. return lr