GitOrigin-RevId: ff8eb003c5
tags/v1.3.0
| @@ -96,7 +96,7 @@ class Graph(_imperative_rt.ComputingGraph): | |||
| data = data.numpy() | |||
| return self._wrap(_imperative_rt.make_const(self, data, device, data.dtype)) | |||
| def make_const(self, data, dtype=None, device=None): | |||
| def make_const(self, data, dtype=None, device=None, name=None): | |||
| if isinstance(data, _imperative_rt.DeviceTensorND): | |||
| assert dtype is None and device is None | |||
| return self._wrap(_imperative_rt.make_shared(self, data)) | |||
| @@ -107,7 +107,9 @@ class Graph(_imperative_rt.ComputingGraph): | |||
| elif data.dtype == np.int64: | |||
| data = data.astype(np.int32) | |||
| device = as_device(device).to_c() | |||
| return self._wrap(_imperative_rt.make_const(self, data, device, dtype)) | |||
| return self._wrap( | |||
| _imperative_rt.make_const(self, data, device, dtype, name) | |||
| ) | |||
| def make_input(self, *args: "VarNode", device=None, dtype=None, shape=None): | |||
| opnode = InputNode(*args, device=device, dtype=dtype, shape=shape, graph=self) | |||
| @@ -305,7 +307,7 @@ def dump_graph( | |||
| output_vars: Union[Dict[str, VarNode], List[VarNode]], | |||
| *, | |||
| keep_var_name: int = 1, | |||
| keep_op_name: bool = True, | |||
| keep_opr_name: bool = False, | |||
| keep_param_name: bool = False, | |||
| keep_opr_priority: bool = False, | |||
| strip_info_file=None, | |||
| @@ -326,7 +328,7 @@ def dump_graph( | |||
| * 0: none of the names are kept | |||
| * 1: (default)keep names of output vars | |||
| * 2: keep names of all (output and internal) vars | |||
| :param keep_op_name: whether to keep operator names. | |||
| :param keep_opr_name: whether to keep operator names. | |||
| :param keep_param_name: whether to keep param names, so param values can be | |||
| easily manipulated after loading model | |||
| :param keep_opr_priority: whether to keep priority setting for operators | |||
| @@ -370,7 +372,7 @@ def dump_graph( | |||
| dump_content = _imperative_rt.dump_graph( | |||
| ov, | |||
| keep_var_name, | |||
| keep_op_name, | |||
| keep_opr_name, | |||
| keep_param_name, | |||
| keep_opr_priority, | |||
| stat, | |||
| @@ -36,6 +36,7 @@ from ..core.ops.builtin import BackwardGraph, OpDef | |||
| from ..core.ops.special import Const | |||
| from ..core.tensor import megbrain_graph as G | |||
| from ..core.tensor.utils import setscalar | |||
| from ..utils.naming import auto_naming | |||
| from .sublinear_memory_config import SublinearMemoryConfig | |||
| @@ -77,6 +78,7 @@ def exclude_from_trace(): | |||
| class TensorInfo: | |||
| __slots__ = ( | |||
| # collected attributes | |||
| "name", | |||
| "external", | |||
| "data_read", | |||
| "shape_read", | |||
| @@ -96,6 +98,7 @@ class TensorInfo: | |||
| ) | |||
| def __init__(self): | |||
| self.name = None | |||
| self.exported = None | |||
| self.data_read = None | |||
| self.shape_read = None | |||
| @@ -290,12 +293,16 @@ class trace: | |||
| h = getattr(x, "_mixin_handle", -1) | |||
| if h < 0 or (not self._capture_as_const and self._tinfo[h].exported): | |||
| h, info = self._new_handle() | |||
| name = auto_naming.get_scope() + "." + x.c_name if x.c_name else x._name | |||
| info.name = name | |||
| info.external = True | |||
| info.device = x.device | |||
| info.dtype = x.dtype | |||
| info.shape = x.shape | |||
| if self._capture_as_const: | |||
| info.bound_data = RawTensor(x.numpy(), x.dtype, x.device, False) | |||
| info.bound_data = RawTensor( | |||
| x.numpy(), x.dtype, x.device, False, name | |||
| ) | |||
| ihandles.append(h) | |||
| @@ -669,6 +676,12 @@ class trace: | |||
| arg_names=None, | |||
| output_names=None, | |||
| append=False, | |||
| keep_var_name: int = 1, | |||
| keep_opr_name: bool = False, | |||
| keep_param_name: bool = False, | |||
| keep_opr_priority: bool = False, | |||
| strip_info_file=None, | |||
| append_json=False, | |||
| optimize_for_inference=True, | |||
| **kwargs | |||
| ): | |||
| @@ -681,6 +694,20 @@ class trace: | |||
| use the default name if not specified. | |||
| :param append: whether output is appended to ``file``. | |||
| Only works when ``file`` is str. | |||
| :param keep_var_name: level for keeping variable names: | |||
| * 0: none of the names are kept | |||
| * 1: (default)keep names of output vars | |||
| * 2: keep names of all (output and internal) vars | |||
| :param keep_opr_name: whether to keep operator names. | |||
| :param keep_param_name: whether to keep param names, so param values can be | |||
| easily manipulated after loading model | |||
| :param keep_opr_priority: whether to keep priority setting for operators | |||
| :param strip_info_file: a string for path or a file handler. if is not None, | |||
| then the dump information for code strip would be written to ``strip_info_file`` | |||
| :param append_json: will be check when `strip_info_file` is not None. if set | |||
| true, the information for code strip will be append to strip_info_file. | |||
| if set false, will rewrite strip_info_file | |||
| :param optimize_for_inference: enbale optmizations, | |||
| will skip all optimize options if this is False. Default: True | |||
| @@ -785,7 +812,10 @@ class trace: | |||
| assert info.external | |||
| assert info.bound_data | |||
| h2v[h] = graph.make_const( | |||
| info.bound_data.numpy(), dtype=info.dtype, device=info.device, | |||
| info.bound_data.numpy(), | |||
| dtype=info.dtype, | |||
| device=info.device, | |||
| name=info.name, | |||
| ) | |||
| continue | |||
| ivars = [] | |||
| @@ -795,13 +825,26 @@ class trace: | |||
| assert info.external | |||
| assert info.bound_data | |||
| h2v[h] = graph.make_const( | |||
| info.bound_data.numpy(), dtype=info.dtype, device=dumped_device | |||
| info.bound_data.numpy(), | |||
| dtype=info.dtype, | |||
| device=dumped_device, | |||
| name=info.name, | |||
| ) | |||
| ivars.append(h2v[h]) | |||
| ovars = G.apply_normal_varnode(op, *ivars) | |||
| auto_naming.record_opnode(ovars[0].op) | |||
| assert len(ovars) == len(ohandles) | |||
| h2v.update(zip(ohandles, ovars)) | |||
| for i in ohandles: | |||
| name = auto_naming.get_var_name(i) | |||
| if name is not None: | |||
| h2v[i].name = name | |||
| auto_naming.remove_duplicate_names() | |||
| dest_vars = [] | |||
| for i, h in enumerate(self._output_bindings): | |||
| v = h2v[h] | |||
| @@ -815,7 +858,15 @@ class trace: | |||
| if isinstance(file, str): | |||
| permission = "wb" if append == False else "ab" | |||
| file = open(file, permission) | |||
| dump_content, dump_info = G.dump_graph(dest_vars) | |||
| dump_content, dump_info = G.dump_graph( | |||
| dest_vars, | |||
| keep_var_name=keep_var_name, | |||
| keep_opr_name=keep_opr_name, | |||
| keep_param_name=keep_param_name, | |||
| keep_opr_priority=keep_opr_priority, | |||
| strip_info_file=strip_info_file, | |||
| append_json=append_json, | |||
| ) | |||
| file.write(dump_content) | |||
| return dump_info | |||
| @@ -1095,20 +1146,22 @@ def apply_compiled_mode(op: OpDef, *args: RawTensor): | |||
| return active_trace._apply_op(op, args) | |||
| def apply_const_compiled_mode(value, dtype, device, is_const, no_cache): | |||
| def apply_const_compiled_mode(value, dtype, device, is_const, no_cache, name): | |||
| if skip_tracing: | |||
| args = [ | |||
| RawTensor(x._dev_tensor()) if x.__class__ is CompiledTensorProxy else x | |||
| for x in args | |||
| ] | |||
| unset_tracing() | |||
| ret = RawTensor(value, dtype, device, False) | |||
| ret = RawTensor(value, dtype, device, False, name) | |||
| set_tracing() | |||
| return ret | |||
| return active_trace._apply_const(value, dtype, device) | |||
| def apply_with_tracing(op: OpDef, *args: RawTensor): | |||
| if hasattr(op, "scope"): | |||
| op.scope = auto_naming.get_scope() | |||
| if active_trace._symbolic: | |||
| outputs = apply_symbolic_mode(op, *args) | |||
| else: | |||
| @@ -1120,12 +1173,12 @@ def apply_with_tracing(op: OpDef, *args: RawTensor): | |||
| return list(outputs) | |||
| def apply_const_with_tracing(value, dtype, device, is_const, no_cache): | |||
| def apply_const_with_tracing(value, dtype, device, is_const, no_cache, name): | |||
| if active_trace._symbolic: | |||
| outputs = apply_const_symbolic_mode(value, dtype, device) | |||
| else: | |||
| unset_tracing() | |||
| outputs = (RawTensor(value, dtype, device, False),) | |||
| outputs = (RawTensor(value, dtype, device, False, name),) | |||
| set_tracing() | |||
| active_trace._record_const(outputs) | |||
| return list(outputs) | |||
| @@ -12,12 +12,12 @@ from typing import Any, Callable, Iterable, Optional, Set, Tuple, Union | |||
| import numpy as np | |||
| from ..core._imperative_rt.core2 import pop_scope, push_scope | |||
| from ..core.tensor.utils import make_shape_tuple | |||
| from ..logger import get_logger | |||
| from ..tensor import Parameter, Tensor | |||
| from ..utils.deprecation import deprecated | |||
| from ..utils.hook import HookHandler | |||
| from ..utils.naming import auto_naming | |||
| logger = get_logger(__name__) | |||
| @@ -69,7 +69,9 @@ class Module(metaclass=ABCMeta): | |||
| Base Module class. | |||
| """ | |||
| def __init__(self): | |||
| def __init__(self, name=""): | |||
| self.name = name | |||
| # runtime attributes | |||
| self.training = True | |||
| self.quantize_disabled = False | |||
| @@ -79,6 +81,8 @@ class Module(metaclass=ABCMeta): | |||
| self._forward_hooks = OrderedDict() | |||
| self._modules = [] | |||
| # used for profiler and automatic naming | |||
| self._name = "{anonymous}" | |||
| @abstractmethod | |||
| @@ -105,7 +109,7 @@ class Module(metaclass=ABCMeta): | |||
| return HookHandler(self._forward_hooks, hook) | |||
| def __call__(self, *inputs, **kwargs): | |||
| push_scope(self._name) | |||
| auto_naming.push_scope(self.name if self.name else self._name) | |||
| for hook in self._forward_pre_hooks.values(): | |||
| modified_inputs = hook(self, inputs) | |||
| if modified_inputs is not None: | |||
| @@ -119,7 +123,7 @@ class Module(metaclass=ABCMeta): | |||
| modified_outputs = hook(self, inputs, outputs) | |||
| if modified_outputs is not None: | |||
| outputs = modified_outputs | |||
| pop_scope(self._name) | |||
| auto_naming.pop_scope() | |||
| return outputs | |||
| def _flatten( | |||
| @@ -579,7 +583,7 @@ class Module(metaclass=ABCMeta): | |||
| value = super().__getattribute__(name) | |||
| if name == "_name": | |||
| return value | |||
| if _is_module(value): | |||
| if isinstance(value, (Tensor, Module)): | |||
| value._name = name | |||
| return value | |||
| @@ -20,6 +20,7 @@ from .core.tensor.array_method import ArrayMethodMixin | |||
| from .device import _valid_device, get_default_device | |||
| from .logger import get_logger | |||
| from .utils.deprecation import deprecated | |||
| from .utils.naming import auto_naming | |||
| class Tensor(_Tensor, ArrayMethodMixin): | |||
| @@ -27,7 +28,9 @@ class Tensor(_Tensor, ArrayMethodMixin): | |||
| dmap_callback = None | |||
| _q_dict = None | |||
| def __new__(cls, data, dtype=None, device=None, is_const=False, no_cache=False): | |||
| def __new__( | |||
| cls, data, dtype=None, device=None, is_const=False, no_cache=False, name="" | |||
| ): | |||
| if device is None: | |||
| cn = get_default_device() | |||
| elif isinstance(device, str): | |||
| @@ -51,8 +54,7 @@ class Tensor(_Tensor, ArrayMethodMixin): | |||
| if isinstance(data, np.ndarray): | |||
| if 0 in data.strides: | |||
| data = data.squeeze().reshape(data.shape) | |||
| obj = _Tensor.__new__(cls, data, dtype, cn, is_const, no_cache) | |||
| obj = _Tensor.__new__(cls, data, dtype, cn, is_const, no_cache, name) | |||
| return obj | |||
| @property | |||
| @@ -91,6 +93,15 @@ class Tensor(_Tensor, ArrayMethodMixin): | |||
| piece += ", device={}".format(self.device) + ")" | |||
| return piece | |||
| @property | |||
| def name(self): | |||
| return self.c_name | |||
| @name.setter | |||
| def name(self, name): | |||
| self.c_name = name | |||
| auto_naming.record_var_name(self._mixin_handle, name) | |||
| @deprecated(version="1.0", reason="no need to reuse an existing tensor since 1.0") | |||
| def set_value(self, value): | |||
| if not isinstance(value, _Tensor): | |||
| @@ -0,0 +1,63 @@ | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from ..core._imperative_rt.core2 import pop_scope, push_scope | |||
| class AutoNaming: | |||
| r""" | |||
| Name all executed operators automaticlly during tracing and record all tensors | |||
| renamed by the user. | |||
| """ | |||
| def __init__(self): | |||
| self.scopes = [] | |||
| self.c_ops = [] | |||
| self.name2ops = {} | |||
| self.handle2names = {} | |||
| def clear(self): | |||
| for var in vars(self).values(): | |||
| var.clear() | |||
| def push_scope(self, scope): | |||
| push_scope(scope) | |||
| self.scopes.append(scope) | |||
| def pop_scope(self): | |||
| scope = self.scopes.pop() | |||
| pop_scope(scope) | |||
| def get_scope(self): | |||
| return ".".join(self.scopes) | |||
| def record_var_name(self, handle, name): | |||
| self.handle2names[handle] = name | |||
| def get_var_name(self, handle): | |||
| return self.handle2names.pop(handle, None) | |||
| def record_opnode(self, op): | |||
| ops = self.name2ops.get(op.name, []) | |||
| ops.append(op) | |||
| self.name2ops[op.name] = ops | |||
| def remove_duplicate_names(self): | |||
| for key, ops in self.name2ops.items(): | |||
| if len(ops) == 1: | |||
| continue | |||
| for i, op in enumerate(ops): | |||
| op.name = key + "[%s]" % str(i) | |||
| if len(op.outputs) == 1: | |||
| continue | |||
| for var in op.outputs: | |||
| var.name = var.name.replace(key, op.name) | |||
| self.name2ops.clear() | |||
| auto_naming = AutoNaming() | |||
| @@ -294,7 +294,7 @@ void init_graph_rt(py::module m) { | |||
| m.def("dump_graph", []( | |||
| const std::vector<VarNode*>& dest_vars, | |||
| int keep_var_name, | |||
| bool keep_op_name, | |||
| bool keep_opr_name, | |||
| bool keep_param_name, | |||
| bool keep_opr_priority, | |||
| py::list& stat, | |||
| @@ -307,7 +307,7 @@ void init_graph_rt(py::module m) { | |||
| SymbolVarArray symvars(dest_vars.begin(), dest_vars.end()); | |||
| ser::GraphDumper::DumpConfig config{keep_var_name, keep_param_name, | |||
| keep_opr_priority, keep_op_name}; | |||
| keep_opr_priority, keep_opr_name}; | |||
| auto rst = dumper->dump(symvars, config); | |||
| for (auto i : rst.inputs) { | |||
| @@ -457,13 +457,17 @@ void init_graph_rt(py::module m) { | |||
| return opr::SharedDeviceTensor::make(*graph, std::make_shared<DeviceTensorND>(data)).node(); | |||
| }); | |||
| m.def("make_const", [](cg::ComputingGraph* graph, py::array data, CompNode cn, DType dtype) { | |||
| m.def("make_const", [](cg::ComputingGraph* graph, py::array data, CompNode cn, DType dtype, std::optional<std::string> name) { | |||
| if (!cn.valid()) { | |||
| cn = CompNode::load(get_default_device()); | |||
| } | |||
| OperatorNodeConfig config(cn); | |||
| if (name) { | |||
| config.name(*name); | |||
| } | |||
| auto hv = npy::np2tensor(data.ptr(), npy::Meth::borrow(cn), dtype); | |||
| return opr::ImmutableTensor::make(*graph, hv, OperatorNodeConfig(cn)).node(); | |||
| }); | |||
| return opr::ImmutableTensor::make(*graph, hv, config).node(); | |||
| }, py::arg(), py::arg(), py::arg(), py::arg(), py::arg() = py::none()); | |||
| m.def("make_h2d", [](cg::ComputingGraph& graph, CompNode cn, DType dtype, TensorShape shape, std::optional<std::string> name) { | |||
| if (!cn.valid()) { | |||
| @@ -99,6 +99,14 @@ PyObject* py_get_generic_impl(PyObject* obj, void* /* closure */) { | |||
| #define py_get_generic(name, attr) \ | |||
| py_get_generic_impl<PyOp(name), decltype(std::declval<name>().attr), &name::attr> | |||
| template<typename T> | |||
| PyObject* py_get_scope_impl(PyObject* obj, void* /* closure */) { | |||
| // T: PyOpXXX inst(): return XXX in opdef.h.inl | |||
| auto& op = reinterpret_cast<T*>(obj)->inst(); | |||
| return pyobj_convert_generic<std::string>::to(op.scope()); | |||
| } | |||
| #define py_get_scope(class) py_get_scope_impl<PyOp(class)> | |||
| template<typename T, typename U, U T::Ty::*attr> | |||
| int py_set_generic_impl(PyObject* obj, PyObject* value, void* /* closure */) { | |||
| if (value == NULL) { | |||
| @@ -121,6 +129,27 @@ int py_set_generic_impl(PyObject* obj, PyObject* value, void* /* closure */) { | |||
| #define py_set_generic(name, attr) \ | |||
| py_set_generic_impl<PyOp(name), decltype(std::declval<name>().attr), &name::attr> | |||
| template<typename T> | |||
| int py_set_scope_impl(PyObject* obj, PyObject* value, void* /* closure */) { | |||
| if (value == NULL) { | |||
| PyErr_SetString(PyExc_TypeError, "Cannot delete the attribute"); | |||
| return -1; | |||
| } | |||
| auto& op = reinterpret_cast<T*>(obj)->inst(); | |||
| try { | |||
| op.set_scope(pyobj_convert_generic<std::string>::from(value)); | |||
| return 0; | |||
| } catch(py::error_already_set& e) { | |||
| e.restore(); | |||
| } catch(py::builtin_exception& e) { | |||
| e.set_error(); | |||
| } catch(...) { | |||
| PyErr_SetString(PyExc_RuntimeError, "Unknown Error"); | |||
| } | |||
| return -1; | |||
| } | |||
| #define py_set_scope(class) py_set_scope_impl<PyOp(class)> | |||
| struct PyOpDef { | |||
| PyObject_HEAD | |||
| std::shared_ptr<OpDef> op; | |||
| @@ -24,6 +24,7 @@ | |||
| #include <pybind11/numpy.h> | |||
| #include <pybind11/operators.h> | |||
| #include <range/v3/all.hpp> | |||
| #include <string> | |||
| #include <unordered_map> | |||
| @@ -222,14 +223,15 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) { | |||
| } | |||
| } else { | |||
| py::detail::loader_life_support life_sup; // FIXME!!!required to cast DType | |||
| if (nargs != 4 && nargs != 5) { | |||
| throw py::type_error("expect 4 or 5 arguments"); | |||
| if (nargs != 5 && nargs != 6) { | |||
| throw py::type_error("expect 5 or 6 arguments"); | |||
| } | |||
| auto data = tup[0].cast<py::array>(); | |||
| DType dtype = tup[1].cast<DType>(); | |||
| CompNode cn = tup[2].cast<CompNode>(); | |||
| bool is_const = tup[3].cast<bool>(); | |||
| bool no_cache = nargs == 5 ? tup[4].cast<bool>() : false; | |||
| bool no_cache = nargs == 6 ? tup[4].cast<bool>() : false; | |||
| std::string name = tup[nargs - 1].cast<std::string>(); | |||
| // const op | |||
| if (is_const && is_tracing) { | |||
| @@ -259,6 +261,7 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) { | |||
| } | |||
| m_tensor = std::make_shared<Tensor>(handle); | |||
| m_tensor->user_custom_name = name; | |||
| if (data.ndim() == 0) { | |||
| m_tensor->m_flags |= Tensor::Flags::SCALAR; | |||
| @@ -313,6 +316,19 @@ REGISTE_TENSORWRAPPER_PYOBJECT_FUNC(trace_mixin_info) | |||
| #undef REGISTE_TENSORWRAPPER_PYOBJECT_FUNC | |||
| #define SET_GET_NAME(member) \ | |||
| PyObject* TensorWrapper::member() { \ | |||
| return py::cast(m_tensor->member).release().ptr(); \ | |||
| } \ | |||
| void TensorWrapper::set_##member(PyObject* dest) { \ | |||
| auto py_dest = py::reinterpret_borrow<py::object>(dest); \ | |||
| m_tensor->member = py_dest.cast<std::string>(); \ | |||
| } | |||
| SET_GET_NAME(user_custom_name) | |||
| SET_GET_NAME(automatic_name) | |||
| #undef SET_GET_NAME | |||
| PyObject* TensorWrapper::handle() { | |||
| return py::cast(m_tensor->m_handle).release().ptr(); | |||
| } | |||
| @@ -453,7 +469,11 @@ void TensorWrapper::reset(PyObject* tensor) { | |||
| if (!t) { | |||
| throw py::type_error("expect Tensor"); | |||
| } | |||
| std::string user_custom_name = m_tensor->user_custom_name; | |||
| std::string automatic_name = m_tensor->automatic_name; | |||
| m_tensor = t->m_tensor; | |||
| m_tensor->user_custom_name = user_custom_name; | |||
| m_tensor->automatic_name = automatic_name; | |||
| } | |||
| void TensorWrapper::reset_varnode() { | |||
| @@ -785,6 +805,8 @@ void init_tensor(py::module m) { | |||
| .def_getset<&TensorWrapper::handle, &TensorWrapper::set_handle>("_handle") | |||
| .def_getset<&TensorWrapper::compiled_info, &TensorWrapper::set_compiled_info>("_compiled_info") | |||
| .def_getset<&TensorWrapper::trace_mixin_info, &TensorWrapper::set_trace_mixin_info>("_trace_mixin_info") | |||
| .def_getset<&TensorWrapper::user_custom_name, &TensorWrapper::set_user_custom_name>("c_name") | |||
| .def_getset<&TensorWrapper::automatic_name, &TensorWrapper::set_automatic_name>("_name") | |||
| .finalize(); | |||
| if (!tensor_type) throw py::error_already_set(); | |||
| py::setattr(m, "Tensor", tensor_type); | |||
| @@ -15,6 +15,7 @@ | |||
| #include "megbrain/imperative/interpreter.h" | |||
| #include "pybind11/pybind11.h" | |||
| #include <string> | |||
| #include "./pyext17.h" | |||
| @@ -70,6 +71,8 @@ struct Tensor : std::enable_shared_from_this<Tensor>, NonCopyableObj { | |||
| GradInfo m_grad_info; | |||
| TraceInfo m_trace_info; | |||
| SharedHandle m_handle; | |||
| std::string user_custom_name; | |||
| std::string automatic_name; | |||
| cg::VarNode* m_var; | |||
| using Handle = interpreter::Interpreter::Handle; | |||
| @@ -170,6 +173,10 @@ struct TensorWrapper { | |||
| void set_compiled_info(PyObject *); | |||
| PyObject* trace_mixin_info(); | |||
| void set_trace_mixin_info(PyObject *); | |||
| PyObject* user_custom_name(); | |||
| void set_user_custom_name(PyObject *); | |||
| PyObject* automatic_name(); | |||
| void set_automatic_name(PyObject *); | |||
| PyObject* _use_cnt() { return PyLong_FromSize_t(m_tensor.use_count()); }; | |||
| }; | |||
| @@ -0,0 +1,169 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import io | |||
| import numpy as np | |||
| import pytest | |||
| import megengine.functional as F | |||
| import megengine.module as M | |||
| import megengine.utils.comp_graph_tools as cgtools | |||
| from megengine import Parameter, Tensor | |||
| from megengine.core.tensor import megbrain_graph as G | |||
| from megengine.jit.tracing import trace | |||
| from megengine.utils.naming import auto_naming | |||
| def _dump_and_load(func, symbolic, keep_opr_name=True): | |||
| auto_naming.clear() | |||
| func = trace(func, symbolic=symbolic, capture_as_const=True) | |||
| x = Tensor(np.ones(shape=(2, 3))) | |||
| func(x).numpy() | |||
| file = io.BytesIO() | |||
| func.dump( | |||
| file, | |||
| optimize_for_inference=False, | |||
| arg_names="x", | |||
| keep_opr_name=keep_opr_name, | |||
| keep_var_name=2, | |||
| ) | |||
| file.seek(0) | |||
| *_, outputs = G.load_graph(file) | |||
| op = cgtools.get_oprs_seq(outputs)[-1] | |||
| return op | |||
| @pytest.mark.parametrize("symbolic", [False, True]) | |||
| def test_auto_naming(symbolic): | |||
| class Simple(M.Module): | |||
| def __init__(self, name): | |||
| super().__init__() | |||
| self.name = name | |||
| def forward(self, x): | |||
| return x + x | |||
| m = Simple("simple") | |||
| op = _dump_and_load(m, symbolic) | |||
| assert op.name == "simple.ADD" | |||
| assert op.outputs[0].name == "simple.ADD" | |||
| @pytest.mark.parametrize("symbolic", [False, True]) | |||
| def test_user_named_tensor(symbolic): | |||
| class Simple(M.Module): | |||
| def __init__(self, name): | |||
| super().__init__() | |||
| self.name = name | |||
| self.k = Parameter(1.0, name="k") | |||
| def forward(self, x): | |||
| x = x + x | |||
| x.name = "o_x" | |||
| return x | |||
| m = Simple("simple") | |||
| op = _dump_and_load(m, symbolic) | |||
| assert op.name == "simple.ADD" | |||
| assert op.outputs[0].name == "o_x" | |||
| @pytest.mark.parametrize("symbolic", [False, True]) | |||
| def test_user_named_param(symbolic): | |||
| class Simple(M.Module): | |||
| def __init__(self, name): | |||
| super().__init__() | |||
| self.name = name | |||
| self.k = Parameter(2.0, name="k") | |||
| def forward(self, x): | |||
| return self.k * x | |||
| m = Simple("simple") | |||
| op = _dump_and_load(m, symbolic) | |||
| assert op.inputs[0].name == "x" | |||
| assert op.inputs[1].name == "simple.k" | |||
| @pytest.mark.parametrize("symbolic", [False, True]) | |||
| def test_without_module(symbolic): | |||
| def f(x): | |||
| return 2 * x | |||
| op = _dump_and_load(f, symbolic) | |||
| assert op.name == "MUL" | |||
| @pytest.mark.parametrize("symbolic", [False, True]) | |||
| def test_with_submodule(symbolic): | |||
| class Simple(M.Module): | |||
| def __init__(self, name): | |||
| super().__init__() | |||
| self.name = name | |||
| self.linear = M.Linear(3, 3) | |||
| def forward(self, x): | |||
| x = self.linear(x) | |||
| return x | |||
| m = Simple("simple") | |||
| op = _dump_and_load(m, symbolic) | |||
| assert op.name == "simple.linear.ADD" | |||
| assert op.inputs[0].owner.name == "simple.linear.MatrixMul" | |||
| assert op.outputs[0].name == "simple.linear.ADD" | |||
| @pytest.mark.parametrize("symbolic", [False, True]) | |||
| def test_named_submodule(symbolic): | |||
| class Simple(M.Module): | |||
| def __init__(self, name): | |||
| super().__init__() | |||
| self.name = name | |||
| self.linear = M.Linear(3, 3, name="x") | |||
| def forward(self, x): | |||
| x = self.linear(x) | |||
| return x | |||
| m = Simple("simple") | |||
| op = _dump_and_load(m, symbolic) | |||
| assert op.name == "simple.x.ADD" | |||
| assert op.inputs[0].owner.name == "simple.x.MatrixMul" | |||
| assert op.outputs[0].name == "simple.x.ADD" | |||
| @pytest.mark.parametrize("symbolic", [False, True]) | |||
| def test_with_same_operators(symbolic): | |||
| class Simple(M.Module): | |||
| def __init__(self, name): | |||
| super().__init__() | |||
| self.name = name | |||
| def forward(self, x): | |||
| x = F.relu(x) | |||
| x = F.relu(x) | |||
| return x | |||
| m = Simple("simple") | |||
| op = _dump_and_load(m, symbolic) | |||
| assert op.name == "simple.RELU[1]" | |||
| assert op.inputs[0].owner.name == "simple.RELU[0]" | |||
| def test_not_keep_opr_name(): | |||
| def f(x): | |||
| return 2 * x | |||
| op = _dump_and_load(f, True, False) | |||
| assert op.name == "MUL(x,2[2])[4]" | |||
| @@ -148,7 +148,7 @@ def test_dump(): | |||
| dump_info = f.dump(file) | |||
| assert dump_info.nr_opr == 3 | |||
| np.testing.assert_equal(dump_info.inputs, ["arg_0", "arg_1"]) | |||
| np.testing.assert_equal(dump_info.outputs, ["ADD(arg_0,arg_1)[4]"]) | |||
| np.testing.assert_equal(dump_info.outputs, ["ADD"]) | |||
| file.seek(0) | |||
| infer_cg = cgtools.GraphInference(file) | |||
| result = list((infer_cg.run(a, b)).values())[0] | |||
| @@ -75,10 +75,6 @@ std::vector<std::pair<const char*, std::string>> OpDef::props( | |||
| return def.trait()->props(def); | |||
| } | |||
| const char* OpDef::name() const { | |||
| return trait()->name; | |||
| } | |||
| std::string OpDef::to_string() const { | |||
| std::string builder = "{"; | |||
| for (auto&& [name, value]: props(*this)) { | |||
| @@ -107,6 +103,20 @@ const OpTrait* OpDef::trait() const { | |||
| return m_trait; | |||
| } | |||
| const std::string OpDef::scope() const { | |||
| return m_scope; | |||
| } | |||
| void OpDef::set_scope(const std::string& scope) { | |||
| m_scope = scope; | |||
| } | |||
| const std::string OpDef::make_name() const { | |||
| if (m_scope.empty()) | |||
| return trait()->make_name(*this); | |||
| return m_scope + "." + trait()->make_name(*this); | |||
| } | |||
| } // namespace imperative | |||
| } // namespace mgb | |||
| @@ -75,6 +75,7 @@ using GradMaker = detail::OpMeth< | |||
| using Props = detail::OpMeth<decltype(OpDef::props)>; | |||
| using HashFunc = detail::OpMeth<size_t(const OpDef&)>; | |||
| using IsSame = detail::OpMeth<bool(const OpDef&, const OpDef&)>; | |||
| using MakeNameFunc = detail::OpMeth<std::string(const OpDef&)>; | |||
| struct OpTrait { | |||
| const char* name; | |||
| @@ -88,6 +89,7 @@ struct OpTrait { | |||
| Props props; | |||
| HashFunc hash; | |||
| IsSame is_same_st; | |||
| MakeNameFunc make_name; | |||
| OpTrait(const char* name); | |||
| static OpTrait* find_by_name(const char* name); | |||
| static OpTrait* find_by_typeinfo(Typeinfo* type); | |||
| @@ -104,7 +106,8 @@ struct OpTrait { | |||
| cb(make_backward_graph) \ | |||
| cb(props) \ | |||
| cb(hash) \ | |||
| cb(is_same_st) | |||
| cb(is_same_st) \ | |||
| cb(make_name) | |||
| struct OpTraitRegistry { | |||
| OpTrait* trait; | |||
| @@ -30,13 +30,14 @@ cg::OperatorNodeBase* apply_on_var_node( | |||
| size_t nr_inp = inputs.size(); | |||
| mgb_assert(nr_inp == 3 ||nr_inp == 5, | |||
| "BatchNorm expects 3 or 5 inputs; got %lu actually", nr_inp); | |||
| OperatorNodeConfig config{bn_opr.make_name()}; | |||
| if (nr_inp == 3) { | |||
| return opr::BatchNorm::make( | |||
| inputs[0], inputs[1], inputs[2], bn_opr.param())[0] | |||
| inputs[0], inputs[1], inputs[2], bn_opr.param(), config)[0] | |||
| .node()->owner_opr(); | |||
| } else { | |||
| return opr::BatchNorm::make( | |||
| inputs[0], inputs[1], inputs[2], inputs[3], inputs[4], bn_opr.param())[0] | |||
| inputs[0], inputs[1], inputs[2], inputs[3], inputs[4], bn_opr.param(), config)[0] | |||
| .node()->owner_opr(); | |||
| } | |||
| } | |||
| @@ -27,10 +27,11 @@ std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) { | |||
| cg::OperatorNodeBase* apply_on_var_node( | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| def.cast_final_safe<Broadcast>(); | |||
| auto&& op = def.cast_final_safe<Broadcast>(); | |||
| size_t nr_inp = inputs.size(); | |||
| mgb_assert(nr_inp == 2, "Broadcast expects 2 inputs; got %lu actually", nr_inp); | |||
| return opr::Broadcast::make(inputs[0], inputs[1]).node()->owner_opr(); | |||
| OperatorNodeConfig config{op.make_name()}; | |||
| return opr::Broadcast::make(inputs[0], inputs[1], config).node()->owner_opr(); | |||
| } | |||
| bool valid_broadcast(const TensorShape& src_shape, | |||
| @@ -96,7 +97,8 @@ auto apply_on_var_node( | |||
| const VarNodeArray& inputs) { | |||
| auto&& op = static_cast<const Reshape&>(def); | |||
| mgb_assert(inputs.size() == 2); | |||
| return opr::Reshape::make(inputs[0], inputs[1], op.param()); | |||
| OperatorNodeConfig config{op.make_name()}; | |||
| return opr::Reshape::make(inputs[0], inputs[1], op.param(), config); | |||
| } | |||
| std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||
| @@ -35,7 +35,7 @@ cg::OperatorNodeBase* apply_on_var_node( | |||
| auto disable = std::make_shared<DTypeScalar>(); | |||
| disable->set(0); | |||
| cg::OperatorNodeConfig config; | |||
| OperatorNodeConfig config{comm.make_name()}; | |||
| if (comm.comp_node.size() > 0) { | |||
| config.comp_node(CompNode::load(comm.comp_node)); | |||
| } | |||
| @@ -23,12 +23,12 @@ namespace { | |||
| cg::OperatorNodeBase* apply_on_var_node( | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| def.cast_final_safe<CondTake>(); | |||
| auto&& op = def.cast_final_safe<CondTake>(); | |||
| auto&& graph = inputs[0]->owner_graph(); | |||
| opr::CondTake::Param param; | |||
| param.val = 1; | |||
| cg::OperatorNodeConfig config; | |||
| OperatorNodeConfig config{op.make_name()}; | |||
| cg::OperatorNodeBase* opr = graph->insert_opr( | |||
| std::make_unique<opr::CondTake>( | |||
| inputs[0], inputs[1], param, config)); | |||
| @@ -31,7 +31,8 @@ cg::OperatorNodeBase* apply_on_var_node( | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| auto&& elemwise_opr = def.cast_final_safe<Elemwise>(); | |||
| return opr::Elemwise::make(inputs, elemwise_opr.mode).node()->owner_opr(); | |||
| OperatorNodeConfig config{elemwise_opr.make_name()}; | |||
| return opr::Elemwise::make(inputs, elemwise_opr.mode, config).node()->owner_opr(); | |||
| } | |||
| std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||
| @@ -23,7 +23,8 @@ auto apply_on_var_node( | |||
| const VarNodeArray& inputs) { | |||
| auto&& op = static_cast<const CvtColor&>(def); | |||
| mgb_assert(inputs.size() == 1); | |||
| return opr::CvtColor::make(inputs[0], op.param()); | |||
| OperatorNodeConfig config{op.make_name()}; | |||
| return opr::CvtColor::make(inputs[0], op.param(), config); | |||
| } | |||
| OP_TRAIT_REG(CvtColor, CvtColor) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| @@ -32,7 +32,7 @@ cg::OperatorNodeBase* apply_on_var_node_remote_send( | |||
| ssprintf("%s:%d", send.addr.data(), send.port)); | |||
| auto&& graph = inputs[0]->owner_graph(); | |||
| cg::OperatorNodeConfig config; | |||
| OperatorNodeConfig config{send.make_name()}; | |||
| cg::OperatorNodeBase* opr = | |||
| graph->insert_opr(std::make_unique<mgb::opr::RemoteSend>( | |||
| send.key, inputs[0], group_client, true, config)); | |||
| @@ -42,11 +42,13 @@ cg::OperatorNodeBase* apply_on_var_node_remote_send( | |||
| cg::OperatorNodeBase* apply_on_var_node_remote_recv( | |||
| const OpDef& def, const VarNodeArray& inputs) { | |||
| auto&& recv = def.cast_final_safe<RemoteRecv>(); | |||
| OperatorNodeConfig config{recv.cn}; | |||
| config.name(recv.make_name()); | |||
| auto group_client = std::make_shared<GroupClientProxy>( | |||
| ssprintf("%s:%d", recv.addr.data(), recv.port)); | |||
| auto&& graph = inputs[0]->owner_graph(); | |||
| return graph->insert_opr(std::make_unique<mgb::opr::RemoteRecv>( | |||
| recv.key, inputs[0], *graph, group_client, OperatorNodeConfig{recv.cn}, | |||
| recv.key, inputs[0], *graph, group_client, config, | |||
| recv.shape, recv.dtype)); | |||
| } | |||
| @@ -21,8 +21,10 @@ namespace { | |||
| auto apply_on_var_node( | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| auto&& op = def.cast_final_safe<MatrixInverse>(); | |||
| mgb_assert(inputs.size() == 1); | |||
| return opr::MatrixInverse::make(inputs[0]); | |||
| OperatorNodeConfig config{op.make_name()}; | |||
| return opr::MatrixInverse::make(inputs[0], {}, config); | |||
| } | |||
| OP_TRAIT_REG(MatrixInverse, MatrixInverse) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| @@ -29,7 +29,9 @@ cg::OperatorNodeBase* apply_on_var_node( | |||
| param.iou_thresh = nms_keep.iou_thresh; | |||
| param.max_output = nms_keep.max_output; | |||
| return NMSKeepOpr::make(inputs[0], param).node()->owner_opr(); | |||
| OperatorNodeConfig config{nms_keep.make_name()}; | |||
| return NMSKeepOpr::make(inputs[0], param, config).node()->owner_opr(); | |||
| } | |||
| OP_TRAIT_REG(NMSKeep, NMSKeep, NMSKeepOpr) | |||
| @@ -79,11 +79,13 @@ public: | |||
| cg::OperatorNodeBase* apply_on_var_node( | |||
| const OpDef& def, const VarNodeArray& inputs) { | |||
| auto&& attr = def.cast_final_safe<OprAttr>(); | |||
| auto config = attr.config; | |||
| config.name(attr.make_name()); | |||
| mgb_assert(!inputs.empty()); | |||
| auto registry = serialization::OprRegistry::find_by_name(attr.type); | |||
| mgb_assert(registry, "operator %s not found", attr.type.c_str()); | |||
| OprParamsLoadContext ctx{attr.param, inputs[0]->owner_graph()}; | |||
| return registry->loader(ctx, inputs, attr.config); | |||
| return registry->loader(ctx, inputs, config); | |||
| } | |||
| std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* opr) { | |||
| @@ -99,10 +101,15 @@ std::vector<std::pair<const char*, std::string>> props(const OpDef& def) { | |||
| return {}; | |||
| } | |||
| std::string make_name(const OpDef& def) { | |||
| return "OprAttr"; | |||
| } | |||
| OP_TRAIT_REG(OprAttr, OprAttr) | |||
| .make_from_op_node(make_from_op_node) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .props(props) | |||
| .make_name(make_name) | |||
| .fallback(); | |||
| } // anonymous namespace | |||
| @@ -24,7 +24,8 @@ auto apply_on_var_node( | |||
| const VarNodeArray& inputs) { | |||
| auto&& op = static_cast<const Resize&>(def); | |||
| mgb_assert(inputs.size() == 2); | |||
| return opr::Resize::make(inputs[0], inputs[1], op.param()); | |||
| OperatorNodeConfig config{op.make_name()}; | |||
| return opr::Resize::make(inputs[0], inputs[1], op.param(), config); | |||
| } | |||
| OP_TRAIT_REG(Resize, Resize) | |||
| @@ -46,7 +46,8 @@ auto apply_on_var_node( | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| auto&& conv = static_cast<const Convolution&>(def); | |||
| return opr::Convolution::make(inputs[0], inputs[1], conv.param(), conv.policy()); | |||
| OperatorNodeConfig config{conv.make_name()}; | |||
| return opr::Convolution::make(inputs[0], inputs[1], conv.param(), conv.policy(), config); | |||
| } | |||
| OP_TRAIT_REG(Convolution, Convolution, opr::Convolution) | |||
| @@ -60,7 +61,7 @@ auto apply_on_var_node( | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| auto&& conv = static_cast<const ConvolutionBackwardData&>(def); | |||
| cg::OperatorNodeConfig config; | |||
| OperatorNodeConfig config{conv.make_name()}; | |||
| if (inputs.size() == 2) { | |||
| return opr::ConvolutionBackwardData::make(inputs[0], inputs[1], conv.param(), conv.policy(), config); | |||
| } else { | |||
| @@ -88,7 +89,8 @@ auto apply_on_var_node( | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| auto&& ds = static_cast<const Dimshuffle&>(def); | |||
| return opr::Dimshuffle::make(inputs[0], ds.pattern); | |||
| OperatorNodeConfig config{ds.make_name()}; | |||
| return opr::Dimshuffle::make(inputs[0], ds.pattern, 0UL, config); | |||
| } | |||
| OP_TRAIT_REG(Dimshuffle, Dimshuffle, opr::Dimshuffle) | |||
| @@ -107,7 +109,8 @@ auto apply_on_var_node( | |||
| for (auto&& i : add_axis.axis) { | |||
| param.push_back(Desc::make_add(i)); | |||
| } | |||
| return opr::AxisAddRemove::make(inputs[0], param); | |||
| OperatorNodeConfig config{add_axis.make_name()}; | |||
| return opr::AxisAddRemove::make(inputs[0], param, config); | |||
| } | |||
| OP_TRAIT_REG(AddAxis, AddAxis) | |||
| @@ -125,7 +128,8 @@ auto apply_on_var_node( | |||
| for (auto&& i : remove_axis.axis) { | |||
| param.push_back(Desc::make_remove(i)); | |||
| } | |||
| return opr::AxisAddRemove::make(inputs[0], param); | |||
| OperatorNodeConfig config{remove_axis.make_name()}; | |||
| return opr::AxisAddRemove::make(inputs[0], param, config); | |||
| } | |||
| OP_TRAIT_REG(RemoveAxis, RemoveAxis) | |||
| @@ -138,7 +142,8 @@ auto apply_on_var_node( | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| auto&& topk = static_cast<const TopK&>(def); | |||
| return opr::TopK::make(inputs[0], inputs[1], topk.param())[0] | |||
| OperatorNodeConfig config{topk.make_name()}; | |||
| return opr::TopK::make(inputs[0], inputs[1], topk.param(), config)[0] | |||
| .node()->owner_opr(); | |||
| } | |||
| @@ -152,10 +157,12 @@ auto apply_on_var_node( | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| auto&& reduce = static_cast<const Reduce&>(def); | |||
| OperatorNodeConfig config{reduce.make_name()}; | |||
| if (inputs.size() > 1) { | |||
| return opr::Reduce::make(inputs[0], reduce.param(), inputs[1]); | |||
| return opr::Reduce::make(inputs[0], reduce.param(), inputs[1], config); | |||
| } else { | |||
| return opr::Reduce::make(inputs[0], reduce.param()); | |||
| return opr::Reduce::make( | |||
| inputs[0], reduce.param(), (cg::VarNode*)nullptr, config); | |||
| } | |||
| } | |||
| @@ -175,7 +182,8 @@ auto apply_on_var_node( | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| auto&& pool = static_cast<const AdaptivePooling&>(def); | |||
| return opr::AdaptivePooling::make(inputs[0], inputs[1], pool.param()); | |||
| OperatorNodeConfig config{pool.make_name()}; | |||
| return opr::AdaptivePooling::make(inputs[0], inputs[1], pool.param(), config); | |||
| } | |||
| OP_TRAIT_REG(AdaptivePooling, AdaptivePooling) | |||
| @@ -189,6 +197,7 @@ auto apply_on_var_node( | |||
| const VarNodeArray& inputs) { | |||
| auto&& conv = static_cast<const ConvBias&>(def); | |||
| cg::OperatorNodeConfig config{conv.dtype}; | |||
| config.name(conv.make_name()); | |||
| if (inputs.size() == 2) { | |||
| return opr::ConvBias::make(inputs[0], inputs[1], conv.param(), conv.policy(), config); | |||
| } else if (inputs.size() == 3) { | |||
| @@ -210,6 +219,7 @@ auto apply_on_var_node( | |||
| const VarNodeArray& inputs) { | |||
| auto&& conv = static_cast<const BatchConvBias&>(def); | |||
| cg::OperatorNodeConfig config{conv.dtype}; | |||
| config.name(conv.make_name()); | |||
| if (inputs.size() == 2) { | |||
| return opr::BatchConvBias::make(inputs[0], inputs[1], conv.param(), conv.policy(), config); | |||
| } else if (inputs.size() == 3) { | |||
| @@ -230,7 +240,8 @@ auto apply_on_var_node( | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| auto&& pool = static_cast<const Pooling&>(def); | |||
| return opr::Pooling::make(inputs[0], pool.param()); | |||
| OperatorNodeConfig config{pool.make_name()}; | |||
| return opr::Pooling::make(inputs[0], pool.param(), config); | |||
| } | |||
| OP_TRAIT_REG(Pooling, Pooling) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| @@ -243,8 +254,9 @@ auto apply_on_var_node( | |||
| const VarNodeArray& inputs) { | |||
| auto&& matmul = static_cast<const MatrixMul&>(def); | |||
| mgb_assert(inputs.size() == 2); | |||
| OperatorNodeConfig config{matmul.make_name()}; | |||
| return opr::MatrixMul::make(inputs[0], inputs[1], matmul.param(), | |||
| matmul.policy()); | |||
| matmul.policy(), config); | |||
| } | |||
| OP_TRAIT_REG(MatrixMul, MatrixMul) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| @@ -257,8 +269,9 @@ auto apply_on_var_node( | |||
| const VarNodeArray& inputs) { | |||
| auto&& matmul = static_cast<const BatchedMatrixMul&>(def); | |||
| mgb_assert(inputs.size() == 2); | |||
| OperatorNodeConfig config{matmul.make_name()}; | |||
| return opr::BatchedMatrixMul::make(inputs[0], inputs[1], matmul.param(), | |||
| matmul.policy()); | |||
| matmul.policy(), config); | |||
| } | |||
| OP_TRAIT_REG(BatchedMatrixMul, BatchedMatrixMul) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| @@ -267,10 +280,12 @@ OP_TRAIT_REG(BatchedMatrixMul, BatchedMatrixMul) | |||
| namespace { namespace dot { | |||
| auto apply_on_var_node( | |||
| const OpDef&, | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| auto&& op = def.cast_final_safe<Dot>(); | |||
| mgb_assert(inputs.size() == 2); | |||
| return opr::Dot::make(inputs[0], inputs[1]); | |||
| OperatorNodeConfig config{op.make_name()}; | |||
| return opr::Dot::make(inputs[0], inputs[1], config); | |||
| } | |||
| OP_TRAIT_REG(Dot, Dot) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| @@ -282,7 +297,8 @@ auto apply_on_var_node( | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| auto&& argsort = static_cast<const Argsort&>(def); | |||
| return opr::Argsort::make(inputs[0], argsort.param()); | |||
| OperatorNodeConfig config{argsort.make_name()}; | |||
| return opr::Argsort::make(inputs[0], argsort.param(), config); | |||
| } | |||
| OP_TRAIT_REG(Argsort, Argsort) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| @@ -294,7 +310,8 @@ auto apply_on_var_node( | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| auto&& argmax = static_cast<const Argmax&>(def); | |||
| return opr::Argmax::make(inputs[0], argmax.param()); | |||
| OperatorNodeConfig config{argmax.make_name()}; | |||
| return opr::Argmax::make(inputs[0], argmax.param(), config); | |||
| } | |||
| OP_TRAIT_REG(Argmax, Argmax) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| @@ -306,7 +323,8 @@ auto apply_on_var_node( | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| auto&& argmin = static_cast<const Argmin&>(def); | |||
| return opr::Argmin::make(inputs[0], argmin.param()); | |||
| OperatorNodeConfig config{argmin.make_name()}; | |||
| return opr::Argmin::make(inputs[0], argmin.param(), config); | |||
| } | |||
| OP_TRAIT_REG(Argmin, Argmin) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| @@ -318,11 +336,13 @@ auto apply_on_var_node( | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| auto&& warp = static_cast<const WarpPerspective&>(def); | |||
| OperatorNodeConfig config{warp.make_name()}; | |||
| if (inputs.size() == 3) { | |||
| return opr::WarpPerspective::make(inputs[0], inputs[1], inputs[2], warp.param()); | |||
| return opr::WarpPerspective::make(inputs[0], inputs[1], inputs[2], warp.param(), config); | |||
| } else { | |||
| mgb_assert(inputs.size() == 4); | |||
| return opr::WarpPerspective::make(inputs[0], inputs[1], inputs[2], inputs[3], warp.param()); | |||
| return opr::WarpPerspective::make( | |||
| inputs[0], inputs[1], inputs[2], inputs[3], warp.param(), config); | |||
| } | |||
| } | |||
| OP_TRAIT_REG(WarpPerspective, WarpPerspective) | |||
| @@ -336,7 +356,8 @@ auto apply_on_var_node( | |||
| const VarNodeArray& inputs) { | |||
| auto&& local = static_cast<const GroupLocal&>(def); | |||
| mgb_assert(inputs.size() == 2); | |||
| return opr::GroupLocal::make(inputs[0], inputs[1], local.param()); | |||
| OperatorNodeConfig config{local.make_name()}; | |||
| return opr::GroupLocal::make(inputs[0], inputs[1], local.param(), config); | |||
| } | |||
| OP_TRAIT_REG(GroupLocal, GroupLocal) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| @@ -349,7 +370,8 @@ auto apply_on_var_node( | |||
| const VarNodeArray& inputs) { | |||
| auto&& op = static_cast<const IndexingOneHot&>(def); | |||
| mgb_assert(inputs.size() == 2); | |||
| return opr::IndexingOneHot::make(inputs[0], inputs[1], op.param()); | |||
| OperatorNodeConfig config{op.make_name()}; | |||
| return opr::IndexingOneHot::make(inputs[0], inputs[1], op.param(), config); | |||
| } | |||
| OP_TRAIT_REG(IndexingOneHot, IndexingOneHot) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| @@ -362,7 +384,8 @@ auto apply_on_var_node( | |||
| const VarNodeArray& inputs) { | |||
| auto&& op = static_cast<const IndexingSetOneHot&>(def); | |||
| mgb_assert(inputs.size() == 3); | |||
| return opr::IndexingSetOneHot::make(inputs[0], inputs[1], inputs[2], op.param()); | |||
| OperatorNodeConfig config{op.make_name()}; | |||
| return opr::IndexingSetOneHot::make(inputs[0], inputs[1], inputs[2], op.param(), config); | |||
| } | |||
| OP_TRAIT_REG(IndexingSetOneHot, IndexingSetOneHot) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| @@ -375,7 +398,8 @@ auto apply_on_var_node( | |||
| const VarNodeArray& inputs) { | |||
| auto&& op = static_cast<const TypeCvt&>(def); | |||
| mgb_assert(inputs.size() == 1); | |||
| return opr::TypeCvt::make(inputs[0], op.dtype); | |||
| OperatorNodeConfig config{op.make_name()}; | |||
| return opr::TypeCvt::make(inputs[0], op.dtype, config); | |||
| } | |||
| OP_TRAIT_REG(TypeCvt, TypeCvt) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| @@ -388,6 +412,7 @@ auto apply_on_var_node( | |||
| const VarNodeArray& inputs) { | |||
| auto&& op = static_cast<const Concat&>(def); | |||
| cg::OperatorNodeConfig config{op.comp_node}; | |||
| config.name(op.make_name()); | |||
| return opr::Concat::make(inputs, op.axis, config); | |||
| } | |||
| OP_TRAIT_REG(Concat, Concat) | |||
| @@ -402,6 +427,7 @@ auto apply_on_var_node( | |||
| auto&& op = static_cast<const Copy&>(def); | |||
| mgb_assert(inputs.size() == 1); | |||
| cg::OperatorNodeConfig config{op.comp_node}; | |||
| config.name(op.make_name()); | |||
| return opr::Copy::make(inputs[0], config); | |||
| } | |||
| OP_TRAIT_REG(Copy, Copy) | |||
| @@ -411,10 +437,12 @@ OP_TRAIT_REG(Copy, Copy) | |||
| namespace { namespace identity { | |||
| auto apply_on_var_node( | |||
| const OpDef&, | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| auto&& op = def.cast_final_safe<Identity>(); | |||
| mgb_assert(inputs.size() == 1); | |||
| return opr::Identity::make(inputs[0]); | |||
| OperatorNodeConfig config{op.make_name()}; | |||
| return opr::Identity::make(inputs[0], config); | |||
| } | |||
| OP_TRAIT_REG(Identity, Identity) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| @@ -427,7 +455,8 @@ auto apply_on_var_node( | |||
| const VarNodeArray& inputs) { | |||
| auto&& op = static_cast<const AssertEqual&>(def); | |||
| mgb_assert(inputs.size() == 2); | |||
| return opr::AssertEqual::make(inputs[0],inputs[1],op.param()); | |||
| OperatorNodeConfig config{op.make_name()}; | |||
| return opr::AssertEqual::make(inputs[0], inputs[1], op.param(), config); | |||
| } | |||
| @@ -443,7 +472,8 @@ auto apply_on_var_node( | |||
| const VarNodeArray& inputs) { | |||
| auto&& op = static_cast<const UniformRNG&>(def); | |||
| mgb_assert(inputs.size() == 1); | |||
| return opr::UniformRNG::make(inputs[0], op.param()); | |||
| OperatorNodeConfig config{op.make_name()}; | |||
| return opr::UniformRNG::make(inputs[0], op.param(), config); | |||
| } | |||
| OP_TRAIT_REG(UniformRNG, UniformRNG) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| @@ -456,7 +486,8 @@ auto apply_on_var_node( | |||
| const VarNodeArray& inputs) { | |||
| auto&& op = static_cast<const GaussianRNG&>(def); | |||
| mgb_assert(inputs.size() == 1); | |||
| return opr::GaussianRNG::make(inputs[0], op.param()); | |||
| OperatorNodeConfig config{op.make_name()}; | |||
| return opr::GaussianRNG::make(inputs[0], op.param(), config); | |||
| } | |||
| OP_TRAIT_REG(GaussianRNG, GaussianRNG) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| @@ -469,7 +500,9 @@ VarNodeArray apply_on_var_node( | |||
| const VarNodeArray& inputs) { | |||
| auto&& op = static_cast<const ROIAlign&>(def); | |||
| mgb_assert(inputs.size() == 2); | |||
| auto* opr = opr::ROIAlign::make(inputs[0], inputs[1], op.param()).node()->owner_opr(); | |||
| OperatorNodeConfig config{op.make_name()}; | |||
| auto* opr = opr::ROIAlign::make( | |||
| inputs[0], inputs[1], op.param(), config).node()->owner_opr(); | |||
| return {opr->output(0), opr->output(1)}; | |||
| } | |||
| OP_TRAIT_REG(ROIAlign, ROIAlign) | |||
| @@ -484,7 +517,8 @@ auto apply_on_var_node( | |||
| const VarNodeArray& inputs) { | |||
| auto&& op = static_cast<const NvOf&>(def); | |||
| mgb_assert(inputs.size() == 1); | |||
| return opr::NvOf::make(inputs[0], op.param()); | |||
| OperatorNodeConfig config{op.make_name()}; | |||
| return opr::NvOf::make(inputs[0], op.param(), config); | |||
| } | |||
| OP_TRAIT_REG(NvOf, NvOf) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| @@ -499,6 +533,7 @@ auto apply_on_var_node( | |||
| auto&& op = static_cast<const Linspace&>(def); | |||
| mgb_assert(inputs.size() == 3); | |||
| cg::OperatorNodeConfig config{op.comp_node}; | |||
| config.name(op.make_name()); | |||
| return opr::Linspace::make(inputs[0], inputs[1], inputs[2], op.param(), config); | |||
| } | |||
| OP_TRAIT_REG(Linspace, Linspace) | |||
| @@ -513,6 +548,7 @@ auto apply_on_var_node( | |||
| auto&& op = static_cast<const Eye&>(def); | |||
| mgb_assert(inputs.size() == 1); | |||
| cg::OperatorNodeConfig config{op.comp_node}; | |||
| config.name(op.make_name()); | |||
| opr::Eye::Param param{op.k, op.dtype.enumv()}; | |||
| return opr::Eye::make(inputs[0], param, config); | |||
| } | |||
| @@ -527,7 +563,10 @@ VarNodeArray apply_on_var_node( | |||
| const VarNodeArray& inputs) { | |||
| auto&& op = static_cast<const ROIPooling&>(def); | |||
| mgb_assert(inputs.size() == 3); | |||
| auto* opr = opr::ROIPooling::make(inputs[0], inputs[1], inputs[2], op.param()).node()->owner_opr(); | |||
| OperatorNodeConfig config{op.make_name()}; | |||
| auto* opr = opr::ROIPooling::make( | |||
| inputs[0], inputs[1], inputs[2], op.param(), config | |||
| ).node()->owner_opr(); | |||
| return {opr->output(0), opr->output(1)}; | |||
| } | |||
| OP_TRAIT_REG(ROIPooling, ROIPooling) | |||
| @@ -541,7 +580,8 @@ auto apply_on_var_node( | |||
| const VarNodeArray& inputs) { | |||
| auto&& op = static_cast<const Remap&>(def); | |||
| mgb_assert(inputs.size() == 2); | |||
| return opr::Remap::make(inputs[0], inputs[1], op.param()); | |||
| OperatorNodeConfig config{op.make_name()}; | |||
| return opr::Remap::make(inputs[0], inputs[1], op.param(), config); | |||
| } | |||
| OP_TRAIT_REG(Remap, Remap) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| @@ -578,7 +618,8 @@ auto apply_on_var_node( \ | |||
| const OpDef& def, \ | |||
| const VarNodeArray& inputs) { \ | |||
| auto&& op = static_cast<const NAME&>(def); \ | |||
| return opr::NAME::make(IN##NR_INPUT, get_index(inputs, NR_INPUT, op.items)); \ | |||
| OperatorNodeConfig config{op.make_name()}; \ | |||
| return opr::NAME::make(IN##NR_INPUT, get_index(inputs, NR_INPUT, op.items), config); \ | |||
| } \ | |||
| OP_TRAIT_REG(NAME, NAME) \ | |||
| .apply_on_var_node(apply_on_var_node) \ | |||
| @@ -609,30 +650,35 @@ auto apply_on_var_node( | |||
| const VarNodeArray& inputs) { | |||
| auto&& op = static_cast<const FakeQuant&>(def); | |||
| mgb_assert(inputs.size() == 3); | |||
| return opr::FakeQuant::make(inputs[0], inputs[1], inputs[2], op.param()); | |||
| OperatorNodeConfig config{op.make_name()}; | |||
| return opr::FakeQuant::make(inputs[0], inputs[1], inputs[2], op.param(), config); | |||
| } | |||
| OP_TRAIT_REG(FakeQuant, FakeQuant) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .fallback(); | |||
| }} // fake_quant | |||
| namespace { namespace tqt { | |||
| auto apply_on_var_node( | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| auto&& op = static_cast<const TQT&>(def); | |||
| mgb_assert(inputs.size() == 2); | |||
| return opr::TQT::make(inputs[0], inputs[1], op.param()); | |||
| OperatorNodeConfig config{op.make_name()}; | |||
| return opr::TQT::make(inputs[0], inputs[1], op.param(), config); | |||
| } | |||
| OP_TRAIT_REG(TQT, TQT) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .fallback(); | |||
| }} // tqt | |||
| namespace { namespace elemwise_multi_type { | |||
| auto apply_on_var_node( | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| auto&& op = static_cast<const ElemwiseMultiType&>(def); | |||
| OperatorNodeConfig config{op.dtype}; | |||
| config.name(op.make_name()); | |||
| return opr::ElemwiseMultiType::make(inputs, op.param(), config); | |||
| } | |||
| OP_TRAIT_REG(ElemwiseMultiType, ElemwiseMultiType) | |||
| @@ -646,7 +692,9 @@ auto apply_on_var_node( | |||
| const VarNodeArray& inputs) { | |||
| auto&& op = static_cast<const SVD&>(def); | |||
| mgb_assert(inputs.size() == 1); | |||
| return opr::SVD::make(inputs[0], op.param())[0].node()->owner_opr()->usable_output(); | |||
| OperatorNodeConfig config{op.make_name()}; | |||
| return opr::SVD::make(inputs[0], op.param(), config)[0] | |||
| .node()->owner_opr()->usable_output(); | |||
| } | |||
| OP_TRAIT_REG(SVD, SVD) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| @@ -21,7 +21,8 @@ cg::OperatorNodeBase* apply_on_var_node( | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| auto&& op_def = def.cast_final_safe<GetVarShape>(); | |||
| return opr::GetVarShape::make(inputs, op_def.param()).node()->owner_opr(); | |||
| OperatorNodeConfig config{op_def.make_name()}; | |||
| return opr::GetVarShape::make(inputs, op_def.param(), config).node()->owner_opr(); | |||
| } | |||
| DispatchMode decide_dispatch_mode( | |||
| @@ -152,7 +153,7 @@ cg::OperatorNodeBase* param_pack_split_apply_on_var_node( | |||
| auto&& graph = inputs[0]->owner_graph(); | |||
| auto&& shapes = get_shapes(param.shapes); | |||
| cg::OperatorNodeConfig config; | |||
| OperatorNodeConfig config(param.make_name()); | |||
| cg::OperatorNodeBase* opr = | |||
| graph->insert_opr(std::make_unique<mgb::opr::ParamPackSplit>( | |||
| inputs[0], param.offsets, shapes, config)); | |||
| @@ -189,7 +190,7 @@ cg::OperatorNodeBase* param_pack_concat_apply_on_var_node( | |||
| auto&& graph = inputs[0]->owner_graph(); | |||
| VarNodeArray inps(inputs.begin(), inputs.end() - 1); | |||
| cg::OperatorNodeConfig config; | |||
| OperatorNodeConfig config{param.make_name()}; | |||
| cg::OperatorNodeBase* opr = | |||
| graph->insert_opr(std::make_unique<mgb::opr::ParamPackConcat>( | |||
| inps, inputs.back(), param.offsets, config)); | |||
| @@ -20,8 +20,9 @@ namespace { namespace tensorrt_runtime { | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| auto&& op = static_cast<const TensorRTRuntime&>(def); | |||
| OperatorNodeConfig config{op.make_name()}; | |||
| SymbolVarArray sinputs(inputs.begin(), inputs.end()); | |||
| return opr::TensorRTRuntimeOpr::make(op.buf.c_str(), op.buf_size, sinputs); | |||
| return opr::TensorRTRuntimeOpr::make(op.buf.c_str(), op.buf_size, sinputs, config); | |||
| } | |||
| OP_TRAIT_REG(TensorRTRuntime, TensorRTRuntime) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| @@ -21,7 +21,8 @@ namespace { namespace warp_affine { | |||
| const VarNodeArray& inputs) { | |||
| mgb_assert(inputs.size() == 3); | |||
| auto&& op = static_cast<const WarpAffine&>(def); | |||
| return opr::WarpAffine::make(inputs[0], inputs[1], inputs[2], op.param()); | |||
| OperatorNodeConfig config{op.make_name()}; | |||
| return opr::WarpAffine::make(inputs[0], inputs[1], inputs[2], op.param(), config); | |||
| } | |||
| OP_TRAIT_REG(WarpAffine, WarpAffine) | |||
| @@ -36,6 +36,7 @@ class OpDef : public Hashable, | |||
| public NonCopyableObj, | |||
| public std::enable_shared_from_this<OpDef> { | |||
| mutable const OpTrait* m_trait = nullptr; | |||
| std::string m_scope; | |||
| public: | |||
| virtual ~OpDef() = default; | |||
| @@ -86,10 +87,14 @@ public: | |||
| const OpTrait* trait() const; | |||
| const char* name() const; | |||
| std::string to_string() const; | |||
| const std::string scope() const; | |||
| const std::string make_name() const; | |||
| void set_scope(const std::string& scope); | |||
| virtual size_t hash() const; | |||
| virtual bool is_same_st(const Hashable&) const; | |||
| @@ -113,9 +113,10 @@ static void gen_op_def_c_header_single(raw_ostream &os, MgbOp& op) { | |||
| "{0}({0}_)", i.name | |||
| )); | |||
| } | |||
| paramList.push_back("std::string scope_ = {}"); | |||
| gen_ctor(llvm::join(paramList, ", "), | |||
| ": " + llvm::join(initList, ", "), | |||
| " {}"); | |||
| " { set_scope(scope_); }"); | |||
| } | |||
| auto packedParams = op.getPackedParams(); | |||
| @@ -236,11 +237,19 @@ static void gen_op_def_c_body_single(raw_ostream &os, MgbOp& op) { | |||
| os << mlir::tblgen::tgfmt(hashable->getPropsFunctionTemplate(), &ctx); | |||
| os << "}\n"; | |||
| // generate make_name() | |||
| os << formatv( | |||
| "std::string {0}(const OpDef& def_) {{\n", formatMethImpl("make_name") | |||
| ); | |||
| os << mlir::tblgen::tgfmt(hashable->getNameFunctionTemplate(), &ctx); | |||
| os << "}\n"; | |||
| os << "} // anonymous namespace\n"; | |||
| methods.push_back("hash"); | |||
| methods.push_back("is_same_st"); | |||
| methods.push_back("props"); | |||
| methods.push_back("make_name"); | |||
| } | |||
| if (!methods.empty()) { | |||
| os << formatv( | |||
| @@ -327,7 +336,7 @@ static void gen_op_def_pybind11_single(raw_ostream &os, MgbOp& op, EnumContext& | |||
| targs.push_back(i.attr.getReturnType()); | |||
| } | |||
| os << llvm::join(targs, ", "); | |||
| os << ">()"; | |||
| os << ", std::string>()"; | |||
| for (auto &&i : op.getMgbAttributes()) { | |||
| os << formatv(", py::arg(\"{0}\")", i.name); | |||
| auto defaultValue = i.attr.getDefaultValue(); | |||
| @@ -337,7 +346,7 @@ static void gen_op_def_pybind11_single(raw_ostream &os, MgbOp& op, EnumContext& | |||
| hasDefaultCtor = true; | |||
| } | |||
| } | |||
| os << ")"; | |||
| os << ", py::arg(\"scope\") = {})"; | |||
| } | |||
| if (hasDefaultCtor) { | |||
| os << "\n .def(py::init<>())"; | |||
| @@ -442,6 +451,10 @@ EnumWrapper<{0}::{1}>::type2str = {{ | |||
| className, i.name)); | |||
| } | |||
| getsetters.push_back(formatv( | |||
| "{{\"scope\", py_get_scope({0}), py_set_scope({0}), \"scope\", NULL},", | |||
| className)); | |||
| // generate tp_init | |||
| std::string initBody; | |||
| if (!op.getMgbAttributes().empty()) { | |||
| @@ -449,6 +462,7 @@ EnumWrapper<{0}::{1}>::type2str = {{ | |||
| llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) { | |||
| initBody += formatv("\"{0}\", ", attr.name); | |||
| }); | |||
| initBody += "\"scope\", "; | |||
| initBody += "NULL};\n"; | |||
| initBody += " PyObject "; | |||
| std::vector<std::string> attrs; | |||
| @@ -456,12 +470,15 @@ EnumWrapper<{0}::{1}>::type2str = {{ | |||
| attrs.push_back(formatv("*{0} = NULL", attr.name)); | |||
| }); | |||
| initBody += llvm::join(attrs, ", ") + ";\n"; | |||
| initBody += " PyObject *scope = NULL;\n"; | |||
| initBody += " if (!PyArg_ParseTupleAndKeywords(args, kwds, \"|"; | |||
| initBody += std::string(op.getMgbAttributes().size(), 'O'); | |||
| // an extra slot created for name | |||
| initBody += std::string(op.getMgbAttributes().size() + 1, 'O'); | |||
| initBody += "\", const_cast<char**>(kwlist)"; | |||
| llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) { | |||
| initBody += formatv(" ,&{0}", attr.name); | |||
| initBody += formatv(", &{0}", attr.name); | |||
| }); | |||
| initBody += ", &scope"; | |||
| initBody += "))\n"; | |||
| initBody += " return -1;\n"; | |||
| llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) { | |||
| @@ -483,6 +500,25 @@ EnumWrapper<{0}::{1}>::type2str = {{ | |||
| } | |||
| )", className, attr.name); | |||
| }); | |||
| initBody += formatv(R"( | |||
| if (scope) {{ | |||
| try {{ | |||
| reinterpret_cast<PyOp({0})*>(self)->inst().set_scope( | |||
| pyobj_convert_generic<std::string>::from(scope)); | |||
| } catch(py::error_already_set& e) {{ | |||
| e.restore(); | |||
| return -1; | |||
| } catch(py::builtin_exception& e) {{ | |||
| e.set_error(); | |||
| return -1; | |||
| } catch(...) {{ | |||
| PyErr_SetString(PyExc_RuntimeError, "Unknown Error"); | |||
| return -1; | |||
| } | |||
| } | |||
| )", className); | |||
| } | |||
| initBody += "\n return 0;"; | |||
| @@ -241,6 +241,30 @@ private: | |||
| body += " return props_;\n"; | |||
| return body; | |||
| } | |||
| std::string getModeName() const { | |||
| std::string body = formatv( | |||
| " auto&& op_ = def_.cast_final_safe<{0}>();\n" | |||
| " static_cast<void>(op_);\n", | |||
| getCppClassName() | |||
| ); | |||
| for (auto&& it : getMgbAttributes()) { | |||
| if (it.name == "mode") { | |||
| auto* enumAttr = llvm::dyn_cast<MgbEnumAttrMixin>(&it.attr); | |||
| body += " switch (op_.mode){\n"; | |||
| for (auto&& enumMember: enumAttr->getEnumMembers()) { | |||
| body += formatv( | |||
| " case {0}::{1}::{2}:\n", | |||
| getCppClassName(), enumAttr->getEnumName(), enumMember | |||
| ); | |||
| body += formatv(" return \"{0}\";\n", enumMember); | |||
| } | |||
| body += formatv( | |||
| " default: return \"{0}::Unknown\";\n", getCppClassName()); | |||
| body += " }\n"; | |||
| } | |||
| } | |||
| return body; | |||
| } | |||
| public: | |||
| static bool classof(const Operator* op) { | |||
| return op->getDef().isSubClassOf("MgbHashableOpMixin"); | |||
| @@ -264,6 +288,12 @@ public: | |||
| } | |||
| return getDefaultPropsFunction(); | |||
| } | |||
| std::string getNameFunctionTemplate() const { | |||
| if (getDef().getValueAsBit("usingModeName")) { | |||
| return getModeName(); | |||
| } | |||
| return formatv(" return \"{0}\";\n", getCppClassName()); | |||
| } | |||
| }; | |||
| } // namespace tblgen | |||
| @@ -476,6 +476,7 @@ def main(): | |||
| output_mgbvars = feeds["outputs"] | |||
| output_mgbvars = optimize_for_inference(args, output_mgbvars) | |||
| output_mgbvars = [var._node for var in output_mgbvars] | |||
| inputs = cgtools.get_dep_vars(output_mgbvars, "Host2DeviceCopy") | |||
| inputs = sorted((i.name, i.dtype) for i in inputs) | |||
| @@ -242,6 +242,7 @@ class MgbPackedParamBase<string className, string accessor>: | |||
| class MgbHashableOpMixin { | |||
| string hashFunction = ?; | |||
| string cmpFunction = ?; | |||
| bit usingModeName = 0; | |||
| } | |||
| class MgbOp<string mnemonic, list<MgbParamBase> params=[], list<OpTrait> traits=[]>: | |||
| @@ -21,6 +21,7 @@ include "mlir/Interfaces/SideEffectInterfaces.td" | |||
| def Elemwise : MgbHashableOp<"Elemwise", [ElemwiseParam], [NoSideEffect]> { | |||
| let inputs = (ins Variadic<AnyType>:$input); | |||
| let results = (outs AnyType); | |||
| let usingModeName = 1; | |||
| } | |||
| def Reduce: MgbHashableOp<"Reduce", [ReduceParam]>; | |||
| @@ -247,6 +248,7 @@ def ElemwiseMultiType: MgbHashableOp<"ElemwiseMultiType", [ElemwiseMultiTypePara | |||
| let extraArguments = (ins | |||
| MgbDTypeAttr:$dtype | |||
| ); | |||
| let usingModeName = 1; | |||
| } | |||
| def InplaceAdd: MgbHashableOp<"InplaceAdd", [EmptyParam]>; | |||