You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

expr.py 9.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import builtins
  10. import collections
  11. import inspect
  12. from typing import Callable, List
  13. from ...core._imperative_rt import OpDef
  14. from ...core._imperative_rt.core2 import Tensor as RawTensor
  15. from ...core._imperative_rt.core2 import apply, set_module_tracing, unset_module_tracing
  16. from ...core.ops.special import Const
  17. from ...module import Module
  18. from ...tensor import Parameter, Tensor
  19. from .module_tracer import active_module_tracer, module_tracer
  20. from .node import ModuleNode, Node, NodeMixin, TensorNode
  21. from .pytree import TreeDef, tree_flatten
  22. class Expr:
  23. """
  24. ``Expr`` represents the operations(i.e. CallMethod, CallFunction, Apply, GetAttr, Input, Constant) on ``Node``.
  25. """
  26. inputs = None # type: List[Node]
  27. outputs = None # type: List[Node]
  28. const_val = None # type: List[Any]
  29. arg_def = None # type: TreeDef
  30. def add_inputs(self, vals):
  31. if not isinstance(vals, collections.abc.Sequence):
  32. vals = (vals,)
  33. for val in vals:
  34. node = NodeMixin.get(val, None)
  35. if isinstance(node, (TensorNode, ModuleNode)):
  36. self.inputs.append(node)
  37. node.users.append(self)
  38. else:
  39. assert node is None
  40. idx = len(self.inputs) + len(self.const_val)
  41. self.const_val.append((idx, val))
  42. def add_outputs(self, outputs, check_inplace=True):
  43. self.outputs = []
  44. if outputs is not None:
  45. if not isinstance(outputs, collections.Sequence):
  46. outputs = (outputs,)
  47. for i in outputs:
  48. assert isinstance(i, RawTensor)
  49. node = NodeMixin.get(i, None) if check_inplace else None
  50. self.outputs.append(
  51. node if node else NodeMixin.get_wrapped_type(i)(self)
  52. )
  53. for i, node in zip(outputs, self.outputs,):
  54. NodeMixin.wrap_safe(i, node)
  55. def unflatten_args(self, inputs):
  56. if self.arg_def is not None:
  57. inputs = list(inputs)
  58. for idx, val in self.const_val:
  59. inputs.insert(idx, val)
  60. args, kwargs = self.arg_def.unflatten(inputs)
  61. return args, kwargs
  62. else:
  63. return inputs, {}
  64. @property
  65. def kwargs(self):
  66. _, kwargs = self.unflatten_args(self.inputs)
  67. return kwargs
  68. @property
  69. def args(self):
  70. args, _ = self.unflatten_args(self.inputs)
  71. return args
  72. # expr: None (i.e. fake expression which is used to mark input)
  73. class Input(Expr):
  74. name = None
  75. def __init__(self, name=None, type=None):
  76. self.inputs = []
  77. node_cls = type if type else Node
  78. self.outputs = [
  79. node_cls(self, name=name),
  80. ]
  81. self.name = name
  82. @classmethod
  83. def make(cls, *args, **kwargs):
  84. expr = cls(*args, **kwargs)
  85. active_module_tracer().current_scope().add_input(expr.outputs[0])
  86. return expr.outputs[0]
  87. def __repr__(self):
  88. return "{} = Input({})".format(self.outputs[0], self.name)
  89. # expr: outputs = getattr(inputs[0], self.name)
  90. class GetAttr(Expr):
  91. name = None
  92. def __init__(self, module, name, type=None):
  93. assert isinstance(module, ModuleNode)
  94. self.inputs = [
  95. module,
  96. ]
  97. module.users.append(self)
  98. self.name = name
  99. node_cls = type if type else Node
  100. self.outputs = [
  101. node_cls(self),
  102. ]
  103. @classmethod
  104. def make(cls, *args, **kwargs):
  105. expr = cls(*args, **kwargs)
  106. active_module_tracer().current_scope().insert(expr)
  107. expr.outputs[0]._name = expr.name
  108. return expr.outputs[0]
  109. def interpret(self, *inputs):
  110. return (getattr(inputs[0], self.name),)
  111. def __repr__(self):
  112. return '{} = GetAttr({}, "{}")'.format(
  113. self.outputs[0], self.inputs[0], self.name
  114. )
  115. # expr: outputs = inputs[0].__call__(*inputs[1:])
  116. class CallMethod(Expr):
  117. def __init__(self, node, method="__call__"):
  118. if isinstance(node, type):
  119. assert issubclass(node, Tensor)
  120. cls = Parameter if issubclass(node, Parameter) else Tensor
  121. self.inputs = []
  122. self.const_val = [(0, cls)]
  123. else:
  124. assert isinstance(node, (TensorNode, ModuleNode))
  125. node.users.append(self)
  126. self.inputs = [
  127. node,
  128. ]
  129. self.const_val = []
  130. self.method = method
  131. @classmethod
  132. def make(cls, *args, **kwargs):
  133. expr = cls(*args, **kwargs)
  134. active_module_tracer().current_scope().insert(expr)
  135. return expr
  136. @property
  137. def graph(self):
  138. if isinstance(self.inputs[0], ModuleNode):
  139. m_node = self.inputs[0]
  140. if m_node.argdef_graph_map:
  141. assert self.arg_def in m_node.argdef_graph_map
  142. return m_node.argdef_graph_map[self.arg_def]
  143. return None
  144. def interpret(self, *inputs):
  145. args, kwargs = self.unflatten_args(inputs)
  146. obj = args[0]
  147. meth = getattr(obj, self.method)
  148. if inspect.ismethod(meth):
  149. args = args[1:]
  150. outputs = getattr(obj, self.method)(*args, **kwargs)
  151. if outputs is None:
  152. return outputs
  153. outputs, _ = tree_flatten(outputs, is_leaf=lambda x: isinstance(x, RawTensor))
  154. return outputs
  155. def __repr__(self):
  156. args = ", ".join(str(i) for i in self.args[1:])
  157. kwargs = ", ".join("{}={}".format(k, v) for k, v in self.kwargs.items())
  158. return "{} = {}.{}({})".format(
  159. ", ".join(str(i) for i in self.outputs),
  160. self.args[0],
  161. self.method,
  162. ", ".join([args, kwargs]),
  163. )
  164. # expr: outputs = apply(self.opdef, *inputs)
  165. class Apply(Expr):
  166. opdef = None
  167. def __init__(self, opdef):
  168. assert isinstance(opdef, OpDef)
  169. self.opdef = opdef
  170. self.inputs = []
  171. @classmethod
  172. def make(cls, *args, **kwargs):
  173. expr = cls(*args, **kwargs)
  174. active_module_tracer().current_scope().insert(expr)
  175. return expr
  176. def interpret(self, *inputs):
  177. return apply(self.opdef, *inputs)
  178. def __repr__(self):
  179. return "{} = {}({})".format(
  180. ", ".join(str(i) for i in self.outputs),
  181. self.opdef,
  182. ", ".join(str(i) for i in self.inputs),
  183. )
  184. @classmethod
  185. def apply_module_trace_hook(cls, opdef, *inputs):
  186. for i in inputs:
  187. node = NodeMixin.get(i, None)
  188. if node is None: # capture as constant
  189. NodeMixin.wrap_safe(i, Constant.make(i))
  190. apply_node = cls.make(opdef)
  191. apply_node.add_inputs(inputs)
  192. assert not apply_node.const_val
  193. unset_module_tracing()
  194. outputs = apply(opdef, *inputs)
  195. set_module_tracing()
  196. apply_node.add_outputs(outputs)
  197. for n, v in zip(apply_node.outputs, outputs):
  198. NodeMixin.wrap_safe(v, n)
  199. return list(outputs)
  200. class CallFunction(Expr):
  201. def __init__(self, func):
  202. assert isinstance(func, Callable)
  203. self.func = func
  204. self.const_val = []
  205. self.inputs = []
  206. @classmethod
  207. def make(cls, *args, **kwargs):
  208. expr = cls(*args, **kwargs)
  209. active_module_tracer().current_scope().insert(expr)
  210. return expr
  211. def interpret(self, *inputs):
  212. args, kwargs = self.unflatten_args(inputs)
  213. outputs = self.func(*args, **kwargs)
  214. outputs = (
  215. outputs if isinstance(outputs, collections.abc.Sequence) else (outputs,)
  216. )
  217. return outputs
  218. def __repr__(self):
  219. args = ", ".join(str(i) for i in self.args)
  220. kwargs = ", ".join("{}={}".format(k, v) for k, v in self.kwargs.items())
  221. return "{} = {}({})".format(
  222. ", ".join(str(i) for i in self.outputs),
  223. self.func.__module__ + "." + self.func.__name__,
  224. ", ".join([args, kwargs]),
  225. )
  226. # expr outputs = self.value
  227. class Constant(Expr):
  228. value = None
  229. # TODO: constant cache to reduce the size of dumped model
  230. _constant_cache = {}
  231. def __init__(self, c):
  232. assert isinstance(c, (RawTensor, Module))
  233. if isinstance(c, Module):
  234. assert module_tracer.is_builtin(c)
  235. self.value = c
  236. self.inputs = []
  237. node_cls = NodeMixin.get_wrapped_type(c)
  238. self.outputs = [
  239. node_cls(self),
  240. ]
  241. @classmethod
  242. def make(cls, *args, **kwargs):
  243. expr = cls(*args, **kwargs)
  244. active_module_tracer().current_scope().insert(expr)
  245. return expr.outputs[0]
  246. def interpret(self, *inputs):
  247. if isinstance(self.value, RawTensor):
  248. return Const(self.value.numpy())()
  249. return (self.value,)
  250. def __repr__(self):
  251. return "{} = Constant({})".format(self.outputs[0], type(self.value))
  252. def __getstate__(self):
  253. state = self.__dict__.copy()
  254. if isinstance(self.value, RawTensor):
  255. state["value"] = Tensor(self.value)
  256. return state

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台