You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

network_node.py 17 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import json
  10. import sys
  11. from typing import Callable
  12. import numpy as np
  13. from ..core import _imperative_rt as rt
  14. from ..core._wrap import Device
  15. from ..core.ops import builtin
  16. from ..core.tensor.megbrain_graph import InputNode
  17. from ..tensor import Tensor
  18. from .comp_graph_tools import replace_vars
  19. from .module_stats import (
  20. preprocess_receptive_field,
  21. register_flops,
  22. register_receptive_field,
  23. )
  24. class NetworkNode:
  25. pass
  26. class VarNode(NetworkNode):
  27. def __init__(self, owner_opr=None, name=None):
  28. self.var = None
  29. self.owner = owner_opr
  30. self.name = name
  31. self.id = id(self)
  32. @classmethod
  33. def load(cls, sym_var, owner_opr):
  34. obj = cls()
  35. obj.var = sym_var # mgb varnode
  36. obj.name = sym_var.name
  37. obj.owner = owner_opr
  38. return obj
  39. @property
  40. def shape(self):
  41. rst = None
  42. if self.var:
  43. try:
  44. rst = self.var.shape
  45. except:
  46. rst = None
  47. return rst
  48. @property
  49. def dtype(self):
  50. return self.var.dtype if self.var else None
  51. def set_owner_opr(self, owner_opr):
  52. self.owner = owner_opr
  53. class OpNode(NetworkNode):
  54. opdef = None
  55. type = None
  56. def __init__(self):
  57. self.inputs = []
  58. self.outputs = []
  59. self.params = {}
  60. self._opr = None # mgb opnode
  61. self.id = id(self)
  62. @classmethod
  63. def load(cls, opr):
  64. obj = cls()
  65. obj.params = json.loads(opr.params)
  66. obj.name = opr.name
  67. obj._opr = opr
  68. return obj
  69. def compile(self, graph=None):
  70. op = self.opdef(**self.params)
  71. args = [i.var for i in self.inputs]
  72. outputs = rt.invoke_op(op, args)
  73. assert len(outputs) == len(self.outputs)
  74. self._opr = outputs[0].owner
  75. for i in range(len(self.outputs)):
  76. self.outputs[i].var = outputs[i]
  77. self.outputs[i].var.name = self.outputs[i].name
  78. assert self.outputs[i].owner is self
  79. def add_inp_var(self, x):
  80. self.inputs.append(x)
  81. def add_out_var(self, x):
  82. self.outputs.append(x)
  83. def __repr__(self):
  84. return "%s{%s}" % (self.name, self.type)
  85. def str_to_mge_class(classname):
  86. # TODO: use megbrain C++ RTTI to replace type string
  87. if classname == "RNGOpr<MegDNNOpr>":
  88. classname = "RNGOpr"
  89. oprcls = getattr(sys.modules[__name__], classname, None)
  90. return oprcls if oprcls else ReadOnlyOpNode
  91. class Host2DeviceCopy(OpNode):
  92. type = "Host2DeviceCopy"
  93. def __init__(self, shape=None, dtype=None, name=None, device=None):
  94. super().__init__()
  95. self.shape = shape
  96. self.dtype = dtype
  97. self.name = name
  98. self.device = Device(device).to_c() if device else Device("xpux").to_c()
  99. self.outputs = []
  100. @classmethod
  101. def load(cls, opr):
  102. self = cls()
  103. self.outputs = []
  104. assert len(opr.outputs) == 1, "wrong number of outputs"
  105. self.shape = opr.outputs[0].shape
  106. self.dtype = opr.outputs[0].dtype
  107. self.name = opr.outputs[0].name
  108. self.device = opr.outputs[0].comp_node
  109. self._opr = opr
  110. return self
  111. def compile(self, graph):
  112. outputs = rt.make_h2d(graph, self.device, self.dtype, self.shape, self.name)
  113. self._opr = outputs.owner
  114. if len(self.outputs) == 0:
  115. self.outputs.append(VarNode(self, self.name))
  116. self.outputs[0].var = outputs
  117. assert self.outputs[0].owner is self
  118. class ImmutableTensor(OpNode):
  119. type = "ImmutableTensor"
  120. def __init__(self, data=None, name=None, device=None, graph=None):
  121. super().__init__()
  122. self.name = name
  123. self.outputs = []
  124. self.graph = graph
  125. if data is not None:
  126. self.set_value(data, device)
  127. @property
  128. def device(self):
  129. return self._opr.outputs[0].comp_node if self._opr else None
  130. @device.setter
  131. def device(self, device):
  132. self.set_value(self.numpy(), device)
  133. @property
  134. def shape(self):
  135. return self.outputs[0].shape
  136. @property
  137. def dtype(self):
  138. return self._opr.outputs[0].dtype if self._opr else None
  139. def numpy(self):
  140. return self._opr.outputs[0].value if self._opr else None
  141. def set_value(self, data, device=None):
  142. assert self.graph is not None
  143. cn = device if device else self.device
  144. assert isinstance(data, (int, float, np.ndarray))
  145. if isinstance(data, (int, float)):
  146. data = np.array(data)
  147. if data.dtype == np.float64:
  148. data = data.astype(np.float32)
  149. elif data.dtype == np.int64:
  150. data = data.astype(np.int32)
  151. varnode = rt.make_const(self.graph, data, cn, data.dtype, self.name)
  152. if len(self.outputs) == 0:
  153. self.outputs.append(VarNode(self, self.name))
  154. self.outputs[0].var = varnode
  155. self._opr = varnode.owner
  156. @classmethod
  157. def load(cls, opr):
  158. self = cls()
  159. self.outputs = []
  160. self._opr = opr
  161. self.name = opr.outputs[0].name
  162. self.graph = opr.graph
  163. return self
  164. def compile(self, graph):
  165. assert self.outputs[0].var is self._opr.outputs[0]
  166. assert self.outputs[0].owner is self
  167. if self.graph != graph:
  168. self.graph = graph
  169. self.set_value(self.numpy())
  170. if self.name is not None:
  171. self.outputs[0].var.name = self.name
  172. class ReadOnlyOpNode(OpNode):
  173. @classmethod
  174. def load(cls, opr):
  175. obj = super(ReadOnlyOpNode, cls).load(opr)
  176. obj.type = opr.type
  177. return obj
  178. def compile(self):
  179. assert self._opr is not None
  180. assert len(self.inputs) == len(self._opr.inputs)
  181. assert len(self.outputs) == len(self._opr.outputs)
  182. repl_dict = {}
  183. for ind, i in enumerate(self.inputs):
  184. if i.var != self._opr.inputs[ind]:
  185. repl_dict[self._opr.inputs[ind]] = i.var
  186. if bool(repl_dict):
  187. out_vars = replace_vars(self._opr.outputs, repl_dict)
  188. for ind, o in enumerate(self.outputs):
  189. o.var = out_vars[ind]
  190. class Elemwise(OpNode):
  191. type = "Elemwise"
  192. opdef = builtin.Elemwise
  193. def __repr__(self):
  194. return "%s{Elemwise:%s}" % (self.name, self.params["mode"])
  195. class ElemwiseMultiType(OpNode):
  196. type = "ElemwiseMultiType"
  197. opdef = builtin.ElemwiseMultiType
  198. def __repr__(self):
  199. return "%s{ElemwiseMultiType:%s}" % (self.name, self.params["mode"])
  200. @classmethod
  201. def load(cls, opr):
  202. obj = super(ElemwiseMultiType, cls).load(opr)
  203. obj.params["dtype"] = opr.outputs[0].dtype
  204. return obj
  205. @register_flops(Elemwise, ElemwiseMultiType)
  206. def flops_elemwise(opnode: Elemwise, inputs, outputs):
  207. return np.prod(outputs[0].shape)
  208. class Reduce(OpNode):
  209. type = "Reduce"
  210. opdef = builtin.Reduce
  211. class TypeCvt(OpNode):
  212. type = "TypeCvt"
  213. opdef = builtin.TypeCvt
  214. @classmethod
  215. def load(cls, opr):
  216. obj = super(TypeCvt, cls).load(opr)
  217. t_dtype = opr.outputs[0].dtype
  218. obj.params["dtype"] = t_dtype
  219. return obj
  220. class MatrixInverse(OpNode):
  221. type = "MatrixInverse"
  222. opdef = builtin.MatrixInverse
  223. class MatrixMul(OpNode):
  224. type = "MatrixMul"
  225. opdef = builtin.MatrixMul
  226. @register_flops(MatrixMul)
  227. def flops_matmul(opnode: MatrixMul, inputs, outputs):
  228. assert len(inputs[0].shape) == 2 and len(outputs[0].shape) == 2
  229. mid_shape = inputs[0].shape[1]
  230. return np.prod(outputs[0].shape) * mid_shape
  231. class BatchedMatrixMul(OpNode):
  232. type = "BatchedMatmul"
  233. opdef = builtin.BatchedMatrixMul
  234. @register_flops(BatchedMatrixMul)
  235. def flops_batchmatmul(opnode: BatchedMatrixMul, inputs, outputs):
  236. assert len(inputs[0].shape) == 3 and len(outputs[0].shape) == 3
  237. mid_shape = inputs[0].shape[2]
  238. return np.prod(outputs[0].shape) * mid_shape
  239. class Dot(OpNode):
  240. type = "Dot"
  241. opdef = builtin.Dot
  242. class SVD(OpNode):
  243. type = "SVD"
  244. opdef = builtin.SVD
  245. class ConvolutionForward(OpNode):
  246. type = "Convolution"
  247. opdef = builtin.Convolution
  248. class ConvolutionBackwardData(OpNode):
  249. type = "ConvTranspose"
  250. opdef = builtin.ConvolutionBackwardData
  251. class DeformableConvForward(OpNode):
  252. type = "DeformableConv"
  253. opdef = builtin.DeformableConv
  254. class GroupLocalForward(OpNode):
  255. type = "GroupLocal"
  256. opdef = builtin.GroupLocal
  257. class PoolingForward(OpNode):
  258. type = "Pooling"
  259. opdef = builtin.Pooling
  260. class AdaptivePoolingForward(OpNode):
  261. type = "AdaptivePooling"
  262. opdef = builtin.AdaptivePooling
  263. class ROIPoolingForward(OpNode):
  264. type = "ROIPooling"
  265. opdef = builtin.ROIPooling
  266. class DeformablePSROIPoolingForward(OpNode):
  267. type = "DeformablePSROIPooling"
  268. opdef = builtin.DeformablePSROIPooling
  269. class ConvBiasForward(OpNode):
  270. type = "ConvBias"
  271. opdef = builtin.ConvBias
  272. @classmethod
  273. def load(cls, opr):
  274. obj = super(ConvBiasForward, cls).load(opr)
  275. obj.params["dtype"] = opr.outputs[0].dtype
  276. return obj
  277. @register_flops(
  278. ConvolutionForward, ConvBiasForward,
  279. )
  280. def flops_conv(opnode: ConvolutionForward, inputs, outputs):
  281. param_W_shape = inputs[1].shape
  282. kh = param_W_shape[-2]
  283. kw = param_W_shape[-1]
  284. if len(param_W_shape) == 5:
  285. num_input = param_W_shape[2]
  286. else:
  287. num_input = param_W_shape[1]
  288. NCHW = np.prod(outputs[0].shape)
  289. bias = 1 if isinstance(opnode, ConvBiasForward) else 0
  290. # N x Cout x H x W x (Cin x Kw x Kh)
  291. return NCHW * (num_input * kw * kh + bias)
  292. @register_receptive_field(ConvolutionForward, ConvBiasForward)
  293. def receptive_field(opnode: ConvolutionForward, inputs, outputs):
  294. pre_rf, pre_stride = preprocess_receptive_field(opnode, inputs, outputs)
  295. param_W_shape = inputs[1].shape
  296. kh = param_W_shape[-2]
  297. kw = param_W_shape[-1]
  298. rf = (
  299. kh * pre_stride[0] + pre_rf[0] - pre_stride[0],
  300. kw * pre_stride[1] + pre_rf[1] - pre_stride[1],
  301. )
  302. stride = (
  303. opnode.params["stride_h"] * pre_stride[0],
  304. opnode.params["stride_w"] * pre_stride[1],
  305. )
  306. opnode._rf = rf
  307. opnode._stride = stride
  308. return rf, stride
  309. class BatchConvBiasForward(OpNode):
  310. type = "BatchConvBias"
  311. opdef = builtin.BatchConvBias
  312. @classmethod
  313. def load(cls, opr):
  314. obj = super(BatchConvBiasForward, cls).load(opr)
  315. obj.params["dtype"] = opr.outputs[0].dtype
  316. return obj
  317. class BatchNormForward(OpNode):
  318. type = "BatchNorm"
  319. opdef = builtin.BatchNorm
  320. output_idx = -1
  321. class ROIAlignForward(OpNode):
  322. type = "ROIAlign"
  323. opdef = builtin.ROIAlign
  324. class WarpPerspectiveForward(OpNode):
  325. type = "WarpPerspective"
  326. opdef = builtin.WarpPerspective
  327. class WarpAffineForward(OpNode):
  328. type = "WarpAffine"
  329. opdef = builtin.WarpAffine
  330. class RemapForward(OpNode):
  331. type = "Remap"
  332. opdef = builtin.Remap
  333. class ResizeForward(OpNode):
  334. type = "Resize"
  335. opdef = builtin.Resize
  336. class IndexingOneHot(OpNode):
  337. type = "IndexingOneHot"
  338. opdef = builtin.IndexingOneHot
  339. class IndexingSetOneHot(OpNode):
  340. type = "IndexingSetOneHot"
  341. opdef = builtin.IndexingSetOneHot
  342. class Copy(OpNode):
  343. type = "Copy"
  344. opdef = builtin.Copy
  345. @classmethod
  346. def load(cls, opr):
  347. obj = super(Copy, cls).load(opr)
  348. obj.params["comp_node"] = opr.outputs[0].comp_node
  349. return obj
  350. class ArgsortForward(OpNode):
  351. type = "Argsort"
  352. opdef = builtin.Argsort
  353. class Argmax(OpNode):
  354. type = "Argmax"
  355. opdef = builtin.Argmax
  356. class Argmin(OpNode):
  357. type = "Argmin"
  358. opdef = builtin.Argmin
  359. class CondTake(OpNode):
  360. type = "CondTake"
  361. opdef = builtin.CondTake
  362. class TopK(OpNode):
  363. type = "TopK"
  364. opdef = builtin.TopK
  365. class NvOf(OpNode):
  366. type = "NvOf"
  367. opdef = builtin.NvOf
  368. class RNGOpr(OpNode):
  369. @classmethod
  370. def load(cls, opr):
  371. obj = super(RNGOpr, cls).load(opr)
  372. if len(obj.params) == 3:
  373. obj.opdef = builtin.GaussianRNG
  374. obj.type = "GaussianRNG"
  375. else:
  376. obj.opdef = builtin.UniformRNG
  377. obj.type = "UniformRNG"
  378. return obj
  379. class Linspace(OpNode):
  380. type = "Linspace"
  381. opdef = builtin.Linspace
  382. @classmethod
  383. def load(cls, opr):
  384. obj = super(Linspace, cls).load(opr)
  385. obj.params["comp_node"] = opr.outputs[0].comp_node
  386. return obj
  387. class Eye(OpNode):
  388. type = "Eye"
  389. opdef = builtin.Eye
  390. @classmethod
  391. def load(cls, opr):
  392. obj = super(Eye, cls).load(opr)
  393. obj.params["dtype"] = opr.outputs[0].dtype
  394. obj.params["comp_node"] = opr.outputs[0].comp_node
  395. return obj
  396. class GetVarShape(OpNode):
  397. type = "GetVarShape"
  398. opdef = builtin.GetVarShape
  399. class Concat(OpNode):
  400. type = "Concat"
  401. opdef = builtin.Concat
  402. @classmethod
  403. def load(cls, opr):
  404. obj = super(Concat, cls).load(opr)
  405. obj.params["comp_node"] = Device("xpux").to_c()
  406. return obj
  407. class Broadcast(OpNode):
  408. type = "Broadcast"
  409. opdef = builtin.Broadcast
  410. class Identity(OpNode):
  411. type = "Identity"
  412. opdef = builtin.Identity
  413. class NMSKeep(OpNode):
  414. type = "NMSKeep"
  415. opdef = builtin.NMSKeep
  416. # class ParamPackSplit
  417. # class ParamPackConcat
  418. class Dimshuffle(OpNode):
  419. type = "Dimshuffle"
  420. opdef = builtin.Dimshuffle
  421. @classmethod
  422. def load(cls, opr):
  423. obj = super(Dimshuffle, cls).load(opr)
  424. del obj.params["ndim"]
  425. return obj
  426. class Reshape(OpNode):
  427. type = "Reshape"
  428. opdef = builtin.Reshape
  429. class AxisAddRemove(OpNode):
  430. type = "AxisAddRemove"
  431. @classmethod
  432. def load(cls, opr):
  433. obj = cls()
  434. obj.name = opr.name
  435. obj._opr = opr
  436. params = json.loads(opr.params)
  437. desc = params["desc"]
  438. method = None
  439. axis = []
  440. for i in desc:
  441. if method is None:
  442. method = i["method"]
  443. assert method == i["method"]
  444. axis.append(i["axisnum"])
  445. obj.params = {"axis": axis}
  446. obj.opdef = builtin.AddAxis if desc[0]["method"] == 0 else builtin.RemoveAxis
  447. return obj
  448. class IndexingBase(OpNode):
  449. @classmethod
  450. def load(cls, opr):
  451. obj = cls()
  452. obj.name = opr.name
  453. obj._opr = opr
  454. params = json.loads(opr.params)
  455. items = [
  456. [
  457. p["axis"],
  458. bool(p["begin"]),
  459. bool(p["end"]),
  460. bool(p["step"]),
  461. bool(p["idx"]),
  462. ]
  463. for p in params
  464. ]
  465. obj.params["items"] = items
  466. return obj
  467. class Subtensor(IndexingBase):
  468. type = "Subtensor"
  469. opdef = builtin.Subtensor
  470. class SetSubtensor(IndexingBase):
  471. type = "SetSubtensor"
  472. opdef = builtin.SetSubtensor
  473. class IncrSubtensor(IndexingBase):
  474. type = "IncrSubtensor"
  475. opdef = builtin.IncrSubtensor
  476. class IndexingMultiAxisVec(IndexingBase):
  477. type = "IndexingMultiAxisVec"
  478. opdef = builtin.IndexingMultiAxisVec
  479. class IndexingSetMultiAxisVec(IndexingBase):
  480. type = "IndexingSetMultiAxisVec"
  481. opdef = builtin.IndexingSetMultiAxisVec
  482. class IndexingIncrMultiAxisVec(IndexingBase):
  483. type = "IndexingIncrMultiAxisVec"
  484. opdef = builtin.IndexingIncrMultiAxisVec
  485. class MeshIndexing(IndexingBase):
  486. type = "MeshIndexing"
  487. opdef = builtin.MeshIndexing
  488. class SetMeshIndexing(IndexingBase):
  489. type = "SetMeshIndexing"
  490. opdef = builtin.SetMeshIndexing
  491. class IncrMeshIndexing(IndexingBase):
  492. type = "IncrMeshIndexing"
  493. opdef = builtin.IncrMeshIndexing
  494. class BatchedMeshIndexing(IndexingBase):
  495. type = "BatchedMeshIndexing"
  496. opdef = builtin.BatchedMeshIndexing
  497. class BatchedSetMeshIndexing(IndexingBase):
  498. type = "BatchedSetMeshIndexing"
  499. opdef = builtin.BatchedSetMeshIndexing
  500. class BatchedIncrMeshIndexing(IndexingBase):
  501. type = "BatchedIncrMeshIndexing"
  502. opdef = builtin.BatchedIncrMeshIndexing
  503. # class CollectiveComm
  504. # class RemoteSend
  505. # class RemoteRecv
  506. # class TQT
  507. # class FakeQuant
  508. # class InplaceAdd
  509. class AssertEqual(OpNode):
  510. type = "AssertEqual"
  511. opdef = builtin.AssertEqual
  512. class CvtColorForward(OpNode):
  513. type = "CvtColor"
  514. opdef = builtin.CvtColor

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台