You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

network_node.py 19 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787
  1. # -*- coding: utf-8 -*-
  2. import json
  3. import sys
  4. from typing import Sequence
  5. import numpy as np
  6. from ..core import _imperative_rt as rt
  7. from ..core._imperative_rt.core2 import SymbolVar, apply
  8. from ..core._trace_option import use_symbolic_shape
  9. from ..core._wrap import Device
  10. from ..core.ops import builtin
  11. from ..core.tensor.array_method import ArrayMethodMixin
  12. from .comp_graph_tools import replace_vars
  13. from .module_stats import (
  14. preprocess_receptive_field,
  15. register_flops,
  16. register_receptive_field,
  17. )
  18. class NetworkNode:
  19. pass
  20. class VarNodeMeta(type(SymbolVar), type(ArrayMethodMixin)):
  21. pass
  22. class VarNode(NetworkNode, SymbolVar, ArrayMethodMixin, metaclass=VarNodeMeta):
  23. def __init__(self, var=None, *, owner_opr=None, name=None):
  24. SymbolVar.__init__(self, var)
  25. self.users = [] # List[OpNode]
  26. self.owner = owner_opr
  27. self.name = name
  28. self.id = id(self)
  29. @classmethod
  30. def load(cls, sym_var, owner_opr):
  31. obj = cls()
  32. obj.var = sym_var # mgb varnode
  33. obj.name = sym_var.name
  34. obj.owner = owner_opr
  35. return obj
  36. def _get_var_shape(self, axis=None):
  37. opdef = (
  38. builtin.GetVarShape() if axis is None else builtin.GetVarShape(axis=axis)
  39. )
  40. return apply(opdef, self)[0]
  41. @property
  42. def partial_shape(self):
  43. r"""Return the tuple type inferred shape of VarNode"""
  44. return tuple(self._get_var_shape().numpy())
  45. def shapeof(self, axis):
  46. r"""Return the symbolic shape of axis"""
  47. return self._get_var_shape(axis=axis) if self.var else None
  48. @property
  49. def _tuple_shape(self):
  50. return self.partial_shape
  51. @property
  52. def shape(self):
  53. r"""Return the symbolic shape if using set_symbolic_shape(True)
  54. else inferred shape
  55. """
  56. rst = None
  57. if self.var:
  58. try:
  59. rst = self.var.shape
  60. except:
  61. rst = None
  62. if not use_symbolic_shape():
  63. return rst
  64. return self._get_var_shape() if self.var else None
  65. @property
  66. def dtype(self):
  67. return self.var.dtype if self.var else None
  68. @property
  69. def ndim(self):
  70. return super().ndim
  71. def __bool__(self):
  72. return False
  73. __index__ = None
  74. __int__ = None
  75. __float__ = None
  76. __complex__ = None
  77. def __hash__(self):
  78. return id(self)
  79. def numpy(self):
  80. return super().numpy()
  81. def _reset(self, other):
  82. if not isinstance(other, VarNode):
  83. assert self.graph, "VarNode _reset must have graph"
  84. node = ImmutableTensor(other, graph=self.graph)
  85. node.compile(self.graph)
  86. other = node.outputs[0]
  87. if self.owner is not None:
  88. idx = self.owner.outputs.index(self)
  89. self.owner.outputs[idx] = VarNode(
  90. self.var, owner_opr=self.owner, name=self.var.name
  91. )
  92. self.var = other.var
  93. self.owner = None
  94. def set_owner_opr(self, owner_opr):
  95. self.owner = owner_opr
  96. class OpNode(NetworkNode):
  97. opdef = None
  98. type = None
  99. def __init__(self):
  100. self.inputs = []
  101. self.outputs = []
  102. self.params = {}
  103. self._opr = None # mgb opnode
  104. @property
  105. def id(self):
  106. return id(self)
  107. @property
  108. def priority(self):
  109. if self._opr is not None:
  110. return (self._opr.priority, self._opr.id)
  111. return (0, 0)
  112. @classmethod
  113. def load(cls, opr):
  114. obj = cls()
  115. obj.params = json.loads(opr.params)
  116. obj.name = opr.name
  117. obj._opr = opr
  118. return obj
  119. def compile(self):
  120. if (
  121. self._opr is None
  122. or len(self._opr.inputs) != len(self.inputs)
  123. or any([i != j.var for i, j in zip(self._opr.inputs, self.inputs)])
  124. ):
  125. op = self.opdef(**self.params)
  126. args = [i.var for i in self.inputs]
  127. outputs = rt.invoke_op(op, args)
  128. assert len(outputs) == len(self.outputs)
  129. self._opr = outputs[0].owner
  130. for i in range(len(self.outputs)):
  131. self.outputs[i].var = outputs[i]
  132. self.outputs[i].var.name = self.outputs[i].name
  133. assert self.outputs[i].owner is self
  134. def add_inp_var(self, x):
  135. self.inputs.append(x)
  136. def add_out_var(self, x):
  137. self.outputs.append(x)
  138. def __repr__(self):
  139. return "%s{%s}" % (self.name, self.type)
  140. def str_to_mge_class(classname):
  141. # TODO: use megbrain C++ RTTI to replace type string
  142. if classname == "RNGOpr<MegDNNOpr>":
  143. classname = "RNGOpr"
  144. oprcls = getattr(sys.modules[__name__], classname, None)
  145. return oprcls if oprcls else ReadOnlyOpNode
  146. class Host2DeviceCopy(OpNode):
  147. type = "Host2DeviceCopy"
  148. def __init__(self, shape=None, dtype=None, name=None, device=None):
  149. super().__init__()
  150. self.shape = shape
  151. self.dtype = dtype
  152. self.name = name
  153. self.device = Device(device).to_c() if device else Device("xpux").to_c()
  154. self.outputs = []
  155. @classmethod
  156. def load(cls, opr):
  157. self = cls()
  158. self.outputs = []
  159. assert len(opr.outputs) == 1, "wrong number of outputs"
  160. self.shape = opr.outputs[0].shape
  161. self.dtype = opr.outputs[0].dtype
  162. self.name = opr.outputs[0].name
  163. self.device = opr.outputs[0].comp_node
  164. self._opr = opr
  165. return self
  166. def compile(self, graph):
  167. if (
  168. self._opr is None
  169. or self._opr.graph != graph
  170. or self._opr.outputs[0].comp_node != self.device
  171. or self._opr.outputs[0].shape != self.shape
  172. or self._opr.outputs[0].dtype != self.dtype
  173. ):
  174. outputs = rt.make_h2d(graph, self.device, self.dtype, self.shape, self.name)
  175. self._opr = outputs.owner
  176. if len(self.outputs) == 0:
  177. self.outputs.append(VarNode(owner_opr=self, name=self.name))
  178. self.outputs[0].var = outputs
  179. assert self.outputs[0].owner is self
  180. class ConstOpBase(OpNode):
  181. type = "ConstOpBase"
  182. def __init__(self, data=None, name=None, device=None, graph=None):
  183. assert type(self) is not ConstOpBase, "ConstOpBase cannot be instantiated"
  184. super().__init__()
  185. self.name = name
  186. self.outputs = []
  187. self.graph = graph
  188. if data is not None:
  189. self.set_value(data, device)
  190. @property
  191. def device(self):
  192. return self._opr.outputs[0].comp_node if self._opr else None
  193. @device.setter
  194. def device(self, device):
  195. self.set_value(self.numpy(), device)
  196. @property
  197. def shape(self):
  198. return self.outputs[0].shape
  199. @property
  200. def dtype(self):
  201. return self._opr.outputs[0].dtype if self._opr else None
  202. def numpy(self):
  203. return self.outputs[0].numpy()
  204. def set_value(self, data, device=None):
  205. assert self.graph is not None
  206. cn = device if device else self.device
  207. assert isinstance(data, (int, float, Sequence, np.ndarray))
  208. if not isinstance(data, np.ndarray):
  209. data = np.array(data)
  210. if data.dtype == np.float64:
  211. data = data.astype(np.float32)
  212. elif data.dtype == np.int64:
  213. data = data.astype(np.int32)
  214. varnode = type(self).rt_fun(self.graph, data, cn, data.dtype, self.name)
  215. if len(self.outputs) == 0:
  216. self.outputs.append(VarNode(owner_opr=self, name=self.name))
  217. self.outputs[0].var = varnode
  218. self._opr = varnode.owner
  219. @classmethod
  220. def load(cls, opr):
  221. self = cls()
  222. self.outputs = []
  223. self._opr = opr
  224. self.name = opr.outputs[0].name
  225. self.graph = opr.graph
  226. return self
  227. def compile(self, graph):
  228. assert self.outputs[0].var is self._opr.outputs[0]
  229. assert self.outputs[0].owner is self
  230. if self.graph != graph:
  231. self.graph = graph
  232. self.set_value(self.numpy())
  233. if self.name is not None:
  234. self.outputs[0].var.name = self.name
  235. class ImmutableTensor(ConstOpBase):
  236. type = "ImmutableTensor"
  237. rt_fun = rt.make_const
  238. class SharedDeviceTensor(ConstOpBase):
  239. type = "SharedDeviceTensor"
  240. rt_fun = rt.make_shared
  241. class ReadOnlyOpNode(OpNode):
  242. @classmethod
  243. def load(cls, opr):
  244. obj = super(ReadOnlyOpNode, cls).load(opr)
  245. obj.type = opr.type
  246. return obj
  247. def compile(self):
  248. assert self._opr is not None
  249. assert len(self.inputs) == len(self._opr.inputs)
  250. assert len(self.outputs) == len(self._opr.outputs)
  251. repl_dict = {}
  252. for ind, i in enumerate(self.inputs):
  253. if i.var != self._opr.inputs[ind]:
  254. repl_dict[self._opr.inputs[ind]] = i.var
  255. if bool(repl_dict):
  256. out_vars = replace_vars(self._opr.outputs, repl_dict)
  257. for ind, o in enumerate(self.outputs):
  258. o.var = out_vars[ind]
  259. class Elemwise(OpNode):
  260. type = "Elemwise"
  261. opdef = builtin.Elemwise
  262. def __repr__(self):
  263. return "%s{Elemwise:%s}" % (self.name, self.params["mode"])
  264. class ElemwiseMultiType(OpNode):
  265. type = "ElemwiseMultiType"
  266. opdef = builtin.ElemwiseMultiType
  267. def __repr__(self):
  268. return "%s{ElemwiseMultiType:%s}" % (self.name, self.params["mode"])
  269. @classmethod
  270. def load(cls, opr):
  271. obj = super(ElemwiseMultiType, cls).load(opr)
  272. obj.params["dtype"] = opr.outputs[0].dtype
  273. return obj
  274. @register_flops(Elemwise, ElemwiseMultiType)
  275. def flops_elemwise(opnode: Elemwise, inputs, outputs):
  276. return np.prod(outputs[0].shape)
  277. class Reduce(OpNode):
  278. type = "Reduce"
  279. opdef = builtin.Reduce
  280. class TypeCvt(OpNode):
  281. type = "TypeCvt"
  282. opdef = builtin.TypeCvt
  283. @classmethod
  284. def load(cls, opr):
  285. obj = super(TypeCvt, cls).load(opr)
  286. t_dtype = opr.outputs[0].dtype
  287. obj.params["dtype"] = t_dtype
  288. return obj
  289. class MatrixInverse(OpNode):
  290. type = "MatrixInverse"
  291. opdef = builtin.MatrixInverse
  292. class MatrixMul(OpNode):
  293. type = "MatrixMul"
  294. opdef = builtin.MatrixMul
  295. @register_flops(MatrixMul)
  296. def flops_matmul(opnode: MatrixMul, inputs, outputs):
  297. assert len(inputs[0].shape) == 2 and len(outputs[0].shape) == 2
  298. mid_shape = inputs[0].shape[1]
  299. return np.prod(outputs[0].shape) * mid_shape
  300. class BatchedMatrixMul(OpNode):
  301. type = "BatchedMatmul"
  302. opdef = builtin.BatchedMatrixMul
  303. @register_flops(BatchedMatrixMul)
  304. def flops_batchmatmul(opnode: BatchedMatrixMul, inputs, outputs):
  305. assert len(inputs[0].shape) == 3 and len(outputs[0].shape) == 3
  306. mid_shape = inputs[0].shape[2]
  307. return np.prod(outputs[0].shape) * mid_shape
  308. class Dot(OpNode):
  309. type = "Dot"
  310. opdef = builtin.Dot
  311. class SVD(OpNode):
  312. type = "SVD"
  313. opdef = builtin.SVD
  314. class ConvolutionForward(OpNode):
  315. type = "Convolution"
  316. opdef = builtin.Convolution
  317. class ConvolutionBackwardData(OpNode):
  318. type = "ConvTranspose"
  319. opdef = builtin.ConvolutionBackwardData
  320. class DeformableConvForward(OpNode):
  321. type = "DeformableConv"
  322. opdef = builtin.DeformableConv
  323. class GroupLocalForward(OpNode):
  324. type = "GroupLocal"
  325. opdef = builtin.GroupLocal
  326. class PoolingForward(OpNode):
  327. type = "Pooling"
  328. opdef = builtin.Pooling
  329. class AdaptivePoolingForward(OpNode):
  330. type = "AdaptivePooling"
  331. opdef = builtin.AdaptivePooling
  332. class ROIPoolingForward(OpNode):
  333. type = "ROIPooling"
  334. opdef = builtin.ROIPooling
  335. class DeformablePSROIPoolingForward(OpNode):
  336. type = "DeformablePSROIPooling"
  337. opdef = builtin.DeformablePSROIPooling
  338. class ConvBiasForward(OpNode):
  339. type = "ConvBias"
  340. opdef = builtin.ConvBias
  341. @classmethod
  342. def load(cls, opr):
  343. obj = super(ConvBiasForward, cls).load(opr)
  344. obj.params["dtype"] = opr.outputs[0].dtype
  345. return obj
  346. @register_flops(
  347. ConvolutionForward, ConvBiasForward,
  348. )
  349. def flops_conv(opnode: ConvolutionForward, inputs, outputs):
  350. param_W_shape = inputs[1].shape
  351. kh = param_W_shape[-2]
  352. kw = param_W_shape[-1]
  353. if len(param_W_shape) == 5:
  354. num_input = param_W_shape[2]
  355. else:
  356. num_input = param_W_shape[1]
  357. NCHW = np.prod(outputs[0].shape)
  358. bias = 1 if isinstance(opnode, ConvBiasForward) else 0
  359. # N x Cout x H x W x (Cin x Kw x Kh)
  360. return NCHW * (num_input * kw * kh + bias)
  361. @register_receptive_field(ConvolutionForward, ConvBiasForward)
  362. def receptive_field(opnode: ConvolutionForward, inputs, outputs):
  363. pre_rf, pre_stride = preprocess_receptive_field(opnode, inputs, outputs)
  364. param_W_shape = inputs[1].shape
  365. kh = param_W_shape[-2]
  366. kw = param_W_shape[-1]
  367. rf = (
  368. kh * pre_stride[0] + pre_rf[0] - pre_stride[0],
  369. kw * pre_stride[1] + pre_rf[1] - pre_stride[1],
  370. )
  371. stride = (
  372. opnode.params["stride_h"] * pre_stride[0],
  373. opnode.params["stride_w"] * pre_stride[1],
  374. )
  375. opnode._rf = rf
  376. opnode._stride = stride
  377. return rf, stride
  378. class BatchConvBiasForward(OpNode):
  379. type = "BatchConvBias"
  380. opdef = builtin.BatchConvBias
  381. @classmethod
  382. def load(cls, opr):
  383. obj = super(BatchConvBiasForward, cls).load(opr)
  384. obj.params["dtype"] = opr.outputs[0].dtype
  385. return obj
  386. class BatchNormForward(OpNode):
  387. type = "BatchNorm"
  388. opdef = builtin.BatchNorm
  389. output_idx = -1
  390. class ROIAlignForward(OpNode):
  391. type = "ROIAlign"
  392. opdef = builtin.ROIAlign
  393. class WarpPerspectiveForward(OpNode):
  394. type = "WarpPerspective"
  395. opdef = builtin.WarpPerspective
  396. class WarpAffineForward(OpNode):
  397. type = "WarpAffine"
  398. opdef = builtin.WarpAffine
  399. class RemapForward(OpNode):
  400. type = "Remap"
  401. opdef = builtin.Remap
  402. class ResizeForward(OpNode):
  403. type = "Resize"
  404. opdef = builtin.Resize
  405. class IndexingOneHot(OpNode):
  406. type = "IndexingOneHot"
  407. opdef = builtin.IndexingOneHot
  408. class IndexingSetOneHot(OpNode):
  409. type = "IndexingSetOneHot"
  410. opdef = builtin.IndexingSetOneHot
  411. class Copy(OpNode):
  412. type = "Copy"
  413. opdef = builtin.Copy
  414. @classmethod
  415. def load(cls, opr):
  416. obj = super(Copy, cls).load(opr)
  417. obj.params["comp_node"] = opr.outputs[0].comp_node
  418. return obj
  419. class ArgsortForward(OpNode):
  420. type = "Argsort"
  421. opdef = builtin.Argsort
  422. class Argmax(OpNode):
  423. type = "Argmax"
  424. opdef = builtin.Argmax
  425. class Argmin(OpNode):
  426. type = "Argmin"
  427. opdef = builtin.Argmin
  428. class CondTake(OpNode):
  429. type = "CondTake"
  430. opdef = builtin.CondTake
  431. class TopK(OpNode):
  432. type = "TopK"
  433. opdef = builtin.TopK
  434. class NvOf(OpNode):
  435. type = "NvOf"
  436. opdef = builtin.NvOf
  437. class RNGOpr(OpNode):
  438. @classmethod
  439. def load(cls, opr):
  440. obj = super(RNGOpr, cls).load(opr)
  441. if len(obj.params) == 3:
  442. obj.opdef = builtin.GaussianRNG
  443. obj.type = "GaussianRNG"
  444. else:
  445. obj.opdef = builtin.UniformRNG
  446. obj.type = "UniformRNG"
  447. return obj
  448. class Linspace(OpNode):
  449. type = "Linspace"
  450. opdef = builtin.Linspace
  451. @classmethod
  452. def load(cls, opr):
  453. obj = super(Linspace, cls).load(opr)
  454. obj.params["comp_node"] = opr.outputs[0].comp_node
  455. return obj
  456. class Eye(OpNode):
  457. type = "Eye"
  458. opdef = builtin.Eye
  459. @classmethod
  460. def load(cls, opr):
  461. obj = super(Eye, cls).load(opr)
  462. obj.params["dtype"] = opr.outputs[0].dtype
  463. obj.params["comp_node"] = opr.outputs[0].comp_node
  464. return obj
  465. class GetVarShape(OpNode):
  466. type = "GetVarShape"
  467. opdef = builtin.GetVarShape
  468. class Concat(OpNode):
  469. type = "Concat"
  470. opdef = builtin.Concat
  471. @classmethod
  472. def load(cls, opr):
  473. obj = super(Concat, cls).load(opr)
  474. obj.params["comp_node"] = Device("xpux").to_c()
  475. return obj
  476. class Broadcast(OpNode):
  477. type = "Broadcast"
  478. opdef = builtin.Broadcast
  479. class Identity(OpNode):
  480. type = "Identity"
  481. opdef = builtin.Identity
  482. class NMSKeep(OpNode):
  483. type = "NMSKeep"
  484. opdef = builtin.NMSKeep
  485. # class ParamPackSplit
  486. # class ParamPackConcat
  487. class Dimshuffle(OpNode):
  488. type = "Dimshuffle"
  489. opdef = builtin.Dimshuffle
  490. @classmethod
  491. def load(cls, opr):
  492. obj = super(Dimshuffle, cls).load(opr)
  493. del obj.params["ndim"]
  494. return obj
  495. class Reshape(OpNode):
  496. type = "Reshape"
  497. opdef = builtin.Reshape
  498. class AxisAddRemove(OpNode):
  499. type = "AxisAddRemove"
  500. @classmethod
  501. def load(cls, opr):
  502. obj = cls()
  503. obj.name = opr.name
  504. obj._opr = opr
  505. params = json.loads(opr.params)
  506. desc = params["desc"]
  507. method = None
  508. axis = []
  509. for i in desc:
  510. if method is None:
  511. method = i["method"]
  512. assert method == i["method"]
  513. axis.append(i["axisnum"])
  514. obj.params = {"axis": axis}
  515. obj.opdef = builtin.AddAxis if desc[0]["method"] == 0 else builtin.RemoveAxis
  516. return obj
  517. class IndexingBase(OpNode):
  518. @classmethod
  519. def load(cls, opr):
  520. obj = cls()
  521. obj.name = opr.name
  522. obj._opr = opr
  523. params = json.loads(opr.params)
  524. items = [
  525. [
  526. p["axis"],
  527. bool(p["begin"]),
  528. bool(p["end"]),
  529. bool(p["step"]),
  530. bool(p["idx"]),
  531. ]
  532. for p in params
  533. ]
  534. obj.params["items"] = items
  535. return obj
  536. class Subtensor(IndexingBase):
  537. type = "Subtensor"
  538. opdef = builtin.Subtensor
  539. class SetSubtensor(IndexingBase):
  540. type = "SetSubtensor"
  541. opdef = builtin.SetSubtensor
  542. class IncrSubtensor(IndexingBase):
  543. type = "IncrSubtensor"
  544. opdef = builtin.IncrSubtensor
  545. class IndexingMultiAxisVec(IndexingBase):
  546. type = "IndexingMultiAxisVec"
  547. opdef = builtin.IndexingMultiAxisVec
  548. class IndexingSetMultiAxisVec(IndexingBase):
  549. type = "IndexingSetMultiAxisVec"
  550. opdef = builtin.IndexingSetMultiAxisVec
  551. class IndexingIncrMultiAxisVec(IndexingBase):
  552. type = "IndexingIncrMultiAxisVec"
  553. opdef = builtin.IndexingIncrMultiAxisVec
  554. class MeshIndexing(IndexingBase):
  555. type = "MeshIndexing"
  556. opdef = builtin.MeshIndexing
  557. class SetMeshIndexing(IndexingBase):
  558. type = "SetMeshIndexing"
  559. opdef = builtin.SetMeshIndexing
  560. class IncrMeshIndexing(IndexingBase):
  561. type = "IncrMeshIndexing"
  562. opdef = builtin.IncrMeshIndexing
  563. class BatchedMeshIndexing(IndexingBase):
  564. type = "BatchedMeshIndexing"
  565. opdef = builtin.BatchedMeshIndexing
  566. class BatchedSetMeshIndexing(IndexingBase):
  567. type = "BatchedSetMeshIndexing"
  568. opdef = builtin.BatchedSetMeshIndexing
  569. class BatchedIncrMeshIndexing(IndexingBase):
  570. type = "BatchedIncrMeshIndexing"
  571. opdef = builtin.BatchedIncrMeshIndexing
  572. # class CollectiveComm
  573. # class RemoteSend
  574. # class RemoteRecv
  575. # class TQT
  576. # class FakeQuant
  577. # class InplaceAdd
  578. class AssertEqual(OpNode):
  579. type = "AssertEqual"
  580. opdef = builtin.AssertEqual
  581. class CvtColorForward(OpNode):
  582. type = "CvtColor"
  583. opdef = builtin.CvtColor