You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_strategy_checkpoint.py 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. from mindspore import context
  16. from mindspore.context import set_auto_parallel_context, reset_auto_parallel_context
  17. import mindspore.nn as nn
  18. from mindspore.ops import operations as P
  19. from mindspore import Tensor, Parameter
  20. from tests.ut.python.ops.test_math_ops import VirtualLoss
  21. import mindspore as ms
  22. from mindspore.common.api import _executor
  23. from mindspore.ops import composite as C
  24. # model_parallel test
  25. def test_six_matmul_save():
  26. class NetWithLoss(nn.Cell):
  27. def __init__(self, network):
  28. super(NetWithLoss, self).__init__()
  29. self.loss = VirtualLoss()
  30. self.network = network
  31. def construct(self, x1, x6):
  32. predict = self.network(x1, x6)
  33. return self.loss(predict)
  34. class GradWrap(nn.Cell):
  35. def __init__(self, network):
  36. super(GradWrap, self).__init__()
  37. self.network = network
  38. def construct(self, x1, x6):
  39. return C.grad_all(self.network)(x1, x6)
  40. class Net(nn.Cell):
  41. def __init__(self, strategy1, strategy2, strategy3, strategy4, strategy5, strategy6):
  42. super().__init__()
  43. self.matmul1 = P.MatMul().set_strategy(strategy1)
  44. self.matmul2 = P.MatMul().set_strategy(strategy2)
  45. self.matmul3 = P.MatMul().set_strategy(strategy3)
  46. self.matmul4 = P.MatMul().set_strategy(strategy4)
  47. self.matmul5 = P.MatMul().set_strategy(strategy5)
  48. self.matmul6 = P.MatMul().set_strategy(strategy6)
  49. self.weight1 = Parameter(Tensor(np.ones([32, 64]), dtype=ms.float32), name="weight1")
  50. self.weight2 = Parameter(Tensor(np.ones([64, 64]), dtype=ms.float32), name="weight2")
  51. self.weight3 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight3")
  52. self.weight4 = Parameter(Tensor(np.ones([128, 64]), dtype=ms.float32), name="weight4")
  53. self.weight5 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight5")
  54. self.weight6 = Parameter(Tensor(np.ones([32, 128]), dtype=ms.float32), name="weight6")
  55. def construct(self, x1, x6):
  56. out = self.matmul1(x1, self.weight1)
  57. out = self.matmul2(out, self.weight2)
  58. out = self.matmul3(out, self.weight3)
  59. out = self.matmul4(out, self.weight4)
  60. out = self.matmul5(out, self.weight5)
  61. out = out + self.weight6
  62. out = self.matmul6(out, x6)
  63. return out
  64. reset_auto_parallel_context()
  65. set_auto_parallel_context(device_num=8, global_rank=0, strategy_ckpt_save_file="./strategy_stage1.ckpt")
  66. strategy1 = ((8, 1), (1, 1))
  67. strategy2 = ((1, 8), (8, 1))
  68. strategy3 = ((2, 2), (2, 2))
  69. strategy4 = ((1, 1), (1, 8))
  70. strategy5 = ((4, 2), (2, 1))
  71. strategy6 = ((4, 1), (1, 2))
  72. net = GradWrap(NetWithLoss(Net(strategy1, strategy2, strategy3, strategy4, strategy5, strategy6)))
  73. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  74. net.set_auto_parallel()
  75. x1 = Tensor(np.ones([32, 32]), dtype=ms.float32)
  76. x6 = Tensor(np.ones([128, 32]), dtype=ms.float32)
  77. _executor.compile(net, x1, x6)
  78. # remove matmul2, add matmul7
  79. def test_six_matmul_load():
  80. class NetWithLoss(nn.Cell):
  81. def __init__(self, network):
  82. super(NetWithLoss, self).__init__()
  83. self.loss = VirtualLoss()
  84. self.network = network
  85. def construct(self, x1, x6, x7):
  86. predict = self.network(x1, x6, x7)
  87. return self.loss(predict)
  88. class GradWrap(nn.Cell):
  89. def __init__(self, network):
  90. super(GradWrap, self).__init__()
  91. self.network = network
  92. def construct(self, x1, x6, x7):
  93. return C.grad_all(self.network)(x1, x6, x7)
  94. class Net(nn.Cell):
  95. def __init__(self, strategy1, strategy3, strategy4, strategy5, strategy6, strategy7):
  96. super().__init__()
  97. self.matmul1 = P.MatMul().set_strategy(strategy1)
  98. self.matmul3 = P.MatMul().set_strategy(strategy3)
  99. self.matmul4 = P.MatMul().set_strategy(strategy4)
  100. self.matmul5 = P.MatMul().set_strategy(strategy5)
  101. self.matmul6 = P.MatMul().set_strategy(strategy6)
  102. self.matmul7 = P.MatMul().set_strategy(strategy7)
  103. self.weight1 = Parameter(Tensor(np.ones([32, 64]), dtype=ms.float32), name="weight1")
  104. self.weight3 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight3")
  105. self.weight4 = Parameter(Tensor(np.ones([128, 64]), dtype=ms.float32), name="weight4")
  106. self.weight5 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight5")
  107. self.weight6 = Parameter(Tensor(np.ones([32, 128]), dtype=ms.float32), name="weight6")
  108. def construct(self, x1, x6, x7):
  109. out = self.matmul1(x1, self.weight1)
  110. out = self.matmul3(out, self.weight3)
  111. out = self.matmul4(out, self.weight4)
  112. out = self.matmul5(out, self.weight5)
  113. out = out + self.weight6
  114. out = self.matmul6(out, x6)
  115. out = self.matmul7(out, x7)
  116. return out
  117. reset_auto_parallel_context()
  118. set_auto_parallel_context(device_num=8, global_rank=0, strategy_ckpt_load_file="./strategy_stage1.ckpt")
  119. strategy1 = ((8, 1), (1, 1))
  120. strategy3 = ((8, 1), (1, 1))
  121. strategy4 = ((8, 1), (1, 1))
  122. strategy5 = ((8, 1), (1, 1))
  123. strategy6 = ((8, 1), (1, 1))
  124. strategy7 = ((8, 1), (1, 1))
  125. net = GradWrap(NetWithLoss(Net(strategy1, strategy3, strategy4, strategy5, strategy6, strategy7)))
  126. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  127. net.set_auto_parallel()
  128. x1 = Tensor(np.ones([32, 32]), dtype=ms.float32)
  129. x6 = Tensor(np.ones([128, 32]), dtype=ms.float32)
  130. x7 = Tensor(np.ones([32, 32]), dtype=ms.float32)
  131. _executor.compile(net, x1, x6, x7)
  132. # model_parallel test
  133. def test_six_matmul_save_auto():
  134. class NetWithLoss(nn.Cell):
  135. def __init__(self, network):
  136. super(NetWithLoss, self).__init__()
  137. self.loss = VirtualLoss()
  138. self.network = network
  139. def construct(self, x1, x6):
  140. predict = self.network(x1, x6)
  141. return self.loss(predict)
  142. class GradWrap(nn.Cell):
  143. def __init__(self, network):
  144. super(GradWrap, self).__init__()
  145. self.network = network
  146. def construct(self, x1, x6):
  147. return C.grad_all(self.network)(x1, x6)
  148. class Net(nn.Cell):
  149. def __init__(self):
  150. super().__init__()
  151. self.matmul1 = P.MatMul()
  152. self.matmul2 = P.MatMul()
  153. self.matmul3 = P.MatMul()
  154. self.matmul4 = P.MatMul()
  155. self.matmul5 = P.MatMul()
  156. self.matmul6 = P.MatMul()
  157. self.weight1 = Parameter(Tensor(np.ones([32, 64]), dtype=ms.float32), name="weight1")
  158. self.weight2 = Parameter(Tensor(np.ones([64, 64]), dtype=ms.float32), name="weight2")
  159. self.weight3 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight3")
  160. self.weight4 = Parameter(Tensor(np.ones([128, 64]), dtype=ms.float32), name="weight4")
  161. self.weight5 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight5")
  162. self.weight6 = Parameter(Tensor(np.ones([32, 128]), dtype=ms.float32), name="weight6")
  163. def construct(self, x1, x6):
  164. out = self.matmul1(x1, self.weight1)
  165. out = self.matmul2(out, self.weight2)
  166. out = self.matmul3(out, self.weight3)
  167. out = self.matmul4(out, self.weight4)
  168. out = self.matmul5(out, self.weight5)
  169. out = out + self.weight6
  170. out = self.matmul6(out, x6)
  171. return out
  172. reset_auto_parallel_context()
  173. set_auto_parallel_context(device_num=8, global_rank=0, strategy_ckpt_save_file="./strategy_stage1_auto.ckpt")
  174. net = GradWrap(NetWithLoss(Net()))
  175. context.set_auto_parallel_context(parallel_mode="auto_parallel")
  176. net.set_auto_parallel()
  177. x1 = Tensor(np.ones([32, 32]), dtype=ms.float32)
  178. x6 = Tensor(np.ones([128, 32]), dtype=ms.float32)
  179. _executor.compile(net, x1, x6)
  180. # remove matmul2, add matmul7
  181. def test_six_matmul_load_auto():
  182. class NetWithLoss(nn.Cell):
  183. def __init__(self, network):
  184. super(NetWithLoss, self).__init__()
  185. self.loss = VirtualLoss()
  186. self.network = network
  187. def construct(self, x1, x6, x7):
  188. predict = self.network(x1, x6, x7)
  189. return self.loss(predict)
  190. class GradWrap(nn.Cell):
  191. def __init__(self, network):
  192. super(GradWrap, self).__init__()
  193. self.network = network
  194. def construct(self, x1, x6, x7):
  195. return C.grad_all(self.network)(x1, x6, x7)
  196. class Net(nn.Cell):
  197. def __init__(self, strategy1, strategy3, strategy4, strategy5):
  198. super().__init__()
  199. self.matmul1 = P.MatMul().set_strategy(strategy1)
  200. self.matmul3 = P.MatMul().set_strategy(strategy3)
  201. self.matmul4 = P.MatMul().set_strategy(strategy4)
  202. self.matmul5 = P.MatMul().set_strategy(strategy5)
  203. self.matmul6 = P.MatMul()
  204. self.matmul7 = P.MatMul()
  205. self.weight1 = Parameter(Tensor(np.ones([32, 64]), dtype=ms.float32), name="weight1")
  206. self.weight3 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight3")
  207. self.weight4 = Parameter(Tensor(np.ones([128, 64]), dtype=ms.float32), name="weight4")
  208. self.weight5 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight5")
  209. self.weight6 = Parameter(Tensor(np.ones([32, 128]), dtype=ms.float32), name="weight6")
  210. def construct(self, x1, x6, x7):
  211. out = self.matmul1(x1, self.weight1)
  212. out = self.matmul3(out, self.weight3)
  213. out = self.matmul4(out, self.weight4)
  214. out = self.matmul5(out, self.weight5)
  215. out = out + self.weight6
  216. out = self.matmul6(out, x6)
  217. out = self.matmul7(out, x7)
  218. return out
  219. reset_auto_parallel_context()
  220. set_auto_parallel_context(device_num=8, global_rank=0, strategy_ckpt_load_file="./strategy_stage1_auto.ckpt")
  221. strategy1 = ((2, 2), (2, 2))
  222. strategy3 = ((2, 2), (2, 2))
  223. strategy4 = ((2, 2), (2, 2))
  224. strategy5 = ((2, 2), (2, 2))
  225. net = GradWrap(NetWithLoss(Net(strategy1, strategy3, strategy4, strategy5)))
  226. context.set_auto_parallel_context(parallel_mode="auto_parallel")
  227. net.set_auto_parallel()
  228. x1 = Tensor(np.ones([32, 32]), dtype=ms.float32)
  229. x6 = Tensor(np.ones([128, 32]), dtype=ms.float32)
  230. x7 = Tensor(np.ones([32, 32]), dtype=ms.float32)
  231. _executor.compile(net, x1, x6, x7)