You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_strategy_checkpoint.py 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. from mindspore import context
  16. from mindspore.context import set_auto_parallel_context, reset_auto_parallel_context
  17. import mindspore.nn as nn
  18. from mindspore.ops import operations as P
  19. from mindspore import Tensor, Parameter
  20. from tests.ut.python.ops.test_math_ops import VirtualLoss
  21. import mindspore as ms
  22. from mindspore.common.api import _executor
  23. from mindspore.ops import composite as C
  24. # model_parallel test
  25. def test_six_matmul_save():
  26. class NetWithLoss(nn.Cell):
  27. def __init__(self, network):
  28. super(NetWithLoss, self).__init__()
  29. self.loss = VirtualLoss()
  30. self.network = network
  31. def construct(self, x1, x6):
  32. predict = self.network(x1, x6)
  33. return self.loss(predict)
  34. class GradWrap(nn.Cell):
  35. def __init__(self, network):
  36. super(GradWrap, self).__init__()
  37. self.network = network
  38. def construct(self, x1, x6):
  39. return C.grad_all(self.network)(x1, x6)
  40. class Net(nn.Cell):
  41. def __init__(self, strategy1, strategy2, strategy3, strategy4, strategy5, strategy6):
  42. super().__init__()
  43. self.matmul1 = P.MatMul().set_strategy(strategy1)
  44. self.matmul2 = P.MatMul().set_strategy(strategy2)
  45. self.matmul3 = P.MatMul().set_strategy(strategy3)
  46. self.matmul4 = P.MatMul().set_strategy(strategy4)
  47. self.matmul5 = P.MatMul().set_strategy(strategy5)
  48. self.matmul6 = P.MatMul().set_strategy(strategy6)
  49. self.weight1 = Parameter(Tensor(np.ones([32, 64]), dtype=ms.float32), name="weight1")
  50. self.weight2 = Parameter(Tensor(np.ones([64, 64]), dtype=ms.float32), name="weight2")
  51. self.weight3 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight3")
  52. self.weight4 = Parameter(Tensor(np.ones([128, 64]), dtype=ms.float32), name="weight4")
  53. self.weight5 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight5")
  54. self.weight6 = Parameter(Tensor(np.ones([32, 128]), dtype=ms.float32), name="weight6")
  55. def construct(self, x1, x6):
  56. out = self.matmul1(x1, self.weight1)
  57. out = self.matmul2(out, self.weight2)
  58. out = self.matmul3(out, self.weight3)
  59. out = self.matmul4(out, self.weight4)
  60. out = self.matmul5(out, self.weight5)
  61. out = out + self.weight6
  62. out = self.matmul6(out, x6)
  63. return out
  64. reset_auto_parallel_context()
  65. set_auto_parallel_context(device_num=8, global_rank=0, strategy_ckpt_save_file="./strategy_stage1.ckpt")
  66. strategy1 = ((8, 1), (1, 1))
  67. strategy2 = ((1, 8), (8, 1))
  68. strategy3 = ((2, 2), (2, 2))
  69. strategy4 = ((1, 1), (1, 8))
  70. strategy5 = ((4, 2), (2, 1))
  71. strategy6 = ((4, 1), (1, 2))
  72. net = GradWrap(NetWithLoss(Net(strategy1, strategy2, strategy3, strategy4, strategy5, strategy6)))
  73. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  74. x1 = Tensor(np.ones([32, 32]), dtype=ms.float32)
  75. x6 = Tensor(np.ones([128, 32]), dtype=ms.float32)
  76. _executor.compile(net, x1, x6)
  77. # remove matmul2, add matmul7
  78. def test_six_matmul_load():
  79. class NetWithLoss(nn.Cell):
  80. def __init__(self, network):
  81. super(NetWithLoss, self).__init__()
  82. self.loss = VirtualLoss()
  83. self.network = network
  84. def construct(self, x1, x6, x7):
  85. predict = self.network(x1, x6, x7)
  86. return self.loss(predict)
  87. class GradWrap(nn.Cell):
  88. def __init__(self, network):
  89. super(GradWrap, self).__init__()
  90. self.network = network
  91. def construct(self, x1, x6, x7):
  92. return C.grad_all(self.network)(x1, x6, x7)
  93. class Net(nn.Cell):
  94. def __init__(self, strategy1, strategy3, strategy4, strategy5, strategy6, strategy7):
  95. super().__init__()
  96. self.matmul1 = P.MatMul().set_strategy(strategy1)
  97. self.matmul3 = P.MatMul().set_strategy(strategy3)
  98. self.matmul4 = P.MatMul().set_strategy(strategy4)
  99. self.matmul5 = P.MatMul().set_strategy(strategy5)
  100. self.matmul6 = P.MatMul().set_strategy(strategy6)
  101. self.matmul7 = P.MatMul().set_strategy(strategy7)
  102. self.weight1 = Parameter(Tensor(np.ones([32, 64]), dtype=ms.float32), name="weight1")
  103. self.weight3 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight3")
  104. self.weight4 = Parameter(Tensor(np.ones([128, 64]), dtype=ms.float32), name="weight4")
  105. self.weight5 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight5")
  106. self.weight6 = Parameter(Tensor(np.ones([32, 128]), dtype=ms.float32), name="weight6")
  107. def construct(self, x1, x6, x7):
  108. out = self.matmul1(x1, self.weight1)
  109. out = self.matmul3(out, self.weight3)
  110. out = self.matmul4(out, self.weight4)
  111. out = self.matmul5(out, self.weight5)
  112. out = out + self.weight6
  113. out = self.matmul6(out, x6)
  114. out = self.matmul7(out, x7)
  115. return out
  116. reset_auto_parallel_context()
  117. set_auto_parallel_context(device_num=8, global_rank=0, strategy_ckpt_load_file="./strategy_stage1.ckpt")
  118. strategy1 = ((8, 1), (1, 1))
  119. strategy3 = ((8, 1), (1, 1))
  120. strategy4 = ((8, 1), (1, 1))
  121. strategy5 = ((8, 1), (1, 1))
  122. strategy6 = ((8, 1), (1, 1))
  123. strategy7 = ((8, 1), (1, 1))
  124. net = GradWrap(NetWithLoss(Net(strategy1, strategy3, strategy4, strategy5, strategy6, strategy7)))
  125. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  126. x1 = Tensor(np.ones([32, 32]), dtype=ms.float32)
  127. x6 = Tensor(np.ones([128, 32]), dtype=ms.float32)
  128. x7 = Tensor(np.ones([32, 32]), dtype=ms.float32)
  129. _executor.compile(net, x1, x6, x7)
  130. # model_parallel test
  131. def test_six_matmul_save_auto():
  132. class NetWithLoss(nn.Cell):
  133. def __init__(self, network):
  134. super(NetWithLoss, self).__init__()
  135. self.loss = VirtualLoss()
  136. self.network = network
  137. def construct(self, x1, x6):
  138. predict = self.network(x1, x6)
  139. return self.loss(predict)
  140. class GradWrap(nn.Cell):
  141. def __init__(self, network):
  142. super(GradWrap, self).__init__()
  143. self.network = network
  144. def construct(self, x1, x6):
  145. return C.grad_all(self.network)(x1, x6)
  146. class Net(nn.Cell):
  147. def __init__(self):
  148. super().__init__()
  149. self.matmul1 = P.MatMul()
  150. self.matmul2 = P.MatMul()
  151. self.matmul3 = P.MatMul()
  152. self.matmul4 = P.MatMul()
  153. self.matmul5 = P.MatMul()
  154. self.matmul6 = P.MatMul()
  155. self.weight1 = Parameter(Tensor(np.ones([32, 64]), dtype=ms.float32), name="weight1")
  156. self.weight2 = Parameter(Tensor(np.ones([64, 64]), dtype=ms.float32), name="weight2")
  157. self.weight3 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight3")
  158. self.weight4 = Parameter(Tensor(np.ones([128, 64]), dtype=ms.float32), name="weight4")
  159. self.weight5 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight5")
  160. self.weight6 = Parameter(Tensor(np.ones([32, 128]), dtype=ms.float32), name="weight6")
  161. def construct(self, x1, x6):
  162. out = self.matmul1(x1, self.weight1)
  163. out = self.matmul2(out, self.weight2)
  164. out = self.matmul3(out, self.weight3)
  165. out = self.matmul4(out, self.weight4)
  166. out = self.matmul5(out, self.weight5)
  167. out = out + self.weight6
  168. out = self.matmul6(out, x6)
  169. return out
  170. reset_auto_parallel_context()
  171. set_auto_parallel_context(device_num=8, global_rank=0, strategy_ckpt_save_file="./strategy_stage1_auto.ckpt")
  172. net = GradWrap(NetWithLoss(Net()))
  173. context.set_auto_parallel_context(parallel_mode="auto_parallel")
  174. x1 = Tensor(np.ones([32, 32]), dtype=ms.float32)
  175. x6 = Tensor(np.ones([128, 32]), dtype=ms.float32)
  176. _executor.compile(net, x1, x6)
  177. # remove matmul2, add matmul7
  178. def test_six_matmul_load_auto():
  179. class NetWithLoss(nn.Cell):
  180. def __init__(self, network):
  181. super(NetWithLoss, self).__init__()
  182. self.loss = VirtualLoss()
  183. self.network = network
  184. def construct(self, x1, x6, x7):
  185. predict = self.network(x1, x6, x7)
  186. return self.loss(predict)
  187. class GradWrap(nn.Cell):
  188. def __init__(self, network):
  189. super(GradWrap, self).__init__()
  190. self.network = network
  191. def construct(self, x1, x6, x7):
  192. return C.grad_all(self.network)(x1, x6, x7)
  193. class Net(nn.Cell):
  194. def __init__(self, strategy1, strategy3, strategy4, strategy5):
  195. super().__init__()
  196. self.matmul1 = P.MatMul().set_strategy(strategy1)
  197. self.matmul3 = P.MatMul().set_strategy(strategy3)
  198. self.matmul4 = P.MatMul().set_strategy(strategy4)
  199. self.matmul5 = P.MatMul().set_strategy(strategy5)
  200. self.matmul6 = P.MatMul()
  201. self.matmul7 = P.MatMul()
  202. self.weight1 = Parameter(Tensor(np.ones([32, 64]), dtype=ms.float32), name="weight1")
  203. self.weight3 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight3")
  204. self.weight4 = Parameter(Tensor(np.ones([128, 64]), dtype=ms.float32), name="weight4")
  205. self.weight5 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight5")
  206. self.weight6 = Parameter(Tensor(np.ones([32, 128]), dtype=ms.float32), name="weight6")
  207. def construct(self, x1, x6, x7):
  208. out = self.matmul1(x1, self.weight1)
  209. out = self.matmul3(out, self.weight3)
  210. out = self.matmul4(out, self.weight4)
  211. out = self.matmul5(out, self.weight5)
  212. out = out + self.weight6
  213. out = self.matmul6(out, x6)
  214. out = self.matmul7(out, x7)
  215. return out
  216. reset_auto_parallel_context()
  217. set_auto_parallel_context(device_num=8, global_rank=0, strategy_ckpt_load_file="./strategy_stage1_auto.ckpt")
  218. strategy1 = ((2, 2), (2, 2))
  219. strategy3 = ((2, 2), (2, 2))
  220. strategy4 = ((2, 2), (2, 2))
  221. strategy5 = ((2, 2), (2, 2))
  222. net = GradWrap(NetWithLoss(Net(strategy1, strategy3, strategy4, strategy5)))
  223. context.set_auto_parallel_context(parallel_mode="auto_parallel")
  224. x1 = Tensor(np.ones([32, 32]), dtype=ms.float32)
  225. x6 = Tensor(np.ones([128, 32]), dtype=ms.float32)
  226. x7 = Tensor(np.ones([32, 32]), dtype=ms.float32)
  227. _executor.compile(net, x1, x6, x7)