You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_python_pass.py 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. import numpy as np
  16. import mindspore
  17. import mindspore.nn as nn
  18. from mindspore import context
  19. from mindspore.common.tensor import Tensor
  20. from mindspore.ops import operations as P
  21. from mindspore.ops import _constants as Constants
  22. from mindspore.graph_utils.python_pass import registe_pass, unregiste_pass, set_renorm, gen_new_parameter,\
  23. cancel_new_parameter, set_reopt
  24. from mindspore.common.api import _generate_pip_args
  25. from mindspore._c_expression import generate_key, Executor_
  26. from mindspore.graph_utils.graph_pattern import OneOf, Prim, Call, NoneOf, Any, NewTensor, NewParameter, Imm
  27. context.set_context(mode=context.GRAPH_MODE)
  28. def get_func_graph(obj, *args, phase="validate"):
  29. args_names, args_list = _generate_pip_args(obj, *args)
  30. dic = dict(zip(args_names, args_list))
  31. key = generate_key(phase, dic)
  32. phase_prefix = str(key[1])
  33. if phase == 'export':
  34. phase = phase + '.' + phase_prefix + '.' + str(obj.create_time)
  35. else:
  36. phase = phase_prefix + phase + '.' + str(obj.create_time)
  37. _executor = Executor_.get_instance()
  38. _executor.compile(obj, args_list, phase, False)
  39. return _executor.get_func_graph(phase)
  40. def test_softmax_relu():
  41. """
  42. Use python pass to transform from Softmax to ReLU.
  43. """
  44. inputs = Tensor(np.ones([42]), mindspore.float16)
  45. softmax_model = nn.Softmax()
  46. @registe_pass(run_only_once=True)
  47. def softmax_relu_pass():
  48. x = Any()
  49. pattern = Call(P.Softmax(), [x])
  50. target = Call(P.ReLU(), [x])
  51. return pattern, target
  52. transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2)
  53. unregiste_pass(softmax_relu_pass)
  54. assert "ReLU" in transformed_repr
  55. assert "Softmax" not in transformed_repr
  56. def test_prim():
  57. inputs = Tensor(np.ones([42]), mindspore.float16)
  58. softmax_model = nn.Softmax()
  59. @registe_pass(run_only_once=True)
  60. def softmax_relu_pass():
  61. x = Any()
  62. sigmoid_softmax_pattern = Prim([P.Sigmoid(), P.Softmax()])
  63. pattern = Call(sigmoid_softmax_pattern, [x])
  64. target = Call(P.ReLU(), [x])
  65. return pattern, target
  66. transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(3)
  67. unregiste_pass(softmax_relu_pass)
  68. assert "ReLU" in transformed_repr
  69. assert "Softmax" not in transformed_repr
  70. def test_softmax_relu_sigmoid():
  71. """
  72. Use python pass to transform from Softmax(x) to ReLU(Sigmoid(x)).
  73. NOTE:
  74. Sigmoid pattern only exists in the target.
  75. """
  76. inputs = Tensor(np.ones([42]), mindspore.float16)
  77. softmax_model = nn.Softmax()
  78. @registe_pass(run_only_once=True)
  79. def softmax_relu_pass():
  80. x = Any()
  81. softmax_pattern = Prim(P.Softmax())
  82. pattern = Call(softmax_pattern, [x])
  83. sigmoid_pattern = Prim(P.Sigmoid())
  84. call_sigmoid = Call(sigmoid_pattern, [x])
  85. relu_pattern = Prim(P.ReLU())
  86. target = Call(relu_pattern, [call_sigmoid])
  87. return pattern, target
  88. transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(3)
  89. unregiste_pass(softmax_relu_pass)
  90. assert "ReLU" in transformed_repr
  91. assert "Sigmoid" in transformed_repr
  92. assert "Softmax" not in transformed_repr
  93. def test_isin_pattern_0():
  94. """
  95. Test IsIn pattern which expresses the IsIn/OneOf semantics.
  96. """
  97. inputs = Tensor(np.ones([42]), mindspore.float16)
  98. softmax_model = nn.Softmax()
  99. @registe_pass(run_only_once=True)
  100. def softmax_relu_pass():
  101. x = Any()
  102. softmax_pattern = Prim(P.Softmax())
  103. call_softmax = Call(softmax_pattern, [x])
  104. relu_pattern = Prim(P.ReLU())
  105. call_relu = Call(relu_pattern, [x])
  106. pattern = OneOf([call_softmax, call_relu])
  107. relu6_pattern = Prim(P.ReLU6())
  108. target = Call(relu6_pattern, [x])
  109. return pattern, target
  110. transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2)
  111. unregiste_pass(softmax_relu_pass)
  112. assert "ReLU6" in transformed_repr
  113. assert "Softmax" not in transformed_repr
  114. def test_isin_pattern_1():
  115. """
  116. Test IsIn. IsIn is used as nested inputs for the target in this case.
  117. """
  118. inputs = Tensor(np.ones([42]), mindspore.float16)
  119. softmax_model = nn.Softmax()
  120. @registe_pass(run_only_once=True)
  121. def softmax_neg_pass():
  122. x = Any()
  123. softmax_pattern = Prim(P.Softmax())
  124. call_softmax = Call(softmax_pattern, [x])
  125. relu_pattern = Prim(P.ReLU())
  126. call_relu = Call(relu_pattern, [x])
  127. pattern = OneOf([call_softmax, call_relu])
  128. neg_ops = Prim(P.Neg())
  129. target = Call(neg_ops, [pattern])
  130. return pattern, target
  131. transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(4)
  132. unregiste_pass(softmax_neg_pass)
  133. assert "Neg" in transformed_repr
  134. assert "Softmax" in transformed_repr
  135. def test_isnot_pattern_0():
  136. """
  137. Test IsNot pattern which expresses the IsNot semantics.
  138. Case: IsNot pass failed to match
  139. """
  140. set_renorm(False)
  141. set_reopt(False)
  142. class ConvBN(nn.Cell):
  143. def __init__(self):
  144. super(ConvBN, self).__init__()
  145. self.conv = P.Conv2D(32, 3)
  146. self.conv_weight = Tensor(np.ones([32, 32, 3, 3]), mindspore.float32)
  147. self.scale = Tensor(np.ones([32]), mindspore.float32)
  148. self.bias = Tensor(np.ones([32]), mindspore.float32)
  149. self.mean = Tensor(np.ones([32]), mindspore.float32)
  150. self.variance = Tensor(np.ones([32]), mindspore.float32)
  151. self.bn = P.BatchNorm()
  152. def construct(self, x):
  153. x = self.conv(x, self.conv_weight)
  154. x = self.bn(x, self.scale, self.bias, self.mean, self.variance)
  155. return x
  156. inputs = Tensor(np.random.normal(0, 1, (10, 32, 32, 32)), mindspore.float32)
  157. conv_bn_model = ConvBN()
  158. @registe_pass(requires_grad=False, run_only_once=True)
  159. def single_bn_pass():
  160. """
  161. Sub a BN which does NOT take Conv as inputs to ReLU6.
  162. """
  163. conv2d_prim = Prim("Conv2D")
  164. conv2d = Call(conv2d_prim)
  165. pattern_0 = NoneOf(conv2d)
  166. pattern = Call(P.BatchNorm(), [pattern_0])
  167. target = Call(P.ReLU6(), [pattern_0])
  168. return pattern, target
  169. @registe_pass(requires_grad=False, run_only_once=True)
  170. def bn_pass():
  171. """
  172. Sub a BN to Softmax.
  173. """
  174. pattern = Call(P.BatchNorm())
  175. target = Call(P.Softmax())
  176. return pattern, target
  177. transformed_repr = get_func_graph(conv_bn_model, inputs).get_return().expanded_str(5)
  178. unregiste_pass(single_bn_pass)
  179. unregiste_pass(bn_pass)
  180. assert "ReLU6" not in transformed_repr
  181. assert "Softmax" in transformed_repr
  182. set_renorm(True)
  183. def test_isnot_pattern_1():
  184. """
  185. Test IsNot pattern which expresses the IsNot semantics.
  186. Case: IsNot pattern matches with the graph
  187. """
  188. inputs = Tensor(np.ones([42]), mindspore.float16)
  189. softmax_model = nn.Softmax()
  190. @registe_pass(run_only_once=True)
  191. def single_bn_pass():
  192. """
  193. Sub a BN which does NOT take MatMul as inputs to ReLU6.
  194. """
  195. matmul = Prim("MatMul")
  196. pattern_0 = NoneOf(matmul)
  197. softmax = P.Softmax()
  198. pattern = Call(softmax, [pattern_0])
  199. relu6 = P.ReLU6()
  200. target = Call(relu6, [pattern_0])
  201. return pattern, target
  202. transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
  203. unregiste_pass(single_bn_pass)
  204. assert "ReLU6" in transformed_repr
  205. assert "Softmax" not in transformed_repr
  206. def test_newtensor_pattern():
  207. """
  208. Test NewTensor pattern in the target
  209. """
  210. set_renorm(False)
  211. set_reopt(False)
  212. inputs = Tensor(np.ones([42]), mindspore.float16)
  213. softmax_model = nn.Softmax()
  214. @registe_pass(requires_grad=False, run_only_once=True)
  215. def softmax_addn_pass():
  216. x = Any()
  217. pattern = Call(P.Softmax(), [x])
  218. weight_tensor = Tensor(np.zeros([42]), mindspore.float16)
  219. new_weight = NewTensor(weight_tensor)
  220. target = Call(P.AddN(), [x, new_weight])
  221. return pattern, target
  222. transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2)
  223. unregiste_pass(softmax_addn_pass)
  224. assert "AddN" in transformed_repr
  225. assert "Softmax" not in transformed_repr
  226. set_renorm(True)
  227. def test_newparameter_pattern():
  228. """
  229. Test NewParameter pattern in the target
  230. """
  231. inputs = Tensor(np.ones([42]), mindspore.float16)
  232. softmax_model = nn.Softmax()
  233. set_renorm(False)
  234. set_reopt(False)
  235. @registe_pass(requires_grad=False, run_only_once=True)
  236. def softmax_addn_pass():
  237. x = Any()
  238. pattern = Call(P.Softmax(), [x])
  239. default_tensor0 = Tensor(np.ones((4, 4)), mindspore.float32)
  240. default_tensor1 = Tensor(np.ones((4, 4)), mindspore.float32)
  241. new_para_0 = NewParameter("Merlin", default_tensor0)
  242. new_para_1 = NewParameter("Arthur", default_tensor1)
  243. target_0 = Call(P.MatMul(), [new_para_0, new_para_1])
  244. target = Call("MakeTuple", [target_0])
  245. return pattern, target
  246. transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
  247. unregiste_pass(softmax_addn_pass)
  248. assert "MatMul" in transformed_repr
  249. assert "MakeTuple" in transformed_repr
  250. assert "Softmax" not in transformed_repr
  251. def test_imm_target():
  252. """
  253. Test NewParameter pattern in the target
  254. """
  255. inputs = Tensor(np.ones([42]), mindspore.float16)
  256. softmax_model = nn.Softmax()
  257. set_renorm(False)
  258. set_reopt(False)
  259. @registe_pass(requires_grad=False, run_only_once=True)
  260. def softmax_pass():
  261. x = Any()
  262. pattern = Call(P.Softmax(), [x])
  263. imm = Imm(0)
  264. target_0 = Call("MakeTuple", [pattern])
  265. target = Call(Constants.kTupleGetItem, [target_0, imm])
  266. return pattern, target
  267. transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
  268. unregiste_pass(softmax_pass)
  269. assert "MakeTuple" in transformed_repr
  270. assert Constants.kTupleGetItem in transformed_repr
  271. assert "Softmax" in transformed_repr
  272. def test_gen_new_parameter():
  273. """
  274. Test gen_new_parameter
  275. """
  276. inputs = Tensor(np.ones([42]), mindspore.float16)
  277. softmax_model = nn.Softmax()
  278. default_tensor = Tensor(np.ones((4, 4)), mindspore.float32)
  279. new_para = NewParameter("Merlin", default_tensor)
  280. set_renorm(False)
  281. set_reopt(False)
  282. gen_new_parameter(new_para)
  283. @registe_pass(requires_grad=False, run_only_once=True)
  284. def softmax_make_tuple_pass():
  285. x = Any()
  286. softmax = P.Softmax()
  287. pattern = Call(softmax, [x])
  288. target = Call("MakeTuple", [pattern, new_para])
  289. return pattern, target
  290. transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
  291. assert "Merlin" in transformed_repr
  292. unregiste_pass(softmax_make_tuple_pass)
  293. cancel_new_parameter(new_para)
  294. transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
  295. assert "Merlin" not in transformed_repr