You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

gpt.py 22 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """GPT model"""
  16. import math
  17. import numpy as np
  18. import mindspore.nn as nn
  19. from mindspore.common.tensor import Tensor
  20. from mindspore.common.parameter import Parameter
  21. import mindspore.common.dtype as mstype
  22. from mindspore.common.initializer import TruncatedNormal, initializer, Normal
  23. from mindspore.ops import operations as P
  24. from mindspore.ops import functional as F
  25. class LayerNorm(nn.Cell):
  26. """
  27. Layer Normalization
  28. Args:
  29. normalized_shape: the corresponding shape of the normalized axes
  30. eps: epsilon, a small number avoiding zero division
  31. Inputs:
  32. x: input tensor
  33. Returns:
  34. rescaled_output: Tensor, returned tensor after layernorm
  35. """
  36. def __init__(self, normalized_shape, eps=1e-5):
  37. super(LayerNorm, self).__init__()
  38. self.gamma = Parameter(initializer('ones', normalized_shape))
  39. self.beta = Parameter(initializer('zeros', normalized_shape))
  40. self.mean = P.ReduceMean(keep_dims=True)
  41. self.eps = eps
  42. def construct(self, x):
  43. mean = self.mean(x, -1)
  44. variance = self.mean(F.square(x - mean), -1)
  45. output = (x - mean) / F.sqrt(variance + self.eps)
  46. rescaled_output = output * self.gamma + self.beta
  47. return rescaled_output
  48. class Softmax(nn.Cell):
  49. """
  50. softmax realization
  51. Args:
  52. axis: the axis to be applied softmax
  53. Inputs:
  54. x: input tensor
  55. Returns:
  56. output: Tensor, returned tensor after softmax
  57. """
  58. def __init__(self, axis=-1):
  59. super(Softmax, self).__init__()
  60. self.max = P.ArgMaxWithValue(axis=axis, keep_dims=True)
  61. self.sum = P.ReduceSum(keep_dims=True)
  62. self.axis = axis
  63. def construct(self, x):
  64. _, max_value = self.max(x)
  65. exp_x = F.tensor_pow(np.e, x - max_value)
  66. sum_x = self.sum(exp_x, self.axis)
  67. output = exp_x / sum_x
  68. return output
  69. class Mapping(nn.Cell):
  70. """
  71. A mapping function with a 3d input
  72. Args:
  73. input_size: the size of the last dimension of the input tensor
  74. output_size: the desired size of the last dimension of the output tensor
  75. dtype: the compute datatype
  76. scale: the scale factor for initialization
  77. Inputs:
  78. x: the 3d input
  79. Returns:
  80. output: Tensor, a 3d tensor after projection
  81. """
  82. def __init__(self, input_size, output_size, dtype, scale=1.0):
  83. super(Mapping, self).__init__()
  84. self.output_size = output_size
  85. self.input_size = input_size
  86. self.weight = Parameter(initializer(Normal(sigma=0.02*scale), [input_size, output_size]))
  87. self.bias = Parameter(initializer("zeros", [output_size,]))
  88. self.dtype = dtype
  89. self.cast = P.Cast()
  90. def construct(self, x):
  91. out_shape = P.Shape()(x)[:-1] + (self.output_size,)
  92. x = P.Reshape()(x, (-1, self.input_size))
  93. x = nn.MatMul()(x, self.cast(self.weight, self.dtype)) + self.cast(self.bias, self.dtype)
  94. output = P.Reshape()(x, out_shape)
  95. return output
  96. class Output(nn.Cell):
  97. """
  98. The output mapping module for each layer
  99. Args:
  100. config(GPTConfig): the config of network
  101. scale: scale factor for initialization
  102. Inputs:
  103. x: output of the self-attention module
  104. Returns:
  105. output: Tensor, the output of this layer after mapping
  106. """
  107. def __init__(self, config, scale=1.0):
  108. super(Output, self).__init__()
  109. input_size = config.embedding_size
  110. output_size = config.embedding_size*config.expand_ratio
  111. self.mapping = Mapping(input_size, output_size, config.compute_dtype)
  112. self.projection = Mapping(output_size, input_size, config.compute_dtype, scale)
  113. self.activation = nn.GELU()
  114. self.dropout = nn.Dropout(1-config.dropout_rate)
  115. def construct(self, x):
  116. hidden = self.activation(self.mapping(x))
  117. output = self.projection(hidden)
  118. output = self.dropout(output)
  119. return output
  120. class AttentionMask(nn.Cell):
  121. """
  122. Get the attention matrix for self-attention module
  123. Args:
  124. config(GPTConfig): the config of network
  125. Inputs:
  126. input_mask: the mask indicating whether each position is a valid input
  127. Returns:
  128. attention_mask: the attention mask matrix with shape (batch_size, 1, seq_length, seq_length)
  129. """
  130. def __init__(self, config):
  131. super(AttentionMask, self).__init__()
  132. self.reshape = P.Reshape()
  133. self.mul = P.BatchMatMul()
  134. ones = np.ones(shape=(config.seq_length, config.seq_length))
  135. self.lower_triangle_mask = Tensor(np.tril(ones), mstype.float32)
  136. self.multiply = P.Mul()
  137. def construct(self, input_mask):
  138. input_shape = P.Shape()(input_mask)
  139. shape_right = (input_shape[0], 1, input_shape[1])
  140. shape_left = input_shape + (1,)
  141. mask_left = self.reshape(input_mask, shape_left)
  142. mask_right = self.reshape(input_mask, shape_right)
  143. attention_mask = self.mul(mask_left, mask_right)
  144. lower_traiangle = P.ExpandDims()(self.lower_triangle_mask, 0)
  145. attention_mask = self.multiply(attention_mask, lower_traiangle) #bs seq_length seq_length
  146. return attention_mask
  147. class EmbeddingLookup(nn.Cell):
  148. """
  149. The embedding lookup table for vocabulary
  150. Args:
  151. config(GPTConfig): the config of network
  152. Inputs:
  153. input_ids: the tokenized inputs with datatype int32
  154. Returns:
  155. output: Tensor, the embedding vector for the input with shape (batch_size, seq_length, embedding_size)
  156. self.embedding_table: Tensor, the embedding table for the vocabulary
  157. """
  158. def __init__(self, config):
  159. super(EmbeddingLookup, self).__init__()
  160. self.vocab_size = config.vocab_size
  161. self.embedding_size = config.embedding_size
  162. self.embedding_table = Parameter(initializer(TruncatedNormal(0.02), [self.vocab_size, self.embedding_size]))
  163. self.gather = P.Gather()
  164. self.shape = (-1, config.seq_length, config.embedding_size)
  165. def construct(self, input_ids):
  166. output = self.gather(self.embedding_table, input_ids, 0)
  167. return output, self.embedding_table
  168. class Attention(nn.Cell):
  169. """
  170. Self-Attention module for each layer
  171. Args:
  172. config(GPTConfig): the config of network
  173. scale: scale factor for initialization
  174. layer_idx: current layer index
  175. """
  176. def __init__(self, config, scale=1.0, layer_idx=None):
  177. super(Attention, self).__init__()
  178. self.get_attention_mask = AttentionMask(config)
  179. self.projection = Mapping(config.embedding_size, config.embedding_size, config.compute_dtype, scale)
  180. self.split = P.Split(axis=-1, output_num=3)
  181. self.transpose = P.Transpose()
  182. self.reshape = P.Reshape()
  183. self.n_head = config.num_heads
  184. self.size_per_head = config.embedding_size // self.n_head
  185. self.concat_k = P.Concat(axis=3)
  186. self.concat_v = P.Concat(axis=2)
  187. self.multiply_data = Tensor([-10000.0,], dtype=mstype.float32)
  188. self.batch_matmul = P.BatchMatMul()
  189. self.scale = scale
  190. if self.scale:
  191. self.scale_factor = Tensor(math.sqrt(self.size_per_head))
  192. if layer_idx is not None:
  193. self.coeff = math.sqrt(layer_idx * math.sqrt(self.size_per_head))
  194. self.coeff = Tensor(self.coeff)
  195. self.use_past = config.use_past
  196. self.dropout = nn.Dropout(1-config.dropout_rate)
  197. self.prob_dropout = nn.Dropout(1-config.dropout_rate)
  198. self.dense1 = nn.Dense(config.embedding_size, config.embedding_size).to_float(config.compute_dtype)
  199. self.dense2 = nn.Dense(config.embedding_size, config.embedding_size).to_float(config.compute_dtype)
  200. self.dense3 = nn.Dense(config.embedding_size, config.embedding_size).to_float(config.compute_dtype)
  201. def construct(self, x, attention_mask, layer_past=None):
  202. """
  203. self-attention
  204. Inputs:
  205. x: output of previous layer
  206. attention_mask: the attention mask matrix with shape (batch_size, 1, seq_length, seq_length)
  207. layer_past: the previous feature map
  208. Returns:
  209. output: Tensor, the output logit of this layer
  210. layer_present: Tensor, the feature map of current layer
  211. """
  212. original_shape = F.shape(x)
  213. x = F.reshape(x, (-1, original_shape[-1]))
  214. query = self.dense1(x)
  215. key = self.dense2(x)
  216. value = self.dense3(x)
  217. query = self.transpose(F.reshape(query, (-1, original_shape[1], self.n_head, self.size_per_head)), (0, 2, 1, 3))
  218. key = self.transpose(F.reshape(key, (-1, original_shape[1], self.n_head, self.size_per_head)), (0, 2, 3, 1))
  219. value = self.transpose(F.reshape(value, (-1, original_shape[1], self.n_head, self.size_per_head)), (0, 2, 1, 3))
  220. if self.use_past:
  221. past_value = layer_past[1]
  222. past_key = self.transpose(layer_past[0], (0, 1, 3, 2))
  223. key = self.concat_k((past_key, key))
  224. value = self.concat_v(past_value, value)
  225. layer_present = P.Pack()([self.transpose(key, (0, 1, 3, 2)), value])
  226. attention = self._attn(query, key, value, attention_mask)
  227. attention_merge = self.merge_heads(attention)
  228. output = self.projection(attention_merge)
  229. output = self.dropout(output)
  230. return output, layer_present
  231. def split_heads(self, x, transpose):
  232. """
  233. split 3d tensor to 4d and switch certain axes
  234. Inputs:
  235. x: input tensor
  236. transpose: tuple, the transpose sequence
  237. Returns:
  238. x_transpose: the 4d output
  239. """
  240. x_size = P.Shape()(x)
  241. new_x_shape = x_size[:-1] + (self.n_head, self.size_per_head)
  242. x = self.reshape(x, new_x_shape)
  243. x_transpose = self.transpose(x, transpose)
  244. return x_transpose
  245. def merge_heads(self, x):
  246. """
  247. convert a 4d input to a 3d output
  248. Inputs:
  249. x: input tensor
  250. Returns:
  251. x_merge: the 3d output
  252. """
  253. x = self.transpose(x, (0, 2, 1, 3)) #bs, seq_length, head, size_per_head
  254. x_shape = P.Shape()(x)
  255. new_shape = x_shape[:-2] + (x_shape[-2]*x_shape[-1],)
  256. x_merge = self.reshape(x, new_shape)
  257. return x_merge
  258. def _attn(self, query, key, value, attention_mask):
  259. """
  260. Get the weighted score along the seq_length
  261. Inputs:
  262. query: the query matrix
  263. key: the key matrix
  264. value: the value matrix
  265. attention_mask: the attention mask matrix with shape (batch_size, 1, seq_length, seq_length)
  266. Returns:
  267. weighted_values: Tensor, the weighted sum scores
  268. """
  269. if not self.scale:
  270. query = query / F.cast(self.coeff, F.dtype(query))
  271. key = key / F.cast(self.coeff, F.dtype(key))
  272. score = self.batch_matmul(query, key)
  273. if self.scale:
  274. score = score / P.Cast()(self.scale_factor, P.DType()(score))
  275. ori_dtype = P.DType()(score)
  276. score = P.Cast()(score, mstype.float32)
  277. multiplu_out = P.Sub()(P.Cast()(F.tuple_to_array((1.0,)), P.DType()(score)),
  278. P.Cast()(attention_mask, P.DType()(score)))
  279. adder = P.Mul()(multiplu_out, self.multiply_data)
  280. attention_scores = adder + score
  281. attention_scores = P.Cast()(attention_scores, ori_dtype)
  282. attention_probs = Softmax()(attention_scores)
  283. attention_probs = self.prob_dropout(attention_probs)
  284. weighted_values = self.batch_matmul(attention_probs, value)
  285. return weighted_values
  286. class Block(nn.Cell):
  287. """
  288. The basic block of GPT network
  289. Args:
  290. config(GPTConfig): the config of network
  291. layer_idx: current layer index
  292. Inputs:
  293. x: the output of previous layer(input_ids for the first layer)
  294. attention_mask: the attention mask matrix with shape (batch_size, 1, seq_length, seq_length)
  295. layer_past: the previous feature map
  296. Returns:
  297. output: Tensor, the output logit of this layer
  298. layer_present: Tensor, the feature map of current layer
  299. """
  300. def __init__(self, config, layer_idx):
  301. super(Block, self).__init__()
  302. scale = 1.0
  303. self.layernorm1 = LayerNorm((config.embedding_size,)).to_float(config.compute_dtype)
  304. self.attention = Attention(config, scale, layer_idx)
  305. self.layernorm2 = LayerNorm((config.embedding_size,)).to_float(config.compute_dtype)
  306. self.output = Output(config, scale)
  307. self.post_layernorm_residual = config.post_layernorm_residual
  308. def construct(self, x, attention_mask, layer_past=None):
  309. """basic block of each layer"""
  310. input_x = self.layernorm1(x)
  311. attention, layer_present = self.attention(input_x, attention_mask, layer_past)
  312. if self.post_layernorm_residual:
  313. x = input_x + attention
  314. else:
  315. x = x + attention
  316. output_x = self.layernorm2(x)
  317. mlp_logit = self.output(output_x)
  318. if self.post_layernorm_residual:
  319. output = output_x + mlp_logit
  320. else:
  321. output = x + mlp_logit
  322. return output, layer_present
  323. class GPT_Model(nn.Cell):
  324. """
  325. The backbone of GPT network
  326. Args:
  327. config(GPTConfig): the config of network
  328. Inputs:
  329. input_ids: the tokenized inputs with datatype int32
  330. input_mask: the mask indicating whether each position is a valid input
  331. layer_past: the previous feature map
  332. Returns:
  333. output_state: Tensor, the output logit of backbone
  334. present_layer: Tensor, the current feature map
  335. embedding_table: Tensor, the embedding table for the vocabulary
  336. """
  337. def __init__(self, config):
  338. super(GPT_Model, self).__init__()
  339. self.get_attention_mask = AttentionMask(config)
  340. self.word_embedding = EmbeddingLookup(config)
  341. self.position_embedding = nn.Embedding(config.seq_length, config.embedding_size,
  342. embedding_table=TruncatedNormal(0.02))
  343. self.blocks = nn.CellList()
  344. for i in range(config.num_layers):
  345. self.blocks.append(Block(config, i+1))
  346. self.layernorm = LayerNorm((config.embedding_size,)).to_float(config.compute_dtype)
  347. self.use_past = config.use_past
  348. self.past = tuple([None]*config.num_layers)
  349. self.num_layers = config.num_layers
  350. def construct(self, input_ids, input_mask, layer_past=None):
  351. """GPT model"""
  352. if not self.use_past:
  353. layer_past = self.past
  354. input_embedding, embedding_table = self.word_embedding(input_ids)
  355. batch_size, seq_length = F.shape(input_ids)
  356. input_position = F.tuple_to_array(F.make_range(seq_length))
  357. input_position = P.Tile()(input_position, (batch_size, 1))
  358. position_embedding = self.position_embedding(input_position)
  359. hidden_states = input_embedding + position_embedding
  360. hidden_states = P.Cast()(hidden_states, mstype.float16)
  361. attention_mask = self.get_attention_mask(input_mask)
  362. attention_mask = P.ExpandDims()(attention_mask, 1)
  363. present_layer = ()
  364. for i in range(self.num_layers):
  365. hidden_states, present = self.blocks[i](hidden_states, attention_mask, layer_past)
  366. present_layer = present_layer + (present,)
  367. output_state = self.layernorm(hidden_states)
  368. return output_state, present_layer, embedding_table
  369. class GPT_Head(nn.Cell):
  370. """
  371. Head for GPT to get the logits of each token in the vocab
  372. Args:
  373. config(GPTConfig): the config of network
  374. Inputs:
  375. state: the output of the backbone
  376. embedding_table: the embedding table of the vocabulary
  377. Returns:
  378. logits: Tensor, the logits of the corresponding inputs
  379. """
  380. def __init__(self, config):
  381. super(GPT_Head, self).__init__()
  382. self.matmul = P.MatMul(transpose_b=True)
  383. self.embedding_size = config.embedding_size
  384. self.log_softmax = P.LogSoftmax(axis=-1)
  385. self.dtype = config.compute_dtype
  386. self.cast = P.Cast()
  387. def construct(self, state, embedding_table):
  388. state = P.Reshape()(state, (-1, self.embedding_size))
  389. logits = self.matmul(state, self.cast(embedding_table, self.dtype))
  390. return logits
  391. class GPT(nn.Cell):
  392. """
  393. The GPT network consisting of two parts the backbone and the head
  394. Args:
  395. config(GPTConfig): the config of network
  396. Inputs:
  397. input_ids: the tokenized inputs
  398. input_mask: the mask indicating whether each position is a valid input
  399. past: the previous feature map
  400. Returns:
  401. logits: Tensor: the logits of the corresponding inputs with shape (batch_size, seq_length, vocab_size)
  402. """
  403. def __init__(self, config):
  404. super(GPT, self).__init__()
  405. self.backbone = GPT_Model(config)
  406. self.head = GPT_Head(config)
  407. def construct(self, input_ids, input_mask, past=None):
  408. output_states, _, embedding_table = self.backbone(input_ids, input_mask, past)
  409. logits = self.head(output_states, embedding_table)
  410. return logits
  411. class CrossEntropyLoss(nn.Cell):
  412. """
  413. Calculate the cross entropy loss
  414. Args:
  415. config(GPTConfig): the config of the network
  416. Inputs:
  417. logits: the output logits of the backbone
  418. label: the ground truth label of the sample
  419. input_mask: the mask indicating whether each position is a valid input
  420. Returns:
  421. loss: Tensor, the corrsponding cross entropy loss
  422. """
  423. def __init__(self, config):
  424. super(CrossEntropyLoss, self).__init__()
  425. self.log_softmax = nn.LogSoftmax(axis=-1)
  426. self.mean = P.ReduceMean()
  427. self.sum = P.ReduceSum()
  428. self.onehot = P.OneHot()
  429. self.on_value = Tensor(1.0, mstype.float32)
  430. self.off_value = Tensor(0.0, mstype.float32)
  431. self.vocab_size = config.vocab_size
  432. def construct(self, logits, label, input_mask):
  433. logits = self.log_softmax(P.Cast()(logits, mstype.float32))
  434. label = P.Reshape()(label, (-1,))
  435. one_hot_label = self.onehot(label, self.vocab_size, self.on_value, self.off_value)
  436. loss_sum = P.Neg()(self.sum(logits*one_hot_label, (-1,)))
  437. input_mask = P.Reshape()(input_mask, (-1,))
  438. numerator = self.sum(loss_sum*input_mask)
  439. denominator = self.sum(input_mask) + P.Cast()(F.tuple_to_array((1e-5,)), mstype.float32)
  440. loss = numerator / denominator
  441. return loss
  442. class GPTWithLoss(nn.Cell):
  443. """
  444. GPT training loss
  445. Args:
  446. network: backbone network of GPT2/3
  447. loss: loss function, e.g., crossentropy
  448. eos_token: the end_of_sentence token
  449. Inputs:
  450. input_ids: the tokenized inputs
  451. past: the previous feature map
  452. Returns:
  453. output: Tensor, the loss of the network
  454. """
  455. def __init__(self, network, loss, eos_token=50256):
  456. super(GPTWithLoss, self).__init__(auto_prefix=False)
  457. self.network = network
  458. self.loss = loss
  459. self.eos_token = eos_token
  460. def construct(self, input_ids, past=None):
  461. tokens = input_ids[:, :-1]
  462. input_mask = F.cast(F.not_equal(tokens, self.eos_token), mstype.float32)
  463. logits = self.network(tokens, input_mask, past)
  464. labels = input_ids[:, 1:]
  465. output = self.loss(logits, labels, input_mask)
  466. return output
  467. class EvalNet(nn.Cell):
  468. """
  469. GPT evaluation net
  470. Args:
  471. backbone: backbone network of GPT2/3
  472. generate: enable generate mode
  473. Inputs:
  474. input_ids: the tokenized inpus
  475. Returns:
  476. outputs: Tensor, corresponding output for different tasks
  477. """
  478. def __init__(self, backbone, generate=False):
  479. super(EvalNet, self).__init__(auto_prefix=False)
  480. self.backbone = backbone
  481. self.argmax = P.Argmax()
  482. self.generate = generate
  483. def construct(self, input_ids):
  484. """evaluation net"""
  485. input_mask = F.cast(F.not_equal(input_ids, 0), mstype.float32)
  486. logits = self.backbone(input_ids, input_mask)
  487. outputs = None
  488. if self.generate:
  489. outputs = nn.LogSoftmax()(logits)
  490. outputs = F.tensor_pow(np.e, outputs)
  491. else:
  492. outputs = self.argmax(logits)
  493. return outputs