You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. '''
  16. Functional Cells used in Bert finetune and evaluation.
  17. '''
  18. import mindspore.nn as nn
  19. from mindspore.common.initializer import TruncatedNormal
  20. from mindspore.ops import operations as P
  21. from mindspore.ops import functional as F
  22. from mindspore.ops import composite as C
  23. from mindspore.common.tensor import Tensor
  24. from mindspore.common.parameter import Parameter, ParameterTuple
  25. from mindspore.common import dtype as mstype
  26. from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
  27. from mindspore.train.parallel_utils import ParallelMode
  28. from mindspore.communication.management import get_group_size
  29. from mindspore import context
  30. from mindspore.model_zoo.Bert_NEZHA.bert_model import BertModel
  31. from mindspore.model_zoo.Bert_NEZHA.bert_for_pre_training import clip_grad
  32. from CRF import CRF
  33. GRADIENT_CLIP_TYPE = 1
  34. GRADIENT_CLIP_VALUE = 1.0
  35. grad_scale = C.MultitypeFuncGraph("grad_scale")
  36. reciprocal = P.Reciprocal()
  37. @grad_scale.register("Tensor", "Tensor")
  38. def tensor_grad_scale(scale, grad):
  39. return grad * reciprocal(scale)
  40. class BertFinetuneCell(nn.Cell):
  41. """
  42. Especifically defined for finetuning where only four inputs tensor are needed.
  43. """
  44. def __init__(self, network, optimizer, scale_update_cell=None):
  45. super(BertFinetuneCell, self).__init__(auto_prefix=False)
  46. self.network = network
  47. self.weights = ParameterTuple(network.trainable_params())
  48. self.optimizer = optimizer
  49. self.grad = C.GradOperation('grad',
  50. get_by_list=True,
  51. sens_param=True)
  52. self.reducer_flag = False
  53. self.allreduce = P.AllReduce()
  54. self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
  55. if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
  56. self.reducer_flag = True
  57. self.grad_reducer = None
  58. if self.reducer_flag:
  59. mean = context.get_auto_parallel_context("mirror_mean")
  60. degree = get_group_size()
  61. self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
  62. self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
  63. self.cast = P.Cast()
  64. self.alloc_status = P.NPUAllocFloatStatus()
  65. self.get_status = P.NPUGetFloatStatus()
  66. self.clear_before_grad = P.NPUClearFloatStatus()
  67. self.reduce_sum = P.ReduceSum(keep_dims=False)
  68. self.depend_parameter_use = P.ControlDepend(depend_mode=1)
  69. self.base = Tensor(1, mstype.float32)
  70. self.less_equal = P.LessEqual()
  71. self.hyper_map = C.HyperMap()
  72. self.loss_scale = None
  73. self.loss_scaling_manager = scale_update_cell
  74. if scale_update_cell:
  75. self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32),
  76. name="loss_scale")
  77. def construct(self,
  78. input_ids,
  79. input_mask,
  80. token_type_id,
  81. label_ids,
  82. sens=None):
  83. weights = self.weights
  84. init = self.alloc_status()
  85. loss = self.network(input_ids,
  86. input_mask,
  87. token_type_id,
  88. label_ids)
  89. if sens is None:
  90. scaling_sens = self.loss_scale
  91. else:
  92. scaling_sens = sens
  93. grads = self.grad(self.network, weights)(input_ids,
  94. input_mask,
  95. token_type_id,
  96. label_ids,
  97. self.cast(scaling_sens,
  98. mstype.float32))
  99. clear_before_grad = self.clear_before_grad(init)
  100. F.control_depend(loss, init)
  101. self.depend_parameter_use(clear_before_grad, scaling_sens)
  102. grads = self.hyper_map(F.partial(grad_scale, scaling_sens), grads)
  103. grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
  104. if self.reducer_flag:
  105. grads = self.grad_reducer(grads)
  106. flag = self.get_status(init)
  107. flag_sum = self.reduce_sum(init, (0,))
  108. if self.is_distributed:
  109. flag_reduce = self.allreduce(flag_sum)
  110. cond = self.less_equal(self.base, flag_reduce)
  111. else:
  112. cond = self.less_equal(self.base, flag_sum)
  113. F.control_depend(grads, flag)
  114. F.control_depend(flag, flag_sum)
  115. overflow = cond
  116. if sens is None:
  117. overflow = self.loss_scaling_manager(self.loss_scale, cond)
  118. if overflow:
  119. succ = False
  120. else:
  121. succ = self.optimizer(grads)
  122. ret = (loss, cond)
  123. return F.depend(ret, succ)
  124. class BertCLSModel(nn.Cell):
  125. """
  126. This class is responsible for classification task evaluation, i.e. XNLI(num_labels=3),
  127. LCQMC(num_labels=2), Chnsenti(num_labels=2). The returned output represents the final
  128. logits as the results of log_softmax is propotional to that of softmax.
  129. """
  130. def __init__(self, config, is_training, num_labels=2, dropout_prob=0.0, use_one_hot_embeddings=False):
  131. super(BertCLSModel, self).__init__()
  132. self.bert = BertModel(config, is_training, use_one_hot_embeddings)
  133. self.cast = P.Cast()
  134. self.weight_init = TruncatedNormal(config.initializer_range)
  135. self.log_softmax = P.LogSoftmax(axis=-1)
  136. self.dtype = config.dtype
  137. self.num_labels = num_labels
  138. self.dense_1 = nn.Dense(config.hidden_size, self.num_labels, weight_init=self.weight_init,
  139. has_bias=True).to_float(config.compute_type)
  140. self.dropout = nn.Dropout(1 - dropout_prob)
  141. def construct(self, input_ids, input_mask, token_type_id):
  142. _, pooled_output, _ = \
  143. self.bert(input_ids, token_type_id, input_mask)
  144. cls = self.cast(pooled_output, self.dtype)
  145. cls = self.dropout(cls)
  146. logits = self.dense_1(cls)
  147. logits = self.cast(logits, self.dtype)
  148. log_probs = self.log_softmax(logits)
  149. return log_probs
  150. class BertNERModel(nn.Cell):
  151. """
  152. This class is responsible for sequence labeling task evaluation, i.e. NER(num_labels=11).
  153. The returned output represents the final logits as the results of log_softmax is propotional to that of softmax.
  154. """
  155. def __init__(self, config, is_training, num_labels=11, use_crf=False, dropout_prob=0.0,
  156. use_one_hot_embeddings=False):
  157. super(BertNERModel, self).__init__()
  158. self.bert = BertModel(config, is_training, use_one_hot_embeddings)
  159. self.cast = P.Cast()
  160. self.weight_init = TruncatedNormal(config.initializer_range)
  161. self.log_softmax = P.LogSoftmax(axis=-1)
  162. self.dtype = config.dtype
  163. self.num_labels = num_labels
  164. self.dense_1 = nn.Dense(config.hidden_size, self.num_labels, weight_init=self.weight_init,
  165. has_bias=True).to_float(config.compute_type)
  166. self.dropout = nn.Dropout(1 - dropout_prob)
  167. self.reshape = P.Reshape()
  168. self.shape = (-1, config.hidden_size)
  169. self.use_crf = use_crf
  170. self.origin_shape = (config.batch_size, config.seq_length, self.num_labels)
  171. def construct(self, input_ids, input_mask, token_type_id):
  172. sequence_output, _, _ = \
  173. self.bert(input_ids, token_type_id, input_mask)
  174. seq = self.dropout(sequence_output)
  175. seq = self.reshape(seq, self.shape)
  176. logits = self.dense_1(seq)
  177. logits = self.cast(logits, self.dtype)
  178. if self.use_crf:
  179. return_value = self.reshape(logits, self.origin_shape)
  180. else:
  181. return_value = self.log_softmax(logits)
  182. return return_value
  183. class CrossEntropyCalculation(nn.Cell):
  184. """
  185. Cross Entropy loss
  186. """
  187. def __init__(self, is_training=True):
  188. super(CrossEntropyCalculation, self).__init__()
  189. self.onehot = P.OneHot()
  190. self.on_value = Tensor(1.0, mstype.float32)
  191. self.off_value = Tensor(0.0, mstype.float32)
  192. self.reduce_sum = P.ReduceSum()
  193. self.reduce_mean = P.ReduceMean()
  194. self.reshape = P.Reshape()
  195. self.last_idx = (-1,)
  196. self.neg = P.Neg()
  197. self.cast = P.Cast()
  198. self.is_training = is_training
  199. def construct(self, logits, label_ids, num_labels):
  200. if self.is_training:
  201. label_ids = self.reshape(label_ids, self.last_idx)
  202. one_hot_labels = self.onehot(label_ids, num_labels, self.on_value, self.off_value)
  203. per_example_loss = self.neg(self.reduce_sum(one_hot_labels * logits, self.last_idx))
  204. loss = self.reduce_mean(per_example_loss, self.last_idx)
  205. return_value = self.cast(loss, mstype.float32)
  206. else:
  207. return_value = logits * 1.0
  208. return return_value
  209. class BertCLS(nn.Cell):
  210. """
  211. Train interface for classification finetuning task.
  212. """
  213. def __init__(self, config, is_training, num_labels=2, dropout_prob=0.0, use_one_hot_embeddings=False):
  214. super(BertCLS, self).__init__()
  215. self.bert = BertCLSModel(config, is_training, num_labels, dropout_prob, use_one_hot_embeddings)
  216. self.loss = CrossEntropyCalculation(is_training)
  217. self.num_labels = num_labels
  218. def construct(self, input_ids, input_mask, token_type_id, label_ids):
  219. log_probs = self.bert(input_ids, input_mask, token_type_id)
  220. loss = self.loss(log_probs, label_ids, self.num_labels)
  221. return loss
  222. class BertNER(nn.Cell):
  223. """
  224. Train interface for sequence labeling finetuning task.
  225. """
  226. def __init__(self, config, is_training, num_labels=11, use_crf=False, tag_to_index=None, dropout_prob=0.0,
  227. use_one_hot_embeddings=False):
  228. super(BertNER, self).__init__()
  229. self.bert = BertNERModel(config, is_training, num_labels, use_crf, dropout_prob, use_one_hot_embeddings)
  230. if use_crf:
  231. if not tag_to_index:
  232. raise Exception("The dict for tag-index mapping should be provided for CRF.")
  233. self.loss = CRF(tag_to_index, config.batch_size, config.seq_length, is_training)
  234. else:
  235. self.loss = CrossEntropyCalculation(is_training)
  236. self.num_labels = num_labels
  237. self.use_crf = use_crf
  238. def construct(self, input_ids, input_mask, token_type_id, label_ids):
  239. logits = self.bert(input_ids, input_mask, token_type_id)
  240. if self.use_crf:
  241. loss = self.loss(logits, label_ids)
  242. else:
  243. loss = self.loss(logits, label_ids, self.num_labels)
  244. return loss