You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 19 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. '''
  16. Functional Cells used in Bert finetune and evaluation.
  17. '''
  18. import mindspore.nn as nn
  19. from mindspore.common.initializer import TruncatedNormal
  20. from mindspore.ops import operations as P
  21. from mindspore.ops import functional as F
  22. from mindspore.ops import composite as C
  23. from mindspore.common.tensor import Tensor
  24. from mindspore.common.parameter import Parameter, ParameterTuple
  25. from mindspore.common import dtype as mstype
  26. from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
  27. from mindspore.train.parallel_utils import ParallelMode
  28. from mindspore.communication.management import get_group_size
  29. from mindspore import context
  30. from .bert_model import BertModel
  31. from .bert_for_pre_training import clip_grad
  32. from .CRF import CRF
  33. GRADIENT_CLIP_TYPE = 1
  34. GRADIENT_CLIP_VALUE = 1.0
  35. grad_scale = C.MultitypeFuncGraph("grad_scale")
  36. reciprocal = P.Reciprocal()
  37. @grad_scale.register("Tensor", "Tensor")
  38. def tensor_grad_scale(scale, grad):
  39. return grad * reciprocal(scale)
  40. _grad_overflow = C.MultitypeFuncGraph("_grad_overflow")
  41. grad_overflow = P.FloatStatus()
  42. @_grad_overflow.register("Tensor")
  43. def _tensor_grad_overflow(grad):
  44. return grad_overflow(grad)
  45. class BertFinetuneCell(nn.Cell):
  46. """
  47. Especifically defined for finetuning where only four inputs tensor are needed.
  48. """
  49. def __init__(self, network, optimizer, scale_update_cell=None):
  50. super(BertFinetuneCell, self).__init__(auto_prefix=False)
  51. self.network = network
  52. self.weights = ParameterTuple(network.trainable_params())
  53. self.optimizer = optimizer
  54. self.grad = C.GradOperation('grad',
  55. get_by_list=True,
  56. sens_param=True)
  57. self.reducer_flag = False
  58. self.allreduce = P.AllReduce()
  59. self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
  60. if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
  61. self.reducer_flag = True
  62. self.grad_reducer = None
  63. if self.reducer_flag:
  64. mean = context.get_auto_parallel_context("mirror_mean")
  65. degree = get_group_size()
  66. self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
  67. self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
  68. self.cast = P.Cast()
  69. self.gpu_target = False
  70. if context.get_context("device_target") == "GPU":
  71. self.gpu_target = True
  72. self.float_status = P.FloatStatus()
  73. self.addn = P.AddN()
  74. self.reshape = P.Reshape()
  75. else:
  76. self.alloc_status = P.NPUAllocFloatStatus()
  77. self.get_status = P.NPUGetFloatStatus()
  78. self.clear_before_grad = P.NPUClearFloatStatus()
  79. self.reduce_sum = P.ReduceSum(keep_dims=False)
  80. self.depend_parameter_use = P.ControlDepend(depend_mode=1)
  81. self.base = Tensor(1, mstype.float32)
  82. self.less_equal = P.LessEqual()
  83. self.hyper_map = C.HyperMap()
  84. self.loss_scale = None
  85. self.loss_scaling_manager = scale_update_cell
  86. if scale_update_cell:
  87. self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32),
  88. name="loss_scale")
  89. def construct(self,
  90. input_ids,
  91. input_mask,
  92. token_type_id,
  93. label_ids,
  94. sens=None):
  95. weights = self.weights
  96. init = False
  97. loss = self.network(input_ids,
  98. input_mask,
  99. token_type_id,
  100. label_ids)
  101. if sens is None:
  102. scaling_sens = self.loss_scale
  103. else:
  104. scaling_sens = sens
  105. if not self.gpu_target:
  106. init = self.alloc_status()
  107. clear_before_grad = self.clear_before_grad(init)
  108. F.control_depend(loss, init)
  109. self.depend_parameter_use(clear_before_grad, scaling_sens)
  110. grads = self.grad(self.network, weights)(input_ids,
  111. input_mask,
  112. token_type_id,
  113. label_ids,
  114. self.cast(scaling_sens,
  115. mstype.float32))
  116. grads = self.hyper_map(F.partial(grad_scale, scaling_sens), grads)
  117. grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
  118. if self.reducer_flag:
  119. grads = self.grad_reducer(grads)
  120. if not self.gpu_target:
  121. flag = self.get_status(init)
  122. flag_sum = self.reduce_sum(init, (0,))
  123. F.control_depend(grads, flag)
  124. F.control_depend(flag, flag_sum)
  125. else:
  126. flag_sum = self.hyper_map(F.partial(_grad_overflow), grads)
  127. flag_sum = self.addn(flag_sum)
  128. flag_sum = self.reshape(flag_sum, (()))
  129. if self.is_distributed:
  130. flag_reduce = self.allreduce(flag_sum)
  131. cond = self.less_equal(self.base, flag_reduce)
  132. else:
  133. cond = self.less_equal(self.base, flag_sum)
  134. overflow = cond
  135. if sens is None:
  136. overflow = self.loss_scaling_manager(self.loss_scale, cond)
  137. if overflow:
  138. succ = False
  139. else:
  140. succ = self.optimizer(grads)
  141. ret = (loss, cond)
  142. return F.depend(ret, succ)
  143. class BertSquadCell(nn.Cell):
  144. """
  145. specifically defined for finetuning where only four inputs tensor are needed.
  146. """
  147. def __init__(self, network, optimizer, scale_update_cell=None):
  148. super(BertSquadCell, self).__init__(auto_prefix=False)
  149. self.network = network
  150. self.weights = ParameterTuple(network.trainable_params())
  151. self.optimizer = optimizer
  152. self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True)
  153. self.reducer_flag = False
  154. self.allreduce = P.AllReduce()
  155. self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
  156. if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
  157. self.reducer_flag = True
  158. self.grad_reducer = None
  159. if self.reducer_flag:
  160. mean = context.get_auto_parallel_context("mirror_mean")
  161. degree = get_group_size()
  162. self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
  163. self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
  164. self.cast = P.Cast()
  165. self.alloc_status = P.NPUAllocFloatStatus()
  166. self.get_status = P.NPUGetFloatStatus()
  167. self.clear_before_grad = P.NPUClearFloatStatus()
  168. self.reduce_sum = P.ReduceSum(keep_dims=False)
  169. self.depend_parameter_use = P.ControlDepend(depend_mode=1)
  170. self.base = Tensor(1, mstype.float32)
  171. self.less_equal = P.LessEqual()
  172. self.hyper_map = C.HyperMap()
  173. self.loss_scale = None
  174. self.loss_scaling_manager = scale_update_cell
  175. if scale_update_cell:
  176. self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32),
  177. name="loss_scale")
  178. def construct(self,
  179. input_ids,
  180. input_mask,
  181. token_type_id,
  182. start_position,
  183. end_position,
  184. unique_id,
  185. is_impossible,
  186. sens=None):
  187. weights = self.weights
  188. init = self.alloc_status()
  189. loss = self.network(input_ids,
  190. input_mask,
  191. token_type_id,
  192. start_position,
  193. end_position,
  194. unique_id,
  195. is_impossible)
  196. if sens is None:
  197. scaling_sens = self.loss_scale
  198. else:
  199. scaling_sens = sens
  200. grads = self.grad(self.network, weights)(input_ids,
  201. input_mask,
  202. token_type_id,
  203. start_position,
  204. end_position,
  205. unique_id,
  206. is_impossible,
  207. self.cast(scaling_sens,
  208. mstype.float32))
  209. clear_before_grad = self.clear_before_grad(init)
  210. F.control_depend(loss, init)
  211. self.depend_parameter_use(clear_before_grad, scaling_sens)
  212. grads = self.hyper_map(F.partial(grad_scale, scaling_sens), grads)
  213. grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
  214. if self.reducer_flag:
  215. grads = self.grad_reducer(grads)
  216. flag = self.get_status(init)
  217. flag_sum = self.reduce_sum(init, (0,))
  218. if self.is_distributed:
  219. flag_reduce = self.allreduce(flag_sum)
  220. cond = self.less_equal(self.base, flag_reduce)
  221. else:
  222. cond = self.less_equal(self.base, flag_sum)
  223. F.control_depend(grads, flag)
  224. F.control_depend(flag, flag_sum)
  225. overflow = cond
  226. if sens is None:
  227. overflow = self.loss_scaling_manager(self.loss_scale, cond)
  228. if overflow:
  229. succ = False
  230. else:
  231. succ = self.optimizer(grads)
  232. ret = (loss, cond)
  233. return F.depend(ret, succ)
  234. class BertCLSModel(nn.Cell):
  235. """
  236. This class is responsible for classification task evaluation, i.e. XNLI(num_labels=3),
  237. LCQMC(num_labels=2), Chnsenti(num_labels=2). The returned output represents the final
  238. logits as the results of log_softmax is propotional to that of softmax.
  239. """
  240. def __init__(self, config, is_training, num_labels=2, dropout_prob=0.0, use_one_hot_embeddings=False):
  241. super(BertCLSModel, self).__init__()
  242. self.bert = BertModel(config, is_training, use_one_hot_embeddings)
  243. self.cast = P.Cast()
  244. self.weight_init = TruncatedNormal(config.initializer_range)
  245. self.log_softmax = P.LogSoftmax(axis=-1)
  246. self.dtype = config.dtype
  247. self.num_labels = num_labels
  248. self.dense_1 = nn.Dense(config.hidden_size, self.num_labels, weight_init=self.weight_init,
  249. has_bias=True).to_float(config.compute_type)
  250. self.dropout = nn.Dropout(1 - dropout_prob)
  251. def construct(self, input_ids, input_mask, token_type_id):
  252. _, pooled_output, _ = \
  253. self.bert(input_ids, token_type_id, input_mask)
  254. cls = self.cast(pooled_output, self.dtype)
  255. cls = self.dropout(cls)
  256. logits = self.dense_1(cls)
  257. logits = self.cast(logits, self.dtype)
  258. log_probs = self.log_softmax(logits)
  259. return log_probs
  260. class BertSquadModel(nn.Cell):
  261. '''
  262. This class is responsible for SQuAD
  263. '''
  264. def __init__(self, config, is_training, num_labels=2, dropout_prob=0.0, use_one_hot_embeddings=False):
  265. super(BertSquadModel, self).__init__()
  266. self.bert = BertModel(config, is_training, use_one_hot_embeddings)
  267. self.weight_init = TruncatedNormal(config.initializer_range)
  268. self.dense1 = nn.Dense(config.hidden_size, num_labels, weight_init=self.weight_init,
  269. has_bias=True).to_float(config.compute_type)
  270. self.num_labels = num_labels
  271. self.dtype = config.dtype
  272. self.log_softmax = P.LogSoftmax(axis=1)
  273. self.is_training = is_training
  274. def construct(self, input_ids, input_mask, token_type_id):
  275. sequence_output, _, _ = self.bert(input_ids, token_type_id, input_mask)
  276. batch_size, seq_length, hidden_size = P.Shape()(sequence_output)
  277. sequence = P.Reshape()(sequence_output, (-1, hidden_size))
  278. logits = self.dense1(sequence)
  279. logits = P.Cast()(logits, self.dtype)
  280. logits = P.Reshape()(logits, (batch_size, seq_length, self.num_labels))
  281. logits = self.log_softmax(logits)
  282. return logits
  283. class BertNERModel(nn.Cell):
  284. """
  285. This class is responsible for sequence labeling task evaluation, i.e. NER(num_labels=11).
  286. The returned output represents the final logits as the results of log_softmax is propotional to that of softmax.
  287. """
  288. def __init__(self, config, is_training, num_labels=11, use_crf=False, dropout_prob=0.0,
  289. use_one_hot_embeddings=False):
  290. super(BertNERModel, self).__init__()
  291. self.bert = BertModel(config, is_training, use_one_hot_embeddings)
  292. self.cast = P.Cast()
  293. self.weight_init = TruncatedNormal(config.initializer_range)
  294. self.log_softmax = P.LogSoftmax(axis=-1)
  295. self.dtype = config.dtype
  296. self.num_labels = num_labels
  297. self.dense_1 = nn.Dense(config.hidden_size, self.num_labels, weight_init=self.weight_init,
  298. has_bias=True).to_float(config.compute_type)
  299. self.dropout = nn.Dropout(1 - dropout_prob)
  300. self.reshape = P.Reshape()
  301. self.shape = (-1, config.hidden_size)
  302. self.use_crf = use_crf
  303. self.origin_shape = (config.batch_size, config.seq_length, self.num_labels)
  304. def construct(self, input_ids, input_mask, token_type_id):
  305. sequence_output, _, _ = \
  306. self.bert(input_ids, token_type_id, input_mask)
  307. seq = self.dropout(sequence_output)
  308. seq = self.reshape(seq, self.shape)
  309. logits = self.dense_1(seq)
  310. logits = self.cast(logits, self.dtype)
  311. if self.use_crf:
  312. return_value = self.reshape(logits, self.origin_shape)
  313. else:
  314. return_value = self.log_softmax(logits)
  315. return return_value
  316. class CrossEntropyCalculation(nn.Cell):
  317. """
  318. Cross Entropy loss
  319. """
  320. def __init__(self, is_training=True):
  321. super(CrossEntropyCalculation, self).__init__()
  322. self.onehot = P.OneHot()
  323. self.on_value = Tensor(1.0, mstype.float32)
  324. self.off_value = Tensor(0.0, mstype.float32)
  325. self.reduce_sum = P.ReduceSum()
  326. self.reduce_mean = P.ReduceMean()
  327. self.reshape = P.Reshape()
  328. self.last_idx = (-1,)
  329. self.neg = P.Neg()
  330. self.cast = P.Cast()
  331. self.is_training = is_training
  332. def construct(self, logits, label_ids, num_labels):
  333. if self.is_training:
  334. label_ids = self.reshape(label_ids, self.last_idx)
  335. one_hot_labels = self.onehot(label_ids, num_labels, self.on_value, self.off_value)
  336. per_example_loss = self.neg(self.reduce_sum(one_hot_labels * logits, self.last_idx))
  337. loss = self.reduce_mean(per_example_loss, self.last_idx)
  338. return_value = self.cast(loss, mstype.float32)
  339. else:
  340. return_value = logits * 1.0
  341. return return_value
  342. class BertCLS(nn.Cell):
  343. """
  344. Train interface for classification finetuning task.
  345. """
  346. def __init__(self, config, is_training, num_labels=2, dropout_prob=0.0, use_one_hot_embeddings=False):
  347. super(BertCLS, self).__init__()
  348. self.bert = BertCLSModel(config, is_training, num_labels, dropout_prob, use_one_hot_embeddings)
  349. self.loss = CrossEntropyCalculation(is_training)
  350. self.num_labels = num_labels
  351. def construct(self, input_ids, input_mask, token_type_id, label_ids):
  352. log_probs = self.bert(input_ids, input_mask, token_type_id)
  353. loss = self.loss(log_probs, label_ids, self.num_labels)
  354. return loss
  355. class BertNER(nn.Cell):
  356. """
  357. Train interface for sequence labeling finetuning task.
  358. """
  359. def __init__(self, config, is_training, num_labels=11, use_crf=False, tag_to_index=None, dropout_prob=0.0,
  360. use_one_hot_embeddings=False):
  361. super(BertNER, self).__init__()
  362. self.bert = BertNERModel(config, is_training, num_labels, use_crf, dropout_prob, use_one_hot_embeddings)
  363. if use_crf:
  364. if not tag_to_index:
  365. raise Exception("The dict for tag-index mapping should be provided for CRF.")
  366. self.loss = CRF(tag_to_index, config.batch_size, config.seq_length, is_training)
  367. else:
  368. self.loss = CrossEntropyCalculation(is_training)
  369. self.num_labels = num_labels
  370. self.use_crf = use_crf
  371. def construct(self, input_ids, input_mask, token_type_id, label_ids):
  372. logits = self.bert(input_ids, input_mask, token_type_id)
  373. if self.use_crf:
  374. loss = self.loss(logits, label_ids)
  375. else:
  376. loss = self.loss(logits, label_ids, self.num_labels)
  377. return loss
  378. class BertSquad(nn.Cell):
  379. '''
  380. Train interface for SQuAD finetuning task.
  381. '''
  382. def __init__(self, config, is_training, num_labels=2, dropout_prob=0.0, use_one_hot_embeddings=False):
  383. super(BertSquad, self).__init__()
  384. self.bert = BertSquadModel(config, is_training, num_labels, dropout_prob, use_one_hot_embeddings)
  385. self.loss = CrossEntropyCalculation(is_training)
  386. self.num_labels = num_labels
  387. self.seq_length = config.seq_length
  388. self.is_training = is_training
  389. self.total_num = Parameter(Tensor([0], mstype.float32), name='total_num')
  390. self.start_num = Parameter(Tensor([0], mstype.float32), name='start_num')
  391. self.end_num = Parameter(Tensor([0], mstype.float32), name='end_num')
  392. self.sum = P.ReduceSum()
  393. self.equal = P.Equal()
  394. self.argmax = P.ArgMaxWithValue(axis=1)
  395. self.squeeze = P.Squeeze(axis=-1)
  396. def construct(self, input_ids, input_mask, token_type_id, start_position, end_position, unique_id, is_impossible):
  397. logits = self.bert(input_ids, input_mask, token_type_id)
  398. if self.is_training:
  399. unstacked_logits_0 = self.squeeze(logits[:, :, 0:1])
  400. unstacked_logits_1 = self.squeeze(logits[:, :, 1:2])
  401. start_loss = self.loss(unstacked_logits_0, start_position, self.seq_length)
  402. end_loss = self.loss(unstacked_logits_1, end_position, self.seq_length)
  403. total_loss = (start_loss + end_loss) / 2.0
  404. else:
  405. start_logits = self.squeeze(logits[:, :, 0:1])
  406. end_logits = self.squeeze(logits[:, :, 1:2])
  407. total_loss = (unique_id, start_logits, end_logits)
  408. return total_loss