You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_python_pass.py 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. import numpy as np
  16. import mindspore
  17. import mindspore.nn as nn
  18. from mindspore import context
  19. from mindspore.common.tensor import Tensor
  20. from mindspore.ops import operations as P
  21. from mindspore.graph_utils.python_pass import registe_pass, unregiste_pass, set_renorm, gen_new_parameter,\
  22. cancel_new_parameter
  23. from mindspore.common.api import _generate_pip_args
  24. from mindspore._c_expression import generate_key, Executor_
  25. from mindspore.graph_utils.graph_pattern import IsIn, IsPrimTypeOf, CallWith, IsNot, AnyPattern, NewTensor,\
  26. NewParameter, Imm
  27. context.set_context(mode=context.GRAPH_MODE)
  28. def get_func_graph(obj, *args, phase="validate"):
  29. args_names, args_list = _generate_pip_args(obj, *args)
  30. dic = dict(zip(args_names, args_list))
  31. key = generate_key(phase, dic)
  32. phase_prefix = str(key[1])
  33. if phase == 'export':
  34. phase = phase + '.' + phase_prefix + '.' + str(obj.create_time)
  35. else:
  36. phase = phase_prefix + phase + '.' + str(obj.create_time)
  37. _executor = Executor_.get_instance()
  38. _executor.compile(obj, args_list, phase, False)
  39. return _executor.get_func_graph(phase)
  40. def test_softmax_relu():
  41. """
  42. Use python pass to transform from Softmax to ReLU.
  43. """
  44. inputs = Tensor(np.ones([42]), mindspore.float16)
  45. softmax_model = nn.Softmax()
  46. @registe_pass(run_only_once=True)
  47. def softmax_relu_pass():
  48. x = AnyPattern()
  49. softmax_pattern = IsPrimTypeOf(P.Softmax())
  50. pattern = CallWith(softmax_pattern, inputs=[x])
  51. relu_pattern = IsPrimTypeOf(P.ReLU(), should_replace=False)
  52. target = CallWith(relu_pattern, inputs=[x])
  53. return pattern, target
  54. transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2)
  55. unregiste_pass(softmax_relu_pass)
  56. assert "ReLU" in transformed_repr
  57. assert "Softmax" not in transformed_repr
  58. def test_softmax_relu_sigmoid():
  59. """
  60. Use python pass to transform from Softmax(x) to ReLU(Sigmoid(x)).
  61. NOTE:
  62. Sigmoid pattern only exists in the target.
  63. """
  64. inputs = Tensor(np.ones([42]), mindspore.float16)
  65. softmax_model = nn.Softmax()
  66. @registe_pass(run_only_once=True)
  67. def softmax_relu_pass():
  68. x = AnyPattern()
  69. softmax_pattern = IsPrimTypeOf(P.Softmax())
  70. pattern = CallWith(softmax_pattern, inputs=[x])
  71. sigmoid_pattern = IsPrimTypeOf(P.Sigmoid(), should_replace=False)
  72. call_sigmoid = CallWith(sigmoid_pattern, [x])
  73. relu_pattern = IsPrimTypeOf(P.ReLU(), should_replace=False)
  74. target = CallWith(relu_pattern, inputs=[call_sigmoid])
  75. return pattern, target
  76. transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(3)
  77. unregiste_pass(softmax_relu_pass)
  78. assert "ReLU" in transformed_repr
  79. assert "Sigmoid" in transformed_repr
  80. assert "Softmax" not in transformed_repr
  81. def test_isin_pattern_0():
  82. """
  83. Test IsIn pattern which expresses the IsIn/OneOf semantics.
  84. """
  85. inputs = Tensor(np.ones([42]), mindspore.float16)
  86. softmax_model = nn.Softmax()
  87. @registe_pass(run_only_once=True)
  88. def softmax_relu_pass():
  89. x = AnyPattern()
  90. softmax_pattern = IsPrimTypeOf(P.Softmax())
  91. call_softmax = CallWith(softmax_pattern, inputs=[x])
  92. relu_pattern = IsPrimTypeOf(P.ReLU())
  93. call_relu = CallWith(relu_pattern, inputs=[x])
  94. pattern = IsIn([call_softmax, call_relu])
  95. relu6_pattern = IsPrimTypeOf(P.ReLU6(), should_replace=False)
  96. target = CallWith(relu6_pattern, inputs=[x])
  97. return pattern, target
  98. transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2)
  99. unregiste_pass(softmax_relu_pass)
  100. assert "ReLU6" in transformed_repr
  101. assert "Softmax" not in transformed_repr
  102. def test_isin_pattern_1():
  103. """
  104. Test IsIn. IsIn is used as nested inputs for the target in this case.
  105. """
  106. inputs = Tensor(np.ones([42]), mindspore.float16)
  107. softmax_model = nn.Softmax()
  108. @registe_pass(run_only_once=True)
  109. def softmax_neg_pass():
  110. x = AnyPattern()
  111. softmax_pattern = IsPrimTypeOf(P.Softmax())
  112. call_softmax = CallWith(softmax_pattern, inputs=[x])
  113. relu_pattern = IsPrimTypeOf(P.ReLU())
  114. call_relu = CallWith(relu_pattern, inputs=[x])
  115. pattern = IsIn([call_softmax, call_relu])
  116. neg_ops = IsPrimTypeOf(P.Neg(), should_replace=False)
  117. target = CallWith(neg_ops, inputs=[pattern])
  118. return pattern, target
  119. transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(4)
  120. print(transformed_repr)
  121. unregiste_pass(softmax_neg_pass)
  122. assert "Neg" in transformed_repr
  123. assert "Softmax" in transformed_repr
  124. def test_isnot_pattern_0():
  125. """
  126. Test IsNot pattern which expresses the IsNot semantics.
  127. Case: IsNot pass failed to match
  128. """
  129. set_renorm(False)
  130. class ConvBN(nn.Cell):
  131. def __init__(self):
  132. super(ConvBN, self).__init__()
  133. self.conv = P.Conv2D(32, 3)
  134. self.conv_weight = Tensor(np.ones([32, 32, 3, 3]), mindspore.float32)
  135. self.scale = Tensor(np.ones([32]), mindspore.float32)
  136. self.bias = Tensor(np.ones([32]), mindspore.float32)
  137. self.mean = Tensor(np.ones([32]), mindspore.float32)
  138. self.variance = Tensor(np.ones([32]), mindspore.float32)
  139. self.bn = P.BatchNorm()
  140. def construct(self, x):
  141. x = self.conv(x, self.conv_weight)
  142. x = self.bn(x, self.scale, self.bias, self.mean, self.variance)
  143. return x
  144. inputs = Tensor(np.random.normal(0, 1, (10, 32, 32, 32)), mindspore.float32)
  145. conv_bn_model = ConvBN()
  146. @registe_pass(run_only_once=True)
  147. def single_bn_pass():
  148. """
  149. Sub a BN which does NOT take Conv as inputs to ReLU6.
  150. """
  151. conv2d_prim = IsPrimTypeOf("Conv2D")
  152. conv2d = CallWith(conv2d_prim)
  153. pattern_0 = IsNot(conv2d)
  154. pattern = CallWith(P.BatchNorm(), inputs=[pattern_0])
  155. target = CallWith(P.ReLU6(), inputs=[pattern_0])
  156. return pattern, target
  157. @registe_pass(run_only_once=True)
  158. def bn_pass():
  159. """
  160. Sub a BN to Softmax.
  161. """
  162. bn = P.BatchNorm()
  163. pattern = CallWith(bn)
  164. softmax = P.Softmax()
  165. target = CallWith(softmax, should_replace=False)
  166. return pattern, target
  167. transformed_repr = get_func_graph(conv_bn_model, inputs).get_return().expanded_str(5)
  168. unregiste_pass(single_bn_pass)
  169. unregiste_pass(bn_pass)
  170. assert "ReLU6" not in transformed_repr
  171. assert "Softmax" in transformed_repr
  172. set_renorm(True)
  173. def test_isnot_pattern_1():
  174. """
  175. Test IsNot pattern which expresses the IsNot semantics.
  176. Case: IsNot pattern matches with the graph
  177. """
  178. inputs = Tensor(np.ones([42]), mindspore.float16)
  179. softmax_model = nn.Softmax()
  180. @registe_pass(run_only_once=True)
  181. def single_bn_pass():
  182. """
  183. Sub a BN which does NOT take MatMul as inputs to ReLU6.
  184. """
  185. matmul = IsPrimTypeOf("MatMul")
  186. pattern_0 = IsNot(matmul)
  187. softmax = P.Softmax()
  188. pattern = CallWith(softmax, inputs=[pattern_0])
  189. relu6 = P.ReLU6()
  190. target = CallWith(relu6, inputs=[pattern_0], should_replace=False)
  191. return pattern, target
  192. transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
  193. unregiste_pass(single_bn_pass)
  194. assert "ReLU6" in transformed_repr
  195. assert "Softmax" not in transformed_repr
  196. def test_newtensor_pattern():
  197. """
  198. Test NewTensor pattern in the target
  199. """
  200. set_renorm(False)
  201. inputs = Tensor(np.ones([42]), mindspore.float16)
  202. softmax_model = nn.Softmax()
  203. @registe_pass(run_only_once=True)
  204. def softmax_addn_pass():
  205. x = AnyPattern()
  206. softmax = P.Softmax()
  207. pattern = CallWith(softmax, inputs=[x])
  208. weight_tensor = Tensor(np.zeros([42]), mindspore.float16)
  209. new_weight = NewTensor(weight_tensor)
  210. addn_ops = P.AddN()
  211. target = CallWith(addn_ops, inputs=[x, new_weight], should_replace=False)
  212. return pattern, target
  213. transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2)
  214. unregiste_pass(softmax_addn_pass)
  215. assert "AddN" in transformed_repr
  216. assert "Softmax" not in transformed_repr
  217. set_renorm(True)
  218. def test_newparameter_pattern():
  219. """
  220. Test NewParameter pattern in the target
  221. """
  222. inputs = Tensor(np.ones([42]), mindspore.float16)
  223. softmax_model = nn.Softmax()
  224. @registe_pass(run_only_once=True)
  225. def softmax_addn_pass():
  226. x = AnyPattern()
  227. softmax = P.Softmax()
  228. pattern = CallWith(softmax, inputs=[x])
  229. default_tensor0 = Tensor(np.ones((4, 4)), mindspore.float32)
  230. default_tensor1 = Tensor(np.ones((4, 4)), mindspore.float32)
  231. new_para_0 = NewParameter("Merlin", default_tensor0)
  232. new_para_1 = NewParameter("Arthur", default_tensor1)
  233. target_0 = CallWith(P.MatMul(), inputs=[new_para_0, new_para_1], should_replace=False)
  234. target = CallWith("make_tuple", inputs=[target_0], should_replace=False)
  235. return pattern, target
  236. transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
  237. print(transformed_repr)
  238. unregiste_pass(softmax_addn_pass)
  239. assert "MatMul" in transformed_repr
  240. assert "make_tuple" in transformed_repr
  241. assert "Softmax" not in transformed_repr
  242. def test_imm_pattern():
  243. """
  244. Test NewParameter pattern in the target
  245. """
  246. inputs = Tensor(np.ones([42]), mindspore.float16)
  247. softmax_model = nn.Softmax()
  248. @registe_pass(run_only_once=True)
  249. def softmax_addn_pass():
  250. x = AnyPattern()
  251. softmax = P.Softmax()
  252. pattern = CallWith(softmax, inputs=[x])
  253. imm = Imm(0)
  254. target_0 = CallWith("make_tuple", inputs=[pattern], should_replace=False)
  255. target = CallWith("tuple_getitem", inputs=[target_0, imm], should_replace=False)
  256. return pattern, target
  257. transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
  258. print(transformed_repr)
  259. unregiste_pass(softmax_addn_pass)
  260. assert "make_tuple" in transformed_repr
  261. assert "tuple_getitem" in transformed_repr
  262. assert "Softmax" in transformed_repr
  263. def test_gen_new_parameter():
  264. """
  265. Test gen_new_parameter
  266. """
  267. inputs = Tensor(np.ones([42]), mindspore.float16)
  268. softmax_model = nn.Softmax()
  269. default_tensor = Tensor(np.ones((4, 4)), mindspore.float32)
  270. new_para = NewParameter("Merlin", default_tensor, should_replace=True)
  271. gen_new_parameter(new_para)
  272. @registe_pass(run_only_once=True)
  273. def softmax_make_tuple_pass():
  274. x = AnyPattern()
  275. softmax = P.Softmax()
  276. pattern = CallWith(softmax, inputs=[x])
  277. target = CallWith("make_tuple", inputs=[pattern, new_para], should_replace=False)
  278. return pattern, target
  279. transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
  280. print(transformed_repr)
  281. assert "Merlin" in transformed_repr
  282. unregiste_pass(softmax_make_tuple_pass)
  283. cancel_new_parameter(new_para)
  284. transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
  285. print(transformed_repr)
  286. assert "Merlin" not in transformed_repr