|
|
@@ -73,11 +73,21 @@ class member_defs: |
|
|
"""define an enum; the result would contain both an enum class def and its |
|
|
"""define an enum; the result would contain both an enum class def and its |
|
|
corresponding data field |
|
|
corresponding data field |
|
|
|
|
|
|
|
|
:param default: index of default member value |
|
|
|
|
|
|
|
|
:param default: |
|
|
|
|
|
for normal enum class: index of default member value |
|
|
|
|
|
for bit combined class: tuple of index of default member value |
|
|
|
|
|
|
|
|
|
|
|
For example, following representations of the default value for bit |
|
|
|
|
|
combined class are all equivalent: |
|
|
|
|
|
Enum(members=('a', 'b', 'c'), default=('a', 'b'), ...) |
|
|
|
|
|
Enum(members=('a', 'b', 'c'), default=(0, 1), ...) |
|
|
|
|
|
Enum(members=('a', 'b', 'c'), default=(1 << 0) | (1 << 1), ...) |
|
|
|
|
|
|
|
|
:attr name_field: name of the data field of this enum in the param |
|
|
:attr name_field: name of the data field of this enum in the param |
|
|
struct |
|
|
struct |
|
|
:attr member_alias: list of (member, alias) pairs |
|
|
|
|
|
|
|
|
:attr member_alias: |
|
|
|
|
|
for normal enum class: list of (member, alias) pairs |
|
|
|
|
|
for bit combined class: list of (tuple of members, alias) paris |
|
|
""" |
|
|
""" |
|
|
__slots__ = ['name', 'name_field', 'members', 'default', |
|
|
__slots__ = ['name', 'name_field', 'members', 'default', |
|
|
'member_alias', 'combined'] |
|
|
'member_alias', 'combined'] |
|
|
@@ -90,17 +100,11 @@ class member_defs: |
|
|
name = member_defs.Doc.make(name) |
|
|
name = member_defs.Doc.make(name) |
|
|
assert name.id[0].isupper() |
|
|
assert name.id[0].isupper() |
|
|
members = tuple(map(member_defs.Doc.make, members)) |
|
|
members = tuple(map(member_defs.Doc.make, members)) |
|
|
if isinstance(default, str): |
|
|
|
|
|
if default not in name_field: |
|
|
|
|
|
raise ValueError( |
|
|
|
|
|
"Default value '{}' does not exist.".format(default)) |
|
|
|
|
|
default = name_field.index(default) |
|
|
|
|
|
assert isinstance(default, int) |
|
|
|
|
|
self.name = name |
|
|
self.name = name |
|
|
self.combined = combined |
|
|
self.combined = combined |
|
|
self.name_field = self.get_name_field(name.id, name_field) |
|
|
self.name_field = self.get_name_field(name.id, name_field) |
|
|
self.members = members |
|
|
self.members = members |
|
|
self.default = default |
|
|
|
|
|
|
|
|
self.default = self.normalize_enum_value(default) |
|
|
|
|
|
|
|
|
self.all_enums[(param_name, name.id)] = self |
|
|
self.all_enums[(param_name, name.id)] = self |
|
|
|
|
|
|
|
|
@@ -114,6 +118,43 @@ class member_defs: |
|
|
assert isinstance(name_field, str) |
|
|
assert isinstance(name_field, str) |
|
|
return name_field |
|
|
return name_field |
|
|
|
|
|
|
|
|
|
|
|
def normalize_enum_value(self, value): |
|
|
|
|
|
def normalize(v): |
|
|
|
|
|
if isinstance(v, str): |
|
|
|
|
|
if v not in self.members: |
|
|
|
|
|
raise ValueError( |
|
|
|
|
|
"enum member '{}' does not exist.".format(v)) |
|
|
|
|
|
v = self.members.index(v) |
|
|
|
|
|
assert isinstance(v, int) |
|
|
|
|
|
return v |
|
|
|
|
|
if self.combined: |
|
|
|
|
|
if isinstance(value, int): |
|
|
|
|
|
value = self.decompose_combined_enum(value) |
|
|
|
|
|
assert isinstance(value, tuple) |
|
|
|
|
|
value = tuple(normalize(i) for i in value) |
|
|
|
|
|
return value |
|
|
|
|
|
else: |
|
|
|
|
|
return normalize(value) |
|
|
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
|
|
|
def decompose_combined_enum(v): |
|
|
|
|
|
"""Integer => tuple of the indexes of the enum members""" |
|
|
|
|
|
assert isinstance(v, int) |
|
|
|
|
|
idx = 0 |
|
|
|
|
|
members = [] |
|
|
|
|
|
while v > 0: |
|
|
|
|
|
if v & 1: |
|
|
|
|
|
members.append(idx) |
|
|
|
|
|
idx += 1 |
|
|
|
|
|
v >>= 1 |
|
|
|
|
|
return tuple(members) |
|
|
|
|
|
|
|
|
|
|
|
def compose_combined_enum(self, v): |
|
|
|
|
|
"""tuple of members => Integer""" |
|
|
|
|
|
assert self.combined and isinstance(v, tuple) |
|
|
|
|
|
norm_v = self.normalize_enum_value(v) |
|
|
|
|
|
return sum(1 << i for i in norm_v) |
|
|
|
|
|
|
|
|
class Field(Base): |
|
|
class Field(Base): |
|
|
"""define a normal data field""" |
|
|
"""define a normal data field""" |
|
|
__slots__ = ['name', 'dtype', 'default'] |
|
|
__slots__ = ['name', 'dtype', 'default'] |
|
|
@@ -146,6 +187,10 @@ class member_defs: |
|
|
src_name = name |
|
|
src_name = name |
|
|
self.src_name = src_name |
|
|
self.src_name = src_name |
|
|
self.default = default |
|
|
self.default = default |
|
|
|
|
|
# TODO: remove this assertion if needed; adding mock param_defs in |
|
|
|
|
|
# current testing framework is too complicated, and currently we |
|
|
|
|
|
# only allow aliasing of normal enum |
|
|
|
|
|
assert not self.src_enum.combined |
|
|
|
|
|
|
|
|
@property |
|
|
@property |
|
|
def src_enum(self): |
|
|
def src_enum(self): |
|
|
@@ -157,7 +202,7 @@ class member_defs: |
|
|
set""" |
|
|
set""" |
|
|
if self.default is None: |
|
|
if self.default is None: |
|
|
return self.src_enum.default |
|
|
return self.src_enum.default |
|
|
return self.default |
|
|
|
|
|
|
|
|
return self.src_enum.normalize_enum_value(self.default) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ParamDef: |
|
|
class ParamDef: |
|
|
@@ -198,7 +243,7 @@ class ParamDef: |
|
|
self.name.id, name, name_field, members, default, member_alias)) |
|
|
self.name.id, name, name_field, members, default, member_alias)) |
|
|
return self |
|
|
return self |
|
|
|
|
|
|
|
|
def add_bit_combination_enum(self, name, *members, default=0, |
|
|
|
|
|
|
|
|
def add_bit_combination_enum(self, name, *members, default=tuple(), |
|
|
name_field=None, member_alias=[]): |
|
|
name_field=None, member_alias=[]): |
|
|
self.members.append(member_defs.Enum( |
|
|
self.members.append(member_defs.Enum( |
|
|
self.name.id, name, name_field, members, default, member_alias, True)) |
|
|
self.name.id, name, name_field, members, default, member_alias, True)) |
|
|
@@ -322,11 +367,13 @@ class PyWriter(IndentWriterBase): |
|
|
' for idx, v in enumerate(pdata):\n' |
|
|
' for idx, v in enumerate(pdata):\n' |
|
|
' if isinstance(v, _EnumBase):\n' |
|
|
' if isinstance(v, _EnumBase):\n' |
|
|
' pdata[idx] = _enum_member2num[id(v)]\n' |
|
|
' pdata[idx] = _enum_member2num[id(v)]\n' |
|
|
|
|
|
' elif isinstance(v, _BitCombinedEnumBase):\n' |
|
|
|
|
|
' pdata[idx] = v._value_\n' |
|
|
' return tag + self._packer.pack(*pdata)\n' |
|
|
' return tag + self._packer.pack(*pdata)\n' |
|
|
'\n' |
|
|
'\n' |
|
|
) |
|
|
) |
|
|
self._write( |
|
|
|
|
|
'class _EnumBase(enum.Enum):\n' |
|
|
|
|
|
|
|
|
# it's hard to mix custom implemention into enum, just do copy-paste instead |
|
|
|
|
|
classbody = ( |
|
|
' @classmethod\n' |
|
|
' @classmethod\n' |
|
|
' def __normalize(cls, val):\n' |
|
|
' def __normalize(cls, val):\n' |
|
|
' if isinstance(val, str):\n' |
|
|
' if isinstance(val, str):\n' |
|
|
@@ -349,6 +396,12 @@ class PyWriter(IndentWriterBase): |
|
|
' return super()._missing_(value)\n' |
|
|
' return super()._missing_(value)\n' |
|
|
'\n' |
|
|
'\n' |
|
|
) |
|
|
) |
|
|
|
|
|
self._write( |
|
|
|
|
|
'class _EnumBase(enum.Enum):\n' + classbody |
|
|
|
|
|
) |
|
|
|
|
|
self._write( |
|
|
|
|
|
'class _BitCombinedEnumBase(enum.Flag):\n' + classbody |
|
|
|
|
|
) |
|
|
if not self._imperative: |
|
|
if not self._imperative: |
|
|
self._write( |
|
|
self._write( |
|
|
'def _as_dtype_num(dtype):\n' |
|
|
'def _as_dtype_num(dtype):\n' |
|
|
@@ -464,30 +517,42 @@ class SerializedDType(_ParamDefBase): |
|
|
def _on_member_enum(self, e): |
|
|
def _on_member_enum(self, e): |
|
|
qualname = '{}.{}'.format(self._cur_param_name, e.name) |
|
|
qualname = '{}.{}'.format(self._cur_param_name, e.name) |
|
|
|
|
|
|
|
|
self._write('class %s(_EnumBase):', e.name, indent=1) |
|
|
|
|
|
|
|
|
if e.combined: |
|
|
|
|
|
self._write('class %s(_BitCombinedEnumBase):', e.name, indent=1) |
|
|
|
|
|
else: |
|
|
|
|
|
self._write('class %s(_EnumBase):', e.name, indent=1) |
|
|
|
|
|
|
|
|
self._write_doc(e.name) |
|
|
self._write_doc(e.name) |
|
|
|
|
|
|
|
|
for idx, emem in enumerate(e.members): |
|
|
for idx, emem in enumerate(e.members): |
|
|
self._write('%s = "%s"', emem, emem) |
|
|
|
|
|
self._write_doc(emem) |
|
|
|
|
|
if e.combined: |
|
|
if e.combined: |
|
|
self._enum_member2num.append('id({}.{}):{}'.format( |
|
|
|
|
|
qualname, emem, 1<<idx)) |
|
|
|
|
|
|
|
|
self._write('%s = 1 << %d', emem, idx) |
|
|
|
|
|
self._write_doc(emem) |
|
|
else: |
|
|
else: |
|
|
|
|
|
self._write('%s = "%s"', emem, emem) |
|
|
|
|
|
self._write_doc(emem) |
|
|
self._enum_member2num.append('id({}.{}):{}'.format( |
|
|
self._enum_member2num.append('id({}.{}):{}'.format( |
|
|
qualname, emem, idx)) |
|
|
qualname, emem, idx)) |
|
|
|
|
|
|
|
|
for emem, emem_alis in e.member_alias: |
|
|
|
|
|
self._write('%s = %s', emem_alis, emem) |
|
|
|
|
|
|
|
|
for emem, emem_alias in e.member_alias: |
|
|
|
|
|
if e.combined: |
|
|
|
|
|
self._write('%s = %s', emem_alias, e.compose_combined_enum(emem)) |
|
|
|
|
|
else: |
|
|
|
|
|
self._write('%s = %s', emem_alias, emem) |
|
|
|
|
|
|
|
|
self._unindent() |
|
|
self._unindent() |
|
|
self._write('') |
|
|
self._write('') |
|
|
|
|
|
|
|
|
|
|
|
if e.combined: |
|
|
|
|
|
default = e.compose_combined_enum(e.default) |
|
|
|
|
|
else: |
|
|
|
|
|
default = "'{}'".format(e.members[e.default]) |
|
|
|
|
|
|
|
|
self._cur_fields.append(self.FieldDef( |
|
|
self._cur_fields.append(self.FieldDef( |
|
|
name=e.name_field, |
|
|
name=e.name_field, |
|
|
cvt='{}.convert({})'.format(qualname, e.name_field), |
|
|
cvt='{}.convert({})'.format(qualname, e.name_field), |
|
|
fmt='I', |
|
|
fmt='I', |
|
|
default="'{}'".format(e.members[e.default]), |
|
|
|
|
|
|
|
|
default=default, |
|
|
type=qualname, |
|
|
type=qualname, |
|
|
doc=None)) |
|
|
doc=None)) |
|
|
|
|
|
|
|
|
@@ -495,11 +560,16 @@ class SerializedDType(_ParamDefBase): |
|
|
self._write('%s = %s.%s', e.name, e.src_class, e.src_name) |
|
|
self._write('%s = %s.%s', e.name, e.src_class, e.src_name) |
|
|
s = e.src_enum |
|
|
s = e.src_enum |
|
|
qualname = '{}.{}'.format(e.src_class, e.src_name) |
|
|
qualname = '{}.{}'.format(e.src_class, e.src_name) |
|
|
|
|
|
|
|
|
|
|
|
if s.combined: |
|
|
|
|
|
default = s.compose_combined_enum(e.get_default()) |
|
|
|
|
|
else: |
|
|
|
|
|
default = "'{}'".format(s.members[e.get_default()]) |
|
|
self._cur_fields.append(self.FieldDef( |
|
|
self._cur_fields.append(self.FieldDef( |
|
|
name=e.name_field, |
|
|
name=e.name_field, |
|
|
cvt='{}.convert({})'.format(qualname, e.name_field), |
|
|
cvt='{}.convert({})'.format(qualname, e.name_field), |
|
|
fmt='I', |
|
|
fmt='I', |
|
|
default="'{}'".format(s.members[e.get_default()]), |
|
|
|
|
|
|
|
|
default=default, |
|
|
type=qualname, |
|
|
type=qualname, |
|
|
doc=None)) |
|
|
doc=None)) |
|
|
|
|
|
|
|
|
@@ -639,14 +709,19 @@ class CPPWriter(IndentWriterBase): |
|
|
v += ',' |
|
|
v += ',' |
|
|
self._write(v) |
|
|
self._write(v) |
|
|
for mem, alias in e.member_alias: |
|
|
for mem, alias in e.member_alias: |
|
|
self._write('%s = %s,', alias, mem) |
|
|
|
|
|
|
|
|
if e.combined: |
|
|
|
|
|
self._write('%s = %s,', alias, e.compose_combined_enum(mem)) |
|
|
|
|
|
else: |
|
|
|
|
|
self._write('%s = %s,', alias, mem) |
|
|
self._write('};', indent=-1) |
|
|
self._write('};', indent=-1) |
|
|
self._non_static_members.append(e) |
|
|
self._non_static_members.append(e) |
|
|
self._write('static MEGDNN_CONSTEXPR uint32_t %s_NR_MEMBER = %d;', |
|
|
self._write('static MEGDNN_CONSTEXPR uint32_t %s_NR_MEMBER = %d;', |
|
|
str(e.name).upper(), len(e.members)) |
|
|
str(e.name).upper(), len(e.members)) |
|
|
self._add_ctor_args(e.name, |
|
|
|
|
|
'{}::{}'.format(e.name, e.members[e.default]), |
|
|
|
|
|
e.name_field) |
|
|
|
|
|
|
|
|
if e.combined: |
|
|
|
|
|
default = 'static_cast<{}>({})'.format(e.name, e.compose_combined_enum(e.default)) |
|
|
|
|
|
else: |
|
|
|
|
|
default = '{}::{}'.format(e.name, e.members[e.default]) |
|
|
|
|
|
self._add_ctor_args(e.name, default, e.name_field) |
|
|
|
|
|
|
|
|
def _on_member_enum_alias(self, e): |
|
|
def _on_member_enum_alias(self, e): |
|
|
s = e.src_enum |
|
|
s = e.src_enum |
|
|
@@ -654,10 +729,11 @@ class CPPWriter(IndentWriterBase): |
|
|
self._non_static_members.append(e) |
|
|
self._non_static_members.append(e) |
|
|
self._write('static MEGDNN_CONSTEXPR uint32_t %s_NR_MEMBER = %d;', |
|
|
self._write('static MEGDNN_CONSTEXPR uint32_t %s_NR_MEMBER = %d;', |
|
|
str(e.name).upper(), len(s.members)) |
|
|
str(e.name).upper(), len(s.members)) |
|
|
self._add_ctor_args(e.name, |
|
|
|
|
|
'{}::{}'.format(e.name, |
|
|
|
|
|
s.members[e.get_default()]), |
|
|
|
|
|
e.name_field) |
|
|
|
|
|
|
|
|
if s.combined: |
|
|
|
|
|
default = 'static_cast<{}>({})'.format(e.name, s.compose_combined_enum(e.default)) |
|
|
|
|
|
else: |
|
|
|
|
|
default = '{}::{}'.format(e.name, s.members[e.get_default()]) |
|
|
|
|
|
self._add_ctor_args(e.name, default, e.name_field) |
|
|
|
|
|
|
|
|
def _on_member_field(self, f): |
|
|
def _on_member_field(self, f): |
|
|
self._non_static_members.append(f) |
|
|
self._non_static_members.append(f) |
|
|
|