You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_bert_cell.py 20 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """ test bert of graph compile """
  16. import functools
  17. import numpy as np
  18. import mindspore.common.dtype as mstype
  19. import mindspore.nn as nn
  20. import mindspore.ops.composite as C
  21. from mindspore.ops import functional as F
  22. from mindspore.common.initializer import TruncatedNormal
  23. from mindspore.common.parameter import ParameterTuple
  24. from mindspore.common.tensor import Tensor
  25. from mindspore.model_zoo.Bert_NEZHA import BertPretrainingLoss, GetNextSentenceOutput
  26. from mindspore.model_zoo.Bert_NEZHA.bert_for_pre_training import clip_grad
  27. from mindspore.model_zoo.Bert_NEZHA.bert_model import BertConfig, \
  28. EmbeddingLookup, EmbeddingPostprocessor, BertOutput, RelaPosMatrixGenerator, \
  29. RelaPosEmbeddingsGenerator, SaturateCast, BertAttention, BertSelfAttention, \
  30. BertEncoderCell, BertTransformer, CreateAttentionMaskFromInputMask, BertModel
  31. from mindspore.nn.layer.basic import Norm
  32. from mindspore.nn.optim import AdamWeightDecay, AdamWeightDecayDynamicLR
  33. from ....mindspore_test_framework.mindspore_test import mindspore_test
  34. from ....mindspore_test_framework.pipeline.forward.compile_forward import \
  35. pipeline_for_compile_forward_ge_graph_for_case_by_case_config
  36. from ....mindspore_test_framework.pipeline.gradient.compile_gradient import \
  37. pipeline_for_compile_grad_ge_graph_for_case_by_case_config
  38. from ....ops_common import convert
  39. def bert_trans():
  40. """bert_trans"""
  41. net = BertTransformer(batch_size=1,
  42. hidden_size=768,
  43. seq_length=128,
  44. num_hidden_layers=1,
  45. num_attention_heads=12,
  46. intermediate_size=768,
  47. attention_probs_dropout_prob=0.1,
  48. use_one_hot_embeddings=False,
  49. initializer_range=0.02,
  50. use_relative_positions=False,
  51. hidden_act="gelu",
  52. compute_type=mstype.float32,
  53. return_all_encoders=True)
  54. net.set_train()
  55. return net
  56. def set_train(net):
  57. net.set_train()
  58. return net
  59. class NetForAdam(nn.Cell):
  60. def __init__(self):
  61. super(NetForAdam, self).__init__()
  62. self.dense = nn.Dense(64, 10)
  63. def construct(self, x):
  64. x = self.dense(x)
  65. return x
  66. class TrainStepWrapForAdam(nn.Cell):
  67. """TrainStepWrapForAdam definition"""
  68. def __init__(self, network):
  69. super(TrainStepWrapForAdam, self).__init__()
  70. self.network = network
  71. self.weights = ParameterTuple(network.get_parameters())
  72. self.optimizer = AdamWeightDecay(self.weights)
  73. self.hyper_map = C.HyperMap()
  74. def construct(self, x, sens):
  75. weights = self.weights
  76. grads = C.grad_by_list_with_sens(self.network, weights)(x, sens)
  77. grads = self.hyper_map(F.partial(clip_grad, 1, 1.0), grads)
  78. return self.optimizer(grads)
  79. class TrainStepWrapForAdamDynamicLr(nn.Cell):
  80. """TrainStepWrapForAdamDynamicLr definition"""
  81. def __init__(self, network):
  82. super(TrainStepWrapForAdamDynamicLr, self).__init__()
  83. self.network = network
  84. self.weights = ParameterTuple(network.get_parameters())
  85. self.optimizer = AdamWeightDecayDynamicLR(self.weights, 10)
  86. self.sens = Tensor(np.ones(shape=(1, 10)).astype(np.float32))
  87. def construct(self, x):
  88. weights = self.weights
  89. grads = C.grad_by_list_with_sens(self.network, weights)(x, self.sens)
  90. return self.optimizer(grads)
  91. class TempC2Wrap(nn.Cell):
  92. def __init__(self, op, c1=None, c2=None, ):
  93. super(TempC2Wrap, self).__init__()
  94. self.op = op
  95. self.c1 = c1
  96. self.c2 = c2
  97. self.hyper_map = C.HyperMap()
  98. def construct(self, x1):
  99. x = self.hyper_map(F.partial(self.op, self.c1, self.c2), x1)
  100. return x
  101. test_case_cell_ops = [
  102. ('Norm_keepdims', {
  103. 'block': Norm(keep_dims=True),
  104. 'desc_inputs': [[1, 3, 4, 4]],
  105. 'desc_bprop': [[1]]}),
  106. ('SaturateCast', {
  107. 'block': SaturateCast(),
  108. 'desc_inputs': [[1, 3, 4, 4]],
  109. 'desc_bprop': [[1, 3, 4, 4]]}),
  110. ('RelaPosMatrixGenerator_0', {
  111. 'block': RelaPosMatrixGenerator(length=128, max_relative_position=16),
  112. 'desc_inputs': [],
  113. 'desc_bprop': [[128, 128]],
  114. 'skip': ['backward']}),
  115. ('RelaPosEmbeddingsGenerator_0', {
  116. 'block': RelaPosEmbeddingsGenerator(length=128, depth=512,
  117. max_relative_position=16,
  118. initializer_range=0.2),
  119. 'desc_inputs': [],
  120. 'desc_bprop': [[16384, 512]],
  121. 'skip': ['backward']}),
  122. ('RelaPosEmbeddingsGenerator_1', {
  123. 'block': RelaPosEmbeddingsGenerator(length=128, depth=512,
  124. max_relative_position=16,
  125. initializer_range=0.2,
  126. use_one_hot_embeddings=False),
  127. 'desc_inputs': [],
  128. 'desc_bprop': [[128, 128, 512]],
  129. 'skip': ['backward']}),
  130. ('RelaPosEmbeddingsGenerator_2', {
  131. 'block': RelaPosEmbeddingsGenerator(length=128, depth=64,
  132. max_relative_position=16,
  133. initializer_range=0.2,
  134. use_one_hot_embeddings=False),
  135. 'desc_inputs': [],
  136. 'desc_bprop': [[128, 128, 64]],
  137. 'skip': ['backward']}),
  138. ('BertAttention_0', {
  139. 'block': BertAttention(batch_size=64,
  140. from_tensor_width=768,
  141. to_tensor_width=768,
  142. from_seq_length=128,
  143. to_seq_length=128,
  144. num_attention_heads=12,
  145. size_per_head=64,
  146. query_act=None,
  147. key_act=None,
  148. value_act=None,
  149. has_attention_mask=True,
  150. attention_probs_dropout_prob=0.1,
  151. use_one_hot_embeddings=False,
  152. initializer_range=0.02,
  153. do_return_2d_tensor=True,
  154. use_relative_positions=False,
  155. compute_type=mstype.float32),
  156. 'desc_inputs': [[64, 128, 768], [64, 128, 768], [64, 128, 128]],
  157. 'desc_bprop': [[8192, 768]]}),
  158. ('BertAttention_1', {
  159. 'block': BertAttention(batch_size=64,
  160. from_tensor_width=768,
  161. to_tensor_width=768,
  162. from_seq_length=128,
  163. to_seq_length=128,
  164. num_attention_heads=12,
  165. size_per_head=64,
  166. query_act=None,
  167. key_act=None,
  168. value_act=None,
  169. has_attention_mask=True,
  170. attention_probs_dropout_prob=0.1,
  171. use_one_hot_embeddings=False,
  172. initializer_range=0.02,
  173. do_return_2d_tensor=True,
  174. use_relative_positions=True,
  175. compute_type=mstype.float32),
  176. 'desc_inputs': [[64, 128, 768], [64, 128, 768], [64, 128, 128]],
  177. 'desc_bprop': [[8192, 768]]}),
  178. ('BertAttention_2', {
  179. 'block': BertAttention(batch_size=64,
  180. from_tensor_width=768,
  181. to_tensor_width=768,
  182. from_seq_length=128,
  183. to_seq_length=128,
  184. num_attention_heads=12,
  185. size_per_head=64,
  186. query_act=None,
  187. key_act=None,
  188. value_act=None,
  189. has_attention_mask=False,
  190. attention_probs_dropout_prob=0.1,
  191. use_one_hot_embeddings=False,
  192. initializer_range=0.02,
  193. do_return_2d_tensor=True,
  194. use_relative_positions=True,
  195. compute_type=mstype.float32),
  196. 'desc_inputs': [[64, 128, 768], [64, 128, 768], [64, 128, 128]],
  197. 'desc_bprop': [[8192, 768]]}),
  198. ('BertAttention_3', {
  199. 'block': BertAttention(batch_size=64,
  200. from_tensor_width=768,
  201. to_tensor_width=768,
  202. from_seq_length=128,
  203. to_seq_length=128,
  204. num_attention_heads=12,
  205. size_per_head=64,
  206. query_act=None,
  207. key_act=None,
  208. value_act=None,
  209. has_attention_mask=True,
  210. attention_probs_dropout_prob=0.1,
  211. use_one_hot_embeddings=False,
  212. initializer_range=0.02,
  213. do_return_2d_tensor=False,
  214. use_relative_positions=True,
  215. compute_type=mstype.float32),
  216. 'desc_inputs': [[64, 128, 768], [64, 128, 768], [64, 128, 128]],
  217. 'desc_bprop': [[8192, 768]]}),
  218. ('BertOutput', {
  219. 'block': BertOutput(in_channels=768,
  220. out_channels=768,
  221. initializer_range=0.02,
  222. dropout_prob=0.1),
  223. 'desc_inputs': [[8192, 768], [8192, 768]],
  224. 'desc_bprop': [[8192, 768]]}),
  225. ('BertSelfAttention_0', {
  226. 'block': BertSelfAttention(batch_size=64,
  227. seq_length=128,
  228. hidden_size=768,
  229. num_attention_heads=12,
  230. attention_probs_dropout_prob=0.1,
  231. use_one_hot_embeddings=False,
  232. initializer_range=0.02,
  233. hidden_dropout_prob=0.1,
  234. use_relative_positions=False,
  235. compute_type=mstype.float32),
  236. 'desc_inputs': [[64, 128, 768], [64, 128, 128]],
  237. 'desc_bprop': [[8192, 768]]}),
  238. ('BertEncoderCell', {
  239. 'block': BertEncoderCell(batch_size=64,
  240. hidden_size=768,
  241. seq_length=128,
  242. num_attention_heads=12,
  243. intermediate_size=768,
  244. attention_probs_dropout_prob=0.02,
  245. use_one_hot_embeddings=False,
  246. initializer_range=0.02,
  247. hidden_dropout_prob=0.1,
  248. use_relative_positions=False,
  249. hidden_act="gelu",
  250. compute_type=mstype.float32),
  251. 'desc_inputs': [[64, 128, 768], [64, 128, 128]],
  252. 'desc_bprop': [[8192, 768]]}),
  253. ('BertTransformer_0', {
  254. 'block': BertTransformer(batch_size=1,
  255. hidden_size=768,
  256. seq_length=128,
  257. num_hidden_layers=1,
  258. num_attention_heads=12,
  259. intermediate_size=768,
  260. attention_probs_dropout_prob=0.1,
  261. use_one_hot_embeddings=False,
  262. initializer_range=0.02,
  263. use_relative_positions=False,
  264. hidden_act="gelu",
  265. compute_type=mstype.float32,
  266. return_all_encoders=True),
  267. 'desc_inputs': [[1, 128, 768], [1, 128, 128]]}),
  268. ('BertTransformer_1', {
  269. 'block': BertTransformer(batch_size=64,
  270. hidden_size=768,
  271. seq_length=128,
  272. num_hidden_layers=2,
  273. num_attention_heads=12,
  274. intermediate_size=768,
  275. attention_probs_dropout_prob=0.1,
  276. use_one_hot_embeddings=False,
  277. initializer_range=0.02,
  278. use_relative_positions=True,
  279. hidden_act="gelu",
  280. compute_type=mstype.float32,
  281. return_all_encoders=False),
  282. 'desc_inputs': [[64, 128, 768], [64, 128, 128]]}),
  283. ('EmbeddingLookup', {
  284. 'block': EmbeddingLookup(vocab_size=32000,
  285. embedding_size=768,
  286. embedding_shape=[1, 128, 768],
  287. use_one_hot_embeddings=False,
  288. initializer_range=0.02),
  289. 'desc_inputs': [Tensor(np.random.rand(128).astype(np.int32))],
  290. 'desc_bprop': [[1, 128, 768], [1, 128, 768]],
  291. 'num_output': 2}),
  292. ('EmbeddingPostprocessor', {
  293. 'block': EmbeddingPostprocessor(embedding_size=768,
  294. embedding_shape=[1, 128, 768],
  295. use_token_type=True,
  296. token_type_vocab_size=16,
  297. use_one_hot_embeddings=False,
  298. initializer_range=0.02,
  299. max_position_embeddings=512,
  300. dropout_prob=0.1),
  301. 'desc_inputs': [Tensor(np.random.rand(128).astype(np.int32)), [1, 128, 768]],
  302. 'desc_bprop': [[1, 128, 768]]}),
  303. ('CreateAttentionMaskFromInputMask', {
  304. 'block': CreateAttentionMaskFromInputMask(config=BertConfig(batch_size=1)),
  305. 'desc_inputs': [[128]],
  306. 'desc_bprop': [[1, 128, 128]]}),
  307. ('BertOutput_0', {
  308. 'block': BertOutput(in_channels=768,
  309. out_channels=768,
  310. initializer_range=0.02,
  311. dropout_prob=0.1),
  312. 'desc_inputs': [[1, 768], [1, 768]],
  313. 'desc_bprop': [[1, 768]]}),
  314. ('BertTransformer_2', {
  315. 'block': bert_trans(),
  316. 'desc_inputs': [[1, 128, 768], [1, 128, 128]]}),
  317. ('BertModel', {
  318. 'block': BertModel(config=BertConfig(batch_size=1,
  319. num_hidden_layers=1,
  320. intermediate_size=768,
  321. token_type_ids_from_dataset=False),
  322. is_training=True),
  323. 'desc_inputs': [Tensor(np.random.rand(128).astype(np.int32)),
  324. Tensor(np.random.rand(128).astype(np.int32)), [128]],
  325. 'desc_bprop': [[1, 128, 768], [1, 128, 768], [1, 128, 768]],
  326. 'num_output': 3}),
  327. ('BertModel_1', {
  328. 'block': BertModel(config=BertConfig(batch_size=1,
  329. num_hidden_layers=1,
  330. intermediate_size=768,
  331. token_type_ids_from_dataset=False),
  332. is_training=False),
  333. 'desc_inputs': [Tensor(np.random.rand(128).astype(np.int32)),
  334. Tensor(np.random.rand(128).astype(np.int32)), [128]],
  335. 'desc_bprop': [[1, 128, 768], [1, 128, 768], [1, 128, 768]],
  336. 'num_output': 3}),
  337. ('BertModel_2', {
  338. 'block': BertModel(config=BertConfig(batch_size=1,
  339. num_hidden_layers=1,
  340. intermediate_size=768,
  341. token_type_ids_from_dataset=False,
  342. input_mask_from_dataset=False),
  343. is_training=True),
  344. 'desc_inputs': [Tensor(np.random.rand(128).astype(np.int32)),
  345. Tensor(np.random.rand(128).astype(np.int32)), [128]],
  346. 'desc_bprop': [[1, 128, 768], [1, 128, 768], [1, 128, 768]],
  347. 'num_output': 3}),
  348. ('BertPretrainingLoss', {
  349. 'block': BertPretrainingLoss(config=BertConfig(batch_size=1)),
  350. 'desc_inputs': [[32000], [20, 2], Tensor(np.array([1]).astype(np.int32)),
  351. [20], Tensor(np.array([20]).astype(np.int32))],
  352. 'desc_bprop': [[1]],
  353. 'num_output': 1}),
  354. ('Dense_1', {
  355. 'block': nn.Dense(in_channels=768,
  356. out_channels=3072,
  357. activation='gelu',
  358. weight_init=TruncatedNormal(0.02)),
  359. 'desc_inputs': [[3, 768]],
  360. 'desc_bprop': [[3, 3072]]}),
  361. ('Dense_2', {
  362. 'block': set_train(nn.Dense(in_channels=768,
  363. out_channels=3072,
  364. activation='gelu',
  365. weight_init=TruncatedNormal(0.02), )),
  366. 'desc_inputs': [[3, 768]],
  367. 'desc_bprop': [[3, 3072]]}),
  368. ('GetNextSentenceOutput', {
  369. 'block': GetNextSentenceOutput(BertConfig(batch_size=1)),
  370. 'desc_inputs': [[128, 768]],
  371. 'desc_bprop': [[128, 2]]}),
  372. ('Adam_1', {
  373. 'block': set_train(TrainStepWrapForAdam(NetForAdam())),
  374. 'desc_inputs': [[1, 64], [1, 10]],
  375. 'skip': ['backward']}),
  376. ('Adam_2', {
  377. 'block': set_train(TrainStepWrapForAdam(GetNextSentenceOutput(BertConfig(batch_size=1)))),
  378. 'desc_inputs': [[128, 768], [128, 2]],
  379. 'skip': ['backward']}),
  380. ('AdamWeightDecayDynamicLR', {
  381. 'block': set_train(TrainStepWrapForAdamDynamicLr(NetForAdam())),
  382. 'desc_inputs': [[1, 64]],
  383. 'skip': ['backward']}),
  384. ('ClipGradients', {
  385. 'block': TempC2Wrap(clip_grad, 1, 1.0),
  386. 'desc_inputs': [tuple(convert(shp) for shp in [[1], [1], [1]])],
  387. 'skip': ['backward', 'exec']}),
  388. ]
  389. test_case = functools.reduce(lambda x, y: x + y, [test_case_cell_ops])
  390. # use -k to select certain testcast
  391. # pytest tests/python/ops/test_ops.py::test_backward -k LayerNorm
  392. test_exec_case = filter(lambda x: 'skip' not in x[1] or
  393. 'exec' not in x[1]['skip'], test_case)
  394. test_backward_exec_case = filter(lambda x: 'skip' not in x[1] or
  395. 'backward' not in x[1]['skip'] and 'backward_exec'
  396. not in x[1]['skip'], test_case)
  397. test_check_gradient_case = filter(lambda x: 'skip' not in x[1] or
  398. 'backward' not in x[1]['skip'] and 'backward_exec'
  399. not in x[1]['skip'], test_case)
  400. @mindspore_test(pipeline_for_compile_forward_ge_graph_for_case_by_case_config)
  401. def test_exec():
  402. return test_exec_case
  403. @mindspore_test(pipeline_for_compile_grad_ge_graph_for_case_by_case_config)
  404. def test_backward_exec():
  405. return test_backward_exec_case