You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

bert_model.py 45 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Bert model."""
  16. import math
  17. import copy
  18. import numpy as np
  19. import mindspore.common.dtype as mstype
  20. import mindspore.nn as nn
  21. import mindspore.ops.functional as F
  22. from mindspore.common.initializer import TruncatedNormal, initializer
  23. from mindspore.ops import operations as P
  24. from mindspore.ops import composite as C
  25. from mindspore.common.tensor import Tensor
  26. from mindspore.common.parameter import Parameter
  27. from .fused_layer_norm import FusedLayerNorm
  28. class BertConfig:
  29. """
  30. Configuration for `BertModel`.
  31. Args:
  32. batch_size (int): Batch size of input dataset.
  33. seq_length (int): Length of input sequence. Default: 128.
  34. vocab_size (int): The shape of each embedding vector. Default: 32000.
  35. hidden_size (int): Size of the bert encoder layers. Default: 768.
  36. num_hidden_layers (int): Number of hidden layers in the BertTransformer encoder
  37. cell. Default: 12.
  38. num_attention_heads (int): Number of attention heads in the BertTransformer
  39. encoder cell. Default: 12.
  40. intermediate_size (int): Size of intermediate layer in the BertTransformer
  41. encoder cell. Default: 3072.
  42. hidden_act (str): Activation function used in the BertTransformer encoder
  43. cell. Default: "gelu".
  44. hidden_dropout_prob (float): The dropout probability for BertOutput. Default: 0.1.
  45. attention_probs_dropout_prob (float): The dropout probability for
  46. BertAttention. Default: 0.1.
  47. max_position_embeddings (int): Maximum length of sequences used in this
  48. model. Default: 512.
  49. type_vocab_size (int): Size of token type vocab. Default: 16.
  50. initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
  51. use_relative_positions (bool): Specifies whether to use relative positions. Default: False.
  52. input_mask_from_dataset (bool): Specifies whether to use the input mask that loaded from
  53. dataset. Default: True.
  54. token_type_ids_from_dataset (bool): Specifies whether to use the token type ids that loaded
  55. from dataset. Default: True.
  56. dtype (:class:`mindspore.dtype`): Data type of the input. Default: mstype.float32.
  57. compute_type (:class:`mindspore.dtype`): Compute type in BertTransformer. Default: mstype.float32.
  58. """
  59. def __init__(self,
  60. batch_size,
  61. seq_length=128,
  62. vocab_size=32000,
  63. hidden_size=768,
  64. num_hidden_layers=12,
  65. num_attention_heads=12,
  66. intermediate_size=3072,
  67. hidden_act="gelu",
  68. hidden_dropout_prob=0.1,
  69. attention_probs_dropout_prob=0.1,
  70. max_position_embeddings=512,
  71. type_vocab_size=16,
  72. initializer_range=0.02,
  73. use_relative_positions=False,
  74. input_mask_from_dataset=True,
  75. token_type_ids_from_dataset=True,
  76. dtype=mstype.float32,
  77. compute_type=mstype.float32,
  78. enable_fused_layernorm=False):
  79. self.batch_size = batch_size
  80. self.seq_length = seq_length
  81. self.vocab_size = vocab_size
  82. self.hidden_size = hidden_size
  83. self.num_hidden_layers = num_hidden_layers
  84. self.num_attention_heads = num_attention_heads
  85. self.hidden_act = hidden_act
  86. self.intermediate_size = intermediate_size
  87. self.hidden_dropout_prob = hidden_dropout_prob
  88. self.attention_probs_dropout_prob = attention_probs_dropout_prob
  89. self.max_position_embeddings = max_position_embeddings
  90. self.type_vocab_size = type_vocab_size
  91. self.initializer_range = initializer_range
  92. self.input_mask_from_dataset = input_mask_from_dataset
  93. self.token_type_ids_from_dataset = token_type_ids_from_dataset
  94. self.use_relative_positions = use_relative_positions
  95. self.dtype = dtype
  96. self.compute_type = compute_type
  97. self.enable_fused_layernorm = enable_fused_layernorm
  98. class EmbeddingLookup(nn.Cell):
  99. """
  100. A embeddings lookup table with a fixed dictionary and size.
  101. Args:
  102. vocab_size (int): Size of the dictionary of embeddings.
  103. embedding_size (int): The size of each embedding vector.
  104. embedding_shape (list): [batch_size, seq_length, embedding_size], the shape of
  105. each embedding vector.
  106. use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
  107. initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
  108. """
  109. def __init__(self,
  110. vocab_size,
  111. embedding_size,
  112. embedding_shape,
  113. use_one_hot_embeddings=False,
  114. initializer_range=0.02):
  115. super(EmbeddingLookup, self).__init__()
  116. self.vocab_size = vocab_size
  117. self.use_one_hot_embeddings = use_one_hot_embeddings
  118. self.embedding_table = Parameter(initializer
  119. (TruncatedNormal(initializer_range),
  120. [vocab_size, embedding_size]),
  121. name='embedding_table')
  122. self.expand = P.ExpandDims()
  123. self.shape_flat = (-1,)
  124. self.gather = P.GatherV2()
  125. self.one_hot = P.OneHot()
  126. self.on_value = Tensor(1.0, mstype.float32)
  127. self.off_value = Tensor(0.0, mstype.float32)
  128. self.array_mul = P.MatMul()
  129. self.reshape = P.Reshape()
  130. self.shape = tuple(embedding_shape)
  131. def construct(self, input_ids):
  132. extended_ids = self.expand(input_ids, -1)
  133. flat_ids = self.reshape(extended_ids, self.shape_flat)
  134. if self.use_one_hot_embeddings:
  135. one_hot_ids = self.one_hot(flat_ids, self.vocab_size, self.on_value, self.off_value)
  136. output_for_reshape = self.array_mul(
  137. one_hot_ids, self.embedding_table)
  138. else:
  139. output_for_reshape = self.gather(self.embedding_table, flat_ids, 0)
  140. output = self.reshape(output_for_reshape, self.shape)
  141. return output, self.embedding_table
  142. class EmbeddingPostprocessor(nn.Cell):
  143. """
  144. Postprocessors apply positional and token type embeddings to word embeddings.
  145. Args:
  146. embedding_size (int): The size of each embedding vector.
  147. embedding_shape (list): [batch_size, seq_length, embedding_size], the shape of
  148. each embedding vector.
  149. use_token_type (bool): Specifies whether to use token type embeddings. Default: False.
  150. token_type_vocab_size (int): Size of token type vocab. Default: 16.
  151. use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
  152. initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
  153. max_position_embeddings (int): Maximum length of sequences used in this
  154. model. Default: 512.
  155. dropout_prob (float): The dropout probability. Default: 0.1.
  156. """
  157. def __init__(self,
  158. embedding_size,
  159. embedding_shape,
  160. use_relative_positions=False,
  161. use_token_type=False,
  162. token_type_vocab_size=16,
  163. use_one_hot_embeddings=False,
  164. initializer_range=0.02,
  165. max_position_embeddings=512,
  166. dropout_prob=0.1):
  167. super(EmbeddingPostprocessor, self).__init__()
  168. self.use_token_type = use_token_type
  169. self.token_type_vocab_size = token_type_vocab_size
  170. self.use_one_hot_embeddings = use_one_hot_embeddings
  171. self.max_position_embeddings = max_position_embeddings
  172. self.embedding_table = Parameter(initializer
  173. (TruncatedNormal(initializer_range),
  174. [token_type_vocab_size,
  175. embedding_size]),
  176. name='embedding_table')
  177. self.shape_flat = (-1,)
  178. self.one_hot = P.OneHot()
  179. self.on_value = Tensor(1.0, mstype.float32)
  180. self.off_value = Tensor(0.1, mstype.float32)
  181. self.array_mul = P.MatMul()
  182. self.reshape = P.Reshape()
  183. self.shape = tuple(embedding_shape)
  184. self.layernorm = nn.LayerNorm((embedding_size,))
  185. self.dropout = nn.Dropout(1 - dropout_prob)
  186. self.gather = P.GatherV2()
  187. self.use_relative_positions = use_relative_positions
  188. self.slice = P.StridedSlice()
  189. self.full_position_embeddings = Parameter(initializer
  190. (TruncatedNormal(initializer_range),
  191. [max_position_embeddings,
  192. embedding_size]),
  193. name='full_position_embeddings')
  194. def construct(self, token_type_ids, word_embeddings):
  195. output = word_embeddings
  196. if self.use_token_type:
  197. flat_ids = self.reshape(token_type_ids, self.shape_flat)
  198. if self.use_one_hot_embeddings:
  199. one_hot_ids = self.one_hot(flat_ids,
  200. self.token_type_vocab_size, self.on_value, self.off_value)
  201. token_type_embeddings = self.array_mul(one_hot_ids,
  202. self.embedding_table)
  203. else:
  204. token_type_embeddings = self.gather(self.embedding_table, flat_ids, 0)
  205. token_type_embeddings = self.reshape(token_type_embeddings, self.shape)
  206. output += token_type_embeddings
  207. if not self.use_relative_positions:
  208. _, seq, width = self.shape
  209. position_embeddings = self.slice(self.full_position_embeddings, (0, 0), (seq, width), (1, 1))
  210. position_embeddings = self.reshape(position_embeddings, (1, seq, width))
  211. output += position_embeddings
  212. output = self.layernorm(output)
  213. output = self.dropout(output)
  214. return output
  215. class BertOutput(nn.Cell):
  216. """
  217. Apply a linear computation to hidden status and a residual computation to input.
  218. Args:
  219. in_channels (int): Input channels.
  220. out_channels (int): Output channels.
  221. initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
  222. dropout_prob (float): The dropout probability. Default: 0.1.
  223. compute_type (:class:`mindspore.dtype`): Compute type in BertTransformer. Default: mstype.float32.
  224. """
  225. def __init__(self,
  226. in_channels,
  227. out_channels,
  228. initializer_range=0.02,
  229. dropout_prob=0.1,
  230. compute_type=mstype.float32,
  231. enable_fused_layernorm=False):
  232. super(BertOutput, self).__init__()
  233. self.dense = nn.Dense(in_channels, out_channels,
  234. weight_init=TruncatedNormal(initializer_range)).to_float(compute_type)
  235. self.dropout = nn.Dropout(1 - dropout_prob)
  236. self.dropout_prob = dropout_prob
  237. self.add = P.TensorAdd()
  238. if compute_type == mstype.float16:
  239. self.layernorm = FusedLayerNorm((out_channels,),
  240. use_batch_norm=enable_fused_layernorm).to_float(compute_type)
  241. else:
  242. self.layernorm = nn.LayerNorm((out_channels,)).to_float(compute_type)
  243. self.cast = P.Cast()
  244. def construct(self, hidden_status, input_tensor):
  245. output = self.dense(hidden_status)
  246. output = self.dropout(output)
  247. output = self.add(input_tensor, output)
  248. output = self.layernorm(output)
  249. return output
  250. class RelaPosMatrixGenerator(nn.Cell):
  251. """
  252. Generates matrix of relative positions between inputs.
  253. Args:
  254. length (int): Length of one dim for the matrix to be generated.
  255. max_relative_position (int): Max value of relative position.
  256. """
  257. def __init__(self, length, max_relative_position):
  258. super(RelaPosMatrixGenerator, self).__init__()
  259. self._length = length
  260. self._max_relative_position = Tensor(max_relative_position, dtype=mstype.int32)
  261. self._min_relative_position = Tensor(-max_relative_position, dtype=mstype.int32)
  262. self.range_length = -length + 1
  263. self.tile = P.Tile()
  264. self.range_mat = P.Reshape()
  265. self.sub = P.Sub()
  266. self.expanddims = P.ExpandDims()
  267. self.cast = P.Cast()
  268. def construct(self):
  269. range_vec_row_out = self.cast(F.tuple_to_array(F.make_range(self._length)), mstype.int32)
  270. range_vec_col_out = self.range_mat(range_vec_row_out, (self._length, -1))
  271. tile_row_out = self.tile(range_vec_row_out, (self._length,))
  272. tile_col_out = self.tile(range_vec_col_out, (1, self._length))
  273. range_mat_out = self.range_mat(tile_row_out, (self._length, self._length))
  274. transpose_out = self.range_mat(tile_col_out, (self._length, self._length))
  275. distance_mat = self.sub(range_mat_out, transpose_out)
  276. distance_mat_clipped = C.clip_by_value(distance_mat,
  277. self._min_relative_position,
  278. self._max_relative_position)
  279. # Shift values to be >=0. Each integer still uniquely identifies a
  280. # relative position difference.
  281. final_mat = distance_mat_clipped + self._max_relative_position
  282. return final_mat
  283. class RelaPosEmbeddingsGenerator(nn.Cell):
  284. """
  285. Generates tensor of size [length, length, depth].
  286. Args:
  287. length (int): Length of one dim for the matrix to be generated.
  288. depth (int): Size of each attention head.
  289. max_relative_position (int): Maxmum value of relative position.
  290. initializer_range (float): Initialization value of TruncatedNormal.
  291. use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
  292. """
  293. def __init__(self,
  294. length,
  295. depth,
  296. max_relative_position,
  297. initializer_range,
  298. use_one_hot_embeddings=False):
  299. super(RelaPosEmbeddingsGenerator, self).__init__()
  300. self.depth = depth
  301. self.vocab_size = max_relative_position * 2 + 1
  302. self.use_one_hot_embeddings = use_one_hot_embeddings
  303. self.embeddings_table = Parameter(
  304. initializer(TruncatedNormal(initializer_range),
  305. [self.vocab_size, self.depth]),
  306. name='embeddings_for_position')
  307. self.relative_positions_matrix = RelaPosMatrixGenerator(length=length,
  308. max_relative_position=max_relative_position)
  309. self.reshape = P.Reshape()
  310. self.one_hot = P.OneHot()
  311. self.on_value = Tensor(1.0, mstype.float32)
  312. self.off_value = Tensor(0.0, mstype.float32)
  313. self.shape = P.Shape()
  314. self.gather = P.GatherV2() # index_select
  315. self.matmul = P.BatchMatMul()
  316. def construct(self):
  317. relative_positions_matrix_out = self.relative_positions_matrix()
  318. # Generate embedding for each relative position of dimension depth.
  319. if self.use_one_hot_embeddings:
  320. flat_relative_positions_matrix = self.reshape(relative_positions_matrix_out, (-1,))
  321. one_hot_relative_positions_matrix = self.one_hot(
  322. flat_relative_positions_matrix, self.vocab_size, self.on_value, self.off_value)
  323. embeddings = self.matmul(one_hot_relative_positions_matrix, self.embeddings_table)
  324. my_shape = self.shape(relative_positions_matrix_out) + (self.depth,)
  325. embeddings = self.reshape(embeddings, my_shape)
  326. else:
  327. embeddings = self.gather(self.embeddings_table,
  328. relative_positions_matrix_out, 0)
  329. return embeddings
  330. class SaturateCast(nn.Cell):
  331. """
  332. Performs a safe saturating cast. This operation applies proper clamping before casting to prevent
  333. the danger that the value will overflow or underflow.
  334. Args:
  335. src_type (:class:`mindspore.dtype`): The type of the elements of the input tensor. Default: mstype.float32.
  336. dst_type (:class:`mindspore.dtype`): The type of the elements of the output tensor. Default: mstype.float32.
  337. """
  338. def __init__(self, src_type=mstype.float32, dst_type=mstype.float32):
  339. super(SaturateCast, self).__init__()
  340. np_type = mstype.dtype_to_nptype(dst_type)
  341. min_type = np.finfo(np_type).min
  342. max_type = np.finfo(np_type).max
  343. self.tensor_min_type = Tensor([min_type], dtype=src_type)
  344. self.tensor_max_type = Tensor([max_type], dtype=src_type)
  345. self.min_op = P.Minimum()
  346. self.max_op = P.Maximum()
  347. self.cast = P.Cast()
  348. self.dst_type = dst_type
  349. def construct(self, x):
  350. out = self.max_op(x, self.tensor_min_type)
  351. out = self.min_op(out, self.tensor_max_type)
  352. return self.cast(out, self.dst_type)
  353. class BertAttention(nn.Cell):
  354. """
  355. Apply multi-headed attention from "from_tensor" to "to_tensor".
  356. Args:
  357. batch_size (int): Batch size of input datasets.
  358. from_tensor_width (int): Size of last dim of from_tensor.
  359. to_tensor_width (int): Size of last dim of to_tensor.
  360. from_seq_length (int): Length of from_tensor sequence.
  361. to_seq_length (int): Length of to_tensor sequence.
  362. num_attention_heads (int): Number of attention heads. Default: 1.
  363. size_per_head (int): Size of each attention head. Default: 512.
  364. query_act (str): Activation function for the query transform. Default: None.
  365. key_act (str): Activation function for the key transform. Default: None.
  366. value_act (str): Activation function for the value transform. Default: None.
  367. has_attention_mask (bool): Specifies whether to use attention mask. Default: False.
  368. attention_probs_dropout_prob (float): The dropout probability for
  369. BertAttention. Default: 0.0.
  370. use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
  371. initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
  372. do_return_2d_tensor (bool): True for return 2d tensor. False for return 3d
  373. tensor. Default: False.
  374. use_relative_positions (bool): Specifies whether to use relative positions. Default: False.
  375. compute_type (:class:`mindspore.dtype`): Compute type in BertAttention. Default: mstype.float32.
  376. """
  377. def __init__(self,
  378. batch_size,
  379. from_tensor_width,
  380. to_tensor_width,
  381. from_seq_length,
  382. to_seq_length,
  383. num_attention_heads=1,
  384. size_per_head=512,
  385. query_act=None,
  386. key_act=None,
  387. value_act=None,
  388. has_attention_mask=False,
  389. attention_probs_dropout_prob=0.0,
  390. use_one_hot_embeddings=False,
  391. initializer_range=0.02,
  392. do_return_2d_tensor=False,
  393. use_relative_positions=False,
  394. compute_type=mstype.float32):
  395. super(BertAttention, self).__init__()
  396. self.batch_size = batch_size
  397. self.from_seq_length = from_seq_length
  398. self.to_seq_length = to_seq_length
  399. self.num_attention_heads = num_attention_heads
  400. self.size_per_head = size_per_head
  401. self.has_attention_mask = has_attention_mask
  402. self.use_relative_positions = use_relative_positions
  403. self.scores_mul = Tensor([1.0 / math.sqrt(float(self.size_per_head))], dtype=compute_type)
  404. self.reshape = P.Reshape()
  405. self.shape_from_2d = (-1, from_tensor_width)
  406. self.shape_to_2d = (-1, to_tensor_width)
  407. weight = TruncatedNormal(initializer_range)
  408. units = num_attention_heads * size_per_head
  409. self.query_layer = nn.Dense(from_tensor_width,
  410. units,
  411. activation=query_act,
  412. weight_init=weight).to_float(compute_type)
  413. self.key_layer = nn.Dense(to_tensor_width,
  414. units,
  415. activation=key_act,
  416. weight_init=weight).to_float(compute_type)
  417. self.value_layer = nn.Dense(to_tensor_width,
  418. units,
  419. activation=value_act,
  420. weight_init=weight).to_float(compute_type)
  421. self.shape_from = (batch_size, from_seq_length, num_attention_heads, size_per_head)
  422. self.shape_to = (
  423. batch_size, to_seq_length, num_attention_heads, size_per_head)
  424. self.matmul_trans_b = P.BatchMatMul(transpose_b=True)
  425. self.multiply = P.Mul()
  426. self.transpose = P.Transpose()
  427. self.trans_shape = (0, 2, 1, 3)
  428. self.trans_shape_relative = (2, 0, 1, 3)
  429. self.trans_shape_position = (1, 2, 0, 3)
  430. self.multiply_data = Tensor([-10000.0,], dtype=compute_type)
  431. self.batch_num = batch_size * num_attention_heads
  432. self.matmul = P.BatchMatMul()
  433. self.softmax = nn.Softmax()
  434. self.dropout = nn.Dropout(1 - attention_probs_dropout_prob)
  435. if self.has_attention_mask:
  436. self.expand_dims = P.ExpandDims()
  437. self.sub = P.Sub()
  438. self.add = P.TensorAdd()
  439. self.cast = P.Cast()
  440. self.get_dtype = P.DType()
  441. if do_return_2d_tensor:
  442. self.shape_return = (batch_size * from_seq_length, num_attention_heads * size_per_head)
  443. else:
  444. self.shape_return = (batch_size, from_seq_length, num_attention_heads * size_per_head)
  445. self.cast_compute_type = SaturateCast(dst_type=compute_type)
  446. if self.use_relative_positions:
  447. self._generate_relative_positions_embeddings = \
  448. RelaPosEmbeddingsGenerator(length=to_seq_length,
  449. depth=size_per_head,
  450. max_relative_position=16,
  451. initializer_range=initializer_range,
  452. use_one_hot_embeddings=use_one_hot_embeddings)
  453. def construct(self, from_tensor, to_tensor, attention_mask):
  454. # reshape 2d/3d input tensors to 2d
  455. from_tensor_2d = self.reshape(from_tensor, self.shape_from_2d)
  456. to_tensor_2d = self.reshape(to_tensor, self.shape_to_2d)
  457. query_out = self.query_layer(from_tensor_2d)
  458. key_out = self.key_layer(to_tensor_2d)
  459. value_out = self.value_layer(to_tensor_2d)
  460. query_layer = self.reshape(query_out, self.shape_from)
  461. query_layer = self.transpose(query_layer, self.trans_shape)
  462. key_layer = self.reshape(key_out, self.shape_to)
  463. key_layer = self.transpose(key_layer, self.trans_shape)
  464. attention_scores = self.matmul_trans_b(query_layer, key_layer)
  465. # use_relative_position, supplementary logic
  466. if self.use_relative_positions:
  467. # 'relations_keys' = [F|T, F|T, H]
  468. relations_keys = self._generate_relative_positions_embeddings()
  469. relations_keys = self.cast_compute_type(relations_keys)
  470. # query_layer_t is [F, B, N, H]
  471. query_layer_t = self.transpose(query_layer, self.trans_shape_relative)
  472. # query_layer_r is [F, B * N, H]
  473. query_layer_r = self.reshape(query_layer_t,
  474. (self.from_seq_length,
  475. self.batch_num,
  476. self.size_per_head))
  477. # key_position_scores is [F, B * N, F|T]
  478. key_position_scores = self.matmul_trans_b(query_layer_r,
  479. relations_keys)
  480. # key_position_scores_r is [F, B, N, F|T]
  481. key_position_scores_r = self.reshape(key_position_scores,
  482. (self.from_seq_length,
  483. self.batch_size,
  484. self.num_attention_heads,
  485. self.from_seq_length))
  486. # key_position_scores_r_t is [B, N, F, F|T]
  487. key_position_scores_r_t = self.transpose(key_position_scores_r,
  488. self.trans_shape_position)
  489. attention_scores = attention_scores + key_position_scores_r_t
  490. attention_scores = self.multiply(self.scores_mul, attention_scores)
  491. if self.has_attention_mask:
  492. attention_mask = self.expand_dims(attention_mask, 1)
  493. multiply_out = self.sub(self.cast(F.tuple_to_array((1.0,)), self.get_dtype(attention_scores)),
  494. self.cast(attention_mask, self.get_dtype(attention_scores)))
  495. adder = self.multiply(multiply_out, self.multiply_data)
  496. attention_scores = self.add(adder, attention_scores)
  497. attention_probs = self.softmax(attention_scores)
  498. attention_probs = self.dropout(attention_probs)
  499. value_layer = self.reshape(value_out, self.shape_to)
  500. value_layer = self.transpose(value_layer, self.trans_shape)
  501. context_layer = self.matmul(attention_probs, value_layer)
  502. # use_relative_position, supplementary logic
  503. if self.use_relative_positions:
  504. # 'relations_values' = [F|T, F|T, H]
  505. relations_values = self._generate_relative_positions_embeddings()
  506. relations_values = self.cast_compute_type(relations_values)
  507. # attention_probs_t is [F, B, N, T]
  508. attention_probs_t = self.transpose(attention_probs, self.trans_shape_relative)
  509. # attention_probs_r is [F, B * N, T]
  510. attention_probs_r = self.reshape(
  511. attention_probs_t,
  512. (self.from_seq_length,
  513. self.batch_num,
  514. self.to_seq_length))
  515. # value_position_scores is [F, B * N, H]
  516. value_position_scores = self.matmul(attention_probs_r,
  517. relations_values)
  518. # value_position_scores_r is [F, B, N, H]
  519. value_position_scores_r = self.reshape(value_position_scores,
  520. (self.from_seq_length,
  521. self.batch_size,
  522. self.num_attention_heads,
  523. self.size_per_head))
  524. # value_position_scores_r_t is [B, N, F, H]
  525. value_position_scores_r_t = self.transpose(value_position_scores_r,
  526. self.trans_shape_position)
  527. context_layer = context_layer + value_position_scores_r_t
  528. context_layer = self.transpose(context_layer, self.trans_shape)
  529. context_layer = self.reshape(context_layer, self.shape_return)
  530. return context_layer
  531. class BertSelfAttention(nn.Cell):
  532. """
  533. Apply self-attention.
  534. Args:
  535. batch_size (int): Batch size of input dataset.
  536. seq_length (int): Length of input sequence.
  537. hidden_size (int): Size of the bert encoder layers.
  538. num_attention_heads (int): Number of attention heads. Default: 12.
  539. attention_probs_dropout_prob (float): The dropout probability for
  540. BertAttention. Default: 0.1.
  541. use_one_hot_embeddings (bool): Specifies whether to use one_hot encoding form. Default: False.
  542. initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
  543. hidden_dropout_prob (float): The dropout probability for BertOutput. Default: 0.1.
  544. use_relative_positions (bool): Specifies whether to use relative positions. Default: False.
  545. compute_type (:class:`mindspore.dtype`): Compute type in BertSelfAttention. Default: mstype.float32.
  546. """
  547. def __init__(self,
  548. batch_size,
  549. seq_length,
  550. hidden_size,
  551. num_attention_heads=12,
  552. attention_probs_dropout_prob=0.1,
  553. use_one_hot_embeddings=False,
  554. initializer_range=0.02,
  555. hidden_dropout_prob=0.1,
  556. use_relative_positions=False,
  557. compute_type=mstype.float32,
  558. enable_fused_layernorm=False):
  559. super(BertSelfAttention, self).__init__()
  560. if hidden_size % num_attention_heads != 0:
  561. raise ValueError("The hidden size (%d) is not a multiple of the number "
  562. "of attention heads (%d)" % (hidden_size, num_attention_heads))
  563. self.size_per_head = int(hidden_size / num_attention_heads)
  564. self.attention = BertAttention(
  565. batch_size=batch_size,
  566. from_tensor_width=hidden_size,
  567. to_tensor_width=hidden_size,
  568. from_seq_length=seq_length,
  569. to_seq_length=seq_length,
  570. num_attention_heads=num_attention_heads,
  571. size_per_head=self.size_per_head,
  572. attention_probs_dropout_prob=attention_probs_dropout_prob,
  573. use_one_hot_embeddings=use_one_hot_embeddings,
  574. initializer_range=initializer_range,
  575. use_relative_positions=use_relative_positions,
  576. has_attention_mask=True,
  577. do_return_2d_tensor=True,
  578. compute_type=compute_type)
  579. self.output = BertOutput(in_channels=hidden_size,
  580. out_channels=hidden_size,
  581. initializer_range=initializer_range,
  582. dropout_prob=hidden_dropout_prob,
  583. compute_type=compute_type,
  584. enable_fused_layernorm=enable_fused_layernorm)
  585. self.reshape = P.Reshape()
  586. self.shape = (-1, hidden_size)
  587. def construct(self, input_tensor, attention_mask):
  588. input_tensor = self.reshape(input_tensor, self.shape)
  589. attention_output = self.attention(input_tensor, input_tensor, attention_mask)
  590. output = self.output(attention_output, input_tensor)
  591. return output
  592. class BertEncoderCell(nn.Cell):
  593. """
  594. Encoder cells used in BertTransformer.
  595. Args:
  596. batch_size (int): Batch size of input dataset.
  597. hidden_size (int): Size of the bert encoder layers. Default: 768.
  598. seq_length (int): Length of input sequence. Default: 512.
  599. num_attention_heads (int): Number of attention heads. Default: 12.
  600. intermediate_size (int): Size of intermediate layer. Default: 3072.
  601. attention_probs_dropout_prob (float): The dropout probability for
  602. BertAttention. Default: 0.02.
  603. use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
  604. initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
  605. hidden_dropout_prob (float): The dropout probability for BertOutput. Default: 0.1.
  606. use_relative_positions (bool): Specifies whether to use relative positions. Default: False.
  607. hidden_act (str): Activation function. Default: "gelu".
  608. compute_type (:class:`mindspore.dtype`): Compute type in attention. Default: mstype.float32.
  609. """
  610. def __init__(self,
  611. batch_size,
  612. hidden_size=768,
  613. seq_length=512,
  614. num_attention_heads=12,
  615. intermediate_size=3072,
  616. attention_probs_dropout_prob=0.02,
  617. use_one_hot_embeddings=False,
  618. initializer_range=0.02,
  619. hidden_dropout_prob=0.1,
  620. use_relative_positions=False,
  621. hidden_act="gelu",
  622. compute_type=mstype.float32,
  623. enable_fused_layernorm=False):
  624. super(BertEncoderCell, self).__init__()
  625. self.attention = BertSelfAttention(
  626. batch_size=batch_size,
  627. hidden_size=hidden_size,
  628. seq_length=seq_length,
  629. num_attention_heads=num_attention_heads,
  630. attention_probs_dropout_prob=attention_probs_dropout_prob,
  631. use_one_hot_embeddings=use_one_hot_embeddings,
  632. initializer_range=initializer_range,
  633. hidden_dropout_prob=hidden_dropout_prob,
  634. use_relative_positions=use_relative_positions,
  635. compute_type=compute_type,
  636. enable_fused_layernorm=enable_fused_layernorm)
  637. self.intermediate = nn.Dense(in_channels=hidden_size,
  638. out_channels=intermediate_size,
  639. activation=hidden_act,
  640. weight_init=TruncatedNormal(initializer_range)).to_float(compute_type)
  641. self.output = BertOutput(in_channels=intermediate_size,
  642. out_channels=hidden_size,
  643. initializer_range=initializer_range,
  644. dropout_prob=hidden_dropout_prob,
  645. compute_type=compute_type,
  646. enable_fused_layernorm=enable_fused_layernorm)
  647. def construct(self, hidden_states, attention_mask):
  648. # self-attention
  649. attention_output = self.attention(hidden_states, attention_mask)
  650. # feed construct
  651. intermediate_output = self.intermediate(attention_output)
  652. # add and normalize
  653. output = self.output(intermediate_output, attention_output)
  654. return output
  655. class BertTransformer(nn.Cell):
  656. """
  657. Multi-layer bert transformer.
  658. Args:
  659. batch_size (int): Batch size of input dataset.
  660. hidden_size (int): Size of the encoder layers.
  661. seq_length (int): Length of input sequence.
  662. num_hidden_layers (int): Number of hidden layers in encoder cells.
  663. num_attention_heads (int): Number of attention heads in encoder cells. Default: 12.
  664. intermediate_size (int): Size of intermediate layer in encoder cells. Default: 3072.
  665. attention_probs_dropout_prob (float): The dropout probability for
  666. BertAttention. Default: 0.1.
  667. use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
  668. initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
  669. hidden_dropout_prob (float): The dropout probability for BertOutput. Default: 0.1.
  670. use_relative_positions (bool): Specifies whether to use relative positions. Default: False.
  671. hidden_act (str): Activation function used in the encoder cells. Default: "gelu".
  672. compute_type (:class:`mindspore.dtype`): Compute type in BertTransformer. Default: mstype.float32.
  673. return_all_encoders (bool): Specifies whether to return all encoders. Default: False.
  674. """
  675. def __init__(self,
  676. batch_size,
  677. hidden_size,
  678. seq_length,
  679. num_hidden_layers,
  680. num_attention_heads=12,
  681. intermediate_size=3072,
  682. attention_probs_dropout_prob=0.1,
  683. use_one_hot_embeddings=False,
  684. initializer_range=0.02,
  685. hidden_dropout_prob=0.1,
  686. use_relative_positions=False,
  687. hidden_act="gelu",
  688. compute_type=mstype.float32,
  689. return_all_encoders=False,
  690. enable_fused_layernorm=False):
  691. super(BertTransformer, self).__init__()
  692. self.return_all_encoders = return_all_encoders
  693. layers = []
  694. for _ in range(num_hidden_layers):
  695. layer = BertEncoderCell(batch_size=batch_size,
  696. hidden_size=hidden_size,
  697. seq_length=seq_length,
  698. num_attention_heads=num_attention_heads,
  699. intermediate_size=intermediate_size,
  700. attention_probs_dropout_prob=attention_probs_dropout_prob,
  701. use_one_hot_embeddings=use_one_hot_embeddings,
  702. initializer_range=initializer_range,
  703. hidden_dropout_prob=hidden_dropout_prob,
  704. use_relative_positions=use_relative_positions,
  705. hidden_act=hidden_act,
  706. compute_type=compute_type,
  707. enable_fused_layernorm=enable_fused_layernorm)
  708. layers.append(layer)
  709. self.layers = nn.CellList(layers)
  710. self.reshape = P.Reshape()
  711. self.shape = (-1, hidden_size)
  712. self.out_shape = (batch_size, seq_length, hidden_size)
  713. def construct(self, input_tensor, attention_mask):
  714. prev_output = self.reshape(input_tensor, self.shape)
  715. all_encoder_layers = ()
  716. for layer_module in self.layers:
  717. layer_output = layer_module(prev_output, attention_mask)
  718. prev_output = layer_output
  719. if self.return_all_encoders:
  720. layer_output = self.reshape(layer_output, self.out_shape)
  721. all_encoder_layers = all_encoder_layers + (layer_output,)
  722. if not self.return_all_encoders:
  723. prev_output = self.reshape(prev_output, self.out_shape)
  724. all_encoder_layers = all_encoder_layers + (prev_output,)
  725. return all_encoder_layers
  726. class CreateAttentionMaskFromInputMask(nn.Cell):
  727. """
  728. Create attention mask according to input mask.
  729. Args:
  730. config (Class): Configuration for BertModel.
  731. """
  732. def __init__(self, config):
  733. super(CreateAttentionMaskFromInputMask, self).__init__()
  734. self.input_mask_from_dataset = config.input_mask_from_dataset
  735. self.input_mask = None
  736. if not self.input_mask_from_dataset:
  737. self.input_mask = initializer(
  738. "ones", [config.batch_size, config.seq_length], mstype.int32).to_tensor()
  739. self.cast = P.Cast()
  740. self.reshape = P.Reshape()
  741. self.shape = (config.batch_size, 1, config.seq_length)
  742. self.broadcast_ones = initializer(
  743. "ones", [config.batch_size, config.seq_length, 1], mstype.float32).to_tensor()
  744. self.batch_matmul = P.BatchMatMul()
  745. def construct(self, input_mask):
  746. if not self.input_mask_from_dataset:
  747. input_mask = self.input_mask
  748. attention_mask = self.cast(self.reshape(input_mask, self.shape), mstype.float32)
  749. return attention_mask
  750. class BertModel(nn.Cell):
  751. """
  752. Bidirectional Encoder Representations from Transformers.
  753. Args:
  754. config (Class): Configuration for BertModel.
  755. is_training (bool): True for training mode. False for eval mode.
  756. use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
  757. """
  758. def __init__(self,
  759. config,
  760. is_training,
  761. use_one_hot_embeddings=False):
  762. super(BertModel, self).__init__()
  763. config = copy.deepcopy(config)
  764. if not is_training:
  765. config.hidden_dropout_prob = 0.0
  766. config.attention_probs_dropout_prob = 0.0
  767. self.input_mask_from_dataset = config.input_mask_from_dataset
  768. self.token_type_ids_from_dataset = config.token_type_ids_from_dataset
  769. self.batch_size = config.batch_size
  770. self.seq_length = config.seq_length
  771. self.hidden_size = config.hidden_size
  772. self.num_hidden_layers = config.num_hidden_layers
  773. self.embedding_size = config.hidden_size
  774. self.token_type_ids = None
  775. self.last_idx = self.num_hidden_layers - 1
  776. output_embedding_shape = [self.batch_size, self.seq_length,
  777. self.embedding_size]
  778. if not self.token_type_ids_from_dataset:
  779. self.token_type_ids = initializer(
  780. "zeros", [self.batch_size, self.seq_length], mstype.int32).to_tensor()
  781. self.bert_embedding_lookup = EmbeddingLookup(
  782. vocab_size=config.vocab_size,
  783. embedding_size=self.embedding_size,
  784. embedding_shape=output_embedding_shape,
  785. use_one_hot_embeddings=use_one_hot_embeddings,
  786. initializer_range=config.initializer_range)
  787. self.bert_embedding_postprocessor = EmbeddingPostprocessor(
  788. embedding_size=self.embedding_size,
  789. embedding_shape=output_embedding_shape,
  790. use_relative_positions=config.use_relative_positions,
  791. use_token_type=True,
  792. token_type_vocab_size=config.type_vocab_size,
  793. use_one_hot_embeddings=use_one_hot_embeddings,
  794. initializer_range=0.02,
  795. max_position_embeddings=config.max_position_embeddings,
  796. dropout_prob=config.hidden_dropout_prob)
  797. self.bert_encoder = BertTransformer(
  798. batch_size=self.batch_size,
  799. hidden_size=self.hidden_size,
  800. seq_length=self.seq_length,
  801. num_attention_heads=config.num_attention_heads,
  802. num_hidden_layers=self.num_hidden_layers,
  803. intermediate_size=config.intermediate_size,
  804. attention_probs_dropout_prob=config.attention_probs_dropout_prob,
  805. use_one_hot_embeddings=use_one_hot_embeddings,
  806. initializer_range=config.initializer_range,
  807. hidden_dropout_prob=config.hidden_dropout_prob,
  808. use_relative_positions=config.use_relative_positions,
  809. hidden_act=config.hidden_act,
  810. compute_type=config.compute_type,
  811. return_all_encoders=True,
  812. enable_fused_layernorm=config.enable_fused_layernorm)
  813. self.cast = P.Cast()
  814. self.dtype = config.dtype
  815. self.cast_compute_type = SaturateCast(dst_type=config.compute_type)
  816. self.slice = P.StridedSlice()
  817. self.squeeze_1 = P.Squeeze(axis=1)
  818. self.dense = nn.Dense(self.hidden_size, self.hidden_size,
  819. activation="tanh",
  820. weight_init=TruncatedNormal(config.initializer_range)).to_float(config.compute_type)
  821. self._create_attention_mask_from_input_mask = CreateAttentionMaskFromInputMask(config)
  822. def construct(self, input_ids, token_type_ids, input_mask):
  823. # embedding
  824. if not self.token_type_ids_from_dataset:
  825. token_type_ids = self.token_type_ids
  826. word_embeddings, embedding_tables = self.bert_embedding_lookup(input_ids)
  827. embedding_output = self.bert_embedding_postprocessor(token_type_ids,
  828. word_embeddings)
  829. # attention mask [batch_size, seq_length, seq_length]
  830. attention_mask = self._create_attention_mask_from_input_mask(input_mask)
  831. # bert encoder
  832. encoder_output = self.bert_encoder(self.cast_compute_type(embedding_output),
  833. attention_mask)
  834. sequence_output = self.cast(encoder_output[self.last_idx], self.dtype)
  835. # pooler
  836. sequence_slice = self.slice(sequence_output,
  837. (0, 0, 0),
  838. (self.batch_size, 1, self.hidden_size),
  839. (1, 1, 1))
  840. first_token = self.squeeze_1(sequence_slice)
  841. pooled_output = self.dense(first_token)
  842. pooled_output = self.cast(pooled_output, self.dtype)
  843. return sequence_output, pooled_output, embedding_tables