You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_python_pass.py 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. import numpy as np
  16. import mindspore
  17. import mindspore.nn as nn
  18. from mindspore import context
  19. from mindspore.common.tensor import Tensor
  20. from mindspore.ops import operations as P
  21. from mindspore.graph_utils.python_pass import registe_pass, unregiste_pass, _set_renorm, gen_new_parameter,\
  22. cancel_new_parameter, _set_reopt
  23. from mindspore.common.api import _generate_pip_args
  24. from mindspore._c_expression import generate_key, Executor_
  25. from mindspore.graph_utils.graph_pattern import OneOf, Prim, Call, NoneOf, Any, NewTensor, NewParameter, Imm
  26. context.set_context(mode=context.GRAPH_MODE)
  27. def get_func_graph(obj, *args, phase="validate"):
  28. args_names, args_list = _generate_pip_args(obj, *args)
  29. dic = dict(zip(args_names, args_list))
  30. key = generate_key(phase, dic)
  31. phase_prefix = str(key[1])
  32. if phase == 'export':
  33. phase = phase + '.' + phase_prefix + '.' + str(obj.create_time)
  34. else:
  35. phase = phase_prefix + phase + '.' + str(obj.create_time)
  36. _executor = Executor_.get_instance()
  37. _executor.compile(obj, args_list, phase, False)
  38. return _executor.get_func_graph(phase)
  39. def test_softmax_relu():
  40. """
  41. Use python pass to transform from Softmax to ReLU.
  42. """
  43. inputs = Tensor(np.ones([42]), mindspore.float16)
  44. softmax_model = nn.Softmax()
  45. @registe_pass(run_only_once=True)
  46. def softmax_relu_pass():
  47. x = Any()
  48. pattern = Call(P.Softmax(), [x])
  49. target = Call(P.ReLU(), [x])
  50. return pattern, target
  51. transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2)
  52. unregiste_pass(softmax_relu_pass)
  53. assert "ReLU" in transformed_repr
  54. assert "Softmax" not in transformed_repr
  55. def test_prim():
  56. inputs = Tensor(np.ones([42]), mindspore.float16)
  57. softmax_model = nn.Softmax()
  58. @registe_pass(run_only_once=True)
  59. def softmax_relu_pass():
  60. x = Any()
  61. sigmoid_softmax_pattern = Prim([P.Sigmoid(), P.Softmax()])
  62. pattern = Call(sigmoid_softmax_pattern, [x])
  63. target = Call(P.ReLU(), [x])
  64. return pattern, target
  65. transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(3)
  66. unregiste_pass(softmax_relu_pass)
  67. assert "ReLU" in transformed_repr
  68. assert "Softmax" not in transformed_repr
  69. def test_softmax_relu_sigmoid():
  70. """
  71. Use python pass to transform from Softmax(x) to ReLU(Sigmoid(x)).
  72. NOTE:
  73. Sigmoid pattern only exists in the target.
  74. """
  75. inputs = Tensor(np.ones([42]), mindspore.float16)
  76. softmax_model = nn.Softmax()
  77. @registe_pass(run_only_once=True)
  78. def softmax_relu_pass():
  79. x = Any()
  80. softmax_pattern = Prim(P.Softmax())
  81. pattern = Call(softmax_pattern, [x])
  82. sigmoid_pattern = Prim(P.Sigmoid())
  83. call_sigmoid = Call(sigmoid_pattern, [x])
  84. relu_pattern = Prim(P.ReLU())
  85. target = Call(relu_pattern, [call_sigmoid])
  86. return pattern, target
  87. transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(3)
  88. unregiste_pass(softmax_relu_pass)
  89. assert "ReLU" in transformed_repr
  90. assert "Sigmoid" in transformed_repr
  91. assert "Softmax" not in transformed_repr
  92. def test_isin_pattern_0():
  93. """
  94. Test IsIn pattern which expresses the IsIn/OneOf semantics.
  95. """
  96. inputs = Tensor(np.ones([42]), mindspore.float16)
  97. softmax_model = nn.Softmax()
  98. @registe_pass(run_only_once=True)
  99. def softmax_relu_pass():
  100. x = Any()
  101. softmax_pattern = Prim(P.Softmax())
  102. call_softmax = Call(softmax_pattern, [x])
  103. relu_pattern = Prim(P.ReLU())
  104. call_relu = Call(relu_pattern, [x])
  105. pattern = OneOf([call_softmax, call_relu])
  106. relu6_pattern = Prim(P.ReLU6())
  107. target = Call(relu6_pattern, [x])
  108. return pattern, target
  109. transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2)
  110. unregiste_pass(softmax_relu_pass)
  111. assert "ReLU6" in transformed_repr
  112. assert "Softmax" not in transformed_repr
  113. def test_isin_pattern_1():
  114. """
  115. Test IsIn. IsIn is used as nested inputs for the target in this case.
  116. """
  117. inputs = Tensor(np.ones([42]), mindspore.float16)
  118. softmax_model = nn.Softmax()
  119. @registe_pass(run_only_once=True)
  120. def softmax_neg_pass():
  121. x = Any()
  122. softmax_pattern = Prim(P.Softmax())
  123. call_softmax = Call(softmax_pattern, [x])
  124. relu_pattern = Prim(P.ReLU())
  125. call_relu = Call(relu_pattern, [x])
  126. pattern = OneOf([call_softmax, call_relu])
  127. neg_ops = Prim(P.Neg())
  128. target = Call(neg_ops, [pattern])
  129. return pattern, target
  130. transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(4)
  131. unregiste_pass(softmax_neg_pass)
  132. assert "Neg" in transformed_repr
  133. assert "Softmax" in transformed_repr
  134. def test_isnot_pattern_0():
  135. """
  136. Test IsNot pattern which expresses the IsNot semantics.
  137. Case: IsNot pass failed to match
  138. """
  139. _set_renorm(False)
  140. _set_reopt(False)
  141. class ConvBN(nn.Cell):
  142. def __init__(self):
  143. super(ConvBN, self).__init__()
  144. self.conv = P.Conv2D(32, 3)
  145. self.conv_weight = Tensor(np.ones([32, 32, 3, 3]), mindspore.float32)
  146. self.scale = Tensor(np.ones([32]), mindspore.float32)
  147. self.bias = Tensor(np.ones([32]), mindspore.float32)
  148. self.mean = Tensor(np.ones([32]), mindspore.float32)
  149. self.variance = Tensor(np.ones([32]), mindspore.float32)
  150. self.bn = P.BatchNorm()
  151. def construct(self, x):
  152. x = self.conv(x, self.conv_weight)
  153. x = self.bn(x, self.scale, self.bias, self.mean, self.variance)
  154. return x
  155. inputs = Tensor(np.random.normal(0, 1, (10, 32, 32, 32)), mindspore.float32)
  156. conv_bn_model = ConvBN()
  157. @registe_pass(requires_grad=False, run_only_once=True)
  158. def single_bn_pass():
  159. """
  160. Sub a BN which does NOT take Conv as inputs to ReLU6.
  161. """
  162. conv2d_prim = Prim("Conv2D")
  163. conv2d = Call(conv2d_prim)
  164. pattern_0 = NoneOf(conv2d)
  165. pattern = Call(P.BatchNorm(), [pattern_0])
  166. target = Call(P.ReLU6(), [pattern_0])
  167. return pattern, target
  168. @registe_pass(requires_grad=False, run_only_once=True)
  169. def bn_pass():
  170. """
  171. Sub a BN to Softmax.
  172. """
  173. pattern = Call(P.BatchNorm())
  174. target = Call(P.Softmax())
  175. return pattern, target
  176. transformed_repr = get_func_graph(conv_bn_model, inputs).get_return().expanded_str(5)
  177. unregiste_pass(single_bn_pass)
  178. unregiste_pass(bn_pass)
  179. assert "ReLU6" not in transformed_repr
  180. assert "Softmax" in transformed_repr
  181. _set_renorm(True)
  182. def test_isnot_pattern_1():
  183. """
  184. Test IsNot pattern which expresses the IsNot semantics.
  185. Case: IsNot pattern matches with the graph
  186. """
  187. inputs = Tensor(np.ones([42]), mindspore.float16)
  188. softmax_model = nn.Softmax()
  189. @registe_pass(run_only_once=True)
  190. def single_bn_pass():
  191. """
  192. Sub a BN which does NOT take MatMul as inputs to ReLU6.
  193. """
  194. matmul = Prim("MatMul")
  195. pattern_0 = NoneOf(matmul)
  196. softmax = P.Softmax()
  197. pattern = Call(softmax, [pattern_0])
  198. relu6 = P.ReLU6()
  199. target = Call(relu6, [pattern_0])
  200. return pattern, target
  201. transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
  202. unregiste_pass(single_bn_pass)
  203. assert "ReLU6" in transformed_repr
  204. assert "Softmax" not in transformed_repr
  205. def test_newtensor_pattern():
  206. """
  207. Test NewTensor pattern in the target
  208. """
  209. _set_renorm(False)
  210. _set_reopt(False)
  211. inputs = Tensor(np.ones([42]), mindspore.float16)
  212. softmax_model = nn.Softmax()
  213. @registe_pass(requires_grad=False, run_only_once=True)
  214. def softmax_addn_pass():
  215. x = Any()
  216. pattern = Call(P.Softmax(), [x])
  217. weight_tensor = Tensor(np.zeros([42]), mindspore.float16)
  218. new_weight = NewTensor(weight_tensor)
  219. target = Call(P.AddN(), [x, new_weight])
  220. return pattern, target
  221. transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2)
  222. unregiste_pass(softmax_addn_pass)
  223. assert "AddN" in transformed_repr
  224. assert "Softmax" not in transformed_repr
  225. _set_renorm(True)
  226. def test_newparameter_pattern():
  227. """
  228. Test NewParameter pattern in the target
  229. """
  230. inputs = Tensor(np.ones([42]), mindspore.float16)
  231. softmax_model = nn.Softmax()
  232. _set_renorm(False)
  233. _set_reopt(False)
  234. @registe_pass(requires_grad=False, run_only_once=True)
  235. def softmax_addn_pass():
  236. x = Any()
  237. pattern = Call(P.Softmax(), [x])
  238. default_tensor0 = Tensor(np.ones((4, 4)), mindspore.float32)
  239. default_tensor1 = Tensor(np.ones((4, 4)), mindspore.float32)
  240. new_para_0 = NewParameter("Merlin", default_tensor0)
  241. new_para_1 = NewParameter("Arthur", default_tensor1)
  242. target_0 = Call(P.MatMul(), [new_para_0, new_para_1])
  243. target = Call("make_tuple", [target_0])
  244. return pattern, target
  245. transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
  246. unregiste_pass(softmax_addn_pass)
  247. assert "MatMul" in transformed_repr
  248. assert "make_tuple" in transformed_repr
  249. assert "Softmax" not in transformed_repr
  250. def test_imm_target():
  251. """
  252. Test NewParameter pattern in the target
  253. """
  254. inputs = Tensor(np.ones([42]), mindspore.float16)
  255. softmax_model = nn.Softmax()
  256. _set_renorm(False)
  257. _set_reopt(False)
  258. @registe_pass(requires_grad=False, run_only_once=True)
  259. def softmax_pass():
  260. x = Any()
  261. pattern = Call(P.Softmax(), [x])
  262. imm = Imm(0)
  263. target_0 = Call("make_tuple", [pattern])
  264. target = Call("tuple_getitem", [target_0, imm])
  265. return pattern, target
  266. transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
  267. unregiste_pass(softmax_pass)
  268. assert "make_tuple" in transformed_repr
  269. assert "tuple_getitem" in transformed_repr
  270. assert "Softmax" in transformed_repr
  271. def test_gen_new_parameter():
  272. """
  273. Test gen_new_parameter
  274. """
  275. inputs = Tensor(np.ones([42]), mindspore.float16)
  276. softmax_model = nn.Softmax()
  277. default_tensor = Tensor(np.ones((4, 4)), mindspore.float32)
  278. new_para = NewParameter("Merlin", default_tensor)
  279. _set_renorm(False)
  280. _set_reopt(False)
  281. gen_new_parameter(new_para)
  282. @registe_pass(requires_grad=False, run_only_once=True)
  283. def softmax_make_tuple_pass():
  284. x = Any()
  285. softmax = P.Softmax()
  286. pattern = Call(softmax, [x])
  287. target = Call("make_tuple", [pattern, new_para])
  288. return pattern, target
  289. transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
  290. assert "Merlin" in transformed_repr
  291. unregiste_pass(softmax_make_tuple_pass)
  292. cancel_new_parameter(new_para)
  293. transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
  294. assert "Merlin" not in transformed_repr