GitOrigin-RevId: 6a9d5beba2
tags/v1.3.0
| @@ -11,6 +11,7 @@ from typing import Iterable, Union | |||
| import numpy as np | |||
| from .._imperative_rt import VarNode | |||
| from .._imperative_rt.core2 import Tensor, apply, dtype_promotion, get_device | |||
| from ..ops import builtin | |||
| from ..ops.special import Const | |||
| @@ -59,7 +60,7 @@ def astype(x, dtype): | |||
| def convert_single_value(v, *, dtype=None, device=None): | |||
| if isinstance(v, Tensor): | |||
| if isinstance(v, (Tensor, VarNode)): | |||
| if not is_quantize(v.dtype): | |||
| v = astype(v, dtype) | |||
| else: | |||
| @@ -12,11 +12,12 @@ import functools | |||
| import numpy as np | |||
| from ..core._imperative_rt.core2 import apply | |||
| from ..core._imperative_rt.graph import VarNode | |||
| from ..core.ops import builtin | |||
| from ..core.ops.builtin import Elemwise | |||
| from ..core.tensor import utils | |||
| from ..core.tensor.array_method import _elwise_apply | |||
| from ..core.tensor.utils import isscalar, setscalar | |||
| from ..core.tensor.utils import astype, isscalar, setscalar | |||
| from ..device import get_default_device | |||
| from ..jit.tracing import is_tracing | |||
| from ..tensor import Tensor | |||
| @@ -77,7 +78,7 @@ __all__ = [ | |||
| def _elwise(*args, mode): | |||
| tensor_args = list(filter(lambda x: isinstance(x, Tensor), args)) | |||
| tensor_args = list(filter(lambda x: isinstance(x, (Tensor, VarNode)), args)) | |||
| if len(tensor_args) == 0: | |||
| dtype = utils.dtype_promotion(args) | |||
| first_arg = Tensor(args[0], dtype=dtype, device=get_default_device()) | |||
| @@ -109,7 +110,7 @@ def _elwise(*args, mode): | |||
| Elemwise.Mode.ROUND, | |||
| ) and np.issubdtype(args[0].dtype, np.integer): | |||
| return args[0] | |||
| args = tuple(map(lambda x: x.astype("float32"), args)) | |||
| args = tuple(map(lambda x: astype(x, "float32"), args)) | |||
| return _elwise_apply(args, mode) | |||
| @@ -65,7 +65,6 @@ def get_owner_opr_inputs(var: VarNode) -> List[VarNode]: | |||
| """ | |||
| Gets the inputs of owner opr of a variable. | |||
| """ | |||
| assert isinstance(var, VarNode) | |||
| return var.owner.inputs | |||
| @@ -74,7 +73,6 @@ def get_owner_opr_type(var: VarNode) -> str: | |||
| Gets the type of owner opr of a variable. | |||
| """ | |||
| assert isinstance(var, VarNode) | |||
| return var.owner.type | |||
| @@ -109,7 +107,7 @@ def graph_traversal(outputs: VarNode): | |||
| var2oprs = collections.defaultdict(list) | |||
| opr2receivers = collections.defaultdict(list) | |||
| queue = list(map(lambda x: x.owner, outputs)) | |||
| queue = list(set(map(lambda x: x.owner, outputs))) | |||
| visited = set(map(lambda x: x.id, queue)) | |||
| # iterate through whole comp_graph, fill in meta information | |||
| @@ -143,12 +141,15 @@ def graph_traversal(outputs: VarNode): | |||
| return map_oprs, map_vars, var2oprs, opr2receivers, indegree2opr, opr2indegree | |||
| def get_oprs_seq(outputs: List[VarNode], prune_reshape=False) -> List[OperatorNode]: | |||
| def get_oprs_seq( | |||
| outputs: List[VarNode], prune_reshape=False, prune_immtensor=True | |||
| ) -> List[OperatorNode]: | |||
| """ | |||
| Gets oprs in some topological order for a dumped model. | |||
| :param outputs: model outputs. | |||
| :param prune_reshape: whether to prune the useless operators during inference. | |||
| :param prune_reshape: whether to prune the useless operators used by Reshape opr during inference. | |||
| :param prune_immtensor: whether to prune the ImmutableTensor opr. | |||
| :return: opr list with some correct execution order. | |||
| """ | |||
| @@ -160,9 +161,7 @@ def get_oprs_seq(outputs: List[VarNode], prune_reshape=False) -> List[OperatorNo | |||
| opr_id = indegree2opr[0].pop() | |||
| opr = map_oprs[opr_id] | |||
| nr_remain -= 1 | |||
| # skip const value generation operator | |||
| if get_opr_type(opr) != "ImmutableTensor": | |||
| if opr.type != "ImmutableTensor" or not prune_immtensor: | |||
| oprs_seq.append(opr) | |||
| for post_id in opr2receivers[opr_id]: | |||
| @@ -0,0 +1,682 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import collections | |||
| import fnmatch | |||
| import itertools | |||
| import re | |||
| from collections import OrderedDict | |||
| from typing import Dict, List | |||
| import numpy as np | |||
| from ..core._imperative_rt import ComputingGraph | |||
| from ..core.tensor import megbrain_graph as G | |||
| from .comp_graph_tools import get_dep_vars, get_opr_type, get_oprs_seq | |||
| from .network_node import ( | |||
| NetworkNode, | |||
| Host2DeviceCopy, | |||
| ImmutableTensor, | |||
| OpNode, | |||
| VarNode, | |||
| str_to_mge_class, | |||
| ) | |||
| class Network: | |||
| def __init__(self): | |||
| self.input_vars = [] # input var of graph | |||
| self._orig_inputs = [] | |||
| self.output_vars = [] # output var of graph | |||
| self._orig_outputs = [] | |||
| self.all_oprs_map = OrderedDict() | |||
| self.all_vars_map = OrderedDict() | |||
| self.graph = ComputingGraph() | |||
| @classmethod | |||
| def load(cls, model_path: str, outspec: List[str] = None): | |||
| """ | |||
| Loads a computing graph as a Network object. | |||
| :param model_path: file path of mge model. | |||
| :param outspec: only load the subgraph with outspec as its endpoints. | |||
| """ | |||
| self = cls() | |||
| _, _, outputs = G.load_graph(model_path) | |||
| if outspec is not None: | |||
| output_spec = outspec.copy() | |||
| all_vars = get_dep_vars(outputs) + outputs | |||
| new_outputs = {} | |||
| for i in all_vars: | |||
| if i.name in output_spec: | |||
| new_outputs[i.name] = i | |||
| output_spec.remove(i.name) | |||
| assert len(output_spec) == 0, "Can not find {} in this model".format( | |||
| output_spec | |||
| ) | |||
| outputs = [new_outputs[i] for i in outspec] | |||
| self._orig_outputs = outputs | |||
| self.add_dep_oprs(*outputs) | |||
| for x in self._orig_inputs: | |||
| self.input_vars.append(self._get_var(x)) | |||
| for x in self._orig_outputs: | |||
| self.output_vars.append(self._get_var(x)) | |||
| self.graph = self._orig_outputs[0].graph | |||
| return self | |||
| def _compile(self): | |||
| self.all_oprs_map = {} | |||
| self.all_vars_map = {} | |||
| for opr in self.all_oprs: | |||
| if isinstance(opr, (ImmutableTensor, Host2DeviceCopy)): | |||
| opr.compile(self.graph) | |||
| else: | |||
| opr.compile() | |||
| if opr.name is not None: | |||
| opr._opr.name = opr.name | |||
| self.all_oprs_map[opr._opr.id] = opr | |||
| for o in opr.outputs: | |||
| self.all_vars_map[o.var.id] = o | |||
| def dump( | |||
| self, | |||
| file, | |||
| *, | |||
| keep_var_name: int = 1, | |||
| keep_opr_name: bool = False, | |||
| keep_param_name: bool = False, | |||
| keep_opr_priority: bool = False, | |||
| strip_info_file=None, | |||
| append_json=False, | |||
| optimize_for_inference=True, | |||
| append=False, | |||
| **kwargs | |||
| ): | |||
| """ | |||
| Serializes graph to file. | |||
| :param file: output file, could be file object or filename. | |||
| :param append: whether output is appended to ``file``. | |||
| Only works when ``file`` is str. | |||
| :param keep_var_name: level for keeping variable names: | |||
| * 0: none of the names are kept | |||
| * 1: (default)keep names of output vars | |||
| * 2: keep names of all (output and internal) vars | |||
| :param keep_opr_name: whether to keep operator names. | |||
| :param keep_param_name: whether to keep param names, so param values can be | |||
| easily manipulated after loading model | |||
| :param keep_opr_priority: whether to keep priority setting for operators | |||
| :param strip_info_file: a string for path or a file handler. if is not None, | |||
| then the dump information for code strip would be written to ``strip_info_file`` | |||
| :param append_json: will be check when `strip_info_file` is not None. if set | |||
| true, the information for code strip will be append to strip_info_file. | |||
| if set false, will rewrite strip_info_file | |||
| :param optimize_for_inference: enbale optmizations, | |||
| will skip all optimize options if this is False. Default: True | |||
| :Keyword Arguments: | |||
| * enable_io16xc32 -- | |||
| whether to use float16 for I/O between oprs and use | |||
| float32 as internal computation precision. Note the output var would be | |||
| changed to float16. | |||
| * enable_ioc16 -- | |||
| whether to use float16 for both I/O and computation | |||
| precision. | |||
| * enable_hwcd4 -- | |||
| whether to use NHWCD4 data layout. This is faster on some | |||
| OpenCL backend. | |||
| * enable_nchw88 -- | |||
| whether to use NCHW88 data layout, currently | |||
| used in X86 AVX backend. | |||
| * enable_nchw44 -- | |||
| whether to use NCHW44 data layout, currently | |||
| used in arm backend. | |||
| * enable_nchw44_dot -- | |||
| whether to use NCHW44_dot data layout, currently | |||
| used in armv8.2+dotprod backend. | |||
| * enable_nchw4 -- | |||
| whether to use NCHW4 data layout, currently | |||
| used in nvidia backend(based on cudnn). | |||
| * enable_nchw32 -- | |||
| whether to use NCHW32 data layout, currently | |||
| used in nvidia backend with tensorcore(based on cudnn). | |||
| * enable_chwn4 -- | |||
| whether to use CHWN4 data layout, currently | |||
| used in nvidia backend with tensorcore. | |||
| * enable_fuse_conv_bias_nonlinearity: whether to fuse conv+bias+nonlinearty | |||
| into one opr. | |||
| * enable_fuse_conv_bias_with_z: whether to fuse conv_bias with z | |||
| input for inference on nvidia backend(this optimization pass will | |||
| result in mismatch of the precision of output of training and | |||
| inference) | |||
| """ | |||
| self._compile() | |||
| out = [G.VarNode(var.var) for var in self.output_vars] | |||
| if optimize_for_inference: | |||
| out = G.optimize_for_inference(out, **kwargs) | |||
| dump_content, _ = G.dump_graph( | |||
| out, | |||
| keep_var_name=keep_var_name, | |||
| keep_opr_name=keep_opr_name, | |||
| keep_param_name=keep_param_name, | |||
| keep_opr_priority=keep_opr_priority, | |||
| strip_info_file=strip_info_file, | |||
| append_json=append_json, | |||
| ) | |||
| if isinstance(file, str): | |||
| permission = "wb" if append == False else "ab" | |||
| file = open(file, permission) | |||
| file.write(dump_content) | |||
| def make_const(self, data, name=None, device=None): | |||
| """Makes an ImmutableTensor OpNode to provide a parameter for the network. | |||
| """ | |||
| node = ImmutableTensor(data, name, device, self.graph) | |||
| node.compile(self.graph) | |||
| return node.outputs[0] | |||
| def make_input_node(self, shape, dtype, name=None, device=None): | |||
| """Makes a Host2DeviceCopy OpNode to provide an input varnode for the network. | |||
| """ | |||
| node = Host2DeviceCopy(shape, dtype, name, device) | |||
| node.compile(self.graph) | |||
| return node.outputs[0] | |||
| def add_output(self, *vars: VarNode): | |||
| """Adds vars into the network output node list | |||
| """ | |||
| for var in vars: | |||
| if var not in self.output_vars: | |||
| self.output_vars.append(var) | |||
| def remove_output(self, *vars: VarNode): | |||
| """Removes vars from the network output node list. | |||
| """ | |||
| for var in vars: | |||
| if var in self.output_vars: | |||
| self.output_vars.remove(var) | |||
| def add_dep_oprs(self, *vars): | |||
| """Adds dependent opnodes and varnodes of vars into network | |||
| """ | |||
| oprs = get_oprs_seq(vars, False, False) | |||
| for mge_opr in oprs: | |||
| if get_opr_type(mge_opr) == "Host2DeviceCopy": | |||
| self._orig_inputs.extend(mge_opr.outputs) | |||
| opr = self._add_opr(mge_opr) | |||
| if opr is not None: | |||
| for x in mge_opr.inputs: | |||
| opr.add_inp_var(self._get_var(x)) | |||
| # set out var | |||
| for x in mge_opr.outputs: | |||
| opr.add_out_var(self._get_var(x)) | |||
| return [self.all_vars_map[var.id] for var in vars] | |||
| def modify_opr_names(self, modifier): | |||
| """Modifies names of operators **inplace**; useful for merging loaded | |||
| network into another network | |||
| :param modifier: a string to be prepended to the name, or a function | |||
| that maps from name to name | |||
| :type modifier: str or callable | |||
| """ | |||
| if isinstance(modifier, str): | |||
| om = modifier | |||
| modifier = lambda v: "{}.{}".format(om, v) | |||
| assert isinstance(modifier, collections.Callable) | |||
| for i in self.all_oprs: | |||
| v0 = i.name | |||
| v1 = modifier(v0) | |||
| assert isinstance(v1, str) | |||
| i.name = v1 | |||
| def reset_batch_size(self, batchsize, *, blacklist=()): | |||
| """Helper for reset batch size; first dimension of all data providers | |||
| not in blacklist are assumed to be the batch size | |||
| :param blacklist: data provider names whose first dimension is not | |||
| batchbatch size | |||
| """ | |||
| blacklist = set(blacklist) | |||
| prev_batchsize = None | |||
| for i in self.data_providers_filter: | |||
| if i.name in blacklist: | |||
| blacklist.remove(i.name) | |||
| else: | |||
| shp = list(i.shape) | |||
| if prev_batchsize is None: | |||
| prev_batchsize = shp[0] | |||
| else: | |||
| assert prev_batchsize == shp[0], ( | |||
| "batchsize mismatch: batchsize={} " | |||
| "shape={} dp={}".format(prev_batchsize, shp, i.name) | |||
| ) | |||
| shp[0] = batchsize | |||
| i.shape = tuple(shp) | |||
| assert prev_batchsize is not None, "no data provider found" | |||
| assert not blacklist, "unused items in blacklist: {}".format(blacklist) | |||
| def replace_vars(self, repl_dict: Dict[VarNode, VarNode]): | |||
| """ | |||
| Replaces vars in the graph. | |||
| :param repl_dict: the map {old_var: new_var} that specifies how to replace the vars. | |||
| """ | |||
| for var in self.all_vars: | |||
| if var in repl_dict: | |||
| repl_var = repl_dict[var] | |||
| owner = repl_var.owner | |||
| idx = owner.outputs.index(repl_var) | |||
| owner.outputs[idx] = var | |||
| var.__dict__.update(repl_var.__dict__) | |||
| def replace_oprs(self, repl_dict: Dict[OpNode, OpNode]): | |||
| """ | |||
| Replaces operators in the graph. | |||
| :param oprmap: the map {old_opr: new_opr} that specifies how to replace the operators. | |||
| """ | |||
| for opr in self.all_oprs: | |||
| if opr in repl_dict: | |||
| assert len(opr.outputs) == len( | |||
| repl_dict[opr].outputs | |||
| ), "can not replace {} with {}".format(type(opr), type(repl_dict[opr])) | |||
| repl_dict[opr].outputs = opr.outputs | |||
| for ind, var in enumerate(opr.outputs): | |||
| var.owner = repl_dict[opr] | |||
| var.__dict__.update(repl_dict[opr].outputs[ind].__dict__) | |||
| def get_opr_by_type(self, oprcls, unique=True): | |||
| assert issubclass(oprcls, OpNode) | |||
| rst = self.opr_filter.type(oprcls).as_list() | |||
| if unique: | |||
| assert len(rst) == 1, "{} operators of type {} found".format( | |||
| len(rst), oprcls | |||
| ) | |||
| (rst,) = rst | |||
| return rst | |||
| def get_opr_by_name(self, name, unique=True): | |||
| rst = self.opr_filter.name(name).as_list() | |||
| if unique: | |||
| assert len(rst) == 1, "{} operators of type {} found".format(len(rst), name) | |||
| (rst,) = rst | |||
| return rst | |||
| def get_var_by_name(self, name, unique=True): | |||
| rst = self.var_filter.name(name).as_list() | |||
| if unique: | |||
| assert len(rst) == 1, "{} operators of type {} found".format(len(rst), name) | |||
| (rst,) = rst | |||
| return rst | |||
| def get_var_receive_oprs(self, var): | |||
| """ Gets all oprs which use var as input | |||
| """ | |||
| return self.opr_filter.has_input(var).as_list() | |||
| def get_dep_oprs(self, var): | |||
| """Gets dependent oprs of var | |||
| """ | |||
| return get_oprs_seq(var, False, False) | |||
| @property | |||
| def opr_filter(self): | |||
| """Filter on all opnodes of the Network. | |||
| """ | |||
| oprs = self.all_oprs | |||
| return NodeFilter(itertools.islice(oprs, len(oprs))) | |||
| @property | |||
| def var_filter(self): | |||
| """Filter on all varnode of the Network. | |||
| """ | |||
| vars = self.all_vars | |||
| return NodeFilter(itertools.islice(vars, len(vars))) | |||
| @property | |||
| def params_filter(self): # all immutable tensor | |||
| """Filter on all parameters (ImmutableTensor Opr) of the Network | |||
| """ | |||
| return self.opr_filter.param_provider() | |||
| @property | |||
| def data_providers_filter(self): # all host2devicecopy | |||
| """Filter on all input nodes (Host2DeviceCopy Opr) of the Network | |||
| """ | |||
| return self.opr_filter.data_provider() | |||
| @property | |||
| def dest_vars(self): | |||
| """Output varnodes of the Network. | |||
| """ | |||
| return self.output_vars | |||
| @property | |||
| def all_oprs(self): | |||
| return get_oprs_seq(self.output_vars, False, False) | |||
| @property | |||
| def all_vars(self): | |||
| return get_dep_vars(self.output_vars) | |||
| @property | |||
| def all_vars_dict(self): | |||
| return self.var_filter.as_dict() | |||
| @property | |||
| def all_oprs_dict(self): | |||
| return self.opr_filter.as_dict() | |||
| # used for loading and building graph | |||
| def _add_opr(self, x): | |||
| # TODO: use megbrain C++ RTTI to replace type string | |||
| if x.id not in self.all_oprs_map: | |||
| self.all_oprs_map[x.id] = str_to_mge_class(get_opr_type(x)).load(x) | |||
| return self.all_oprs_map[x.id] | |||
| else: | |||
| return None | |||
| def _get_opr(self, x): | |||
| if x.id in self.all_oprs_map: | |||
| return self.all_oprs_map[x.id] | |||
| else: | |||
| return None | |||
| def _get_var(self, x): | |||
| # auto convert to VarNode of Network | |||
| if x.id not in self.all_vars_map: | |||
| self.all_vars_map[x.id] = VarNode.load(x, self._get_opr(x.owner)) | |||
| return self.all_vars_map[x.id] | |||
| def as_varnode(obj): | |||
| """convert a :class:`.VarNode` compatible object to :class:`.VarNode`. | |||
| :param obj: it must be one of the following: | |||
| 1. a :class:`.VarNode` object | |||
| 2. a :class:`.OpNode` object that has unique output | |||
| 3. an iterable that produces either type 1 or 2, with length 1 | |||
| :rtype: :class:`.VarNode` | |||
| """ | |||
| if type(obj) is VarNode: | |||
| return obj | |||
| if isinstance(obj, OpNode): | |||
| assert len(obj.outputs) == 1, ( | |||
| "operator {} must have one output to be converted to VarNode; " | |||
| "got {} actually".format(obj, len(obj.outputs)) | |||
| ) | |||
| ret = obj.outputs[0] | |||
| assert type(ret) is VarNode | |||
| return ret | |||
| assert isinstance( | |||
| obj, collections.Iterable | |||
| ), "{} is not compatible with VarNode".format(obj) | |||
| val = list(obj) | |||
| assert ( | |||
| len(val) == 1 | |||
| ), "can not convert sequence of length {} to VarNode ({})".format( | |||
| len(val), (lambda s: s if len(s) < 50 else s[:50] + " ...")(str(val)) | |||
| ) | |||
| return as_varnode(val[0]) | |||
| def as_oprnode(obj): | |||
| """convert a :class:`.OpNode` compatible object to | |||
| :class:`.OpNode`; it works like :func:`as_varnode`.""" | |||
| if type(obj) is VarNode: | |||
| return obj.owner | |||
| if isinstance(obj, OpNode): | |||
| return obj | |||
| assert isinstance( | |||
| obj, collections.Iterable | |||
| ), "{} is not compatible with OpNode".format(obj) | |||
| val = list(obj) | |||
| assert ( | |||
| len(val) == 1 | |||
| ), "can not convert sequence of length {} to " "OpNode({})".format(len(val), val) | |||
| return as_oprnode(val[0]) | |||
| class NodeFilter: | |||
| """Filter on node iterator. This class is an iterator of | |||
| :class:`.NetworkNode` objects and multiple filtering conditions and | |||
| mappers can be chained. | |||
| Example:: | |||
| # find all :class:`.ImmutableTensor` nodes | |||
| for i in NodeFilter(node_iter).param_provider(): | |||
| print(i) | |||
| # find all :class:`.ImmutableTensor` nodes that end with ':W' | |||
| for i in NodeFilter(node_iter).param_provider().name('*:W'): | |||
| print(i) | |||
| # number of inputs | |||
| nr_input = NodeFilter(node_iter).data_provider().as_count() | |||
| """ | |||
| _iter = None | |||
| def __init__(self, node_iter): | |||
| """ | |||
| :param node_iter: iterator to :class:`.NetworkNode`, or a | |||
| :class:`.VarNode`-compatible object; in the later case, its | |||
| dependent oprs would be used | |||
| """ | |||
| if isinstance(node_iter, VarNode): | |||
| oprs = get_oprs_seq(node_iter, False, False) | |||
| node_iter = itertools.islice(oprs, len(oprs) - 1) | |||
| if isinstance(node_iter, OpNode): | |||
| oprs = get_oprs_seq(node_iter.inputs, False, False) | |||
| node_iter = itertools.islice(oprs, len(oprs) - 1) | |||
| assert isinstance(node_iter, collections.Iterable) | |||
| if (not isinstance(node_iter, NodeFilter)) and type( | |||
| self | |||
| ) is not NodeFilterCheckType: | |||
| node_iter = NodeFilterCheckType(node_iter, NetworkNode) | |||
| self._iter = node_iter | |||
| @classmethod | |||
| def make_all_deps(cls, *dest_vars): | |||
| """make a :class:`NodeFilter` that contains all deps of given vars""" | |||
| return cls(list(get_oprs_seq(dest_vars, False, False))) | |||
| def __iter__(self): | |||
| """to be overwritten by subclass to implement filters""" | |||
| return iter(self._iter) | |||
| def type(self, node_type): | |||
| """filter by specific node type | |||
| :param node_type: node type class | |||
| :return: a new :class:`NodeFilter` object | |||
| """ | |||
| return NodeFilterType(self, node_type) | |||
| def check_type(self, node_type): | |||
| """assert that all oprs produced by this iterator are instances of | |||
| certain type | |||
| :param node_type: node type class | |||
| :return: a new :class:`NodeFilter` object | |||
| :raises TypeError: if type check failed | |||
| """ | |||
| return NodeFilterCheckType(self, node_type) | |||
| def not_type(self, node_type): | |||
| """remove oprs of specific type | |||
| :param node_type: node type class | |||
| :return: a new :class:`NodeFilter` object | |||
| """ | |||
| return NodeFilterNotType(self, node_type) | |||
| def param_provider(self): | |||
| """get :class:`.ParamProvider` oprs; shorthand for | |||
| ``.type(ParamProvider)``""" | |||
| return self.type(ImmutableTensor) | |||
| def data_provider(self): | |||
| """get :class:`.DataProvider` oprs; shorthand for | |||
| ``.type(DataProvider)``""" | |||
| return self.type(Host2DeviceCopy) | |||
| def name(self, pattern, ignorecase=True): | |||
| """filter by node name | |||
| :param pattern: a string in glob syntax that can contain ``?`` and | |||
| ``*`` to match a single or arbitrary characters. | |||
| :type pattern: :class:`str` | |||
| :param ignorecase: whether to ignroe case | |||
| :type ignorecase: bool | |||
| :return: a new :class:`NodeFilter` object | |||
| """ | |||
| return NodeFilterName(self, pattern, ignorecase) | |||
| def has_input(self, var): | |||
| """an opr is kept if it has given var as one of its inputs | |||
| :param var: var node to checked | |||
| :return: a new :class:`NodeFilter` object | |||
| """ | |||
| return NodeFilterHasInput(self, var) | |||
| def as_list(self): | |||
| """consume this iterator and return its content as a list | |||
| :rtype: [:class:`.GraphNodeBase`] | |||
| """ | |||
| return list(self) | |||
| def as_unique(self): | |||
| """assert that this iterator yields only one node and return it | |||
| :return: the unique node | |||
| :rtype: :class:`.GraphNodeBase` | |||
| :raises ValueError: if this iterator does not yield a unique node | |||
| """ | |||
| (opr,) = self | |||
| return opr | |||
| def as_dict(self): | |||
| """construct an ordered dict to map from node names to objects in | |||
| this iterator | |||
| :rtype: :class:`OrderedDict` | |||
| """ | |||
| return collections.OrderedDict((i.name, i) for i in self) | |||
| def as_count(self): | |||
| """consume this iterator and get the number of elements | |||
| :rtype: int | |||
| """ | |||
| return sum(1 for _ in self) | |||
| class NodeFilterType(NodeFilter): | |||
| """see :meth:`NodeFilter.type`""" | |||
| _node_type = None | |||
| def __init__(self, node_iter, node_type): | |||
| assert issubclass(node_type, NetworkNode), "bad opr type: {}".format( | |||
| node_type | |||
| ) | |||
| super().__init__(node_iter) | |||
| self._node_type = node_type | |||
| def __iter__(self): | |||
| for i in self._iter: | |||
| if isinstance(i, self._node_type): | |||
| yield i | |||
| class NodeFilterNotType(NodeFilterType): | |||
| """see :meth:`NodeFilter.not_type`""" | |||
| def __iter__(self): | |||
| for i in self._iter: | |||
| if not isinstance(i, self._node_type): | |||
| yield i | |||
| class NodeFilterCheckType(NodeFilterType): | |||
| """see :meth:`NodeFilter.check_type`""" | |||
| def __iter__(self): | |||
| for i in self._iter: | |||
| if not isinstance(i, self._node_type): | |||
| raise TypeError( | |||
| "all nodes should be {}; got {!r}".format(self._node_type, i) | |||
| ) | |||
| yield i | |||
| class NodeFilterHasInput(NodeFilter): | |||
| """see :meth:`NodeFilter.has_input`""" | |||
| _var = None | |||
| def __init__(self, node_iter, var): | |||
| var = as_varnode(var) | |||
| super().__init__(node_iter) | |||
| self.var = var | |||
| def __iter__(self): | |||
| for i in self._iter: | |||
| assert isinstance( | |||
| i, OpNode | |||
| ), "has_input() must be used with OpNode; " "got {!r}".format(i) | |||
| if self.var in i.inputs: | |||
| yield i | |||
| class NodeFilterName(NodeFilter): | |||
| """see :meth:`NodeFilter.name`""" | |||
| _re = None | |||
| def __init__(self, node_iter, pattern, ignorecase): | |||
| super().__init__(node_iter) | |||
| self._re = self.make_re(pattern, ignorecase) | |||
| @classmethod | |||
| def make_re(cls, pattern, ignorecase=True): | |||
| assert isinstance(pattern, str), "bad pattern: {!r}".format(pattern) | |||
| assert isinstance(ignorecase, bool) | |||
| flags = 0 | |||
| if ignorecase: | |||
| flags |= re.IGNORECASE | |||
| return re.compile(fnmatch.translate(pattern), flags=flags) | |||
| def __iter__(self): | |||
| for i in self._iter: | |||
| if self._re.match(i.name): | |||
| yield i | |||
| @@ -0,0 +1,628 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import json | |||
| import sys | |||
| from typing import Callable | |||
| from ..core import _imperative_rt as rt | |||
| from ..core._wrap import Device | |||
| from ..core.ops import builtin | |||
| from ..core.tensor.megbrain_graph import InputNode | |||
| from ..tensor import Tensor | |||
| from .comp_graph_tools import replace_vars | |||
| class NetworkNode: | |||
| pass | |||
| class VarNode(NetworkNode): | |||
| def __init__(self, owner_opr=None, name=None): | |||
| self.var = None | |||
| self.owner = owner_opr | |||
| self.name = name | |||
| self.id = id(self) | |||
| @classmethod | |||
| def load(cls, sym_var, owner_opr): | |||
| obj = cls() | |||
| obj.var = sym_var # mgb varnode | |||
| obj.name = sym_var.name | |||
| obj.owner = owner_opr | |||
| return obj | |||
| @property | |||
| def shape(self): | |||
| rst = None | |||
| if self.var: | |||
| try: | |||
| rst = self.var.shape | |||
| except: | |||
| rst = None | |||
| return rst | |||
| @property | |||
| def dtype(self): | |||
| return self.var.dtype if self.var else None | |||
| def set_owner_opr(self, owner_opr): | |||
| self.owner_opr = owner_opr | |||
| class OpNode(NetworkNode): | |||
| opdef = None | |||
| type = None | |||
| def __init__(self): | |||
| self.inputs = [] | |||
| self.outputs = [] | |||
| self.params = {} | |||
| self._opr = None # mgb opnode | |||
| self.id = id(self) | |||
| @classmethod | |||
| def load(cls, opr): | |||
| obj = cls() | |||
| obj.params = json.loads(opr.params) | |||
| obj.name = opr.name | |||
| obj._opr = opr | |||
| return obj | |||
| def compile(self, graph=None): | |||
| op = self.opdef(**self.params) | |||
| args = [i.var for i in self.inputs] | |||
| outputs = rt.invoke_op(op, args) | |||
| assert len(outputs) == len(self.outputs) | |||
| self._opr = outputs[0].owner | |||
| for i in range(len(self.outputs)): | |||
| self.outputs[i].var = outputs[i] | |||
| self.outputs[i].var.name = self.outputs[i].name | |||
| assert self.outputs[i].owner is self | |||
| def add_inp_var(self, x): | |||
| self.inputs.append(x) | |||
| def add_out_var(self, x): | |||
| self.outputs.append(x) | |||
| def str_to_mge_class(classname): | |||
| # TODO: use megbrain C++ RTTI to replace type string | |||
| if classname == "RNGOpr<MegDNNOpr>": | |||
| classname = "RNGOpr" | |||
| oprcls = getattr(sys.modules[__name__], classname, None) | |||
| return oprcls if oprcls else ReadOnlyOpNode | |||
| class Host2DeviceCopy(OpNode): | |||
| type = "Host2DeviceCopy" | |||
| def __init__(self, shape=None, dtype=None, name=None, device=None): | |||
| super().__init__() | |||
| self.shape = shape | |||
| self.dtype = dtype | |||
| self.name = name | |||
| self.device = Device(device).to_c() if device else Device("xpux").to_c() | |||
| self.outputs = [] | |||
| @classmethod | |||
| def load(cls, opr): | |||
| self = cls() | |||
| self.outputs = [] | |||
| assert len(opr.outputs) == 1, "wrong number of outputs" | |||
| self.shape = opr.outputs[0].shape | |||
| self.dtype = opr.outputs[0].dtype | |||
| self.name = opr.outputs[0].name | |||
| self.device = opr.outputs[0].comp_node | |||
| self._opr = opr | |||
| return self | |||
| def compile(self, graph): | |||
| outputs = rt.make_h2d(graph, self.device, self.dtype, self.shape, self.name) | |||
| self._opr = outputs.owner | |||
| if len(self.outputs) == 0: | |||
| self.outputs.append(VarNode(self, self.name)) | |||
| self.outputs[0].var = outputs | |||
| assert self.outputs[0].owner is self | |||
| class ImmutableTensor(OpNode): | |||
| type = "ImmutableTensor" | |||
| def __init__(self, data=None, name=None, device=None, graph=None): | |||
| super().__init__() | |||
| self.name = name | |||
| self.outputs = [] | |||
| self.graph = graph | |||
| if data is not None: | |||
| self.set_value(data, device) | |||
| @property | |||
| def device(self): | |||
| return self._opr.outputs[0].comp_node if self._opr else None | |||
| @device.setter | |||
| def device(self, device): | |||
| self.set_value(self.numpy(), device) | |||
| @property | |||
| def shape(self): | |||
| return self.outputs[0].shape | |||
| @property | |||
| def dtype(self): | |||
| return self._opr.outputs[0].dtype if self._opr else None | |||
| def numpy(self): | |||
| return self._opr.outputs[0].value if self._opr else None | |||
| def set_value(self, data, device=None): | |||
| assert self.graph is not None | |||
| cn = device if device else self.device | |||
| assert isinstance(data, (int, float, np.ndarray)) | |||
| if isinstance(data, (int, float)): | |||
| data = np.array(data) | |||
| if data.dtype == np.float64: | |||
| data = data.astype(np.float32) | |||
| elif data.dtype == np.int64: | |||
| data = data.astype(np.int32) | |||
| varnode = rt.make_const(self.graph, data, cn, data.dtype, self.name) | |||
| if len(self.outputs) == 0: | |||
| self.outputs.append(VarNode(self, self.name)) | |||
| self.outputs[0].var = varnode | |||
| self._opr = varnode.owner | |||
| @classmethod | |||
| def load(cls, opr): | |||
| self = cls() | |||
| self.outputs = [] | |||
| self._opr = opr | |||
| self.name = opr.outputs[0].name | |||
| self.graph = opr.graph | |||
| return self | |||
| def compile(self, graph): | |||
| assert self.outputs[0].var is self._opr.outputs[0] | |||
| assert self.outputs[0].owner is self | |||
| if self.graph != graph: | |||
| self.graph = graph | |||
| self.set_value(self.numpy()) | |||
| if self.name is not None: | |||
| self.outputs[0].var.name = self.name | |||
| class ReadOnlyOpNode(OpNode): | |||
| @classmethod | |||
| def load(cls, opr): | |||
| obj = super(ReadOnlyOpNode, cls).load(opr) | |||
| obj.type = opr.type | |||
| return obj | |||
| def compile(self): | |||
| assert self._opr is not None | |||
| assert len(self.inputs) == len(self._opr.inputs) | |||
| assert len(self.outputs) == len(self._opr.outputs) | |||
| repl_dict = {} | |||
| for ind, i in enumerate(self.inputs): | |||
| if i.var != self._opr.inputs[ind]: | |||
| repl_dict[self._opr.inputs[ind]] = i.var | |||
| if bool(repl_dict): | |||
| out_vars = replace_vars(self._opr.outputs, repl_dict) | |||
| for ind, o in enumerate(self.outputs): | |||
| o.var = out_vars[ind] | |||
| class Elemwise(OpNode): | |||
| type = "Elemwise" | |||
| opdef = builtin.Elemwise | |||
| class Reduce(OpNode): | |||
| type = "Reduce" | |||
| opdef = builtin.Reduce | |||
| class TypeCvt(OpNode): | |||
| type = "TypeCvt" | |||
| opdef = builtin.TypeCvt | |||
| @classmethod | |||
| def load(cls, opr): | |||
| obj = super(TypeCvt, cls).load(opr) | |||
| t_dtype = opr.outputs[0].dtype | |||
| obj.params["dtype"] = t_dtype | |||
| return obj | |||
| class MatrixInverse(OpNode): | |||
| type = "MatrixInverse" | |||
| opdef = builtin.MatrixInverse | |||
| class MatrixMul(OpNode): | |||
| type = "MatrixMul" | |||
| opdef = builtin.MatrixMul | |||
| class BatchedMatrixMul(OpNode): | |||
| type = "BatchedMatmul" | |||
| opdef = builtin.BatchedMatrixMul | |||
| class Dot(OpNode): | |||
| type = "Dot" | |||
| opdef = builtin.Dot | |||
| class SVD(OpNode): | |||
| type = "SVD" | |||
| opdef = builtin.SVD | |||
| class ConvolutionForward(OpNode): | |||
| type = "Convolution" | |||
| opdef = builtin.Convolution | |||
| class ConvolutionBackwardData(OpNode): | |||
| type = "ConvTranspose" | |||
| opdef = builtin.ConvolutionBackwardData | |||
| class DeformableConvForward(OpNode): | |||
| type = "DeformableConv" | |||
| opdef = builtin.DeformableConv | |||
| class GroupLocalForward(OpNode): | |||
| type = "GroupLocal" | |||
| opdef = builtin.GroupLocal | |||
| class PoolingForward(OpNode): | |||
| type = "Pooling" | |||
| opdef = builtin.Pooling | |||
| class AdaptivePoolingForward(OpNode): | |||
| type = "AdaptivePooling" | |||
| opdef = builtin.AdaptivePooling | |||
| class ROIPoolingForward(OpNode): | |||
| type = "ROIPooling" | |||
| opdef = builtin.ROIPooling | |||
| class DeformablePSROIPoolingForward(OpNode): | |||
| type = "DeformablePSROIPooling" | |||
| opdef = builtin.DeformablePSROIPooling | |||
| class ConvBiasForward(OpNode): | |||
| type = "ConvBias" | |||
| opdef = builtin.ConvBias | |||
| @classmethod | |||
| def load(cls, opr): | |||
| obj = super(ConvBiasForward, cls).load(opr) | |||
| obj.params["dtype"] = opr.outputs[0].dtype | |||
| return obj | |||
| class BatchConvBiasForward(OpNode): | |||
| type = "BatchConvBias" | |||
| opdef = builtin.BatchConvBias | |||
| @classmethod | |||
| def load(cls, opr): | |||
| obj = super(BatchConvBiasForward, cls).load(opr) | |||
| obj.params["dtype"] = opr.outputs[0].dtype | |||
| return obj | |||
| class BatchNormForward(OpNode): | |||
| type = "BatchNorm" | |||
| opdef = builtin.BatchNorm | |||
| class ROIAlignForward(OpNode): | |||
| type = "ROIAlign" | |||
| opdef = builtin.ROIAlign | |||
| class WarpPerspectiveForward(OpNode): | |||
| type = "WarpPerspective" | |||
| opdef = builtin.WarpPerspective | |||
| class WarpAffineForward(OpNode): | |||
| type = "WarpAffine" | |||
| opdef = builtin.WarpAffine | |||
| class RemapForward(OpNode): | |||
| type = "Remap" | |||
| opdef = builtin.Remap | |||
| class ResizeForward(OpNode): | |||
| type = "Resize" | |||
| opdef = builtin.Resize | |||
| class IndexingOneHot(OpNode): | |||
| type = "IndexingOneHot" | |||
| opdef = builtin.IndexingOneHot | |||
| class IndexingSetOneHot(OpNode): | |||
| type = "IndexingSetOneHot" | |||
| opdef = builtin.IndexingSetOneHot | |||
| class Copy(OpNode): | |||
| type = "Copy" | |||
| opdef = builtin.Copy | |||
| @classmethod | |||
| def load(cls, opr): | |||
| obj = super(Copy, cls).load(opr) | |||
| obj.params["comp_node"] = opr.outputs[0].comp_node | |||
| return obj | |||
| class ArgsortForward(OpNode): | |||
| type = "Argsort" | |||
| opdef = builtin.Argsort | |||
| class Argmax(OpNode): | |||
| type = "Argmax" | |||
| opdef = builtin.Argmax | |||
| class Argmin(OpNode): | |||
| type = "Argmin" | |||
| opdef = builtin.Argmin | |||
| class CondTake(OpNode): | |||
| type = "CondTake" | |||
| opdef = builtin.CondTake | |||
| class TopK(OpNode): | |||
| type = "TopK" | |||
| opdef = builtin.TopK | |||
| class NvOf(OpNode): | |||
| type = "NvOf" | |||
| opdef = builtin.NvOf | |||
| class RNGOpr(OpNode): | |||
| @classmethod | |||
| def load(cls, opr): | |||
| obj = super(RNGOpr, cls).load(opr) | |||
| if len(obj.params) == 3: | |||
| obj.opdef = builtin.GaussianRNG | |||
| obj.type = "GaussianRNG" | |||
| else: | |||
| obj.opdef = builtin.UniformRNG | |||
| obj.type = "UniformRNG" | |||
| return obj | |||
| class Linspace(OpNode): | |||
| type = "Linspace" | |||
| opdef = builtin.Linspace | |||
| @classmethod | |||
| def load(cls, opr): | |||
| obj = super(Linspace, cls).load(opr) | |||
| obj.params["comp_node"] = opr.outputs[0].comp_node | |||
| return obj | |||
| class Eye(OpNode): | |||
| type = "Eye" | |||
| opdef = builtin.Eye | |||
| @classmethod | |||
| def load(cls, opr): | |||
| obj = super(Eye, cls).load(opr) | |||
| obj.params["dtype"] = opr.outputs[0].dtype | |||
| obj.params["comp_node"] = opr.outputs[0].comp_node | |||
| return obj | |||
| class GetVarShape(OpNode): | |||
| type = "GetVarShape" | |||
| opdef = builtin.GetVarShape | |||
| class Concat(OpNode): | |||
| type = "Concat" | |||
| opdef = builtin.Concat | |||
| @classmethod | |||
| def load(cls, opr): | |||
| obj = super(Concat, cls).load(opr) | |||
| obj.params["comp_node"] = Device("xpux").to_c() | |||
| return obj | |||
| class Broadcast(OpNode): | |||
| type = "Broadcast" | |||
| opdef = builtin.Broadcast | |||
| class Identity(OpNode): | |||
| type = "Identity" | |||
| opdef = builtin.Identity | |||
| class NMSKeep(OpNode): | |||
| type = "NMSKeep" | |||
| opdef = builtin.NMSKeep | |||
| # class ParamPackSplit | |||
| # class ParamPackConcat | |||
| class Dimshuffle(OpNode): | |||
| type = "Dimshuffle" | |||
| opdef = builtin.Dimshuffle | |||
| @classmethod | |||
| def load(cls, opr): | |||
| obj = super(Dimshuffle, cls).load(opr) | |||
| del obj.params["ndim"] | |||
| return obj | |||
| class Reshape(OpNode): | |||
| type = "Reshape" | |||
| opdef = builtin.Reshape | |||
| class AxisAddRemove(OpNode): | |||
| type = "AxisAddRemove" | |||
| @classmethod | |||
| def load(cls, opr): | |||
| obj = cls() | |||
| obj.name = opr.name | |||
| obj._opr = opr | |||
| params = json.loads(opr.params) | |||
| desc = params["desc"] | |||
| method = None | |||
| axis = [] | |||
| for i in desc: | |||
| if method is None: | |||
| method = i["method"] | |||
| assert method == i["method"] | |||
| axis.append(i["axisnum"]) | |||
| obj.params = {"axis": axis} | |||
| obj.opdef = builtin.AddAxis if desc[0]["method"] == 0 else builtin.RemoveAxis | |||
| return obj | |||
| class IndexingBase(OpNode): | |||
| @classmethod | |||
| def load(cls, opr): | |||
| obj = cls() | |||
| obj.name = opr.name | |||
| obj._opr = opr | |||
| params = json.loads(opr.params) | |||
| items = [ | |||
| [ | |||
| p["axis"], | |||
| bool(p["begin"]), | |||
| bool(p["end"]), | |||
| bool(p["step"]), | |||
| bool(p["idx"]), | |||
| ] | |||
| for p in params | |||
| ] | |||
| obj.params["items"] = items | |||
| return obj | |||
| class Subtensor(IndexingBase): | |||
| type = "Subtensor" | |||
| opdef = builtin.Subtensor | |||
| class SetSubtensor(IndexingBase): | |||
| type = "SetSubtensor" | |||
| opdef = builtin.SetSubtensor | |||
| class IncrSubtensor(IndexingBase): | |||
| type = "IncrSubtensor" | |||
| opdef = builtin.IncrSubtensor | |||
| class IndexingMultiAxisVec(IndexingBase): | |||
| type = "IndexingMultiAxisVec" | |||
| opdef = builtin.IndexingMultiAxisVec | |||
| class IndexingSetMultiAxisVec(IndexingBase): | |||
| type = "IndexingSetMultiAxisVec" | |||
| opdef = builtin.IndexingSetMultiAxisVec | |||
| class IndexingIncrMultiAxisVec(IndexingBase): | |||
| type = "IndexingIncrMultiAxisVec" | |||
| opdef = builtin.IndexingIncrMultiAxisVec | |||
| class MeshIndexing(IndexingBase): | |||
| type = "MeshIndexing" | |||
| opdef = builtin.MeshIndexing | |||
| class SetMeshIndexing(IndexingBase): | |||
| type = "SetMeshIndexing" | |||
| opdef = builtin.SetMeshIndexing | |||
| class IncrMeshIndexing(IndexingBase): | |||
| type = "IncrMeshIndexing" | |||
| opdef = builtin.IncrMeshIndexing | |||
| class BatchedMeshIndexing(IndexingBase): | |||
| type = "BatchedMeshIndexing" | |||
| opdef = builtin.BatchedMeshIndexing | |||
| class BatchedSetMeshIndexing(IndexingBase): | |||
| type = "BatchedSetMeshIndexing" | |||
| opdef = builtin.BatchedSetMeshIndexing | |||
| class BatchedIncrMeshIndexing(IndexingBase): | |||
| type = "BatchedIncrMeshIndexing" | |||
| opdef = builtin.BatchedIncrMeshIndexing | |||
| # class CollectiveComm | |||
| # class RemoteSend | |||
| # class RemoteRecv | |||
| # class TQT | |||
| # class FakeQuant | |||
| # class InplaceAdd | |||
| class AssertEqual(OpNode): | |||
| type = "AssertEqual" | |||
| opdef = builtin.AssertEqual | |||
| class ElemwiseMultiType(OpNode): | |||
| type = "ElemwiseMultiType" | |||
| opdef = builtin.ElemwiseMultiType | |||
| @classmethod | |||
| def load(cls, opr): | |||
| obj = super(ElemwiseMultiType, cls).load(opr) | |||
| obj.params["dtype"] = opr.outputs[0].dtype | |||
| return obj | |||
| class CvtColorForward(OpNode): | |||
| type = "CvtColor" | |||
| opdef = builtin.CvtColor | |||
| @@ -160,6 +160,16 @@ PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObje | |||
| if (ctx.op->same_type<BackwardGraph>()) { | |||
| ctx.backward = true; | |||
| } | |||
| if (py::isinstance<cg::VarNode>(py::handle(args[0]))){ | |||
| SmallVector<cg::VarNode*> vinputs(nargs); | |||
| for (size_t i = 0; i < nargs; ++i) { | |||
| vinputs[i] = py::handle(args[i]).cast<cg::VarNode *>(); | |||
| } | |||
| auto op = ctx.op.get(); | |||
| return to_tuple(OpDef::apply_on_var_node(*op, vinputs)).release().ptr(); | |||
| } | |||
| for (size_t i = 0; i < nargs; ++i) { | |||
| if (TensorWrapper* tw = TensorWrapper::try_cast(args[i])) { | |||
| @@ -675,6 +685,16 @@ PyArray_Descr* _dtype_promotion(PyObject*const* args, size_t nargs) { | |||
| tensors.emplace_back(descr); | |||
| continue; | |||
| } | |||
| if (py::isinstance<cg::VarNode>(py::handle(handle))){ | |||
| auto var = py::handle(handle).cast<cg::VarNode *>(); | |||
| mgb::DType type = var->dtype(); | |||
| auto && descr = npy::dtype_mgb2np_descr(type); | |||
| Py_INCREF(descr.get()); | |||
| tensors.emplace_back(descr.get()); | |||
| continue; | |||
| } | |||
| PyArray_Descr* descr = scalar2dtype(handle); | |||
| if (descr) { | |||
| scalars.emplace_back(descr); | |||
| @@ -719,12 +739,14 @@ CompNode _get_device(PyObject*const* args, size_t nargs) { | |||
| for (size_t i = 0; i < nargs; ++i) { | |||
| PyObject* handle = is_tuple ? PyTuple_GetItem(tuple, i): args[i]; | |||
| TensorWrapper* tw = TensorWrapper::try_cast(handle); | |||
| if (tw) { | |||
| bool is_var = py::isinstance<cg::VarNode>(py::handle(handle)); | |||
| if (tw || is_var) { | |||
| if (!valid) { | |||
| cn = tw->m_tensor->comp_node(); | |||
| cn = tw ? tw->m_tensor->comp_node() : py::handle(handle).cast<cg::VarNode *>()->comp_node(); | |||
| valid = true; | |||
| } else { | |||
| CompNode cn1 = tw->m_tensor->comp_node(); | |||
| CompNode cn1 = tw ? tw->m_tensor->comp_node() : py::handle(handle).cast<cg::VarNode *>()->comp_node(); | |||
| if (cn1 != cn) { | |||
| throw py::value_error(ssprintf("ambiguous device: %s vs %s", | |||
| cn.to_string().c_str(), cn1.to_string().c_str())); | |||
| @@ -0,0 +1,351 @@ | |||
| import io | |||
| import numpy as np | |||
| import megengine.core.tensor.megbrain_graph as G | |||
| import megengine.functional as F | |||
| import megengine.module as M | |||
| import megengine.utils.network_node as N | |||
| from megengine.jit.tracing import trace | |||
| from megengine.tensor import Tensor | |||
| from megengine.utils.comp_graph_tools import GraphInference | |||
| from megengine.utils.network import Network as Net | |||
| from megengine.utils.network import as_oprnode | |||
| from megengine.utils.network_node import Host2DeviceCopy, VarNode | |||
| def test_replace_var(): | |||
| a = Tensor([1, 2]) | |||
| b = Tensor([3, 4]) | |||
| @trace(symbolic=True, capture_as_const=True) | |||
| def fwd(a, b): | |||
| return (a + b) * 2 | |||
| fwd(a, b) | |||
| orig_model = io.BytesIO() | |||
| fwd.dump( | |||
| orig_model, arg_names=["a", "b"], output_names="o", optimize_for_inference=False | |||
| ) | |||
| orig_model.seek(0) | |||
| graph = Net.load(orig_model) | |||
| vara = graph.var_filter.name("a").as_unique() | |||
| varb = graph.var_filter.name("b").as_unique() | |||
| out = F.mul(vara.var, varb.var) | |||
| out = F.relu(out) | |||
| var_list = graph.add_dep_oprs(out) | |||
| opnode = list(graph.opr_filter.has_input(vara)) | |||
| repl_dict = {opnode[0].outputs[0]: var_list[0]} | |||
| graph.replace_vars(repl_dict) | |||
| modified_model = io.BytesIO() | |||
| graph.dump(modified_model) | |||
| modified_model.seek(0) | |||
| load_graph = GraphInference(modified_model) | |||
| out = load_graph.run(a, b) | |||
| np.testing.assert_equal(out["o"], [6, 16]) | |||
| def test_replace_opr(): | |||
| a = Tensor([1, 2]) | |||
| b = Tensor([3, 4]) | |||
| @trace(symbolic=True, capture_as_const=True) | |||
| def fwd(a, b): | |||
| return (a + b) * 2 | |||
| fwd(a, b) | |||
| orig_model = io.BytesIO() | |||
| fwd.dump( | |||
| orig_model, arg_names=["a", "b"], output_names="o", optimize_for_inference=False | |||
| ) | |||
| orig_model.seek(0) | |||
| graph = Net.load(orig_model) | |||
| vara = graph.var_filter.name("a").as_unique() | |||
| varb = graph.var_filter.name("b").as_unique() | |||
| out1 = F.sub(vara.var, varb.var) | |||
| out1 = F.relu(out1) | |||
| var_list = graph.add_dep_oprs(out1) | |||
| repl_opr = as_oprnode(var_list) | |||
| orig_opr = graph.opr_filter.has_input(vara).as_unique() | |||
| repl_dict = {orig_opr: repl_opr} | |||
| graph.replace_oprs(repl_dict) | |||
| modified_model1 = io.BytesIO() | |||
| graph.dump(modified_model1) | |||
| modified_model1.seek(0) | |||
| load_graph = GraphInference(modified_model1) | |||
| out = load_graph.run(a, b) | |||
| np.testing.assert_equal(out["o"], [0, 0]) | |||
| def test_modify_params(): | |||
| a = Tensor([1, 2]) | |||
| b = Tensor([3, 4]) | |||
| @trace(symbolic=True, capture_as_const=True) | |||
| def fwd(a, b): | |||
| return (a + b) * 2 | |||
| fwd(a, b) | |||
| orig_model = io.BytesIO() | |||
| fwd.dump( | |||
| orig_model, arg_names=["a", "b"], output_names="o", optimize_for_inference=False | |||
| ) | |||
| orig_model.seek(0) | |||
| graph = Net.load(orig_model) | |||
| param_const = graph.params_filter.as_unique() | |||
| param_const.set_value(3) | |||
| modified_model = io.BytesIO() | |||
| graph.dump(modified_model) | |||
| modified_model.seek(0) | |||
| load_graph = GraphInference(modified_model) | |||
| out = load_graph.run(a, b) | |||
| np.testing.assert_equal(out["o"], [12, 18]) | |||
| def test_make_const(): | |||
| a = Tensor([1, 2]) | |||
| b = Tensor([3, 4]) | |||
| @trace(symbolic=True, capture_as_const=True) | |||
| def fwd(a, b): | |||
| return (a + b) * 2 | |||
| fwd(a, b) | |||
| orig_model = io.BytesIO() | |||
| fwd.dump( | |||
| orig_model, arg_names=["a", "b"], output_names="o", optimize_for_inference=False | |||
| ) | |||
| orig_model.seek(0) | |||
| graph = Net.load(orig_model) | |||
| const_b = graph.make_const(np.array([0.0, 0.0]), name="b") | |||
| varb = graph.var_filter.name("b").as_unique() | |||
| repl_dict = {varb: const_b} | |||
| graph.replace_vars(repl_dict) | |||
| modified_model = io.BytesIO() | |||
| graph.dump(modified_model) | |||
| modified_model.seek(0) | |||
| load_graph = GraphInference(modified_model) | |||
| out = load_graph.run(a) | |||
| np.testing.assert_equal(out["o"], [2, 4]) | |||
| def test_add_input(): | |||
| a = Tensor([1, 2]) | |||
| b = Tensor([3, 4]) | |||
| @trace(symbolic=True, capture_as_const=True) | |||
| def fwd(a, b): | |||
| return (a + b) * 2 | |||
| fwd(a, b) | |||
| orig_model = io.BytesIO() | |||
| fwd.dump( | |||
| orig_model, arg_names=["a", "b"], output_names="o", optimize_for_inference=False | |||
| ) | |||
| orig_model.seek(0) | |||
| graph = Net.load(orig_model) | |||
| inp_c = graph.make_input_node((2,), np.int32, name="c") | |||
| varo = graph.var_filter.name("o").as_unique() | |||
| out = F.add(varo.var, inp_c.var) | |||
| out = graph.add_dep_oprs(out)[0] | |||
| out.name = "o1" | |||
| graph.remove_output(varo) | |||
| graph.add_output(out) | |||
| modified_model = io.BytesIO() | |||
| graph.dump(modified_model) | |||
| modified_model.seek(0) | |||
| load_graph = GraphInference(modified_model) | |||
| out = load_graph.run(a, b, a) | |||
| np.testing.assert_equal(out["o1"], ((a + b) * 2 + a).numpy()) | |||
| def test_add_output(): | |||
| a = Tensor([1.0, 2.0]) | |||
| b = Tensor([3.0, 4.0]) | |||
| @trace(symbolic=True, capture_as_const=True) | |||
| def fwd(a, b): | |||
| return (a + b) * 2 | |||
| fwd(a, b) | |||
| orig_model = io.BytesIO() | |||
| fwd.dump( | |||
| orig_model, arg_names=["a", "b"], output_names="o", optimize_for_inference=False | |||
| ) | |||
| orig_model.seek(0) | |||
| net = Net.load(orig_model) | |||
| var_a = net.var_filter.name("a").as_unique() | |||
| var_b = net.var_filter.name("b").as_unique() | |||
| y = F.add(var_a.var, var_b.var) | |||
| y = F.sigmoid(y) | |||
| new_vars = net.add_dep_oprs(y)[0] | |||
| new_vars.name = "o1" | |||
| net.add_output(new_vars) | |||
| modified_model = io.BytesIO() | |||
| net.dump(modified_model) | |||
| modified_model.seek(0) | |||
| g = GraphInference(modified_model) | |||
| out = g.run(a.numpy(), b.numpy()) | |||
| np.testing.assert_equal(out["o"], ((a + b) * 2).numpy()) | |||
| np.testing.assert_equal(out["o1"], (F.sigmoid((a + b))).numpy()) | |||
| def test_query(): | |||
| class Model(M.Module): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.conv1 = M.Conv2d(3, 32, 3) | |||
| self.conv2 = M.Conv2d(32, 32, 3) | |||
| self.conv3 = M.Conv2d(32, 32, 3) | |||
| def forward(self, data): | |||
| x = self.conv1(data) | |||
| x = self.conv2(x) | |||
| x = self.conv3(x) | |||
| return x | |||
| n = Model() | |||
| @trace(symbolic=True, capture_as_const=True) | |||
| def fwd(data): | |||
| return n(data) | |||
| fwd(Tensor(np.random.random((1, 3, 224, 224)))) | |||
| orig_model = io.BytesIO() | |||
| fwd.dump( | |||
| orig_model, | |||
| arg_names=["data"], | |||
| output_names="o", | |||
| keep_opr_name=True, | |||
| keep_var_name=True, | |||
| optimize_for_inference=False, | |||
| ) | |||
| orig_model.seek(0) | |||
| graph = Net.load(orig_model) | |||
| r = graph.data_providers_filter.as_count() | |||
| assert r == 1 | |||
| opr = graph.get_opr_by_type(Host2DeviceCopy) | |||
| assert isinstance(opr, Host2DeviceCopy) | |||
| r1 = graph.params_filter.as_count() | |||
| assert r1 == 6 | |||
| r2 = graph.opr_filter.type(N.ConvolutionForward).as_count() | |||
| assert r2 == 3 | |||
| r3 = graph.opr_filter.not_type(N.ConvolutionForward).as_count() | |||
| assert r3 == len(graph.all_oprs) - r2 | |||
| var = graph.var_filter.name("data").as_unique() | |||
| r4 = graph.opr_filter.has_input(var).as_count() | |||
| assert r4 == 1 | |||
| r5 = graph.opr_filter.name("data").as_count() | |||
| assert r5 == 1 | |||
| opr = graph.get_opr_by_name("data") | |||
| assert isinstance(opr, Host2DeviceCopy) | |||
| var = graph.get_var_by_name("data") | |||
| assert isinstance(var, VarNode) | |||
| r6 = graph.var_filter.name("*bias").as_count() | |||
| assert r6 == 3 | |||
| def test_optimize_for_inference(): | |||
| @trace(symbolic=True, capture_as_const=True) | |||
| def f(x): | |||
| return F.exp(x) | |||
| orig_model = io.BytesIO() | |||
| f(Tensor(5.0)) | |||
| f.dump(orig_model, optimize_for_inference=False) | |||
| orig_model.seek(0) | |||
| optimize_model = io.BytesIO() | |||
| net = Net.load(orig_model) | |||
| net.dump(optimize_model, enable_io16xc32=True) | |||
| optimize_model.seek(0) | |||
| res = G.load_graph(optimize_model) | |||
| computing_input = res.output_vars_list[0].owner.inputs[0] | |||
| assert computing_input.dtype == np.float16 | |||
| def test_reset_batchsize(): | |||
| @trace(symbolic=True, capture_as_const=True) | |||
| def f(x): | |||
| return F.exp(x) | |||
| orig_model = io.BytesIO() | |||
| f(Tensor(np.random.random((3, 3, 224, 224)))) | |||
| f.dump(orig_model, optimize_for_inference=False) | |||
| orig_model.seek(0) | |||
| modified_model = io.BytesIO() | |||
| net = Net.load(orig_model) | |||
| net.reset_batch_size(1) | |||
| net.dump(modified_model, optimize_for_inference=False) | |||
| modified_model.seek(0) | |||
| net1 = Net.load(modified_model) | |||
| assert net1.data_providers_filter.as_unique().shape[0] == 1 | |||
| def test_modify_opr_name(): | |||
| @trace(symbolic=True, capture_as_const=True) | |||
| def f(x): | |||
| return F.exp(x) | |||
| orig_model = io.BytesIO() | |||
| f(Tensor(np.random.random((3, 3, 224, 224)))) | |||
| f.dump(orig_model, arg_names=["a"], optimize_for_inference=False) | |||
| orig_model.seek(0) | |||
| modified_model = io.BytesIO() | |||
| net = Net.load(orig_model) | |||
| net.modify_opr_names("net") | |||
| net.modify_opr_names(lambda x: "net1." + x) | |||
| net.dump(modified_model, optimize_for_inference=False) | |||
| modified_model.seek(0) | |||
| net1 = Net.load(modified_model) | |||
| assert net1.data_providers_filter.as_unique().name == "net1.net.a" | |||
| @@ -0,0 +1,712 @@ | |||
| import io | |||
| import os | |||
| import platform | |||
| import numpy as np | |||
| import pytest | |||
| import megengine.core.tensor.dtype as dtype | |||
| import megengine.core.tensor.megbrain_graph as G | |||
| import megengine.functional as F | |||
| import megengine.module as M | |||
| import megengine.random as rand | |||
| from megengine.core._imperative_rt.core2 import apply | |||
| from megengine.core._wrap import Device | |||
| from megengine.core.ops import builtin | |||
| from megengine.device import is_cuda_available | |||
| from megengine.functional.external import tensorrt_runtime_opr | |||
| from megengine.jit.tracing import trace | |||
| from megengine.tensor import Tensor | |||
| from megengine.utils.comp_graph_tools import GraphInference | |||
| from megengine.utils.network import Network as Net | |||
| def check_pygraph_dump(trace_func, inp_data, expect_results): | |||
| orig_model = io.BytesIO() | |||
| inp_size = len(inp_data) | |||
| out_size = len(expect_results) | |||
| arg_names = ["arg_{}".format(i) for i in range(inp_size)] | |||
| output_names = ["out_{}".format(i) for i in range(out_size)] | |||
| trace_func.dump( | |||
| orig_model, | |||
| arg_names=arg_names, | |||
| output_names=output_names, | |||
| optimize_for_inference=False, | |||
| ) | |||
| orig_model.seek(0) | |||
| net = Net.load(orig_model) | |||
| file = io.BytesIO() | |||
| net.dump(file, optimize_for_inference=False) | |||
| file.seek(0) | |||
| graph = GraphInference(file) | |||
| inp_dict = dict([(arg_names[i], inp_data[i].numpy()) for i in range(inp_size)]) | |||
| results = graph.run(inp_dict=inp_dict) | |||
| for ind, tensor in enumerate(expect_results): | |||
| np.testing.assert_equal(tensor.numpy(), results[output_names[ind]]) | |||
| assert tensor.dtype == results[output_names[ind]].dtype | |||
| def test_elemwise(): | |||
| @trace(symbolic=True, capture_as_const=True) | |||
| def fwd(x, y): | |||
| z1 = x * y | |||
| z2 = x + y | |||
| z3 = z1 / z2 | |||
| z3 = z3 ** 3 | |||
| return z3 | |||
| x = Tensor([1.0, 2.0]) | |||
| y = Tensor([3.0, 5.0]) | |||
| result = fwd(x, y) | |||
| check_pygraph_dump(fwd, [x, y], [result]) | |||
| def test_reduce(): | |||
| @trace(symbolic=True, capture_as_const=True) | |||
| def fwd(data): | |||
| x = data.sum(axis=2) | |||
| x = x.mean(axis=1) | |||
| return x | |||
| data = Tensor(np.random.random((1, 32, 32))) | |||
| result = fwd(data) | |||
| check_pygraph_dump(fwd, [data], [result]) | |||
| def test_typecvt(): | |||
| @trace(symbolic=True, capture_as_const=True) | |||
| def fwd(data): | |||
| return data.astype(dtype.qint8(0.8)) | |||
| x = Tensor(np.random.random((2, 3)) * 255) | |||
| result = fwd(x) | |||
| check_pygraph_dump(fwd, [x], [result]) | |||
| def test_matinv(): | |||
| @trace(symbolic=True, capture_as_const=True) | |||
| def fwd(data): | |||
| return F.matinv(data) | |||
| data = Tensor(np.random.random((5, 5))) | |||
| result = fwd(data) | |||
| check_pygraph_dump(fwd, [data], [result]) | |||
| def test_matmul(): | |||
| @trace(symbolic=True, capture_as_const=True) | |||
| def fwd(data1, data2): | |||
| return F.matmul(data1, data2) | |||
| data1 = Tensor(np.random.random((32, 64))) | |||
| data2 = Tensor(np.random.random((64, 16))) | |||
| result = fwd(data1, data2) | |||
| check_pygraph_dump(fwd, [data1, data2], [result]) | |||
| def test_batchmatmul(): | |||
| @trace(symbolic=True, capture_as_const=True) | |||
| def fwd(x, y): | |||
| return F.matmul(x, y) | |||
| x = Tensor(np.random.random((3, 3, 5))) | |||
| y = Tensor(np.random.random((3, 5, 3))) | |||
| result = fwd(x, y) | |||
| check_pygraph_dump(fwd, [x, y], [result]) | |||
| def test_dot(): | |||
| @trace(symbolic=True, capture_as_const=True) | |||
| def fwd(x, y): | |||
| return F.dot(x, y) | |||
| x = Tensor([1.0, 2.0, 3.0]) | |||
| y = Tensor([3.0, 4.0, 5.0]) | |||
| result = fwd(x, y) | |||
| check_pygraph_dump(fwd, [x, y], [result]) | |||
| def test_svd(): | |||
| @trace(symbolic=True, capture_as_const=True) | |||
| def fwd(data): | |||
| _, out, _ = F.svd(data) | |||
| return out | |||
| input = Tensor(np.random.random((1, 1, 3, 3))) | |||
| result = fwd(input) | |||
| check_pygraph_dump(fwd, [input], [result]) | |||
| def test_conv(): | |||
| conv = M.Conv2d(3, 32, 3) | |||
| @trace(symbolic=True, capture_as_const=True) | |||
| def fwd(data): | |||
| return conv(data) | |||
| data = Tensor(np.random.random((1, 3, 32, 32))) | |||
| result = fwd(data) | |||
| check_pygraph_dump(fwd, [data], [result]) | |||
| def test_deformable_conv(): | |||
| if not is_cuda_available(): | |||
| return | |||
| conv = M.DeformableConv2d(3, 32, 3) | |||
| @trace(symbolic=True, capture_as_const=True) | |||
| def fwd(data, offset, mask): | |||
| return conv(data, offset, mask) | |||
| data = Tensor(np.random.random((1, 3, 32, 32))) | |||
| offset = Tensor(np.ones((32, 3 * 3 * 2, 30, 30)).astype("int32") * 5) | |||
| mask = Tensor(np.ones((32, 3 * 3, 30, 30)).astype("int32")) | |||
| out = fwd(data, offset, mask) | |||
| check_pygraph_dump(fwd, [data, offset, mask], [out]) | |||
| def test_convtranspose(): | |||
| deconv = M.ConvTranspose2d(32, 32, 3) | |||
| @trace(symbolic=True, capture_as_const=True) | |||
| def fwd(data): | |||
| return deconv(data) | |||
| data = Tensor(np.random.random((1, 32, 32, 32))) | |||
| result = fwd(data) | |||
| check_pygraph_dump(fwd, [data], [result]) | |||
| @pytest.mark.skip(reason="pytest aborted") | |||
| def test_grouplocal(): | |||
| n = M.LocalConv2d(3, 32, 32, 32, 3) | |||
| @trace(symbolic=True, capture_as_const=True) | |||
| def fwd(data): | |||
| return n(data) | |||
| input = Tensor(np.random.random((1, 3, 32, 32))) | |||
| result = fwd(input) | |||
| check_pygraph_dump(fwd, [input], [result]) | |||
| def test_pooling(): | |||
| @trace(symbolic=True, capture_as_const=True) | |||
| def fwd(data): | |||
| out = F.max_pool2d(data, 2, 2) | |||
| out = F.avg_pool2d(out, 2, 2) | |||
| return out | |||
| data = Tensor(np.random.random((1, 3, 64, 64))) | |||
| result = fwd(data) | |||
| check_pygraph_dump(fwd, [data], [result]) | |||
| def test_adaptivepooling(): | |||
| pool1 = M.AdaptiveMaxPool2d((2, 2)) | |||
| pool2 = M.AdaptiveAvgPool2d((2, 2)) | |||
| @trace(symbolic=True, capture_as_const=True) | |||
| def fwd(data): | |||
| out = pool1(data) | |||
| out = pool2(out) | |||
| return out | |||
| input = Tensor(np.random.random((1, 3, 32, 32))) | |||
| result = fwd(input) | |||
| check_pygraph_dump(fwd, [input], [result]) | |||
| def test_roipooling(): | |||
| inp = Tensor(np.random.random((1, 1, 128, 128))) | |||
| rois = Tensor(np.random.random((4, 5))) | |||
| @trace(symbolic=True, capture_as_const=True) | |||
| def fwd(inp, rois): | |||
| return F.nn.roi_pooling(inp, rois, (2, 2), scale=2.0) | |||
| output = fwd(inp, rois) | |||
| check_pygraph_dump(fwd, [inp, rois], [output]) | |||
| def test_deformable_ps_roi_pooling(): | |||
| inp = Tensor(np.random.random((1, 256, 64, 64)).astype("float32")) | |||
| rois = Tensor(np.random.random((1, 5)).astype("float32")) | |||
| trans = Tensor(np.random.random((24, 2, 7, 7)).astype("float32")) | |||
| pooled_h = 7 | |||
| pooled_w = 7 | |||
| sample_per_part = 4 | |||
| no_trans = False | |||
| part_size = 7 | |||
| spatial_scale = 1.0 / 64 | |||
| trans_std = 0.1 | |||
| @trace(symbolic=True, capture_as_const=True) | |||
| def fwd(inp, rois, trans): | |||
| y = F.deformable_psroi_pooling( | |||
| inp, | |||
| rois, | |||
| trans, | |||
| no_trans, | |||
| part_size, | |||
| pooled_h, | |||
| pooled_w, | |||
| sample_per_part, | |||
| spatial_scale, | |||
| trans_std, | |||
| ) | |||
| return y | |||
| result = fwd(inp, rois, trans) | |||
| check_pygraph_dump(fwd, [inp, rois, trans], [result]) | |||
| def test_convbias(): | |||
| @trace(symbolic=True, capture_as_const=True) | |||
| def fwd(inp, weight, bias): | |||
| return F.quantized.conv_bias_activation( | |||
| inp, weight, bias, dtype=dtype.qint8(scale=1.0), nonlinear_mode="RELU" | |||
| ) | |||
| inp = Tensor(np.random.random((1, 3, 64, 64)), dtype=dtype.qint8(scale=1.0)) | |||
| weight = Tensor(np.random.random((32, 3, 3, 3)), dtype=dtype.qint8(scale=1.0)) | |||
| bias = Tensor(np.random.random((1, 32, 1, 1)), dtype=dtype.qint32(scale=1.0)) | |||
| result = fwd(inp, weight, bias) | |||
| check_pygraph_dump(fwd, [inp, weight, bias], [result]) | |||
| def test_batch_convbias(): | |||
| if is_cuda_available(): | |||
| return | |||
| @trace(symbolic=True, capture_as_const=True) | |||
| def fwd(inp, weight, bias): | |||
| return F.quantized.batch_conv_bias_activation( | |||
| inp, weight, bias, dtype=dtype.qint8(scale=1.0), nonlinear_mode="RELU" | |||
| ) | |||
| inp = Tensor(np.random.random((1, 3, 64, 64)), dtype=dtype.qint8(scale=1.0)) | |||
| weight = Tensor(np.random.random((1, 32, 3, 3, 3)), dtype=dtype.qint8(scale=1.0)) | |||
| bias = Tensor(np.random.random((1, 32, 1, 1)), dtype=dtype.qint32(scale=1.0)) | |||
| result = fwd(inp, weight, bias) | |||
| check_pygraph_dump(fwd, [inp, weight, bias], [result]) | |||
| def test_batchnorm(): | |||
| bn = M.BatchNorm2d(32) | |||
| bn.eval() | |||
| @trace(symbolic=True, capture_as_const=True) | |||
| def fwd(data): | |||
| return bn(data) | |||
| data = Tensor(np.random.random((1, 32, 32, 32))) | |||
| result = fwd(data) | |||
| check_pygraph_dump(fwd, [data], [result]) | |||
| def test_roialign(): | |||
| inp = Tensor(np.random.randn(1, 1, 128, 128)) | |||
| rois = Tensor(np.random.random((4, 5))) | |||
| @trace(symbolic=True, capture_as_const=True) | |||
| def fwd(inp, rois): | |||
| return F.nn.roi_align(inp, rois, (2, 2)) | |||
| output = fwd(inp, rois) | |||
| check_pygraph_dump(fwd, [inp, rois], [output]) | |||
| def test_warpperspective(): | |||
| inp_shape = (1, 1, 4, 4) | |||
| x = Tensor(np.arange(16, dtype=np.float32).reshape(inp_shape)) | |||
| M_shape = (1, 3, 3) | |||
| # M defines a translation: dst(1, 1, h, w) = rst(1, 1, h+1, w+1) | |||
| M = Tensor( | |||
| np.array( | |||
| [[1.0, 0.0, 1.0], [0.0, 1.0, 1.0], [0.0, 0.0, 1.0]], dtype=np.float32 | |||
| ).reshape(M_shape) | |||
| ) | |||
| @trace(symbolic=True, capture_as_const=True) | |||
| def fwd(x, M): | |||
| return F.warp_perspective(x, M, (2, 2)) | |||
| result = fwd(x, M) | |||
| check_pygraph_dump(fwd, [x, M], [result]) | |||
| def test_warpaffine(): | |||
| inp_shape = (1, 3, 3, 3) | |||
| x = Tensor(np.arange(27, dtype=np.float32).reshape(inp_shape)) | |||
| weightv = Tensor([[[1.26666667, 0.6, -83.33333333], [-0.33333333, 1, 66.66666667]]]) | |||
| @trace(symbolic=True, capture_as_const=True) | |||
| def fwd(x, weightv): | |||
| return F.warp_affine(x, weightv, (2, 2), border_mode="WRAP") | |||
| outp = fwd(x, weightv) | |||
| check_pygraph_dump(fwd, [x, weightv], [outp]) | |||
| def test_remap(): | |||
| inp_shape = (1, 1, 4, 4) | |||
| inp = Tensor(np.arange(16, dtype=np.float32).reshape(inp_shape)) | |||
| map_xy_shape = (1, 2, 2, 2) | |||
| map_xy = Tensor( | |||
| np.array( | |||
| [[[1.0, 0.0], [0.0, 1.0]], [[0.0, 1.0], [0.0, 1.0]]], dtype=np.float32 | |||
| ).reshape(map_xy_shape) | |||
| ) | |||
| @trace(symbolic=True, capture_as_const=True) | |||
| def fwd(inp, map_xy): | |||
| return F.remap(inp, map_xy) | |||
| out = fwd(inp, map_xy) | |||
| check_pygraph_dump(fwd, [inp, map_xy], [out]) | |||
| def test_resize(): | |||
| x = Tensor(np.random.randn(10, 3, 32, 32)) | |||
| @trace(symbolic=True, capture_as_const=True) | |||
| def fwd(x): | |||
| return F.nn.interpolate(x, size=(16, 16), mode="BILINEAR") | |||
| out = fwd(x) | |||
| check_pygraph_dump(fwd, [x], [out]) | |||
| def test_index_onehot(): | |||
| src = Tensor([[1.0, 2.0]]) | |||
| index = Tensor([0]) | |||
| @trace(symbolic=True, capture_as_const=True) | |||
| def fwd(src, index): | |||
| return F.indexing_one_hot(src, index) | |||
| out = fwd(src, index) | |||
| check_pygraph_dump(fwd, [src, index], [out]) | |||
| def test_set_onehot(): | |||
| x = Tensor(np.arange(1, 4, dtype=np.int32)) | |||
| @trace(symbolic=True, capture_as_const=True) | |||
| def fwd(x): | |||
| return F.one_hot(x, num_classes=4) | |||
| out = fwd(x) | |||
| check_pygraph_dump(fwd, [x], [out]) | |||
| def test_copy(): | |||
| x = Tensor([1, 2, 3]) | |||
| @trace(symbolic=True, capture_as_const=True) | |||
| def fwd(x): | |||
| return x.to("cpu0:0") | |||
| o = fwd(x) | |||
| check_pygraph_dump(fwd, [x], [o]) | |||
| def test_argsort(): | |||
| @trace(symbolic=True, capture_as_const=True) | |||
| def fwd(data): | |||
| return F.argsort(data, True) | |||
| data = Tensor([1.0, 2.0, 3.0, 5.0]) | |||
| result = fwd(data) | |||
| check_pygraph_dump(fwd, [data], [result]) | |||
| def test_argmax_min(): | |||
| @trace(symbolic=True, capture_as_const=True) | |||
| def fwd(data): | |||
| return F.argmax(data), F.argmin(data) | |||
| data = Tensor(np.random.random((10, 10))) | |||
| result = fwd(data) | |||
| check_pygraph_dump(fwd, [data], result) | |||
| def test_condtake(): | |||
| mask = Tensor(np.array([[True, False], [False, True]], dtype=np.bool_)) | |||
| x = Tensor(np.array([[1, np.inf], [np.nan, 4]], dtype=np.float32)) | |||
| @trace(symbolic=True, capture_as_const=True) | |||
| def fwd(mask, x): | |||
| v, index = F.cond_take(mask, x) | |||
| return v, index | |||
| v, index = fwd(mask, x) | |||
| check_pygraph_dump(fwd, [mask, x], [v, index]) | |||
| def test_topk(): | |||
| x = Tensor(np.array([2, 4, 6, 8, 7, 5, 3, 1], dtype=np.float32)) | |||
| @trace(symbolic=True, capture_as_const=True) | |||
| def fwd(x): | |||
| top, indices = F.topk(x, 5) | |||
| return top, indices | |||
| top, indices = fwd(x) | |||
| check_pygraph_dump(fwd, [x], [top, indices]) | |||
| def test_random(): | |||
| @trace(symbolic=True, capture_as_const=True) | |||
| def fwd(): | |||
| x = rand.uniform(size=(2, 2)) | |||
| y = rand.normal(size=(1, 3, 3, 3)) | |||
| return x, y | |||
| x, y = fwd() | |||
| check_pygraph_dump(fwd, [], [x, y]) | |||
| def test_tensor_gen(): | |||
| @trace(symbolic=True, capture_as_const=True) | |||
| def fwd(): | |||
| a = F.linspace(3, 10, 3, device=Device("xpux").to_c()) | |||
| b = F.eye(3, device=Device("xpux").to_c()) | |||
| return a, b | |||
| a, b = fwd() | |||
| check_pygraph_dump(fwd, [], [a, b]) | |||
| def test_getvarshape(): | |||
| op = builtin.GetVarShape(axis=1) | |||
| @trace(symbolic=True, capture_as_const=True) | |||
| def fwd(data): | |||
| return apply(op, data)[0] | |||
| data = Tensor(np.random.random((1, 2, 3, 4))) | |||
| result = fwd(data) | |||
| check_pygraph_dump(fwd, [data], [result]) | |||
| def test_concat(): | |||
| @trace(symbolic=True, capture_as_const=True) | |||
| def fwd(data1, data2): | |||
| return F.concat([data1, data2], axis=1) | |||
| x = Tensor(np.random.random((2, 3))) | |||
| y = Tensor(np.random.random((2, 5))) | |||
| result = fwd(x, y) | |||
| check_pygraph_dump(fwd, [x, y], [result]) | |||
| def test_broadcast(): | |||
| inp = Tensor([[1], [2], [3], [4]]) | |||
| @trace(symbolic=True, capture_as_const=True) | |||
| def fwd(inp): | |||
| return F.broadcast_to(inp, (4, 4)) | |||
| out = fwd(inp) | |||
| check_pygraph_dump(fwd, [inp], [out]) | |||
| def test_identity(): | |||
| @trace(symbolic=True, capture_as_const=True) | |||
| def fwd(data): | |||
| return F.copy(data) | |||
| data = Tensor([1.0, 2.0]) | |||
| result = fwd(data) | |||
| check_pygraph_dump(fwd, [data], [result]) | |||
| @pytest.mark.skip(reason="advance indexing trace error") | |||
| def test_nms(): | |||
| x = np.zeros((100, 4)) | |||
| np.random.seed(42) | |||
| x[:, :2] = np.random.rand(100, 2) * 20 | |||
| x[:, 2:] = np.random.rand(100, 2) * 20 + 100 | |||
| scores = Tensor(np.random.rand(100)) | |||
| inp = Tensor(x) | |||
| @trace(symbolic=True, capture_as_const=True) | |||
| def fwd(inp, scores): | |||
| return F.nn.nms(inp, scores, iou_thresh=0.7, max_output=3) | |||
| result = fwd(inp, scores) | |||
| check_pygraph_dump(fwd, [inp, scores], [result]) | |||
| def test_dimshuffle(): | |||
| inp = Tensor([1, 2, 3, 4]) | |||
| @trace(symbolic=True, capture_as_const=True) | |||
| def fwd(inp): | |||
| return inp.T | |||
| out = fwd(inp) | |||
| check_pygraph_dump(fwd, [inp], [out]) | |||
| def test_reshape(): | |||
| @trace(symbolic=True, capture_as_const=True) | |||
| def fwd(data): | |||
| return data.reshape((1, 8)) | |||
| data = Tensor(np.random.random((1, 2, 2, 2))) | |||
| result = fwd(data) | |||
| check_pygraph_dump(fwd, [data], [result]) | |||
| def test_add_remove_axis(): | |||
| @trace(symbolic=True, capture_as_const=True) | |||
| def fwd(data): | |||
| x = F.expand_dims(data, [0, 0]) | |||
| y = F.squeeze(x, 0) | |||
| return y | |||
| data = Tensor([1.0, 2.0]) | |||
| result = fwd(data) | |||
| check_pygraph_dump(fwd, [data], [result]) | |||
| @pytest.mark.parametrize("mode", ["get", "set", "inc"]) | |||
| def test_subtensor(mode): | |||
| items = [[0, True, True, True, False], [1, False, False, False, True]] | |||
| data = [Tensor(np.random.random((5, 5))), Tensor(np.random.random(2))] | |||
| if mode == "get": | |||
| op = builtin.Subtensor(items) | |||
| data = data[:1] | |||
| if mode == "set": | |||
| op = builtin.SetSubtensor(items) | |||
| if mode == "inc": | |||
| op = builtin.IncrSubtensor(items) | |||
| tensors = [Tensor(0), Tensor(4), Tensor(2), Tensor(3)] | |||
| @trace(symbolic=True, capture_as_const=True) | |||
| def fwd(*tensors): | |||
| return apply(op, *tensors)[0] | |||
| result = fwd(*data, *tensors) | |||
| check_pygraph_dump(fwd, data + tensors, [result]) | |||
| @pytest.mark.parametrize("mode", ["get", "set", "inc"]) | |||
| def test_advance_indexing(mode): | |||
| items = [[0, False, False, False, True]] | |||
| tensors = [Tensor([0, 4, 2])] | |||
| data = [Tensor(np.random.random((5, 5))), Tensor(np.random.random((3, 5)))] | |||
| if mode == "get": | |||
| op = builtin.IndexingMultiAxisVec(items) | |||
| data = data[:1] | |||
| if mode == "set": | |||
| op = builtin.IndexingSetMultiAxisVec(items) | |||
| if mode == "inc": | |||
| op = builtin.IndexingIncrMultiAxisVec(items) | |||
| @trace(symbolic=True, capture_as_const=True) | |||
| def fwd(*tensors): | |||
| return apply(op, *tensors)[0] | |||
| result = fwd(*data, *tensors) | |||
| check_pygraph_dump(fwd, data + tensors, [result]) | |||
| @pytest.mark.parametrize("mode", ["get", "set", "inc"]) | |||
| def test_mesh_indexing(mode): | |||
| items = [[0, True, True, True, False], [1, False, False, False, True]] | |||
| tensors = [Tensor(0), Tensor(5), Tensor(2), Tensor([1, 3])] | |||
| data = [Tensor(np.random.random((5, 5))), Tensor(np.random.random((3, 2)))] | |||
| if mode == "get": | |||
| op = builtin.IndexingMultiAxisVec(items) | |||
| data = data[:1] | |||
| if mode == "set": | |||
| op = builtin.IndexingSetMultiAxisVec(items) | |||
| if mode == "inc": | |||
| op = builtin.IndexingIncrMultiAxisVec(items) | |||
| @trace(symbolic=True, capture_as_const=True) | |||
| def fwd(*tensors): | |||
| return apply(op, *tensors)[0] | |||
| result = fwd(*data, *tensors) | |||
| check_pygraph_dump(fwd, data + tensors, [result]) | |||
| @pytest.mark.parametrize("mode", ["get", "set", "inc"]) | |||
| def test_batch_mesh_indexing(mode): | |||
| items = [[1, False, False, False, True], [2, False, False, False, True]] | |||
| tensors = [Tensor([[0, 2], [0, 2]]), Tensor([[0, 1, 2], [1, 2, 3]])] | |||
| data = [Tensor(np.random.random((2, 3, 4))), Tensor(np.random.random((2, 2, 3)))] | |||
| if mode == "get": | |||
| op = builtin.BatchedMeshIndexing(items) | |||
| data = data[:1] | |||
| if mode == "set": | |||
| op = builtin.BatchedSetMeshIndexing(items) | |||
| if mode == "inc": | |||
| op = builtin.BatchedIncrMeshIndexing(items) | |||
| @trace(symbolic=True, capture_as_const=True) | |||
| def fwd(*tensors): | |||
| return apply(op, *tensors)[0] | |||
| result = fwd(*data, *tensors) | |||
| check_pygraph_dump(fwd, data + tensors, [result]) | |||
| @pytest.mark.skip(reason="tmp skip") | |||
| def test_assert_equal(): | |||
| g = G.Graph() | |||
| inp1 = g.make_h2d(dtype=np.float32, device="xpux") | |||
| inp2 = g.make_h2d(dtype=np.float32, device="xpux") | |||
| op = builtin.AssertEqual(maxerr=1e-5) | |||
| out = G.apply_normal_varnode(op, inp1._node, inp2._node)[0] | |||
| print(out) | |||
| g.compile(out) | |||
| file = io.BytesIO() | |||
| out_model = G.dump_graph([out]) | |||
| file.write(out_model[0]) | |||
| file.seek(0) | |||
| net = Net.load(file) | |||
| dump_file = io.BytesIO() | |||
| net.dump(dump_file) | |||
| dump_file.seek(0) | |||
| g = GraphInference(dump_file) | |||
| g.run(np.array([1.0, 2.0]), np.array([1.0, 2.0])) | |||
| def test_elemwise_multitype(): | |||
| op = builtin.ElemwiseMultiType(mode="QADD", dtype=dtype.qint32(2.0)) | |||
| @trace(symbolic=True, capture_as_const=True) | |||
| def fwd(x, y): | |||
| return apply(op, x, y)[0] | |||
| x = Tensor(np.random.random(10) * 10, dtype=dtype.qint8(2.0)) | |||
| y = Tensor(np.random.random(10) * 10, dtype=dtype.qint8(2.0)) | |||
| result = fwd(x, y) | |||
| check_pygraph_dump(fwd, [x, y], [result]) | |||
| def test_cvtcolor(): | |||
| inp = np.random.randn(3, 3, 3, 3).astype(np.float32) | |||
| x = Tensor(inp) | |||
| @trace(symbolic=True, capture_as_const=True) | |||
| def fwd(inp): | |||
| return F.img_proc.cvt_color(inp, mode="RGB2GRAY") | |||
| result = fwd(x) | |||
| check_pygraph_dump(fwd, [x], [result]) | |||
| @@ -17,9 +17,20 @@ | |||
| #include "megbrain/opr/dnn/local.h" | |||
| #include "megbrain/opr/dnn/lrn.h" | |||
| #include "megbrain/opr/dnn/pooling.h" | |||
| #include "megbrain/opr/dnn/adaptive_pooling.h" | |||
| #include "megbrain/opr/dnn/roi_pooling.h" | |||
| #include "megbrain/opr/dnn/roi_align.h" | |||
| #include "megbrain/opr/imgproc.h" | |||
| #include "megbrain/opr/standalone/nms_opr.h" | |||
| #include "megbrain/opr/io.h" | |||
| #include "megbrain/opr/tensor_manip.h" | |||
| #include "megbrain/opr/rand.h" | |||
| #include "megbrain/opr/dnn/batch_norm.h" | |||
| #include "megbrain/opr/misc.h" | |||
| #include "megbrain/opr/indexing.h" | |||
| #include "megbrain/opr/internal/indexing_helper.h" | |||
| #include "megbrain/opr/nn_int.h" | |||
| #include "megbrain/opr/tensor_gen.h" | |||
| #if MGB_ENABLE_JSON | |||
| #include "megdnn/opr_param_json.h" | |||
| #endif | |||
| @@ -354,7 +365,7 @@ uint64_t opr_footprint_func<opr::DeformableConvForward>( | |||
| auto&& out_shape = opr->output()[0]->shape(); | |||
| auto&& filter_shape = opr->input()[1]->shape(); | |||
| using Param = opr::DeformableConvForward::Param; | |||
| auto&& param = opr->cast_final_safe<opr::Convolution>().param(); | |||
| auto&& param = opr->cast_final_safe<opr::DeformableConvForward>().param(); | |||
| size_t fh, fw, icpg; | |||
| mgb_assert(param.format == Param::Format::NCHW); | |||
| if (param.sparse == Param::Sparse::GROUP) { | |||
| @@ -425,9 +436,11 @@ uint64_t opr_footprint_func<opr::BatchConvBiasForward>( | |||
| auto&& filter_shape = opr->input()[1]->shape(); | |||
| using Param = opr::BatchConvBiasForward::Param; | |||
| auto&& param = opr->cast_final_safe<opr::BatchConvBiasForward>().param(); | |||
| mgb_assert(param.format == Param::Format::NCHW4); | |||
| size_t packed_channels = 4; | |||
| size_t packed_channels = 1; | |||
| size_t kern_spatial_pos = 3; | |||
| if (param.format == Param::Format::NCHW4) { | |||
| packed_channels = 4; | |||
| } | |||
| size_t fh = filter_shape[kern_spatial_pos], | |||
| fw = filter_shape[kern_spatial_pos + 1]; | |||
| return out_shape.total_nr_elems() * fh * fw * src_shape[1] * | |||
| @@ -508,7 +521,29 @@ REGISTE_PARAM_JSON_FUNC(LocalShareBackwardFilter) | |||
| REGISTE_PARAM_JSON_FUNC(DeformableConvForward) | |||
| REGISTE_PARAM_JSON_FUNC(DeformableConvBackwardFilter) | |||
| REGISTE_PARAM_JSON_FUNC(DeformableConvBackwardData) | |||
| REGISTE_PARAM_JSON_FUNC(DeformablePSROIPoolingForward) | |||
| REGISTE_PARAM_JSON_FUNC(BatchConvBiasForward) | |||
| REGISTE_PARAM_JSON_FUNC(BatchNormForward) | |||
| REGISTE_PARAM_JSON_FUNC(ElemwiseMultiType) | |||
| REGISTE_PARAM_JSON_FUNC(Argsort) | |||
| REGISTE_PARAM_JSON_FUNC(Argmax) | |||
| REGISTE_PARAM_JSON_FUNC(Argmin) | |||
| REGISTE_PARAM_JSON_FUNC(AdaptivePooling) | |||
| REGISTE_PARAM_JSON_FUNC(ROIPooling) | |||
| REGISTE_PARAM_JSON_FUNC(ROIAlign) | |||
| REGISTE_PARAM_JSON_FUNC(WarpPerspective) | |||
| REGISTE_PARAM_JSON_FUNC(WarpAffine) | |||
| REGISTE_PARAM_JSON_FUNC(Remap) | |||
| REGISTE_PARAM_JSON_FUNC(Resize) | |||
| REGISTE_PARAM_JSON_FUNC(IndexingOneHot) | |||
| REGISTE_PARAM_JSON_FUNC(IndexingSetOneHot) | |||
| REGISTE_PARAM_JSON_FUNC(TopK) | |||
| REGISTE_PARAM_JSON_FUNC(UniformRNG) | |||
| REGISTE_PARAM_JSON_FUNC(GaussianRNG) | |||
| REGISTE_PARAM_JSON_FUNC(Linspace) | |||
| REGISTE_PARAM_JSON_FUNC(Eye) | |||
| REGISTE_PARAM_JSON_FUNC(CvtColor) | |||
| template <> | |||
| std::shared_ptr<json::Value> opr_param_json_func<opr::Dimshuffle>( | |||
| @@ -547,24 +582,83 @@ std::shared_ptr<json::Value> opr_param_json_func<opr::AxisAddRemove>( | |||
| }); | |||
| } | |||
| std::shared_ptr<json::Value> indexing_param_to_json( | |||
| const std::vector<opr::indexing::AxisIndexer>& indices) { | |||
| auto desc = json::Array::make(); | |||
| for (auto& index : indices) { | |||
| desc->add(json::Object::make({ | |||
| {"axis", json::NumberInt::make(index.axis.get_raw())}, | |||
| {"begin", | |||
| json::NumberInt::make(index.begin.node() != nullptr)}, | |||
| {"end", json::NumberInt::make(index.end.node() != nullptr)}, | |||
| {"step", | |||
| json::NumberInt::make(index.step.node() != nullptr)}, | |||
| {"idx", json::NumberInt::make(index.idx.node() != nullptr)}, | |||
| })); | |||
| } | |||
| return desc; | |||
| } | |||
| #define REGISTE_INDEXING_PARAM_JSON_FUNC(cls) \ | |||
| template <> \ | |||
| std::shared_ptr<json::Value> opr_param_json_func<opr::cls>( \ | |||
| cg::OperatorNodeBase * opr) { \ | |||
| auto indices = opr->cast_final_safe<opr::cls>().index_desc(); \ | |||
| return indexing_param_to_json(indices); \ | |||
| } | |||
| REGISTE_INDEXING_PARAM_JSON_FUNC(Subtensor); | |||
| REGISTE_INDEXING_PARAM_JSON_FUNC(SetSubtensor); | |||
| REGISTE_INDEXING_PARAM_JSON_FUNC(IncrSubtensor); | |||
| REGISTE_INDEXING_PARAM_JSON_FUNC(IndexingMultiAxisVec); | |||
| REGISTE_INDEXING_PARAM_JSON_FUNC(IndexingSetMultiAxisVec); | |||
| REGISTE_INDEXING_PARAM_JSON_FUNC(IndexingIncrMultiAxisVec); | |||
| REGISTE_INDEXING_PARAM_JSON_FUNC(MeshIndexing); | |||
| REGISTE_INDEXING_PARAM_JSON_FUNC(IncrMeshIndexing); | |||
| REGISTE_INDEXING_PARAM_JSON_FUNC(SetMeshIndexing); | |||
| REGISTE_INDEXING_PARAM_JSON_FUNC(BatchedMeshIndexing); | |||
| REGISTE_INDEXING_PARAM_JSON_FUNC(BatchedIncrMeshIndexing); | |||
| REGISTE_INDEXING_PARAM_JSON_FUNC(BatchedSetMeshIndexing); | |||
| template <> | |||
| std::shared_ptr<json::Value> opr_param_json_func<opr::Subtensor>( | |||
| std::shared_ptr<json::Value> opr_param_json_func<opr::Reshape>( | |||
| cg::OperatorNodeBase * opr) { | |||
| auto desc = json::Array::make(); | |||
| auto indices = opr->cast_final_safe<opr::Subtensor>().index_desc(); | |||
| for (auto &index : indices){ | |||
| desc->add( | |||
| json::Object::make({ | |||
| {"axis", json::NumberInt::make(index.axis.get_raw())}, | |||
| {"begin", json::NumberInt::make(index.begin.node() != nullptr)}, | |||
| {"end", json::NumberInt::make(index.end.node() != nullptr)}, | |||
| {"step", json::NumberInt::make(index.step.node() != nullptr)}, | |||
| {"idx", json::NumberInt::make(index.idx.node() != nullptr)}, | |||
| })); | |||
| auto axis_param = opr->cast_final_safe<opr::Reshape>().param(); | |||
| if (axis_param.axis != axis_param.MAX_NDIM){ | |||
| return json::Object::make({ | |||
| {"axis", json::NumberInt::make(axis_param.axis)}, | |||
| }); | |||
| } else { | |||
| return json::Object::make(); | |||
| } | |||
| } | |||
| return desc; | |||
| template <> | |||
| std::shared_ptr<json::Value> opr_param_json_func<opr::GetVarShape>( | |||
| cg::OperatorNodeBase * opr) { | |||
| auto desc = json::Array::make(); | |||
| auto axis_param = opr->cast_final_safe<opr::GetVarShape>().param(); | |||
| if (axis_param.axis != axis_param.MAX_NDIM){ | |||
| return json::Object::make({ | |||
| {"axis", json::NumberInt::make(axis_param.axis)}, | |||
| }); | |||
| } else { | |||
| return json::Object::make(); | |||
| } | |||
| } | |||
| template <> | |||
| std::shared_ptr<json::Value> opr_param_json_func<opr::standalone::NMSKeep>( | |||
| cg::OperatorNodeBase * opr) { | |||
| auto nms_param = opr->cast_final_safe<opr::standalone::NMSKeep>().param(); | |||
| return json::Object::make({ | |||
| {"iou_thresh", json::Number::make(nms_param.iou_thresh)}, | |||
| {"max_output", json::Number::make(nms_param.max_output)}, | |||
| }); | |||
| } | |||
| #endif // MGB_ENABLE_JSON | |||
| } // namespace | |||
| @@ -632,6 +726,17 @@ void OprFootprint::init_all_footprints() { | |||
| add_single_param_json<opr::Dimshuffle>(); | |||
| add_single_param_json<opr::AxisAddRemove>(); | |||
| add_single_param_json<opr::Subtensor>(); | |||
| add_single_param_json<opr::SetSubtensor>(); | |||
| add_single_param_json<opr::IncrSubtensor>(); | |||
| add_single_param_json<opr::IndexingMultiAxisVec>(); | |||
| add_single_param_json<opr::IndexingSetMultiAxisVec>(); | |||
| add_single_param_json<opr::IndexingIncrMultiAxisVec>(); | |||
| add_single_param_json<opr::MeshIndexing>(); | |||
| add_single_param_json<opr::SetMeshIndexing>(); | |||
| add_single_param_json<opr::IncrMeshIndexing>(); | |||
| add_single_param_json<opr::BatchedMeshIndexing>(); | |||
| add_single_param_json<opr::BatchedSetMeshIndexing>(); | |||
| add_single_param_json<opr::BatchedIncrMeshIndexing>(); | |||
| add_single_param_json<opr::Reduce>(); | |||
| add_single_param_json<opr::LocalShareForward>(); | |||
| add_single_param_json<opr::LocalShareBackwardData>(); | |||
| @@ -639,7 +744,31 @@ void OprFootprint::init_all_footprints() { | |||
| add_single_param_json<opr::DeformableConvForward>(); | |||
| add_single_param_json<opr::DeformableConvBackwardFilter>(); | |||
| add_single_param_json<opr::DeformableConvBackwardData>(); | |||
| add_single_param_json<opr::DeformablePSROIPoolingForward>(); | |||
| add_single_param_json<opr::BatchConvBiasForward>(); | |||
| add_single_param_json<opr::BatchNormForward>(); | |||
| add_single_param_json<opr::Reshape>(); | |||
| add_single_param_json<opr::GetVarShape>(); | |||
| add_single_param_json<opr::Argsort>(); | |||
| add_single_param_json<opr::Argmin>(); | |||
| add_single_param_json<opr::Argmax>(); | |||
| add_single_param_json<opr::ElemwiseMultiType>(); | |||
| add_single_param_json<opr::AdaptivePooling>(); | |||
| add_single_param_json<opr::ROIPooling>(); | |||
| add_single_param_json<opr::ROIAlign>(); | |||
| add_single_param_json<opr::WarpPerspective>(); | |||
| add_single_param_json<opr::Remap>(); | |||
| add_single_param_json<opr::Resize>(); | |||
| add_single_param_json<opr::IndexingOneHot>(); | |||
| add_single_param_json<opr::IndexingSetOneHot>(); | |||
| add_single_param_json<opr::WarpAffine>(); | |||
| add_single_param_json<opr::TopK>(); | |||
| add_single_param_json<opr::UniformRNG>(); | |||
| add_single_param_json<opr::GaussianRNG>(); | |||
| add_single_param_json<opr::Linspace>(); | |||
| add_single_param_json<opr::Eye>(); | |||
| add_single_param_json<opr::standalone::NMSKeep>(); | |||
| add_single_param_json<opr::CvtColor>(); | |||
| #endif | |||
| } | |||