You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

network_node.py 20 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import json
  10. import sys
  11. from typing import Sequence
  12. import numpy as np
  13. from ..core import _imperative_rt as rt
  14. from ..core._imperative_rt.core2 import SymbolVar, apply
  15. from ..core._trace_option import use_symbolic_shape
  16. from ..core._wrap import Device
  17. from ..core.ops import builtin
  18. from ..core.tensor.array_method import ArrayMethodMixin
  19. from ..core.tensor.megbrain_graph import OutputNode
  20. from .comp_graph_tools import replace_vars
  21. from .module_stats import (
  22. preprocess_receptive_field,
  23. register_flops,
  24. register_receptive_field,
  25. )
  26. class NetworkNode:
  27. pass
  28. class VarNodeMeta(type(SymbolVar), type(ArrayMethodMixin)):
  29. pass
  30. class VarNode(NetworkNode, SymbolVar, ArrayMethodMixin, metaclass=VarNodeMeta):
  31. def __init__(self, var=None, *, owner_opr=None, name=None):
  32. SymbolVar.__init__(self, var)
  33. self.owner = owner_opr
  34. self.name = name
  35. self.id = id(self)
  36. @classmethod
  37. def load(cls, sym_var, owner_opr):
  38. obj = cls()
  39. obj.var = sym_var # mgb varnode
  40. obj.name = sym_var.name
  41. obj.owner = owner_opr
  42. return obj
  43. def _get_var_shape(self, axis=None):
  44. opdef = (
  45. builtin.GetVarShape() if axis is None else builtin.GetVarShape(axis=axis)
  46. )
  47. return apply(opdef, self)[0]
  48. @property
  49. def partial_shape(self):
  50. """Return the tuple type inferred shape of VarNode
  51. """
  52. return tuple(self._get_var_shape().numpy())
  53. def shapeof(self, axis):
  54. """Return the symbolic shape of axis
  55. """
  56. return self._get_var_shape(axis=axis) if self.var else None
  57. @property
  58. def _tuple_shape(self):
  59. return self.partial_shape
  60. @property
  61. def shape(self):
  62. """Return the symbolic shape if using set_symbolic_shape(True)
  63. else inferred shape
  64. """
  65. rst = None
  66. if self.var:
  67. try:
  68. rst = self.var.shape
  69. except:
  70. rst = None
  71. if not use_symbolic_shape():
  72. return rst
  73. return self._get_var_shape() if self.var else None
  74. @property
  75. def dtype(self):
  76. return self.var.dtype if self.var else None
  77. @property
  78. def ndim(self):
  79. return super().ndim
  80. def __bool__(self):
  81. return False
  82. __index__ = None
  83. __int__ = None
  84. __float__ = None
  85. __complex__ = None
  86. def __hash__(self):
  87. return id(self)
  88. def numpy(self):
  89. o = OutputNode(self.var)
  90. self.graph.compile(o.outputs).execute()
  91. return o.get_value().numpy()
  92. def _reset(self, other):
  93. if not isinstance(other, VarNode):
  94. assert self.graph, "VarNode _reset must have graph"
  95. node = ImmutableTensor(other, graph=self.graph)
  96. node.compile(self.graph)
  97. other = node.outputs[0]
  98. if self.owner is not None:
  99. idx = self.owner.outputs.index(self)
  100. self.owner.outputs[idx] = VarNode(
  101. self.var, owner_opr=self.owner, name=self.var.name
  102. )
  103. self.var = other.var
  104. self.owner = None
  105. def set_owner_opr(self, owner_opr):
  106. self.owner = owner_opr
  107. class OpNode(NetworkNode):
  108. opdef = None
  109. type = None
  110. def __init__(self):
  111. self.inputs = []
  112. self.outputs = []
  113. self.params = {}
  114. self._opr = None # mgb opnode
  115. @property
  116. def id(self):
  117. if self._opr is not None:
  118. return self._opr.id
  119. return id(self)
  120. @property
  121. def priority(self):
  122. if self._opr is not None:
  123. return self._opr.priority
  124. return 0
  125. @classmethod
  126. def load(cls, opr):
  127. obj = cls()
  128. obj.params = json.loads(opr.params)
  129. obj.name = opr.name
  130. obj._opr = opr
  131. return obj
  132. def compile(self):
  133. if (
  134. self._opr is None
  135. or len(self._opr.inputs) != len(self.inputs)
  136. or any([i != j.var for i, j in zip(self._opr.inputs, self.inputs)])
  137. ):
  138. op = self.opdef(**self.params)
  139. args = [i.var for i in self.inputs]
  140. outputs = rt.invoke_op(op, args)
  141. assert len(outputs) == len(self.outputs)
  142. self._opr = outputs[0].owner
  143. for i in range(len(self.outputs)):
  144. self.outputs[i].var = outputs[i]
  145. self.outputs[i].var.name = self.outputs[i].name
  146. assert self.outputs[i].owner is self
  147. def add_inp_var(self, x):
  148. self.inputs.append(x)
  149. def add_out_var(self, x):
  150. self.outputs.append(x)
  151. def __repr__(self):
  152. return "%s{%s}" % (self.name, self.type)
  153. def str_to_mge_class(classname):
  154. # TODO: use megbrain C++ RTTI to replace type string
  155. if classname == "RNGOpr<MegDNNOpr>":
  156. classname = "RNGOpr"
  157. oprcls = getattr(sys.modules[__name__], classname, None)
  158. return oprcls if oprcls else ReadOnlyOpNode
  159. class Host2DeviceCopy(OpNode):
  160. type = "Host2DeviceCopy"
  161. def __init__(self, shape=None, dtype=None, name=None, device=None):
  162. super().__init__()
  163. self.shape = shape
  164. self.dtype = dtype
  165. self.name = name
  166. self.device = Device(device).to_c() if device else Device("xpux").to_c()
  167. self.outputs = []
  168. @classmethod
  169. def load(cls, opr):
  170. self = cls()
  171. self.outputs = []
  172. assert len(opr.outputs) == 1, "wrong number of outputs"
  173. self.shape = opr.outputs[0].shape
  174. self.dtype = opr.outputs[0].dtype
  175. self.name = opr.outputs[0].name
  176. self.device = opr.outputs[0].comp_node
  177. self._opr = opr
  178. return self
  179. def compile(self, graph):
  180. if (
  181. self._opr is None
  182. or self._opr.outputs[0].comp_node != self.device
  183. or self._opr.outputs[0].shape != self.shape
  184. or self._opr.outputs[0].dtype != self.dtype
  185. ):
  186. outputs = rt.make_h2d(graph, self.device, self.dtype, self.shape, self.name)
  187. self._opr = outputs.owner
  188. if len(self.outputs) == 0:
  189. self.outputs.append(VarNode(owner_opr=self, name=self.name))
  190. self.outputs[0].var = outputs
  191. assert self.outputs[0].owner is self
  192. class ImmutableTensor(OpNode):
  193. type = "ImmutableTensor"
  194. def __init__(self, data=None, name=None, device=None, graph=None):
  195. super().__init__()
  196. self.name = name
  197. self.outputs = []
  198. self.graph = graph
  199. if data is not None:
  200. self.set_value(data, device)
  201. @property
  202. def device(self):
  203. return self._opr.outputs[0].comp_node if self._opr else None
  204. @device.setter
  205. def device(self, device):
  206. self.set_value(self.numpy(), device)
  207. @property
  208. def shape(self):
  209. return self.outputs[0].shape
  210. @property
  211. def dtype(self):
  212. return self._opr.outputs[0].dtype if self._opr else None
  213. def numpy(self):
  214. return self._opr.outputs[0].value if self._opr else None
  215. def set_value(self, data, device=None):
  216. assert self.graph is not None
  217. cn = device if device else self.device
  218. assert isinstance(data, (int, float, Sequence, np.ndarray))
  219. if not isinstance(data, np.ndarray):
  220. data = np.array(data)
  221. if data.dtype == np.float64:
  222. data = data.astype(np.float32)
  223. elif data.dtype == np.int64:
  224. data = data.astype(np.int32)
  225. varnode = rt.make_const(self.graph, data, cn, data.dtype, self.name)
  226. if len(self.outputs) == 0:
  227. self.outputs.append(VarNode(owner_opr=self, name=self.name))
  228. self.outputs[0].var = varnode
  229. self._opr = varnode.owner
  230. @classmethod
  231. def load(cls, opr):
  232. self = cls()
  233. self.outputs = []
  234. self._opr = opr
  235. self.name = opr.outputs[0].name
  236. self.graph = opr.graph
  237. return self
  238. def compile(self, graph):
  239. assert self.outputs[0].var is self._opr.outputs[0]
  240. assert self.outputs[0].owner is self
  241. if self.graph != graph:
  242. self.graph = graph
  243. self.set_value(self.numpy())
  244. if self.name is not None:
  245. self.outputs[0].var.name = self.name
  246. class ReadOnlyOpNode(OpNode):
  247. @classmethod
  248. def load(cls, opr):
  249. obj = super(ReadOnlyOpNode, cls).load(opr)
  250. obj.type = opr.type
  251. return obj
  252. def compile(self):
  253. assert self._opr is not None
  254. assert len(self.inputs) == len(self._opr.inputs)
  255. assert len(self.outputs) == len(self._opr.outputs)
  256. repl_dict = {}
  257. for ind, i in enumerate(self.inputs):
  258. if i.var != self._opr.inputs[ind]:
  259. repl_dict[self._opr.inputs[ind]] = i.var
  260. if bool(repl_dict):
  261. out_vars = replace_vars(self._opr.outputs, repl_dict)
  262. for ind, o in enumerate(self.outputs):
  263. o.var = out_vars[ind]
  264. class Elemwise(OpNode):
  265. type = "Elemwise"
  266. opdef = builtin.Elemwise
  267. def __repr__(self):
  268. return "%s{Elemwise:%s}" % (self.name, self.params["mode"])
  269. class ElemwiseMultiType(OpNode):
  270. type = "ElemwiseMultiType"
  271. opdef = builtin.ElemwiseMultiType
  272. def __repr__(self):
  273. return "%s{ElemwiseMultiType:%s}" % (self.name, self.params["mode"])
  274. @classmethod
  275. def load(cls, opr):
  276. obj = super(ElemwiseMultiType, cls).load(opr)
  277. obj.params["dtype"] = opr.outputs[0].dtype
  278. return obj
  279. @register_flops(Elemwise, ElemwiseMultiType)
  280. def flops_elemwise(opnode: Elemwise, inputs, outputs):
  281. return np.prod(outputs[0].shape)
  282. class Reduce(OpNode):
  283. type = "Reduce"
  284. opdef = builtin.Reduce
  285. class TypeCvt(OpNode):
  286. type = "TypeCvt"
  287. opdef = builtin.TypeCvt
  288. @classmethod
  289. def load(cls, opr):
  290. obj = super(TypeCvt, cls).load(opr)
  291. t_dtype = opr.outputs[0].dtype
  292. obj.params["dtype"] = t_dtype
  293. return obj
  294. class MatrixInverse(OpNode):
  295. type = "MatrixInverse"
  296. opdef = builtin.MatrixInverse
  297. class MatrixMul(OpNode):
  298. type = "MatrixMul"
  299. opdef = builtin.MatrixMul
  300. @register_flops(MatrixMul)
  301. def flops_matmul(opnode: MatrixMul, inputs, outputs):
  302. assert len(inputs[0].shape) == 2 and len(outputs[0].shape) == 2
  303. mid_shape = inputs[0].shape[1]
  304. return np.prod(outputs[0].shape) * mid_shape
  305. class BatchedMatrixMul(OpNode):
  306. type = "BatchedMatmul"
  307. opdef = builtin.BatchedMatrixMul
  308. @register_flops(BatchedMatrixMul)
  309. def flops_batchmatmul(opnode: BatchedMatrixMul, inputs, outputs):
  310. assert len(inputs[0].shape) == 3 and len(outputs[0].shape) == 3
  311. mid_shape = inputs[0].shape[2]
  312. return np.prod(outputs[0].shape) * mid_shape
  313. class Dot(OpNode):
  314. type = "Dot"
  315. opdef = builtin.Dot
  316. class SVD(OpNode):
  317. type = "SVD"
  318. opdef = builtin.SVD
  319. class ConvolutionForward(OpNode):
  320. type = "Convolution"
  321. opdef = builtin.Convolution
  322. class ConvolutionBackwardData(OpNode):
  323. type = "ConvTranspose"
  324. opdef = builtin.ConvolutionBackwardData
  325. class DeformableConvForward(OpNode):
  326. type = "DeformableConv"
  327. opdef = builtin.DeformableConv
  328. class GroupLocalForward(OpNode):
  329. type = "GroupLocal"
  330. opdef = builtin.GroupLocal
  331. class PoolingForward(OpNode):
  332. type = "Pooling"
  333. opdef = builtin.Pooling
  334. class AdaptivePoolingForward(OpNode):
  335. type = "AdaptivePooling"
  336. opdef = builtin.AdaptivePooling
  337. class ROIPoolingForward(OpNode):
  338. type = "ROIPooling"
  339. opdef = builtin.ROIPooling
  340. class DeformablePSROIPoolingForward(OpNode):
  341. type = "DeformablePSROIPooling"
  342. opdef = builtin.DeformablePSROIPooling
  343. class ConvBiasForward(OpNode):
  344. type = "ConvBias"
  345. opdef = builtin.ConvBias
  346. @classmethod
  347. def load(cls, opr):
  348. obj = super(ConvBiasForward, cls).load(opr)
  349. obj.params["dtype"] = opr.outputs[0].dtype
  350. return obj
  351. @register_flops(
  352. ConvolutionForward, ConvBiasForward,
  353. )
  354. def flops_conv(opnode: ConvolutionForward, inputs, outputs):
  355. param_W_shape = inputs[1].shape
  356. kh = param_W_shape[-2]
  357. kw = param_W_shape[-1]
  358. if len(param_W_shape) == 5:
  359. num_input = param_W_shape[2]
  360. else:
  361. num_input = param_W_shape[1]
  362. NCHW = np.prod(outputs[0].shape)
  363. bias = 1 if isinstance(opnode, ConvBiasForward) else 0
  364. # N x Cout x H x W x (Cin x Kw x Kh)
  365. return NCHW * (num_input * kw * kh + bias)
  366. @register_receptive_field(ConvolutionForward, ConvBiasForward)
  367. def receptive_field(opnode: ConvolutionForward, inputs, outputs):
  368. pre_rf, pre_stride = preprocess_receptive_field(opnode, inputs, outputs)
  369. param_W_shape = inputs[1].shape
  370. kh = param_W_shape[-2]
  371. kw = param_W_shape[-1]
  372. rf = (
  373. kh * pre_stride[0] + pre_rf[0] - pre_stride[0],
  374. kw * pre_stride[1] + pre_rf[1] - pre_stride[1],
  375. )
  376. stride = (
  377. opnode.params["stride_h"] * pre_stride[0],
  378. opnode.params["stride_w"] * pre_stride[1],
  379. )
  380. opnode._rf = rf
  381. opnode._stride = stride
  382. return rf, stride
  383. class BatchConvBiasForward(OpNode):
  384. type = "BatchConvBias"
  385. opdef = builtin.BatchConvBias
  386. @classmethod
  387. def load(cls, opr):
  388. obj = super(BatchConvBiasForward, cls).load(opr)
  389. obj.params["dtype"] = opr.outputs[0].dtype
  390. return obj
  391. class BatchNormForward(OpNode):
  392. type = "BatchNorm"
  393. opdef = builtin.BatchNorm
  394. output_idx = -1
  395. class ROIAlignForward(OpNode):
  396. type = "ROIAlign"
  397. opdef = builtin.ROIAlign
  398. class WarpPerspectiveForward(OpNode):
  399. type = "WarpPerspective"
  400. opdef = builtin.WarpPerspective
  401. class WarpAffineForward(OpNode):
  402. type = "WarpAffine"
  403. opdef = builtin.WarpAffine
  404. class RemapForward(OpNode):
  405. type = "Remap"
  406. opdef = builtin.Remap
  407. class ResizeForward(OpNode):
  408. type = "Resize"
  409. opdef = builtin.Resize
  410. class IndexingOneHot(OpNode):
  411. type = "IndexingOneHot"
  412. opdef = builtin.IndexingOneHot
  413. class IndexingSetOneHot(OpNode):
  414. type = "IndexingSetOneHot"
  415. opdef = builtin.IndexingSetOneHot
  416. class Copy(OpNode):
  417. type = "Copy"
  418. opdef = builtin.Copy
  419. @classmethod
  420. def load(cls, opr):
  421. obj = super(Copy, cls).load(opr)
  422. obj.params["comp_node"] = opr.outputs[0].comp_node
  423. return obj
  424. class ArgsortForward(OpNode):
  425. type = "Argsort"
  426. opdef = builtin.Argsort
  427. class Argmax(OpNode):
  428. type = "Argmax"
  429. opdef = builtin.Argmax
  430. class Argmin(OpNode):
  431. type = "Argmin"
  432. opdef = builtin.Argmin
  433. class CondTake(OpNode):
  434. type = "CondTake"
  435. opdef = builtin.CondTake
  436. class TopK(OpNode):
  437. type = "TopK"
  438. opdef = builtin.TopK
  439. class NvOf(OpNode):
  440. type = "NvOf"
  441. opdef = builtin.NvOf
  442. class RNGOpr(OpNode):
  443. @classmethod
  444. def load(cls, opr):
  445. obj = super(RNGOpr, cls).load(opr)
  446. if len(obj.params) == 3:
  447. obj.opdef = builtin.GaussianRNG
  448. obj.type = "GaussianRNG"
  449. else:
  450. obj.opdef = builtin.UniformRNG
  451. obj.type = "UniformRNG"
  452. return obj
  453. class Linspace(OpNode):
  454. type = "Linspace"
  455. opdef = builtin.Linspace
  456. @classmethod
  457. def load(cls, opr):
  458. obj = super(Linspace, cls).load(opr)
  459. obj.params["comp_node"] = opr.outputs[0].comp_node
  460. return obj
  461. class Eye(OpNode):
  462. type = "Eye"
  463. opdef = builtin.Eye
  464. @classmethod
  465. def load(cls, opr):
  466. obj = super(Eye, cls).load(opr)
  467. obj.params["dtype"] = opr.outputs[0].dtype
  468. obj.params["comp_node"] = opr.outputs[0].comp_node
  469. return obj
  470. class GetVarShape(OpNode):
  471. type = "GetVarShape"
  472. opdef = builtin.GetVarShape
  473. class Concat(OpNode):
  474. type = "Concat"
  475. opdef = builtin.Concat
  476. @classmethod
  477. def load(cls, opr):
  478. obj = super(Concat, cls).load(opr)
  479. obj.params["comp_node"] = Device("xpux").to_c()
  480. return obj
  481. class Broadcast(OpNode):
  482. type = "Broadcast"
  483. opdef = builtin.Broadcast
  484. class Identity(OpNode):
  485. type = "Identity"
  486. opdef = builtin.Identity
  487. class NMSKeep(OpNode):
  488. type = "NMSKeep"
  489. opdef = builtin.NMSKeep
  490. # class ParamPackSplit
  491. # class ParamPackConcat
  492. class Dimshuffle(OpNode):
  493. type = "Dimshuffle"
  494. opdef = builtin.Dimshuffle
  495. @classmethod
  496. def load(cls, opr):
  497. obj = super(Dimshuffle, cls).load(opr)
  498. del obj.params["ndim"]
  499. return obj
  500. class Reshape(OpNode):
  501. type = "Reshape"
  502. opdef = builtin.Reshape
  503. class AxisAddRemove(OpNode):
  504. type = "AxisAddRemove"
  505. @classmethod
  506. def load(cls, opr):
  507. obj = cls()
  508. obj.name = opr.name
  509. obj._opr = opr
  510. params = json.loads(opr.params)
  511. desc = params["desc"]
  512. method = None
  513. axis = []
  514. for i in desc:
  515. if method is None:
  516. method = i["method"]
  517. assert method == i["method"]
  518. axis.append(i["axisnum"])
  519. obj.params = {"axis": axis}
  520. obj.opdef = builtin.AddAxis if desc[0]["method"] == 0 else builtin.RemoveAxis
  521. return obj
  522. class IndexingBase(OpNode):
  523. @classmethod
  524. def load(cls, opr):
  525. obj = cls()
  526. obj.name = opr.name
  527. obj._opr = opr
  528. params = json.loads(opr.params)
  529. items = [
  530. [
  531. p["axis"],
  532. bool(p["begin"]),
  533. bool(p["end"]),
  534. bool(p["step"]),
  535. bool(p["idx"]),
  536. ]
  537. for p in params
  538. ]
  539. obj.params["items"] = items
  540. return obj
  541. class Subtensor(IndexingBase):
  542. type = "Subtensor"
  543. opdef = builtin.Subtensor
  544. class SetSubtensor(IndexingBase):
  545. type = "SetSubtensor"
  546. opdef = builtin.SetSubtensor
  547. class IncrSubtensor(IndexingBase):
  548. type = "IncrSubtensor"
  549. opdef = builtin.IncrSubtensor
  550. class IndexingMultiAxisVec(IndexingBase):
  551. type = "IndexingMultiAxisVec"
  552. opdef = builtin.IndexingMultiAxisVec
  553. class IndexingSetMultiAxisVec(IndexingBase):
  554. type = "IndexingSetMultiAxisVec"
  555. opdef = builtin.IndexingSetMultiAxisVec
  556. class IndexingIncrMultiAxisVec(IndexingBase):
  557. type = "IndexingIncrMultiAxisVec"
  558. opdef = builtin.IndexingIncrMultiAxisVec
  559. class MeshIndexing(IndexingBase):
  560. type = "MeshIndexing"
  561. opdef = builtin.MeshIndexing
  562. class SetMeshIndexing(IndexingBase):
  563. type = "SetMeshIndexing"
  564. opdef = builtin.SetMeshIndexing
  565. class IncrMeshIndexing(IndexingBase):
  566. type = "IncrMeshIndexing"
  567. opdef = builtin.IncrMeshIndexing
  568. class BatchedMeshIndexing(IndexingBase):
  569. type = "BatchedMeshIndexing"
  570. opdef = builtin.BatchedMeshIndexing
  571. class BatchedSetMeshIndexing(IndexingBase):
  572. type = "BatchedSetMeshIndexing"
  573. opdef = builtin.BatchedSetMeshIndexing
  574. class BatchedIncrMeshIndexing(IndexingBase):
  575. type = "BatchedIncrMeshIndexing"
  576. opdef = builtin.BatchedIncrMeshIndexing
  577. # class CollectiveComm
  578. # class RemoteSend
  579. # class RemoteRecv
  580. # class TQT
  581. # class FakeQuant
  582. # class InplaceAdd
  583. class AssertEqual(OpNode):
  584. type = "AssertEqual"
  585. opdef = builtin.AssertEqual
  586. class CvtColorForward(OpNode):
  587. type = "CvtColor"
  588. opdef = builtin.CvtColor

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台