You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_strategy_checkpoint.py 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. from mindspore import context
  16. from mindspore.context import set_auto_parallel_context, reset_auto_parallel_context
  17. import mindspore.nn as nn
  18. from mindspore.ops import operations as P
  19. from mindspore import Tensor, Parameter
  20. from tests.ut.python.ops.test_math_ops import VirtualLoss
  21. import mindspore as ms
  22. from mindspore.common.api import _executor
  23. from mindspore.ops import composite as C
  24. # model_parallel test
  25. def test_six_matmul_save():
  26. class NetWithLoss(nn.Cell):
  27. def __init__(self, network):
  28. super(NetWithLoss, self).__init__()
  29. self.loss = VirtualLoss()
  30. self.network = network
  31. def construct(self, x1, x6):
  32. predict = self.network(x1, x6)
  33. return self.loss(predict)
  34. class GradWrap(nn.Cell):
  35. def __init__(self, network):
  36. super(GradWrap, self).__init__()
  37. self.network = network
  38. def construct(self, x1, x6):
  39. return C.grad_all(self.network)(x1, x6)
  40. class Net(nn.Cell):
  41. def __init__(self, strategy1, strategy2, strategy3, strategy4, strategy5, strategy6):
  42. super().__init__()
  43. self.matmul1 = P.MatMul().set_strategy(strategy1)
  44. self.matmul2 = P.MatMul().set_strategy(strategy2)
  45. self.matmul3 = P.MatMul().set_strategy(strategy3)
  46. self.matmul4 = P.MatMul().set_strategy(strategy4)
  47. self.matmul5 = P.MatMul().set_strategy(strategy5)
  48. self.matmul6 = P.MatMul().set_strategy(strategy6)
  49. self.weight1 = Parameter(Tensor(np.ones([32, 64]), dtype=ms.float32), name="weight1")
  50. self.weight2 = Parameter(Tensor(np.ones([64, 64]), dtype=ms.float32), name="weight2")
  51. self.weight3 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight3")
  52. self.weight4 = Parameter(Tensor(np.ones([128, 64]), dtype=ms.float32), name="weight4")
  53. self.weight5 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight5")
  54. def construct(self, x1, x6):
  55. out = self.matmul1(x1, self.weight1)
  56. out = self.matmul2(out, self.weight2)
  57. out = self.matmul3(out, self.weight3)
  58. out = self.matmul4(out, self.weight4)
  59. out = self.matmul5(out, self.weight5)
  60. out = self.matmul6(out, x6)
  61. return out
  62. reset_auto_parallel_context()
  63. set_auto_parallel_context(device_num=8, global_rank=0, strategy_ckpt_save_file="./strategy_stage1.ckpt")
  64. strategy1 = ((8, 1), (1, 1))
  65. strategy2 = ((1, 8), (8, 1))
  66. strategy3 = ((2, 2), (2, 2))
  67. strategy4 = ((1, 1), (1, 8))
  68. strategy5 = ((4, 2), (2, 1))
  69. strategy6 = ((4, 1), (1, 2))
  70. net = GradWrap(NetWithLoss(Net(strategy1, strategy2, strategy3, strategy4, strategy5, strategy6)))
  71. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  72. x1 = Tensor(np.ones([32, 32]), dtype=ms.float32)
  73. x6 = Tensor(np.ones([128, 32]), dtype=ms.float32)
  74. _executor.compile(net, x1, x6)
  75. # remove matmul2, add matmul7
  76. def test_six_matmul_load():
  77. class NetWithLoss(nn.Cell):
  78. def __init__(self, network):
  79. super(NetWithLoss, self).__init__()
  80. self.loss = VirtualLoss()
  81. self.network = network
  82. def construct(self, x1, x6, x7):
  83. predict = self.network(x1, x6, x7)
  84. return self.loss(predict)
  85. class GradWrap(nn.Cell):
  86. def __init__(self, network):
  87. super(GradWrap, self).__init__()
  88. self.network = network
  89. def construct(self, x1, x6, x7):
  90. return C.grad_all(self.network)(x1, x6, x7)
  91. class Net(nn.Cell):
  92. def __init__(self, strategy1, strategy3, strategy4, strategy5, strategy6, strategy7):
  93. super().__init__()
  94. self.matmul1 = P.MatMul().set_strategy(strategy1)
  95. self.matmul3 = P.MatMul().set_strategy(strategy3)
  96. self.matmul4 = P.MatMul().set_strategy(strategy4)
  97. self.matmul5 = P.MatMul().set_strategy(strategy5)
  98. self.matmul6 = P.MatMul().set_strategy(strategy6)
  99. self.matmul7 = P.MatMul().set_strategy(strategy7)
  100. self.weight1 = Parameter(Tensor(np.ones([32, 64]), dtype=ms.float32), name="weight1")
  101. self.weight3 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight3")
  102. self.weight4 = Parameter(Tensor(np.ones([128, 64]), dtype=ms.float32), name="weight4")
  103. self.weight5 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight5")
  104. def construct(self, x1, x6, x7):
  105. out = self.matmul1(x1, self.weight1)
  106. out = self.matmul3(out, self.weight3)
  107. out = self.matmul4(out, self.weight4)
  108. out = self.matmul5(out, self.weight5)
  109. out = self.matmul6(out, x6)
  110. out = self.matmul7(out, x7)
  111. return out
  112. reset_auto_parallel_context()
  113. set_auto_parallel_context(device_num=8, global_rank=0, strategy_ckpt_load_file="./strategy_stage1.ckpt")
  114. strategy1 = ((8, 1), (1, 1))
  115. strategy3 = ((8, 1), (1, 1))
  116. strategy4 = ((8, 1), (1, 1))
  117. strategy5 = ((8, 1), (1, 1))
  118. strategy6 = ((8, 1), (1, 1))
  119. strategy7 = ((8, 1), (1, 1))
  120. net = GradWrap(NetWithLoss(Net(strategy1, strategy3, strategy4, strategy5, strategy6, strategy7)))
  121. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  122. x1 = Tensor(np.ones([32, 32]), dtype=ms.float32)
  123. x6 = Tensor(np.ones([128, 32]), dtype=ms.float32)
  124. x7 = Tensor(np.ones([32, 32]), dtype=ms.float32)
  125. _executor.compile(net, x1, x6, x7)
  126. # model_parallel test
  127. def test_six_matmul_save_auto():
  128. class NetWithLoss(nn.Cell):
  129. def __init__(self, network):
  130. super(NetWithLoss, self).__init__()
  131. self.loss = VirtualLoss()
  132. self.network = network
  133. def construct(self, x1, x6):
  134. predict = self.network(x1, x6)
  135. return self.loss(predict)
  136. class GradWrap(nn.Cell):
  137. def __init__(self, network):
  138. super(GradWrap, self).__init__()
  139. self.network = network
  140. def construct(self, x1, x6):
  141. return C.grad_all(self.network)(x1, x6)
  142. class Net(nn.Cell):
  143. def __init__(self):
  144. super().__init__()
  145. self.matmul1 = P.MatMul()
  146. self.matmul2 = P.MatMul()
  147. self.matmul3 = P.MatMul()
  148. self.matmul4 = P.MatMul()
  149. self.matmul5 = P.MatMul()
  150. self.matmul6 = P.MatMul()
  151. self.weight1 = Parameter(Tensor(np.ones([32, 64]), dtype=ms.float32), name="weight1")
  152. self.weight2 = Parameter(Tensor(np.ones([64, 64]), dtype=ms.float32), name="weight2")
  153. self.weight3 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight3")
  154. self.weight4 = Parameter(Tensor(np.ones([128, 64]), dtype=ms.float32), name="weight4")
  155. self.weight5 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight5")
  156. def construct(self, x1, x6):
  157. out = self.matmul1(x1, self.weight1)
  158. out = self.matmul2(out, self.weight2)
  159. out = self.matmul3(out, self.weight3)
  160. out = self.matmul4(out, self.weight4)
  161. out = self.matmul5(out, self.weight5)
  162. out = self.matmul6(out, x6)
  163. return out
  164. reset_auto_parallel_context()
  165. set_auto_parallel_context(device_num=8, global_rank=0, strategy_ckpt_save_file="./strategy_stage1_auto.ckpt")
  166. net = GradWrap(NetWithLoss(Net()))
  167. context.set_auto_parallel_context(parallel_mode="auto_parallel")
  168. x1 = Tensor(np.ones([32, 32]), dtype=ms.float32)
  169. x6 = Tensor(np.ones([128, 32]), dtype=ms.float32)
  170. _executor.compile(net, x1, x6)
  171. # remove matmul2, add matmul7
  172. def test_six_matmul_load_auto():
  173. class NetWithLoss(nn.Cell):
  174. def __init__(self, network):
  175. super(NetWithLoss, self).__init__()
  176. self.loss = VirtualLoss()
  177. self.network = network
  178. def construct(self, x1, x6, x7):
  179. predict = self.network(x1, x6, x7)
  180. return self.loss(predict)
  181. class GradWrap(nn.Cell):
  182. def __init__(self, network):
  183. super(GradWrap, self).__init__()
  184. self.network = network
  185. def construct(self, x1, x6, x7):
  186. return C.grad_all(self.network)(x1, x6, x7)
  187. class Net(nn.Cell):
  188. def __init__(self, strategy1, strategy3, strategy4, strategy5):
  189. super().__init__()
  190. self.matmul1 = P.MatMul().set_strategy(strategy1)
  191. self.matmul3 = P.MatMul().set_strategy(strategy3)
  192. self.matmul4 = P.MatMul().set_strategy(strategy4)
  193. self.matmul5 = P.MatMul().set_strategy(strategy5)
  194. self.matmul6 = P.MatMul()
  195. self.matmul7 = P.MatMul()
  196. self.weight1 = Parameter(Tensor(np.ones([32, 64]), dtype=ms.float32), name="weight1")
  197. self.weight3 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight3")
  198. self.weight4 = Parameter(Tensor(np.ones([128, 64]), dtype=ms.float32), name="weight4")
  199. self.weight5 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight5")
  200. def construct(self, x1, x6, x7):
  201. out = self.matmul1(x1, self.weight1)
  202. out = self.matmul3(out, self.weight3)
  203. out = self.matmul4(out, self.weight4)
  204. out = self.matmul5(out, self.weight5)
  205. out = self.matmul6(out, x6)
  206. out = self.matmul7(out, x7)
  207. return out
  208. reset_auto_parallel_context()
  209. set_auto_parallel_context(device_num=8, global_rank=0, strategy_ckpt_load_file="./strategy_stage1_auto.ckpt")
  210. strategy1 = ((2, 2), (2, 2))
  211. strategy3 = ((2, 2), (2, 2))
  212. strategy4 = ((2, 2), (2, 2))
  213. strategy5 = ((2, 2), (2, 2))
  214. net = GradWrap(NetWithLoss(Net(strategy1, strategy3, strategy4, strategy5)))
  215. context.set_auto_parallel_context(parallel_mode="auto_parallel")
  216. x1 = Tensor(np.ones([32, 32]), dtype=ms.float32)
  217. x6 = Tensor(np.ones([128, 32]), dtype=ms.float32)
  218. x7 = Tensor(np.ones([32, 32]), dtype=ms.float32)
  219. _executor.compile(net, x1, x6, x7)