You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_serialization.py 12 kB


  1. import pickle
  2. from collections import defaultdict
  3. from functools import wraps
  4. from tempfile import TemporaryFile
  5. import numpy as np
  6. import megengine.functional as F
  7. import megengine.module as M
  8. import megengine.traced_module.expr as Expr
  9. import megengine.traced_module.serialization as S
  10. from megengine import Tensor
  11. from megengine.core._imperative_rt.core2 import apply
  12. from megengine.core.ops import builtin
  13. from megengine.core.ops.builtin import Elemwise
  14. from megengine.module import Module
  15. from megengine.traced_module import trace_module
  16. from megengine.traced_module.expr import CallMethod, Constant
  17. from megengine.traced_module.node import TensorNode
  18. from megengine.traced_module.serialization import (
  19. register_functional_loader,
  20. register_module_loader,
  21. register_opdef_loader,
  22. register_tensor_method_loader,
  23. )
  24. from megengine.traced_module.utils import _convert_kwargs_to_args
  25. def _check_id(traced_module):
  26. _total_ids = traced_module.graph._total_ids
  27. node_ids = [n._id for n in traced_module.graph.nodes().as_list()]
  28. assert len(set(node_ids)) == len(node_ids)
  29. assert max(node_ids) + 1 == _total_ids[0]
  30. expr_ids = [n._id for n in traced_module.graph.exprs().as_list()]
  31. assert len(set(expr_ids)) == len(expr_ids)
  32. assert max(expr_ids) + 1 == _total_ids[1]
  33. def _check_name(flatened_module):
  34. node_names = [n._name for n in flatened_module.graph.nodes().as_list()]
  35. assert len(set(node_names)) == len(node_names)
  36. def _check_expr_users(traced_module):
  37. node_user = defaultdict(list)
  38. for expr in traced_module.graph._exprs:
  39. for node in expr.inputs:
  40. node_user[node].append(expr)
  41. if isinstance(expr, CallMethod) and expr.graph:
  42. _check_expr_users(expr.inputs[0].owner)
  43. for node in traced_module.graph.nodes(False):
  44. node.users.sort(key=lambda m: m._id)
  45. node_user[node].sort(key=lambda m: m._id)
  46. assert node.users == node_user[node]
  47. class MyBlock(Module):
  48. def __init__(self, in_channels, channels):
  49. super(MyBlock, self).__init__()
  50. self.conv1 = M.Conv2d(in_channels, channels, 3, 1, padding=1, bias=False)
  51. self.bn1 = M.BatchNorm2d(channels)
  52. def forward(self, x):
  53. x = self.conv1(x)
  54. x = self.bn1(x)
  55. x = F.relu(x) + 1
  56. return x
  57. class MyModule(Module):
  58. def __init__(self):
  59. super(MyModule, self).__init__()
  60. self.block0 = MyBlock(8, 4)
  61. self.block1 = MyBlock(4, 2)
  62. def forward(self, x):
  63. x = self.block0(x)
  64. x = self.block1(x)
  65. return x
  66. def test_dump_and_load():
  67. module = MyModule()
  68. x = Tensor(np.ones((1, 8, 14, 14)))
  69. expect = module(x)
  70. traced_module = trace_module(module, x)
  71. np.testing.assert_array_equal(expect, traced_module(x))
  72. obj = pickle.dumps(traced_module)
  73. new_tm = pickle.loads(obj)
  74. _check_id(new_tm)
  75. _check_expr_users(new_tm)
  76. traced_module.graph._reset_ids()
  77. old_nodes = traced_module.graph.nodes().as_list()
  78. new_nodes = new_tm.graph.nodes().as_list()
  79. old_exprs = traced_module.graph.exprs().as_list()
  80. new_exprs = new_tm.graph.exprs().as_list()
  81. assert len(old_nodes) == len(new_nodes)
  82. for i, j in zip(old_nodes, new_nodes):
  83. assert i._name == j._name
  84. assert i._qualname == j._qualname
  85. assert i._id == j._id
  86. assert len(old_exprs) == len(new_exprs)
  87. for i, j in zip(old_exprs, new_exprs):
  88. assert i._id == j._id
  89. np.testing.assert_array_equal(expect, traced_module(x))
  90. def test_opdef_loader():
  91. class MyModule1(Module):
  92. def forward(self, x, y):
  93. op = Elemwise("ADD")
  94. return apply(op, x, y)[0]
  95. m = MyModule1()
  96. x = Tensor(np.ones((20)))
  97. y = Tensor(np.ones((20)))
  98. traced_module = trace_module(m, x, y)
  99. orig_loader_dict = S.OPDEF_LOADER
  100. S.OPDEF_LOADER = {}
  101. @register_opdef_loader(Elemwise)
  102. def add_opdef_loader(expr):
  103. if expr.opdef_state["mode"] == "ADD":
  104. expr.opdef_state["mode"] = "MUL"
  105. node = expr.inputs[1]
  106. astype_expr = CallMethod(node, "astype")
  107. oup = TensorNode(
  108. astype_expr,
  109. shape=node.shape,
  110. dtype=expr.inputs[0].dtype,
  111. qparams=node.qparams,
  112. )
  113. astype_expr.set_args_kwargs(node, expr.inputs[0].dtype)
  114. astype_expr.return_val = (oup,)
  115. expr.inputs[1] = oup
  116. obj = pickle.dumps(traced_module)
  117. new_module = pickle.loads(obj)
  118. _check_id(new_module)
  119. _check_expr_users(new_module)
  120. _check_name(new_module.flatten())
  121. assert (
  122. isinstance(new_module.graph._exprs[0], CallMethod)
  123. and new_module.graph._exprs[1].opdef.mode == "MUL"
  124. and len(new_module.graph._exprs) == 2
  125. )
  126. result = new_module(x, y)
  127. np.testing.assert_equal(result.numpy(), x.numpy())
  128. S.OPDEF_LOADER = orig_loader_dict
  129. def test_functional_loader():
  130. class MyModule2(Module):
  131. def forward(self, x, y):
  132. return F.conv2d(x, y)
  133. m = MyModule2()
  134. x = Tensor(np.random.random((1, 3, 32, 32)))
  135. y = Tensor(np.random.random((3, 3, 3, 3)))
  136. traced_module = trace_module(m, x, y)
  137. orig_loader_dict = S.FUNCTIONAL_LOADER
  138. S.FUNCTIONAL_LOADER = {}
  139. @register_functional_loader(("megengine.functional.nn", "conv2d"))
  140. def conv2df_loader(expr):
  141. # expr.func = ("megengine.functional.nn","conv2d")
  142. kwargs = expr.kwargs
  143. orig_weight = expr.named_args["weight"]
  144. astype_expr = CallMethod(orig_weight, "astype")
  145. oup = TensorNode(
  146. astype_expr,
  147. shape=orig_weight.shape,
  148. dtype=orig_weight.dtype,
  149. qparams=orig_weight.qparams,
  150. )
  151. astype_expr.set_args_kwargs(orig_weight, expr.named_args["inp"].dtype)
  152. astype_expr.return_val = (oup,)
  153. expr.set_arg("weight", oup)
  154. obj = pickle.dumps(traced_module)
  155. new_module = pickle.loads(obj)
  156. _check_expr_users(new_module)
  157. _check_id(new_module)
  158. result = new_module(x, y)
  159. gt = m(x, y)
  160. assert (
  161. isinstance(new_module.graph._exprs[0], CallMethod)
  162. and len(new_module.graph._exprs) == 2
  163. )
  164. np.testing.assert_equal(result.numpy(), gt.numpy())
  165. S.FUNCTIONAL_LOADER = orig_loader_dict
  166. def test_tensor_method_loader():
  167. class MyModule3(Module):
  168. def forward(self, x):
  169. return x + 1
  170. m = MyModule3()
  171. x = Tensor(np.ones((20)))
  172. traced_module = trace_module(m, x)
  173. orig_loader_dict = S.TENSORMETHOD_LOADER
  174. S.TENSORMETHOD_LOADER = {}
  175. @register_tensor_method_loader("__add__")
  176. def add_loader(expr):
  177. args = list(expr.args)
  178. if not isinstance(args[1], TensorNode):
  179. args[1] = Tensor(args[1])
  180. node = Constant(args[1], "const").outputs[0]
  181. astype_expr = CallMethod(node, "astype")
  182. oup = TensorNode(
  183. astype_expr, shape=node.shape, dtype=node.dtype, qparams=node.qparams,
  184. )
  185. astype_expr.set_args_kwargs(node, expr.inputs[0].dtype)
  186. astype_expr.return_val = (oup,)
  187. add_expr = CallMethod(oup, "__add__")
  188. add_expr.set_args_kwargs(oup, oup)
  189. oup1 = TensorNode(
  190. add_expr, shape=oup.shape, dtype=oup.dtype, qparams=node.qparams,
  191. )
  192. add_expr.return_val = oup1
  193. args[1] = oup1
  194. expr.set_args_kwargs(*args)
  195. obj = pickle.dumps(traced_module)
  196. new_module = pickle.loads(obj)
  197. _check_expr_users(new_module)
  198. _check_id(new_module)
  199. result = new_module(x)
  200. gt = m(x)
  201. assert (
  202. isinstance(new_module.graph._exprs[0], Constant)
  203. and len(new_module.graph._exprs) == 4
  204. )
  205. np.testing.assert_equal(result.numpy(), (x + 2).numpy())
  206. S.TENSORMETHOD_LOADER = orig_loader_dict
  207. def test_module_loader():
  208. class MyModule4(Module):
  209. def __init__(self):
  210. super().__init__()
  211. self.conv = M.Conv2d(3, 3, 3)
  212. def forward(self, x):
  213. return self.conv(x)
  214. m = MyModule4()
  215. x = Tensor(np.random.random((1, 3, 32, 32)))
  216. traced_module = trace_module(m, x)
  217. orig_loader_dict = S.MODULE_LOADER
  218. S.MODULE_LOADER = {}
  219. @register_module_loader(("megengine.module.conv", "Conv2d"))
  220. def conv2dm_loader(expr):
  221. module = expr.inputs[0].owner
  222. args = list(expr.args)
  223. orig_inp = args[1]
  224. astype_expr = CallMethod(orig_inp, "astype")
  225. oup = TensorNode(
  226. astype_expr,
  227. shape=orig_inp.shape,
  228. dtype=orig_inp.dtype,
  229. qparams=orig_inp.qparams,
  230. )
  231. astype_expr.set_args_kwargs(orig_inp, module.weight.dtype)
  232. astype_expr.return_val = (oup,)
  233. args[1] = oup
  234. expr.set_args_kwargs(*args)
  235. obj = pickle.dumps(traced_module)
  236. new_module = pickle.loads(obj)
  237. result = new_module(x)
  238. gt = m(x)
  239. assert (
  240. isinstance(new_module.graph._exprs[1], CallMethod)
  241. and len(new_module.graph._exprs) == 3
  242. )
  243. np.testing.assert_equal(result.numpy(), gt.numpy())
  244. S.MODULE_LOADER = orig_loader_dict
  245. def test_shared_module():
  246. class MyModule(M.Module):
  247. def __init__(self):
  248. super().__init__()
  249. self.a = M.Elemwise("ADD")
  250. self.b = self.a
  251. def forward(self, x, y):
  252. z = self.a(x, y)
  253. z = self.b(z, y)
  254. return z
  255. x = Tensor(1)
  256. y = Tensor(2)
  257. m = MyModule()
  258. tm = trace_module(m, x, y)
  259. obj = pickle.dumps(tm)
  260. load_tm = pickle.loads(obj)
  261. _check_expr_users(load_tm)
  262. _check_name(load_tm.flatten())
  263. _check_id(load_tm)
  264. assert load_tm.a is load_tm.b
  265. def test_convert_kwargs_to_args():
  266. def func(a, b, c=4, *, d, e=3, f=4):
  267. pass
  268. args = (1,)
  269. kwargs = {"b": 1, "d": 6}
  270. new_args, new_kwargs = _convert_kwargs_to_args(func, args, kwargs)
  271. assert new_args == (1, 1, 4)
  272. assert new_kwargs == {"d": 6, "e": 3, "f": 4}
  273. args = (1,)
  274. kwargs = {"d": 6}
  275. new_args, new_kwargs = _convert_kwargs_to_args(func, args, kwargs, is_bounded=True)
  276. assert new_args == (1, 4)
  277. assert new_kwargs == {"d": 6, "e": 3, "f": 4}
  278. def func1(a, b, c, d, e, *, f):
  279. pass
  280. args = ()
  281. kwargs = {"a": 1, "b": 2, "c": 3, "d": 4, "e": 5, "f": 6}
  282. new_args, new_kwargs = _convert_kwargs_to_args(func1, args, kwargs)
  283. assert new_args == (1, 2, 3, 4, 5)
  284. assert new_kwargs == {"f": 6}
  285. def test_opdef_serialization():
  286. with TemporaryFile() as f:
  287. x = builtin.Elemwise(mode="Add")
  288. pickle.dump(x, f)
  289. f.seek(0)
  290. load_x = pickle.load(f)
  291. assert x == load_x
  292. with TemporaryFile() as f:
  293. x = builtin.Convolution(stride_h=9, compute_mode="float32")
  294. x.strategy = (
  295. builtin.Convolution.Strategy.PROFILE
  296. | builtin.Convolution.Strategy.HEURISTIC
  297. | builtin.Convolution.Strategy.REPRODUCIBLE
  298. )
  299. pickle.dump(x, f)
  300. f.seek(0)
  301. load_x = pickle.load(f)
  302. assert x.strategy == load_x.strategy
  303. assert x == load_x
  304. def test_square_function_compat():
  305. @wraps(F.elemwise.square)
  306. def origin_square(x):
  307. return F.pow(x, 2)
  308. new_square = F.elemwise.square
  309. F.elemwise.square = origin_square
  310. current_version = Expr.__version__
  311. Expr.__version__ = "1.11.1"
  312. class old_square(M.Module):
  313. def forward(self, x):
  314. x = F.relu(x)
  315. x = F.elemwise.square(x)
  316. return x * 2
  317. m = trace_module(old_square(), Tensor([1, 2, 4, 6]))
  318. float_m = trace_module(old_square(), Tensor([1.0, 2.0, 4.0, 6.0]))
  319. # dump old version square
  320. obj = pickle.dumps(m)
  321. f_obj = pickle.dumps(float_m)
  322. # load in new version
  323. F.elemwise.square = new_square
  324. Expr.__version__ = current_version
  325. new_m = pickle.loads(obj)
  326. new_float_m = pickle.loads(f_obj)
  327. assert len(new_m.graph._exprs) == 4 and len(new_float_m.graph._exprs) == 3
  328. assert new_m(Tensor([1, 2, 4, 6])).dtype == np.float32