You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. '''
  16. Functional Cells used in Bert finetune and evaluation.
  17. '''
  18. import mindspore.nn as nn
  19. from mindspore.common.initializer import TruncatedNormal
  20. from mindspore.ops import operations as P
  21. from mindspore.ops import functional as F
  22. from mindspore.ops import composite as C
  23. from mindspore.common.tensor import Tensor
  24. from mindspore.common.parameter import Parameter, ParameterTuple
  25. from mindspore.common import dtype as mstype
  26. from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
  27. from mindspore.train.parallel_utils import ParallelMode
  28. from mindspore.communication.management import get_group_size
  29. from mindspore import context
  30. from mindspore.model_zoo.Bert_NEZHA.bert_model import BertModel
  31. from mindspore.model_zoo.Bert_NEZHA.bert_for_pre_training import ClipGradients
  32. from CRF import CRF
  33. GRADIENT_CLIP_TYPE = 1
  34. GRADIENT_CLIP_VALUE = 1.0
  35. grad_scale = C.MultitypeFuncGraph("grad_scale")
  36. reciprocal = P.Reciprocal()
  37. @grad_scale.register("Tensor", "Tensor")
  38. def tensor_grad_scale(scale, grad):
  39. return grad * reciprocal(scale)
  40. class BertFinetuneCell(nn.Cell):
  41. """
  42. Especifically defined for finetuning where only four inputs tensor are needed.
  43. """
  44. def __init__(self, network, optimizer, scale_update_cell=None):
  45. super(BertFinetuneCell, self).__init__(auto_prefix=False)
  46. self.network = network
  47. self.weights = ParameterTuple(network.trainable_params())
  48. self.optimizer = optimizer
  49. self.grad = C.GradOperation('grad',
  50. get_by_list=True,
  51. sens_param=True)
  52. self.reducer_flag = False
  53. self.allreduce = P.AllReduce()
  54. self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
  55. if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
  56. self.reducer_flag = True
  57. self.grad_reducer = None
  58. if self.reducer_flag:
  59. mean = context.get_auto_parallel_context("mirror_mean")
  60. degree = get_group_size()
  61. self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
  62. self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
  63. self.clip_gradients = ClipGradients()
  64. self.cast = P.Cast()
  65. self.alloc_status = P.NPUAllocFloatStatus()
  66. self.get_status = P.NPUGetFloatStatus()
  67. self.clear_before_grad = P.NPUClearFloatStatus()
  68. self.reduce_sum = P.ReduceSum(keep_dims=False)
  69. self.depend_parameter_use = P.ControlDepend(depend_mode=1)
  70. self.base = Tensor(1, mstype.float32)
  71. self.less_equal = P.LessEqual()
  72. self.hyper_map = C.HyperMap()
  73. self.loss_scale = None
  74. self.loss_scaling_manager = scale_update_cell
  75. if scale_update_cell:
  76. self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32),
  77. name="loss_scale")
  78. def construct(self,
  79. input_ids,
  80. input_mask,
  81. token_type_id,
  82. label_ids,
  83. sens=None):
  84. weights = self.weights
  85. init = self.alloc_status()
  86. loss = self.network(input_ids,
  87. input_mask,
  88. token_type_id,
  89. label_ids)
  90. if sens is None:
  91. scaling_sens = self.loss_scale
  92. else:
  93. scaling_sens = sens
  94. grads = self.grad(self.network, weights)(input_ids,
  95. input_mask,
  96. token_type_id,
  97. label_ids,
  98. self.cast(scaling_sens,
  99. mstype.float32))
  100. clear_before_grad = self.clear_before_grad(init)
  101. F.control_depend(loss, init)
  102. self.depend_parameter_use(clear_before_grad, scaling_sens)
  103. grads = self.hyper_map(F.partial(grad_scale, scaling_sens), grads)
  104. grads = self.clip_gradients(grads, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE)
  105. if self.reducer_flag:
  106. grads = self.grad_reducer(grads)
  107. flag = self.get_status(init)
  108. flag_sum = self.reduce_sum(init, (0,))
  109. if self.is_distributed:
  110. flag_reduce = self.allreduce(flag_sum)
  111. cond = self.less_equal(self.base, flag_reduce)
  112. else:
  113. cond = self.less_equal(self.base, flag_sum)
  114. F.control_depend(grads, flag)
  115. F.control_depend(flag, flag_sum)
  116. overflow = cond
  117. if sens is None:
  118. overflow = self.loss_scaling_manager(self.loss_scale, cond)
  119. if overflow:
  120. succ = False
  121. else:
  122. succ = self.optimizer(grads)
  123. ret = (loss, cond)
  124. return F.depend(ret, succ)
  125. class BertCLSModel(nn.Cell):
  126. """
  127. This class is responsible for classification task evaluation, i.e. XNLI(num_labels=3),
  128. LCQMC(num_labels=2), Chnsenti(num_labels=2). The returned output represents the final
  129. logits as the results of log_softmax is propotional to that of softmax.
  130. """
  131. def __init__(self, config, is_training, num_labels=2, dropout_prob=0.0, use_one_hot_embeddings=False):
  132. super(BertCLSModel, self).__init__()
  133. self.bert = BertModel(config, is_training, use_one_hot_embeddings)
  134. self.cast = P.Cast()
  135. self.weight_init = TruncatedNormal(config.initializer_range)
  136. self.log_softmax = P.LogSoftmax(axis=-1)
  137. self.dtype = config.dtype
  138. self.num_labels = num_labels
  139. self.dense_1 = nn.Dense(config.hidden_size, self.num_labels, weight_init=self.weight_init,
  140. has_bias=True).to_float(config.compute_type)
  141. self.dropout = nn.Dropout(1 - dropout_prob)
  142. def construct(self, input_ids, input_mask, token_type_id):
  143. _, pooled_output, _ = \
  144. self.bert(input_ids, token_type_id, input_mask)
  145. cls = self.cast(pooled_output, self.dtype)
  146. cls = self.dropout(cls)
  147. logits = self.dense_1(cls)
  148. logits = self.cast(logits, self.dtype)
  149. log_probs = self.log_softmax(logits)
  150. return log_probs
  151. class BertNERModel(nn.Cell):
  152. """
  153. This class is responsible for sequence labeling task evaluation, i.e. NER(num_labels=11).
  154. The returned output represents the final logits as the results of log_softmax is propotional to that of softmax.
  155. """
  156. def __init__(self, config, is_training, num_labels=11, use_crf=False, dropout_prob=0.0,
  157. use_one_hot_embeddings=False):
  158. super(BertNERModel, self).__init__()
  159. self.bert = BertModel(config, is_training, use_one_hot_embeddings)
  160. self.cast = P.Cast()
  161. self.weight_init = TruncatedNormal(config.initializer_range)
  162. self.log_softmax = P.LogSoftmax(axis=-1)
  163. self.dtype = config.dtype
  164. self.num_labels = num_labels
  165. self.dense_1 = nn.Dense(config.hidden_size, self.num_labels, weight_init=self.weight_init,
  166. has_bias=True).to_float(config.compute_type)
  167. self.dropout = nn.Dropout(1 - dropout_prob)
  168. self.reshape = P.Reshape()
  169. self.shape = (-1, config.hidden_size)
  170. self.use_crf = use_crf
  171. self.origin_shape = (config.batch_size, config.seq_length, self.num_labels)
  172. def construct(self, input_ids, input_mask, token_type_id):
  173. sequence_output, _, _ = \
  174. self.bert(input_ids, token_type_id, input_mask)
  175. seq = self.dropout(sequence_output)
  176. seq = self.reshape(seq, self.shape)
  177. logits = self.dense_1(seq)
  178. logits = self.cast(logits, self.dtype)
  179. if self.use_crf:
  180. return_value = self.reshape(logits, self.origin_shape)
  181. else:
  182. return_value = self.log_softmax(logits)
  183. return return_value
  184. class CrossEntropyCalculation(nn.Cell):
  185. """
  186. Cross Entropy loss
  187. """
  188. def __init__(self, is_training=True):
  189. super(CrossEntropyCalculation, self).__init__()
  190. self.onehot = P.OneHot()
  191. self.on_value = Tensor(1.0, mstype.float32)
  192. self.off_value = Tensor(0.0, mstype.float32)
  193. self.reduce_sum = P.ReduceSum()
  194. self.reduce_mean = P.ReduceMean()
  195. self.reshape = P.Reshape()
  196. self.last_idx = (-1,)
  197. self.neg = P.Neg()
  198. self.cast = P.Cast()
  199. self.is_training = is_training
  200. def construct(self, logits, label_ids, num_labels):
  201. if self.is_training:
  202. label_ids = self.reshape(label_ids, self.last_idx)
  203. one_hot_labels = self.onehot(label_ids, num_labels, self.on_value, self.off_value)
  204. per_example_loss = self.neg(self.reduce_sum(one_hot_labels * logits, self.last_idx))
  205. loss = self.reduce_mean(per_example_loss, self.last_idx)
  206. return_value = self.cast(loss, mstype.float32)
  207. else:
  208. return_value = logits * 1.0
  209. return return_value
  210. class BertCLS(nn.Cell):
  211. """
  212. Train interface for classification finetuning task.
  213. """
  214. def __init__(self, config, is_training, num_labels=2, dropout_prob=0.0, use_one_hot_embeddings=False):
  215. super(BertCLS, self).__init__()
  216. self.bert = BertCLSModel(config, is_training, num_labels, dropout_prob, use_one_hot_embeddings)
  217. self.loss = CrossEntropyCalculation(is_training)
  218. self.num_labels = num_labels
  219. def construct(self, input_ids, input_mask, token_type_id, label_ids):
  220. log_probs = self.bert(input_ids, input_mask, token_type_id)
  221. loss = self.loss(log_probs, label_ids, self.num_labels)
  222. return loss
  223. class BertNER(nn.Cell):
  224. """
  225. Train interface for sequence labeling finetuning task.
  226. """
  227. def __init__(self, config, is_training, num_labels=11, use_crf=False, tag_to_index=None, dropout_prob=0.0,
  228. use_one_hot_embeddings=False):
  229. super(BertNER, self).__init__()
  230. self.bert = BertNERModel(config, is_training, num_labels, use_crf, dropout_prob, use_one_hot_embeddings)
  231. if use_crf:
  232. if not tag_to_index:
  233. raise Exception("The dict for tag-index mapping should be provided for CRF.")
  234. self.loss = CRF(tag_to_index, config.batch_size, config.seq_length, is_training)
  235. else:
  236. self.loss = CrossEntropyCalculation(is_training)
  237. self.num_labels = num_labels
  238. self.use_crf = use_crf
  239. def construct(self, input_ids, input_mask, token_type_id, label_ids):
  240. logits = self.bert(input_ids, input_mask, token_type_id)
  241. if self.use_crf:
  242. loss = self.loss(logits, label_ids)
  243. else:
  244. loss = self.loss(logits, label_ids, self.num_labels)
  245. return loss