You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_bert_cell.py 20 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """ test bert of graph compile """
  16. import functools
  17. import numpy as np
  18. import mindspore.common.dtype as mstype
  19. from mindspore.common.tensor import Tensor
  20. from mindspore.model_zoo.Bert_NEZHA.bert_model import BertConfig, \
  21. EmbeddingLookup, EmbeddingPostprocessor, BertOutput, RelaPosMatrixGenerator, \
  22. RelaPosEmbeddingsGenerator, SaturateCast, BertAttention, BertSelfAttention, \
  23. BertEncoderCell, BertTransformer, CreateAttentionMaskFromInputMask, BertModel
  24. from mindspore.nn.layer.basic import Norm
  25. from mindspore.model_zoo.Bert_NEZHA import BertPretrainingLoss, GetNextSentenceOutput
  26. import mindspore.nn as nn
  27. from mindspore.common.initializer import TruncatedNormal
  28. from mindspore.common.parameter import ParameterTuple
  29. from mindspore.nn.optim import AdamWeightDecay, AdamWeightDecayDynamicLR
  30. from mindspore.model_zoo.Bert_NEZHA.bert_for_pre_training import ClipGradients
  31. import mindspore.ops.composite as C
  32. from mindspore.ops import functional as F
  33. from ....ops_common import convert
  34. from ....mindspore_test_framework.mindspore_test import mindspore_test
  35. from ....mindspore_test_framework.pipeline.forward.compile_forward import pipeline_for_compile_forward_ge_graph_for_case_by_case_config
  36. from ....mindspore_test_framework.pipeline.gradient.compile_gradient import pipeline_for_compile_grad_ge_graph_for_case_by_case_config
  37. def bert_trans():
  38. """bert_trans"""
  39. net = BertTransformer(batch_size=1,
  40. hidden_size=768,
  41. seq_length=128,
  42. num_hidden_layers=1,
  43. num_attention_heads=12,
  44. intermediate_size=768,
  45. attention_probs_dropout_prob=0.1,
  46. use_one_hot_embeddings=False,
  47. initializer_range=0.02,
  48. use_relative_positions=False,
  49. hidden_act="gelu",
  50. compute_type=mstype.float32,
  51. return_all_encoders=True)
  52. net.set_train()
  53. return net
  54. def set_train(net):
  55. net.set_train()
  56. return net
  57. class NetForAdam(nn.Cell):
  58. def __init__(self):
  59. super(NetForAdam, self).__init__()
  60. self.dense = nn.Dense(64, 10)
  61. def construct(self, x):
  62. x = self.dense(x)
  63. return x
  64. class TrainStepWrapForAdam(nn.Cell):
  65. """TrainStepWrapForAdam definition"""
  66. def __init__(self, network):
  67. super(TrainStepWrapForAdam, self).__init__()
  68. self.network = network
  69. self.weights = ParameterTuple(network.get_parameters())
  70. self.optimizer = AdamWeightDecay(self.weights)
  71. self.clip_gradients = ClipGradients()
  72. def construct(self, x, sens):
  73. weights = self.weights
  74. grads = C.grad_by_list_with_sens(self.network, weights)(x, sens)
  75. grads = self.clip_gradients(grads, 1, 1.0)
  76. return self.optimizer(grads)
  77. class TrainStepWrapForAdamDynamicLr(nn.Cell):
  78. """TrainStepWrapForAdamDynamicLr definition"""
  79. def __init__(self, network):
  80. super(TrainStepWrapForAdamDynamicLr, self).__init__()
  81. self.network = network
  82. self.weights = ParameterTuple(network.get_parameters())
  83. self.optimizer = AdamWeightDecayDynamicLR(self.weights, 10)
  84. self.sens = Tensor(np.ones(shape=(1, 10)).astype(np.float32))
  85. def construct(self, x):
  86. weights = self.weights
  87. grads = C.grad_by_list_with_sens(self.network, weights)(x, self.sens)
  88. return self.optimizer(grads)
  89. class TempC2Wrap(nn.Cell):
  90. def __init__(self, op, c1=None, c2=None,):
  91. super(TempC2Wrap, self).__init__()
  92. self.op = op
  93. self.c1 = c1
  94. self.c2 = c2
  95. def construct(self, x1):
  96. x = self.op(x1, self.c1, self.c2)
  97. return x
  98. test_case_cell_ops = [
  99. ('Norm_keepdims', {
  100. 'block': Norm(keep_dims=True),
  101. 'desc_inputs': [[1, 3, 4, 4]],
  102. 'desc_bprop': [[1]]}),
  103. ('SaturateCast', {
  104. 'block': SaturateCast(),
  105. 'desc_inputs': [[1, 3, 4, 4]],
  106. 'desc_bprop': [[1, 3, 4, 4]]}),
  107. ('RelaPosMatrixGenerator_0', {
  108. 'block': RelaPosMatrixGenerator(length=128, max_relative_position=16),
  109. 'desc_inputs': [],
  110. 'desc_bprop': [[128, 128]],
  111. 'skip': ['backward']}),
  112. ('RelaPosEmbeddingsGenerator_0', {
  113. 'block': RelaPosEmbeddingsGenerator(length=128, depth=512,
  114. max_relative_position=16,
  115. initializer_range=0.2),
  116. 'desc_inputs': [],
  117. 'desc_bprop': [[16384, 512]],
  118. 'skip': ['backward']}),
  119. ('RelaPosEmbeddingsGenerator_1', {
  120. 'block': RelaPosEmbeddingsGenerator(length=128, depth=512,
  121. max_relative_position=16,
  122. initializer_range=0.2,
  123. use_one_hot_embeddings=False),
  124. 'desc_inputs': [],
  125. 'desc_bprop': [[128, 128, 512]],
  126. 'skip': ['backward']}),
  127. ('RelaPosEmbeddingsGenerator_2', {
  128. 'block': RelaPosEmbeddingsGenerator(length=128, depth=64,
  129. max_relative_position=16,
  130. initializer_range=0.2,
  131. use_one_hot_embeddings=False),
  132. 'desc_inputs': [],
  133. 'desc_bprop': [[128, 128, 64]],
  134. 'skip': ['backward']}),
  135. ('BertAttention_0', {
  136. 'block': BertAttention(batch_size=64,
  137. from_tensor_width=768,
  138. to_tensor_width=768,
  139. from_seq_length=128,
  140. to_seq_length=128,
  141. num_attention_heads=12,
  142. size_per_head=64,
  143. query_act=None,
  144. key_act=None,
  145. value_act=None,
  146. has_attention_mask=True,
  147. attention_probs_dropout_prob=0.1,
  148. use_one_hot_embeddings=False,
  149. initializer_range=0.02,
  150. do_return_2d_tensor=True,
  151. use_relative_positions=False,
  152. compute_type=mstype.float32),
  153. 'desc_inputs': [[64, 128, 768], [64, 128, 768], [64, 128, 128]],
  154. 'desc_bprop': [[8192, 768]]}),
  155. ('BertAttention_1', {
  156. 'block': BertAttention(batch_size=64,
  157. from_tensor_width=768,
  158. to_tensor_width=768,
  159. from_seq_length=128,
  160. to_seq_length=128,
  161. num_attention_heads=12,
  162. size_per_head=64,
  163. query_act=None,
  164. key_act=None,
  165. value_act=None,
  166. has_attention_mask=True,
  167. attention_probs_dropout_prob=0.1,
  168. use_one_hot_embeddings=False,
  169. initializer_range=0.02,
  170. do_return_2d_tensor=True,
  171. use_relative_positions=True,
  172. compute_type=mstype.float32),
  173. 'desc_inputs': [[64, 128, 768], [64, 128, 768], [64, 128, 128]],
  174. 'desc_bprop': [[8192, 768]]}),
  175. ('BertAttention_2', {
  176. 'block': BertAttention(batch_size=64,
  177. from_tensor_width=768,
  178. to_tensor_width=768,
  179. from_seq_length=128,
  180. to_seq_length=128,
  181. num_attention_heads=12,
  182. size_per_head=64,
  183. query_act=None,
  184. key_act=None,
  185. value_act=None,
  186. has_attention_mask=False,
  187. attention_probs_dropout_prob=0.1,
  188. use_one_hot_embeddings=False,
  189. initializer_range=0.02,
  190. do_return_2d_tensor=True,
  191. use_relative_positions=True,
  192. compute_type=mstype.float32),
  193. 'desc_inputs': [[64, 128, 768], [64, 128, 768], [64, 128, 128]],
  194. 'desc_bprop': [[8192, 768]]}),
  195. ('BertAttention_3', {
  196. 'block': BertAttention(batch_size=64,
  197. from_tensor_width=768,
  198. to_tensor_width=768,
  199. from_seq_length=128,
  200. to_seq_length=128,
  201. num_attention_heads=12,
  202. size_per_head=64,
  203. query_act=None,
  204. key_act=None,
  205. value_act=None,
  206. has_attention_mask=True,
  207. attention_probs_dropout_prob=0.1,
  208. use_one_hot_embeddings=False,
  209. initializer_range=0.02,
  210. do_return_2d_tensor=False,
  211. use_relative_positions=True,
  212. compute_type=mstype.float32),
  213. 'desc_inputs': [[64, 128, 768], [64, 128, 768], [64, 128, 128]],
  214. 'desc_bprop': [[8192, 768]]}),
  215. ('BertOutput', {
  216. 'block': BertOutput(in_channels=768,
  217. out_channels=768,
  218. initializer_range=0.02,
  219. dropout_prob=0.1),
  220. 'desc_inputs': [[8192, 768], [8192, 768]],
  221. 'desc_bprop': [[8192, 768]]}),
  222. ('BertSelfAttention_0', {
  223. 'block': BertSelfAttention(batch_size=64,
  224. seq_length=128,
  225. hidden_size=768,
  226. num_attention_heads=12,
  227. attention_probs_dropout_prob=0.1,
  228. use_one_hot_embeddings=False,
  229. initializer_range=0.02,
  230. hidden_dropout_prob=0.1,
  231. use_relative_positions=False,
  232. compute_type=mstype.float32),
  233. 'desc_inputs': [[64, 128, 768], [64, 128, 128]],
  234. 'desc_bprop': [[8192, 768]]}),
  235. ('BertEncoderCell', {
  236. 'block': BertEncoderCell(batch_size=64,
  237. hidden_size=768,
  238. seq_length=128,
  239. num_attention_heads=12,
  240. intermediate_size=768,
  241. attention_probs_dropout_prob=0.02,
  242. use_one_hot_embeddings=False,
  243. initializer_range=0.02,
  244. hidden_dropout_prob=0.1,
  245. use_relative_positions=False,
  246. hidden_act="gelu",
  247. compute_type=mstype.float32),
  248. 'desc_inputs': [[64, 128, 768], [64, 128, 128]],
  249. 'desc_bprop': [[8192, 768]]}),
  250. ('BertTransformer_0', {
  251. 'block': BertTransformer(batch_size=1,
  252. hidden_size=768,
  253. seq_length=128,
  254. num_hidden_layers=1,
  255. num_attention_heads=12,
  256. intermediate_size=768,
  257. attention_probs_dropout_prob=0.1,
  258. use_one_hot_embeddings=False,
  259. initializer_range=0.02,
  260. use_relative_positions=False,
  261. hidden_act="gelu",
  262. compute_type=mstype.float32,
  263. return_all_encoders=True),
  264. 'desc_inputs': [[1, 128, 768], [1, 128, 128]]}),
  265. ('BertTransformer_1', {
  266. 'block': BertTransformer(batch_size=64,
  267. hidden_size=768,
  268. seq_length=128,
  269. num_hidden_layers=2,
  270. num_attention_heads=12,
  271. intermediate_size=768,
  272. attention_probs_dropout_prob=0.1,
  273. use_one_hot_embeddings=False,
  274. initializer_range=0.02,
  275. use_relative_positions=True,
  276. hidden_act="gelu",
  277. compute_type=mstype.float32,
  278. return_all_encoders=False),
  279. 'desc_inputs': [[64, 128, 768], [64, 128, 128]]}),
  280. ('EmbeddingLookup', {
  281. 'block': EmbeddingLookup(vocab_size=32000,
  282. embedding_size=768,
  283. embedding_shape=[1, 128, 768],
  284. use_one_hot_embeddings=False,
  285. initializer_range=0.02),
  286. 'desc_inputs': [Tensor(np.random.rand(128).astype(np.int32))],
  287. 'desc_bprop': [[1, 128, 768], [1, 128, 768]],
  288. 'num_output': 2}),
  289. ('EmbeddingPostprocessor', {
  290. 'block': EmbeddingPostprocessor(embedding_size=768,
  291. embedding_shape=[1, 128, 768],
  292. use_token_type=True,
  293. token_type_vocab_size=16,
  294. use_one_hot_embeddings=False,
  295. initializer_range=0.02,
  296. max_position_embeddings=512,
  297. dropout_prob=0.1),
  298. 'desc_inputs': [Tensor(np.random.rand(128).astype(np.int32)), [1, 128, 768]],
  299. 'desc_bprop': [[1, 128, 768]]}),
  300. ('CreateAttentionMaskFromInputMask', {
  301. 'block': CreateAttentionMaskFromInputMask(config=BertConfig(batch_size=1)),
  302. 'desc_inputs': [[128]],
  303. 'desc_bprop': [[1, 128, 128]]}),
  304. ('BertOutput_0', {
  305. 'block': BertOutput(in_channels=768,
  306. out_channels=768,
  307. initializer_range=0.02,
  308. dropout_prob=0.1),
  309. 'desc_inputs': [[1, 768], [1, 768]],
  310. 'desc_bprop': [[1, 768]]}),
  311. ('BertTransformer_2', {
  312. 'block': bert_trans(),
  313. 'desc_inputs': [[1, 128, 768], [1, 128, 128]]}),
  314. ('BertModel', {
  315. 'block': BertModel(config=BertConfig(batch_size=1,
  316. num_hidden_layers=1,
  317. intermediate_size=768,
  318. token_type_ids_from_dataset=False),
  319. is_training=True),
  320. 'desc_inputs': [Tensor(np.random.rand(128).astype(np.int32)),
  321. Tensor(np.random.rand(128).astype(np.int32)), [128]],
  322. 'desc_bprop': [[1, 128, 768], [1, 128, 768], [1, 128, 768]],
  323. 'num_output': 3}),
  324. ('BertModel_1', {
  325. 'block': BertModel(config=BertConfig(batch_size=1,
  326. num_hidden_layers=1,
  327. intermediate_size=768,
  328. token_type_ids_from_dataset=False),
  329. is_training=False),
  330. 'desc_inputs': [Tensor(np.random.rand(128).astype(np.int32)),
  331. Tensor(np.random.rand(128).astype(np.int32)), [128]],
  332. 'desc_bprop': [[1, 128, 768], [1, 128, 768], [1, 128, 768]],
  333. 'num_output': 3}),
  334. ('BertModel_2', {
  335. 'block': BertModel(config=BertConfig(batch_size=1,
  336. num_hidden_layers=1,
  337. intermediate_size=768,
  338. token_type_ids_from_dataset=False,
  339. input_mask_from_dataset=False),
  340. is_training=True),
  341. 'desc_inputs': [Tensor(np.random.rand(128).astype(np.int32)),
  342. Tensor(np.random.rand(128).astype(np.int32)), [128]],
  343. 'desc_bprop': [[1, 128, 768], [1, 128, 768], [1, 128, 768]],
  344. 'num_output': 3}),
  345. ('BertPretrainingLoss', {
  346. 'block': BertPretrainingLoss(config=BertConfig(batch_size=1)),
  347. 'desc_inputs': [[32000], [20, 2], Tensor(np.array([1]).astype(np.int32)),
  348. [20], Tensor(np.array([20]).astype(np.int32))],
  349. 'desc_bprop': [[1]],
  350. 'num_output': 1}),
  351. ('Dense_1', {
  352. 'block': nn.Dense(in_channels=768,
  353. out_channels=3072,
  354. activation='gelu',
  355. weight_init=TruncatedNormal(0.02)),
  356. 'desc_inputs': [[3, 768]],
  357. 'desc_bprop': [[3, 3072]]}),
  358. ('Dense_2', {
  359. 'block': set_train(nn.Dense(in_channels=768,
  360. out_channels=3072,
  361. activation='gelu',
  362. weight_init=TruncatedNormal(0.02),)),
  363. 'desc_inputs': [[3, 768]],
  364. 'desc_bprop': [[3, 3072]]}),
  365. ('GetNextSentenceOutput', {
  366. 'block': GetNextSentenceOutput(BertConfig(batch_size=1)),
  367. 'desc_inputs': [[128, 768]],
  368. 'desc_bprop': [[128, 2]]}),
  369. ('Adam_1', {
  370. 'block': set_train(TrainStepWrapForAdam(NetForAdam())),
  371. 'desc_inputs': [[1, 64], [1, 10]],
  372. 'skip': ['backward']}),
  373. ('Adam_2', {
  374. 'block': set_train(TrainStepWrapForAdam(GetNextSentenceOutput(BertConfig(batch_size=1)))),
  375. 'desc_inputs': [[128, 768], [128, 2]],
  376. 'skip': ['backward']}),
  377. ('AdamWeightDecayDynamicLR', {
  378. 'block': set_train(TrainStepWrapForAdamDynamicLr(NetForAdam())),
  379. 'desc_inputs': [[1, 64]],
  380. 'skip': ['backward']}),
  381. ('ClipGradients', {
  382. 'block': TempC2Wrap(ClipGradients(), 1, 1.0),
  383. 'desc_inputs': [tuple(convert(shp) for shp in [[1], [1], [1]])],
  384. 'skip': ['backward', 'exec']}),
  385. ]
  386. test_case = functools.reduce(lambda x, y: x+y, [test_case_cell_ops])
  387. # use -k to select certain testcast
  388. # pytest tests/python/ops/test_ops.py::test_backward -k LayerNorm
  389. test_exec_case = filter(lambda x: 'skip' not in x[1] or
  390. 'exec' not in x[1]['skip'], test_case)
  391. test_backward_exec_case = filter(lambda x: 'skip' not in x[1] or
  392. 'backward' not in x[1]['skip'] and 'backward_exec'
  393. not in x[1]['skip'], test_case)
  394. test_check_gradient_case = filter(lambda x: 'skip' not in x[1] or
  395. 'backward' not in x[1]['skip'] and 'backward_exec'
  396. not in x[1]['skip'], test_case)
  397. @mindspore_test(pipeline_for_compile_forward_ge_graph_for_case_by_case_config)
  398. def test_exec():
  399. return test_exec_case
  400. @mindspore_test(pipeline_for_compile_grad_ge_graph_for_case_by_case_config)
  401. def test_backward_exec():
  402. return test_backward_exec_case