| @@ -52,14 +52,11 @@ class FlatBuffersWriter(IndentWriterBase): | |||
| name = p + e | |||
| e = self._enums[(p, e)] | |||
| self._write_doc(e.name) | |||
| self._write("enum %s%s : uint {", p, e.name, indent=1) | |||
| attribute = "(bit_flags)" if e.combined else "" | |||
| self._write("enum %s%s : uint %s {", p, e.name, attribute, indent=1) | |||
| for idx, member in enumerate(e.members): | |||
| self._write_doc(member) | |||
| if e.combined: | |||
| self._write("%s=%d,", scramble_enum_member_name(str(member)), | |||
| 1<<idx) | |||
| else: | |||
| self._write("%s,", scramble_enum_member_name(str(member))) | |||
| self._write("%s,", scramble_enum_member_name(str(member))) | |||
| self._write("}\n", indent=-1) | |||
| def _write_doc(self, doc): | |||
| @@ -97,8 +94,11 @@ class FlatBuffersWriter(IndentWriterBase): | |||
| return | |||
| self._write_doc(e.name) | |||
| self._used_enum.add(key) | |||
| self._write("%s:%s%s = %s;", e.name_field, p.name, e.name, | |||
| scramble_enum_member_name(str(e.members[e.default]))) | |||
| if e.combined: | |||
| default = e.compose_combined_enum(e.default) | |||
| else: | |||
| default = scramble_enum_member_name(str(e.members[e.default])) | |||
| self._write("%s:%s%s = %s;", e.name_field, p.name, e.name, default) | |||
| def _resolve_const(self, v): | |||
| while v in self._cur_const_val: | |||
| @@ -120,9 +120,12 @@ class FlatBuffersWriter(IndentWriterBase): | |||
| return | |||
| self._used_enum.add((e.src_class, e.src_name)) | |||
| enum_name = e.src_class + e.src_name | |||
| self._write( | |||
| "%s:%s = %s;", e.name_field, enum_name, | |||
| scramble_enum_member_name(str(e.src_enum.members[e.get_default()]))) | |||
| s = e.src_enum | |||
| if s.combined: | |||
| default = s.compose_combined_enum(e.get_default()) | |||
| else: | |||
| default = scramble_enum_member_name(str(s.members[e.get_default()])) | |||
| self._write("%s:%s = %s;", e.name_field, enum_name, default) | |||
| def _get_fb_default(self, cppdefault): | |||
| if not isinstance(cppdefault, str): | |||
| @@ -73,11 +73,21 @@ class member_defs: | |||
| """define an enum; the result would contain both an enum class def and its | |||
| corresponding data field | |||
| :param default: index of default member value | |||
| :param default: | |||
| for normal enum class: index of default member value | |||
| for bit combined class: tuple of index of default member value | |||
| For example, following representations of the default value for bit | |||
| combined class are all equivalent: | |||
| Enum(members=('a', 'b', 'c'), default=('a', 'b'), ...) | |||
| Enum(members=('a', 'b', 'c'), default=(0, 1), ...) | |||
| Enum(members=('a', 'b', 'c'), default=(1 << 0) | (1 << 1), ...) | |||
| :attr name_field: name of the data field of this enum in the param | |||
| struct | |||
| :attr member_alias: list of (member, alias) pairs | |||
| :attr member_alias: | |||
| for normal enum class: list of (member, alias) pairs | |||
| for bit combined class: list of (tuple of members, alias) paris | |||
| """ | |||
| __slots__ = ['name', 'name_field', 'members', 'default', | |||
| 'member_alias', 'combined'] | |||
| @@ -90,17 +100,11 @@ class member_defs: | |||
| name = member_defs.Doc.make(name) | |||
| assert name.id[0].isupper() | |||
| members = tuple(map(member_defs.Doc.make, members)) | |||
| if isinstance(default, str): | |||
| if default not in name_field: | |||
| raise ValueError( | |||
| "Default value '{}' does not exist.".format(default)) | |||
| default = name_field.index(default) | |||
| assert isinstance(default, int) | |||
| self.name = name | |||
| self.combined = combined | |||
| self.name_field = self.get_name_field(name.id, name_field) | |||
| self.members = members | |||
| self.default = default | |||
| self.default = self.normalize_enum_value(default) | |||
| self.all_enums[(param_name, name.id)] = self | |||
| @@ -114,6 +118,43 @@ class member_defs: | |||
| assert isinstance(name_field, str) | |||
| return name_field | |||
| def normalize_enum_value(self, value): | |||
| def normalize(v): | |||
| if isinstance(v, str): | |||
| if v not in self.members: | |||
| raise ValueError( | |||
| "enum member '{}' does not exist.".format(v)) | |||
| v = self.members.index(v) | |||
| assert isinstance(v, int) | |||
| return v | |||
| if self.combined: | |||
| if isinstance(value, int): | |||
| value = self.decompose_combined_enum(value) | |||
| assert isinstance(value, tuple) | |||
| value = tuple(normalize(i) for i in value) | |||
| return value | |||
| else: | |||
| return normalize(value) | |||
| @staticmethod | |||
| def decompose_combined_enum(v): | |||
| """Integer => tuple of the indexes of the enum members""" | |||
| assert isinstance(v, int) | |||
| idx = 0 | |||
| members = [] | |||
| while v > 0: | |||
| if v & 1: | |||
| members.append(idx) | |||
| idx += 1 | |||
| v >>= 1 | |||
| return tuple(members) | |||
| def compose_combined_enum(self, v): | |||
| """tuple of members => Integer""" | |||
| assert self.combined and isinstance(v, tuple) | |||
| norm_v = self.normalize_enum_value(v) | |||
| return sum(1 << i for i in norm_v) | |||
| class Field(Base): | |||
| """define a normal data field""" | |||
| __slots__ = ['name', 'dtype', 'default'] | |||
| @@ -146,6 +187,10 @@ class member_defs: | |||
| src_name = name | |||
| self.src_name = src_name | |||
| self.default = default | |||
| # TODO: remove this assertion if needed; adding mock param_defs in | |||
| # current testing framework is too complicated, and currently we | |||
| # only allow aliasing of normal enum | |||
| assert not self.src_enum.combined | |||
| @property | |||
| def src_enum(self): | |||
| @@ -157,7 +202,7 @@ class member_defs: | |||
| set""" | |||
| if self.default is None: | |||
| return self.src_enum.default | |||
| return self.default | |||
| return self.src_enum.normalize_enum_value(self.default) | |||
| class ParamDef: | |||
| @@ -198,7 +243,7 @@ class ParamDef: | |||
| self.name.id, name, name_field, members, default, member_alias)) | |||
| return self | |||
| def add_bit_combination_enum(self, name, *members, default=0, | |||
| def add_bit_combination_enum(self, name, *members, default=tuple(), | |||
| name_field=None, member_alias=[]): | |||
| self.members.append(member_defs.Enum( | |||
| self.name.id, name, name_field, members, default, member_alias, True)) | |||
| @@ -322,11 +367,13 @@ class PyWriter(IndentWriterBase): | |||
| ' for idx, v in enumerate(pdata):\n' | |||
| ' if isinstance(v, _EnumBase):\n' | |||
| ' pdata[idx] = _enum_member2num[id(v)]\n' | |||
| ' elif isinstance(v, _BitCombinedEnumBase):\n' | |||
| ' pdata[idx] = v._value_\n' | |||
| ' return tag + self._packer.pack(*pdata)\n' | |||
| '\n' | |||
| ) | |||
| self._write( | |||
| 'class _EnumBase(enum.Enum):\n' | |||
| # it's hard to mix custom implemention into enum, just do copy-paste instead | |||
| classbody = ( | |||
| ' @classmethod\n' | |||
| ' def __normalize(cls, val):\n' | |||
| ' if isinstance(val, str):\n' | |||
| @@ -349,6 +396,12 @@ class PyWriter(IndentWriterBase): | |||
| ' return super()._missing_(value)\n' | |||
| '\n' | |||
| ) | |||
| self._write( | |||
| 'class _EnumBase(enum.Enum):\n' + classbody | |||
| ) | |||
| self._write( | |||
| 'class _BitCombinedEnumBase(enum.Flag):\n' + classbody | |||
| ) | |||
| if not self._imperative: | |||
| self._write( | |||
| 'def _as_dtype_num(dtype):\n' | |||
| @@ -464,30 +517,42 @@ class SerializedDType(_ParamDefBase): | |||
| def _on_member_enum(self, e): | |||
| qualname = '{}.{}'.format(self._cur_param_name, e.name) | |||
| self._write('class %s(_EnumBase):', e.name, indent=1) | |||
| if e.combined: | |||
| self._write('class %s(_BitCombinedEnumBase):', e.name, indent=1) | |||
| else: | |||
| self._write('class %s(_EnumBase):', e.name, indent=1) | |||
| self._write_doc(e.name) | |||
| for idx, emem in enumerate(e.members): | |||
| self._write('%s = "%s"', emem, emem) | |||
| self._write_doc(emem) | |||
| if e.combined: | |||
| self._enum_member2num.append('id({}.{}):{}'.format( | |||
| qualname, emem, 1<<idx)) | |||
| self._write('%s = 1 << %d', emem, idx) | |||
| self._write_doc(emem) | |||
| else: | |||
| self._write('%s = "%s"', emem, emem) | |||
| self._write_doc(emem) | |||
| self._enum_member2num.append('id({}.{}):{}'.format( | |||
| qualname, emem, idx)) | |||
| for emem, emem_alis in e.member_alias: | |||
| self._write('%s = %s', emem_alis, emem) | |||
| for emem, emem_alias in e.member_alias: | |||
| if e.combined: | |||
| self._write('%s = %s', emem_alias, e.compose_combined_enum(emem)) | |||
| else: | |||
| self._write('%s = %s', emem_alias, emem) | |||
| self._unindent() | |||
| self._write('') | |||
| if e.combined: | |||
| default = e.compose_combined_enum(e.default) | |||
| else: | |||
| default = "'{}'".format(e.members[e.default]) | |||
| self._cur_fields.append(self.FieldDef( | |||
| name=e.name_field, | |||
| cvt='{}.convert({})'.format(qualname, e.name_field), | |||
| fmt='I', | |||
| default="'{}'".format(e.members[e.default]), | |||
| default=default, | |||
| type=qualname, | |||
| doc=None)) | |||
| @@ -495,11 +560,16 @@ class SerializedDType(_ParamDefBase): | |||
| self._write('%s = %s.%s', e.name, e.src_class, e.src_name) | |||
| s = e.src_enum | |||
| qualname = '{}.{}'.format(e.src_class, e.src_name) | |||
| if s.combined: | |||
| default = s.compose_combined_enum(e.get_default()) | |||
| else: | |||
| default = "'{}'".format(s.members[e.get_default()]) | |||
| self._cur_fields.append(self.FieldDef( | |||
| name=e.name_field, | |||
| cvt='{}.convert({})'.format(qualname, e.name_field), | |||
| fmt='I', | |||
| default="'{}'".format(s.members[e.get_default()]), | |||
| default=default, | |||
| type=qualname, | |||
| doc=None)) | |||
| @@ -639,14 +709,19 @@ class CPPWriter(IndentWriterBase): | |||
| v += ',' | |||
| self._write(v) | |||
| for mem, alias in e.member_alias: | |||
| self._write('%s = %s,', alias, mem) | |||
| if e.combined: | |||
| self._write('%s = %s,', alias, e.compose_combined_enum(mem)) | |||
| else: | |||
| self._write('%s = %s,', alias, mem) | |||
| self._write('};', indent=-1) | |||
| self._non_static_members.append(e) | |||
| self._write('static MEGDNN_CONSTEXPR uint32_t %s_NR_MEMBER = %d;', | |||
| str(e.name).upper(), len(e.members)) | |||
| self._add_ctor_args(e.name, | |||
| '{}::{}'.format(e.name, e.members[e.default]), | |||
| e.name_field) | |||
| if e.combined: | |||
| default = 'static_cast<{}>({})'.format(e.name, e.compose_combined_enum(e.default)) | |||
| else: | |||
| default = '{}::{}'.format(e.name, e.members[e.default]) | |||
| self._add_ctor_args(e.name, default, e.name_field) | |||
| def _on_member_enum_alias(self, e): | |||
| s = e.src_enum | |||
| @@ -654,10 +729,11 @@ class CPPWriter(IndentWriterBase): | |||
| self._non_static_members.append(e) | |||
| self._write('static MEGDNN_CONSTEXPR uint32_t %s_NR_MEMBER = %d;', | |||
| str(e.name).upper(), len(s.members)) | |||
| self._add_ctor_args(e.name, | |||
| '{}::{}'.format(e.name, | |||
| s.members[e.get_default()]), | |||
| e.name_field) | |||
| if s.combined: | |||
| default = 'static_cast<{}>({})'.format(e.name, s.compose_combined_enum(e.default)) | |||
| else: | |||
| default = '{}::{}'.format(e.name, s.members[e.get_default()]) | |||
| self._add_ctor_args(e.name, default, e.name_field) | |||
| def _on_member_field(self, f): | |||
| self._non_static_members.append(f) | |||
| @@ -106,7 +106,12 @@ class ConverterWriter(IndentWriterBase): | |||
| return | |||
| # wrapped with default value | |||
| default_val = "static_cast<{}::{}>({})".format(fullname, e.name, e.default) | |||
| if e.combined: | |||
| default_val = "static_cast<{}::{}>({})".format( | |||
| fullname, e.name, e.compose_combined_enum(e.default)) | |||
| else: | |||
| default_val = "{}::{}::{}".format(fullname, e.name, e.members[e.default]) | |||
| wrapped = self._wrapped_with_default_value(td_class, default_val) | |||
| self._current_tparams.append("{}:${}".format(wrapped, e.name_field)) | |||
| @@ -124,7 +129,13 @@ class ConverterWriter(IndentWriterBase): | |||
| self._write("def {} : {};".format(td_class, enum_def)) | |||
| # wrapped with default value | |||
| default_val = "static_cast<{}::{}>({})".format(fullname, e.name, e.get_default()) | |||
| s = e.src_enum | |||
| if s.combined: | |||
| default_val = "static_cast<{}::{}>({})".format( | |||
| fullname, e.name, s.compose_combined_enum(e.get_default())) | |||
| else: | |||
| default_val = "{}::{}::{}".format(fullname, e.name, s.members[e.get_default()]) | |||
| wrapped = self._wrapped_with_default_value(td_class, default_val) | |||
| self._current_tparams.append("{}:${}".format(wrapped, e.name_field)) | |||
| @@ -185,7 +185,7 @@ SmallVector<AlgoCategory> ConvBiasImpl::suggest_algo_category_order( | |||
| } | |||
| //! conv1x1 | |||
| im2col_prefer |= (FH == 1 && FW == 1); | |||
| //! x86 8x8x16 not optmized, so it will use fallback im2col+matmul | |||
| //! x86 8x8x16 not optimized, so it will use fallback im2col+matmul | |||
| if (param.deduce_algo_data_type() == AlgoDataType::INT8X8X16) { | |||
| im2col_prefer = true; | |||
| } | |||
| @@ -8,6 +8,9 @@ | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import numpy as np | |||
| from .._imperative_rt import make_const | |||
| from .._imperative_rt.core2 import SymbolVar, Tensor | |||
| class Const: | |||
| def __init__(self, value=None, *, dtype=None, device=None): | |||
| @@ -19,7 +22,19 @@ class Const: | |||
| from ...tensor import Tensor | |||
| device = self.device | |||
| if device is None: | |||
| device = reference[0].device | |||
| if len(reference) != 0: | |||
| reference = reference[0] | |||
| assert isinstance( | |||
| reference, (SymbolVar, Tensor) | |||
| ), "Reference should be Tensor or VarNode" | |||
| if device is None: | |||
| device = reference.device | |||
| if isinstance(reference, SymbolVar): | |||
| cls = type(reference) | |||
| rst = cls(make_const(reference.graph, self.value, device, self.dtype)) | |||
| return (rst,) | |||
| return (Tensor(self.value, self.dtype, self.device, True),) | |||
| @@ -13,7 +13,7 @@ from typing import Union | |||
| import numpy as np | |||
| from .._imperative_rt.common import CompNode | |||
| from .._imperative_rt.core2 import Tensor, apply | |||
| from .._imperative_rt.core2 import SymbolVar, Tensor, apply | |||
| from ..ops import builtin | |||
| from ..ops.builtin import Elemwise, GetVarShape | |||
| from . import utils | |||
| @@ -230,7 +230,9 @@ def _todo(*_): | |||
| def _expand_args(args): | |||
| if len(args) == 1: | |||
| if isinstance(args[0], (collections.abc.Sequence, Tensor, np.ndarray),): | |||
| if isinstance( | |||
| args[0], (collections.abc.Sequence, Tensor, SymbolVar, np.ndarray), | |||
| ): | |||
| args = args[0] | |||
| return args | |||
| @@ -5,6 +5,7 @@ | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import re | |||
| from collections import namedtuple | |||
| from typing import Union | |||
| @@ -22,6 +23,12 @@ from .._imperative_rt.common import ( | |||
| ) | |||
| def get_dtype_bit(dtype_name: str): | |||
| numbers = re.findall(r"\d+", dtype_name) | |||
| assert len(numbers) == 1, "Unsupport dtype name with more than one number." | |||
| return int(numbers[0]) | |||
| # normal dtype related | |||
| def is_lowbit(dtype): | |||
| return (dtype is intb1) or (dtype is intb2) or (dtype is intb4) | |||
| @@ -10,7 +10,7 @@ from typing import Iterable | |||
| import numpy as np | |||
| from .._imperative_rt.core2 import Tensor, apply | |||
| from .._imperative_rt.core2 import SymbolVar, Tensor, apply | |||
| from .._trace_option import use_symbolic_shape | |||
| from ..ops import builtin | |||
| from ..ops.special import Const | |||
| @@ -149,13 +149,13 @@ def unpack_getitem(inp, tuple_val, *, allow_newaxis=True): | |||
| return True | |||
| def get_index(i): | |||
| if not isinstance(i, (Tensor)): | |||
| if not isinstance(i, (Tensor, SymbolVar)): | |||
| if is_bool_list(i) or isinstance(i, np.ndarray) and i.dtype == np.bool_: | |||
| (i,) = Const(i, dtype=np.bool_, device=inp.device)() | |||
| (i,) = Const(i, dtype=np.bool_, device=inp.device)(inp) | |||
| else: | |||
| (i,) = Const(i, dtype=np.int32, device=inp.device)() | |||
| (i,) = Const(i, dtype=np.int32, device=inp.device)(inp) | |||
| return i | |||
| assert isinstance(i, Tensor) | |||
| assert isinstance(i, (Tensor, SymbolVar)) | |||
| if i.dtype != np.bool_: | |||
| return i | |||
| _, ind = apply(builtin.CondTake(), i, i) | |||
| @@ -197,9 +197,9 @@ def try_condtake(tensor, index): | |||
| ): | |||
| return [] | |||
| if isinstance(index, np.ndarray): | |||
| (index,) = Const(index, dtype=np.bool_, device=tensor.device)() | |||
| assert isinstance(index, Tensor) | |||
| if not isinstance(tensor, Tensor): | |||
| (index,) = Const(index, dtype=np.bool_, device=tensor.device)(tensor) | |||
| assert isinstance(index, (Tensor, SymbolVar)) | |||
| if not isinstance(tensor, (Tensor, SymbolVar)): | |||
| raise TypeError("input must be a tensor") | |||
| if tensor.device != index.device: | |||
| raise ValueError( | |||
| @@ -214,11 +214,16 @@ def getitem(tensor, index): | |||
| return try_result[0] | |||
| tensor, tensors, items, use_subtensor, ret_scalar = unpack_getitem(tensor, index) | |||
| for v in tensors: | |||
| if v.shape is None: | |||
| break | |||
| if isinstance(v.shape, v.__class__): | |||
| break | |||
| if len(v.shape) > 0 and v.shape[0] == 0: | |||
| (empty_tensor,) = Const([], dtype=tensor.dtype, device=tensor.device)() | |||
| (empty_tensor,) = Const([], dtype=tensor.dtype, device=tensor.device)( | |||
| tensor | |||
| ) | |||
| return empty_tensor | |||
| if use_subtensor: | |||
| op = builtin.Subtensor(items=items) | |||
| else: | |||
| @@ -235,8 +240,8 @@ def setitem(tensor, index, value): | |||
| if len(try_result) == 2: | |||
| index = try_result[1] | |||
| tensor = tensor.reshape(-1) | |||
| if not isinstance(value, Tensor): | |||
| (value,) = Const(value, dtype=tensor.dtype, device=tensor.device)() | |||
| if not isinstance(value, (Tensor, SymbolVar)): | |||
| (value,) = Const(value, dtype=tensor.dtype, device=tensor.device)(tensor) | |||
| tensor, tensors, items, use_subtensor, _ = unpack_getitem(tensor, index) | |||
| if use_subtensor: | |||
| op = builtin.Subtensor(items=items) | |||
| @@ -11,8 +11,9 @@ from typing import Iterable, Union | |||
| import numpy as np | |||
| from .._imperative_rt import VarNode | |||
| from .._imperative_rt.core2 import Tensor, apply, dtype_promotion, get_device | |||
| from .._imperative_rt import make_const | |||
| from .._imperative_rt.core2 import SymbolVar, Tensor, apply, dtype_promotion, get_device | |||
| from .._wrap import device as as_device | |||
| from ..ops import builtin | |||
| from ..ops.special import Const | |||
| from .dtype import is_dtype_equal, is_quantize | |||
| @@ -38,13 +39,9 @@ def set_convert_inputs(flag): | |||
| def concatenate(inputs, axis=0, *, device=None): | |||
| dtype = dtype_promotion(inputs) | |||
| device = get_device(inputs) | |||
| def convert(x): | |||
| return convert_single_value(x, dtype=dtype, device=device) | |||
| inputs = tuple(map(convert, inputs)) | |||
| inputs = convert_inputs(*inputs) | |||
| if device is None: | |||
| device = get_device(inputs) | |||
| (result,) = apply(builtin.Concat(axis=axis, comp_node=device), *inputs) | |||
| return result | |||
| @@ -60,7 +57,7 @@ def astype(x, dtype): | |||
| def convert_single_value(v, *, dtype=None, device=None): | |||
| if isinstance(v, (Tensor, VarNode)): | |||
| if isinstance(v, (Tensor, SymbolVar)): | |||
| if not is_quantize(v.dtype): | |||
| v = astype(v, dtype) | |||
| else: | |||
| @@ -68,17 +65,35 @@ def convert_single_value(v, *, dtype=None, device=None): | |||
| return v | |||
| def convert_inputs(*args: Tensor): | |||
| def convert_inputs(*args, device=None): | |||
| if not _enable_convert_inputs: | |||
| return args | |||
| dtype = dtype_promotion(args) | |||
| device = get_device(args) | |||
| if device is None: | |||
| device = get_device(args) | |||
| device = as_device(device) | |||
| graph = None | |||
| sym_type = None | |||
| for a in args: | |||
| if isinstance(a, SymbolVar): | |||
| if graph is None: | |||
| graph = a.var.graph | |||
| sym_type = type(a) | |||
| else: | |||
| assert graph == a.var.graph | |||
| args = list(args) | |||
| if graph is not None: | |||
| for i in range(len(args)): | |||
| if not isinstance(args[i], SymbolVar): | |||
| rst = make_const(graph, np.array(args[i]), device.to_c(), dtype) | |||
| args[i] = sym_type(rst) | |||
| def convert(value): | |||
| if value is None: | |||
| return value | |||
| return convert_single_value(value, dtype=dtype, device=device) | |||
| return convert_single_value(value, dtype=dtype, device=device.to_c()) | |||
| return tuple(map(convert, args)) | |||
| @@ -98,14 +113,14 @@ def result_type(*args): | |||
| def isscalar(x): | |||
| if isinstance(x, Tensor): | |||
| if isinstance(x, (Tensor, SymbolVar)): | |||
| return x._isscalar() | |||
| return np.isscalar(x) | |||
| def setscalar(x): | |||
| if isinstance(x, Tensor): | |||
| if isinstance(x, (Tensor, SymbolVar)): | |||
| x._setscalar() | |||
| else: | |||
| raise NotImplementedError("Unsupport type {}".format(type(x))) | |||
| @@ -132,7 +147,7 @@ def astensor1d(x, *reference, dtype=None, device=None): | |||
| if not isinstance(x, collections.abc.Sequence): | |||
| raise TypeError | |||
| if any(isinstance(i, Tensor) for i in x): | |||
| if any(isinstance(i, (Tensor, SymbolVar)) for i in x): | |||
| x = concatenate(x, device=device) | |||
| if dtype is not None: | |||
| x = astype(x, dtype) | |||
| @@ -142,7 +157,7 @@ def astensor1d(x, *reference, dtype=None, device=None): | |||
| def _expand_int(s, i): | |||
| if isinstance(i, Tensor): | |||
| if isinstance(i, (Tensor, SymbolVar)): | |||
| i_np = i.numpy() | |||
| if i_np.ndim == 0: | |||
| s.append(int(i_np)) | |||
| @@ -40,7 +40,7 @@ def set_execution_strategy(option): | |||
| * HEURISTIC uses heuristic to choose the fastest algorithm. | |||
| * PROFILE runs possible algorithms on real device to find the best one. | |||
| * REPRODUCIBLE uses the algorithms that is reproducible. | |||
| * OPTMIZED uses the algorithms that is optimized. | |||
| * OPTIMIZED uses the algorithms that is optimized. | |||
| The default strategy is HEURISTIC, this options can be combined to | |||
| form a combination option, e.g. PROFILE | REPRODUCIBLE | |||
| @@ -9,8 +9,7 @@ | |||
| # pylint: disable=unused-argument,invalid-name,redefined-builtin,arguments-out-of-order | |||
| import numpy as np | |||
| from ..core._imperative_rt.core2 import apply | |||
| from ..core._imperative_rt.graph import VarNode | |||
| from ..core._imperative_rt.core2 import SymbolVar, apply | |||
| from ..core.ops import builtin | |||
| from ..core.ops.builtin import Elemwise | |||
| from ..core.tensor import utils | |||
| @@ -72,7 +71,7 @@ __all__ = [ | |||
| def _elwise(*args, mode): | |||
| tensor_args = list(filter(lambda x: isinstance(x, (Tensor, VarNode)), args)) | |||
| tensor_args = list(filter(lambda x: isinstance(x, (Tensor, SymbolVar)), args)) | |||
| if len(tensor_args) == 0: | |||
| dtype = utils.dtype_promotion(args) | |||
| first_arg = Tensor(args[0], dtype=dtype, device=get_default_device()) | |||
| @@ -12,7 +12,7 @@ from typing import Iterable, Optional, Sequence, Union | |||
| import numpy as np | |||
| from ..core._imperative_rt import CompNode | |||
| from ..core._imperative_rt.core2 import apply | |||
| from ..core._imperative_rt.core2 import SymbolVar, apply | |||
| from ..core._wrap import device as as_device | |||
| from ..core.ops import builtin | |||
| from ..core.ops.builtin import Copy, Identity | |||
| @@ -101,7 +101,7 @@ def eye(N, M=None, *, dtype="float32", device: Optional[CompNode] = None) -> Ten | |||
| return result | |||
| def full(shape, value, dtype="float32", device=None): | |||
| def full(shape, value, dtype="float32", device=None) -> Tensor: | |||
| """ | |||
| Returns a tensor with given shape and value. | |||
| """ | |||
| @@ -115,7 +115,7 @@ def full(shape, value, dtype="float32", device=None): | |||
| return broadcast_to(x, shape) | |||
| def ones(shape, dtype="float32", device=None): | |||
| def ones(shape, dtype="float32", device=None) -> Tensor: | |||
| """ | |||
| Returns a ones tensor with given shape. | |||
| @@ -142,14 +142,14 @@ def ones(shape, dtype="float32", device=None): | |||
| return full(shape, 1.0, dtype=dtype, device=device) | |||
| def zeros(shape, dtype="float32", device=None): | |||
| def zeros(shape, dtype="float32", device=None) -> Tensor: | |||
| """ | |||
| Returns a zero tensor with given shape. | |||
| """ | |||
| return full(shape, 0.0, dtype=dtype, device=device) | |||
| def zeros_like(inp: Tensor) -> Tensor: | |||
| def zeros_like(inp: Union[Tensor, SymbolVar]) -> Union[Tensor, SymbolVar]: | |||
| """ | |||
| Returns a zero tensor with the same shape as input tensor. | |||
| @@ -176,21 +176,26 @@ def zeros_like(inp: Tensor) -> Tensor: | |||
| [0 0 0]] | |||
| """ | |||
| return zeros(inp.shape, dtype=inp.dtype, device=inp.device) | |||
| return full_like(inp, 0.0) | |||
| def ones_like(inp: Tensor) -> Tensor: | |||
| def ones_like(inp: Union[Tensor, SymbolVar]) -> Union[Tensor, SymbolVar]: | |||
| """ | |||
| Returns a ones tensor with the same shape as input tensor. | |||
| """ | |||
| return ones(inp.shape, dtype=inp.dtype, device=inp.device) | |||
| return full_like(inp, 1.0) | |||
| def full_like(inp: Tensor, value: Union[int, float]) -> Tensor: | |||
| def full_like( | |||
| inp: Union[Tensor, SymbolVar], value: Union[int, float] | |||
| ) -> Union[Tensor, SymbolVar]: | |||
| """ | |||
| Returns a tensor filled with given value with the same shape as input tensor. | |||
| """ | |||
| return full(inp.shape, value, dtype=inp.dtype, device=inp.device) | |||
| (x,) = Const(value, dtype=inp.dtype, device=inp.device)(inp) | |||
| if inp.shape is (): | |||
| return x | |||
| return broadcast_to(x, inp.shape) | |||
| def broadcast_to(inp: Tensor, shape: Union[int, Iterable[int]]) -> Tensor: | |||
| @@ -259,15 +264,10 @@ def concat(inps: Iterable[Tensor], axis: int = 0, device=None) -> Tensor: | |||
| if len(inps) == 1: | |||
| return inps[0] | |||
| dtype = dtype_promotion(inps) | |||
| inps = convert_inputs(*inps, device=device) | |||
| if device is None: | |||
| device = get_device(inps) | |||
| device = as_device(device) | |||
| def convert(x): | |||
| return convert_single_value(x, dtype=dtype, device=device) | |||
| inps = tuple(map(convert, inps)) | |||
| (result,) = apply(builtin.Concat(axis=axis, comp_node=device.to_c()), *inps) | |||
| return result | |||
| @@ -379,8 +379,14 @@ def split(inp, nsplits_or_sections, axis=0): | |||
| Ntotal, axis, Nsections | |||
| ) | |||
| ) | |||
| func = ( | |||
| floor_div | |||
| if isinstance(Nsections, (SymbolVar, Tensor)) | |||
| else lambda x, y: x // y | |||
| ) | |||
| div_points = [0] + [ | |||
| floor_div(Ntotal + Nsections - i - 1, Nsections) for i in range(Nsections) | |||
| func(Ntotal + Nsections - i - 1, Nsections) for i in range(Nsections) | |||
| ] | |||
| for i in range(2, Nsections + 1): | |||
| div_points[i] = div_points[i - 1] + div_points[i] | |||
| @@ -925,11 +931,15 @@ def linspace( | |||
| if not (cur_device is None or device == cur_device): | |||
| raise ("ambiguous device for linspace opr") | |||
| if not isinstance(start, Tensor): | |||
| is_symbolvar = list(isinstance(x, SymbolVar) for x in [start, stop, num]) | |||
| if any(is_symbolvar) and not all(is_symbolvar): | |||
| raise TypeError("start, stop and num should all be VarNode or none of them") | |||
| if not isinstance(start, (Tensor, SymbolVar)): | |||
| start = Tensor(start, device=device) | |||
| if not isinstance(stop, Tensor): | |||
| if not isinstance(stop, (Tensor, SymbolVar)): | |||
| stop = Tensor(stop, device=device) | |||
| if not isinstance(num, Tensor): | |||
| if not isinstance(num, (Tensor, SymbolVar)): | |||
| num = Tensor(num, device=device) | |||
| op = builtin.Linspace(comp_node=device) | |||
| @@ -983,7 +993,7 @@ def arange( | |||
| stop = stop.astype("float32") | |||
| if isinstance(step, Tensor): | |||
| step = step.astype("float32") | |||
| num = ceil(Tensor((stop - start) / step, device=device)) | |||
| num = ceil((stop - start) / step) | |||
| stop = start + step * (num - 1) | |||
| result = linspace(start, stop, num, device=device) | |||
| if np.dtype(dtype) == np.int32: | |||
| @@ -607,10 +607,10 @@ class Module(metaclass=ABCMeta): | |||
| def __getattribute__(self, name: str): | |||
| value = super().__getattribute__(name) | |||
| if name == "_name": | |||
| if name == "__dict__": | |||
| return value | |||
| if isinstance(value, (Tensor, Module)): | |||
| value._name = name | |||
| for prefix, variable in _expand_structure(name, value): | |||
| variable._name = prefix | |||
| return value | |||
| def __setattr__(self, name: str, value): | |||
| @@ -23,7 +23,7 @@ class Concat(QuantizedModule): | |||
| self.output_dtype = dtype | |||
| def forward(self, inps: Iterable[Tensor], axis: int = 0): | |||
| new_inps = (x.astype(self.output_dtype) for x in inps) | |||
| new_inps = tuple(x.astype(self.output_dtype) for x in inps) | |||
| return F.concat(new_inps, axis) | |||
| @classmethod | |||
| @@ -92,6 +92,7 @@ class Sequential(Module): | |||
| return [getattr(self, key) for key in self.layer_keys] | |||
| def forward(self, inp): | |||
| for layer in self.layer_values: | |||
| # avoid layer_values as a name prefix, see Module.__getattribute__ | |||
| for layer in [getattr(self, key) for key in self.layer_keys]: | |||
| inp = layer(inp) | |||
| return inp | |||
| @@ -6,6 +6,7 @@ | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import copy | |||
| from abc import ABCMeta, abstractmethod | |||
| from collections.abc import Iterable | |||
| from typing import Dict | |||
| @@ -197,10 +198,11 @@ class Optimizer(metaclass=ABCMeta): | |||
| cur_id += 1 | |||
| for param, st in self._state.items(): | |||
| _st = copy.copy(st) | |||
| if not keep_var: | |||
| for k, v in st.items(): | |||
| st[k] = v.numpy() | |||
| state[param2id[param]] = st | |||
| _st[k] = v.numpy() | |||
| state[param2id[param]] = _st | |||
| for group in self.param_groups: | |||
| param_group = {k: v for k, v in group.items() if k != "params"} | |||
| @@ -6,8 +6,16 @@ | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from .fake_quant import FakeQuantize | |||
| from .observer import Observer | |||
| from .fake_quant import TQT, FakeQuantize | |||
| from .observer import ( | |||
| ExponentialMovingAverageObserver, | |||
| HistogramObserver, | |||
| MinMaxObserver, | |||
| Observer, | |||
| PassiveObserver, | |||
| SyncExponentialMovingAverageObserver, | |||
| SyncMinMaxObserver, | |||
| ) | |||
| from .qconfig import ( | |||
| QConfig, | |||
| calibration_qconfig, | |||
| @@ -8,14 +8,19 @@ | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import argparse | |||
| import logging | |||
| import re | |||
| import numpy as np | |||
| from megengine.core.tensor.dtype import is_quantize | |||
| from megengine.logger import _imperative_rt_logger, get_logger, set_mgb_log_level | |||
| from megengine.utils.module_stats import ( | |||
| print_flops_stats, | |||
| print_params_stats, | |||
| enable_receptive_field, | |||
| get_op_stats, | |||
| get_param_stats, | |||
| print_op_stats, | |||
| print_param_stats, | |||
| print_summary, | |||
| sizeof_fmt, | |||
| ) | |||
| from megengine.utils.network import Network | |||
| @@ -40,34 +45,41 @@ def visualize( | |||
| :param log_params: whether print and record params size. | |||
| :param log_flops: whether print and record op flops. | |||
| """ | |||
| try: | |||
| from tensorboard.compat.proto.attr_value_pb2 import AttrValue | |||
| from tensorboard.compat.proto.config_pb2 import RunMetadata | |||
| from tensorboard.compat.proto.graph_pb2 import GraphDef | |||
| from tensorboard.compat.proto.node_def_pb2 import NodeDef | |||
| from tensorboard.compat.proto.step_stats_pb2 import ( | |||
| AllocatorMemoryUsed, | |||
| DeviceStepStats, | |||
| NodeExecStats, | |||
| StepStats, | |||
| ) | |||
| from tensorboard.compat.proto.tensor_shape_pb2 import TensorShapeProto | |||
| from tensorboard.compat.proto.versions_pb2 import VersionDef | |||
| from tensorboardX import SummaryWriter | |||
| except ImportError: | |||
| logger.error( | |||
| "TensorBoard and TensorboardX are required for visualize.", exc_info=True | |||
| ) | |||
| return | |||
| if log_path: | |||
| try: | |||
| from tensorboard.compat.proto.attr_value_pb2 import AttrValue | |||
| from tensorboard.compat.proto.config_pb2 import RunMetadata | |||
| from tensorboard.compat.proto.graph_pb2 import GraphDef | |||
| from tensorboard.compat.proto.node_def_pb2 import NodeDef | |||
| from tensorboard.compat.proto.step_stats_pb2 import ( | |||
| AllocatorMemoryUsed, | |||
| DeviceStepStats, | |||
| NodeExecStats, | |||
| StepStats, | |||
| ) | |||
| from tensorboard.compat.proto.tensor_shape_pb2 import TensorShapeProto | |||
| from tensorboard.compat.proto.versions_pb2 import VersionDef | |||
| from tensorboardX import SummaryWriter | |||
| except ImportError: | |||
| logger.error( | |||
| "TensorBoard and TensorboardX are required for visualize.", | |||
| exc_info=True, | |||
| ) | |||
| return | |||
| # FIXME: remove this after resolving "span dist too large" warning | |||
| old_level = set_mgb_log_level(logging.ERROR) | |||
| enable_receptive_field() | |||
| graph = Network.load(model_path) | |||
| writer = SummaryWriter(log_path) | |||
| def process_name(name): | |||
| return name.replace(".", "/").encode(encoding="utf-8") | |||
| # nodes that start with point or contain float const will lead to display bug | |||
| if not re.match(r"^[+-]?\d*\.\d*", name): | |||
| name = name.replace(".", "/") | |||
| return name.encode(encoding="utf-8") | |||
| summary = [["item", "value"]] | |||
| node_list = [] | |||
| flops_list = [] | |||
| params_list = [] | |||
| @@ -84,78 +96,90 @@ def visualize( | |||
| node_oup = node.outputs[0] | |||
| inp_list = [process_name(var.owner.name) for var in node.inputs] | |||
| attr = { | |||
| "_output_shapes": AttrValue( | |||
| list=AttrValue.ListValue( | |||
| shape=[ | |||
| TensorShapeProto( | |||
| dim=[TensorShapeProto.Dim(size=d) for d in node_oup.shape] | |||
| ) | |||
| ] | |||
| ) | |||
| ), | |||
| } | |||
| if hasattr(node, "calc_flops"): | |||
| flops_num = node.calc_flops() | |||
| if log_path: | |||
| # detail format see tensorboard/compat/proto/attr_value.proto | |||
| attr = { | |||
| "_output_shapes": AttrValue( | |||
| list=AttrValue.ListValue( | |||
| shape=[ | |||
| TensorShapeProto( | |||
| dim=[ | |||
| TensorShapeProto.Dim(size=d) for d in node_oup.shape | |||
| ] | |||
| ) | |||
| ] | |||
| ) | |||
| ), | |||
| "params": AttrValue(s=str(node.params).encode(encoding="utf-8")), | |||
| "dtype": AttrValue(s=str(node_oup.dtype).encode(encoding="utf-8")), | |||
| } | |||
| flops_stats = get_op_stats(node, node.inputs, node.outputs) | |||
| if flops_stats is not None: | |||
| # add op flops attr | |||
| attr["flops"] = AttrValue(s=sizeof_fmt(flops_num).encode(encoding="utf-8")) | |||
| flops_list.append( | |||
| dict( | |||
| name=node.name, | |||
| class_name=node.type, | |||
| input_shapes=[i.shape for i in node.inputs], | |||
| output_shapes=[o.shape for o in node.outputs], | |||
| flops_num=flops_num, | |||
| flops_cum=0, | |||
| if log_path and hasattr(flops_stats, "flops_num"): | |||
| attr["flops"] = AttrValue( | |||
| s=sizeof_fmt(flops_stats["flops"]).encode(encoding="utf-8") | |||
| ) | |||
| ) | |||
| flops_stats["name"] = node.name | |||
| flops_stats["class_name"] = node.type | |||
| flops_list.append(flops_stats) | |||
| if node.type == "ImmutableTensor": | |||
| param_dim = np.prod(node_oup.shape) | |||
| # TODO: consider other quantize dtypes | |||
| param_bytes = 1 if is_quantize(node_oup.dtype) else 4 | |||
| param_stats = get_param_stats(node.numpy()) | |||
| # add tensor size attr | |||
| attr["size"] = AttrValue( | |||
| s=sizeof_fmt(param_dim * param_bytes).encode(encoding="utf-8") | |||
| ) | |||
| params_list.append( | |||
| dict( | |||
| name=node.name, | |||
| shape=node_oup.shape, | |||
| param_dim=param_dim, | |||
| bits=param_bytes * 8, | |||
| size=param_dim * param_bytes, | |||
| size_cum=0, | |||
| mean="{:.2g}".format(node.numpy().mean()), | |||
| std="{:.2g}".format(node.numpy().std()), | |||
| if log_path: | |||
| attr["size"] = AttrValue( | |||
| s=sizeof_fmt(param_stats["size"]).encode(encoding="utf-8") | |||
| ) | |||
| param_stats["name"] = node.name | |||
| params_list.append(param_stats) | |||
| if log_path: | |||
| node_list.append( | |||
| NodeDef( | |||
| name=process_name(node.name), | |||
| op=node.type, | |||
| input=inp_list, | |||
| attr=attr, | |||
| ) | |||
| ) | |||
| # FIXME(MGE-2165): nodes outside network module may lead to unknown display bug | |||
| if not len(node.name.split(".")) > 2 and not node in graph.input_vars: | |||
| continue | |||
| node_list.append( | |||
| NodeDef( | |||
| name=process_name(node.name), op=node.type, input=inp_list, attr=attr, | |||
| ) | |||
| ) | |||
| # summary | |||
| extra_info = { | |||
| "#ops": len(graph.all_oprs), | |||
| "#params": len(params_list), | |||
| } | |||
| total_flops, total_params = 0, 0 | |||
| total_flops, total_param_dims, total_param_size = 0, 0, 0 | |||
| if log_params: | |||
| total_params = print_params_stats(params_list, bar_length_max) | |||
| total_param_dims, total_param_size = print_param_stats( | |||
| params_list, bar_length_max | |||
| ) | |||
| extra_info["total_param_dims"] = sizeof_fmt(total_param_dims) | |||
| extra_info["total_param_size"] = sizeof_fmt(total_param_size) | |||
| if log_flops: | |||
| total_flops = print_flops_stats(flops_list, bar_length_max) | |||
| total_flops = print_op_stats(flops_list, bar_length_max) | |||
| extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs") | |||
| if log_params and log_flops: | |||
| extra_info["flops/param_size"] = "{:3.3f}".format( | |||
| total_flops / total_param_size | |||
| ) | |||
| graph_def = GraphDef(node=node_list, versions=VersionDef(producer=22)) | |||
| if log_path: | |||
| graph_def = GraphDef(node=node_list, versions=VersionDef(producer=22)) | |||
| device = "/device:CPU:0" | |||
| stepstats = RunMetadata( | |||
| step_stats=StepStats(dev_stats=[DeviceStepStats(device=device)]) | |||
| ) | |||
| writer._get_file_writer().add_graph((graph_def, stepstats)) | |||
| device = "/device:CPU:0" | |||
| stepstats = RunMetadata( | |||
| step_stats=StepStats(dev_stats=[DeviceStepStats(device=device)]) | |||
| ) | |||
| writer = SummaryWriter(log_path) | |||
| writer._get_file_writer().add_graph((graph_def, stepstats)) | |||
| print_summary(**extra_info) | |||
| # FIXME: remove this after resolving "span dist too large" warning | |||
| _imperative_rt_logger.set_log_level(old_level) | |||
| return total_params, total_flops | |||
| return total_param_size, total_flops | |||
| def main(): | |||
| @@ -164,7 +188,7 @@ def main(): | |||
| formatter_class=argparse.ArgumentDefaultsHelpFormatter, | |||
| ) | |||
| parser.add_argument("model_path", help="dumped model path.") | |||
| parser.add_argument("log_path", help="tensorboard log path.") | |||
| parser.add_argument("--log_path", help="tensorboard log path.") | |||
| parser.add_argument( | |||
| "--bar_length_max", | |||
| type=int, | |||
| @@ -179,7 +203,20 @@ def main(): | |||
| parser.add_argument( | |||
| "--log_flops", action="store_true", help="whether print and record op flops.", | |||
| ) | |||
| visualize(**vars(parser.parse_args())) | |||
| parser.add_argument( | |||
| "--all", | |||
| action="store_true", | |||
| help="whether print all stats. Tensorboard logs will be placed in './log' if not specified.", | |||
| ) | |||
| args = parser.parse_args() | |||
| if args.all: | |||
| args.log_params = True | |||
| args.log_flops = True | |||
| if not args.log_path: | |||
| args.log_path = "./log" | |||
| kwargs = vars(args) | |||
| kwargs.pop("all") | |||
| visualize(**kwargs) | |||
| if __name__ == "__main__": | |||
| @@ -5,16 +5,17 @@ | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import contextlib | |||
| from functools import partial | |||
| import numpy as np | |||
| import tabulate | |||
| import megengine as mge | |||
| import megengine.core.tensor.dtype as dtype | |||
| import megengine.module as m | |||
| import megengine.module.qat as qatm | |||
| import megengine.module.quantized as qm | |||
| from megengine.core.tensor.dtype import get_dtype_bit | |||
| from megengine.functional.tensor import zeros | |||
| try: | |||
| @@ -26,61 +27,99 @@ logger = mge.get_logger(__name__) | |||
| logger.setLevel("INFO") | |||
| CALC_FLOPS = {} | |||
| _calc_flops_dict = {} | |||
| _calc_receptive_field_dict = {} | |||
| def _register_modules(*modules): | |||
| def _receptive_field_fallback(module, inputs, outputs): | |||
| if not _receptive_field_enabled: | |||
| return | |||
| assert not hasattr(module, "_rf") | |||
| assert not hasattr(module, "_stride") | |||
| if len(inputs) == 0: | |||
| # TODO: support other dimension | |||
| module._rf = (1, 1) | |||
| module._stride = (1, 1) | |||
| return module._rf, module._stride | |||
| rf, stride = preprocess_receptive_field(module, inputs, outputs) | |||
| module._rf = rf | |||
| module._stride = stride | |||
| return rf, stride | |||
| # key tuple, impl_dict, fallback | |||
| _iter_list = [ | |||
| ("flops_num", _calc_flops_dict, None), | |||
| ( | |||
| ("receptive_field", "stride"), | |||
| _calc_receptive_field_dict, | |||
| _receptive_field_fallback, | |||
| ), | |||
| ] | |||
| _receptive_field_enabled = False | |||
| def _register_dict(*modules, dict=None): | |||
| def callback(impl): | |||
| for module in modules: | |||
| CALC_FLOPS[module] = impl | |||
| dict[module] = impl | |||
| return impl | |||
| return callback | |||
| @_register_modules( | |||
| m.Conv2d, | |||
| m.ConvTranspose2d, | |||
| m.LocalConv2d, | |||
| qm.Conv2d, | |||
| qm.ConvRelu2d, | |||
| qm.ConvBn2d, | |||
| qm.ConvBnRelu2d, | |||
| qatm.Conv2d, | |||
| qatm.ConvRelu2d, | |||
| qatm.ConvBn2d, | |||
| qatm.ConvBnRelu2d, | |||
| def register_flops(*modules): | |||
| return _register_dict(*modules, dict=_calc_flops_dict) | |||
| def register_receptive_field(*modules): | |||
| return _register_dict(*modules, dict=_calc_receptive_field_dict) | |||
| def enable_receptive_field(): | |||
| global _receptive_field_enabled | |||
| _receptive_field_enabled = True | |||
| def disable_receptive_field(): | |||
| global _receptive_field_enabled | |||
| _receptive_field_enabled = False | |||
| @register_flops( | |||
| m.Conv1d, m.Conv2d, m.Conv3d, m.ConvTranspose2d, m.LocalConv2d, m.DeformableConv2d | |||
| ) | |||
| def count_convNd(module, input, output): | |||
| def flops_convNd(module: m.Conv2d, inputs, outputs): | |||
| bias = 1 if module.bias is not None else 0 | |||
| group = module.groups | |||
| ic = input[0].shape[1] | |||
| oc = output[0].shape[1] | |||
| goc = oc // group | |||
| gic = ic // group | |||
| N = output[0].shape[0] | |||
| HW = np.prod(output[0].shape[2:]) | |||
| # N x Cout x H x W x (Cin x Kw x Kh + bias) | |||
| return N * HW * goc * (gic * np.prod(module.kernel_size) + bias) | |||
| return np.prod(outputs[0].shape) * ( | |||
| module.in_channels // module.groups * np.prod(module.kernel_size) + bias | |||
| ) | |||
| @_register_modules(m.ConvTranspose2d) | |||
| def count_deconvNd(module, input, output): | |||
| return np.prod(input[0].shape) * output[0].shape[1] * np.prod(module.kernel_size) | |||
| @register_flops(m.Linear) | |||
| def flops_linear(module: m.Linear, inputs, outputs): | |||
| bias = module.out_features if module.bias is not None else 0 | |||
| return np.prod(outputs[0].shape) * module.in_features + bias | |||
| @_register_modules(m.Linear, qatm.Linear, qm.Linear) | |||
| def count_linear(module, input, output): | |||
| return np.prod(output[0].shape) * module.in_features | |||
| @register_flops(m.BatchMatMulActivation) | |||
| def flops_batchmatmul(module: m.BatchMatMulActivation, inputs, outputs): | |||
| bias = 1 if module.bias is not None else 0 | |||
| x = inputs[0] | |||
| w = module.weight | |||
| batch_size = x.shape[0] | |||
| n, p = x.shape[1:] | |||
| _, m = w.shape[1:] | |||
| return n * (p + bias) * m * batch_size | |||
| # does not need import qat and quantized module since they inherit from float module. | |||
| hook_modules = ( | |||
| m.Conv2d, | |||
| m.ConvTranspose2d, | |||
| m.LocalConv2d, | |||
| m.BatchNorm2d, | |||
| m.conv._ConvNd, | |||
| m.Linear, | |||
| m.BatchMatMulActivation, | |||
| ) | |||
| @@ -106,22 +145,63 @@ def sizeof_fmt(num, suffix="B"): | |||
| return "{}{:.1f} {}{}".format(sign_str, num, "Yi", suffix) | |||
| def print_flops_stats(flops, bar_length_max=20): | |||
| flops_list = [i["flops_num"] for i in flops] | |||
| max_flops_num = max(flops_list + [0]) | |||
| # calc total flops and set flops_cum | |||
| def preprocess_receptive_field(module, inputs, outputs): | |||
| # TODO: support other dimensions | |||
| pre_rf = ( | |||
| max(getattr(i.owner, "_rf", (1, 1))[0] for i in inputs), | |||
| max(getattr(i.owner, "_rf", (1, 1))[1] for i in inputs), | |||
| ) | |||
| pre_stride = ( | |||
| max(getattr(i.owner, "_stride", (1, 1))[0] for i in inputs), | |||
| max(getattr(i.owner, "_stride", (1, 1))[1] for i in inputs), | |||
| ) | |||
| return pre_rf, pre_stride | |||
| def get_op_stats(module, inputs, outputs): | |||
| rst = { | |||
| "input_shapes": [i.shape for i in inputs], | |||
| "output_shapes": [o.shape for o in outputs], | |||
| } | |||
| valid_flag = False | |||
| for key, _dict, fallback in _iter_list: | |||
| for _type in _dict: | |||
| if isinstance(module, _type): | |||
| value = _dict[_type](module, inputs, outputs) | |||
| valid_flag = True | |||
| break | |||
| else: | |||
| if fallback is not None: | |||
| value = fallback(module, inputs, outputs) | |||
| continue | |||
| if isinstance(key, tuple): | |||
| assert isinstance(value, tuple) | |||
| for k, v in zip(key, value): | |||
| rst[k] = v | |||
| else: | |||
| rst[key] = value | |||
| if valid_flag: | |||
| return rst | |||
| else: | |||
| return None | |||
| return | |||
| def print_op_stats(flops, bar_length_max=20): | |||
| max_flops_num = max([i["flops_num"] for i in flops] + [0]) | |||
| total_flops_num = 0 | |||
| for d in flops: | |||
| total_flops_num += int(d["flops_num"]) | |||
| d["flops_cum"] = sizeof_fmt(total_flops_num, suffix="OPs") | |||
| for i in flops: | |||
| f = i["flops_num"] | |||
| i["flops"] = sizeof_fmt(f, suffix="OPs") | |||
| r = i["ratio"] = f / total_flops_num | |||
| i["percentage"] = "{:.2f}%".format(r * 100) | |||
| bar_length = int(f / max_flops_num * bar_length_max) | |||
| i["bar"] = "#" * bar_length | |||
| for d in flops: | |||
| ratio = d["ratio"] = d["flops_num"] / total_flops_num | |||
| d["percentage"] = "{:.2f}%".format(ratio * 100) | |||
| bar_length = int(d["flops_num"] / max_flops_num * bar_length_max) | |||
| d["bar"] = "#" * bar_length | |||
| d["flops"] = sizeof_fmt(d["flops_num"], suffix="OPs") | |||
| header = [ | |||
| "name", | |||
| @@ -133,10 +213,13 @@ def print_flops_stats(flops, bar_length_max=20): | |||
| "percentage", | |||
| "bar", | |||
| ] | |||
| if _receptive_field_enabled: | |||
| header.insert(4, "receptive_field") | |||
| header.insert(5, "stride") | |||
| total_flops_str = sizeof_fmt(total_flops_num, suffix="OPs") | |||
| total_var_size = sum( | |||
| sum(s[1] if len(s) > 1 else 0 for s in i["output_shapes"]) for i in flops | |||
| sum(s[1] if len(s) > 1 else 0 for s in d["output_shapes"]) for d in flops | |||
| ) | |||
| flops.append( | |||
| dict(name="total", flops=total_flops_str, output_shapes=total_var_size) | |||
| @@ -147,30 +230,44 @@ def print_flops_stats(flops, bar_length_max=20): | |||
| return total_flops_num | |||
| def print_params_stats(params, bar_length_max=20): | |||
| def get_param_stats(param: np.ndarray): | |||
| nbits = get_dtype_bit(param.dtype.name) | |||
| shape = param.shape | |||
| param_dim = np.prod(param.shape) | |||
| param_size = param_dim * nbits // 8 | |||
| return { | |||
| "dtype": param.dtype, | |||
| "shape": shape, | |||
| "mean": "{:.3g}".format(param.mean()), | |||
| "std": "{:.3g}".format(param.std()), | |||
| "param_dim": param_dim, | |||
| "nbits": nbits, | |||
| "size": param_size, | |||
| } | |||
| def print_param_stats(params, bar_length_max=20): | |||
| max_size = max([d["size"] for d in params] + [0]) | |||
| total_param_dims, total_param_size = 0, 0 | |||
| for d in params: | |||
| total_param_dims += int(d["param_dim"]) | |||
| total_param_size += int(d["size"]) | |||
| d["size"] = sizeof_fmt(d["size"]) | |||
| d["size_cum"] = sizeof_fmt(total_param_size) | |||
| for d in params: | |||
| ratio = d["param_dim"] / total_param_dims | |||
| ratio = d["size"] / total_param_size | |||
| d["ratio"] = ratio | |||
| d["percentage"] = "{:.2f}%".format(ratio * 100) | |||
| # construct bar | |||
| max_ratio = max([d["ratio"] for d in params]) | |||
| for d in params: | |||
| bar_length = int(d["ratio"] / max_ratio * bar_length_max) | |||
| bar_length = int(d["size"] / max_size * bar_length_max) | |||
| d["size_bar"] = "#" * bar_length | |||
| d["size"] = sizeof_fmt(d["size"]) | |||
| param_size = sizeof_fmt(total_param_size) | |||
| params.append(dict(name="total", param_dim=total_param_dims, size=param_size,)) | |||
| header = [ | |||
| "name", | |||
| "dtype", | |||
| "shape", | |||
| "mean", | |||
| "std", | |||
| @@ -186,7 +283,13 @@ def print_params_stats(params, bar_length_max=20): | |||
| "param stats: \n" + tabulate.tabulate(dict2table(params, header=header)) | |||
| ) | |||
| return total_param_size | |||
| return total_param_dims, total_param_size | |||
| def print_summary(**kwargs): | |||
| data = [["item", "value"]] | |||
| data.extend(list(kwargs.items())) | |||
| logger.info("summary\n" + tabulate.tabulate(data)) | |||
| def module_stats( | |||
| @@ -205,71 +308,53 @@ def module_stats( | |||
| :param log_params: whether print and record params size. | |||
| :param log_flops: whether print and record op flops. | |||
| """ | |||
| disable_receptive_field() | |||
| def get_byteswidth(tensor): | |||
| if dtype.is_quantize(tensor.dtype): | |||
| return 1 | |||
| # elif dtype.is_bfloat16(tensor.dtype): | |||
| # return 2 | |||
| else: | |||
| return 4 | |||
| def module_stats_hook(module, input, output, name=""): | |||
| def module_stats_hook(module, inputs, outputs, name=""): | |||
| class_name = str(module.__class__).split(".")[-1].split("'")[0] | |||
| flops_fun = CALC_FLOPS.get(type(module)) | |||
| if callable(flops_fun): | |||
| flops_num = flops_fun(module, input, output) | |||
| if not isinstance(output, (list, tuple)): | |||
| output = [output] | |||
| flops.append( | |||
| dict( | |||
| name=name, | |||
| class_name=class_name, | |||
| input_shapes=[i.shape for i in input], | |||
| output_shapes=[o.shape for o in output], | |||
| flops_num=flops_num, | |||
| flops_cum=0, | |||
| ) | |||
| ) | |||
| flops_stats = get_op_stats(module, inputs, outputs) | |||
| if flops_stats is not None: | |||
| flops_stats["name"] = name | |||
| flops_stats["class_name"] = class_name | |||
| flops.append(flops_stats) | |||
| if hasattr(module, "weight") and module.weight is not None: | |||
| w = module.weight | |||
| value = w.numpy() | |||
| param_dim = np.prod(w.shape) | |||
| param_bytes = get_byteswidth(w) | |||
| params.append( | |||
| dict( | |||
| name=name + "-w", | |||
| shape=w.shape, | |||
| param_dim=param_dim, | |||
| bits=param_bytes * 8, | |||
| size=param_dim * param_bytes, | |||
| size_cum=0, | |||
| mean="{:.2g}".format(value.mean()), | |||
| std="{:.2g}".format(value.std()), | |||
| ) | |||
| ) | |||
| param_stats = get_param_stats(w.numpy()) | |||
| param_stats["name"] = name + "-w" | |||
| params.append(param_stats) | |||
| if hasattr(module, "bias") and module.bias is not None: | |||
| b = module.bias | |||
| value = b.numpy() | |||
| param_dim = np.prod(b.shape) | |||
| param_bytes = get_byteswidth(b) | |||
| params.append( | |||
| dict( | |||
| name=name + "-b", | |||
| shape=b.shape, | |||
| param_dim=param_dim, | |||
| bits=param_bytes * 8, | |||
| size=param_dim * param_bytes, | |||
| size_cum=0, | |||
| mean="{:.2g}".format(value.mean()), | |||
| std="{:.2g}".format(value.std()), | |||
| ) | |||
| ) | |||
| param_stats = get_param_stats(b.numpy()) | |||
| param_stats["name"] = name + "-b" | |||
| params.append(param_stats) | |||
| @contextlib.contextmanager | |||
| def adjust_stats(module, training=False): | |||
| """Adjust module to training/eval mode temporarily. | |||
| Args: | |||
| module (M.Module): used module. | |||
| training (bool): training mode. True for train mode, False fro eval mode. | |||
| """ | |||
| def recursive_backup_stats(module, mode): | |||
| for m in module.modules(): | |||
| # save prev status to _prev_training | |||
| m._prev_training = m.training | |||
| m.train(mode, recursive=False) | |||
| def recursive_recover_stats(module): | |||
| for m in module.modules(): | |||
| # recover prev status and delete attribute | |||
| m.training = m._prev_training | |||
| delattr(m, "_prev_training") | |||
| recursive_backup_stats(module, mode=training) | |||
| yield module | |||
| recursive_recover_stats(module) | |||
| # multiple inputs to the network | |||
| if not isinstance(input_size[0], tuple): | |||
| @@ -286,15 +371,28 @@ def module_stats( | |||
| ) | |||
| inputs = [zeros(in_size, dtype=np.float32) for in_size in input_size] | |||
| model.eval() | |||
| model(*inputs) | |||
| with adjust_stats(model, training=False) as model: | |||
| model(*inputs) | |||
| for h in hooks: | |||
| h.remove() | |||
| total_flops, total_params = 0, 0 | |||
| extra_info = { | |||
| "#params": len(params), | |||
| } | |||
| total_flops, total_param_dims, total_param_size = 0, 0, 0 | |||
| if log_params: | |||
| total_params = print_params_stats(params, bar_length_max) | |||
| total_param_dims, total_param_size = print_param_stats(params, bar_length_max) | |||
| extra_info["total_param_dims"] = sizeof_fmt(total_param_dims) | |||
| extra_info["total_param_size"] = sizeof_fmt(total_param_size) | |||
| if log_flops: | |||
| total_flops = print_flops_stats(flops, bar_length_max) | |||
| total_flops = print_op_stats(flops, bar_length_max) | |||
| extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs") | |||
| if log_params and log_flops: | |||
| extra_info["flops/param_size"] = "{:3.3f}".format( | |||
| total_flops / total_param_size | |||
| ) | |||
| print_summary(**extra_info) | |||
| return total_params, total_flops | |||
| return total_param_size, total_flops | |||
| @@ -11,12 +11,14 @@ import fnmatch | |||
| import itertools | |||
| import re | |||
| from collections import OrderedDict | |||
| from typing import Dict, List | |||
| from typing import Dict, List, Sequence | |||
| import numpy as np | |||
| from ..core._imperative_rt import ComputingGraph | |||
| from ..core._imperative_rt.core2 import SymbolVar | |||
| from ..core.tensor import megbrain_graph as G | |||
| from ..logger import get_logger | |||
| from .comp_graph_tools import get_dep_vars, get_opr_type, get_oprs_seq | |||
| from .network_node import ( | |||
| Host2DeviceCopy, | |||
| @@ -27,6 +29,8 @@ from .network_node import ( | |||
| str_to_mge_class, | |||
| ) | |||
| logger = get_logger(__name__) | |||
| class Network: | |||
| def __init__(self): | |||
| @@ -60,12 +64,12 @@ class Network: | |||
| ) | |||
| outputs = [new_outputs[i] for i in outspec] | |||
| self._orig_outputs = outputs | |||
| self.add_dep_oprs(*outputs) | |||
| for x in self._orig_outputs: | |||
| self.output_vars.append(self._get_var(x)) | |||
| self.add_dep_oprs() | |||
| for x in self._orig_inputs: | |||
| self.input_vars.append(self._get_var(x)) | |||
| for x in self._orig_outputs: | |||
| self.output_vars.append(self._get_var(x)) | |||
| self.graph = self._orig_outputs[0].graph | |||
| return self | |||
| @@ -83,6 +87,58 @@ class Network: | |||
| for o in opr.outputs: | |||
| self.all_vars_map[o.var.id] = o | |||
| def optimize_for_inference(self, dest_vars, **kwargs): | |||
| r""" | |||
| Applies optimize_for_inference pass for operator graph. | |||
| :param dest_vars: list of output vars in the operator graph | |||
| :Keyword Arguments: | |||
| * enable_io16xc32 -- | |||
| whether to use float16 for I/O between oprs and use | |||
| float32 as internal computation precision. Note the output var would be | |||
| changed to float16. | |||
| * enable_ioc16 -- | |||
| whether to use float16 for both I/O and computation | |||
| precision. | |||
| * enable_hwcd4 -- | |||
| whether to use NHWCD4 data layout. This is faster on some | |||
| OpenCL backend. | |||
| * enable_nchw88 -- | |||
| whether to use NCHW88 data layout, currently | |||
| used in X86 AVX backend. | |||
| * enable_nchw44 -- | |||
| whether to use NCHW44 data layout, currently | |||
| used in arm backend. | |||
| * enable_nchw44_dot -- | |||
| whether to use NCHW44_dot data layout, currently | |||
| used in armv8.2+dotprod backend. | |||
| * enable_nchw4 -- | |||
| whether to use NCHW4 data layout, currently | |||
| used in nvidia backend(based on cudnn). | |||
| * enable_nchw32 -- | |||
| whether to use NCHW32 data layout, currently | |||
| used in nvidia backend with tensorcore(based on cudnn). | |||
| * enable_chwn4 -- | |||
| whether to use CHWN4 data layout, currently | |||
| used in nvidia backend with tensorcore. | |||
| * enable_fuse_conv_bias_nonlinearity: whether to fuse conv+bias+nonlinearty | |||
| into one opr. | |||
| * enable_fuse_conv_bias_with_z: whether to fuse conv_bias with z | |||
| input for inference on nvidia backend(this optimization pass will | |||
| result in mismatch of the precision of output of training and | |||
| inference) | |||
| """ | |||
| if not isinstance(dest_vars, Sequence): | |||
| dest_vars = [dest_vars] | |||
| dest_vars = list(G.VarNode(var.var) for var in dest_vars) | |||
| new_vars = G.optimize_for_inference(dest_vars, **kwargs) | |||
| return list(self._get_var(var) for var in new_vars) | |||
| def dump( | |||
| self, | |||
| file, | |||
| @@ -122,47 +178,22 @@ class Network: | |||
| :Keyword Arguments: | |||
| * enable_io16xc32 -- | |||
| whether to use float16 for I/O between oprs and use | |||
| float32 as internal computation precision. Note the output var would be | |||
| changed to float16. | |||
| * enable_ioc16 -- | |||
| whether to use float16 for both I/O and computation | |||
| precision. | |||
| * enable_hwcd4 -- | |||
| whether to use NHWCD4 data layout. This is faster on some | |||
| OpenCL backend. | |||
| * enable_nchw88 -- | |||
| whether to use NCHW88 data layout, currently | |||
| used in X86 AVX backend. | |||
| * enable_nchw44 -- | |||
| whether to use NCHW44 data layout, currently | |||
| used in arm backend. | |||
| * enable_nchw44_dot -- | |||
| whether to use NCHW44_dot data layout, currently | |||
| used in armv8.2+dotprod backend. | |||
| * enable_nchw4 -- | |||
| whether to use NCHW4 data layout, currently | |||
| used in nvidia backend(based on cudnn). | |||
| * enable_nchw32 -- | |||
| whether to use NCHW32 data layout, currently | |||
| used in nvidia backend with tensorcore(based on cudnn). | |||
| * enable_chwn4 -- | |||
| whether to use CHWN4 data layout, currently | |||
| used in nvidia backend with tensorcore. | |||
| * enable_fuse_conv_bias_nonlinearity: whether to fuse conv+bias+nonlinearty | |||
| into one opr. | |||
| * enable_fuse_conv_bias_with_z: whether to fuse conv_bias with z | |||
| input for inference on nvidia backend(this optimization pass will | |||
| result in mismatch of the precision of output of training and | |||
| inference) | |||
| See also :py:meth:`optimize_for_inference`. | |||
| """ | |||
| self._compile() | |||
| out = [G.VarNode(var.var) for var in self.output_vars] | |||
| if kwargs.pop("arg_names", False): | |||
| logger.warning( | |||
| '"arg_names" is not supported in Network.dump, rename input vars directly' | |||
| ) | |||
| if kwargs.pop("output_names", False): | |||
| logger.warning( | |||
| '"output_names" is not supported in Network.dump, rename output vars directly' | |||
| ) | |||
| if optimize_for_inference: | |||
| out = G.optimize_for_inference(out, **kwargs) | |||
| @@ -197,6 +228,8 @@ class Network: | |||
| def add_output(self, *vars: VarNode): | |||
| """Adds vars into the network output node list | |||
| """ | |||
| if not all([var.owner for var in vars]): | |||
| self.add_dep_oprs(*vars) | |||
| for var in vars: | |||
| if var not in self.output_vars: | |||
| self.output_vars.append(var) | |||
| @@ -209,21 +242,25 @@ class Network: | |||
| self.output_vars.remove(var) | |||
| def add_dep_oprs(self, *vars): | |||
| """Adds dependent opnodes and varnodes of vars into network | |||
| """ | |||
| oprs = get_oprs_seq(vars, False, False) | |||
| for mge_opr in oprs: | |||
| if len(vars) == 0: | |||
| vars = self.output_vars | |||
| q = list(vars) | |||
| while len(q) > 0: | |||
| cur = q.pop(0) | |||
| if cur.owner is not None: | |||
| continue | |||
| if cur.name is None: | |||
| cur.name = cur.var.name | |||
| self.all_vars_map[cur.var.id] = cur | |||
| mge_opr = cur.var.owner | |||
| if get_opr_type(mge_opr) == "Host2DeviceCopy": | |||
| self._orig_inputs.extend(mge_opr.outputs) | |||
| opr = self._add_opr(mge_opr) | |||
| if opr is not None: | |||
| for x in mge_opr.inputs: | |||
| opr.add_inp_var(self._get_var(x)) | |||
| # set out var | |||
| for x in mge_opr.outputs: | |||
| opr.add_out_var(self._get_var(x)) | |||
| return [self.all_vars_map[var.id] for var in vars] | |||
| cur.owner = self._add_opr(mge_opr) | |||
| if cur.owner is None: | |||
| cur.owner = self.all_oprs_map[mge_opr.id] | |||
| continue | |||
| q.extend(cur.owner.inputs) | |||
| return list(vars) | |||
| def modify_opr_names(self, modifier): | |||
| """Modifies names of operators **inplace**; useful for merging loaded | |||
| @@ -275,6 +312,9 @@ class Network: | |||
| Replaces vars in the graph. | |||
| :param repl_dict: the map {old_var: new_var} that specifies how to replace the vars. | |||
| """ | |||
| if not all([var.owner for var in repl_dict.values()]): | |||
| print(repl_dict.values()) | |||
| self.add_dep_oprs(*list(repl_dict.values())) | |||
| for var in self.all_vars: | |||
| if var in repl_dict: | |||
| repl_var = repl_dict[var] | |||
| @@ -282,6 +322,7 @@ class Network: | |||
| idx = owner.outputs.index(repl_var) | |||
| owner.outputs[idx] = var | |||
| var.__dict__.update(repl_var.__dict__) | |||
| var.var = repl_var.var | |||
| def replace_oprs(self, repl_dict: Dict[OpNode, OpNode]): | |||
| """ | |||
| @@ -297,6 +338,7 @@ class Network: | |||
| for ind, var in enumerate(opr.outputs): | |||
| var.owner = repl_dict[opr] | |||
| var.__dict__.update(repl_dict[opr].outputs[ind].__dict__) | |||
| var.var = repl_dict[opr].outputs[ind].var | |||
| def get_opr_by_type(self, oprcls, unique=True): | |||
| assert issubclass(oprcls, OpNode) | |||
| @@ -381,11 +423,16 @@ class Network: | |||
| return self.opr_filter.as_dict() | |||
| # used for loading and building graph | |||
| def _add_opr(self, x): | |||
| def _add_opr(self, opr): | |||
| # TODO: use megbrain C++ RTTI to replace type string | |||
| if x.id not in self.all_oprs_map: | |||
| self.all_oprs_map[x.id] = str_to_mge_class(get_opr_type(x)).load(x) | |||
| return self.all_oprs_map[x.id] | |||
| if opr.id not in self.all_oprs_map: | |||
| opnode = str_to_mge_class(get_opr_type(opr)).load(opr) | |||
| self.all_oprs_map[opr.id] = opnode | |||
| for var in opr.inputs: | |||
| opnode.add_inp_var(self._get_var(var)) | |||
| for var in opr.outputs: | |||
| opnode.add_out_var(self._get_var(var)) | |||
| return opnode | |||
| else: | |||
| return None | |||
| @@ -397,7 +444,7 @@ class Network: | |||
| def _get_var(self, x): | |||
| # auto convert to VarNode of Network | |||
| if x.id not in self.all_vars_map: | |||
| if x.id not in self.all_vars_map or self.all_vars_map[x.id].var != x: | |||
| self.all_vars_map[x.id] = VarNode.load(x, self._get_opr(x.owner)) | |||
| return self.all_vars_map[x.id] | |||
| @@ -652,7 +699,7 @@ class NodeFilterHasInput(NodeFilter): | |||
| assert isinstance( | |||
| i, OpNode | |||
| ), "has_input() must be used with OpNode; " "got {!r}".format(i) | |||
| if self.var in i.inputs: | |||
| if any(self.var is _ for _ in i.inputs): | |||
| yield i | |||
| @@ -663,6 +710,7 @@ class NodeFilterName(NodeFilter): | |||
| def __init__(self, node_iter, pattern, ignorecase): | |||
| super().__init__(node_iter) | |||
| self.pattern = pattern | |||
| self._re = self.make_re(pattern, ignorecase) | |||
| @classmethod | |||
| @@ -676,5 +724,5 @@ class NodeFilterName(NodeFilter): | |||
| def __iter__(self): | |||
| for i in self._iter: | |||
| if self._re.match(i.name): | |||
| if self.pattern == i.name or self._re.match(i.name): | |||
| yield i | |||
| @@ -6,27 +6,41 @@ | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import abc | |||
| import json | |||
| import sys | |||
| from typing import Callable | |||
| from typing import Callable, Sequence | |||
| import numpy as np | |||
| from ..core import _imperative_rt as rt | |||
| from ..core._imperative_rt.core2 import SymbolVar | |||
| from ..core._wrap import Device | |||
| from ..core.ops import builtin | |||
| from ..core.tensor.megbrain_graph import InputNode | |||
| from ..core.tensor.array_method import ArrayMethodMixin | |||
| from ..core.tensor.indexing import getitem as _getitem | |||
| from ..core.tensor.indexing import setitem as _setitem | |||
| from ..core.tensor.megbrain_graph import InputNode, OutputNode | |||
| from ..tensor import Tensor | |||
| from .comp_graph_tools import replace_vars | |||
| from .module_stats import ( | |||
| preprocess_receptive_field, | |||
| register_flops, | |||
| register_receptive_field, | |||
| ) | |||
| class NetworkNode: | |||
| pass | |||
| class VarNode(NetworkNode): | |||
| def __init__(self, owner_opr=None, name=None): | |||
| self.var = None | |||
| class VarNodeMeta(type(SymbolVar), type(ArrayMethodMixin)): | |||
| pass | |||
| class VarNode(NetworkNode, SymbolVar, ArrayMethodMixin, metaclass=VarNodeMeta): | |||
| def __init__(self, var=None, *, owner_opr=None, name=None): | |||
| SymbolVar.__init__(self, var) | |||
| self.owner = owner_opr | |||
| self.name = name | |||
| self.id = id(self) | |||
| @@ -53,6 +67,40 @@ class VarNode(NetworkNode): | |||
| def dtype(self): | |||
| return self.var.dtype if self.var else None | |||
| def __bool__(self): | |||
| return False | |||
| __index__ = None | |||
| __int__ = None | |||
| __float__ = None | |||
| __complex__ = None | |||
| def __hash__(self): | |||
| return id(self) | |||
| @property | |||
| def _tuple_shape(self): | |||
| return self.var.shape | |||
| def numpy(self): | |||
| o = OutputNode(self.var) | |||
| self.graph.compile(o.outputs).execute() | |||
| return o.get_value().numpy() | |||
| def __getitem__(self, index): | |||
| return _getitem(self, index) | |||
| def __setitem__(self, index, value): | |||
| if index is not Ellipsis: | |||
| value = _setitem(self, index, value) | |||
| if self.owner is not None: | |||
| idx = self.owner.outputs.index(self) | |||
| self.owner.outputs[idx] = VarNode( | |||
| self.var, owner_opr=self.owner, name=self.var.name | |||
| ) | |||
| self.var = value.var | |||
| self.owner = None | |||
| def set_owner_opr(self, owner_opr): | |||
| self.owner = owner_opr | |||
| @@ -130,7 +178,7 @@ class Host2DeviceCopy(OpNode): | |||
| outputs = rt.make_h2d(graph, self.device, self.dtype, self.shape, self.name) | |||
| self._opr = outputs.owner | |||
| if len(self.outputs) == 0: | |||
| self.outputs.append(VarNode(self, self.name)) | |||
| self.outputs.append(VarNode(owner_opr=self, name=self.name)) | |||
| self.outputs[0].var = outputs | |||
| assert self.outputs[0].owner is self | |||
| @@ -168,8 +216,8 @@ class ImmutableTensor(OpNode): | |||
| def set_value(self, data, device=None): | |||
| assert self.graph is not None | |||
| cn = device if device else self.device | |||
| assert isinstance(data, (int, float, np.ndarray)) | |||
| if isinstance(data, (int, float)): | |||
| assert isinstance(data, (int, float, Sequence, np.ndarray)) | |||
| if not isinstance(data, np.ndarray): | |||
| data = np.array(data) | |||
| if data.dtype == np.float64: | |||
| data = data.astype(np.float32) | |||
| @@ -177,7 +225,7 @@ class ImmutableTensor(OpNode): | |||
| data = data.astype(np.int32) | |||
| varnode = rt.make_const(self.graph, data, cn, data.dtype, self.name) | |||
| if len(self.outputs) == 0: | |||
| self.outputs.append(VarNode(self, self.name)) | |||
| self.outputs.append(VarNode(owner_opr=self, name=self.name)) | |||
| self.outputs[0].var = varnode | |||
| self._opr = varnode.owner | |||
| @@ -225,8 +273,21 @@ class Elemwise(OpNode): | |||
| type = "Elemwise" | |||
| opdef = builtin.Elemwise | |||
| def calc_flops(self): | |||
| return np.prod(self.outputs[0].shape) | |||
| class ElemwiseMultiType(OpNode): | |||
| type = "ElemwiseMultiType" | |||
| opdef = builtin.ElemwiseMultiType | |||
| @classmethod | |||
| def load(cls, opr): | |||
| obj = super(ElemwiseMultiType, cls).load(opr) | |||
| obj.params["dtype"] = opr.outputs[0].dtype | |||
| return obj | |||
| @register_flops(Elemwise, ElemwiseMultiType) | |||
| def flops_elemwise(opnode: Elemwise, inputs, outputs): | |||
| return np.prod(outputs[0].shape) | |||
| class Reduce(OpNode): | |||
| @@ -255,20 +316,24 @@ class MatrixMul(OpNode): | |||
| type = "MatrixMul" | |||
| opdef = builtin.MatrixMul | |||
| def calc_flops(self): | |||
| assert len(self.inputs[0].shape) == 2 and len(self.outputs[0].shape) == 2 | |||
| mid_shape = self.inputs[0].shape[1] | |||
| return np.prod(self.outputs[0].shape) * mid_shape | |||
| @register_flops(MatrixMul) | |||
| def flops_matmul(opnode: MatrixMul, inputs, outputs): | |||
| assert len(inputs[0].shape) == 2 and len(outputs[0].shape) == 2 | |||
| mid_shape = inputs[0].shape[1] | |||
| return np.prod(outputs[0].shape) * mid_shape | |||
| class BatchedMatrixMul(OpNode): | |||
| type = "BatchedMatmul" | |||
| opdef = builtin.BatchedMatrixMul | |||
| def calc_flops(self): | |||
| assert len(self.inputs[0].shape) == 3 and len(self.outputs[0].shape) == 3 | |||
| mid_shape = self.inputs[0].shape[2] | |||
| return np.prod(self.outputs[0].shape) * mid_shape | |||
| @register_flops(BatchedMatrixMul) | |||
| def flops_batchmatmul(opnode: BatchedMatrixMul, inputs, outputs): | |||
| assert len(inputs[0].shape) == 3 and len(outputs[0].shape) == 3 | |||
| mid_shape = inputs[0].shape[2] | |||
| return np.prod(outputs[0].shape) * mid_shape | |||
| class Dot(OpNode): | |||
| @@ -285,18 +350,6 @@ class ConvolutionForward(OpNode): | |||
| type = "Convolution" | |||
| opdef = builtin.Convolution | |||
| def calc_flops(self): | |||
| param_W_shape = self.inputs[1].shape | |||
| kh = param_W_shape[-2] | |||
| kw = param_W_shape[-1] | |||
| if len(param_W_shape) == 5: | |||
| num_input = param_W_shape[2] | |||
| else: | |||
| num_input = param_W_shape[1] | |||
| NCHW = np.prod(self.outputs[0].shape) | |||
| # N x Cout x H x W x (Cin x Kw x Kh) | |||
| return NCHW * (num_input * kw * kh) | |||
| class ConvolutionBackwardData(OpNode): | |||
| type = "ConvTranspose" | |||
| @@ -343,17 +396,41 @@ class ConvBiasForward(OpNode): | |||
| obj.params["dtype"] = opr.outputs[0].dtype | |||
| return obj | |||
| def calc_flops(self): | |||
| param_W_shape = self.inputs[1].shape | |||
| kh = param_W_shape[-2] | |||
| kw = param_W_shape[-1] | |||
| if len(param_W_shape) == 5: | |||
| num_input = param_W_shape[2] | |||
| else: | |||
| num_input = param_W_shape[1] | |||
| NCHW = np.prod(self.outputs[0].shape) | |||
| # N x Cout x H x W x (Cin x Kw x Kh + bias) | |||
| return NCHW * (num_input * kw * kh + 1) | |||
| @register_flops( | |||
| ConvolutionForward, ConvBiasForward, | |||
| ) | |||
| def flops_conv(opnode: ConvolutionForward, inputs, outputs): | |||
| param_W_shape = inputs[1].shape | |||
| kh = param_W_shape[-2] | |||
| kw = param_W_shape[-1] | |||
| if len(param_W_shape) == 5: | |||
| num_input = param_W_shape[2] | |||
| else: | |||
| num_input = param_W_shape[1] | |||
| NCHW = np.prod(outputs[0].shape) | |||
| bias = 1 if isinstance(opnode, ConvBiasForward) else 0 | |||
| # N x Cout x H x W x (Cin x Kw x Kh) | |||
| return NCHW * (num_input * kw * kh + bias) | |||
| @register_receptive_field(ConvolutionForward, ConvBiasForward) | |||
| def receptive_field(opnode: ConvolutionForward, inputs, outputs): | |||
| pre_rf, pre_stride = preprocess_receptive_field(opnode, inputs, outputs) | |||
| param_W_shape = inputs[1].shape | |||
| kh = param_W_shape[-2] | |||
| kw = param_W_shape[-1] | |||
| rf = ( | |||
| kh * pre_stride[0] + pre_rf[0] - pre_stride[0], | |||
| kw * pre_stride[1] + pre_rf[1] - pre_stride[1], | |||
| ) | |||
| stride = ( | |||
| opnode.params["stride_h"] * pre_stride[0], | |||
| opnode.params["stride_w"] * pre_stride[1], | |||
| ) | |||
| opnode._rf = rf | |||
| opnode._stride = stride | |||
| return rf, stride | |||
| class BatchConvBiasForward(OpNode): | |||
| @@ -652,20 +729,6 @@ class AssertEqual(OpNode): | |||
| opdef = builtin.AssertEqual | |||
| class ElemwiseMultiType(OpNode): | |||
| type = "ElemwiseMultiType" | |||
| opdef = builtin.ElemwiseMultiType | |||
| @classmethod | |||
| def load(cls, opr): | |||
| obj = super(ElemwiseMultiType, cls).load(opr) | |||
| obj.params["dtype"] = opr.outputs[0].dtype | |||
| return obj | |||
| def calc_flops(self): | |||
| return np.prod(self.outputs[0].shape) | |||
| class CvtColorForward(OpNode): | |||
| type = "CvtColor" | |||
| opdef = builtin.CvtColor | |||
| @@ -266,7 +266,7 @@ void init_graph_rt(py::module m) { | |||
| {"HEURISTIC", [&]() { stg = _AlgoStrategy::HEURISTIC; }}, | |||
| {"PROFILE", [&]() { stg = _AlgoStrategy::PROFILE; }}, | |||
| {"REPRODUCIBLE", [&]() { stg = _AlgoStrategy::REPRODUCIBLE; }}, | |||
| {"OPTMIZED", [&]() { stg = _AlgoStrategy::OPTMIZED; }}, | |||
| {"OPTIMIZED", [&]() { stg = _AlgoStrategy::OPTIMIZED; }}, | |||
| }; | |||
| auto it = m.find(strategy); | |||
| mgb_assert(it != m.end(), "Invalid strategy string!"); | |||
| @@ -87,9 +87,13 @@ struct pyobj_convert_generic { | |||
| } | |||
| }; | |||
| template<typename T, typename SFINAE=void> | |||
| struct EnumTrait; | |||
| template <typename T> | |||
| struct EnumTrait { | |||
| struct EnumTrait<T, std::enable_if_t<std::is_enum_v<T>>> { | |||
| static constexpr bool is_bit_combined = false; | |||
| static constexpr std::underlying_type_t<T> max = 0; | |||
| }; | |||
| template <typename T> | |||
| @@ -264,18 +268,25 @@ struct BitCombinedEnumWrapper { | |||
| return ret; | |||
| } | |||
| } | |||
| static PyObject* py_new_combined_enum(PyTypeObject* type, PyObject*, PyObject*) { | |||
| PyObject* obj = type->tp_alloc(type, 0); | |||
| reinterpret_cast<BitCombinedEnumWrapper*>(obj)->value = static_cast<T>(1); | |||
| return obj; | |||
| } | |||
| static int py_init(PyObject* self, PyObject* args, PyObject*) { | |||
| int input = 1; | |||
| if (PyArg_ParseTuple(args, "|i", &input)){ | |||
| reinterpret_cast<BitCombinedEnumWrapper*>(self)->value = | |||
| static_cast<T>(input); | |||
| static PyObject* py_new_combined_enum(PyTypeObject* type, PyObject* args, PyObject*) { | |||
| if (!PyTuple_Size(args)) { | |||
| PyObject* obj = type->tp_alloc(type, 0); | |||
| reinterpret_cast<BitCombinedEnumWrapper*>(obj)->value = T(); | |||
| return obj; | |||
| } | |||
| else { | |||
| PyObject* input; | |||
| if (!PyArg_ParseTuple(args, "|O", &input)) { | |||
| return nullptr; | |||
| } | |||
| T value; | |||
| try { | |||
| value = pyobj_convert_generic<T>::from(input); | |||
| } CATCH_ALL(nullptr); | |||
| PyObject* obj = type->tp_alloc(type, 0); | |||
| reinterpret_cast<BitCombinedEnumWrapper*>(obj)->value = value; | |||
| return obj; | |||
| } | |||
| return 0; | |||
| } | |||
| static PyObject* py_repr(PyObject* self) { | |||
| return pyobj_convert_generic<std::string>::to( | |||
| @@ -325,6 +336,12 @@ struct pyobj_convert_generic<T, | |||
| static T from(PyObject* obj) { | |||
| if (PyObject_TypeCheck(obj, &Wrapper::type)) { | |||
| return reinterpret_cast<Wrapper*>(obj)->value; | |||
| } else if(PyLong_Check(obj)) { | |||
| auto value = pyobj_convert_generic<std::underlying_type_t<T>>::from(obj); | |||
| mgb_throw_if(value > EnumTrait<T>::max, mgb::MegBrainError, | |||
| "out of range, cannot convert %zu to %s", | |||
| static_cast<uint32_t>(value), Wrapper::name); | |||
| return static_cast<T>(value); | |||
| } | |||
| // try as string | |||
| // TODO: type checkcd | |||
| @@ -160,16 +160,21 @@ PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObje | |||
| if (ctx.op->same_type<BackwardGraph>()) { | |||
| ctx.backward = true; | |||
| } | |||
| if (py::isinstance<cg::VarNode>(py::handle(args[0]))){ | |||
| SmallVector<cg::VarNode*> vinputs(nargs); | |||
| for (size_t i = 0; i < nargs; ++i) { | |||
| vinputs[i] = py::handle(args[i]).cast<cg::VarNode *>(); | |||
| } | |||
| auto op = ctx.op.get(); | |||
| return to_tuple(OpDef::apply_on_var_node(*op, vinputs)).release().ptr(); | |||
| } | |||
| if (py::isinstance<PySymbolVar>(py::handle(args[0]))){ | |||
| SmallVector<cg::VarNode*> vinputs(nargs); | |||
| for (size_t i = 0; i < nargs; ++i) { | |||
| vinputs[i] = py::handle(args[i]).cast<PySymbolVar*>()->m_node; | |||
| } | |||
| auto op = ctx.op.get(); | |||
| auto rst = OpDef::apply_on_var_node(*op, vinputs); | |||
| auto ret = pybind11::tuple(rst.size()); | |||
| auto typeobj = py::handle(args[0]).get_type(); | |||
| for (size_t i = 0; i<rst.size(); ++i) { | |||
| ret[i] = typeobj(pybind11::cast(rst[i], pybind11::return_value_policy::automatic)); | |||
| } | |||
| return ret.release().ptr(); | |||
| } | |||
| for (size_t i = 0; i < nargs; ++i) { | |||
| if (TensorWrapper* tw = TensorWrapper::try_cast(args[i])) { | |||
| @@ -686,9 +691,9 @@ PyArray_Descr* _dtype_promotion(PyObject*const* args, size_t nargs) { | |||
| continue; | |||
| } | |||
| if (py::isinstance<cg::VarNode>(py::handle(handle))){ | |||
| auto var = py::handle(handle).cast<cg::VarNode *>(); | |||
| mgb::DType type = var->dtype(); | |||
| if (py::isinstance<PySymbolVar>(py::handle(handle))){ | |||
| auto var = py::handle(handle).cast<PySymbolVar*>(); | |||
| mgb::DType type = var->m_node->dtype(); | |||
| auto && descr = npy::dtype_mgb2np_descr(type); | |||
| Py_INCREF(descr.get()); | |||
| tensors.emplace_back(descr.get()); | |||
| @@ -737,19 +742,26 @@ CompNode _get_device(PyObject*const* args, size_t nargs) { | |||
| bool valid = false; | |||
| CompNode cn; | |||
| for (size_t i = 0; i < nargs; ++i) { | |||
| PyObject* handle = is_tuple ? PyTuple_GetItem(tuple, i): args[i]; | |||
| PyObject* handle = is_tuple ? PyTuple_GetItem(tuple, i) : args[i]; | |||
| TensorWrapper* tw = TensorWrapper::try_cast(handle); | |||
| bool is_var = py::isinstance<cg::VarNode>(py::handle(handle)); | |||
| if (tw || is_var) { | |||
| bool is_symvar = py::isinstance<PySymbolVar>(py::handle(handle)); | |||
| if (tw || is_symvar) { | |||
| if (!valid) { | |||
| cn = tw ? tw->m_tensor->comp_node() : py::handle(handle).cast<cg::VarNode *>()->comp_node(); | |||
| cn = tw ? tw->m_tensor->comp_node() | |||
| : py::handle(handle) | |||
| .cast<PySymbolVar*>() | |||
| ->m_node->comp_node(); | |||
| valid = true; | |||
| } else { | |||
| CompNode cn1 = tw ? tw->m_tensor->comp_node() : py::handle(handle).cast<cg::VarNode *>()->comp_node(); | |||
| CompNode cn1 = tw ? tw->m_tensor->comp_node() | |||
| : py::handle(handle) | |||
| .cast<PySymbolVar*>() | |||
| ->m_node->comp_node(); | |||
| if (cn1 != cn) { | |||
| throw py::value_error(ssprintf("ambiguous device: %s vs %s", | |||
| cn.to_string().c_str(), cn1.to_string().c_str())); | |||
| cn.to_string().c_str(), | |||
| cn1.to_string().c_str())); | |||
| } | |||
| } | |||
| } | |||
| @@ -849,6 +861,32 @@ void init_tensor(py::module m) { | |||
| .def("__call__", &TensorWeakRef::operator()) | |||
| .def("_use_cnt", &TensorWeakRef::_use_cnt); | |||
| py::class_<PySymbolVar, std::shared_ptr<PySymbolVar>>(m, "SymbolVar") | |||
| .def_property_readonly( | |||
| "dtype", [](PySymbolVar* v) { return v->m_node->dtype(); }) | |||
| .def_property("var", [](PySymbolVar* v) { return v->m_node; }, | |||
| [](PySymbolVar* s, cg::VarNode* v) { s->m_node = v; }) | |||
| .def_property_readonly( | |||
| "device", | |||
| [](PySymbolVar* v) { return v->m_node->comp_node(); }) | |||
| .def_property_readonly( | |||
| "graph", | |||
| [](PySymbolVar* v) { return v->m_node->owner_graph(); }) | |||
| .def_property_readonly( | |||
| "shape", | |||
| [](PySymbolVar* v) -> const TensorShape* { | |||
| auto&& mgr = v->m_node->owner_graph() | |||
| ->static_infer_manager(); | |||
| return mgr.infer_shape_fallible(v->m_node); | |||
| }) | |||
| .def("_isscalar", [](PySymbolVar* v) { return v->is_scalar; }) | |||
| .def("_setscalar", | |||
| [](PySymbolVar* v) { return v->is_scalar = true; }) | |||
| .def(py::init([](cg::VarNode* node) { | |||
| return std::make_shared<PySymbolVar>(node); | |||
| }), | |||
| py::arg() = nullptr); | |||
| static PyMethodDef method_defs[] = { | |||
| MGE_PY_INTERFACE(apply, py_apply), | |||
| MGE_PY_INTERFACE(dtype_promotion, dtype_promotion), | |||
| @@ -181,6 +181,12 @@ struct TensorWrapper { | |||
| PyObject* _use_cnt() { return PyLong_FromSize_t(m_tensor.use_count()); }; | |||
| }; | |||
| struct PySymbolVar { | |||
| cg::VarNode* m_node = nullptr; | |||
| bool is_scalar = false; | |||
| PySymbolVar() = default; | |||
| PySymbolVar(VarNode *m): m_node(m){} | |||
| }; | |||
| PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObject* kwnames */); | |||
| @@ -2,9 +2,11 @@ import io | |||
| import numpy as np | |||
| import megengine.core.tensor.megbrain_graph as G | |||
| import megengine.utils.comp_graph_tools as cgtools | |||
| from megengine import tensor | |||
| from megengine.jit import trace | |||
| from megengine.utils.network_node import VarNode | |||
| def _default_compare_fn(x, y): | |||
| @@ -14,8 +16,23 @@ def _default_compare_fn(x, y): | |||
| np.testing.assert_allclose(x.numpy(), y, rtol=1e-6) | |||
| def make_tensor(x, network=None, device=None): | |||
| if network is not None: | |||
| if isinstance(x, VarNode): | |||
| return VarNode(x.var) | |||
| return network.make_const(x, device=device) | |||
| else: | |||
| return tensor(x, device=device) | |||
| def opr_test( | |||
| cases, func, compare_fn=_default_compare_fn, ref_fn=None, test_trace=True, **kwargs | |||
| cases, | |||
| func, | |||
| compare_fn=_default_compare_fn, | |||
| ref_fn=None, | |||
| test_trace=True, | |||
| network=None, | |||
| **kwargs | |||
| ): | |||
| """ | |||
| :param cases: the list which have dict element, the list length should be 2 for dynamic shape test. | |||
| @@ -44,7 +61,7 @@ def opr_test( | |||
| if not isinstance(results, (tuple, list)): | |||
| results = (results,) | |||
| for r, e in zip(results, expected): | |||
| if not isinstance(r, tensor): | |||
| if not isinstance(r, (tensor, VarNode)): | |||
| r = tensor(r) | |||
| compare_fn(r, e) | |||
| @@ -72,9 +89,9 @@ def opr_test( | |||
| raise ValueError("the input func should be callable") | |||
| inp, outp = get_param(cases, 0) | |||
| inp_tensor = [tensor(inpi) for inpi in inp] | |||
| inp_tensor = [make_tensor(inpi, network) for inpi in inp] | |||
| if test_trace: | |||
| if test_trace and not network: | |||
| copied_inp = inp_tensor.copy() | |||
| for symbolic in [False, True]: | |||
| traced_func = trace(symbolic=symbolic)(func) | |||
| @@ -104,6 +104,10 @@ def _test_optimizer(opt_str, test_case, check_class, update_lr=False): | |||
| ) | |||
| step += 1 | |||
| check_func(ori_params, net.parameters(), step) | |||
| try_state_dict = { | |||
| "net": net.state_dict(), | |||
| "opt": opt.state_dict(), | |||
| } | |||
| def test_sgd(): | |||
| @@ -10,12 +10,17 @@ import collections | |||
| import numpy as np | |||
| import pytest | |||
| from utils import make_tensor | |||
| import megengine | |||
| import megengine.core.tensor.megbrain_graph as G | |||
| import megengine.functional as F | |||
| from megengine.core._imperative_rt.core2 import apply | |||
| from megengine.core._trace_option import use_symbolic_shape | |||
| from megengine.core.ops import builtin | |||
| from megengine.tensor import Tensor | |||
| from megengine.utils.network import Network | |||
| from megengine.utils.network_node import VarNode | |||
| def cvt_to_shape_desc(val, inpvar, config=None): | |||
| @@ -387,108 +392,130 @@ def test_batched_mesh_indexing(): | |||
| # high level | |||
| def get_value(x): | |||
| if isinstance(x, VarNode): | |||
| var = x.var | |||
| o = G.OutputNode(var) | |||
| graph = x.graph | |||
| graph.compile(o.outputs).execute() | |||
| return o.get_value().numpy() | |||
| else: | |||
| return x.numpy() | |||
| @pytest.mark.parametrize("test_varnode", [True, False]) | |||
| def test_advance_indexing_high_level(test_varnode): | |||
| if test_varnode: | |||
| network = Network() | |||
| else: | |||
| network = None | |||
| def test_advance_indexing_high_level(): | |||
| x = np.arange(25).reshape(5, 5).astype("int32") | |||
| d = np.arange(15).reshape(3, 5).astype("int32") | |||
| xx = Tensor(x) | |||
| xx = make_tensor(x, network) | |||
| np.testing.assert_equal(x[1, :], xx[1, :].numpy()) | |||
| np.testing.assert_equal(x[:, 1], xx[:, 1].numpy()) | |||
| np.testing.assert_equal(x[1:3, :], xx[1:3, :].numpy()) | |||
| np.testing.assert_equal(x[1, :], get_value(xx[1, :])) | |||
| np.testing.assert_equal(x[:, 1], get_value(xx[:, 1])) | |||
| np.testing.assert_equal(x[1:3, :], get_value(xx[1:3, :])) | |||
| np.testing.assert_equal(x[:, :], xx[:, :].numpy()) | |||
| np.testing.assert_equal(x[1, 1], xx[1, 1].numpy()) | |||
| np.testing.assert_equal(x[:, :], get_value(xx[:, :])) | |||
| np.testing.assert_equal(x[1, 1], get_value(xx[1, 1])) | |||
| yy = xx[(0, 4, 2), :] | |||
| np.testing.assert_equal(x[(0, 4, 2), :], yy.numpy()) | |||
| np.testing.assert_equal(x[(0, 4, 2), :], get_value(yy)) | |||
| x_ = x.copy() | |||
| x_[(0, 4, 2), :] = d | |||
| xx_ = Tensor(xx) | |||
| xx_ = make_tensor(xx, network) | |||
| xx_[(0, 4, 2), :] = d | |||
| np.testing.assert_equal(x_, xx_.numpy()) | |||
| np.testing.assert_equal(x_, get_value(xx_)) | |||
| x = np.arange(27).reshape(3, 3, 3).astype("int32") | |||
| xx = Tensor(x) | |||
| xx = make_tensor(x, network) | |||
| np.testing.assert_equal(x[1, :, :], xx[1, :, :].numpy()) | |||
| np.testing.assert_equal(x[1, :, 1], xx[1, :, 1].numpy()) | |||
| np.testing.assert_equal(x[1, 0:1, :], xx[1, 0:1, :].numpy()) | |||
| np.testing.assert_equal(x[0:1, 1, 1], xx[0:1, 1, 1].numpy()) | |||
| np.testing.assert_equal(x[:, 1, 1], xx[:, 1, 1].numpy()) | |||
| np.testing.assert_equal(x[:, 1], xx[:, 1].numpy()) | |||
| np.testing.assert_equal(x[1, 1:2], xx[1, 1:2].numpy()) | |||
| np.testing.assert_equal(x[1, :, :], get_value(xx[1, :, :])) | |||
| np.testing.assert_equal(x[1, :, 1], get_value(xx[1, :, 1])) | |||
| np.testing.assert_equal(x[1, 0:1, :], get_value(xx[1, 0:1, :])) | |||
| np.testing.assert_equal(x[0:1, 1, 1], get_value(xx[0:1, 1, 1])) | |||
| np.testing.assert_equal(x[:, 1, 1], get_value(xx[:, 1, 1])) | |||
| np.testing.assert_equal(x[:, 1], get_value(xx[:, 1])) | |||
| np.testing.assert_equal(x[1, 1:2], get_value(xx[1, 1:2])) | |||
| x_ = x.copy() | |||
| x_[1, 1, 1] = -1 | |||
| xx[1, 1, 1] = -1 | |||
| np.testing.assert_equal(x_, xx.numpy()) | |||
| np.testing.assert_equal(x_, get_value(xx)) | |||
| x_[:, 1, 1] = -2 | |||
| xx[:, 1, 1] = x_[:, 1, 1] | |||
| np.testing.assert_equal(x_, xx.numpy()) | |||
| np.testing.assert_equal(x_, get_value(xx)) | |||
| x_[0:1, :, 1] = -3 | |||
| xx[0:1, :, 1] = x_[0:1, :, 1] | |||
| np.testing.assert_equal(x_, xx.numpy()) | |||
| np.testing.assert_equal(x_, get_value(xx)) | |||
| x_[0:1, :, 1] = -4 | |||
| y = Tensor(x_) | |||
| y = make_tensor(x_, network) | |||
| xx[0:1, :, 1] = y[0:1, :, 1] | |||
| np.testing.assert_equal(y.numpy(), xx.numpy()) | |||
| np.testing.assert_equal(get_value(y), get_value(xx)) | |||
| x[:] = 1 | |||
| xx[:] = 1 | |||
| np.testing.assert_equal(x, xx.numpy()) | |||
| np.testing.assert_equal(x, get_value(xx)) | |||
| x = np.arange(9).reshape(3, 3).astype("int32") | |||
| xx = Tensor(x) | |||
| xx = make_tensor(x, network) | |||
| y = np.array([1, 2]) | |||
| yy = Tensor(y) | |||
| np.testing.assert_equal(x[:, y[0]], xx[:, y[0]].numpy()) | |||
| np.testing.assert_equal(x[:, y[0]], xx[:, yy[0]].numpy()) | |||
| np.testing.assert_equal(x[:, y], xx[:, y].numpy()) | |||
| np.testing.assert_equal(x[:, y], xx[:, yy].numpy()) | |||
| yy = make_tensor(y, network) | |||
| np.testing.assert_equal(x[:, y[0]], get_value(xx[:, y[0]])) | |||
| np.testing.assert_equal(x[:, y[0]], get_value(xx[:, yy[0]])) | |||
| np.testing.assert_equal(x[:, y], get_value(xx[:, y])) | |||
| np.testing.assert_equal(x[:, y], get_value(xx[:, yy])) | |||
| x_ = x.copy() | |||
| x_[:, y[0]] = -1 | |||
| xx_ = Tensor(x_) | |||
| xx_ = make_tensor(x_, network) | |||
| xx[:, yy[0]] = xx_[:, yy[0]] | |||
| np.testing.assert_equal(x_, xx.numpy()) | |||
| np.testing.assert_equal(x_, get_value(xx)) | |||
| x_[:, y] = -1 | |||
| xx_ = Tensor(x_) | |||
| xx_ = make_tensor(x_, network) | |||
| xx[:, yy] = xx_[:, yy] | |||
| np.testing.assert_equal(x_, xx.numpy()) | |||
| np.testing.assert_equal(x_, get_value(xx)) | |||
| x = np.arange(9).reshape(3, 3).astype("int32") | |||
| xx = Tensor(x) | |||
| xx = make_tensor(x, network) | |||
| y = np.array([1]) | |||
| yy = Tensor(y) | |||
| np.testing.assert_equal(x[:, y[0]], xx[:, y[0]].numpy()) | |||
| np.testing.assert_equal(x[:, y[0]], xx[:, yy[0]].numpy()) | |||
| np.testing.assert_equal(x[:, y], xx[:, y].numpy()) | |||
| yy = make_tensor(y, network) | |||
| np.testing.assert_equal(x[:, y[0]], get_value(xx[:, y[0]])) | |||
| np.testing.assert_equal(x[:, y[0]], get_value(xx[:, yy[0]])) | |||
| np.testing.assert_equal(x[:, y], get_value(xx[:, y])) | |||
| np.testing.assert_equal(x[:, y], xx[:, yy].numpy()) | |||
| np.testing.assert_equal(x[:, y], get_value(xx[:, yy])) | |||
| x = np.arange(9).reshape(3, 3).astype("int32") | |||
| xx = Tensor(x) | |||
| np.testing.assert_equal(x[[0, 1], 0], xx[[0, 1], 0].numpy()) | |||
| np.testing.assert_equal(x[0:2, 0], xx[0:2, 0].numpy()) | |||
| def test_advance_indexing_with_bool(): | |||
| xx = make_tensor(x, network) | |||
| np.testing.assert_equal(x[[0, 1], 0], get_value(xx[[0, 1], 0])) | |||
| np.testing.assert_equal(x[0:2, 0], get_value(xx[0:2, 0])) | |||
| @pytest.mark.parametrize( | |||
| "test_varnode", [True, False], | |||
| ) | |||
| def test_advance_indexing_with_bool(test_varnode): | |||
| if test_varnode: | |||
| network = Network() | |||
| else: | |||
| network = None | |||
| a = np.arange(9).reshape(3, 3).astype(np.float32) | |||
| b = np.array([1, 2, 3]) | |||
| c = np.array([1, 2, 3]) | |||
| aa = Tensor(a) | |||
| bb = Tensor(b) | |||
| cc = Tensor(c) | |||
| np.testing.assert_equal(a[b == 1, c == 2], aa[bb == 1, cc == 2].numpy()) | |||
| aa = make_tensor(a, network) | |||
| bb = make_tensor(b, network) | |||
| cc = make_tensor(c, network) | |||
| np.testing.assert_equal(a[b == 1, c == 2], get_value(aa[bb == 1, cc == 2])) | |||
| a[b == 1, c == 2] = -1.0 | |||
| aa[bb == 1, cc == 2] = -1.0 | |||
| np.testing.assert_equal(a, aa.numpy()) | |||
| np.testing.assert_equal(a, get_value(aa)) | |||
| a = np.arange(9).reshape(3, 3).astype(np.float32) | |||
| b = np.array([False, True, True]) | |||
| @@ -11,13 +11,16 @@ import platform | |||
| import numpy as np | |||
| import pytest | |||
| from utils import opr_test | |||
| from utils import make_tensor, opr_test | |||
| import megengine.functional as F | |||
| from megengine import tensor | |||
| from megengine.core._trace_option import use_symbolic_shape | |||
| from megengine.core.tensor import megbrain_graph as G | |||
| from megengine.core.tensor.utils import astensor1d | |||
| from megengine.distributed.helper import get_device_count_by_fork | |||
| from megengine.utils.network import Network | |||
| from megengine.utils.network_node import VarNode | |||
| def test_eye(): | |||
| @@ -38,7 +41,13 @@ def test_eye(): | |||
| ) | |||
| def test_concat(): | |||
| @pytest.mark.parametrize("is_varnode", [True, False]) | |||
| def test_concat(is_varnode): | |||
| if is_varnode: | |||
| network = Network() | |||
| else: | |||
| network = None | |||
| def get_data_shape(length: int): | |||
| return (length, 2, 3) | |||
| @@ -50,18 +59,30 @@ def test_concat(): | |||
| return F.concat([data1, data2]) | |||
| cases = [{"input": [data1, data2]}, {"input": [data1, data3]}] | |||
| opr_test(cases, run, ref_fn=lambda x, y: np.concatenate([x, y])) | |||
| opr_test(cases, run, ref_fn=lambda x, y: np.concatenate([x, y]), network=network) | |||
| def test_concat_device(): | |||
| data1 = tensor(np.random.random((3, 2, 2)).astype("float32"), device="cpu0") | |||
| data2 = tensor(np.random.random((2, 2, 2)).astype("float32"), device="cpu1") | |||
| @pytest.mark.parametrize("is_varnode", [True, False]) | |||
| def test_concat_device(is_varnode): | |||
| if is_varnode: | |||
| network = Network() | |||
| else: | |||
| network = None | |||
| data1 = make_tensor(np.random.random((3, 2, 2)).astype("float32"), network, "cpu0") | |||
| data2 = make_tensor(np.random.random((2, 2, 2)).astype("float32"), network, "cpu1") | |||
| out = F.concat([data1, data2], device="cpu0") | |||
| assert str(out.device).split(":")[0] == "cpu0" | |||
| def test_stack(): | |||
| @pytest.mark.parametrize("is_varnode", [True, False]) | |||
| def test_stack(is_varnode): | |||
| if is_varnode: | |||
| network = Network() | |||
| else: | |||
| network = None | |||
| data1 = np.random.random((3, 2, 2)).astype("float32") | |||
| data2 = np.random.random((3, 2, 2)).astype("float32") | |||
| data3 = np.random.random((3, 2, 2)).astype("float32") | |||
| @@ -72,12 +93,20 @@ def test_stack(): | |||
| def run(data1, data2): | |||
| return F.stack([data1, data2], axis=ai) | |||
| opr_test(cases, run, ref_fn=lambda x, y: np.stack([x, y], axis=ai)) | |||
| opr_test( | |||
| cases, run, ref_fn=lambda x, y: np.stack([x, y], axis=ai), network=network | |||
| ) | |||
| @pytest.mark.parametrize("is_varnode", [True, False]) | |||
| def test_split(is_varnode): | |||
| if is_varnode: | |||
| network = Network() | |||
| else: | |||
| network = None | |||
| def test_split(): | |||
| data = np.random.random((2, 3, 4, 5)).astype(np.float32) | |||
| inp = tensor(data) | |||
| inp = make_tensor(data, network) | |||
| mge_out0 = F.split(inp, 2, axis=3) | |||
| mge_out1 = F.split(inp, [3], axis=3) | |||
| @@ -106,26 +135,42 @@ def test_split(): | |||
| assert str(e) == "Invalid nsplits_or_secions: [3, 3, 5]" | |||
| def test_reshape(): | |||
| @pytest.mark.parametrize("is_varnode", [True, False]) | |||
| def test_reshape(is_varnode): | |||
| if is_varnode: | |||
| network = Network() | |||
| else: | |||
| network = None | |||
| x = np.arange(6, dtype="float32") | |||
| xx = tensor(x) | |||
| xx = make_tensor(x, network) | |||
| y = x.reshape(1, 2, 3) | |||
| for shape in [ | |||
| (1, 2, 3), | |||
| (1, -1, 3), | |||
| (1, tensor(-1), 3), | |||
| (1, make_tensor(-1, network), 3), | |||
| np.array([1, -1, 3], dtype="int32"), | |||
| tensor([1, -1, 3]), | |||
| make_tensor([1, -1, 3], network), | |||
| ]: | |||
| yy = F.reshape(xx, shape) | |||
| np.testing.assert_equal(yy.numpy(), y) | |||
| def test_reshape_shape_inference(): | |||
| x_shape_known = tensor([1, 2, 3, 4], dtype="float32") | |||
| x_shape_unknown = F.broadcast_to(tensor([1.0]), shape=tensor([1, 1, 1, 1]).sum()) | |||
| tshp_unknown = astensor1d((tensor([2]), tensor([2])), x_shape_known) | |||
| @pytest.mark.parametrize("is_varnode", [True, False]) | |||
| def test_reshape_shape_inference(is_varnode): | |||
| if is_varnode: | |||
| network = Network() | |||
| else: | |||
| network = None | |||
| x_shape_known = make_tensor([1, 2, 3, 4], network) | |||
| x_shape_unknown = F.broadcast_to( | |||
| make_tensor([1.0], network), shape=make_tensor([1, 1, 1, 1], network).sum() | |||
| ) | |||
| tshp_unknown = astensor1d( | |||
| (make_tensor([2], network), make_tensor([2], network)), x_shape_known | |||
| ) | |||
| tshp_known = astensor1d((2, 2), x_shape_known) | |||
| tshp_known_unspec = astensor1d((2, -1), x_shape_known) | |||
| @@ -146,12 +191,18 @@ def test_reshape_shape_inference(): | |||
| {"input": [x_shape_unknown, tshp_known], "output": [(2, 2),]}, | |||
| {"input": [x_shape_unknown, tshp_known_unspec], "output": [(2, 2),]}, | |||
| ] | |||
| opr_test(cases, func, compare_fn=check_shape, test_trace=True) | |||
| opr_test(cases, func, compare_fn=check_shape, test_trace=True, network=network) | |||
| @pytest.mark.parametrize("is_varnode", [True, False]) | |||
| def test_squeeze(is_varnode): | |||
| if is_varnode: | |||
| network = Network() | |||
| else: | |||
| network = None | |||
| def test_squeeze(): | |||
| x = np.arange(6, dtype="float32").reshape(1, 2, 3, 1) | |||
| xx = tensor(x) | |||
| xx = make_tensor(x, network) | |||
| for axis in [None, 3, -4, (3, -4)]: | |||
| y = np.squeeze(x, axis) | |||
| @@ -159,9 +210,15 @@ def test_squeeze(): | |||
| np.testing.assert_equal(y, yy.numpy()) | |||
| def test_expand_dims(): | |||
| @pytest.mark.parametrize("is_varnode", [True, False]) | |||
| def test_expand_dims(is_varnode): | |||
| if is_varnode: | |||
| network = Network() | |||
| else: | |||
| network = None | |||
| x = np.arange(6, dtype="float32").reshape(2, 3) | |||
| xx = tensor(x) | |||
| xx = make_tensor(x, network) | |||
| for axis in [2, -3, (3, -4), (1, -4)]: | |||
| y = np.expand_dims(x, axis) | |||
| @@ -169,11 +226,17 @@ def test_expand_dims(): | |||
| np.testing.assert_equal(y, yy.numpy()) | |||
| def test_elemwise_dtype_promotion(): | |||
| @pytest.mark.parametrize("is_varnode", [True, False]) | |||
| def test_elemwise_dtype_promotion(is_varnode): | |||
| if is_varnode: | |||
| network = Network() | |||
| else: | |||
| network = None | |||
| x = np.random.rand(2, 3).astype("float32") | |||
| y = np.random.rand(1, 3).astype("float16") | |||
| xx = tensor(x) | |||
| yy = tensor(y) | |||
| xx = make_tensor(x, network) | |||
| yy = make_tensor(y, network) | |||
| z = xx * yy | |||
| np.testing.assert_equal(z.numpy(), x * y) | |||
| @@ -184,7 +247,13 @@ def test_elemwise_dtype_promotion(): | |||
| np.testing.assert_equal(z.numpy(), x - y) | |||
| def test_linspace(): | |||
| @pytest.mark.parametrize("is_varnode", [True, False]) | |||
| def test_linspace(is_varnode): | |||
| if is_varnode: | |||
| network = Network() | |||
| else: | |||
| network = None | |||
| cases = [ | |||
| {"input": [1, 9, 9]}, | |||
| {"input": [3, 10, 8]}, | |||
| @@ -193,6 +262,7 @@ def test_linspace(): | |||
| cases, | |||
| F.linspace, | |||
| ref_fn=lambda start, end, step: np.linspace(start, end, step, dtype=np.float32), | |||
| network=network, | |||
| ) | |||
| cases = [ | |||
| @@ -203,20 +273,28 @@ def test_linspace(): | |||
| cases, | |||
| F.linspace, | |||
| ref_fn=lambda start, end, step: np.linspace(start, end, step, dtype=np.float32), | |||
| network=network, | |||
| ) | |||
| cases = [ | |||
| {"input": [1, tensor(9), 9]}, | |||
| {"input": [tensor(1), 9, tensor(9)]}, | |||
| {"input": [1, make_tensor(9, network), 9]}, | |||
| {"input": [make_tensor(1, network), 9, make_tensor(9, network)]}, | |||
| ] | |||
| opr_test( | |||
| cases, | |||
| F.linspace, | |||
| ref_fn=lambda start, end, step: np.linspace(1, 9, 9, dtype=np.float32), | |||
| network=network, | |||
| ) | |||
| def test_arange(): | |||
| @pytest.mark.parametrize("is_varnode", [True, False]) | |||
| def test_arange(is_varnode): | |||
| if is_varnode: | |||
| network = Network() | |||
| else: | |||
| network = None | |||
| cases = [ | |||
| {"input": [1, 9, 1]}, | |||
| {"input": [2, 10, 2]}, | |||
| @@ -225,6 +303,7 @@ def test_arange(): | |||
| cases, | |||
| F.arange, | |||
| ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32), | |||
| network=network, | |||
| ) | |||
| cases = [ | |||
| @@ -235,6 +314,7 @@ def test_arange(): | |||
| cases, | |||
| F.arange, | |||
| ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32), | |||
| network=network, | |||
| ) | |||
| cases = [ | |||
| @@ -245,20 +325,33 @@ def test_arange(): | |||
| cases, | |||
| F.arange, | |||
| ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32), | |||
| network=network, | |||
| ) | |||
| def test_round(): | |||
| @pytest.mark.parametrize("is_varnode", [True, False]) | |||
| def test_round(is_varnode): | |||
| if is_varnode: | |||
| network = Network() | |||
| else: | |||
| network = None | |||
| data1_shape = (15,) | |||
| data2_shape = (25,) | |||
| data1 = np.random.random(data1_shape).astype(np.float32) | |||
| data2 = np.random.random(data2_shape).astype(np.float32) | |||
| cases = [{"input": data1}, {"input": data2}] | |||
| opr_test(cases, F.round, ref_fn=np.round) | |||
| opr_test(cases, F.round, ref_fn=np.round, network=network) | |||
| def test_flatten(): | |||
| @pytest.mark.parametrize("is_varnode", [True, False]) | |||
| def test_flatten(is_varnode): | |||
| if is_varnode: | |||
| network = Network() | |||
| else: | |||
| network = None | |||
| data0_shape = (2, 3, 4, 5) | |||
| data1_shape = (4, 5, 6, 7) | |||
| data0 = np.random.random(data0_shape).astype(np.float32) | |||
| @@ -273,7 +366,7 @@ def test_flatten(): | |||
| {"input": data0, "output": output0}, | |||
| {"input": data1, "output": output1}, | |||
| ] | |||
| opr_test(cases, F.flatten, compare_fn=compare_fn) | |||
| opr_test(cases, F.flatten, compare_fn=compare_fn, network=network) | |||
| output0 = (2, 3 * 4 * 5) | |||
| output1 = (4, 5 * 6 * 7) | |||
| @@ -281,7 +374,7 @@ def test_flatten(): | |||
| {"input": data0, "output": output0}, | |||
| {"input": data1, "output": output1}, | |||
| ] | |||
| opr_test(cases, F.flatten, compare_fn=compare_fn, start_axis=1) | |||
| opr_test(cases, F.flatten, compare_fn=compare_fn, start_axis=1, network=network) | |||
| output0 = (2, 3, 4 * 5) | |||
| output1 = (4, 5, 6 * 7) | |||
| @@ -289,7 +382,7 @@ def test_flatten(): | |||
| {"input": data0, "output": output0}, | |||
| {"input": data1, "output": output1}, | |||
| ] | |||
| opr_test(cases, F.flatten, compare_fn=compare_fn, start_axis=2) | |||
| opr_test(cases, F.flatten, compare_fn=compare_fn, start_axis=2, network=network) | |||
| output0 = (2, 3 * 4, 5) | |||
| output1 = (4, 5 * 6, 7) | |||
| @@ -297,10 +390,23 @@ def test_flatten(): | |||
| {"input": data0, "output": output0}, | |||
| {"input": data1, "output": output1}, | |||
| ] | |||
| opr_test(cases, F.flatten, compare_fn=compare_fn, start_axis=1, end_axis=2) | |||
| opr_test( | |||
| cases, | |||
| F.flatten, | |||
| compare_fn=compare_fn, | |||
| start_axis=1, | |||
| end_axis=2, | |||
| network=network, | |||
| ) | |||
| @pytest.mark.parametrize("is_varnode", [True, False]) | |||
| def test_broadcast(is_varnode): | |||
| if is_varnode: | |||
| network = Network() | |||
| else: | |||
| network = None | |||
| def test_broadcast(): | |||
| input1_shape = (20, 30) | |||
| output1_shape = (30, 20, 30) | |||
| data1 = np.random.random(input1_shape).astype(np.float32) | |||
| @@ -309,14 +415,19 @@ def test_broadcast(): | |||
| output2_shape = (20, 10, 20) | |||
| data2 = np.random.random(input2_shape).astype(np.float32) | |||
| input3_shape = (10, 10) | |||
| output3_shape = (10, 10) | |||
| data3 = np.random.random(input3_shape).astype(np.float32) | |||
| def compare_fn(x, y): | |||
| assert x.shape[0] == y | |||
| cases = [ | |||
| {"input": [data1, output1_shape], "output": output1_shape}, | |||
| {"input": [data2, output2_shape], "output": output2_shape}, | |||
| {"input": [data3, output3_shape], "output": output3_shape}, | |||
| ] | |||
| opr_test(cases, F.broadcast_to, compare_fn=compare_fn) | |||
| opr_test(cases, F.broadcast_to, compare_fn=compare_fn, network=network) | |||
| x = F.ones((2, 1, 3)) | |||
| with pytest.raises(RuntimeError): | |||
| @@ -329,35 +440,41 @@ def test_broadcast(): | |||
| F.broadcast_to(x, (1, 3)) | |||
| def test_utils_astensor1d(): | |||
| reference = tensor(0) | |||
| @pytest.mark.parametrize("is_varnode", [True, False]) | |||
| def test_utils_astensor1d(is_varnode): | |||
| if is_varnode: | |||
| network = Network() | |||
| else: | |||
| network = None | |||
| reference = make_tensor(0, network) | |||
| # literal | |||
| x = [1, 2, 3] | |||
| for dtype in [None, "float32"]: | |||
| xx = astensor1d(x, reference, dtype=dtype) | |||
| assert type(xx) is tensor | |||
| assert isinstance(xx, type(reference)) | |||
| np.testing.assert_equal(xx.numpy(), x) | |||
| # numpy array | |||
| x = np.asarray([1, 2, 3], dtype="int32") | |||
| for dtype in [None, "float32"]: | |||
| xx = astensor1d(x, reference, dtype=dtype) | |||
| assert type(xx) is tensor | |||
| assert isinstance(xx, type(reference)) | |||
| np.testing.assert_equal(xx.numpy(), x.astype(dtype) if dtype else x) | |||
| # tensor | |||
| x = tensor([1, 2, 3], dtype="int32") | |||
| x = make_tensor([1, 2, 3], network) | |||
| for dtype in [None, "float32"]: | |||
| xx = astensor1d(x, reference, dtype=dtype) | |||
| assert type(xx) is tensor | |||
| assert isinstance(xx, type(reference)) | |||
| np.testing.assert_equal(xx.numpy(), x.numpy()) | |||
| # mixed | |||
| x = [1, tensor(2), 3] | |||
| x = [1, make_tensor(2, network), 3] | |||
| for dtype in [None, "float32"]: | |||
| xx = astensor1d(x, reference, dtype=dtype) | |||
| assert type(xx) is tensor | |||
| assert isinstance(xx, type(reference)) | |||
| np.testing.assert_equal(xx.numpy(), [1, 2, 3]) | |||
| @@ -377,35 +494,60 @@ def test_device(): | |||
| np.testing.assert_almost_equal(y5.numpy(), y6.numpy()) | |||
| def test_identity(): | |||
| x = tensor(np.random.random((5, 10)).astype(np.float32)) | |||
| @pytest.mark.parametrize("is_varnode", [True, False]) | |||
| def test_identity(is_varnode): | |||
| if is_varnode: | |||
| network = Network() | |||
| else: | |||
| network = None | |||
| x = make_tensor(np.random.random((5, 10)).astype(np.float32), network) | |||
| y = F.copy(x) | |||
| np.testing.assert_equal(y.numpy(), x) | |||
| def copy_test(dst, src): | |||
| def copy_test(dst, src, network): | |||
| data = np.random.random((2, 3)).astype(np.float32) | |||
| x = tensor(data, device=src) | |||
| x = make_tensor(data, device=src, network=network) | |||
| y = F.copy(x, dst) | |||
| assert np.allclose(data, y.numpy()) | |||
| z = x.to(dst) | |||
| assert np.allclose(data, z.numpy()) | |||
| if network is None: | |||
| z = x.to(dst) | |||
| assert np.allclose(data, z.numpy()) | |||
| @pytest.mark.require_ngpu(1) | |||
| def test_copy_h2d(): | |||
| copy_test("cpu0", "gpu0") | |||
| @pytest.mark.parametrize("is_varnode", [True, False]) | |||
| def test_copy_h2d(is_varnode): | |||
| if is_varnode: | |||
| network = Network() | |||
| else: | |||
| network = None | |||
| copy_test("cpu0", "gpu0", network=network) | |||
| @pytest.mark.require_ngpu(1) | |||
| def test_copy_d2h(): | |||
| copy_test("gpu0", "cpu0") | |||
| @pytest.mark.parametrize("is_varnode", [True, False]) | |||
| def test_copy_d2h(is_varnode): | |||
| if is_varnode: | |||
| network = Network() | |||
| else: | |||
| network = None | |||
| copy_test("gpu0", "cpu0", network=network) | |||
| @pytest.mark.require_ngpu(2) | |||
| def test_copy_d2d(): | |||
| copy_test("gpu0", "gpu1") | |||
| copy_test("gpu0:0", "gpu0:1") | |||
| @pytest.mark.parametrize("is_varnode", [True, False]) | |||
| def test_copy_d2d(is_varnode): | |||
| if is_varnode: | |||
| network = Network() | |||
| else: | |||
| network = None | |||
| copy_test("gpu0", "gpu1", network=network) | |||
| copy_test("gpu0:0", "gpu0:1", network=network) | |||
| @pytest.mark.parametrize( | |||
| @@ -420,7 +562,13 @@ def test_copy_d2d(): | |||
| ((), 10, None), | |||
| ], | |||
| ) | |||
| def test_repeat(shape, repeats, axis): | |||
| @pytest.mark.parametrize("is_varnode", [True, False]) | |||
| def test_repeat(shape, repeats, axis, is_varnode): | |||
| if is_varnode: | |||
| network = Network() | |||
| else: | |||
| network = None | |||
| def repeat_func(inp): | |||
| return F.repeat(inp=inp, repeats=repeats, axis=axis) | |||
| @@ -432,7 +580,10 @@ def test_repeat(shape, repeats, axis): | |||
| cases = [{"input": np.array(1.23)}] | |||
| opr_test( | |||
| cases, repeat_func, ref_fn=lambda inp: np.repeat(inp, repeats, axis), | |||
| cases, | |||
| repeat_func, | |||
| ref_fn=lambda inp: np.repeat(inp, repeats, axis), | |||
| network=network, | |||
| ) | |||
| @@ -445,14 +596,16 @@ def test_repeat(shape, repeats, axis): | |||
| ((2, 3, 4, 5), (2, 2, 2, 2, 2, 2, 2)), | |||
| ], | |||
| ) | |||
| def test_tile(shape, reps): | |||
| @pytest.mark.parametrize("is_varnode", [True]) | |||
| def test_tile(shape, reps, is_varnode): | |||
| if is_varnode: | |||
| network = Network() | |||
| else: | |||
| network = None | |||
| def tile_func(inp): | |||
| return F.tile(inp=inp, reps=reps) | |||
| cases = [ | |||
| {"input": np.random.randn(*shape).astype("float32")}, | |||
| ] | |||
| cases = [{"input": np.random.randn(*shape).astype("float32")}] | |||
| opr_test( | |||
| cases, tile_func, ref_fn=lambda inp: np.tile(inp, reps), | |||
| ) | |||
| opr_test(cases, tile_func, ref_fn=lambda inp: np.tile(inp, reps), network=network) | |||
| @@ -30,7 +30,10 @@ min_max_fakequant_qconfig = QConfig( | |||
| act_fake_quant=partial(FakeQuantize, dtype="qint8"), | |||
| ) | |||
| inp_scale = np.float32(np.random.rand() + 1) | |||
| def gen_inp_scale(): | |||
| return np.float32(np.random.rand() + 1) | |||
| min_val = np.random.randint(-127, 0, size=(2,)).astype("float32") | |||
| max_val = np.random.randint(1, 127, size=(2,)).astype("float32") | |||
| @@ -116,6 +119,7 @@ def test_dequant_stub(): | |||
| q_net.eval() | |||
| x = mge.tensor(np.random.normal(size=(3, 3)).astype("float32")) | |||
| inp_scale = gen_inp_scale() | |||
| x = fake_quant_act(x, inp_scale) | |||
| x.qparams.scale = inp_scale | |||
| @@ -192,6 +196,7 @@ def test_linear(): | |||
| init_qat_net(qat_net) | |||
| x = mge.tensor(np.random.normal(size=(3, 3)).astype("float32")) | |||
| inp_scale = gen_inp_scale() | |||
| x = fake_quant_act(x, inp_scale) | |||
| x.qparams.update(create_qparams(QuantMode.SYMMERTIC, "qint8", inp_scale)) | |||
| @@ -235,6 +240,7 @@ def test_conv(module): | |||
| init_qat_net(qat_net) | |||
| x = mge.tensor(np.random.normal(size=(1, 3, 3, 3)).astype("float32")) | |||
| inp_scale = gen_inp_scale() | |||
| x = fake_quant_act(x, inp_scale) | |||
| x.qparams.update(create_qparams(QuantMode.SYMMERTIC, "qint8", inp_scale)) | |||
| @@ -269,3 +275,41 @@ def test_conv(module): | |||
| np.testing.assert_allclose(qat_without_fakequant, normal, atol=1e-5) | |||
| np.testing.assert_allclose(qat, fake_quant_normal, atol=act_scale) | |||
| np.testing.assert_allclose(q, fake_quant_normal.numpy(), atol=act_scale) | |||
| def test_concat(): | |||
| normal_net = Float.Concat() | |||
| normal_net.eval() | |||
| qat_net = QAT.Concat() | |||
| qat_net.eval() | |||
| disable_observer(qat_net) | |||
| propagate_qconfig(qat_net, min_max_fakequant_qconfig) | |||
| init_qat_net(qat_net) | |||
| inps = [] | |||
| inps_int8 = [] | |||
| for i in range(3): | |||
| inp_scale = gen_inp_scale() | |||
| inps.append(mge.tensor(np.random.normal(size=(3, 3)).astype("float32"))) | |||
| inps[i] = fake_quant_act(inps[i], inp_scale) | |||
| inps[i].qparams.update(create_qparams(QuantMode.SYMMERTIC, "qint8", inp_scale)) | |||
| inps_int8.append(quant(inps[i], inp_scale)) | |||
| qat_from_float = QAT.Concat.from_float_module(normal_net) | |||
| qat_from_float.eval() | |||
| disable_fake_quant(qat_from_float) | |||
| disable_observer(qat_from_float) | |||
| q_net = Q.Concat.from_qat_module(qat_net) | |||
| q_net.eval() | |||
| normal = normal_net(inps) | |||
| qat_without_fakequant = qat_from_float(inps) | |||
| fake_quant_normal = fake_quant_act(normal_net(inps), act_scale) | |||
| qat = qat_net(inps) | |||
| q = q_net(inps_int8).numpy() * act_scale | |||
| np.testing.assert_allclose(qat_without_fakequant, normal) | |||
| np.testing.assert_allclose(qat, fake_quant_normal.numpy()) | |||
| np.testing.assert_allclose(q, fake_quant_normal.numpy()) | |||
| @@ -123,6 +123,35 @@ def test_with_submodule(symbolic): | |||
| assert ops[-1].outputs[0].name == "simple.linear.ADD" | |||
| @pytest.mark.parametrize("symbolic", [False, True]) | |||
| def test_with_submodule_in_container(symbolic): | |||
| class Simple(M.Module): | |||
| def __init__(self, name): | |||
| super().__init__() | |||
| self.name = name | |||
| self.l0 = [M.Linear(3, 3) for _ in range(2)] | |||
| self.l1 = tuple(self.l0) | |||
| self.l2 = dict(zip(["l2-0", "l2-1"], self.l0)) | |||
| def forward(self, x): | |||
| for i in range(2): | |||
| x = self.l0[i](x) | |||
| x = self.l1[i](x) | |||
| x = self.l2["l2-%d" % i](x) | |||
| return x | |||
| m = Simple("simple") | |||
| ops = _dump_and_load(m, symbolic) | |||
| assert ops[-1].outputs[0].name == "simple.l2.l2-1.ADD" | |||
| assert ops[-1].name == "simple.l2.l2-1.ADD" | |||
| assert ops[-2].name == "simple.l2.l2-1.MatrixMul" | |||
| assert ops[-3].name == "simple.l1.1.ADD" | |||
| assert ops[-4].name == "simple.l1.1.MatrixMul" | |||
| assert ops[-5].name == "simple.l0.1.ADD" | |||
| assert ops[-6].name == "simple.l0.1.MatrixMul" | |||
| @pytest.mark.parametrize("symbolic", [False, True]) | |||
| def test_named_submodule(symbolic): | |||
| class Simple(M.Module): | |||
| @@ -264,4 +293,4 @@ def test_quantized_module_user_naming_param(symbolic): | |||
| (matrix_mul_op,) = [op for op in ops if op.name == "simple.linear.MatrixMul"] | |||
| for var in matrix_mul_op.inputs: | |||
| assert var.name in ("simple.quant.TypeCvt", "simple.linear.user-weight") | |||
| # BUG bias' name does not meet expectations because of astype operator after quantization | |||
| # WONTFIX: bias' name does not meet expectations because of astype operator after quantization | |||
| @@ -34,13 +34,11 @@ def test_replace_var(): | |||
| vara = graph.var_filter.name("a").as_unique() | |||
| varb = graph.var_filter.name("b").as_unique() | |||
| out = F.mul(vara.var, varb.var) | |||
| out = F.mul(vara, varb) | |||
| out = F.relu(out) | |||
| var_list = graph.add_dep_oprs(out) | |||
| opnode = list(graph.opr_filter.has_input(vara)) | |||
| repl_dict = {opnode[0].outputs[0]: var_list[0]} | |||
| repl_dict = {opnode[0].outputs[0]: out} | |||
| graph.replace_vars(repl_dict) | |||
| modified_model = io.BytesIO() | |||
| @@ -72,14 +70,12 @@ def test_replace_opr(): | |||
| vara = graph.var_filter.name("a").as_unique() | |||
| varb = graph.var_filter.name("b").as_unique() | |||
| out1 = F.sub(vara.var, varb.var) | |||
| out1 = F.sub(vara, varb) | |||
| out1 = F.relu(out1) | |||
| var_list = graph.add_dep_oprs(out1) | |||
| repl_opr = as_oprnode(var_list) | |||
| out1 = graph.add_dep_oprs(out1) | |||
| orig_opr = graph.opr_filter.has_input(vara).as_unique() | |||
| repl_dict = {orig_opr: repl_opr} | |||
| repl_dict = {orig_opr: out1[0].owner} | |||
| graph.replace_oprs(repl_dict) | |||
| modified_model1 = io.BytesIO() | |||
| graph.dump(modified_model1) | |||
| @@ -171,8 +167,7 @@ def test_add_input(): | |||
| inp_c = graph.make_input_node((2,), np.int32, name="c") | |||
| varo = graph.var_filter.name("o").as_unique() | |||
| out = F.add(varo.var, inp_c.var) | |||
| out = graph.add_dep_oprs(out)[0] | |||
| out = F.add(varo, inp_c) | |||
| out.name = "o1" | |||
| graph.remove_output(varo) | |||
| graph.add_output(out) | |||
| @@ -206,12 +201,11 @@ def test_add_output(): | |||
| var_a = net.var_filter.name("a").as_unique() | |||
| var_b = net.var_filter.name("b").as_unique() | |||
| y = F.add(var_a.var, var_b.var) | |||
| y = F.add(var_a, var_b) | |||
| y = F.sigmoid(y) | |||
| new_vars = net.add_dep_oprs(y)[0] | |||
| new_vars.name = "o1" | |||
| net.add_output(new_vars) | |||
| y.name = "o1" | |||
| net.add_output(y) | |||
| modified_model = io.BytesIO() | |||
| net.dump(modified_model) | |||
| @@ -466,6 +466,19 @@ def test_topk(): | |||
| check_pygraph_dump(fwd, [x], [top, indices]) | |||
| def test_nvof(): | |||
| if not is_cuda_available(): | |||
| return | |||
| src_shape = (4, 5, 224, 224, 4) | |||
| src = np.random.randint(0, 255, src_shape).astype("uint8") | |||
| src = Tensor(src) | |||
| @trace(symbolic=True, capture_as_const=True) | |||
| def fwd(src): | |||
| return F.nn.nvof(src, precision=1) | |||
| result = fwd(src) | |||
| check_pygraph_dump(fwd, [src], [result]) | |||
| def test_random(): | |||
| @@ -6,5 +6,5 @@ | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| __version__ = "1.3.0.dev" | |||
| __version__ = "1.3.1" | |||
| @@ -24,14 +24,14 @@ std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) { | |||
| return Broadcast::make(); | |||
| } | |||
| cg::OperatorNodeBase* apply_on_var_node( | |||
| auto apply_on_var_node( | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| auto&& op = def.cast_final_safe<Broadcast>(); | |||
| size_t nr_inp = inputs.size(); | |||
| mgb_assert(nr_inp == 2, "Broadcast expects 2 inputs; got %lu actually", nr_inp); | |||
| OperatorNodeConfig config{op.make_name()}; | |||
| return opr::Broadcast::make(inputs[0], inputs[1], config).node()->owner_opr(); | |||
| return opr::Broadcast::make(inputs[0], inputs[1], config); | |||
| } | |||
| bool valid_broadcast(const TensorShape& src_shape, | |||
| @@ -1,6 +1,7 @@ | |||
| # mgb tablegen executable | |||
| set(TABLE_TARGET mgb-mlir-autogen) | |||
| add_executable(${TABLE_TARGET} autogen.cpp) | |||
| file(GLOB_RECURSE SRCS ${CMAKE_CURRENT_SOURCE_DIR}/*.h ${CMAKE_CURRENT_SOURCE_DIR}/*.cpp) | |||
| add_executable(${TABLE_TARGET} ${SRCS}) | |||
| target_include_directories(${TABLE_TARGET} PRIVATE ${MLIR_LLVM_INCLUDE_DIR}) | |||
| target_link_libraries(${TABLE_TARGET} PRIVATE LLVMTableGen MLIRTableGen LLVMSupport) | |||
| set(MGB_TABLEGEN_EXE ${TABLE_TARGET}) | |||
| @@ -1,8 +1,17 @@ | |||
| #include <iostream> | |||
| #include <unordered_map> | |||
| #include <functional> | |||
| #include "./helper.h" | |||
| /** | |||
| * \file imperative/tablegen/autogen.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #include "./targets/cpp_class.h" | |||
| #include "./targets/pybind11.h" | |||
| #include "./targets/python_c_extension.h" | |||
| using llvm::raw_ostream; | |||
| using llvm::RecordKeeper; | |||
| @@ -27,731 +36,7 @@ llvm::cl::opt<ActionType> action( | |||
| clEnumValN(CPython, "gen-python-c-extension", | |||
| "Generate python c extensions"))); | |||
| using MgbAttrWrapper = mlir::tblgen::MgbAttrWrapperBase; | |||
| using MgbEnumAttr = mlir::tblgen::MgbEnumAttrMixin; | |||
| using MgbHashableAttr = mlir::tblgen::MgbHashableAttrMixin; | |||
| using MgbAliasAttr = mlir::tblgen::MgbAliasAttrMixin; | |||
| using MgbOp = mlir::tblgen::MgbOpBase; | |||
| using MgbHashableOp = mlir::tblgen::MgbHashableOpMixin; | |||
| llvm::StringRef attr_to_ctype(const mlir::tblgen::Attribute& attr_) { | |||
| // Note: we have already registered the corresponding attr wrappers | |||
| // for following basic ctypes so we needn't handle them here | |||
| /* auto&& attr_type_name = attr.getAttrDefName(); | |||
| if (attr_type_name == "UI32Attr") { | |||
| return "uint32_t"; | |||
| } | |||
| if (attr_type_name == "UI64Attr") { | |||
| return "uint64_t"; | |||
| } | |||
| if (attr_type_name == "I32Attr") { | |||
| return "int32_t"; | |||
| } | |||
| if (attr_type_name == "F32Attr") { | |||
| return "float"; | |||
| } | |||
| if (attr_type_name == "F64Attr") { | |||
| return "double"; | |||
| } | |||
| if (attr_type_name == "StrAttr") { | |||
| return "std::string"; | |||
| } | |||
| if (attr_type_name == "BoolAttr") { | |||
| return "bool"; | |||
| }*/ | |||
| auto&& attr = llvm::cast<MgbAttrWrapper>(attr_); | |||
| if (auto e = llvm::dyn_cast<MgbEnumAttr>(&attr)) { | |||
| return e->getEnumName(); | |||
| } | |||
| return attr.getUnderlyingType(); | |||
| } | |||
| static void gen_op_def_c_header_single(raw_ostream &os, MgbOp& op) { | |||
| os << formatv( | |||
| "class {0} : public OpDefImplBase<{0}> {{\n" | |||
| " MGB_DYN_TYPE_OBJ_FINAL_DECL;\n\n" | |||
| "public:\n", | |||
| op.getCppClassName() | |||
| ); | |||
| // handle enum alias | |||
| for (auto &&i : op.getMgbAttributes()) { | |||
| if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) { | |||
| os << formatv( | |||
| " using {0} = {1};\n", | |||
| attr->getEnumName(), attr->getUnderlyingType() | |||
| ); | |||
| } | |||
| } | |||
| for (auto &&i : op.getMgbAttributes()) { | |||
| auto defaultValue = i.attr.getDefaultValue().str(); | |||
| if (!defaultValue.empty()) { | |||
| defaultValue = formatv(" = {0}", defaultValue); | |||
| } | |||
| os << formatv( | |||
| " {0} {1}{2};\n", | |||
| attr_to_ctype(i.attr), i.name, defaultValue | |||
| ); | |||
| } | |||
| auto gen_ctor = [&](auto&& paramList, auto&& memInitList, auto&& body) { | |||
| os << formatv( | |||
| " {0}({1}){2}{3}\n", | |||
| op.getCppClassName(), paramList, memInitList, body | |||
| ); | |||
| }; | |||
| gen_ctor("", "", " = default;"); | |||
| if (!op.getMgbAttributes().empty()) { | |||
| std::vector<std::string> paramList, initList; | |||
| for (auto &&i : op.getMgbAttributes()) { | |||
| paramList.push_back(formatv( | |||
| "{0} {1}_", attr_to_ctype(i.attr), i.name | |||
| )); | |||
| initList.push_back(formatv( | |||
| "{0}({0}_)", i.name | |||
| )); | |||
| } | |||
| paramList.push_back("std::string scope_ = {}"); | |||
| gen_ctor(llvm::join(paramList, ", "), | |||
| ": " + llvm::join(initList, ", "), | |||
| " { set_scope(scope_); }"); | |||
| } | |||
| auto packedParams = op.getPackedParams(); | |||
| if (!packedParams.empty()) { | |||
| std::vector<std::string> paramList, initList; | |||
| for (auto &&p : packedParams) { | |||
| auto&& paramFields = p.getFields(); | |||
| auto&& paramType = p.getFullName(); | |||
| auto&& paramName = formatv("packed_param_{0}", paramList.size()); | |||
| paramList.push_back( | |||
| paramFields.empty() ? paramType.str() | |||
| : formatv("{0} {1}", paramType, paramName) | |||
| ); | |||
| for (auto&& i : paramFields) { | |||
| initList.push_back(formatv( | |||
| "{0}({1}.{0})", i.name, paramName | |||
| )); | |||
| } | |||
| } | |||
| for (auto&& i : op.getExtraArguments()) { | |||
| paramList.push_back(formatv( | |||
| "{0} {1}_", attr_to_ctype(i.attr), i.name | |||
| )); | |||
| initList.push_back(formatv( | |||
| "{0}({0}_)", i.name | |||
| )); | |||
| } | |||
| gen_ctor(llvm::join(paramList, ", "), | |||
| initList.empty() ? "" : ": " + llvm::join(initList, ", "), | |||
| " {}"); | |||
| } | |||
| if (!packedParams.empty()) { | |||
| for (auto&& p : packedParams) { | |||
| auto accessor = p.getAccessor(); | |||
| if (!accessor.empty()) { | |||
| os << formatv( | |||
| " {0} {1}() const {{\n", | |||
| p.getFullName(), accessor | |||
| ); | |||
| std::vector<llvm::StringRef> fields; | |||
| for (auto&& i : p.getFields()) { | |||
| fields.push_back(i.name); | |||
| } | |||
| os << formatv( | |||
| " return {{{0}};\n", | |||
| llvm::join(fields, ", ") | |||
| ); | |||
| os << " }\n"; | |||
| } | |||
| } | |||
| } | |||
| if (auto decl = op.getExtraOpdefDecl()) { | |||
| os << decl.getValue(); | |||
| } | |||
| os << formatv( | |||
| "};\n\n" | |||
| ); | |||
| } | |||
| static void gen_to_string_trait_for_enum(raw_ostream &os, MgbOp& op) { | |||
| for (auto &&i : op.getMgbAttributes()) { | |||
| if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) { | |||
| if (attr->supportToString()) { | |||
| std::vector<std::string> case_body; | |||
| std::string ename = formatv("{0}::{1}", | |||
| op.getCppClassName(), attr->getEnumName()); | |||
| llvm::for_each(attr->getEnumMembers(), [&](auto&& v){ | |||
| case_body.push_back(formatv( | |||
| "case {0}::{1}: return \"{1}\";", ename, v)); | |||
| }); | |||
| os << formatv(R"( | |||
| template <> | |||
| struct ToStringTrait<{0}> { | |||
| std::string operator()({0} e) const { | |||
| switch (e) { | |||
| {1} | |||
| default: | |||
| return "{0}::Unknown"; | |||
| } | |||
| } | |||
| }; | |||
| )", ename, llvm::join(case_body, "\n")); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| static void gen_op_def_c_body_single(raw_ostream &os, MgbOp& op) { | |||
| auto&& className = op.getCppClassName(); | |||
| os << formatv( | |||
| "MGB_DYN_TYPE_OBJ_FINAL_IMPL({0});\n\n", className | |||
| ); | |||
| auto formatMethImpl = [&](auto&& meth) { | |||
| return formatv( | |||
| "{0}_{1}_impl", className, meth | |||
| ); | |||
| }; | |||
| std::vector<std::string> methods; | |||
| if (auto hashable = llvm::dyn_cast<MgbHashableOp>(&op)) { | |||
| os << "namespace {\n"; | |||
| // generate hash() | |||
| mlir::tblgen::FmtContext ctx; | |||
| os << formatv( | |||
| "size_t {0}(const OpDef& def_) {{\n", | |||
| formatMethImpl("hash") | |||
| ); | |||
| os << formatv( | |||
| " auto&& op_ = def_.cast_final_safe<{0}>();\n" | |||
| " static_cast<void>(op_);\n", | |||
| className | |||
| ); | |||
| ctx.withSelf("op_"); | |||
| os << mlir::tblgen::tgfmt(hashable->getHashFunctionTemplate(), &ctx); | |||
| os << "}\n"; | |||
| // generate is_same_st() | |||
| os << formatv( | |||
| "bool {0}(const OpDef& lhs_, const OpDef& rhs_) {{\n", | |||
| formatMethImpl("is_same_st") | |||
| ); | |||
| os << formatv( | |||
| " auto &&a_ = lhs_.cast_final_safe<{0}>(),\n" | |||
| " &&b_ = rhs_.cast_final_safe<{0}>();\n" | |||
| " static_cast<void>(a_);\n" | |||
| " static_cast<void>(b_);\n", | |||
| className | |||
| ); | |||
| os << mlir::tblgen::tgfmt(hashable->getCmpFunctionTemplate(), &ctx, "a_", "b_"); | |||
| os << "}\n"; | |||
| // generate props() | |||
| os << formatv( | |||
| "std::vector<std::pair<const char*, std::string>> {0}(const OpDef& def_) {{\n", | |||
| formatMethImpl("props") | |||
| ); | |||
| os << formatv( | |||
| " auto&& op_ = def_.cast_final_safe<{0}>();\n" | |||
| " static_cast<void>(op_);\n", | |||
| className | |||
| ); | |||
| ctx.withSelf("op_"); | |||
| os << mlir::tblgen::tgfmt(hashable->getPropsFunctionTemplate(), &ctx); | |||
| os << "}\n"; | |||
| // generate make_name() | |||
| os << formatv( | |||
| "std::string {0}(const OpDef& def_) {{\n", formatMethImpl("make_name") | |||
| ); | |||
| os << formatv( | |||
| " auto&& op_ = def_.cast_final_safe<{0}>();\n" | |||
| " static_cast<void>(op_);\n", | |||
| className | |||
| ); | |||
| ctx.withSelf("op_"); | |||
| os << mlir::tblgen::tgfmt(op.getNameFunctionTemplate(), &ctx); | |||
| os << "}\n"; | |||
| os << "} // anonymous namespace\n"; | |||
| methods.push_back("hash"); | |||
| methods.push_back("is_same_st"); | |||
| methods.push_back("props"); | |||
| methods.push_back("make_name"); | |||
| } | |||
| if (!methods.empty()) { | |||
| os << formatv( | |||
| "OP_TRAIT_REG({0}, {0})", op.getCppClassName() | |||
| ); | |||
| for (auto&& i : methods) { | |||
| os << formatv( | |||
| "\n .{0}({1})", i, formatMethImpl(i) | |||
| ); | |||
| } | |||
| os << ";\n\n"; | |||
| } | |||
| } | |||
| struct EnumContext { | |||
| std::unordered_map<unsigned int, std::pair<llvm::StringRef, llvm::StringRef>> enumAlias; | |||
| }; | |||
| static void gen_op_def_pybind11_single(raw_ostream &os, MgbOp& op, EnumContext& ctx) { | |||
| auto className = op.getCppClassName(); | |||
| os << formatv( | |||
| "py::class_<{0}, std::shared_ptr<{0}>, OpDef> {0}Inst(m, \"{0}\");\n\n", | |||
| className | |||
| ); | |||
| for (auto&& i : op.getMgbAttributes()) { | |||
| if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) { | |||
| unsigned int enumID; | |||
| if (auto alias = llvm::dyn_cast<MgbAliasAttr>(attr)) { | |||
| auto&& aliasBase = alias->getAliasBase(); | |||
| enumID = | |||
| llvm::cast<MgbEnumAttr>(aliasBase) | |||
| .getBaseRecord()->getID(); | |||
| } else { | |||
| enumID = attr->getBaseRecord()->getID(); | |||
| } | |||
| auto&& enumAlias = ctx.enumAlias; | |||
| auto&& iter = enumAlias.find(enumID); | |||
| if (iter == enumAlias.end()) { | |||
| os << formatv( | |||
| "py::enum_<{0}::{1}>({0}Inst, \"{1}\")", | |||
| className, attr->getEnumName() | |||
| ); | |||
| std::vector<std::string> body; | |||
| for (auto&& i: attr->getEnumMembers()) { | |||
| os << formatv( | |||
| "\n .value(\"{2}\", {0}::{1}::{2})", | |||
| className, attr->getEnumName(), i | |||
| ); | |||
| body.push_back(formatv( | |||
| "if (str == \"{2}\") return {0}::{1}::{2};", | |||
| className, attr->getEnumName(), i | |||
| )); | |||
| } | |||
| if (attr->getEnumCombinedFlag()) { | |||
| //! define operator | | |||
| os << formatv( | |||
| "\n .def(\"__or__\", []({0}::{1} s0, {0}::{1} s1) {{ " | |||
| "\n return static_cast<{0}::{1}>(uint32_t(s0) | uint32_t(s1));" | |||
| "\n })", | |||
| className, attr->getEnumName()); | |||
| //! define operator & | |||
| os << formatv( | |||
| "\n .def(\"__and__\", []({0}::{1} s0, {0}::{1} s1) {{" | |||
| "\n return static_cast<{0}::{1}>(uint32_t(s0) & uint32_t(s1));" | |||
| "\n })", | |||
| className, attr->getEnumName()); | |||
| } | |||
| os << formatv( | |||
| "\n .def(py::init([](const std::string& in) {" | |||
| "\n auto&& str = normalize_enum(in);" | |||
| "\n {0}" | |||
| "\n throw py::cast_error(\"invalid enum value \" + in);" | |||
| "\n }));\n", | |||
| llvm::join(body, "\n ") | |||
| ); | |||
| os << formatv( | |||
| "py::implicitly_convertible<std::string, {0}::{1}>();\n\n", | |||
| className, attr->getEnumName() | |||
| ); | |||
| enumAlias.emplace(enumID, | |||
| std::make_pair(className, attr->getEnumName())); | |||
| } else { | |||
| os << formatv( | |||
| "{0}Inst.attr(\"{1}\") = {2}Inst.attr(\"{3}\");\n\n", | |||
| className, attr->getEnumName(), | |||
| iter->second.first, iter->second.second | |||
| ); | |||
| } | |||
| } | |||
| } | |||
| // generate op class binding | |||
| os << formatv("{0}Inst", className); | |||
| bool hasDefaultCtor = op.getMgbAttributes().empty(); | |||
| if (!hasDefaultCtor) { | |||
| os << "\n .def(py::init<"; | |||
| std::vector<llvm::StringRef> targs; | |||
| for (auto &&i : op.getMgbAttributes()) { | |||
| targs.push_back(i.attr.getReturnType()); | |||
| } | |||
| os << llvm::join(targs, ", "); | |||
| os << ", std::string>()"; | |||
| for (auto &&i : op.getMgbAttributes()) { | |||
| os << formatv(", py::arg(\"{0}\")", i.name); | |||
| auto defaultValue = i.attr.getDefaultValue(); | |||
| if (!defaultValue.empty()) { | |||
| os << formatv(" = {0}", defaultValue); | |||
| } else { | |||
| hasDefaultCtor = true; | |||
| } | |||
| } | |||
| os << ", py::arg(\"scope\") = {})"; | |||
| } | |||
| if (hasDefaultCtor) { | |||
| os << "\n .def(py::init<>())"; | |||
| } | |||
| for (auto &&i : op.getMgbAttributes()) { | |||
| os << formatv( | |||
| "\n .def_readwrite(\"{0}\", &{1}::{0})", | |||
| i.name, className | |||
| ); | |||
| } | |||
| os << ";\n\n"; | |||
| } | |||
| static std::string gen_op_def_python_c_extension_enum( | |||
| raw_ostream& os, EnumContext& ctx, MgbEnumAttr* attr, | |||
| llvm::StringRef className) { | |||
| std::string body; | |||
| unsigned int enumID; | |||
| if (auto alias = llvm::dyn_cast<MgbAliasAttr>(attr)) { | |||
| auto&& aliasBase = alias->getAliasBase(); | |||
| enumID = llvm::cast<MgbEnumAttr>(aliasBase).getBaseRecord()->getID(); | |||
| } else { | |||
| enumID = attr->getBaseRecord()->getID(); | |||
| } | |||
| auto&& enumAlias = ctx.enumAlias; | |||
| auto&& iter = enumAlias.find(enumID); | |||
| auto enumName = attr->getEnumName(); | |||
| body += "{\n"; | |||
| body += formatv("auto& e_type = EnumWrapper<{0}::{1}>::type;", className, | |||
| enumName); | |||
| if (iter == enumAlias.end()) { | |||
| os << formatv( | |||
| "template<> PyTypeObject EnumWrapper<{0}::{1}>::type={{};\n", | |||
| className, enumName); | |||
| os << formatv( | |||
| "template<> const char* EnumWrapper<{0}::{1}>::name = " | |||
| "\"{0}.{1}\";\n", | |||
| className, enumName); | |||
| std::vector<std::string> pairStr; | |||
| for (auto&& i : attr->getEnumMembers()) { | |||
| pairStr.push_back( | |||
| formatv("{{normalize_enum(\"{2}\"), {0}::{1}::{2}}", | |||
| className, enumName, i)); | |||
| } | |||
| os << formatv(R"( | |||
| template<> std::unordered_map<std::string, {0}::{1}> | |||
| EnumWrapper<{0}::{1}>::str2type = {{ | |||
| {2} | |||
| }; | |||
| )", | |||
| className, enumName, llvm::join(pairStr, ", ")); | |||
| pairStr.clear(); | |||
| for (auto&& i : attr->getEnumMembers()) { | |||
| pairStr.push_back( | |||
| formatv("{{{0}::{1}::{2}, normalize_enum(\"{2}\")}", | |||
| className, enumName, i)); | |||
| } | |||
| os << formatv(R"( | |||
| template<> std::unordered_map<{0}::{1}, std::string> | |||
| EnumWrapper<{0}::{1}>::type2str = {{ | |||
| {2} | |||
| }; | |||
| )", | |||
| className, enumName, llvm::join(pairStr, ", ")); | |||
| body += formatv(R"( | |||
| e_type = {{PyVarObject_HEAD_INIT(NULL, 0)}; | |||
| e_type.tp_name = "megengine.core._imperative_rt.ops.{0}.{1}"; | |||
| e_type.tp_basicsize = sizeof(EnumWrapper<{0}::{1}>); | |||
| e_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE; | |||
| e_type.tp_doc = "{0}.{1}"; | |||
| e_type.tp_base = &PyBaseObject_Type; | |||
| e_type.tp_repr = EnumWrapper<{0}::{1}>::py_repr; | |||
| e_type.tp_richcompare = EnumWrapper<{0}::{1}>::tp_richcompare; | |||
| mgb_assert(PyType_Ready(&e_type) >= 0); | |||
| )", | |||
| className, enumName); | |||
| for (auto&& i : attr->getEnumMembers()) { | |||
| body += formatv(R"({{ | |||
| PyObject* inst = e_type.tp_alloc(&e_type, 0); | |||
| reinterpret_cast<EnumWrapper<{0}::{1}>*>(inst)->value = {0}::{1}::{2}; | |||
| mgb_assert(PyDict_SetItemString(e_type.tp_dict, "{2}", inst) >= 0); | |||
| })", | |||
| className, enumName, i); | |||
| } | |||
| enumAlias.emplace(enumID, std::make_pair(className, enumName)); | |||
| } | |||
| body += formatv(R"( | |||
| PyType_Modified(&e_type); | |||
| mgb_assert(PyDict_SetItemString( | |||
| py_type.tp_dict, "{0}", reinterpret_cast<PyObject*>(&e_type)) >= 0); | |||
| )", | |||
| enumName); | |||
| body += "}\n"; | |||
| return body; | |||
| } | |||
| static std::string gen_op_def_python_c_extension_bit_combined_enum( | |||
| raw_ostream& os, EnumContext& ctx, MgbEnumAttr* attr, | |||
| llvm::StringRef className) { | |||
| std::string body; | |||
| unsigned int enumID; | |||
| if (auto alias = llvm::dyn_cast<MgbAliasAttr>(attr)) { | |||
| auto&& aliasBase = alias->getAliasBase(); | |||
| enumID = llvm::cast<MgbEnumAttr>(aliasBase).getBaseRecord()->getID(); | |||
| } else { | |||
| enumID = attr->getBaseRecord()->getID(); | |||
| } | |||
| auto&& enumAlias = ctx.enumAlias; | |||
| auto&& iter = enumAlias.find(enumID); | |||
| auto enumName = attr->getEnumName(); | |||
| body += "{\n"; | |||
| body += formatv("auto& e_type = BitCombinedEnumWrapper<{0}::{1}>::type;", | |||
| className, enumName); | |||
| if (iter == enumAlias.end()) { | |||
| os << formatv( | |||
| "template<> PyTypeObject " | |||
| "BitCombinedEnumWrapper<{0}::{1}>::type={{};\n", | |||
| className, enumName); | |||
| os << formatv( | |||
| "template<> PyNumberMethods " | |||
| "BitCombinedEnumWrapper<{0}::{1}>::number_methods={{};\n", | |||
| className, enumName); | |||
| os << formatv( | |||
| "template<> const char* BitCombinedEnumWrapper<{0}::{1}>::name " | |||
| "= \"{0}.{1}\";\n", | |||
| className, enumName); | |||
| os << formatv( | |||
| "template<> struct EnumTrait<{0}::{1}> {{ static constexpr " | |||
| "bool is_bit_combined = true;};\n", | |||
| className, enumName); | |||
| std::vector<std::string> pairStr; | |||
| for (auto&& i : attr->getEnumMembers()) { | |||
| pairStr.push_back( | |||
| formatv("{{normalize_enum(\"{2}\"), {0}::{1}::{2}}", | |||
| className, enumName, i)); | |||
| } | |||
| os << formatv(R"( | |||
| template<> std::unordered_map<std::string, {0}::{1}> | |||
| BitCombinedEnumWrapper<{0}::{1}>::str2type = {{ | |||
| {2} | |||
| }; | |||
| )", | |||
| className, enumName, llvm::join(pairStr, ", ")); | |||
| pairStr.clear(); | |||
| for (auto&& i : attr->getEnumMembers()) { | |||
| pairStr.push_back( | |||
| formatv("{{{0}::{1}::{2}, normalize_enum(\"{2}\")}", | |||
| className, enumName, i)); | |||
| } | |||
| os << formatv(R"( | |||
| template<> std::unordered_map<{0}::{1}, std::string> | |||
| BitCombinedEnumWrapper<{0}::{1}>::type2str = {{ | |||
| {2} | |||
| }; | |||
| )", | |||
| className, enumName, llvm::join(pairStr, ", ")); | |||
| body += formatv(R"( | |||
| e_type = {{PyVarObject_HEAD_INIT(NULL, 0)}; | |||
| e_type.tp_name = "megengine.core._imperative_rt.ops.{0}.{1}"; | |||
| e_type.tp_basicsize = sizeof(BitCombinedEnumWrapper<{0}::{1}>); | |||
| e_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE; | |||
| e_type.tp_doc = "{0}.{1}"; | |||
| e_type.tp_base = &PyBaseObject_Type; | |||
| e_type.tp_new = BitCombinedEnumWrapper<{0}::{1}>::py_new_combined_enum; | |||
| e_type.tp_init = BitCombinedEnumWrapper<{0}::{1}>::py_init; | |||
| e_type.tp_repr = BitCombinedEnumWrapper<{0}::{1}>::py_repr; | |||
| e_type.tp_richcompare = BitCombinedEnumWrapper<{0}::{1}>::tp_richcompare; | |||
| auto& number_method = BitCombinedEnumWrapper<{0}::{1}>::number_methods; | |||
| number_method.nb_or = BitCombinedEnumWrapper<{0}::{1}>::py_or; | |||
| number_method.nb_and = BitCombinedEnumWrapper<{0}::{1}>::py_and; | |||
| e_type.tp_as_number = &number_method; | |||
| mgb_assert(PyType_Ready(&e_type) >= 0); | |||
| )", | |||
| className, enumName); | |||
| for (auto&& i : attr->getEnumMembers()) { | |||
| body += formatv(R"({{ | |||
| PyObject* inst = e_type.tp_alloc(&e_type, 0); | |||
| reinterpret_cast<BitCombinedEnumWrapper<{0}::{1}>*>(inst)->value = {0}::{1}::{2}; | |||
| mgb_assert(PyDict_SetItemString(e_type.tp_dict, "{2}", inst) >= 0); | |||
| })", | |||
| className, enumName, i); | |||
| } | |||
| enumAlias.emplace(enumID, std::make_pair(className, enumName)); | |||
| } | |||
| body += formatv(R"( | |||
| PyType_Modified(&e_type); | |||
| mgb_assert(PyDict_SetItemString( | |||
| py_type.tp_dict, "{0}", reinterpret_cast<PyObject*>(&e_type)) >= 0); | |||
| )", | |||
| enumName); | |||
| body += "}\n"; | |||
| return body; | |||
| } | |||
| static void gen_op_def_python_c_extension_single(raw_ostream &os, MgbOp& op, EnumContext& ctx) { | |||
| auto className = op.getCppClassName(); | |||
| std::string body; | |||
| // generate PyType for enum class member | |||
| for (auto&& i : op.getMgbAttributes()) { | |||
| if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) { | |||
| if (attr->getEnumCombinedFlag()) { | |||
| body += gen_op_def_python_c_extension_bit_combined_enum( | |||
| os, ctx, attr, className); | |||
| } else { | |||
| body += gen_op_def_python_c_extension_enum(os, ctx, attr, | |||
| className); | |||
| } | |||
| } | |||
| } | |||
| // generate getsetters | |||
| std::vector<std::string> getsetters; | |||
| for (auto &&i : op.getMgbAttributes()) { | |||
| getsetters.push_back(formatv( | |||
| "{{const_cast<char*>(\"{1}\"), py_get_generic({0}, {1}), py_set_generic({0}, {1}), const_cast<char*>(\"{1}\"), NULL},", | |||
| className, i.name)); | |||
| } | |||
| // generate tp_init | |||
| std::string initBody; | |||
| if (!op.getMgbAttributes().empty()) { | |||
| initBody += "static const char* kwlist[] = {"; | |||
| std::vector<llvm::StringRef> attr_name_list; | |||
| llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) { | |||
| attr_name_list.push_back(attr.name); | |||
| }); | |||
| attr_name_list.push_back("scope"); | |||
| llvm::for_each(attr_name_list, [&](auto&& attr) { | |||
| initBody += formatv("\"{0}\", ", attr); | |||
| }); | |||
| initBody += "NULL};\n"; | |||
| initBody += " PyObject "; | |||
| std::vector<std::string> attr_init; | |||
| llvm::for_each(attr_name_list, [&](auto&& attr) { | |||
| attr_init.push_back(formatv("*{0} = NULL", attr)); | |||
| }); | |||
| initBody += llvm::join(attr_init, ", ") + ";\n"; | |||
| initBody += " if (!PyArg_ParseTupleAndKeywords(args, kwds, \"|"; | |||
| // an extra slot created for name | |||
| initBody += std::string(attr_name_list.size(), 'O'); | |||
| initBody += "\", const_cast<char**>(kwlist)"; | |||
| llvm::for_each(attr_name_list, [&](auto&& attr) { | |||
| initBody += formatv(", &{0}", attr); | |||
| }); | |||
| initBody += "))\n"; | |||
| initBody += " return -1;\n"; | |||
| llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) { | |||
| initBody += formatv(R"( | |||
| if ({1}) {{ | |||
| try {{ | |||
| reinterpret_cast<PyOp({0})*>(self)->inst().{1} = | |||
| pyobj_convert_generic<decltype({0}::{1})>::from({1}); | |||
| } CATCH_ALL(-1) | |||
| } | |||
| )", className, attr.name); | |||
| }); | |||
| initBody += formatv(R"( | |||
| if (scope) {{ | |||
| try {{ | |||
| reinterpret_cast<PyOp(OpDef)*>(self)->op | |||
| ->set_scope(pyobj_convert_generic<std::string>::from(scope)); | |||
| } CATCH_ALL(-1) | |||
| } | |||
| )", className); | |||
| } | |||
| initBody += "\n return 0;"; | |||
| os << formatv(R"( | |||
| PyOpDefBegin({0}) // {{ | |||
| static PyGetSetDef py_getsetters[]; | |||
| static int py_init(PyObject *self, PyObject *args, PyObject *kwds); | |||
| // }; | |||
| PyOpDefEnd({0}) | |||
| PyGetSetDef PyOp({0})::py_getsetters[] = {{ | |||
| {1} | |||
| {{NULL} /* Sentinel */ | |||
| }; | |||
| int PyOp({0})::py_init(PyObject *self, PyObject *args, PyObject *kwds) {{ | |||
| {2} | |||
| } | |||
| void _init_py_{0}(py::module m) {{ | |||
| using py_op = PyOp({0}); | |||
| auto& py_type = PyOpType({0}); | |||
| py_type = {{PyVarObject_HEAD_INIT(NULL, 0)}; | |||
| py_type.tp_name = "megengine.core._imperative_rt.ops.{0}"; | |||
| py_type.tp_basicsize = sizeof(PyOp({0})); | |||
| py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE; | |||
| py_type.tp_doc = "{0}"; | |||
| py_type.tp_base = &PyOpType(OpDef); | |||
| py_type.tp_dealloc = py_dealloc_generic<py_op>; | |||
| py_type.tp_new = py_new_generic<py_op>; | |||
| py_type.tp_init = py_op::py_init; | |||
| py_type.tp_getset = py_op::py_getsetters; | |||
| mgb_assert(PyType_Ready(&py_type) >= 0); | |||
| {3} | |||
| PyType_Modified(&py_type); | |||
| m.add_object("{0}", reinterpret_cast<PyObject*>(&py_type)); | |||
| mgb_assert(PyOp(OpDef)::ctype2pytype.emplace({0}::typeinfo(), &py_type).second); | |||
| } | |||
| )", | |||
| op.getCppClassName(), llvm::join(getsetters, "\n "), initBody, body); | |||
| } | |||
| static void for_each_operator(raw_ostream &os, RecordKeeper &keeper, | |||
| std::function<void(raw_ostream&, MgbOp&)> callback) { | |||
| auto op_base_class = keeper.getClass("Op"); | |||
| ASSERT(op_base_class, "could not find base class Op"); | |||
| for (auto&& i: keeper.getDefs()) { | |||
| auto&& r = i.second; | |||
| if (r->isSubClassOf(op_base_class)) { | |||
| auto op = mlir::tblgen::Operator(r.get()); | |||
| if (op.getDialectName().str() == "mgb") { | |||
| std::cerr << "\033[34;15m" << "Generating " << r->getName().str() << "\033[0m" << std::endl; | |||
| callback(os, llvm::cast<MgbOp>(op)); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| static bool gen_op_def_c_header(raw_ostream &os, RecordKeeper &keeper) { | |||
| for_each_operator(os, keeper, gen_op_def_c_header_single); | |||
| for_each_operator(os, keeper, gen_to_string_trait_for_enum); | |||
| return false; | |||
| } | |||
| static bool gen_op_def_c_body(raw_ostream &os, RecordKeeper &keeper) { | |||
| for_each_operator(os, keeper, gen_op_def_c_body_single); | |||
| return false; | |||
| } | |||
| static bool gen_op_def_pybind11(raw_ostream &os, RecordKeeper &keeper) { | |||
| EnumContext ctx; | |||
| using namespace std::placeholders; | |||
| for_each_operator(os, keeper, | |||
| std::bind(gen_op_def_pybind11_single, _1, _2, std::ref(ctx))); | |||
| return false; | |||
| } | |||
| static bool gen_op_def_python_c_extension(raw_ostream &os, RecordKeeper &keeper) { | |||
| EnumContext ctx; | |||
| using namespace std::placeholders; | |||
| for_each_operator(os, keeper, | |||
| std::bind(gen_op_def_python_c_extension_single, _1, _2, std::ref(ctx))); | |||
| os << "#define INIT_ALL_OP(m)"; | |||
| for_each_operator(os, keeper, [&](raw_ostream& os, MgbOp& op) { | |||
| os << formatv(" \\\n _init_py_{0}(m);", op.getCppClassName()); | |||
| }); | |||
| os << "\n"; | |||
| return false; | |||
| } | |||
| using namespace mlir::tblgen; | |||
| int main(int argc, char **argv) { | |||
| llvm::InitLLVM y(argc, argv); | |||
| @@ -0,0 +1,40 @@ | |||
| /** | |||
| * \file imperative/tablegen/emitter.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #pragma once | |||
| #include <unordered_map> | |||
| #include <stdexcept> | |||
| #include "llvm/ADT/StringRef.h" | |||
| #include "llvm/Support/raw_ostream.h" | |||
| namespace mlir::tblgen { | |||
| struct Environment { | |||
| std::unordered_map<unsigned int, std::pair<llvm::StringRef, llvm::StringRef>> enumAlias; | |||
| }; | |||
| struct EmitterBase { | |||
| EmitterBase(raw_ostream& os_): os(os_) {} | |||
| EmitterBase(raw_ostream& os_, Environment& env): os(os_), env_p(&env) {} | |||
| protected: | |||
| void newline() { os << "\n"; } | |||
| Environment& env() { | |||
| if (env_p) { | |||
| return *env_p; | |||
| } | |||
| throw std::runtime_error("access global environment via non-environment emitter"); | |||
| } | |||
| raw_ostream& os; | |||
| Environment* env_p = nullptr; | |||
| }; | |||
| } // namespace mlir::tblgen | |||
| @@ -1,3 +1,16 @@ | |||
| /** | |||
| * \file imperative/tablegen/helper.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #pragma once | |||
| #include <iostream> | |||
| #include <string> | |||
| #include <vector> | |||
| @@ -278,5 +291,28 @@ public: | |||
| } | |||
| }; | |||
| using MgbAttrWrapper = mlir::tblgen::MgbAttrWrapperBase; | |||
| using MgbEnumAttr = mlir::tblgen::MgbEnumAttrMixin; | |||
| using MgbHashableAttr = mlir::tblgen::MgbHashableAttrMixin; | |||
| using MgbAliasAttr = mlir::tblgen::MgbAliasAttrMixin; | |||
| using MgbOp = mlir::tblgen::MgbOpBase; | |||
| using MgbHashableOp = mlir::tblgen::MgbHashableOpMixin; | |||
| static inline void foreach_operator(llvm::RecordKeeper &keeper, | |||
| std::function<void(MgbOp&)> callback) { | |||
| auto op_base_class = keeper.getClass("Op"); | |||
| ASSERT(op_base_class, "could not find base class Op"); | |||
| for (auto&& i: keeper.getDefs()) { | |||
| auto&& r = i.second; | |||
| if (r->isSubClassOf(op_base_class)) { | |||
| auto op = mlir::tblgen::Operator(r.get()); | |||
| if (op.getDialectName().str() == "mgb") { | |||
| std::cerr << "\033[34;15m" << "Generating " << r->getName().str() << "\033[0m" << std::endl; | |||
| callback(llvm::cast<MgbOp>(op)); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } // namespace tblgen | |||
| } // namespace mlir | |||
| @@ -0,0 +1,309 @@ | |||
| /** | |||
| * \file imperative/tablegen/targets/cpp_class.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #include "./cpp_class.h" | |||
| #include "../emitter.h" | |||
| namespace mlir::tblgen { | |||
| namespace { | |||
| llvm::StringRef attr_to_ctype(const mlir::tblgen::Attribute& attr_) { | |||
| // Note: we have already registered the corresponding attr wrappers | |||
| // for following basic ctypes so we needn't handle them here | |||
| /* auto&& attr_type_name = attr.getAttrDefName(); | |||
| if (attr_type_name == "UI32Attr") { | |||
| return "uint32_t"; | |||
| } | |||
| if (attr_type_name == "UI64Attr") { | |||
| return "uint64_t"; | |||
| } | |||
| if (attr_type_name == "I32Attr") { | |||
| return "int32_t"; | |||
| } | |||
| if (attr_type_name == "F32Attr") { | |||
| return "float"; | |||
| } | |||
| if (attr_type_name == "F64Attr") { | |||
| return "double"; | |||
| } | |||
| if (attr_type_name == "StrAttr") { | |||
| return "std::string"; | |||
| } | |||
| if (attr_type_name == "BoolAttr") { | |||
| return "bool"; | |||
| }*/ | |||
| auto&& attr = llvm::cast<MgbAttrWrapper>(attr_); | |||
| if (auto e = llvm::dyn_cast<MgbEnumAttr>(&attr)) { | |||
| return e->getEnumName(); | |||
| } | |||
| return attr.getUnderlyingType(); | |||
| } | |||
| class OpDefEmitter final: public EmitterBase { | |||
| public: | |||
| OpDefEmitter(MgbOp& op_, raw_ostream& os_): | |||
| EmitterBase(os_), op(op_) {} | |||
| void emit_header(); | |||
| void emit_tpl_spl(); | |||
| void emit_body(); | |||
| private: | |||
| MgbOp& op; | |||
| }; | |||
| void OpDefEmitter::emit_header() { | |||
| os << formatv( | |||
| "class {0} : public OpDefImplBase<{0}> {{\n" | |||
| " MGB_DYN_TYPE_OBJ_FINAL_DECL;\n\n" | |||
| "public:\n", | |||
| op.getCppClassName() | |||
| ); | |||
| // handle enum alias | |||
| for (auto &&i : op.getMgbAttributes()) { | |||
| if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) { | |||
| os << formatv( | |||
| " using {0} = {1};\n", | |||
| attr->getEnumName(), attr->getUnderlyingType() | |||
| ); | |||
| } | |||
| } | |||
| for (auto &&i : op.getMgbAttributes()) { | |||
| auto defaultValue = i.attr.getDefaultValue().str(); | |||
| if (!defaultValue.empty()) { | |||
| defaultValue = formatv(" = {0}", defaultValue); | |||
| } | |||
| os << formatv( | |||
| " {0} {1}{2};\n", | |||
| attr_to_ctype(i.attr), i.name, defaultValue | |||
| ); | |||
| } | |||
| auto gen_ctor = [&](auto&& paramList, auto&& memInitList, auto&& body) { | |||
| os << formatv( | |||
| " {0}({1}){2}{3}\n", | |||
| op.getCppClassName(), paramList, memInitList, body | |||
| ); | |||
| }; | |||
| gen_ctor("", "", " = default;"); | |||
| if (!op.getMgbAttributes().empty()) { | |||
| std::vector<std::string> paramList, initList; | |||
| for (auto &&i : op.getMgbAttributes()) { | |||
| paramList.push_back(formatv( | |||
| "{0} {1}_", attr_to_ctype(i.attr), i.name | |||
| )); | |||
| initList.push_back(formatv( | |||
| "{0}({0}_)", i.name | |||
| )); | |||
| } | |||
| paramList.push_back("std::string scope_ = {}"); | |||
| gen_ctor(llvm::join(paramList, ", "), | |||
| ": " + llvm::join(initList, ", "), | |||
| " { set_scope(scope_); }"); | |||
| } | |||
| auto packedParams = op.getPackedParams(); | |||
| if (!packedParams.empty()) { | |||
| std::vector<std::string> paramList, initList; | |||
| for (auto &&p : packedParams) { | |||
| auto&& paramFields = p.getFields(); | |||
| auto&& paramType = p.getFullName(); | |||
| auto&& paramName = formatv("packed_param_{0}", paramList.size()); | |||
| paramList.push_back( | |||
| paramFields.empty() ? paramType.str() | |||
| : formatv("{0} {1}", paramType, paramName) | |||
| ); | |||
| for (auto&& i : paramFields) { | |||
| initList.push_back(formatv( | |||
| "{0}({1}.{0})", i.name, paramName | |||
| )); | |||
| } | |||
| } | |||
| for (auto&& i : op.getExtraArguments()) { | |||
| paramList.push_back(formatv( | |||
| "{0} {1}_", attr_to_ctype(i.attr), i.name | |||
| )); | |||
| initList.push_back(formatv( | |||
| "{0}({0}_)", i.name | |||
| )); | |||
| } | |||
| gen_ctor(llvm::join(paramList, ", "), | |||
| initList.empty() ? "" : ": " + llvm::join(initList, ", "), | |||
| " {}"); | |||
| } | |||
| if (!packedParams.empty()) { | |||
| for (auto&& p : packedParams) { | |||
| auto accessor = p.getAccessor(); | |||
| if (!accessor.empty()) { | |||
| os << formatv( | |||
| " {0} {1}() const {{\n", | |||
| p.getFullName(), accessor | |||
| ); | |||
| std::vector<llvm::StringRef> fields; | |||
| for (auto&& i : p.getFields()) { | |||
| fields.push_back(i.name); | |||
| } | |||
| os << formatv( | |||
| " return {{{0}};\n", | |||
| llvm::join(fields, ", ") | |||
| ); | |||
| os << " }\n"; | |||
| } | |||
| } | |||
| } | |||
| if (auto decl = op.getExtraOpdefDecl()) { | |||
| os << decl.getValue(); | |||
| } | |||
| os << formatv( | |||
| "};\n\n" | |||
| ); | |||
| } | |||
| void OpDefEmitter::emit_tpl_spl() { | |||
| for (auto &&i : op.getMgbAttributes()) { | |||
| if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) { | |||
| if (attr->supportToString()) { | |||
| std::vector<std::string> case_body; | |||
| std::string ename = formatv("{0}::{1}", | |||
| op.getCppClassName(), attr->getEnumName()); | |||
| llvm::for_each(attr->getEnumMembers(), [&](auto&& v){ | |||
| case_body.push_back(formatv( | |||
| "case {0}::{1}: return \"{1}\";", ename, v)); | |||
| }); | |||
| os << formatv(R"( | |||
| template <> | |||
| struct ToStringTrait<{0}> { | |||
| std::string operator()({0} e) const { | |||
| switch (e) { | |||
| {1} | |||
| default: | |||
| return "{0}::Unknown"; | |||
| } | |||
| } | |||
| }; | |||
| )", ename, llvm::join(case_body, "\n")); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| void OpDefEmitter::emit_body() { | |||
| auto&& className = op.getCppClassName(); | |||
| os << formatv( | |||
| "MGB_DYN_TYPE_OBJ_FINAL_IMPL({0});\n\n", className | |||
| ); | |||
| auto formatMethImpl = [&](auto&& meth) { | |||
| return formatv( | |||
| "{0}_{1}_impl", className, meth | |||
| ); | |||
| }; | |||
| std::vector<std::string> methods; | |||
| if (auto hashable = llvm::dyn_cast<MgbHashableOp>(&op)) { | |||
| os << "namespace {\n"; | |||
| // generate hash() | |||
| mlir::tblgen::FmtContext ctx; | |||
| os << formatv( | |||
| "size_t {0}(const OpDef& def_) {{\n", | |||
| formatMethImpl("hash") | |||
| ); | |||
| os << formatv( | |||
| " auto&& op_ = def_.cast_final_safe<{0}>();\n" | |||
| " static_cast<void>(op_);\n", | |||
| className | |||
| ); | |||
| ctx.withSelf("op_"); | |||
| os << mlir::tblgen::tgfmt(hashable->getHashFunctionTemplate(), &ctx); | |||
| os << "}\n"; | |||
| // generate is_same_st() | |||
| os << formatv( | |||
| "bool {0}(const OpDef& lhs_, const OpDef& rhs_) {{\n", | |||
| formatMethImpl("is_same_st") | |||
| ); | |||
| os << formatv( | |||
| " auto &&a_ = lhs_.cast_final_safe<{0}>(),\n" | |||
| " &&b_ = rhs_.cast_final_safe<{0}>();\n" | |||
| " static_cast<void>(a_);\n" | |||
| " static_cast<void>(b_);\n", | |||
| className | |||
| ); | |||
| os << mlir::tblgen::tgfmt(hashable->getCmpFunctionTemplate(), &ctx, "a_", "b_"); | |||
| os << "}\n"; | |||
| // generate props() | |||
| os << formatv( | |||
| "std::vector<std::pair<const char*, std::string>> {0}(const OpDef& def_) {{\n", | |||
| formatMethImpl("props") | |||
| ); | |||
| os << formatv( | |||
| " auto&& op_ = def_.cast_final_safe<{0}>();\n" | |||
| " static_cast<void>(op_);\n", | |||
| className | |||
| ); | |||
| ctx.withSelf("op_"); | |||
| os << mlir::tblgen::tgfmt(hashable->getPropsFunctionTemplate(), &ctx); | |||
| os << "}\n"; | |||
| // generate make_name() | |||
| os << formatv( | |||
| "std::string {0}(const OpDef& def_) {{\n", formatMethImpl("make_name") | |||
| ); | |||
| os << formatv( | |||
| " auto&& op_ = def_.cast_final_safe<{0}>();\n" | |||
| " static_cast<void>(op_);\n", | |||
| className | |||
| ); | |||
| ctx.withSelf("op_"); | |||
| os << mlir::tblgen::tgfmt(op.getNameFunctionTemplate(), &ctx); | |||
| os << "}\n"; | |||
| os << "} // anonymous namespace\n"; | |||
| methods.push_back("hash"); | |||
| methods.push_back("is_same_st"); | |||
| methods.push_back("props"); | |||
| methods.push_back("make_name"); | |||
| } | |||
| if (!methods.empty()) { | |||
| os << formatv( | |||
| "OP_TRAIT_REG({0}, {0})", op.getCppClassName() | |||
| ); | |||
| for (auto&& i : methods) { | |||
| os << formatv( | |||
| "\n .{0}({1})", i, formatMethImpl(i) | |||
| ); | |||
| } | |||
| os << ";\n\n"; | |||
| } | |||
| } | |||
| } // namespace | |||
| bool gen_op_def_c_header(raw_ostream &os, llvm::RecordKeeper &keeper) { | |||
| foreach_operator(keeper, [&](MgbOp& op) { | |||
| OpDefEmitter emitter(op, os); | |||
| emitter.emit_header(); | |||
| emitter.emit_tpl_spl(); | |||
| }); | |||
| return false; | |||
| } | |||
| bool gen_op_def_c_body(raw_ostream &os, llvm::RecordKeeper &keeper) { | |||
| foreach_operator(keeper, [&](MgbOp& op) { | |||
| OpDefEmitter emitter(op, os); | |||
| emitter.emit_body(); | |||
| }); | |||
| return false; | |||
| } | |||
| } // namespace mlir::tblgen | |||
| @@ -0,0 +1,21 @@ | |||
| /** | |||
| * \file imperative/tablegen/targets/cpp_class.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #pragma once | |||
| #include "../helper.h" | |||
| namespace mlir::tblgen { | |||
| bool gen_op_def_c_header(raw_ostream &os, llvm::RecordKeeper &keeper); | |||
| bool gen_op_def_c_body(raw_ostream &os, llvm::RecordKeeper &keeper); | |||
| } // namespace mlir::tblgen | |||
| @@ -0,0 +1,142 @@ | |||
| /** | |||
| * \file imperative/tablegen/targets/pybind11.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #include "./pybind11.h" | |||
| #include "../emitter.h" | |||
| namespace mlir::tblgen { | |||
| namespace { | |||
| class OpDefEmitter final: public EmitterBase { | |||
| public: | |||
| OpDefEmitter(MgbOp& op_, raw_ostream& os_, Environment& env_): | |||
| EmitterBase(os_, env_), op(op_) {} | |||
| void emit(); | |||
| private: | |||
| MgbOp& op; | |||
| }; | |||
| void OpDefEmitter::emit() { | |||
| auto className = op.getCppClassName(); | |||
| os << formatv( | |||
| "py::class_<{0}, std::shared_ptr<{0}>, OpDef> {0}Inst(m, \"{0}\");\n\n", | |||
| className | |||
| ); | |||
| for (auto&& i : op.getMgbAttributes()) { | |||
| if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) { | |||
| unsigned int enumID; | |||
| if (auto alias = llvm::dyn_cast<MgbAliasAttr>(attr)) { | |||
| auto&& aliasBase = alias->getAliasBase(); | |||
| enumID = | |||
| llvm::cast<MgbEnumAttr>(aliasBase) | |||
| .getBaseRecord()->getID(); | |||
| } else { | |||
| enumID = attr->getBaseRecord()->getID(); | |||
| } | |||
| auto&& enumAlias = env().enumAlias; | |||
| auto&& iter = enumAlias.find(enumID); | |||
| if (iter == enumAlias.end()) { | |||
| os << formatv( | |||
| "py::enum_<{0}::{1}>({0}Inst, \"{1}\")", | |||
| className, attr->getEnumName() | |||
| ); | |||
| std::vector<std::string> body; | |||
| for (auto&& i: attr->getEnumMembers()) { | |||
| os << formatv( | |||
| "\n .value(\"{2}\", {0}::{1}::{2})", | |||
| className, attr->getEnumName(), i | |||
| ); | |||
| body.push_back(formatv( | |||
| "if (str == \"{2}\") return {0}::{1}::{2};", | |||
| className, attr->getEnumName(), i | |||
| )); | |||
| } | |||
| if (attr->getEnumCombinedFlag()) { | |||
| //! define operator | | |||
| os << formatv( | |||
| "\n .def(\"__or__\", []({0}::{1} s0, {0}::{1} s1) {{ " | |||
| "\n return static_cast<{0}::{1}>(uint32_t(s0) | uint32_t(s1));" | |||
| "\n })", | |||
| className, attr->getEnumName()); | |||
| //! define operator & | |||
| os << formatv( | |||
| "\n .def(\"__and__\", []({0}::{1} s0, {0}::{1} s1) {{" | |||
| "\n return static_cast<{0}::{1}>(uint32_t(s0) & uint32_t(s1));" | |||
| "\n })", | |||
| className, attr->getEnumName()); | |||
| } | |||
| os << formatv( | |||
| "\n .def(py::init([](const std::string& in) {" | |||
| "\n auto&& str = normalize_enum(in);" | |||
| "\n {0}" | |||
| "\n throw py::cast_error(\"invalid enum value \" + in);" | |||
| "\n }));\n", | |||
| llvm::join(body, "\n ") | |||
| ); | |||
| os << formatv( | |||
| "py::implicitly_convertible<std::string, {0}::{1}>();\n\n", | |||
| className, attr->getEnumName() | |||
| ); | |||
| enumAlias.emplace(enumID, | |||
| std::make_pair(className, attr->getEnumName())); | |||
| } else { | |||
| os << formatv( | |||
| "{0}Inst.attr(\"{1}\") = {2}Inst.attr(\"{3}\");\n\n", | |||
| className, attr->getEnumName(), | |||
| iter->second.first, iter->second.second | |||
| ); | |||
| } | |||
| } | |||
| } | |||
| // generate op class binding | |||
| os << formatv("{0}Inst", className); | |||
| bool hasDefaultCtor = op.getMgbAttributes().empty(); | |||
| if (!hasDefaultCtor) { | |||
| os << "\n .def(py::init<"; | |||
| std::vector<llvm::StringRef> targs; | |||
| for (auto &&i : op.getMgbAttributes()) { | |||
| targs.push_back(i.attr.getReturnType()); | |||
| } | |||
| os << llvm::join(targs, ", "); | |||
| os << ", std::string>()"; | |||
| for (auto &&i : op.getMgbAttributes()) { | |||
| os << formatv(", py::arg(\"{0}\")", i.name); | |||
| auto defaultValue = i.attr.getDefaultValue(); | |||
| if (!defaultValue.empty()) { | |||
| os << formatv(" = {0}", defaultValue); | |||
| } else { | |||
| hasDefaultCtor = true; | |||
| } | |||
| } | |||
| os << ", py::arg(\"scope\") = {})"; | |||
| } | |||
| if (hasDefaultCtor) { | |||
| os << "\n .def(py::init<>())"; | |||
| } | |||
| for (auto &&i : op.getMgbAttributes()) { | |||
| os << formatv( | |||
| "\n .def_readwrite(\"{0}\", &{1}::{0})", | |||
| i.name, className | |||
| ); | |||
| } | |||
| os << ";\n\n"; | |||
| } | |||
| } // namespace | |||
| bool gen_op_def_pybind11(raw_ostream &os, llvm::RecordKeeper &keeper) { | |||
| Environment env; | |||
| using namespace std::placeholders; | |||
| foreach_operator(keeper, [&](MgbOp& op) { | |||
| OpDefEmitter(op, os, env).emit(); | |||
| }); | |||
| return false; | |||
| } | |||
| } // namespace mlir::tblgen | |||
| @@ -0,0 +1,19 @@ | |||
| /** | |||
| * \file imperative/tablegen/targets/pybind11.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #pragma once | |||
| #include "../helper.h" | |||
| namespace mlir::tblgen { | |||
| bool gen_op_def_pybind11(raw_ostream &os, llvm::RecordKeeper &keeper); | |||
| } // namespace mlir::tblgen | |||
| @@ -0,0 +1,314 @@ | |||
| /** | |||
| * \file imperative/tablegen/targets/python_c_extension.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #include "python_c_extension.h" | |||
| #include "../emitter.h" | |||
| namespace mlir::tblgen { | |||
| namespace { | |||
| struct Initproc { | |||
| std::string func; | |||
| Initproc(std::string&& s): func(std::move(s)) {} | |||
| std::string operator()(std::string argument) { | |||
| return formatv("{0}({1})", func, argument); | |||
| } | |||
| }; | |||
| class OpDefEmitter: public EmitterBase { | |||
| public: | |||
| OpDefEmitter(MgbOp& op_, raw_ostream& os_, Environment& env_): | |||
| EmitterBase(os_, env_), op(op_) { | |||
| ctx.withSelf(op.getCppClassName()); | |||
| } | |||
| Initproc emit(); | |||
| private: | |||
| void emit_class(); | |||
| void emit_py_init(); | |||
| void emit_py_getsetters(); | |||
| Initproc emit_initproc(); | |||
| MgbOp& op; | |||
| std::vector<Initproc> subclasses; | |||
| mlir::tblgen::FmtContext ctx; | |||
| }; | |||
| class EnumAttrEmitter: public EmitterBase { | |||
| public: | |||
| EnumAttrEmitter(llvm::StringRef parent, MgbEnumAttr* attr_, raw_ostream& os_, Environment& env_): | |||
| EmitterBase(os_, env_), attr(attr_) { | |||
| unsigned int enumID; | |||
| if (auto alias = llvm::dyn_cast<MgbAliasAttr>(attr)) { | |||
| auto&& aliasBase = alias->getAliasBase(); | |||
| enumID = llvm::cast<MgbEnumAttr>(aliasBase).getBaseRecord()->getID(); | |||
| } else { | |||
| enumID = attr->getBaseRecord()->getID(); | |||
| } | |||
| ctx.addSubst("enumTpl", attr->getEnumCombinedFlag() ? "BitCombinedEnumWrapper" : "EnumWrapper"); | |||
| ctx.addSubst("opClass", parent); | |||
| ctx.addSubst("enumClass", attr->getEnumName()); | |||
| firstOccur = env().enumAlias.emplace(enumID, std::make_pair(parent, attr->getEnumName())).second; | |||
| } | |||
| Initproc emit(); | |||
| protected: | |||
| void emit_tpl_spl(); | |||
| Initproc emit_initproc(); | |||
| MgbEnumAttr* attr; | |||
| bool firstOccur; | |||
| mlir::tblgen::FmtContext ctx; | |||
| }; | |||
| Initproc EnumAttrEmitter::emit() { | |||
| emit_tpl_spl(); | |||
| return emit_initproc(); | |||
| } | |||
| void EnumAttrEmitter::emit_tpl_spl() { | |||
| if (!firstOccur) return; | |||
| os << tgfmt( | |||
| "template<> PyTypeObject $enumTpl<$opClass::$enumClass>::type={};\n", | |||
| &ctx); | |||
| os << tgfmt( | |||
| "template<> const char* $enumTpl<$opClass::$enumClass>::name = " | |||
| "\"$opClass.$enumClass\";\n", | |||
| &ctx); | |||
| if (attr->getEnumCombinedFlag()) { | |||
| os << tgfmt( | |||
| "template<> PyNumberMethods " | |||
| "$enumTpl<$opClass::$enumClass>::number_methods={};\n", | |||
| &ctx); | |||
| os << tgfmt(R"( | |||
| template<> struct EnumTrait<$opClass::$enumClass> { | |||
| static constexpr bool is_bit_combined = true; | |||
| static constexpr std::underlying_type_t<$opClass::$enumClass> max = (1llu << $0) - 1; | |||
| }; | |||
| )", &ctx, attr->getEnumMembers().size()); | |||
| } | |||
| auto str2type = [&](auto&& i) -> std::string { | |||
| return tgfmt("{normalize_enum(\"$0\"), $opClass::$enumClass::$0}", &ctx, i); | |||
| }; | |||
| os << tgfmt(R"( | |||
| template<> std::unordered_map<std::string, $opClass::$enumClass> | |||
| $enumTpl<$opClass::$enumClass>::str2type = {$0}; | |||
| )", &ctx, llvm::join(llvm::map_range(attr->getEnumMembers(), str2type), ", ")); | |||
| auto type2str = [&](auto&& i) -> std::string { | |||
| return tgfmt("{$opClass::$enumClass::$0, normalize_enum(\"$0\")}", &ctx, i); | |||
| }; | |||
| os << tgfmt(R"( | |||
| template<> std::unordered_map<$opClass::$enumClass, std::string> | |||
| $enumTpl<$opClass::$enumClass>::type2str = {$0}; | |||
| )", &ctx, llvm::join(llvm::map_range(attr->getEnumMembers(), type2str), ", ")); | |||
| } | |||
| Initproc EnumAttrEmitter::emit_initproc() { | |||
| std::string initproc = formatv("_init_py_{0}_{1}", | |||
| ctx.getSubstFor("opClass"), ctx.getSubstFor("enumClass")); | |||
| os << tgfmt(R"( | |||
| void $0(PyTypeObject& py_type) { | |||
| auto& e_type = $enumTpl<$opClass::$enumClass>::type; | |||
| )", &ctx, initproc); | |||
| if (firstOccur) { | |||
| os << tgfmt(R"( | |||
| e_type = {PyVarObject_HEAD_INIT(NULL, 0)}; | |||
| e_type.tp_name = "megengine.core._imperative_rt.ops.$opClass.$enumClass"; | |||
| e_type.tp_basicsize = sizeof($enumTpl<$opClass::$enumClass>); | |||
| e_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE; | |||
| e_type.tp_doc = "$opClass.$enumClass"; | |||
| e_type.tp_base = &PyBaseObject_Type; | |||
| e_type.tp_repr = $enumTpl<$opClass::$enumClass>::py_repr; | |||
| e_type.tp_richcompare = $enumTpl<$opClass::$enumClass>::tp_richcompare; | |||
| )", &ctx); | |||
| if (attr->getEnumCombinedFlag()) { | |||
| // only bit combined enum could new instance because bitwise operation, | |||
| // others should always use singleton | |||
| os << tgfmt(R"( | |||
| e_type.tp_new = $enumTpl<$opClass::$enumClass>::py_new_combined_enum; | |||
| auto& number_method = $enumTpl<$opClass::$enumClass>::number_methods; | |||
| number_method.nb_or = $enumTpl<$opClass::$enumClass>::py_or; | |||
| number_method.nb_and = $enumTpl<$opClass::$enumClass>::py_and; | |||
| e_type.tp_as_number = &number_method; | |||
| )", &ctx); | |||
| } | |||
| os << " mgb_assert(PyType_Ready(&e_type) >= 0);\n"; | |||
| for (auto&& i : attr->getEnumMembers()) { | |||
| os << tgfmt(R"({ | |||
| PyObject* inst = e_type.tp_alloc(&e_type, 0); | |||
| reinterpret_cast<$enumTpl<$opClass::$enumClass>*>(inst)->value = $opClass::$enumClass::$0; | |||
| mgb_assert(PyDict_SetItemString(e_type.tp_dict, "$0", inst) >= 0); | |||
| PyType_Modified(&e_type); | |||
| })", &ctx, i); | |||
| } | |||
| } | |||
| os << tgfmt(R"( | |||
| mgb_assert(PyDict_SetItemString( | |||
| py_type.tp_dict, "$enumClass", reinterpret_cast<PyObject*>(&e_type)) >= 0); | |||
| )", &ctx); | |||
| os << "}\n"; | |||
| return initproc; | |||
| } | |||
| Initproc OpDefEmitter::emit() { | |||
| for (auto&& i : op.getMgbAttributes()) { | |||
| if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) { | |||
| subclasses.push_back(EnumAttrEmitter(op.getCppClassName(), attr, os, env()).emit()); | |||
| } | |||
| } | |||
| emit_class(); | |||
| emit_py_init(); | |||
| emit_py_getsetters(); | |||
| return emit_initproc(); | |||
| } | |||
| void OpDefEmitter::emit_class() { | |||
| os << tgfmt(R"( | |||
| PyOpDefBegin($_self) // { | |||
| static PyGetSetDef py_getsetters[]; | |||
| static int py_init(PyObject *self, PyObject *args, PyObject *kwds); | |||
| // }; | |||
| PyOpDefEnd($_self) | |||
| )", &ctx); | |||
| } | |||
| void OpDefEmitter::emit_py_init() { | |||
| std::string initBody; | |||
| if (!op.getMgbAttributes().empty()) { | |||
| initBody += "static const char* kwlist[] = {"; | |||
| std::vector<llvm::StringRef> attr_name_list; | |||
| llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) { | |||
| attr_name_list.push_back(attr.name); | |||
| }); | |||
| attr_name_list.push_back("scope"); | |||
| llvm::for_each(attr_name_list, [&](auto&& attr) { | |||
| initBody += formatv("\"{0}\", ", attr); | |||
| }); | |||
| initBody += "NULL};\n"; | |||
| initBody += " PyObject "; | |||
| auto initializer = [&](auto&& attr) -> std::string { | |||
| return formatv("*{0} = NULL", attr); | |||
| }; | |||
| initBody += llvm::join(llvm::map_range(attr_name_list, initializer), ", ") + ";\n"; | |||
| initBody += " if (!PyArg_ParseTupleAndKeywords(args, kwds, \"|"; | |||
| // an extra slot created for name | |||
| initBody += std::string(attr_name_list.size(), 'O'); | |||
| initBody += "\", const_cast<char**>(kwlist)"; | |||
| llvm::for_each(attr_name_list, [&](auto&& attr) { | |||
| initBody += formatv(", &{0}", attr); | |||
| }); | |||
| initBody += "))\n"; | |||
| initBody += " return -1;\n"; | |||
| llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) { | |||
| initBody += tgfmt(R"( | |||
| if ($0) { | |||
| try { | |||
| reinterpret_cast<PyOp($_self)*>(self)->inst().$0 = | |||
| pyobj_convert_generic<decltype($_self::$0)>::from($0); | |||
| } CATCH_ALL(-1) | |||
| } | |||
| )", &ctx, attr.name); | |||
| }); | |||
| initBody += tgfmt(R"( | |||
| if (scope) { | |||
| try { | |||
| reinterpret_cast<PyOp(OpDef)*>(self)->op | |||
| ->set_scope(pyobj_convert_generic<std::string>::from(scope)); | |||
| } CATCH_ALL(-1) | |||
| } | |||
| )", &ctx); | |||
| } | |||
| initBody += "\n return 0;"; | |||
| os << tgfmt(R"( | |||
| int PyOp($_self)::py_init(PyObject *self, PyObject *args, PyObject *kwds) { | |||
| $0 | |||
| } | |||
| )", &ctx, initBody); | |||
| } | |||
| void OpDefEmitter::emit_py_getsetters() { | |||
| auto f = [&](auto&& attr) -> std::string { | |||
| return tgfmt( | |||
| "{const_cast<char*>(\"$0\"), py_get_generic($_self, $0), py_set_generic($_self, $0), const_cast<char*>(\"$0\"), NULL},", | |||
| &ctx, attr.name); | |||
| }; | |||
| os << tgfmt(R"( | |||
| PyGetSetDef PyOp($_self)::py_getsetters[] = { | |||
| $0 | |||
| {NULL} /* Sentinel */ | |||
| }; | |||
| )", &ctx, llvm::join(llvm::map_range(op.getMgbAttributes(), f), "\n ")); | |||
| } | |||
| Initproc OpDefEmitter::emit_initproc() { | |||
| std::string initproc = formatv("_init_py_{0}", op.getCppClassName()); | |||
| std::string subclass_init_call; | |||
| for (auto&& i : subclasses) { | |||
| subclass_init_call += formatv(" {0};\n", i("py_type")); | |||
| } | |||
| os << tgfmt(R"( | |||
| void $0(py::module m) { | |||
| using py_op = PyOp($_self); | |||
| auto& py_type = PyOpType($_self); | |||
| py_type = {PyVarObject_HEAD_INIT(NULL, 0)}; | |||
| py_type.tp_name = "megengine.core._imperative_rt.ops.$_self"; | |||
| py_type.tp_basicsize = sizeof(PyOp($_self)); | |||
| py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE; | |||
| py_type.tp_doc = "$_self"; | |||
| py_type.tp_base = &PyOpType(OpDef); | |||
| py_type.tp_dealloc = py_dealloc_generic<py_op>; | |||
| py_type.tp_new = py_new_generic<py_op>; | |||
| py_type.tp_init = py_op::py_init; | |||
| py_type.tp_getset = py_op::py_getsetters; | |||
| mgb_assert(PyType_Ready(&py_type) >= 0); | |||
| $1 | |||
| PyType_Modified(&py_type); | |||
| m.add_object("$_self", reinterpret_cast<PyObject*>(&py_type)); | |||
| mgb_assert(PyOp(OpDef)::ctype2pytype.emplace($_self::typeinfo(), &py_type).second); | |||
| } | |||
| )", &ctx, initproc, subclass_init_call); | |||
| return initproc; | |||
| } | |||
| } // namespace | |||
| bool gen_op_def_python_c_extension(raw_ostream &os, llvm::RecordKeeper &keeper) { | |||
| Environment env; | |||
| using namespace std::placeholders; | |||
| std::vector<Initproc> initprocs; | |||
| foreach_operator(keeper, [&](MgbOp& op) { | |||
| initprocs.emplace_back(OpDefEmitter(op, os, env).emit()); | |||
| }); | |||
| os << "#define INIT_ALL_OP(m)"; | |||
| for(auto&& init : initprocs) { | |||
| os << formatv(" \\\n {0};", init("m")); | |||
| } | |||
| os << "\n"; | |||
| return false; | |||
| } | |||
| } // namespace mlir::tblgen | |||
| @@ -0,0 +1,19 @@ | |||
| /** | |||
| * \file imperative/tablegen/targets/python_c_extension.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #pragma once | |||
| #include "../helper.h" | |||
| namespace mlir::tblgen { | |||
| bool gen_op_def_python_c_extension(raw_ostream &os, llvm::RecordKeeper &keeper); | |||
| } // namespace mlir::tblgen | |||
| @@ -709,7 +709,7 @@ void run_test_st(Args &env) { | |||
| strategy = S::PROFILE; | |||
| } | |||
| } else if (env.use_fast_run) { | |||
| strategy = S::PROFILE | S::OPTMIZED; | |||
| strategy = S::PROFILE | S::OPTIMIZED; | |||
| } else if (env.reproducible) { | |||
| strategy = S::HEURISTIC | S::REPRODUCIBLE; | |||
| } | |||
| @@ -365,14 +365,16 @@ namespace mgb { | |||
| if (!m_free_task_block.empty()) { | |||
| ret = std::move(m_free_task_block.back()); | |||
| m_free_task_block.pop_back(); | |||
| break; | |||
| } else if (m_block_quota > 0) { | |||
| ret = std::make_unique<TaskBlock>(); | |||
| m_block_quota--; | |||
| break; | |||
| } else { | |||
| m_cv.wait(m_mutex); | |||
| continue; | |||
| } | |||
| } while (false); | |||
| } while (true); | |||
| ret->first_tid = m_new_block_first_tid; | |||
| m_new_block_first_tid += BLOCK_SIZE; | |||
| ret->prev = prev; | |||
| @@ -12,8 +12,8 @@ | |||
| #pragma once | |||
| #define MGB_MAJOR 8 | |||
| #define MGB_MINOR 9999 | |||
| #define MGB_PATCH 0 | |||
| #define MGB_MINOR 10 | |||
| #define MGB_PATCH 1 | |||
| //! whether it is development version | |||
| #ifndef MGB_IS_DEV | |||
| #define MGB_IS_DEV 0 | |||
| @@ -1565,7 +1565,7 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() { | |||
| if (new_inp[i]->shape()[1] % 4 != 0) { | |||
| can_exec_cd4 = false; | |||
| } | |||
| //! cd4 elemwise with scaler is supported | |||
| //! cd4 elemwise with scaler is unsupported | |||
| } else if (!new_inp[i]->shape().is_scalar()) { | |||
| can_exec_cd4 = false; | |||
| } | |||
| @@ -1627,6 +1627,7 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() { | |||
| replace_func[opr::Broadcast::typeinfo()] = relayout_inp_to_chw; | |||
| replace_func[opr::IncrSubtensor::typeinfo()] = relayout_inp_to_chw; | |||
| replace_func[opr::AxisAddRemove::typeinfo()] = relayout_inp_to_chw; | |||
| replace_func[opr::TypeCvt::typeinfo()] = replace_elemwise_opr; | |||
| replace_func[opr::ResizeForward::typeinfo()] = replace_resize_opr; | |||
| replace_func[opr::WarpPerspectiveForward::typeinfo()] = | |||
| replace_warp_perspective_opr; | |||
| @@ -1265,6 +1265,55 @@ TEST(TestGoptInference, ConvertFormatNHWCD4Elemwise) { | |||
| MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-3); | |||
| } | |||
| TEST(TestGoptInference, ConvertFormatNHWCD4TypeCvt) { | |||
| NaiveMegDNNHandleScope naive_megdnn_handle; | |||
| HostTensorGenerator<> gen; | |||
| auto cn = CompNode::load("cpu0"); | |||
| auto graph = ComputingGraph::make(); | |||
| graph->options().graph_opt_level = 0; | |||
| auto mkcvar = [&](const char* name, const TensorShape& shp) { | |||
| return opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)) | |||
| .rename(name); | |||
| }; | |||
| auto host_x = gen({8, 8, 8, 8}, cn); | |||
| auto x = opr::Host2DeviceCopy::make(*graph, host_x); | |||
| opr::Convolution::Param param; | |||
| param.pad_h = param.pad_w = 0; | |||
| auto w1 = mkcvar("w1", {8, 8, 3, 3}), | |||
| conv1 = opr::Convolution::make(x, w1, param), | |||
| tcvt1 = opr::TypeCvt::make(conv1, dtype::Float16()); | |||
| auto w2 = mkcvar("w2", {8, 8, 3, 3}), | |||
| conv2 = opr::Convolution::make(x, w2, param), | |||
| tcvt2 = opr::TypeCvt::make(conv2, dtype::Float16()); | |||
| auto y = opr::Elemwise::make({tcvt1, tcvt2}, opr::Elemwise::Param::Mode::ADD); | |||
| SymbolVar y_opt; | |||
| auto options = gopt::OptimizeForInferenceOptions{}; | |||
| options.enable_nhwcd4(); | |||
| unpack_vector(gopt::optimize_for_inference({y}, options), y_opt); | |||
| ASSERT_EQ(opr::Convolution::Param::Format::NHWCD4, | |||
| find_opr<opr::Convolution>(y_opt).param().format); | |||
| graph->compile({{y_opt, {}}}) | |||
| ->to_json() | |||
| ->writeto_fpath(output_file( | |||
| "TestGoptInference.ConvertFormatNHWCD4TypeCvt.json")); | |||
| HostTensorND host_y_opt, host_y; | |||
| auto func = graph->compile({make_callback_copy(y, host_y), | |||
| make_callback_copy(y_opt, host_y_opt)}); | |||
| func->execute(); | |||
| MGB_ASSERT_TENSOR_EQ(host_y, host_y_opt); | |||
| *host_x = *gen({8, 8, 16, 16}, cn); | |||
| func->execute(); | |||
| MGB_ASSERT_TENSOR_EQ(host_y, host_y_opt); | |||
| } | |||
| TEST(TestGoptInference, ConvertFormatNHWCD4LOCAL) { | |||
| // hwcd4 is only supported in naive handle | |||
| NaiveMegDNNHandleScope naive_megdnn_handle; | |||
| @@ -1707,8 +1756,8 @@ TEST(TestGoptInference, FastProfileCache) { | |||
| using S = opr::Convolution::ExecutionPolicy::Strategy; | |||
| ASSERT_EQ(S::HEURISTIC, conv.execution_policy_transient().strategy); | |||
| gopt::modify_opr_algo_strategy_inplace({z + 2.3f}, | |||
| S::PROFILE | S::OPTMIZED); | |||
| ASSERT_EQ(S::PROFILE | S::OPTMIZED, conv.execution_policy().strategy); | |||
| S::PROFILE | S::OPTIMIZED); | |||
| ASSERT_EQ(S::PROFILE | S::OPTIMIZED, conv.execution_policy().strategy); | |||
| } | |||
| TEST(TestGoptInference, AlgoWorkspaceLimit) { | |||
| @@ -6,7 +6,7 @@ decl_opr('Convolution', | |||
| 'convolution kernel in ' | |||
| '(out channel, in channel, kern row, kern col) format')], | |||
| params=[('param', 'ConvolutionV0'), | |||
| ('execution_polity', 'ExecutionPolicy')], | |||
| ('execution_polity', 'ExecutionPolicyV0')], | |||
| desc='batched convolution on channeled 2D images') | |||
| decl_opr('Convolution', | |||
| @@ -28,7 +28,7 @@ decl_opr('ConvolutionBackwardData', | |||
| 'convolution kernel in ' | |||
| '(out channel, in channel, kern row, kern col) format')], | |||
| params=[('param', 'ConvolutionV0'), | |||
| ('execution_polity', 'ExecutionPolicy')], | |||
| ('execution_polity', 'ExecutionPolicyV0')], | |||
| body=[ | |||
| 'a, b = all_inputs', | |||
| 'all_inputs = [b, a]' | |||
| @@ -201,7 +201,7 @@ decl_opr('ConvBiasForward', | |||
| Doc('bias', 'bias'), | |||
| ], | |||
| params=[('param', 'ConvBiasV1'), | |||
| ('execution_policy', 'ExecutionPolicy')], | |||
| ('execution_policy', 'ExecutionPolicyV0')], | |||
| desc=('activation(convolution(src, filter) + bias) with specified ' | |||
| 'dtype'), | |||
| has_out_dtype=True) | |||
| @@ -283,7 +283,7 @@ std::vector<megdnn::Algorithm::SearchItem> flatten_search_space( | |||
| static bool algo_attribute_match_strategy(AlgoAttribute attribute, | |||
| ExecutionStrategy selected_strategy) { | |||
| bool ret = true; | |||
| if (selected_strategy & ExecutionStrategy::OPTMIZED) { | |||
| if (selected_strategy & ExecutionStrategy::OPTIMIZED) { | |||
| ret &= (!static_cast<bool>(AlgoAttribute::NAIVE & attribute)); | |||
| } else if (selected_strategy & ExecutionStrategy::REPRODUCIBLE) { | |||
| ret &= static_cast<bool>(AlgoAttribute::REPRODUCIBLE & attribute); | |||
| @@ -357,7 +357,7 @@ TEST(TestOprDNN, ConvBiasExePolicy) { | |||
| #if MGB_ENABLE_FASTRUN | |||
| for (auto strategy : | |||
| SmallVector<S>{S::PROFILE, S::HEURISTIC, S::PROFILE | S::REPRODUCIBLE, | |||
| S::PROFILE | S::HEURISTIC, S::PROFILE | S::OPTMIZED}) { | |||
| S::PROFILE | S::HEURISTIC, S::PROFILE | S::OPTIMIZED}) { | |||
| #else | |||
| for (auto strategy : | |||
| SmallVector<S>{S : HEURISTIC, S::PROFILE | S::HEURISTIC}) { | |||
| @@ -444,7 +444,7 @@ TEST(TestOprDNN, ConvolutionExePolicy) { | |||
| #if MGB_ENABLE_FASTRUN | |||
| for (auto strategy : | |||
| SmallVector<S>{S::PROFILE, S::HEURISTIC, S::PROFILE | S::REPRODUCIBLE, | |||
| S::PROFILE | S::HEURISTIC, S::PROFILE | S::OPTMIZED}) { | |||
| S::PROFILE | S::HEURISTIC, S::PROFILE | S::OPTIMIZED}) { | |||
| #else | |||
| for (auto strategy : | |||
| SmallVector<S>{S : HEURISTIC, S::PROFILE | S::HEURISTIC}) { | |||
| @@ -1717,7 +1717,7 @@ TEST(TestOprDNN, LocalShareForwardExecPolicy) { | |||
| #if MGB_ENABLE_FASTRUN | |||
| for (auto strategy : | |||
| SmallVector<S>{S::PROFILE, S::HEURISTIC, S::PROFILE | S::REPRODUCIBLE, | |||
| S::PROFILE | S::HEURISTIC, S::PROFILE | S::OPTMIZED}) { | |||
| S::PROFILE | S::HEURISTIC, S::PROFILE | S::OPTIMIZED}) { | |||
| #else | |||
| for (auto strategy : | |||
| SmallVector<S>{S : HEURISTIC, S::PROFILE | S::HEURISTIC}) { | |||
| @@ -1828,7 +1828,7 @@ TEST(TestOprDNN, DeformableConvForward) { | |||
| #if MGB_ENABLE_FASTRUN | |||
| for (auto strategy : | |||
| SmallVector<S>{S::PROFILE, S::HEURISTIC, S::PROFILE | S::REPRODUCIBLE, | |||
| S::PROFILE | S::HEURISTIC, S::PROFILE | S::OPTMIZED}) { | |||
| S::PROFILE | S::HEURISTIC, S::PROFILE | S::OPTIMIZED}) { | |||
| #else | |||
| for (auto strategy : | |||
| SmallVector<S>{S : HEURISTIC, S::PROFILE | S::HEURISTIC}) { | |||
| @@ -1997,7 +1997,7 @@ TEST(TestOprDNN, BatchConvBiasForward) { | |||
| #if MGB_ENABLE_FASTRUN | |||
| for (auto strategy : | |||
| SmallVector<S>{S::PROFILE, S::HEURISTIC, S::PROFILE | S::REPRODUCIBLE, | |||
| S::PROFILE | S::HEURISTIC, S::PROFILE | S::OPTMIZED}) { | |||
| S::PROFILE | S::HEURISTIC, S::PROFILE | S::OPTIMIZED}) { | |||
| #else | |||
| for (auto strategy : | |||
| SmallVector<S>{S : HEURISTIC, S::PROFILE | S::HEURISTIC}) { | |||
| @@ -290,11 +290,8 @@ ExternCOprRunner::ExternCOprRunner(std::string& name, | |||
| m_dump_name{name}, | |||
| m_param{nullptr} { | |||
| mgb_assert(m_desc->size == sizeof(MGBOprDesc), | |||
| "invalid MGBOprDesc size: expect=%zu got=%u, may caused by " | |||
| "extern_c_opr.h mismatch, please confirm that the " | |||
| "extern_c_opr.h used when compiling the loader is consistent " | |||
| "with the runtime caller build used", | |||
| sizeof(MGBOprDesc), m_desc->size); | |||
| "invalid MGBOprDesc size: expect=%zu got=%u", sizeof(MGBOprDesc), | |||
| m_desc->size); | |||
| for (auto i : inputs) { | |||
| add_input({i}); | |||
| } | |||
| @@ -0,0 +1,8 @@ | |||
| add_custom_command( | |||
| OUTPUT link_sh | |||
| COMMAND ${CMAKE_COMMAND} -E create_symlink | |||
| ${PROJECT_SOURCE_DIR}/tools/mlir/mgb-file-check/mgb-file-check.sh | |||
| ${PROJECT_BINARY_DIR}/tools/mlir/mgb-file-check/mgb-file-check | |||
| ) | |||
| add_custom_target(mgb-file-check DEPENDS link_sh) | |||
| @@ -0,0 +1,3 @@ | |||
| #!/bin/bash -e | |||
| FileCheck --enable-var-scope --dump-input=fail "$@" | |||
| @@ -0,0 +1,23 @@ | |||
| get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) | |||
| get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) | |||
| set(LIBS | |||
| ${dialect_libs} | |||
| ${conversion_libs} | |||
| LLVMSupport | |||
| MLIROptLib | |||
| MLIRIR | |||
| MLIRPass | |||
| MLIRSupport | |||
| ) | |||
| add_executable(mgb-opt mgb-opt.cpp) | |||
| target_include_directories( | |||
| mgb-opt | |||
| PRIVATE ${MLIR_LLVM_INCLUDE_DIR} ${PROJECT_SOURCE_DIR}/src/jit/include | |||
| ${PROJECT_BINARY_DIR}/src/jit/include) | |||
| add_dependencies(mgb-opt mgb_dialect) | |||
| target_link_libraries(mgb-opt PRIVATE ${LIBS} megbrain megdnn ${MGE_CUDA_LIBS}) | |||
| llvm_update_compile_flags(mgb-opt) | |||
| @@ -0,0 +1,85 @@ | |||
| /** | |||
| * \file tools/mlir/mgb-opt/mgb-opt.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "megbrain/jit/mlir/ir/dialect.h" | |||
| #include "megbrain/jit/mlir/ir/passes.h" | |||
| #include <llvm/Support/CommandLine.h> | |||
| #include <llvm/Support/InitLLVM.h> | |||
| #include <llvm/Support/PrettyStackTrace.h> | |||
| #include <llvm/Support/SourceMgr.h> | |||
| #include <llvm/Support/ToolOutputFile.h> | |||
| #include <mlir/Dialect/Affine/IR/AffineOps.h> | |||
| #include <mlir/Dialect/LLVMIR/LLVMDialect.h> | |||
| #include <mlir/IR/AsmState.h> | |||
| #include <mlir/InitAllDialects.h> | |||
| #include <mlir/InitAllPasses.h> | |||
| #include <mlir/Pass/Pass.h> | |||
| #include <mlir/Pass/PassManager.h> | |||
| #include <mlir/Support/FileUtilities.h> | |||
| #include <mlir/Support/MlirOptMain.h> | |||
| using namespace llvm; | |||
| using namespace mlir; | |||
| //! TODO: Implement a custom MlirOptMain that supports the following flags. | |||
| static cl::opt<bool> print_mlir{ | |||
| "print-mlir", | |||
| cl::desc("Prints MLIR IR after translation"), | |||
| cl::init(false), | |||
| }; | |||
| static cl::list<std::string> input_values{ | |||
| "input-value", | |||
| cl::desc("Input shapes and optional values"), | |||
| cl::ZeroOrMore, | |||
| }; | |||
| static cl::opt<std::string> input_values_file{ | |||
| "input-value-file", | |||
| cl::desc("Provides a file for input shapes and optional values (see " | |||
| "ParseToVariantListFromFile in vm_util.h for details)"), | |||
| cl::init(""), | |||
| }; | |||
| static cl::opt<bool> run{ | |||
| "run", | |||
| cl::desc("Runs the module (vs. just compiling and verifing)"), | |||
| cl::init(true), | |||
| }; | |||
| static cl::list<std::string> run_args{ | |||
| "run-arg", | |||
| cl::desc("Argument passed to the execution flag parser"), | |||
| cl::ZeroOrMore, | |||
| }; | |||
| namespace mgb { | |||
| namespace jit { | |||
| void register_test_mgb_to_affine_lowering_pass(); | |||
| void register_test_affine_to_llvm_lowering_pass(); | |||
| } // namespace jit | |||
| } // namespace mgb | |||
| int main(int argc, char** argv) { | |||
| mlir::registerAllPasses(); | |||
| mlir::DialectRegistry registry; | |||
| mlir::registerAllDialects(registry); | |||
| registry.insert<mgb::jit::MgbDialect>(); | |||
| mgb::jit::register_test_mgb_to_affine_lowering_pass(); | |||
| mgb::jit::register_test_affine_to_llvm_lowering_pass(); | |||
| return failed(MlirOptMain(argc, argv, "MLIR modular optimizer driver", registry)); | |||
| } | |||
| @@ -41,8 +41,14 @@ pdef('PersistentOutputStorage').add_fields( | |||
| Doc('REPRODUCIBLE', | |||
| 'when profile or heuristic algo selection it require the algos' | |||
| 'must be reproducible'), | |||
| Doc('OPTMIZED', | |||
| 'profile require algos are optmized to achieve fast-profile')). | |||
| Doc('OPTIMIZED', | |||
| 'profile require algos are optmized to achieve fast-profile'), | |||
| default=('HEURISTIC',), | |||
| member_alias=[(('HEURISTIC', 'REPRODUCIBLE'), 'HEURISTIC_REPRODUCIBLE'), | |||
| (('PROFILE', 'REPRODUCIBLE'), 'PROFILE_REPRODUCIBLE'), | |||
| (('PROFILE', 'HEURISTIC'), 'PROFILE_HEURISTIC'), | |||
| (('OPTIMIZED',), 'OPTMIZED'), | |||
| ]). | |||
| add_fields('uint64', | |||
| Doc('workspace_limit', 'workspace limit in bytes'), | |||
| str(2**64-1)+'ull')) | |||