You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_bert_cell.py 20 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """ test bert of graph compile """
  16. import functools
  17. import numpy as np
  18. import mindspore.common.dtype as mstype
  19. import mindspore.nn as nn
  20. import mindspore.ops.composite as C
  21. from mindspore.common.initializer import TruncatedNormal
  22. from mindspore.common.parameter import ParameterTuple
  23. from mindspore.common.tensor import Tensor
  24. from mindspore.model_zoo.Bert_NEZHA import BertPretrainingLoss, GetNextSentenceOutput
  25. from mindspore.model_zoo.Bert_NEZHA.bert_for_pre_training import ClipGradients
  26. from mindspore.model_zoo.Bert_NEZHA.bert_model import BertConfig, \
  27. EmbeddingLookup, EmbeddingPostprocessor, BertOutput, RelaPosMatrixGenerator, \
  28. RelaPosEmbeddingsGenerator, SaturateCast, BertAttention, BertSelfAttention, \
  29. BertEncoderCell, BertTransformer, CreateAttentionMaskFromInputMask, BertModel
  30. from mindspore.nn.layer.basic import Norm
  31. from mindspore.nn.optim import AdamWeightDecay, AdamWeightDecayDynamicLR
  32. from ....mindspore_test_framework.mindspore_test import mindspore_test
  33. from ....mindspore_test_framework.pipeline.forward.compile_forward import \
  34. pipeline_for_compile_forward_ge_graph_for_case_by_case_config
  35. from ....mindspore_test_framework.pipeline.gradient.compile_gradient import \
  36. pipeline_for_compile_grad_ge_graph_for_case_by_case_config
  37. from ....ops_common import convert
  38. def bert_trans():
  39. """bert_trans"""
  40. net = BertTransformer(batch_size=1,
  41. hidden_size=768,
  42. seq_length=128,
  43. num_hidden_layers=1,
  44. num_attention_heads=12,
  45. intermediate_size=768,
  46. attention_probs_dropout_prob=0.1,
  47. use_one_hot_embeddings=False,
  48. initializer_range=0.02,
  49. use_relative_positions=False,
  50. hidden_act="gelu",
  51. compute_type=mstype.float32,
  52. return_all_encoders=True)
  53. net.set_train()
  54. return net
  55. def set_train(net):
  56. net.set_train()
  57. return net
  58. class NetForAdam(nn.Cell):
  59. def __init__(self):
  60. super(NetForAdam, self).__init__()
  61. self.dense = nn.Dense(64, 10)
  62. def construct(self, x):
  63. x = self.dense(x)
  64. return x
  65. class TrainStepWrapForAdam(nn.Cell):
  66. """TrainStepWrapForAdam definition"""
  67. def __init__(self, network):
  68. super(TrainStepWrapForAdam, self).__init__()
  69. self.network = network
  70. self.weights = ParameterTuple(network.get_parameters())
  71. self.optimizer = AdamWeightDecay(self.weights)
  72. self.clip_gradients = ClipGradients()
  73. def construct(self, x, sens):
  74. weights = self.weights
  75. grads = C.grad_by_list_with_sens(self.network, weights)(x, sens)
  76. grads = self.clip_gradients(grads, 1, 1.0)
  77. return self.optimizer(grads)
  78. class TrainStepWrapForAdamDynamicLr(nn.Cell):
  79. """TrainStepWrapForAdamDynamicLr definition"""
  80. def __init__(self, network):
  81. super(TrainStepWrapForAdamDynamicLr, self).__init__()
  82. self.network = network
  83. self.weights = ParameterTuple(network.get_parameters())
  84. self.optimizer = AdamWeightDecayDynamicLR(self.weights, 10)
  85. self.sens = Tensor(np.ones(shape=(1, 10)).astype(np.float32))
  86. def construct(self, x):
  87. weights = self.weights
  88. grads = C.grad_by_list_with_sens(self.network, weights)(x, self.sens)
  89. return self.optimizer(grads)
  90. class TempC2Wrap(nn.Cell):
  91. def __init__(self, op, c1=None, c2=None, ):
  92. super(TempC2Wrap, self).__init__()
  93. self.op = op
  94. self.c1 = c1
  95. self.c2 = c2
  96. def construct(self, x1):
  97. x = self.op(x1, self.c1, self.c2)
  98. return x
  99. test_case_cell_ops = [
  100. ('Norm_keepdims', {
  101. 'block': Norm(keep_dims=True),
  102. 'desc_inputs': [[1, 3, 4, 4]],
  103. 'desc_bprop': [[1]]}),
  104. ('SaturateCast', {
  105. 'block': SaturateCast(),
  106. 'desc_inputs': [[1, 3, 4, 4]],
  107. 'desc_bprop': [[1, 3, 4, 4]]}),
  108. ('RelaPosMatrixGenerator_0', {
  109. 'block': RelaPosMatrixGenerator(length=128, max_relative_position=16),
  110. 'desc_inputs': [],
  111. 'desc_bprop': [[128, 128]],
  112. 'skip': ['backward']}),
  113. ('RelaPosEmbeddingsGenerator_0', {
  114. 'block': RelaPosEmbeddingsGenerator(length=128, depth=512,
  115. max_relative_position=16,
  116. initializer_range=0.2),
  117. 'desc_inputs': [],
  118. 'desc_bprop': [[16384, 512]],
  119. 'skip': ['backward']}),
  120. ('RelaPosEmbeddingsGenerator_1', {
  121. 'block': RelaPosEmbeddingsGenerator(length=128, depth=512,
  122. max_relative_position=16,
  123. initializer_range=0.2,
  124. use_one_hot_embeddings=False),
  125. 'desc_inputs': [],
  126. 'desc_bprop': [[128, 128, 512]],
  127. 'skip': ['backward']}),
  128. ('RelaPosEmbeddingsGenerator_2', {
  129. 'block': RelaPosEmbeddingsGenerator(length=128, depth=64,
  130. max_relative_position=16,
  131. initializer_range=0.2,
  132. use_one_hot_embeddings=False),
  133. 'desc_inputs': [],
  134. 'desc_bprop': [[128, 128, 64]],
  135. 'skip': ['backward']}),
  136. ('BertAttention_0', {
  137. 'block': BertAttention(batch_size=64,
  138. from_tensor_width=768,
  139. to_tensor_width=768,
  140. from_seq_length=128,
  141. to_seq_length=128,
  142. num_attention_heads=12,
  143. size_per_head=64,
  144. query_act=None,
  145. key_act=None,
  146. value_act=None,
  147. has_attention_mask=True,
  148. attention_probs_dropout_prob=0.1,
  149. use_one_hot_embeddings=False,
  150. initializer_range=0.02,
  151. do_return_2d_tensor=True,
  152. use_relative_positions=False,
  153. compute_type=mstype.float32),
  154. 'desc_inputs': [[64, 128, 768], [64, 128, 768], [64, 128, 128]],
  155. 'desc_bprop': [[8192, 768]]}),
  156. ('BertAttention_1', {
  157. 'block': BertAttention(batch_size=64,
  158. from_tensor_width=768,
  159. to_tensor_width=768,
  160. from_seq_length=128,
  161. to_seq_length=128,
  162. num_attention_heads=12,
  163. size_per_head=64,
  164. query_act=None,
  165. key_act=None,
  166. value_act=None,
  167. has_attention_mask=True,
  168. attention_probs_dropout_prob=0.1,
  169. use_one_hot_embeddings=False,
  170. initializer_range=0.02,
  171. do_return_2d_tensor=True,
  172. use_relative_positions=True,
  173. compute_type=mstype.float32),
  174. 'desc_inputs': [[64, 128, 768], [64, 128, 768], [64, 128, 128]],
  175. 'desc_bprop': [[8192, 768]]}),
  176. ('BertAttention_2', {
  177. 'block': BertAttention(batch_size=64,
  178. from_tensor_width=768,
  179. to_tensor_width=768,
  180. from_seq_length=128,
  181. to_seq_length=128,
  182. num_attention_heads=12,
  183. size_per_head=64,
  184. query_act=None,
  185. key_act=None,
  186. value_act=None,
  187. has_attention_mask=False,
  188. attention_probs_dropout_prob=0.1,
  189. use_one_hot_embeddings=False,
  190. initializer_range=0.02,
  191. do_return_2d_tensor=True,
  192. use_relative_positions=True,
  193. compute_type=mstype.float32),
  194. 'desc_inputs': [[64, 128, 768], [64, 128, 768], [64, 128, 128]],
  195. 'desc_bprop': [[8192, 768]]}),
  196. ('BertAttention_3', {
  197. 'block': BertAttention(batch_size=64,
  198. from_tensor_width=768,
  199. to_tensor_width=768,
  200. from_seq_length=128,
  201. to_seq_length=128,
  202. num_attention_heads=12,
  203. size_per_head=64,
  204. query_act=None,
  205. key_act=None,
  206. value_act=None,
  207. has_attention_mask=True,
  208. attention_probs_dropout_prob=0.1,
  209. use_one_hot_embeddings=False,
  210. initializer_range=0.02,
  211. do_return_2d_tensor=False,
  212. use_relative_positions=True,
  213. compute_type=mstype.float32),
  214. 'desc_inputs': [[64, 128, 768], [64, 128, 768], [64, 128, 128]],
  215. 'desc_bprop': [[8192, 768]]}),
  216. ('BertOutput', {
  217. 'block': BertOutput(in_channels=768,
  218. out_channels=768,
  219. initializer_range=0.02,
  220. dropout_prob=0.1),
  221. 'desc_inputs': [[8192, 768], [8192, 768]],
  222. 'desc_bprop': [[8192, 768]]}),
  223. ('BertSelfAttention_0', {
  224. 'block': BertSelfAttention(batch_size=64,
  225. seq_length=128,
  226. hidden_size=768,
  227. num_attention_heads=12,
  228. attention_probs_dropout_prob=0.1,
  229. use_one_hot_embeddings=False,
  230. initializer_range=0.02,
  231. hidden_dropout_prob=0.1,
  232. use_relative_positions=False,
  233. compute_type=mstype.float32),
  234. 'desc_inputs': [[64, 128, 768], [64, 128, 128]],
  235. 'desc_bprop': [[8192, 768]]}),
  236. ('BertEncoderCell', {
  237. 'block': BertEncoderCell(batch_size=64,
  238. hidden_size=768,
  239. seq_length=128,
  240. num_attention_heads=12,
  241. intermediate_size=768,
  242. attention_probs_dropout_prob=0.02,
  243. use_one_hot_embeddings=False,
  244. initializer_range=0.02,
  245. hidden_dropout_prob=0.1,
  246. use_relative_positions=False,
  247. hidden_act="gelu",
  248. compute_type=mstype.float32),
  249. 'desc_inputs': [[64, 128, 768], [64, 128, 128]],
  250. 'desc_bprop': [[8192, 768]]}),
  251. ('BertTransformer_0', {
  252. 'block': BertTransformer(batch_size=1,
  253. hidden_size=768,
  254. seq_length=128,
  255. num_hidden_layers=1,
  256. num_attention_heads=12,
  257. intermediate_size=768,
  258. attention_probs_dropout_prob=0.1,
  259. use_one_hot_embeddings=False,
  260. initializer_range=0.02,
  261. use_relative_positions=False,
  262. hidden_act="gelu",
  263. compute_type=mstype.float32,
  264. return_all_encoders=True),
  265. 'desc_inputs': [[1, 128, 768], [1, 128, 128]]}),
  266. ('BertTransformer_1', {
  267. 'block': BertTransformer(batch_size=64,
  268. hidden_size=768,
  269. seq_length=128,
  270. num_hidden_layers=2,
  271. num_attention_heads=12,
  272. intermediate_size=768,
  273. attention_probs_dropout_prob=0.1,
  274. use_one_hot_embeddings=False,
  275. initializer_range=0.02,
  276. use_relative_positions=True,
  277. hidden_act="gelu",
  278. compute_type=mstype.float32,
  279. return_all_encoders=False),
  280. 'desc_inputs': [[64, 128, 768], [64, 128, 128]]}),
  281. ('EmbeddingLookup', {
  282. 'block': EmbeddingLookup(vocab_size=32000,
  283. embedding_size=768,
  284. embedding_shape=[1, 128, 768],
  285. use_one_hot_embeddings=False,
  286. initializer_range=0.02),
  287. 'desc_inputs': [Tensor(np.random.rand(128).astype(np.int32))],
  288. 'desc_bprop': [[1, 128, 768], [1, 128, 768]],
  289. 'num_output': 2}),
  290. ('EmbeddingPostprocessor', {
  291. 'block': EmbeddingPostprocessor(embedding_size=768,
  292. embedding_shape=[1, 128, 768],
  293. use_token_type=True,
  294. token_type_vocab_size=16,
  295. use_one_hot_embeddings=False,
  296. initializer_range=0.02,
  297. max_position_embeddings=512,
  298. dropout_prob=0.1),
  299. 'desc_inputs': [Tensor(np.random.rand(128).astype(np.int32)), [1, 128, 768]],
  300. 'desc_bprop': [[1, 128, 768]]}),
  301. ('CreateAttentionMaskFromInputMask', {
  302. 'block': CreateAttentionMaskFromInputMask(config=BertConfig(batch_size=1)),
  303. 'desc_inputs': [[128]],
  304. 'desc_bprop': [[1, 128, 128]]}),
  305. ('BertOutput_0', {
  306. 'block': BertOutput(in_channels=768,
  307. out_channels=768,
  308. initializer_range=0.02,
  309. dropout_prob=0.1),
  310. 'desc_inputs': [[1, 768], [1, 768]],
  311. 'desc_bprop': [[1, 768]]}),
  312. ('BertTransformer_2', {
  313. 'block': bert_trans(),
  314. 'desc_inputs': [[1, 128, 768], [1, 128, 128]]}),
  315. ('BertModel', {
  316. 'block': BertModel(config=BertConfig(batch_size=1,
  317. num_hidden_layers=1,
  318. intermediate_size=768,
  319. token_type_ids_from_dataset=False),
  320. is_training=True),
  321. 'desc_inputs': [Tensor(np.random.rand(128).astype(np.int32)),
  322. Tensor(np.random.rand(128).astype(np.int32)), [128]],
  323. 'desc_bprop': [[1, 128, 768], [1, 128, 768], [1, 128, 768]],
  324. 'num_output': 3}),
  325. ('BertModel_1', {
  326. 'block': BertModel(config=BertConfig(batch_size=1,
  327. num_hidden_layers=1,
  328. intermediate_size=768,
  329. token_type_ids_from_dataset=False),
  330. is_training=False),
  331. 'desc_inputs': [Tensor(np.random.rand(128).astype(np.int32)),
  332. Tensor(np.random.rand(128).astype(np.int32)), [128]],
  333. 'desc_bprop': [[1, 128, 768], [1, 128, 768], [1, 128, 768]],
  334. 'num_output': 3}),
  335. ('BertModel_2', {
  336. 'block': BertModel(config=BertConfig(batch_size=1,
  337. num_hidden_layers=1,
  338. intermediate_size=768,
  339. token_type_ids_from_dataset=False,
  340. input_mask_from_dataset=False),
  341. is_training=True),
  342. 'desc_inputs': [Tensor(np.random.rand(128).astype(np.int32)),
  343. Tensor(np.random.rand(128).astype(np.int32)), [128]],
  344. 'desc_bprop': [[1, 128, 768], [1, 128, 768], [1, 128, 768]],
  345. 'num_output': 3}),
  346. ('BertPretrainingLoss', {
  347. 'block': BertPretrainingLoss(config=BertConfig(batch_size=1)),
  348. 'desc_inputs': [[32000], [20, 2], Tensor(np.array([1]).astype(np.int32)),
  349. [20], Tensor(np.array([20]).astype(np.int32))],
  350. 'desc_bprop': [[1]],
  351. 'num_output': 1}),
  352. ('Dense_1', {
  353. 'block': nn.Dense(in_channels=768,
  354. out_channels=3072,
  355. activation='gelu',
  356. weight_init=TruncatedNormal(0.02)),
  357. 'desc_inputs': [[3, 768]],
  358. 'desc_bprop': [[3, 3072]]}),
  359. ('Dense_2', {
  360. 'block': set_train(nn.Dense(in_channels=768,
  361. out_channels=3072,
  362. activation='gelu',
  363. weight_init=TruncatedNormal(0.02), )),
  364. 'desc_inputs': [[3, 768]],
  365. 'desc_bprop': [[3, 3072]]}),
  366. ('GetNextSentenceOutput', {
  367. 'block': GetNextSentenceOutput(BertConfig(batch_size=1)),
  368. 'desc_inputs': [[128, 768]],
  369. 'desc_bprop': [[128, 2]]}),
  370. ('Adam_1', {
  371. 'block': set_train(TrainStepWrapForAdam(NetForAdam())),
  372. 'desc_inputs': [[1, 64], [1, 10]],
  373. 'skip': ['backward']}),
  374. ('Adam_2', {
  375. 'block': set_train(TrainStepWrapForAdam(GetNextSentenceOutput(BertConfig(batch_size=1)))),
  376. 'desc_inputs': [[128, 768], [128, 2]],
  377. 'skip': ['backward']}),
  378. ('AdamWeightDecayDynamicLR', {
  379. 'block': set_train(TrainStepWrapForAdamDynamicLr(NetForAdam())),
  380. 'desc_inputs': [[1, 64]],
  381. 'skip': ['backward']}),
  382. ('ClipGradients', {
  383. 'block': TempC2Wrap(ClipGradients(), 1, 1.0),
  384. 'desc_inputs': [tuple(convert(shp) for shp in [[1], [1], [1]])],
  385. 'skip': ['backward', 'exec']}),
  386. ]
  387. test_case = functools.reduce(lambda x, y: x + y, [test_case_cell_ops])
  388. # use -k to select certain testcast
  389. # pytest tests/python/ops/test_ops.py::test_backward -k LayerNorm
  390. test_exec_case = filter(lambda x: 'skip' not in x[1] or
  391. 'exec' not in x[1]['skip'], test_case)
  392. test_backward_exec_case = filter(lambda x: 'skip' not in x[1] or
  393. 'backward' not in x[1]['skip'] and 'backward_exec'
  394. not in x[1]['skip'], test_case)
  395. test_check_gradient_case = filter(lambda x: 'skip' not in x[1] or
  396. 'backward' not in x[1]['skip'] and 'backward_exec'
  397. not in x[1]['skip'], test_case)
  398. @mindspore_test(pipeline_for_compile_forward_ge_graph_for_case_by_case_config)
  399. def test_exec():
  400. return test_exec_case
  401. @mindspore_test(pipeline_for_compile_grad_ge_graph_for_case_by_case_config)
  402. def test_backward_exec():
  403. return test_backward_exec_case