GitOrigin-RevId: d2f22ad5fe
tags/v1.2.0
| @@ -230,6 +230,10 @@ endif() | |||
| # FIXME At present, there are some conflicts between the LLVM that halide | |||
| # depends on and the LLVM that MLIR depends on. Should be fixed in subsequent | |||
| # versions. | |||
| if(MGE_BUILD_IMPERATIVE_RT) | |||
| set(MGE_WITH_HALIDE OFF) | |||
| message(WARNING "cannot use HALIDE when building IMPERATIVE_RT") | |||
| endif() | |||
| if(MGE_WITH_JIT_MLIR) | |||
| if(MGE_WITH_HALIDE) | |||
| message(FATAL_ERROR "please set MGE_WITH_HALIDE to OFF with MGE_WITH_JIT_MLIR enabled") | |||
| @@ -310,7 +314,7 @@ if(MGE_INFERENCE_ONLY) | |||
| set(MGE_BUILD_IMPERATIVE_RT OFF) | |||
| endif() | |||
| if(MGE_WITH_JIT_MLIR) | |||
| if(MGE_WITH_JIT_MLIR OR MGE_BUILD_IMPERATIVE_RT) | |||
| include(cmake/llvm-project.cmake) | |||
| endif() | |||
| @@ -750,7 +754,7 @@ target_include_directories(mgb_opr_param_defs | |||
| add_dependencies(mgb_opr_param_defs _mgb_opr_param_defs) | |||
| install(TARGETS mgb_opr_param_defs EXPORT ${MGE_EXPORT_TARGETS}) | |||
| if(MGE_WITH_JIT_MLIR) | |||
| if(MGE_WITH_JIT_MLIR OR MGE_BUILD_IMPERATIVE_RT) | |||
| # generate param_defs.td | |||
| set(MGE_GENFILE_DIR ${PROJECT_BINARY_DIR}/src/genfiles) | |||
| set(MGE_GEN_IR_DIR ${PROJECT_BINARY_DIR}/src/core/include/megbrain/ir) | |||
| @@ -800,12 +804,6 @@ if(TARGET _imperative_rt) | |||
| COMMAND ${CMAKE_COMMAND} -E create_symlink | |||
| ${CMAKE_CURRENT_BINARY_DIR}/imperative/python/${PACKAGE_NAME}/core/$<TARGET_FILE_NAME:${MODULE_NAME}> | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/imperative/python/${PACKAGE_NAME}/core/$<TARGET_FILE_NAME:${MODULE_NAME}> | |||
| COMMAND ${CMAKE_COMMAND} -E create_symlink | |||
| ${CMAKE_CURRENT_BINARY_DIR}/imperative/python/${PACKAGE_NAME}/core/ops/_internal/generated_ops.py | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/imperative/python/${PACKAGE_NAME}/core/ops/_internal/generated_ops.py | |||
| COMMAND ${CMAKE_COMMAND} -E create_symlink | |||
| ${CMAKE_CURRENT_BINARY_DIR}/imperative/python/${PACKAGE_NAME}/core/ops/_internal/param_defs.py | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/imperative/python/${PACKAGE_NAME}/core/ops/_internal/param_defs.py | |||
| DEPENDS _imperative_rt | |||
| VERBATIM | |||
| ) | |||
| @@ -0,0 +1,150 @@ | |||
| #!/usr/bin/env python3 | |||
| # -*- coding: utf-8 -*- | |||
| import argparse | |||
| import collections | |||
| import textwrap | |||
| import os | |||
| import hashlib | |||
| import struct | |||
| import io | |||
| from gen_param_defs import member_defs, ParamDef, IndentWriterBase | |||
| class ConverterWriter(IndentWriterBase): | |||
| _skip_current_param = False | |||
| _last_param = None | |||
| _current_tparams = None | |||
| _packed = None | |||
| _const = None | |||
| def __call__(self, fout, defs): | |||
| super().__call__(fout) | |||
| self._write("// %s", self._get_header()) | |||
| self._write("#ifndef MGB_PARAM") | |||
| self._write("#define MGB_PARAM") | |||
| self._process(defs) | |||
| self._write("#endif // MGB_PARAM") | |||
| def _ctype2attr(self, ctype, value): | |||
| if ctype == 'uint32_t': | |||
| return 'MgbUI32Attr', value | |||
| if ctype == 'uint64_t': | |||
| return 'MgbUI64Attr', value | |||
| if ctype == 'int32_t': | |||
| return 'MgbI32Attr', value | |||
| if ctype == 'float': | |||
| return 'MgbF32Attr', value | |||
| if ctype == 'double': | |||
| return 'MgbF64Attr', value | |||
| if ctype == 'bool': | |||
| return 'MgbBoolAttr', value | |||
| if ctype == 'DTypeEnum': | |||
| self._packed = False | |||
| return 'MgbDTypeAttr', 'megdnn::DType::from_enum(megdnn::{})'.format(value) | |||
| raise RuntimeError("unknown ctype") | |||
| def _on_param_begin(self, p): | |||
| self._last_param = p | |||
| if p.is_legacy: | |||
| self._skip_current_param = True | |||
| return | |||
| self._packed = True | |||
| self._current_tparams = [] | |||
| self._const = set() | |||
| def _on_param_end(self, p): | |||
| if self._skip_current_param: | |||
| self._skip_current_param = False | |||
| return | |||
| if self._packed: | |||
| self._write("class {0}ParamBase<string accessor> : MgbPackedParamBase<\"{0}\", accessor> {{".format(p.name), indent=1) | |||
| else: | |||
| self._write("def {0}Param: MgbParamBase<\"{0}\"> {{".format(p.name), indent=1) | |||
| self._write("let fields = (ins", indent=1) | |||
| self._write(",\n{}".format(self._cur_indent).join(self._current_tparams)) | |||
| self._write(");", indent=-1) | |||
| self._write("}\n", indent=-1) | |||
| if self._packed: | |||
| self._write("def {0}Param : {0}ParamBase<\"param\">;\n".format(p.name)) | |||
| self._current_tparams = None | |||
| self._packed = None | |||
| self._const = None | |||
| def _wrapped_with_default_value(self, attr, default): | |||
| return 'MgbDefaultValuedAttr<{}, \"{}\">'.format(attr, default) | |||
| def _on_member_enum(self, e): | |||
| p = self._last_param | |||
| # Note: always generate llvm Record def for enum attribute even it was not | |||
| # directly used by any operator, or other enum couldn't alias to this enum | |||
| td_class = "{}{}".format(p.name, e.name) | |||
| fullname = "::megdnn::param::{}".format(p.name) | |||
| enum_def = "MgbEnumAttr<\"{}\", \"{}\", [".format(fullname, e.name) | |||
| def format(v): | |||
| return '\"{}\"'.format(str(v)) | |||
| enum_def += ','.join(format(i) for i in e.members) | |||
| enum_def += "]>" | |||
| self._write("def {} : {};".format(td_class, enum_def)) | |||
| if self._skip_current_param: | |||
| return | |||
| # wrapped with default value | |||
| default_val = "static_cast<{}::{}>({})".format(fullname, e.name, e.default) | |||
| wrapped = self._wrapped_with_default_value(td_class, default_val) | |||
| self._current_tparams.append("{}:${}".format(wrapped, e.name_field)) | |||
| def _on_member_enum_alias(self, e): | |||
| p = self._last_param | |||
| if self._skip_current_param: | |||
| return | |||
| # write enum attr def | |||
| td_class = "{}{}".format(p.name, e.name) | |||
| fullname = "::megdnn::param::{}".format(p.name) | |||
| base_td_class = "{}{}".format(e.src_class, e.src_name) | |||
| enum_def = "MgbEnumAliasAttr<\"{}\", \"{}\", {}>".format(fullname, e.name, base_td_class) | |||
| self._write("def {} : {};".format(td_class, enum_def)) | |||
| # wrapped with default value | |||
| default_val = "static_cast<{}::{}>({})".format(fullname, e.name, e.get_default()) | |||
| wrapped = self._wrapped_with_default_value(td_class, default_val) | |||
| self._current_tparams.append("{}:${}".format(wrapped, e.name_field)) | |||
| def _on_member_field(self, f): | |||
| if self._skip_current_param: | |||
| return | |||
| attr, value = self._ctype2attr(f.dtype.cname, str(f.default)) | |||
| if str(value) in self._const: | |||
| value = '::megdnn::param::{}::{}'.format(self._last_param.name, value) | |||
| wrapped = self._wrapped_with_default_value(attr, value) | |||
| self._current_tparams.append("{}:${}".format(wrapped, f.name)) | |||
| def _on_const_field(self, f): | |||
| self._const.add(str(f.name)) | |||
| def main(): | |||
| parser = argparse.ArgumentParser('generate op param tablegen file') | |||
| parser.add_argument('input') | |||
| parser.add_argument('output') | |||
| args = parser.parse_args() | |||
| with open(args.input) as fin: | |||
| inputs = fin.read() | |||
| exec(inputs, {'pdef': ParamDef, 'Doc': member_defs.Doc}) | |||
| input_hash = hashlib.sha256() | |||
| input_hash.update(inputs.encode(encoding='UTF-8')) | |||
| input_hash = input_hash.hexdigest() | |||
| writer = ConverterWriter() | |||
| with open(args.output, 'w') as fout: | |||
| writer.set_input_hash(input_hash)(fout, ParamDef.all_param_defs) | |||
| if __name__ == "__main__": | |||
| main() | |||
| @@ -8,9 +8,7 @@ file(GLOB_RECURSE SRCS src/impl/*.cpp src/include/*.h python/src/*.cpp python/sr | |||
| set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DMGB_WITH_IMPERATIVE=1") | |||
| file(GLOB_RECURSE OPR_DECL_SRCS "${PROJECT_SOURCE_DIR}/src/**/*.oprdecl") | |||
| file(GLOB_RECURSE PYTHON_SRCS python/${PACKAGE_NAME}/*.py) | |||
| list(REMOVE_ITEM PYTHON_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/python/megengine/core/ops/_internal/generated_ops.py ${CMAKE_CURRENT_SOURCE_DIR}/python/megengine/core/ops/_internal/param_defs.py) | |||
| file(GLOB_RECURSE ALL_HEADERS src/cpp/megbrain_pubapi.h | |||
| ${PROJECT_SOURCE_DIR}/src/core/include/* | |||
| ${PROJECT_SOURCE_DIR}/src/opr/include/* | |||
| @@ -19,33 +17,8 @@ file(GLOB_RECURSE ALL_HEADERS src/cpp/megbrain_pubapi.h | |||
| ${PROJECT_SOURCE_DIR}/dnn/include/*) | |||
| set(MEGENGINE_DIR ${CMAKE_CURRENT_BINARY_DIR}/python/) | |||
| set(GEN_OPS_DIR ${MEGENGINE_DIR}/${PACKAGE_NAME}/core/ops/_internal) | |||
| file(MAKE_DIRECTORY ${GEN_OPS_DIR}) | |||
| set(GEN_OPS_FILE ${GEN_OPS_DIR}/generated_ops.py) | |||
| set(GEN_OP_PARAMS_FILE ${MEGENGINE_DIR}/${PACKAGE_NAME}/core/ops/_internal/param_defs.py) | |||
| set(GEN_OP_PARAMS_TEMPLATE ${CMAKE_CURRENT_SOURCE_DIR}/python/tools/ops.tpl.py) | |||
| ##################### generate python opr_param_defs.py ############## | |||
| file(COPY ${PROJECT_SOURCE_DIR}/dnn/scripts/opr_param_defs.py DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) | |||
| file(READ ${PROJECT_SOURCE_DIR}/tools/param_defs/mgb_opr_param_defs.py CONTENTS) | |||
| file(APPEND ${CMAKE_CURRENT_BINARY_DIR}/opr_param_defs.py ${CONTENTS}) | |||
| add_custom_command( | |||
| OUTPUT ${GEN_OPS_FILE} | |||
| COMMAND ${CMAKE_COMMAND} -E touch ${MEGENGINE_DIR}/${PACKAGE_NAME}/core/${MODULE_NAME}.so ${GEN_OPS_FILE} ${GEN_OP_PARAMS_FILE} | |||
| COMMAND ${CMAKE_COMMAND} -E copy_directory ${CMAKE_CURRENT_SOURCE_DIR}/python/${PACKAGE_NAME} ${MEGENGINE_DIR}/${PACKAGE_NAME} | |||
| COMMAND ${CMAKE_COMMAND} -E remove -f ${MEGENGINE_DIR}/${PACKAGE_NAME}/core/${MODULE_NAME}.so ${GEN_OPS_FILE} ${GEN_OP_PARAMS_FILE} | |||
| COMMAND ${PYTHON_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/python/tools/gen_ops.py ${OPR_DECL_SRCS} -o ${GEN_OPS_FILE} | |||
| COMMAND ${CMAKE_COMMAND} -E copy_directory ${CMAKE_CURRENT_SOURCE_DIR}/python/test ${MEGENGINE_DIR}/${PACKAGE_NAME}/test | |||
| COMMAND ${PYTHON_EXECUTABLE} ${PROJECT_SOURCE_DIR}/dnn/scripts/gen_param_defs.py -t py --imperative ${CMAKE_CURRENT_BINARY_DIR}/opr_param_defs.py ${GEN_OP_PARAMS_FILE} | |||
| DEPENDS ${OPR_DECL_SRCS} ${PYTHON_SRCS} ${ALL_HEADERS} ${GEN_OP_PARAMS_TEMPLATE} | |||
| VERBATIM | |||
| ) | |||
| add_custom_target(gen_opr_py DEPENDS ${GEN_OPS_FILE}) | |||
| ##################### end of opdef generation ######################### | |||
| add_subdirectory(tablegen) | |||
| add_custom_target(_version_ld SOURCES ${MGE_VERSION_SCRIPT}) | |||
| @@ -73,7 +46,7 @@ else() | |||
| endif() | |||
| endif() | |||
| target_include_directories(${MODULE_NAME} PUBLIC src/include PRIVATE ${PYTHON_INCLUDE_DIRS} ${NUMPY_INCLUDE_DIR}) | |||
| target_include_directories(${MODULE_NAME} PUBLIC src/include PRIVATE ${PYTHON_INCLUDE_DIRS} ${NUMPY_INCLUDE_DIR} ${MGB_OPDEF_OUT_DIR}) | |||
| target_compile_definitions(${MODULE_NAME} PRIVATE MODULE_NAME=${MODULE_NAME}) | |||
| target_compile_options(${MODULE_NAME} PRIVATE -Wno-unused-parameter) | |||
| if(CXX_SUPPORT_WCLASS_MEMACCESS) | |||
| @@ -87,7 +60,7 @@ if (APPLE OR MSVC OR WIN32) | |||
| message(VERBOSE "overwriting SUFFIX at macos and windows before config by set_target_properties") | |||
| pybind11_extension(${MODULE_NAME}) | |||
| endif() | |||
| add_dependencies(${MODULE_NAME} gen_opr_py _version_ld) | |||
| add_dependencies(${MODULE_NAME} mgb_opdef _version_ld) | |||
| if(MGE_WITH_TEST AND MGE_ENABLE_RTTI) | |||
| add_subdirectory(test) | |||
| @@ -19,7 +19,6 @@ from ..ops.builtin import ( | |||
| IndexingMultiAxisVec, | |||
| IndexingSetMultiAxisVec, | |||
| OpDef, | |||
| OprAttr, | |||
| Reduce, | |||
| Reshape, | |||
| SetSubtensor, | |||
| @@ -31,8 +30,6 @@ from ..tensor.function import Function | |||
| from ..tensor.tensor import Tensor | |||
| from ..tensor.tensor_wrapper import TensorWrapper | |||
| _reduce_sum_param = Reduce(mode="SUM").to_c().param[0] | |||
| @functools.singledispatch | |||
| def builtin_op_get_backward_fn(op: OpDef, inputs, outputs, input_requires_grad): | |||
| @@ -41,17 +38,18 @@ def builtin_op_get_backward_fn(op: OpDef, inputs, outputs, input_requires_grad): | |||
| @builtin_op_get_backward_fn.register(OpDef) | |||
| def _(op: OpDef, inputs, outputs, input_requires_grad): | |||
| if isinstance(op, OprAttr): | |||
| grad_fn = _oprAttr_grad_fn.get(op.type, None) | |||
| if grad_fn is None: | |||
| if op.type == Reduce.name and op.param[0] == _reduce_sum_param: | |||
| grad_fn = reduce_sum_grad_fn | |||
| else: | |||
| grad_fn = default_grad_fn | |||
| if isinstance(op, Reshape): | |||
| grad_fn = reshape_grad_fn | |||
| elif isinstance(op, Subtensor): | |||
| grad_fn = subtensor_grad_fn | |||
| elif isinstance(op, IndexingMultiAxisVec): | |||
| grad_fn = indexingMultiAxisVec_grad_fn | |||
| elif isinstance(op, Broadcast) or ( | |||
| isinstance(op, Elemwise) and op.mode == Elemwise.Mode.ADD | |||
| ): | |||
| grad_fn = elemwise_add_grad_fn | |||
| elif isinstance(op, Reduce) and op.mode.name == "SUM": | |||
| grad_fn = reduce_sum_grad_fn | |||
| else: | |||
| grad_fn = default_grad_fn | |||
| return grad_fn(op, inputs, outputs, input_requires_grad) | |||
| @@ -152,9 +150,7 @@ def reshape_grad_fn(op, inputs, outputs, input_requires_grad): | |||
| # override for Subtensor | |||
| def subtensor_grad_fn(op, inputs, outputs, input_requires_grad): | |||
| grad_op = OprAttr() | |||
| grad_op.type = SetSubtensor.name | |||
| grad_op.param = op.param | |||
| grad_op = SetSubtensor(op.items) | |||
| input_shape = get_shape(inputs[0]) | |||
| params = inputs[1:] | |||
| @@ -175,9 +171,7 @@ def subtensor_grad_fn(op, inputs, outputs, input_requires_grad): | |||
| # override for IndexingMultiAxisVec | |||
| def indexingMultiAxisVec_grad_fn(op, inputs, outputs, input_requires_grad): | |||
| grad_op = OprAttr() | |||
| grad_op.type = IndexingSetMultiAxisVec.name | |||
| grad_op.param = op.param | |||
| grad_op = IndexingSetMultiAxisVec(op.items) | |||
| input_shape = get_shape(inputs[0]) | |||
| params = inputs[1:] | |||
| @@ -209,10 +203,3 @@ def reduce_sum_grad_fn(op, inputs, outputs, input_requires_grad): | |||
| return (broadcast_to(dy, input_shape) if input_requires_grad[0] else None,) | |||
| return backward, [True] | |||
| _oprAttr_grad_fn = { | |||
| Reshape.name: reshape_grad_fn, | |||
| Subtensor.name: subtensor_grad_fn, | |||
| IndexingMultiAxisVec.name: indexingMultiAxisVec_grad_fn, | |||
| } | |||
| @@ -1,8 +0,0 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| @@ -1,10 +0,0 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from .generated_ops import * | |||
| from .misc_ops import * | |||
| @@ -1,939 +0,0 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import sys | |||
| from functools import reduce | |||
| from operator import or_ as _or_ | |||
| from types import DynamicClassAttribute, MappingProxyType | |||
| # try _collections first to reduce startup cost | |||
| try: | |||
| from _collections import OrderedDict | |||
| except ImportError: | |||
| from collections import OrderedDict | |||
| __all__ = [ | |||
| "EnumMeta", | |||
| "Enum", | |||
| "IntEnum", | |||
| "Flag", | |||
| "IntFlag", | |||
| "auto", | |||
| "unique", | |||
| ] | |||
| def _is_descriptor(obj): | |||
| """Returns True if obj is a descriptor, False otherwise.""" | |||
| return ( | |||
| hasattr(obj, "__get__") or hasattr(obj, "__set__") or hasattr(obj, "__delete__") | |||
| ) | |||
| def _is_dunder(name): | |||
| """Returns True if a __dunder__ name, False otherwise.""" | |||
| return ( | |||
| name[:2] == name[-2:] == "__" | |||
| and name[2:3] != "_" | |||
| and name[-3:-2] != "_" | |||
| and len(name) > 4 | |||
| ) | |||
| def _is_sunder(name): | |||
| """Returns True if a _sunder_ name, False otherwise.""" | |||
| return ( | |||
| name[0] == name[-1] == "_" | |||
| and name[1:2] != "_" | |||
| and name[-2:-1] != "_" | |||
| and len(name) > 2 | |||
| ) | |||
| def _make_class_unpicklable(cls): | |||
| """Make the given class un-picklable.""" | |||
| def _break_on_call_reduce(self, proto): | |||
| raise TypeError("%r cannot be pickled" % self) | |||
| cls.__reduce_ex__ = _break_on_call_reduce | |||
| cls.__module__ = "<unknown>" | |||
| _auto_null = object() | |||
| class auto: | |||
| """ | |||
| Instances are replaced with an appropriate value in Enum class suites. | |||
| """ | |||
| value = _auto_null | |||
| class _EnumDict(dict): | |||
| """ | |||
| Track enum member order and ensure member names are not reused. | |||
| EnumMeta will use the names found in self._member_names as the | |||
| enumeration member names. | |||
| """ | |||
| def __init__(self): | |||
| super().__init__() | |||
| self._member_names = [] | |||
| self._last_values = [] | |||
| def __setitem__(self, key, value): | |||
| """ | |||
| Changes anything not dundered or not a descriptor. | |||
| If an enum member name is used twice, an error is raised; duplicate | |||
| values are not checked for. | |||
| Single underscore (sunder) names are reserved. | |||
| """ | |||
| if _is_sunder(key): | |||
| if key not in ( | |||
| "_order_", | |||
| "_create_pseudo_member_", | |||
| "_generate_next_value_", | |||
| "_missing_", | |||
| ): | |||
| raise ValueError("_names_ are reserved for future Enum use") | |||
| if key == "_generate_next_value_": | |||
| setattr(self, "_generate_next_value", value) | |||
| elif _is_dunder(key): | |||
| if key == "__order__": | |||
| key = "_order_" | |||
| elif key in self._member_names: | |||
| # descriptor overwriting an enum? | |||
| raise TypeError("Attempted to reuse key: %r" % key) | |||
| elif not _is_descriptor(value): | |||
| if key in self: | |||
| # enum overwriting a descriptor? | |||
| raise TypeError("%r already defined as: %r" % (key, self[key])) | |||
| if isinstance(value, auto): | |||
| if value.value == _auto_null: | |||
| value.value = self._generate_next_value( | |||
| key, 1, len(self._member_names), self._last_values[:] | |||
| ) | |||
| value = value.value | |||
| self._member_names.append(key) | |||
| self._last_values.append(value) | |||
| super().__setitem__(key, value) | |||
| # Dummy value for Enum as EnumMeta explicitly checks for it, but of course | |||
| # until EnumMeta finishes running the first time the Enum class doesn't exist. | |||
| # This is also why there are checks in EnumMeta like `if Enum is not None` | |||
| Enum = None | |||
| class EnumMeta(type): | |||
| """Metaclass for Enum""" | |||
| @classmethod | |||
| def __prepare__(metacls, cls, bases): | |||
| # create the namespace dict | |||
| enum_dict = _EnumDict() | |||
| # inherit previous flags and _generate_next_value_ function | |||
| member_type, first_enum = metacls._get_mixins_(bases) | |||
| if first_enum is not None: | |||
| enum_dict["_generate_next_value_"] = getattr( | |||
| first_enum, "_generate_next_value_", None | |||
| ) | |||
| return enum_dict | |||
| def __new__(metacls, cls, bases, classdict): | |||
| # an Enum class is final once enumeration items have been defined; it | |||
| # cannot be mixed with other types (int, float, etc.) if it has an | |||
| # inherited __new__ unless a new __new__ is defined (or the resulting | |||
| # class will fail). | |||
| member_type, first_enum = metacls._get_mixins_(bases) | |||
| __new__, save_new, use_args = metacls._find_new_( | |||
| classdict, member_type, first_enum | |||
| ) | |||
| # save enum items into separate mapping so they don't get baked into | |||
| # the new class | |||
| enum_members = {k: classdict[k] for k in classdict._member_names} | |||
| for name in classdict._member_names: | |||
| del classdict[name] | |||
| # adjust the sunders | |||
| _order_ = classdict.pop("_order_", None) | |||
| # check for illegal enum names (any others?) | |||
| invalid_names = set(enum_members) & { | |||
| "mro", | |||
| } | |||
| if invalid_names: | |||
| raise ValueError( | |||
| "Invalid enum member name: {0}".format(",".join(invalid_names)) | |||
| ) | |||
| # create a default docstring if one has not been provided | |||
| if "__doc__" not in classdict: | |||
| classdict["__doc__"] = "An enumeration." | |||
| # create our new Enum type | |||
| enum_class = super().__new__(metacls, cls, bases, classdict) | |||
| enum_class._member_names_ = [] # names in definition order | |||
| enum_class._member_map_ = OrderedDict() # name->value map | |||
| enum_class._member_type_ = member_type | |||
| # save attributes from super classes so we know if we can take | |||
| # the shortcut of storing members in the class dict | |||
| base_attributes = {a for b in enum_class.mro() for a in b.__dict__} | |||
| # Reverse value->name map for hashable values. | |||
| enum_class._value2member_map_ = {} | |||
| # If a custom type is mixed into the Enum, and it does not know how | |||
| # to pickle itself, pickle.dumps will succeed but pickle.loads will | |||
| # fail. Rather than have the error show up later and possibly far | |||
| # from the source, sabotage the pickle protocol for this class so | |||
| # that pickle.dumps also fails. | |||
| # | |||
| # However, if the new class implements its own __reduce_ex__, do not | |||
| # sabotage -- it's on them to make sure it works correctly. We use | |||
| # __reduce_ex__ instead of any of the others as it is preferred by | |||
| # pickle over __reduce__, and it handles all pickle protocols. | |||
| if "__reduce_ex__" not in classdict: | |||
| if member_type is not object: | |||
| methods = ( | |||
| "__getnewargs_ex__", | |||
| "__getnewargs__", | |||
| "__reduce_ex__", | |||
| "__reduce__", | |||
| ) | |||
| if not any(m in member_type.__dict__ for m in methods): | |||
| _make_class_unpicklable(enum_class) | |||
| # instantiate them, checking for duplicates as we go | |||
| # we instantiate first instead of checking for duplicates first in case | |||
| # a custom __new__ is doing something funky with the values -- such as | |||
| # auto-numbering ;) | |||
| for member_name in classdict._member_names: | |||
| value = enum_members[member_name] | |||
| if not isinstance(value, tuple): | |||
| args = (value,) | |||
| else: | |||
| args = value | |||
| if member_type is tuple: # special case for tuple enums | |||
| args = (args,) # wrap it one more time | |||
| if not use_args: | |||
| enum_member = __new__(enum_class) | |||
| if not hasattr(enum_member, "_value_"): | |||
| enum_member._value_ = value | |||
| else: | |||
| enum_member = __new__(enum_class, *args) | |||
| if not hasattr(enum_member, "_value_"): | |||
| if member_type is object: | |||
| enum_member._value_ = value | |||
| else: | |||
| enum_member._value_ = member_type(*args) | |||
| value = enum_member._value_ | |||
| enum_member._name_ = member_name | |||
| enum_member.__objclass__ = enum_class | |||
| enum_member.__init__(*args) | |||
| # If another member with the same value was already defined, the | |||
| # new member becomes an alias to the existing one. | |||
| for name, canonical_member in enum_class._member_map_.items(): | |||
| if canonical_member._value_ == enum_member._value_: | |||
| enum_member = canonical_member | |||
| break | |||
| else: | |||
| # Aliases don't appear in member names (only in __members__). | |||
| enum_class._member_names_.append(member_name) | |||
| # performance boost for any member that would not shadow | |||
| # a DynamicClassAttribute | |||
| if member_name not in base_attributes: | |||
| setattr(enum_class, member_name, enum_member) | |||
| # now add to _member_map_ | |||
| enum_class._member_map_[member_name] = enum_member | |||
| try: | |||
| # This may fail if value is not hashable. We can't add the value | |||
| # to the map, and by-value lookups for this value will be | |||
| # linear. | |||
| enum_class._value2member_map_[value] = enum_member | |||
| except TypeError: | |||
| pass | |||
| # double check that repr and friends are not the mixin's or various | |||
| # things break (such as pickle) | |||
| for name in ("__repr__", "__str__", "__format__", "__reduce_ex__"): | |||
| class_method = getattr(enum_class, name) | |||
| obj_method = getattr(member_type, name, None) | |||
| enum_method = getattr(first_enum, name, None) | |||
| if obj_method is not None and obj_method is class_method: | |||
| setattr(enum_class, name, enum_method) | |||
| # replace any other __new__ with our own (as long as Enum is not None, | |||
| # anyway) -- again, this is to support pickle | |||
| if Enum is not None: | |||
| # if the user defined their own __new__, save it before it gets | |||
| # clobbered in case they subclass later | |||
| if save_new: | |||
| enum_class.__new_member__ = __new__ | |||
| enum_class.__new__ = Enum.__new__ | |||
| # py3 support for definition order (helps keep py2/py3 code in sync) | |||
| if _order_ is not None: | |||
| if isinstance(_order_, str): | |||
| _order_ = _order_.replace(",", " ").split() | |||
| if _order_ != enum_class._member_names_: | |||
| raise TypeError("member order does not match _order_") | |||
| return enum_class | |||
| def __bool__(self): | |||
| """ | |||
| classes/types should always be True. | |||
| """ | |||
| return True | |||
| def __call__( | |||
| cls, value, names=None, *, module=None, qualname=None, type=None, start=1 | |||
| ): | |||
| """ | |||
| Either returns an existing member, or creates a new enum class. | |||
| This method is used both when an enum class is given a value to match | |||
| to an enumeration member (i.e. Color(3)) and for the functional API | |||
| (i.e. Color = Enum('Color', names='RED GREEN BLUE')). | |||
| When used for the functional API: | |||
| `value` will be the name of the new class. | |||
| `names` should be either a string of white-space/comma delimited names | |||
| (values will start at `start`), or an iterator/mapping of name, value pairs. | |||
| `module` should be set to the module this class is being created in; | |||
| if it is not set, an attempt to find that module will be made, but if | |||
| it fails the class will not be picklable. | |||
| `qualname` should be set to the actual location this class can be found | |||
| at in its module; by default it is set to the global scope. If this is | |||
| not correct, unpickling will fail in some circumstances. | |||
| `type`, if set, will be mixed in as the first base class. | |||
| """ | |||
| if names is None: # simple value lookup | |||
| return cls.__new__(cls, value) | |||
| # otherwise, functional API: we're creating a new Enum type | |||
| return cls._create_( | |||
| value, names, module=module, qualname=qualname, type=type, start=start | |||
| ) | |||
| def __contains__(cls, member): | |||
| return isinstance(member, cls) and member._name_ in cls._member_map_ | |||
| def __delattr__(cls, attr): | |||
| # nicer error message when someone tries to delete an attribute | |||
| # (see issue19025). | |||
| if attr in cls._member_map_: | |||
| raise AttributeError("%s: cannot delete Enum member." % cls.__name__) | |||
| super().__delattr__(attr) | |||
| def __dir__(self): | |||
| return [ | |||
| "__class__", | |||
| "__doc__", | |||
| "__members__", | |||
| "__module__", | |||
| ] + self._member_names_ | |||
| def __getattr__(cls, name): | |||
| """ | |||
| Return the enum member matching `name` | |||
| We use __getattr__ instead of descriptors or inserting into the enum | |||
| class' __dict__ in order to support `name` and `value` being both | |||
| properties for enum members (which live in the class' __dict__) and | |||
| enum members themselves. | |||
| """ | |||
| if _is_dunder(name): | |||
| raise AttributeError(name) | |||
| try: | |||
| return cls._member_map_[name] | |||
| except KeyError: | |||
| raise AttributeError(name) from None | |||
| def __getitem__(cls, name): | |||
| return cls._member_map_[name] | |||
| def __iter__(cls): | |||
| return (cls._member_map_[name] for name in cls._member_names_) | |||
| def __len__(cls): | |||
| return len(cls._member_names_) | |||
| @property | |||
| def __members__(cls): | |||
| """ | |||
| Returns a mapping of member name->value. | |||
| This mapping lists all enum members, including aliases. Note that this | |||
| is a read-only view of the internal mapping. | |||
| """ | |||
| return MappingProxyType(cls._member_map_) | |||
| def __repr__(cls): | |||
| return "<enum %r>" % cls.__name__ | |||
| def __reversed__(cls): | |||
| return (cls._member_map_[name] for name in reversed(cls._member_names_)) | |||
| def __setattr__(cls, name, value): | |||
| """ | |||
| Block attempts to reassign Enum members. | |||
| A simple assignment to the class namespace only changes one of the | |||
| several possible ways to get an Enum member from the Enum class, | |||
| resulting in an inconsistent Enumeration. | |||
| """ | |||
| member_map = cls.__dict__.get("_member_map_", {}) | |||
| if name in member_map: | |||
| raise AttributeError("Cannot reassign members.") | |||
| super().__setattr__(name, value) | |||
| def _create_( | |||
| cls, class_name, names=None, *, module=None, qualname=None, type=None, start=1 | |||
| ): | |||
| """ | |||
| Convenience method to create a new Enum class. | |||
| `names` can be: | |||
| * A string containing member names, separated either with spaces or | |||
| commas. Values are incremented by 1 from `start`. | |||
| * An iterable of member names. Values are incremented by 1 from `start`. | |||
| * An iterable of (member name, value) pairs. | |||
| * A mapping of member name -> value pairs. | |||
| """ | |||
| metacls = cls.__class__ | |||
| bases = (cls,) if type is None else (type, cls) | |||
| _, first_enum = cls._get_mixins_(bases) | |||
| classdict = metacls.__prepare__(class_name, bases) | |||
| # special processing needed for names? | |||
| if isinstance(names, str): | |||
| names = names.replace(",", " ").split() | |||
| if isinstance(names, (tuple, list)) and names and isinstance(names[0], str): | |||
| original_names, names = names, [] | |||
| last_values = [] | |||
| for count, name in enumerate(original_names): | |||
| value = first_enum._generate_next_value_( | |||
| name, start, count, last_values[:] | |||
| ) | |||
| last_values.append(value) | |||
| names.append((name, value)) | |||
| # Here, names is either an iterable of (name, value) or a mapping. | |||
| for item in names: | |||
| if isinstance(item, str): | |||
| member_name, member_value = item, names[item] | |||
| else: | |||
| member_name, member_value = item | |||
| classdict[member_name] = member_value | |||
| enum_class = metacls.__new__(metacls, class_name, bases, classdict) | |||
| # TODO: replace the frame hack if a blessed way to know the calling | |||
| # module is ever developed | |||
| if module is None: | |||
| try: | |||
| module = sys._getframe(2).f_globals["__name__"] | |||
| except (AttributeError, ValueError) as exc: | |||
| pass | |||
| if module is None: | |||
| _make_class_unpicklable(enum_class) | |||
| else: | |||
| enum_class.__module__ = module | |||
| if qualname is not None: | |||
| enum_class.__qualname__ = qualname | |||
| return enum_class | |||
| @staticmethod | |||
| def _get_mixins_(bases): | |||
| """ | |||
| Returns the type for creating enum members, and the first inherited | |||
| enum class. | |||
| bases: the tuple of bases that was given to __new__ | |||
| """ | |||
| if not bases: | |||
| return object, Enum | |||
| # double check that we are not subclassing a class with existing | |||
| # enumeration members; while we're at it, see if any other data | |||
| # type has been mixed in so we can use the correct __new__ | |||
| member_type = first_enum = None | |||
| for base in bases: | |||
| if base is not Enum and issubclass(base, Enum) and base._member_names_: | |||
| raise TypeError("Cannot extend enumerations") | |||
| # base is now the last base in bases | |||
| if not issubclass(base, Enum): | |||
| raise TypeError( | |||
| "new enumerations must be created as " | |||
| "`ClassName([mixin_type,] enum_type)`" | |||
| ) | |||
| # get correct mix-in type (either mix-in type of Enum subclass, or | |||
| # first base if last base is Enum) | |||
| if not issubclass(bases[0], Enum): | |||
| member_type = bases[0] # first data type | |||
| first_enum = bases[-1] # enum type | |||
| else: | |||
| for base in bases[0].__mro__: | |||
| # most common: (IntEnum, int, Enum, object) | |||
| # possible: (<Enum 'AutoIntEnum'>, <Enum 'IntEnum'>, | |||
| # <class 'int'>, <Enum 'Enum'>, | |||
| # <class 'object'>) | |||
| if issubclass(base, Enum): | |||
| if first_enum is None: | |||
| first_enum = base | |||
| else: | |||
| if member_type is None: | |||
| member_type = base | |||
| return member_type, first_enum | |||
| @staticmethod | |||
| def _find_new_(classdict, member_type, first_enum): | |||
| """ | |||
| Returns the __new__ to be used for creating the enum members. | |||
| classdict: the class dictionary given to __new__ | |||
| member_type: the data type whose __new__ will be used by default | |||
| first_enum: enumeration to check for an overriding __new__ | |||
| """ | |||
| # now find the correct __new__, checking to see of one was defined | |||
| # by the user; also check earlier enum classes in case a __new__ was | |||
| # saved as __new_member__ | |||
| __new__ = classdict.get("__new__", None) | |||
| # should __new__ be saved as __new_member__ later? | |||
| save_new = __new__ is not None | |||
| if __new__ is None: | |||
| # check all possibles for __new_member__ before falling back to | |||
| # __new__ | |||
| for method in ("__new_member__", "__new__"): | |||
| for possible in (member_type, first_enum): | |||
| target = getattr(possible, method, None) | |||
| if target not in { | |||
| None, | |||
| None.__new__, | |||
| object.__new__, | |||
| Enum.__new__, | |||
| }: | |||
| __new__ = target | |||
| break | |||
| if __new__ is not None: | |||
| break | |||
| else: | |||
| __new__ = object.__new__ | |||
| # if a non-object.__new__ is used then whatever value/tuple was | |||
| # assigned to the enum member name will be passed to __new__ and to the | |||
| # new enum member's __init__ | |||
| if __new__ is object.__new__: | |||
| use_args = False | |||
| else: | |||
| use_args = True | |||
| return __new__, save_new, use_args | |||
| class Enum(metaclass=EnumMeta): | |||
| """ | |||
| Generic enumeration. | |||
| Derive from this class to define new enumerations. | |||
| """ | |||
| def __new__(cls, value): | |||
| # all enum instances are actually created during class construction | |||
| # without calling this method; this method is called by the metaclass' | |||
| # __call__ (i.e. Color(3) ), and by pickle | |||
| if type(value) is cls: | |||
| # For lookups like Color(Color.RED) | |||
| return value | |||
| # by-value search for a matching enum member | |||
| # see if it's in the reverse mapping (for hashable values) | |||
| try: | |||
| if value in cls._value2member_map_: | |||
| return cls._value2member_map_[value] | |||
| except TypeError: | |||
| # not there, now do long search -- O(n) behavior | |||
| for member in cls._member_map_.values(): | |||
| if member._value_ == value: | |||
| return member | |||
| # still not found -- try _missing_ hook | |||
| return cls._missing_(value) | |||
| def _generate_next_value_(name, start, count, last_values): | |||
| for last_value in reversed(last_values): | |||
| try: | |||
| return last_value + 1 | |||
| except TypeError: | |||
| pass | |||
| else: | |||
| return start | |||
| @classmethod | |||
| def _missing_(cls, value): | |||
| raise ValueError("%r is not a valid %s" % (value, cls.__name__)) | |||
| def __repr__(self): | |||
| return "<%s.%s: %r>" % (self.__class__.__name__, self._name_, self._value_) | |||
| def __str__(self): | |||
| return "%s.%s" % (self.__class__.__name__, self._name_) | |||
| def __dir__(self): | |||
| added_behavior = [ | |||
| m | |||
| for cls in self.__class__.mro() | |||
| for m in cls.__dict__ | |||
| if m[0] != "_" and m not in self._member_map_ | |||
| ] | |||
| return ["__class__", "__doc__", "__module__"] + added_behavior | |||
| def __format__(self, format_spec): | |||
| # mixed-in Enums should use the mixed-in type's __format__, otherwise | |||
| # we can get strange results with the Enum name showing up instead of | |||
| # the value | |||
| # pure Enum branch | |||
| if self._member_type_ is object: | |||
| cls = str | |||
| val = str(self) | |||
| # mix-in branch | |||
| else: | |||
| cls = self._member_type_ | |||
| val = self._value_ | |||
| return cls.__format__(val, format_spec) | |||
| def __hash__(self): | |||
| return hash(self._name_) | |||
| def __reduce_ex__(self, proto): | |||
| return self.__class__, (self._value_,) | |||
| # DynamicClassAttribute is used to provide access to the `name` and | |||
| # `value` properties of enum members while keeping some measure of | |||
| # protection from modification, while still allowing for an enumeration | |||
| # to have members named `name` and `value`. This works because enumeration | |||
| # members are not set directly on the enum class -- __getattr__ is | |||
| # used to look them up. | |||
| @DynamicClassAttribute | |||
| def name(self): | |||
| """The name of the Enum member.""" | |||
| return self._name_ | |||
| @DynamicClassAttribute | |||
| def value(self): | |||
| """The value of the Enum member.""" | |||
| return self._value_ | |||
| @classmethod | |||
| def _convert(cls, name, module, filter, source=None): | |||
| """ | |||
| Create a new Enum subclass that replaces a collection of global constants | |||
| """ | |||
| # convert all constants from source (or module) that pass filter() to | |||
| # a new Enum called name, and export the enum and its members back to | |||
| # module; | |||
| # also, replace the __reduce_ex__ method so unpickling works in | |||
| # previous Python versions | |||
| module_globals = vars(sys.modules[module]) | |||
| if source: | |||
| source = vars(source) | |||
| else: | |||
| source = module_globals | |||
| # We use an OrderedDict of sorted source keys so that the | |||
| # _value2member_map is populated in the same order every time | |||
| # for a consistent reverse mapping of number to name when there | |||
| # are multiple names for the same number rather than varying | |||
| # between runs due to hash randomization of the module dictionary. | |||
| members = [(name, source[name]) for name in source.keys() if filter(name)] | |||
| try: | |||
| # sort by value | |||
| members.sort(key=lambda t: (t[1], t[0])) | |||
| except TypeError: | |||
| # unless some values aren't comparable, in which case sort by name | |||
| members.sort(key=lambda t: t[0]) | |||
| cls = cls(name, members, module=module) | |||
| cls.__reduce_ex__ = _reduce_ex_by_name | |||
| module_globals.update(cls.__members__) | |||
| module_globals[name] = cls | |||
| return cls | |||
| class IntEnum(int, Enum): | |||
| """Enum where members are also (and must be) ints""" | |||
| def _reduce_ex_by_name(self, proto): | |||
| return self.name | |||
| class Flag(Enum): | |||
| """Support for flags""" | |||
| def _generate_next_value_(name, start, count, last_values): | |||
| """ | |||
| Generate the next value when not given. | |||
| name: the name of the member | |||
| start: the initital start value or None | |||
| count: the number of existing members | |||
| last_value: the last value assigned or None | |||
| """ | |||
| if not count: | |||
| return start if start is not None else 1 | |||
| for last_value in reversed(last_values): | |||
| try: | |||
| high_bit = _high_bit(last_value) | |||
| break | |||
| except Exception: | |||
| raise TypeError("Invalid Flag value: %r" % last_value) from None | |||
| return 2 ** (high_bit + 1) | |||
| @classmethod | |||
| def _missing_(cls, value): | |||
| original_value = value | |||
| if value < 0: | |||
| value = ~value | |||
| possible_member = cls._create_pseudo_member_(value) | |||
| if original_value < 0: | |||
| possible_member = ~possible_member | |||
| return possible_member | |||
| @classmethod | |||
| def _create_pseudo_member_(cls, value): | |||
| """ | |||
| Create a composite member iff value contains only members. | |||
| """ | |||
| pseudo_member = cls._value2member_map_.get(value, None) | |||
| if pseudo_member is None: | |||
| # verify all bits are accounted for | |||
| _, extra_flags = _decompose(cls, value) | |||
| if extra_flags: | |||
| raise ValueError("%r is not a valid %s" % (value, cls.__name__)) | |||
| # construct a singleton enum pseudo-member | |||
| pseudo_member = object.__new__(cls) | |||
| pseudo_member._name_ = None | |||
| pseudo_member._value_ = value | |||
| # use setdefault in case another thread already created a composite | |||
| # with this value | |||
| pseudo_member = cls._value2member_map_.setdefault(value, pseudo_member) | |||
| return pseudo_member | |||
| def __contains__(self, other): | |||
| if not isinstance(other, self.__class__): | |||
| return NotImplemented | |||
| return other._value_ & self._value_ == other._value_ | |||
| def __repr__(self): | |||
| cls = self.__class__ | |||
| if self._name_ is not None: | |||
| return "<%s.%s: %r>" % (cls.__name__, self._name_, self._value_) | |||
| members, uncovered = _decompose(cls, self._value_) | |||
| return "<%s.%s: %r>" % ( | |||
| cls.__name__, | |||
| "|".join([str(m._name_ or m._value_) for m in members]), | |||
| self._value_, | |||
| ) | |||
| def __str__(self): | |||
| cls = self.__class__ | |||
| if self._name_ is not None: | |||
| return "%s.%s" % (cls.__name__, self._name_) | |||
| members, uncovered = _decompose(cls, self._value_) | |||
| if len(members) == 1 and members[0]._name_ is None: | |||
| return "%s.%r" % (cls.__name__, members[0]._value_) | |||
| else: | |||
| return "%s.%s" % ( | |||
| cls.__name__, | |||
| "|".join([str(m._name_ or m._value_) for m in members]), | |||
| ) | |||
| def __bool__(self): | |||
| return bool(self._value_) | |||
| def __or__(self, other): | |||
| if not isinstance(other, self.__class__): | |||
| return NotImplemented | |||
| return self.__class__(self._value_ | other._value_) | |||
| def __and__(self, other): | |||
| if not isinstance(other, self.__class__): | |||
| return NotImplemented | |||
| return self.__class__(self._value_ & other._value_) | |||
| def __xor__(self, other): | |||
| if not isinstance(other, self.__class__): | |||
| return NotImplemented | |||
| return self.__class__(self._value_ ^ other._value_) | |||
| def __invert__(self): | |||
| members, uncovered = _decompose(self.__class__, self._value_) | |||
| inverted_members = [ | |||
| m | |||
| for m in self.__class__ | |||
| if m not in members and not m._value_ & self._value_ | |||
| ] | |||
| inverted = reduce(_or_, inverted_members, self.__class__(0)) | |||
| return self.__class__(inverted) | |||
| class IntFlag(int, Flag): | |||
| """Support for integer-based Flags""" | |||
| @classmethod | |||
| def _missing_(cls, value): | |||
| if not isinstance(value, int): | |||
| raise ValueError("%r is not a valid %s" % (value, cls.__name__)) | |||
| new_member = cls._create_pseudo_member_(value) | |||
| return new_member | |||
| @classmethod | |||
| def _create_pseudo_member_(cls, value): | |||
| pseudo_member = cls._value2member_map_.get(value, None) | |||
| if pseudo_member is None: | |||
| need_to_create = [value] | |||
| # get unaccounted for bits | |||
| _, extra_flags = _decompose(cls, value) | |||
| # timer = 10 | |||
| while extra_flags: | |||
| # timer -= 1 | |||
| bit = _high_bit(extra_flags) | |||
| flag_value = 2 ** bit | |||
| if ( | |||
| flag_value not in cls._value2member_map_ | |||
| and flag_value not in need_to_create | |||
| ): | |||
| need_to_create.append(flag_value) | |||
| if extra_flags == -flag_value: | |||
| extra_flags = 0 | |||
| else: | |||
| extra_flags ^= flag_value | |||
| for value in reversed(need_to_create): | |||
| # construct singleton pseudo-members | |||
| pseudo_member = int.__new__(cls, value) | |||
| pseudo_member._name_ = None | |||
| pseudo_member._value_ = value | |||
| # use setdefault in case another thread already created a composite | |||
| # with this value | |||
| pseudo_member = cls._value2member_map_.setdefault(value, pseudo_member) | |||
| return pseudo_member | |||
| def __or__(self, other): | |||
| if not isinstance(other, (self.__class__, int)): | |||
| return NotImplemented | |||
| result = self.__class__(self._value_ | self.__class__(other)._value_) | |||
| return result | |||
| def __and__(self, other): | |||
| if not isinstance(other, (self.__class__, int)): | |||
| return NotImplemented | |||
| return self.__class__(self._value_ & self.__class__(other)._value_) | |||
| def __xor__(self, other): | |||
| if not isinstance(other, (self.__class__, int)): | |||
| return NotImplemented | |||
| return self.__class__(self._value_ ^ self.__class__(other)._value_) | |||
| __ror__ = __or__ | |||
| __rand__ = __and__ | |||
| __rxor__ = __xor__ | |||
| def __invert__(self): | |||
| result = self.__class__(~self._value_) | |||
| return result | |||
| def _high_bit(value): | |||
| """returns index of highest bit, or -1 if value is zero or negative""" | |||
| return value.bit_length() - 1 | |||
| def unique(enumeration): | |||
| """Class decorator for enumerations ensuring unique member values.""" | |||
| duplicates = [] | |||
| for name, member in enumeration.__members__.items(): | |||
| if name != member.name: | |||
| duplicates.append((name, member.name)) | |||
| if duplicates: | |||
| alias_details = ", ".join( | |||
| ["%s -> %s" % (alias, name) for (alias, name) in duplicates] | |||
| ) | |||
| raise ValueError( | |||
| "duplicate values found in %r: %s" % (enumeration, alias_details) | |||
| ) | |||
| return enumeration | |||
| def _decompose(flag, value): | |||
| """Extract all members from the value.""" | |||
| # _decompose is only called if the value is not named | |||
| not_covered = value | |||
| negative = value < 0 | |||
| # issue29167: wrap accesses to _value2member_map_ in a list to avoid race | |||
| # conditions between iterating over it and having more psuedo- | |||
| # members added to it | |||
| if negative: | |||
| # only check for named flags | |||
| flags_to_check = [ | |||
| (m, v) | |||
| for v, m in list(flag._value2member_map_.items()) | |||
| if m.name is not None | |||
| ] | |||
| else: | |||
| # check for named flags and powers-of-two flags | |||
| flags_to_check = [ | |||
| (m, v) | |||
| for v, m in list(flag._value2member_map_.items()) | |||
| if m.name is not None or _power_of_two(v) | |||
| ] | |||
| members = [] | |||
| for member, member_value in flags_to_check: | |||
| if member_value and member_value & value == member_value: | |||
| members.append(member) | |||
| not_covered &= ~member_value | |||
| if not members and value in flag._value2member_map_: | |||
| members.append(flag._value2member_map_[value]) | |||
| members.sort(key=lambda m: m._value_, reverse=True) | |||
| if len(members) > 1 and members[0].value == value: | |||
| # we have the breakdown, don't need the value member itself | |||
| members.pop(0) | |||
| return members, not_covered | |||
| def _power_of_two(value): | |||
| if value < 1: | |||
| return False | |||
| return value == 2 ** _high_bit(value) | |||
| @@ -1,94 +0,0 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import warnings | |||
| from ..._imperative_rt.ops import OprAttr | |||
| from . import param_defs | |||
| def make_param(param, ptype, kwargs): | |||
| if param is not None: | |||
| if isinstance(param, ptype): | |||
| return param | |||
| param = [param] | |||
| assert len(param) == len( | |||
| ptype.__slots__ | |||
| ), "{} needs {} params, but {} are provided".format( | |||
| ptype, len(ptype.__slots__), len(param) | |||
| ) | |||
| return ptype(*param) | |||
| ckw = {} | |||
| for i in ptype.__slots__: | |||
| val = kwargs.pop(i, ckw) | |||
| if val is not ckw: | |||
| ckw[i] = val | |||
| return ptype(**ckw) | |||
| class PodOpVisitor: | |||
| __name2subclass = {} | |||
| __c = None | |||
| name = None | |||
| param_names = [] | |||
| config = None | |||
| def __init__(self, config, **params): | |||
| self.config = config | |||
| assert set(params) == set(self.param_names) | |||
| self.__dict__.update(params) | |||
| def __init_subclass__(cls, **kwargs): | |||
| super().__init_subclass__(**kwargs) # python 3.5 does not have this | |||
| name = cls.name | |||
| if name in cls.__name2subclass: | |||
| if not issubclass(cls, cls.__name2subclass[name]): | |||
| warnings.warn("Multiple subclasses for bultin op: %s" % name) | |||
| cls.__name2subclass[name] = cls | |||
| def to_c(self): | |||
| if self.__c: | |||
| return self.__c | |||
| op = OprAttr() | |||
| op.type = self.name | |||
| if self.config is not None: | |||
| op.config = self.config | |||
| # first 4 bytes is TAG, has to remove them currently | |||
| op.param = b"".join(self.__dict__[k].serialize()[4:] for k in self.param_names) | |||
| self.__c = op | |||
| return op | |||
| def __eq__(self, rhs): | |||
| return self.to_c() == rhs.to_c() | |||
| def __repr__(self): | |||
| name = self.__class__.__name__ | |||
| if self.__c: | |||
| return "{}(<binary data>)".format(name) | |||
| kwargs = {} | |||
| for i in self.param_names: | |||
| p = self.__dict__[i] | |||
| if isinstance(p, param_defs._ParamDefBase): | |||
| for k in p.__slots__: | |||
| v = getattr(p, k) | |||
| if isinstance(v, param_defs._EnumBase): | |||
| v = v.name | |||
| kwargs[k] = repr(v) | |||
| else: | |||
| kwargs[i] = repr(p) | |||
| if self.config: | |||
| if len(self.config.comp_node_arr) == 1: | |||
| kwargs["device"] = "'%s'" % self.config.comp_node | |||
| return "{}({})".format( | |||
| name, ", ".join("{}={}".format(k, v) for k, v in kwargs.items()) | |||
| ) | |||
| @@ -1,194 +0,0 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import collections | |||
| import ctypes | |||
| from ..._imperative_rt import OperatorNodeConfig as Config | |||
| from . import param_defs | |||
| from .helper import PodOpVisitor, make_param | |||
| __all__ = ["ConvolutionBackwardData", "Dimshuffle", "Reshape", "AxisAddRemove"] | |||
| class TensorShape: | |||
| MAX_NDIM = 7 | |||
| class ConvolutionBackwardData(PodOpVisitor): | |||
| param_names = ( | |||
| "param", | |||
| "execution_polity", | |||
| ) | |||
| name = "ConvolutionBackwardDataV1" | |||
| def __init__( | |||
| self, | |||
| *, | |||
| param=None, | |||
| execution_polity=None, | |||
| name=None, | |||
| comp_node=None, | |||
| config=None, | |||
| dtype=None, | |||
| **kwargs | |||
| ): | |||
| config = config or Config() | |||
| if name: | |||
| config.name = name | |||
| if comp_node: | |||
| config.comp_node = comp_node | |||
| if dtype: | |||
| config.dtype = dtype | |||
| self.config = config | |||
| self.param = make_param(param, param_defs.Convolution, kwargs) | |||
| self.execution_polity = make_param( | |||
| execution_polity, param_defs.ExecutionPolicy, kwargs | |||
| ) | |||
| assert not kwargs, "extra kwargs: {}".format(kwargs) | |||
| class Dimshuffle(PodOpVisitor): | |||
| name = "Dimshuffle" | |||
| param_names = ("pattern",) | |||
| class Pattern(ctypes.Structure): | |||
| Pattern_Array = ctypes.c_int32 * TensorShape.MAX_NDIM | |||
| _fields_ = [ | |||
| ("length", ctypes.c_uint32), | |||
| ("pattern", Pattern_Array), | |||
| ("ndim", ctypes.c_uint32), | |||
| ] | |||
| def serialize(self): | |||
| return bytes(ctypes.c_uint32(0)) + bytes(self) | |||
| def __init__(self, pattern, ndim=0): | |||
| assert isinstance(pattern, collections.abc.Iterable) | |||
| assert len(pattern) <= TensorShape.MAX_NDIM | |||
| pattern_array = Dimshuffle.Pattern.Pattern_Array() | |||
| for idx, v in enumerate(pattern): | |||
| pattern_array[idx] = ctypes.c_int32(-1 if v == "x" else int(v)) | |||
| self.pattern = Dimshuffle.Pattern(len(pattern), pattern_array, ndim) | |||
| class Reshape(PodOpVisitor): | |||
| name = "ReshapeV1" | |||
| param_names = ("unspec_axis",) | |||
| def __init__(self, unspec_axis=None): | |||
| if unspec_axis is None: | |||
| self.unspec_axis = param_defs.OptionalAxisV1() | |||
| else: | |||
| self.unspec_axis = param_defs.OptionalAxisV1(unspec_axis) | |||
| class AxisNum(ctypes.Structure): | |||
| _fields_ = [ | |||
| ("m_num", ctypes.c_int), | |||
| ] | |||
| class AxisDesc(ctypes.Structure): | |||
| class Method(ctypes.c_int): | |||
| ADD_1 = 0 | |||
| REMOVE = 1 | |||
| _fields_ = [ | |||
| ("method", Method), | |||
| ("axis", AxisNum), | |||
| ] | |||
| @classmethod | |||
| def make_add(cls, axis): | |||
| return cls(cls.Method.ADD_1, AxisNum(axis)) | |||
| @classmethod | |||
| def make_remove(cls, axis): | |||
| return cls(cls.Method.REMOVE, AxisNum(axis)) | |||
| class AxisAddRemove(PodOpVisitor): | |||
| name = "AxisAddRemove" | |||
| param_names = ("param",) | |||
| AxisDesc = AxisDesc | |||
| class Param(ctypes.Structure): | |||
| MAX_DESC_SIZE = TensorShape.MAX_NDIM * 2 | |||
| _fields_ = [("nr_desc", ctypes.c_uint32), ("desc", AxisDesc * MAX_DESC_SIZE)] | |||
| def __init__(self, *args): | |||
| super().__init__() | |||
| self.nr_desc = len(args) | |||
| for i, a in enumerate(args): | |||
| self.desc[i] = a | |||
| def serialize(self): | |||
| return bytes(ctypes.c_uint32(0)) + bytes(self) | |||
| def __init__(self, param): | |||
| assert isinstance(param, self.Param) | |||
| self.param = param | |||
| del AxisDesc | |||
| class IndexingOpBase(PodOpVisitor): | |||
| param_names = ("index_desc",) | |||
| class IndexDescMaskDump(ctypes.Structure): | |||
| class Item(ctypes.Structure): | |||
| _fields_ = [ | |||
| ("axis", ctypes.c_int8), | |||
| ("begin", ctypes.c_bool), | |||
| ("end", ctypes.c_bool), | |||
| ("step", ctypes.c_bool), | |||
| ("idx", ctypes.c_bool), | |||
| ] | |||
| Item_Array = Item * TensorShape.MAX_NDIM | |||
| _fields_ = [("nr_item", ctypes.c_uint8), ("items", Item_Array)] | |||
| def serialize(self): | |||
| return bytes(ctypes.c_uint32(0)) + bytes(self) | |||
| def __init__(self, items): | |||
| nr_item = len(items) | |||
| assert nr_item <= TensorShape.MAX_NDIM | |||
| item_array = IndexingOpBase.IndexDescMaskDump.Item_Array() | |||
| for idx, item in enumerate(items): | |||
| assert isinstance(item, (tuple, list)) and len(item) == 5 | |||
| item_array[idx] = IndexingOpBase.IndexDescMaskDump.Item(*item) | |||
| self.index_desc = IndexingOpBase.IndexDescMaskDump(nr_item, item_array) | |||
| def _gen_indexing_defs(*names): | |||
| for name in names: | |||
| globals()[name] = type(name, (IndexingOpBase,), dict(name=name)) | |||
| __all__.append(name) | |||
| _gen_indexing_defs( | |||
| "Subtensor", | |||
| "SetSubtensor", | |||
| "IncrSubtensor", | |||
| "IndexingMultiAxisVec", | |||
| "IndexingSetMultiAxisVec", | |||
| "IndexingIncrMultiAxisVec", | |||
| "MeshIndexing", | |||
| "IncrMeshIndexing", | |||
| "SetMeshIndexing", | |||
| "BatchedMeshIndexing", | |||
| "BatchedIncrMeshIndexing", | |||
| "BatchedSetMeshIndexing", | |||
| ) | |||
| @@ -11,25 +11,12 @@ from typing import Union | |||
| from ..._imperative_rt import OpDef, ops | |||
| from ...tensor.core import OpBase, TensorBase, TensorWrapperBase, apply | |||
| from .._internal import all_ops | |||
| from .._internal.helper import PodOpVisitor | |||
| # register OpDef as a "virtual subclass" of OpBase, so any of registered | |||
| # apply(OpBase, ...) rules could work well on OpDef | |||
| OpBase.register(OpDef) | |||
| # forward to apply(OpDef, ...) | |||
| @apply.register() | |||
| def _(op: PodOpVisitor, *args: Union[TensorBase, TensorWrapperBase]): | |||
| return apply(op.to_c(), *args) | |||
| __all__ = ["OpDef", "PodOpVisitor"] | |||
| for k, v in all_ops.__dict__.items(): | |||
| if isinstance(v, type) and issubclass(v, PodOpVisitor): | |||
| globals()[k] = v | |||
| __all__.append(k) | |||
| __all__ = ["OpDef"] | |||
| for k, v in ops.__dict__.items(): | |||
| if isinstance(v, type) and issubclass(v, OpDef): | |||
| @@ -90,7 +90,7 @@ def _reshape(x, shape): | |||
| if unspec_axis is None: | |||
| op = builtin.Reshape() | |||
| else: | |||
| op = builtin.Reshape(unspec_axis=unspec_axis) | |||
| op = builtin.Reshape(axis=unspec_axis) | |||
| (x,) = apply(op, x, shape) | |||
| return x | |||
| @@ -144,8 +144,6 @@ def _logical_binary_elwise(mode, rev=False): | |||
| def _remove_axis(inp: Tensor, axis) -> Tensor: | |||
| Param = builtin.AxisAddRemove.Param | |||
| def get_axes(): | |||
| if axis is None: | |||
| return [i for i, s in enumerate(inp.shape) if s == 1] | |||
| @@ -159,8 +157,7 @@ def _remove_axis(inp: Tensor, axis) -> Tensor: | |||
| axis = sorted(i + inp.ndim if i < 0 else i for i in axis) | |||
| axis = [a - i for i, a in enumerate(axis)] | |||
| param = Param(*map(builtin.AxisAddRemove.AxisDesc.make_remove, axis)) | |||
| op = builtin.AxisAddRemove(param=param) | |||
| op = builtin.RemoveAxis(axis=axis) | |||
| (result,) = apply(op, inp) | |||
| if len(axis) == inp.ndim: | |||
| setscalar(result) | |||
| @@ -134,7 +134,7 @@ def astype(x, dtype): | |||
| dtype = np.dtype(dtype) | |||
| if not is_equal(x.dtype, dtype): | |||
| isscalar = x.__wrapped__._data._isscalar | |||
| (x,) = apply(builtin.TypeCvt(param=dtype), x) | |||
| (x,) = apply(builtin.TypeCvt(dtype=dtype), x) | |||
| x.__wrapped__._data._isscalar = isscalar | |||
| return x | |||
| @@ -8,7 +8,6 @@ | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from typing import Optional, Tuple | |||
| from ..core._imperative_rt.ops import CollectiveCommMode | |||
| from ..core.autodiff.builtin_op_utils import builtin_op_get_backward_fn | |||
| from ..core.autodiff.grad import ( | |||
| Tracer, | |||
| @@ -110,17 +109,20 @@ def collective_comm(inp, mode, group, device): | |||
| assert isinstance(group, Group) | |||
| if group is None: | |||
| return inp | |||
| op = CollectiveComm() | |||
| op.key = group.key | |||
| op.nr_devices = group.size | |||
| op.rank = group.rank | |||
| op.is_root = op.rank == 0 | |||
| op.local_grad = False | |||
| op.addr, op.port = get_mm_server_addr() | |||
| op.mode = mode | |||
| op.dtype = inp.dtype | |||
| op.backend = get_backend() | |||
| op.comp_node = device | |||
| addr, port = get_mm_server_addr() | |||
| op = CollectiveComm( | |||
| key=group.key, | |||
| nr_devices=group.size, | |||
| rank=group.rank, | |||
| is_root=(group.rank == 0), | |||
| local_grad=False, | |||
| addr=addr, | |||
| port=port, | |||
| mode=mode, | |||
| dtype=inp.dtype, | |||
| backend=get_backend(), | |||
| comp_node=device, | |||
| ) | |||
| return apply(op, inp)[0] | |||
| @@ -134,7 +136,7 @@ def reduce_sum( | |||
| :param group: communication group. | |||
| :param device: execution device. | |||
| """ | |||
| mode = CollectiveCommMode.REDUCE_SUM | |||
| mode = CollectiveComm.Mode.REDUCE_SUM | |||
| return collective_comm(inp, mode, group, device) | |||
| @@ -148,7 +150,7 @@ def broadcast( | |||
| :param group: communication group. | |||
| :param device: execution device. | |||
| """ | |||
| mode = CollectiveCommMode.BROADCAST | |||
| mode = CollectiveComm.Mode.BROADCAST | |||
| return collective_comm(inp, mode, group, device) | |||
| @@ -162,7 +164,7 @@ def all_gather( | |||
| :param group: communication group. | |||
| :param device: execution device. | |||
| """ | |||
| mode = CollectiveCommMode.ALL_GATHER | |||
| mode = CollectiveComm.Mode.ALL_GATHER | |||
| return collective_comm(inp, mode, group, device) | |||
| @@ -176,7 +178,7 @@ def reduce_scatter_sum( | |||
| :param group: communication group. | |||
| :param device: execution device. | |||
| """ | |||
| mode = CollectiveCommMode.REDUCE_SCATTER_SUM | |||
| mode = CollectiveComm.Mode.REDUCE_SCATTER_SUM | |||
| return collective_comm(inp, mode, group, device) | |||
| @@ -190,7 +192,7 @@ def all_reduce_sum( | |||
| :param group: communication group. | |||
| :param device: execution device. | |||
| """ | |||
| mode = CollectiveCommMode.ALL_REDUCE_SUM | |||
| mode = CollectiveComm.Mode.ALL_REDUCE_SUM | |||
| return collective_comm(inp, mode, group, device) | |||
| @@ -204,7 +206,7 @@ def all_reduce_max( | |||
| :param group: communication group. | |||
| :param device: execution device. | |||
| """ | |||
| mode = CollectiveCommMode.ALL_REDUCE_MAX | |||
| mode = CollectiveComm.Mode.ALL_REDUCE_MAX | |||
| return collective_comm(inp, mode, group, device) | |||
| @@ -218,7 +220,7 @@ def all_reduce_min( | |||
| :param group: communication group. | |||
| :param device: execution device. | |||
| """ | |||
| mode = CollectiveCommMode.ALL_REDUCE_MIN | |||
| mode = CollectiveComm.Mode.ALL_REDUCE_MIN | |||
| return collective_comm(inp, mode, group, device) | |||
| @@ -232,7 +234,7 @@ def gather( | |||
| :param group: communication group. | |||
| :param device: execution device. | |||
| """ | |||
| mode = CollectiveCommMode.GATHER | |||
| mode = CollectiveComm.Mode.GATHER | |||
| return collective_comm(inp, mode, group, device) | |||
| @@ -246,7 +248,7 @@ def scatter( | |||
| :param group: communication group. | |||
| :param device: execution device. | |||
| """ | |||
| mode = CollectiveCommMode.SCATTER | |||
| mode = CollectiveComm.Mode.SCATTER | |||
| return collective_comm(inp, mode, group, device) | |||
| @@ -260,7 +262,7 @@ def all_to_all( | |||
| :param group: communication group. | |||
| :param device: execution device. | |||
| """ | |||
| mode = CollectiveCommMode.ALL_TO_ALL | |||
| mode = CollectiveComm.Mode.ALL_TO_ALL | |||
| return collective_comm(inp, mode, group, device) | |||
| @@ -73,27 +73,7 @@ __all__ = [ | |||
| ] | |||
| class _ElemwiseMode(Elemwise.Mode): | |||
| @classmethod | |||
| def __normalize(cls, val): | |||
| if isinstance(val, str): | |||
| if not hasattr(cls, "__member_upper_dict__"): | |||
| cls.__member_upper_dict__ = { | |||
| k.upper(): v for k, v in cls.__members__.items() | |||
| } | |||
| val = cls.__member_upper_dict__.get(val.upper(), val) | |||
| return val | |||
| @classmethod | |||
| def convert(cls, val): | |||
| val = cls.__normalize(val) | |||
| if isinstance(val, cls): | |||
| return val | |||
| return cls(val) | |||
| def _elwise(*args, mode): | |||
| mode = _ElemwiseMode.convert(mode) | |||
| op = builtin.Elemwise(mode) | |||
| tensor_args = list( | |||
| filter(lambda x: isinstance(x, (Tensor, megbrain_graph.VarNode)), args) | |||
| @@ -13,7 +13,6 @@ import numbers | |||
| from typing import Optional, Sequence, Tuple, Union | |||
| from ..core.ops import builtin | |||
| from ..core.ops._internal import param_defs as P | |||
| from ..core.ops.special import Const | |||
| from ..core.tensor import utils | |||
| from ..core.tensor.core import TensorBase, TensorWrapperBase, apply | |||
| @@ -601,9 +600,9 @@ def argsort(inp: Tensor, descending: bool = False) -> Tensor: | |||
| """ | |||
| assert len(inp.shape) <= 2, "Input should be 1d or 2d" | |||
| if descending: | |||
| order = P.Argsort.Order.DESCENDING | |||
| order = "DESCENDING" | |||
| else: | |||
| order = P.Argsort.Order.ASCENDING | |||
| order = "ASCENDING" | |||
| op = builtin.Argsort(order=order) | |||
| if len(inp.shape) == 1: | |||
| @@ -643,9 +642,9 @@ def sort(inp: Tensor, descending: bool = False) -> Tuple[Tensor, Tensor]: | |||
| """ | |||
| assert len(inp.shape) <= 2, "Input should be 1d or 2d" | |||
| if descending: | |||
| order = P.Argsort.Order.DESCENDING | |||
| order = "DESCENDING" | |||
| else: | |||
| order = P.Argsort.Order.ASCENDING | |||
| order = "ASCENDING" | |||
| op = builtin.Argsort(order=order) | |||
| if len(inp.shape) == 1: | |||
| @@ -695,13 +694,12 @@ def topk( | |||
| if descending: | |||
| inp = -inp | |||
| Mode = P.TopK.Mode | |||
| if kth_only: | |||
| mode = Mode.KTH_ONLY | |||
| mode = "KTH_ONLY" | |||
| elif no_sort: | |||
| mode = Mode.VALUE_IDX_NOSORT | |||
| mode = "VALUE_IDX_NOSORT" | |||
| else: | |||
| mode = Mode.VALUE_IDX_SORTED | |||
| mode = "VALUE_IDX_SORTED" | |||
| op = builtin.TopK(mode=mode) | |||
| if not isinstance(k, (TensorBase, TensorWrapperBase)): | |||
| @@ -12,7 +12,6 @@ from typing import Optional, Sequence, Tuple, Union | |||
| from ..core._imperative_rt import CompNode | |||
| from ..core._trace_option import use_symbolic_shape | |||
| from ..core.ops import builtin | |||
| from ..core.ops._internal import param_defs as P | |||
| from ..core.ops.builtin import BatchNorm | |||
| from ..core.ops.special import Const | |||
| from ..core.tensor import megbrain_graph, utils | |||
| @@ -121,11 +120,11 @@ def conv2d( | |||
| ``in_channels`` and ``out_channels`` must be divisible by ``groups``, | |||
| and the shape of weight should be `(groups, out_channel // groups, | |||
| in_channels // groups, height, width)`. | |||
| :type conv_mode: string or :class:`P.Convolution.Mode` | |||
| :type conv_mode: string or :class:`Convolution.Mode` | |||
| :param conv_mode: supports "CROSS_CORRELATION". Default: | |||
| "CROSS_CORRELATION" | |||
| :type compute_mode: string or | |||
| :class:`P.Convolution.ComputeMode` | |||
| :class:`Convolution.ComputeMode` | |||
| :param compute_mode: when set to "DEFAULT", no special requirements will be | |||
| placed on the precision of intermediate results. When set to "FLOAT32", | |||
| "Float32" would be used for accumulator and intermediate result, but only | |||
| @@ -139,8 +138,8 @@ def conv2d( | |||
| pad_h, pad_w = expand_hw(padding) | |||
| dilate_h, dilate_w = expand_hw(dilation) | |||
| Sparse = P.Convolution.Sparse | |||
| sparse_type = Sparse.DENSE if groups == 1 else Sparse.GROUP | |||
| Sparse = builtin.Convolution.Sparse | |||
| sparse_type = "DENSE" if groups == 1 else "GROUP" | |||
| op = builtin.Convolution( | |||
| stride_h=stride_h, | |||
| stride_w=stride_w, | |||
| @@ -187,11 +186,11 @@ def conv_transpose2d( | |||
| ``in_channels`` and ``out_channels`` must be divisible by groups, | |||
| and the shape of weight should be `(groups, out_channel // groups, | |||
| in_channels // groups, height, width)`. Default: 1 | |||
| :type conv_mode: string or :class:`P.Convolution.Mode` | |||
| :type conv_mode: string or :class:`Convolution.Mode` | |||
| :param conv_mode: supports "CROSS_CORRELATION". Default: | |||
| "CROSS_CORRELATION" | |||
| :type compute_mode: string or | |||
| :class:`P.Convolution.ComputeMode` | |||
| :class:`Convolution.ComputeMode` | |||
| :param compute_mode: when set to "DEFAULT", no special requirements will be | |||
| placed on the precision of intermediate results. When set to "FLOAT32", | |||
| "Float32" would be used for accumulator and intermediate result, but only | |||
| @@ -240,8 +239,6 @@ def local_conv2d( | |||
| pad_h, pad_w = expand_hw(padding) | |||
| dilate_h, dilate_w = expand_hw(dilation) | |||
| Sparse = P.Convolution.Sparse | |||
| op = builtin.GroupLocal( | |||
| stride_h=stride_h, | |||
| stride_w=stride_w, | |||
| @@ -251,7 +248,7 @@ def local_conv2d( | |||
| dilate_w=dilate_w, | |||
| mode=conv_mode, | |||
| compute_mode="DEFAULT", | |||
| sparse=Sparse.DENSE, | |||
| sparse="DENSE", | |||
| ) | |||
| inp, weight = utils.convert_inputs(inp, weight) | |||
| (output,) = apply(op, inp, weight) | |||
| @@ -696,19 +693,14 @@ def batch_norm( | |||
| if not training: | |||
| op = builtin.BatchNorm( | |||
| BatchNorm.ParamDim.DIM_1C11, BatchNorm.FwdMode.INFERENCE, eps, 1.0, 1.0, 0.0 | |||
| fwd_mode=BatchNorm.FwdMode.INFERENCE, epsilon=eps, param_dim="DIM_1C11" | |||
| ) | |||
| ret = apply(op, inp, weight, bias, running_mean, running_var)[-1] | |||
| return ret | |||
| else: | |||
| op = builtin.BatchNorm( | |||
| BatchNorm.ParamDim.DIM_1C11, | |||
| BatchNorm.FwdMode.TRAINING, | |||
| eps, | |||
| 1.0 - momentum, | |||
| 1.0, | |||
| 0.0, | |||
| avg_factor=1 - momentum, epsilon=eps, param_dim="DIM_1C11" | |||
| ) | |||
| if has_mean or has_var: | |||
| running_mean = make_full_if_none(running_mean, 0) | |||
| @@ -1638,8 +1630,7 @@ def conv1d( | |||
| pad_h = padding | |||
| dilate_h = dilation | |||
| Sparse = P.Convolution.Sparse | |||
| sparse_type = Sparse.DENSE if groups == 1 else Sparse.GROUP | |||
| sparse_type = "DENSE" if groups == 1 else "GROUP" | |||
| op = builtin.Convolution( | |||
| stride_h=stride_h, | |||
| stride_w=1, | |||
| @@ -41,12 +41,12 @@ def conv_bias_activation( | |||
| ``in_channels`` and ``out_channels`` must be divisible by ``groups``, | |||
| and the shape of weight should be `(groups, out_channel // groups, | |||
| in_channels // groups, height, width)`. | |||
| :type conv_mode: string or :class:`P.Convolution.Mode`. | |||
| :type conv_mode: string or :class:`Convolution.Mode`. | |||
| :param conv_mode: supports 'CROSS_CORRELATION' or 'CONVOLUTION'. Default: | |||
| 'CROSS_CORRELATION' | |||
| :param dtype: support for ``np.dtype``, Default: np.int8 | |||
| :type compute_mode: string or | |||
| :class:`P.Convolution.ComputeMode`. | |||
| :class:`Convolution.ComputeMode`. | |||
| :param compute_mode: when set to "DEFAULT", no special requirements will be | |||
| placed on the precision of intermediate results. When set to "FLOAT32", | |||
| "Float32" would be used for accumulator and intermediate result, but only effective when input and output are of Float16 dtype. | |||
| @@ -56,7 +56,7 @@ def conv_bias_activation( | |||
| sh, sw = _pair_nonzero(stride) | |||
| dh, dw = _pair_nonzero(dilation) | |||
| sparse_type = "DENSE" if groups == 1 else "GROUP" | |||
| op = builtin.ConvBiasForward( | |||
| op = builtin.ConvBias( | |||
| stride_h=sh, | |||
| stride_w=sw, | |||
| pad_h=ph, | |||
| @@ -101,12 +101,12 @@ def batch_conv_bias_activation( | |||
| ``in_channels`` and ``out_channels`` must be divisible by ``groups``, | |||
| and the shape of weight should be `(groups, out_channel // groups, | |||
| in_channels // groups, height, width)`. | |||
| :type conv_mode: string or :class:`P.Convolution.Mode`. | |||
| :type conv_mode: string or :class:`Convolution.Mode`. | |||
| :param conv_mode: supports 'CROSS_CORRELATION' or 'CONVOLUTION'. Default: | |||
| 'CROSS_CORRELATION' | |||
| :param dtype: support for ``np.dtype``, Default: np.int8 | |||
| :type compute_mode: string or | |||
| :class:`P.Convolution.ComputeMode`. | |||
| :class:`Convolution.ComputeMode`. | |||
| :param compute_mode: when set to "DEFAULT", no special requirements will be | |||
| placed on the precision of intermediate results. When set to "FLOAT32", | |||
| "Float32" would be used for accumulator and intermediate result, but only effective when input and output are of Float16 dtype. | |||
| @@ -116,7 +116,7 @@ def batch_conv_bias_activation( | |||
| sh, sw = _pair_nonzero(stride) | |||
| dh, dw = _pair_nonzero(dilation) | |||
| sparse_type = "DENSE" if groups == 1 else "GROUP" | |||
| op = builtin.BatchConvBiasForward( | |||
| op = builtin.BatchConvBias( | |||
| stride_h=sh, | |||
| stride_w=sw, | |||
| pad_h=ph, | |||
| @@ -16,7 +16,6 @@ import numpy as np | |||
| from ..core._imperative_rt import CompNode | |||
| from ..core._wrap import device as as_device | |||
| from ..core.ops import builtin | |||
| from ..core.ops._internal import param_defs as P | |||
| from ..core.ops.special import Const | |||
| from ..core.tensor.core import TensorBase, TensorWrapperBase, apply | |||
| from ..core.tensor.tensor_wrapper import _broadcast, _remove_axis | |||
| @@ -722,7 +721,7 @@ def transpose(inp: Tensor, pattern: Iterable[int]) -> Tensor: | |||
| [1 0]] | |||
| """ | |||
| return inp.transpose(pattern) | |||
| return inp.transpose(list(-1 if _ == "x" else _ for _ in pattern)) | |||
| def reshape(inp: Tensor, target_shape: Iterable[int]) -> Tensor: | |||
| @@ -756,10 +755,6 @@ def reshape(inp: Tensor, target_shape: Iterable[int]) -> Tensor: | |||
| return inp.reshape(target_shape) | |||
| AxisAddRemove = builtin.AxisAddRemove | |||
| AxisDesc = AxisAddRemove.AxisDesc | |||
| def flatten(inp: Tensor, start_axis: int = 0, end_axis: int = -1) -> Tensor: | |||
| r""" | |||
| Reshapes the tensor by flattening the sub-tensor from dimension ``start_axis`` to dimension ``end_axis``. | |||
| @@ -826,7 +821,6 @@ def expand_dims(inp: Tensor, axis: Union[int, Sequence[int]]) -> Tensor: | |||
| (1, 2) | |||
| """ | |||
| Param = builtin.AxisAddRemove.Param | |||
| def get_axes(): | |||
| try: | |||
| @@ -839,8 +833,7 @@ def expand_dims(inp: Tensor, axis: Union[int, Sequence[int]]) -> Tensor: | |||
| ndim = inp.ndim + len(axis) | |||
| axis = sorted(i + ndim if i < 0 else i for i in axis) | |||
| param = Param(*map(builtin.AxisAddRemove.AxisDesc.make_add, axis)) | |||
| op = builtin.AxisAddRemove(param=param) | |||
| op = builtin.AddAxis(axis=axis) | |||
| (result,) = apply(op, inp) | |||
| return result | |||
| @@ -21,9 +21,10 @@ import numpy as np | |||
| from ..core._imperative_rt import GraphProfiler | |||
| from ..core._imperative_rt.ops import ( | |||
| CollectiveComm, | |||
| OprAttr, | |||
| GaussianRNG, | |||
| RemoteRecv, | |||
| RemoteSend, | |||
| UniformRNG, | |||
| VirtualDep, | |||
| ) | |||
| from ..core._trace_option import set_symbolic_shape | |||
| @@ -182,14 +183,7 @@ class trace: | |||
| record = self._seq[self._pc] | |||
| op_, ihandles, ohandles = record | |||
| if op != op_: | |||
| # FIXME: will be removed once better rng implementation is done | |||
| if isinstance(op, OprAttr) and ( | |||
| op.type in ("UniformRNG", "GaussianRNG") and op.type == op_.type | |||
| ): | |||
| if op.param[8:] != op_.param[8:]: | |||
| raise TraceMismatchError("op different from last time") | |||
| else: | |||
| raise TraceMismatchError("op different from last time") | |||
| raise TraceMismatchError("op different from last time") | |||
| if len(ihandles) != len(args): | |||
| raise TraceMismatchError("op input size different from last time") | |||
| @@ -10,7 +10,6 @@ from typing import Tuple, Union | |||
| import numpy as np | |||
| from ..core.ops._internal import param_defs as P | |||
| from ..functional import conv1d, conv2d, conv_transpose2d, local_conv2d, relu | |||
| from ..functional.types import _pair, _pair_nonzero | |||
| from ..tensor import Parameter | |||
| @@ -156,8 +155,6 @@ class Conv1d(_ConvNd): | |||
| (2, 1, 2) | |||
| """ | |||
| _conv_mode_type = P.Convolution.Mode | |||
| _compute_mode_type = P.Convolution.ComputeMode | |||
| def __init__( | |||
| self, | |||
| @@ -176,8 +173,8 @@ class Conv1d(_ConvNd): | |||
| stride = stride | |||
| padding = padding | |||
| dilation = dilation | |||
| self.conv_mode = self._conv_mode_type.convert(conv_mode) | |||
| self.compute_mode = self._compute_mode_type.convert(compute_mode) | |||
| self.conv_mode = conv_mode | |||
| self.compute_mode = compute_mode | |||
| super().__init__( | |||
| in_channels, | |||
| out_channels, | |||
| @@ -302,9 +299,6 @@ class Conv2d(_ConvNd): | |||
| """ | |||
| _conv_mode_type = P.Convolution.Mode | |||
| _compute_mode_type = P.Convolution.ComputeMode | |||
| def __init__( | |||
| self, | |||
| in_channels: int, | |||
| @@ -322,8 +316,8 @@ class Conv2d(_ConvNd): | |||
| stride = _pair_nonzero(stride) | |||
| padding = _pair(padding) | |||
| dilation = _pair_nonzero(dilation) | |||
| self.conv_mode = self._conv_mode_type.convert(conv_mode) | |||
| self.compute_mode = self._compute_mode_type.convert(compute_mode) | |||
| self.conv_mode = conv_mode | |||
| self.compute_mode = compute_mode | |||
| super().__init__( | |||
| in_channels, | |||
| out_channels, | |||
| @@ -414,9 +408,6 @@ class ConvTranspose2d(_ConvNd): | |||
| effective when input and output are of float16 dtype. | |||
| """ | |||
| _conv_mode_type = P.Convolution.Mode | |||
| _compute_mode_type = P.Convolution.ComputeMode | |||
| def __init__( | |||
| self, | |||
| in_channels: int, | |||
| @@ -434,8 +425,8 @@ class ConvTranspose2d(_ConvNd): | |||
| stride = _pair_nonzero(stride) | |||
| padding = _pair(padding) | |||
| dilation = _pair_nonzero(dilation) | |||
| self.conv_mode = self._conv_mode_type.convert(conv_mode) | |||
| self.compute_mode = self._compute_mode_type.convert(compute_mode) | |||
| self.conv_mode = conv_mode | |||
| self.compute_mode = compute_mode | |||
| super().__init__( | |||
| in_channels, | |||
| out_channels, | |||
| @@ -509,8 +500,6 @@ class LocalConv2d(Conv2d): | |||
| in_channels // groups, *kernel_size, out_channels // groups)`. | |||
| """ | |||
| _conv_mode_type = P.Convolution.Mode | |||
| def __init__( | |||
| self, | |||
| in_channels: int, | |||
| @@ -5,7 +5,6 @@ | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from ..core.ops._internal import param_defs as P | |||
| from ..functional.elemwise import _elwise | |||
| from ..tensor import Tensor | |||
| from .module import Module | |||
| @@ -41,8 +41,8 @@ class Conv2d(Float.Conv2d, QATModule): | |||
| float_module.dilation, | |||
| float_module.groups, | |||
| float_module.bias is not None, | |||
| float_module.conv_mode.name, | |||
| float_module.compute_mode.name, | |||
| float_module.conv_mode, | |||
| float_module.compute_mode, | |||
| ) | |||
| qat_module.weight = float_module.weight | |||
| qat_module.bias = float_module.bias | |||
| @@ -5,7 +5,6 @@ | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from ...core.ops._internal import param_defs as P | |||
| from ...functional.elemwise import _elemwise_multi_type | |||
| from ...tensor import Tensor | |||
| from ..qat import elemwise as QAT | |||
| @@ -15,11 +14,9 @@ from .module import QuantizedModule | |||
| class Elemwise(QuantizedModule): | |||
| r"""Quantized version of :class:`~.qat.elemwise.Elemwise`.""" | |||
| _elemwise_multi_type_mode = P.ElemwiseMultiType.Mode | |||
| def __init__(self, method, dtype=None): | |||
| super().__init__() | |||
| self.method = self._elemwise_multi_type_mode.convert("Q" + method) | |||
| self.method = "Q" + method | |||
| self.output_dtype = dtype | |||
| def forward(self, *inps): | |||
| @@ -15,7 +15,7 @@ from typing import Iterable, List, Optional | |||
| from ..core._imperative_rt import OperatorNodeConfig, ProfileEntry | |||
| from ..core._imperative_rt import ProfilerImpl as _Profiler | |||
| from ..core._imperative_rt.imperative import sync | |||
| from ..core._imperative_rt.ops import CollectiveCommMode | |||
| from ..core._imperative_rt.ops import CollectiveComm | |||
| def _make_dict(**kwargs): | |||
| @@ -194,7 +194,7 @@ class Profiler: | |||
| _type_map = { | |||
| OperatorNodeConfig: lambda x: _print_opnode_config(x), | |||
| bytes: lambda x: base64.encodebytes(x).decode("ascii"), | |||
| CollectiveCommMode: lambda x: str(x), | |||
| CollectiveComm.Mode: lambda x: str(x), | |||
| } | |||
| _dumper_map = { | |||
| @@ -421,9 +421,7 @@ void init_graph_rt(py::module m) { | |||
| common.def("invoke_op", [](const OpDef& def, const std::vector<cg::VarNode*> inputs, cg::ComputingGraph* graph) { | |||
| cg::VarNodeArray vinputs(inputs.begin(), inputs.end()); | |||
| auto opr = OpDef::apply_on_var_node(def, vinputs); | |||
| auto outputs = opr->usable_output(); | |||
| return to_tuple(outputs); | |||
| return to_tuple(OpDef::apply_on_var_node(def, vinputs)); | |||
| }, | |||
| py::arg(), py::arg(), py::arg("graph") = py::none()); | |||
| @@ -109,9 +109,6 @@ void init_imperative_rt(py::module m) { | |||
| py::class_<OpDef, std::shared_ptr<OpDef>>(m, "OpDef") | |||
| .def("ctype", [](const OpDef& opdef) { | |||
| if (auto attr = opdef.try_cast_final<OprAttr>()) { | |||
| return attr->type.c_str(); | |||
| } | |||
| return opdef.dyn_typeinfo()->name; | |||
| }) | |||
| .def("__eq__", [](const OpDef& lhs, const OpDef& rhs) { | |||
| @@ -14,41 +14,29 @@ | |||
| #include "megbrain/imperative.h" | |||
| #include "megbrain/imperative/ops/backward_graph.h" | |||
| #include "megbrain/imperative/ops/opr_attr.h" | |||
| #include "megbrain/imperative/ops/tensor_manip.h" | |||
| #include "megbrain/imperative/ops/collective_comm.h" | |||
| #include "megbrain/imperative/ops/io_remote.h" | |||
| #include "megbrain/imperative/ops/cond_take.h" | |||
| #include "megbrain/imperative/ops/nms.h" | |||
| #include "megbrain/imperative/ops/elemwise.h" | |||
| #include "megbrain/imperative/ops/batch_norm.h" | |||
| #include "megbrain/imperative/ops/broadcast.h" | |||
| #include "megbrain/imperative/ops/utility.h" | |||
| #include "megbrain/imperative/ops/autogen.h" | |||
| namespace py = pybind11; | |||
| namespace { | |||
| auto normalize_enum(const std::string& in) { | |||
| std::string ret; | |||
| for (auto&& c : in) { | |||
| ret += toupper(c); | |||
| } | |||
| return ret; | |||
| } | |||
| } // anonymous namespace | |||
| void init_ops(py::module m) { | |||
| using namespace mgb::imperative; | |||
| py::class_<OprAttr, std::shared_ptr<OprAttr>, OpDef>(m, "OprAttr") | |||
| .def(py::init<>()) | |||
| .def_readwrite("type", &OprAttr::type) | |||
| .def_readwrite("param", &OprAttr::param) | |||
| .def_readwrite("config", &OprAttr::config) | |||
| .def_property("param", | |||
| [](const OprAttr& attr) -> py::bytes { | |||
| return std::string(attr.param.begin(), attr.param.end()); | |||
| }, | |||
| [] (OprAttr& attr, py::bytes data) { | |||
| auto s = py::cast<std::string>(data); | |||
| attr.param.clear(); | |||
| attr.param.insert(attr.param.end(), s.begin(), s.end()); | |||
| }); | |||
| py::class_<BackwardGraph, std::shared_ptr<BackwardGraph>, OpDef>(m, "BackwardGraph") | |||
| .def("interpret", [](BackwardGraph& self, py::object pyf, py::object pyc, | |||
| const mgb::SmallVector<py::object>& inputs) { | |||
| auto f = [pyf](OpDef& op, const mgb::SmallVector<py::object>& inputs) { | |||
| return py::cast<mgb::SmallVector<py::object>>(pyf(op.copy(), inputs)); | |||
| return py::cast<mgb::SmallVector<py::object>>(pyf(op.shared_from_this(), inputs)); | |||
| }; | |||
| auto c = [pyc](const TensorPtr& tensor) { | |||
| return pyc(tensor->dev_tensor()); | |||
| @@ -56,162 +44,8 @@ void init_ops(py::module m) { | |||
| return self.graph().interpret<py::object>(f, c, inputs); | |||
| }); | |||
| py::class_<GetVarShape, std::shared_ptr<GetVarShape>, OpDef>(m, "GetVarShape") | |||
| .def(py::init()); | |||
| #define V(m) .value(#m, CollectiveComm::Mode::m) | |||
| py::enum_<CollectiveComm::Mode>(m, "CollectiveCommMode") | |||
| V(REDUCE_SUM) | |||
| V(BROADCAST) | |||
| V(ALL_GATHER) | |||
| V(REDUCE_SCATTER_SUM) | |||
| V(ALL_REDUCE_SUM) | |||
| V(ALL_REDUCE_MAX) | |||
| V(ALL_REDUCE_MIN) | |||
| V(ALL_REDUCE_PROD) | |||
| V(GATHER) | |||
| V(SCATTER) | |||
| V(ALL_TO_ALL); | |||
| #undef V | |||
| py::class_<CollectiveComm, std::shared_ptr<CollectiveComm>, OpDef>(m, "CollectiveComm") | |||
| .def(py::init<>()) | |||
| .def_readwrite("key", &CollectiveComm::key) | |||
| .def_readwrite("nr_devices", &CollectiveComm::nr_devices) | |||
| .def_readwrite("rank", &CollectiveComm::rank) | |||
| .def_readwrite("is_root", &CollectiveComm::is_root) | |||
| .def_readwrite("local_grad", &CollectiveComm::local_grad) | |||
| .def_readwrite("addr", &CollectiveComm::addr) | |||
| .def_readwrite("port", &CollectiveComm::port) | |||
| .def_readwrite("mode", &CollectiveComm::mode) | |||
| .def_readwrite("dtype", &CollectiveComm::dtype) | |||
| .def_readwrite("backend", &CollectiveComm::backend) | |||
| .def_readwrite("comp_node", &CollectiveComm::comp_node); | |||
| py::class_<RemoteSend, std::shared_ptr<RemoteSend>, OpDef>(m, "RemoteSend") | |||
| .def(py::init<>()) | |||
| .def_readwrite("key", &RemoteSend::key) | |||
| .def_readwrite("addr", &RemoteSend::addr) | |||
| .def_readwrite("port", &RemoteSend::port) | |||
| .def_readwrite("rank_to", &RemoteSend::rank_to); | |||
| py::class_<RemoteRecv, std::shared_ptr<RemoteRecv>, OpDef>(m, "RemoteRecv") | |||
| .def(py::init<>()) | |||
| .def_readwrite("key", &RemoteRecv::key) | |||
| .def_readwrite("addr", &RemoteRecv::addr) | |||
| .def_readwrite("port", &RemoteRecv::port) | |||
| .def_readwrite("rank_from", &RemoteRecv::rank_from) | |||
| .def_readwrite("shape", &RemoteRecv::shape) | |||
| .def_readwrite("cn", &RemoteRecv::cn) | |||
| .def_readwrite("dtype", &RemoteRecv::dtype); | |||
| py::class_<ParamPackSplit, std::shared_ptr<ParamPackSplit>, OpDef>(m, "ParamPackSplit") | |||
| .def(py::init<>()) | |||
| .def_readwrite("offsets", &ParamPackSplit::offsets) | |||
| .def_readwrite("shapes", &ParamPackSplit::shapes); | |||
| py::class_<ParamPackConcat, std::shared_ptr<ParamPackConcat>, OpDef>(m, "ParamPackConcat") | |||
| .def(py::init<>()) | |||
| .def_readwrite("offsets", &ParamPackConcat::offsets); | |||
| py::class_<VirtualDep, std::shared_ptr<VirtualDep>, OpDef>(m, "VirtualDep") | |||
| .def(py::init<>()); | |||
| py::class_<CondTake, std::shared_ptr<CondTake>, OpDef>(m, "CondTake") | |||
| .def(py::init<>()); | |||
| py::class_<NMSKeep, std::shared_ptr<NMSKeep>, OpDef>(m, "NMSKeep") | |||
| .def(py::init<float, uint32_t>()) | |||
| .def_readwrite("iou_thresh", &NMSKeep::iou_thresh) | |||
| .def_readwrite("max_output", &NMSKeep::max_output); | |||
| py::class_<Elemwise, std::shared_ptr<Elemwise>, OpDef> elemwise(m, "Elemwise"); | |||
| elemwise.def(py::init<Elemwise::Mode>()) | |||
| .def_readwrite("mode", &Elemwise::mode); | |||
| #define V(m) .value(#m, Elemwise::Mode::m) | |||
| py::enum_<Elemwise::Mode>(elemwise, "Mode") | |||
| V(RELU) | |||
| V(ABS) | |||
| V(ACOS) | |||
| V(ASIN) | |||
| V(CEIL) | |||
| V(COS) | |||
| V(EXP) | |||
| V(EXPM1) | |||
| V(FLOOR) | |||
| V(LOG) | |||
| V(LOG1P) | |||
| V(NEGATE) | |||
| V(SIGMOID) | |||
| V(SIN) | |||
| V(TANH) | |||
| V(ABS_GRAD) | |||
| V(ADD) | |||
| V(FLOOR_DIV) | |||
| V(MAX) | |||
| V(MIN) | |||
| V(MOD) | |||
| V(MUL) | |||
| V(POW) | |||
| V(SIGMOID_GRAD) | |||
| V(SUB) | |||
| V(SWITCH_GT0) | |||
| V(TANH_GRAD) | |||
| V(TRUE_DIV) | |||
| V(LOG_SUM_EXP) | |||
| V(LT) | |||
| V(LEQ) | |||
| V(EQ) | |||
| V(SHL) | |||
| V(SHR) | |||
| V(COND_LEQ_MOV) | |||
| V(FUSE_MUL_ADD3) | |||
| V(FUSE_MUL_ADD4) | |||
| V(FUSE_ADD_RELU) | |||
| V(FUSE_ADD_SIGMOID) | |||
| V(FUSE_ADD_TANH) | |||
| V(FAST_TANH) | |||
| V(FAST_TANH_GRAD) | |||
| V(ROUND) | |||
| V(RMULH) | |||
| V(ATAN2) | |||
| V(ERF) | |||
| V(ERFINV) | |||
| V(ERFC) | |||
| V(ERFCINV) | |||
| V(H_SWISH) | |||
| V(H_SWISH_GRAD) | |||
| V(FUSE_ADD_H_SWISH) | |||
| V(NOT) | |||
| V(AND) | |||
| V(OR) | |||
| V(XOR); | |||
| #undef V | |||
| py::class_<BatchNorm, std::shared_ptr<BatchNorm>, OpDef> batchnorm(m, "BatchNorm"); | |||
| batchnorm.def(py::init<const BatchNorm::Param::ParamDim&, const BatchNorm::Param::FwdMode&, double, double, float, float>()) | |||
| .def_readwrite("param_dim", &BatchNorm::param_dim) | |||
| .def_readwrite("fwd_mode", &BatchNorm::fwd_mode) | |||
| .def_readwrite("epsilon", &BatchNorm::epsilon) | |||
| .def_readwrite("avg_factor", &BatchNorm::avg_factor) | |||
| .def_readwrite("scale", &BatchNorm::scale) | |||
| .def_readwrite("bias", &BatchNorm::bias); | |||
| #define V(m) .value(#m, BatchNorm::Param::ParamDim::m) | |||
| py::enum_<BatchNorm::Param::ParamDim>(batchnorm, "ParamDim") | |||
| V(DIM_11HW) | |||
| V(DIM_1CHW) | |||
| V(DIM_1C11); | |||
| #undef V | |||
| #define V(m) .value(#m, BatchNorm::Param::FwdMode::m) | |||
| py::enum_<BatchNorm::Param::FwdMode>(batchnorm, "FwdMode") | |||
| V(TRAINING) | |||
| V(INFERENCE); | |||
| #undef V | |||
| py::class_<Broadcast, std::shared_ptr<Broadcast>, OpDef>(m, "Broadcast") | |||
| .def(py::init<>()); | |||
| #include "opdef.py.inl" | |||
| } | |||
| @@ -113,7 +113,7 @@ def test_quint8_typecvt(): | |||
| data = np.random.random(shape).astype(np.float32) * 5 - 1 | |||
| def typecvt(x, dt=None): | |||
| (y,) = apply(ops.TypeCvt(param=dt), x) | |||
| (y,) = apply(ops.TypeCvt(dtype=dt), x) | |||
| return y | |||
| # convert to quint8 | |||
| @@ -194,7 +194,7 @@ def test_quint4_typecvt(): | |||
| data = np.random.random(shape).astype(np.float32) * 5 - 1 | |||
| def typecvt(x, dt=None): | |||
| (y,) = apply(ops.TypeCvt(param=dt), x) | |||
| (y,) = apply(ops.TypeCvt(dtype=dt), x) | |||
| return y | |||
| # convert to quint4 | |||
| @@ -11,10 +11,9 @@ import collections | |||
| import numpy as np | |||
| import pytest | |||
| import megengine.core.ops.builtin | |||
| import megengine.core.tensor.raw_tensor | |||
| from megengine.core._trace_option import use_symbolic_shape | |||
| from megengine.core.ops._internal import all_ops | |||
| from megengine.core.ops import builtin | |||
| from megengine.core.tensor import Tensor | |||
| from megengine.core.tensor.core import apply | |||
| from megengine.core.tensor.raw_tensor import RawTensor, as_raw_tensor | |||
| @@ -105,7 +104,7 @@ def canonize_inputs(inputs, *, config): | |||
| need_cvt = False | |||
| for i in old_inputs: | |||
| if isinstance(i, RawTensor): | |||
| get_comp_node = lambda cn=i.device.to_c(): cn | |||
| get_comp_node = lambda cn=i.device: cn | |||
| else: | |||
| need_cvt = True | |||
| inputs.append(i) | |||
| @@ -193,91 +192,91 @@ def unpack_getitem(inp, tuple_val, *, allow_newaxis=True): | |||
| def transpose(*args, **kwargs): | |||
| op = all_ops.Dimshuffle(**kwargs).to_c() | |||
| op = builtin.Dimshuffle(**kwargs) | |||
| return invoke_op(op, args) | |||
| def broadcast(input, tshape): | |||
| op = all_ops.Broadcast().to_c() | |||
| op = builtin.Broadcast() | |||
| return invoke_op(op, (input, tshape), canonize_reshape) | |||
| def subtensor(input, tuple_val): | |||
| input, tensors, items = unpack_getitem(input, tuple_val) | |||
| op = all_ops.Subtensor(items).to_c() | |||
| op = builtin.Subtensor(items) | |||
| return invoke_op(op, (input, *tensors)) | |||
| def set_subtensor(input, value, tuple_val): | |||
| input, tensors, items = unpack_getitem(input, tuple_val) | |||
| op = all_ops.SetSubtensor(items).to_c() | |||
| op = builtin.SetSubtensor(items) | |||
| return invoke_op(op, (input, value, *tensors)) | |||
| def incr_subtensor(input, value, tuple_val): | |||
| input, tensors, items = unpack_getitem(input, tuple_val) | |||
| op = all_ops.IncrSubtensor(items).to_c() | |||
| op = builtin.IncrSubtensor(items) | |||
| return invoke_op(op, (input, value, *tensors)) | |||
| def advance_indexing(input, tuple_val): | |||
| input, tensors, items = unpack_getitem(input, tuple_val) | |||
| op = all_ops.IndexingMultiAxisVec(items).to_c() | |||
| op = builtin.IndexingMultiAxisVec(items) | |||
| return invoke_op(op, (input, *tensors)) | |||
| def set_advance_indexing(input, value, tuple_val): | |||
| input, tensors, items = unpack_getitem(input, tuple_val) | |||
| op = all_ops.IndexingSetMultiAxisVec(items).to_c() | |||
| op = builtin.IndexingSetMultiAxisVec(items) | |||
| return invoke_op(op, (input, value, *tensors)) | |||
| def incr_advance_indexing(input, value, tuple_val): | |||
| input, tensors, items = unpack_getitem(input, tuple_val) | |||
| op = all_ops.IndexingIncrMultiAxisVec(items).to_c() | |||
| op = builtin.IndexingIncrMultiAxisVec(items) | |||
| return invoke_op(op, (input, value, *tensors)) | |||
| def mesh_indexing(input, tuple_val): | |||
| input, tensors, items = unpack_getitem(input, tuple_val) | |||
| op = all_ops.MeshIndexing(items).to_c() | |||
| op = builtin.MeshIndexing(items) | |||
| return invoke_op(op, (input, *tensors)) | |||
| def set_mesh_indexing(input, value, tuple_val): | |||
| input, tensors, items = unpack_getitem(input, tuple_val) | |||
| op = all_ops.SetMeshIndexing(items).to_c() | |||
| op = builtin.SetMeshIndexing(items) | |||
| return invoke_op(op, (input, value, *tensors)) | |||
| def incr_mesh_indexing(input, value, tuple_val): | |||
| input, tensors, items = unpack_getitem(input, tuple_val) | |||
| op = all_ops.IncrMeshIndexing(items).to_c() | |||
| op = builtin.IncrMeshIndexing(items) | |||
| return invoke_op(op, (input, value, *tensors)) | |||
| def batched_mesh_indexing(input, tuple_val): | |||
| input, tensors, items = unpack_getitem(input, tuple_val) | |||
| op = all_ops.BatchedMeshIndexing(items).to_c() | |||
| op = builtin.BatchedMeshIndexing(items) | |||
| return invoke_op(op, (input, *tensors)) | |||
| def batched_set_mesh_indexing(input, value, tuple_val): | |||
| input, tensors, items = unpack_getitem(input, tuple_val) | |||
| op = all_ops.BatchedSetMeshIndexing(items).to_c() | |||
| op = builtin.BatchedSetMeshIndexing(items) | |||
| return invoke_op(op, (input, value, *tensors)) | |||
| def batched_incr_mesh_indexing(input, value, tuple_val): | |||
| input, tensors, items = unpack_getitem(input, tuple_val) | |||
| op = all_ops.BatchedIncrMeshIndexing(items).to_c() | |||
| op = builtin.BatchedIncrMeshIndexing(items) | |||
| return invoke_op(op, (input, value, *tensors)) | |||
| def test_transpose(): | |||
| x = np.arange(10).reshape(2, 5).astype("int32") | |||
| xx = as_raw_tensor(x) | |||
| (yy,) = transpose(xx, pattern="1x0") | |||
| (yy,) = transpose(xx, pattern=[1, -1, 0]) | |||
| np.testing.assert_equal(np.expand_dims(x.transpose(), axis=1), yy.numpy()) | |||
| @@ -1,320 +0,0 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from io import StringIO | |||
| import re | |||
| import argparse | |||
| import subprocess | |||
| import os | |||
| import textwrap | |||
| import inspect | |||
| def camel2underscore( | |||
| name, *, | |||
| first_cap_re=re.compile('([A-Z])([A-Z][a-z]+)'), | |||
| all_cap_re = re.compile('([a-z])([A-Z]+)')): | |||
| if name.isupper(): | |||
| return name.lower() | |||
| s1 = first_cap_re.sub(r'\1_\2', name) | |||
| return all_cap_re.sub(r'\1_\2', s1).lower() | |||
| def caller_lineno(level=1): | |||
| f = inspect.stack()[level+1] | |||
| return '%s:%d' % (f.filename, f.lineno) | |||
| class Doc: | |||
| """wrap an identifier and doc""" | |||
| _id = None | |||
| def __init__(self, id_, doc, typestr=None, default=None): | |||
| self._id = id_ | |||
| self.doc = doc | |||
| self.typestr = typestr | |||
| self.default = default | |||
| def __str__(self): | |||
| return self._id | |||
| class Context: | |||
| fout = None | |||
| def __init__(self): | |||
| self.fout = StringIO() | |||
| self.indent = 0 | |||
| self.skipped = [] | |||
| self.generated_signature = set() | |||
| self.generated_opr = dict() | |||
| def write(self, text, *fmt, indent=0): | |||
| text = textwrap.dedent(text) | |||
| text = textwrap.indent(text, ' '*4*(self.indent + indent)) | |||
| text = text % fmt | |||
| if not text.endswith('\n'): | |||
| text += '\n' | |||
| self.fout.write(text) | |||
| def _gen_signature(self, params, *, have_config=True, | |||
| has_out_dtype=False): | |||
| sig = ['self', '*'] | |||
| for i, _ in params: | |||
| sig.append('{}=None'.format(i)) | |||
| if have_config: | |||
| sig.extend(['name=None', 'comp_node=None', 'config=None']) | |||
| if has_out_dtype: | |||
| sig.append('dtype=None') | |||
| if params: | |||
| sig.append('**kwargs') | |||
| if sig[-1] == '*': | |||
| sig.pop() | |||
| return ', '.join(sig) | |||
| def _write_canonize_inputs(self, inputs, convert_inputs, | |||
| convert_inputs_args=None, | |||
| has_out_dtype=False): | |||
| self._write_gen_config(has_out_dtype) | |||
| inputs = list(map(str, inputs)) | |||
| if convert_inputs_args is None: | |||
| if inputs[0][0] == '*': | |||
| arg = inputs[0][1:] | |||
| else: | |||
| arg = '[{}]'.format(', '.join(inputs)) | |||
| else: | |||
| arg = convert_inputs_args | |||
| self.write('inputs = helper.%s(%s, config=config)', | |||
| convert_inputs, arg) | |||
| def _write_gen_config(self, has_out_dtype=False): | |||
| self.write('''\ | |||
| config = config or Config() | |||
| if name: | |||
| config.name = name | |||
| if comp_node: | |||
| config.comp_node = comp_node | |||
| ''') | |||
| if has_out_dtype: | |||
| self.write('''\ | |||
| if dtype: | |||
| config.dtype = dtype | |||
| ''') | |||
| self.write('self.config = config') | |||
| def _write_make_params(self, params): | |||
| for pname, ptype in params: | |||
| self.write('self.%s = helper.make_param(%s, param_defs.%s, kwargs)', | |||
| pname, pname, ptype) | |||
| self.write('assert not kwargs, "extra kwargs: {}".format(kwargs)') | |||
| def _write_doc(self, inputs, params, desc): | |||
| self.write('"""') | |||
| if isinstance(desc, Doc): | |||
| assert desc._id is None | |||
| self.write(desc.doc) | |||
| elif desc: | |||
| for i in textwrap.wrap(desc, 75): | |||
| self.write(i) | |||
| self.write('') | |||
| for i in inputs: | |||
| name = str(i) | |||
| typestr = ':class:`.Tensor`' | |||
| if name[0] == '*': | |||
| name = name[1:] | |||
| typestr = 'list of ' + typestr | |||
| if isinstance(i, Doc): | |||
| self.write(':param %s: %s', name, i.doc) | |||
| if i.typestr is not None: | |||
| typestr = i.typestr | |||
| if typestr: | |||
| if not isinstance(i, Doc): | |||
| self.write(':param %s: ', name) | |||
| self.write(':type %s: %s', name, typestr) | |||
| for pname, ptype in params: | |||
| self.write(':param %s: ', pname) | |||
| self.write(':type %s: :class:`~megbrain.opr_param_defs.%s`', | |||
| pname, ptype) | |||
| self.write(':param comp_node: see doc for *config*') | |||
| self.write(':param name: see doc for *config*') | |||
| self.write( | |||
| ':param config: give a :class:`.OperatorNodeConfig` object to set ' | |||
| 'operator name and comp node. This can also be achieved by passing ' | |||
| '*comp_node* and *name* separately.') | |||
| self.write('"""') | |||
| def _write_return(self, name, outputs): | |||
| self.write('opdef = helper.PodOpVisitor("%s", config, params)', name) | |||
| self.write('outputs = helper.create_op(opdef, inputs)') | |||
| if outputs: | |||
| self.write('outputs = [outputs[i] for i in %s]', | |||
| list(map(int, outputs))) | |||
| self.write('return helper.convert_outputs(outputs)') | |||
| def decl_opr(self, name, *, inputs, params, desc=None, pyname=None, | |||
| canonize_input_vars=None, | |||
| canonize_input_vars_args=None, body=None, | |||
| outputs=None, version=0, has_out_dtype=False): | |||
| """ | |||
| :param inputs: name of variable inputs; a name starting with `*' means | |||
| a list of vars | |||
| :type inputs: list of str | |||
| :param params: (param name, param type) pairs; it can be a single | |||
| string representing the param type, and param name defaults to | |||
| 'param' | |||
| :type params: list of pair of str, or str | |||
| :param pyname: python function name | |||
| :param body: extra statements to be placed before calling _create_opr | |||
| :param outputs: the indices of output vars to be selected from raw opr | |||
| result | |||
| """ | |||
| class OprItem: | |||
| def __init__(self, inputs, desc, params, version, has_out_dtype): | |||
| self.inputs = inputs | |||
| self.desc = desc | |||
| self.params = params | |||
| self.version = version | |||
| self.has_out_dtype = has_out_dtype | |||
| if body: | |||
| self.skipped.append(name) | |||
| return | |||
| signature = (name, params if isinstance(params, str) else frozenset(params), has_out_dtype, version) | |||
| if signature in self.generated_signature: | |||
| self.skipped.append(name) | |||
| return | |||
| else: | |||
| self.generated_signature.add(signature) | |||
| body = body or [] | |||
| if isinstance(params, str): | |||
| params = [('param', params)] | |||
| assert params | |||
| if name in self.generated_opr: | |||
| org_opr = self.generated_opr[name] | |||
| if version > org_opr.version: | |||
| def compare_doc(a, b): | |||
| if isinstance(a, str): | |||
| return a == b | |||
| else: | |||
| assert isinstance(a, Doc) | |||
| return a.doc == b.doc | |||
| assert compare_doc(desc, org_opr.desc) | |||
| assert len(inputs) == len(org_opr.inputs) | |||
| for i, j in zip(inputs, org_opr.inputs): | |||
| assert compare_doc(i, j) | |||
| self.generated_opr[name] = OprItem(inputs, desc, params, version, has_out_dtype) | |||
| else: | |||
| self.generated_opr[name] = OprItem(inputs, desc, params, version, has_out_dtype) | |||
| def write_generated_oprs(self): | |||
| for opr, opr_item in self.generated_opr.items(): | |||
| name = opr | |||
| params = opr_item.params | |||
| version = opr_item.version | |||
| has_out_dtype = opr_item.has_out_dtype | |||
| self.write('# %s', caller_lineno()) | |||
| self.write('class %s(PodOpVisitor):', name) | |||
| self.indent += 1 | |||
| param_names, _ = zip(*params) | |||
| self.write('param_names = (%s,)', ', '.join(map('"{}"'.format, param_names))) | |||
| self.write('name = "%s"', '{}V{}'.format(name, version) if version else name) | |||
| self.write('\n') | |||
| self.write('def __init__(%s):', | |||
| self._gen_signature(params, | |||
| has_out_dtype=has_out_dtype)) | |||
| self.indent += 1 | |||
| self._write_gen_config(has_out_dtype=has_out_dtype) | |||
| self.write('\n') | |||
| self._write_make_params(params) | |||
| self.write('\n') | |||
| self.indent -= 2 | |||
| def decl_raw_opr(self, name, *, inputs, inputs_cvt=[], body=None, | |||
| desc=None, local_defs=[], have_config=True, params=None, has_out_dtype=False): | |||
| self.skipped.append(name) | |||
| def get_str(self): | |||
| return self.fout.getvalue() | |||
| def all_list(self): | |||
| buf = StringIO() | |||
| print( | |||
| '[', | |||
| *(' "%s",' % i for i in self.generated_opr), | |||
| ']', | |||
| sep='\n', | |||
| file=buf | |||
| ) | |||
| return buf.getvalue() | |||
| def main(): | |||
| parser = argparse.ArgumentParser( | |||
| description='generate operator function def code from decl file') | |||
| parser.add_argument('inputs', nargs='+') | |||
| parser.add_argument('--output', '-o') | |||
| args = parser.parse_args() | |||
| gen = Context() | |||
| exec_globals = { | |||
| 'decl_opr': gen.decl_opr, | |||
| 'decl_raw_opr': gen.decl_raw_opr, | |||
| 'Doc': Doc, | |||
| 'camel2underscore': camel2underscore, | |||
| } | |||
| for i in args.inputs: | |||
| print('generate ops from {}'.format(i)) | |||
| with open(i) as fin: | |||
| exec(compile(fin.read(), i, 'exec'), exec_globals) | |||
| gen.write_generated_oprs() | |||
| try: | |||
| git_commit = subprocess.check_output( | |||
| ['git', 'rev-parse', 'HEAD'], universal_newlines=True, | |||
| cwd=os.path.dirname(os.path.realpath(__file__))).strip() | |||
| except: | |||
| git_commit = 'NOT_A_GIT_REPO' | |||
| def relpath(*args): | |||
| d = os.path.dirname(__file__) | |||
| return os.path.join(d, *args) | |||
| with open(relpath('ops.tpl.py')) as fin: | |||
| with open(args.output, 'w') as fout: | |||
| fout.write(fin.read() | |||
| .replace('{%all%}', gen.all_list()) | |||
| .replace('{%body%}', gen.get_str()) | |||
| .replace('{%git_commit%}', git_commit)) | |||
| print('Skipped:') | |||
| print(*gen.skipped, sep='\n') | |||
| if __name__ == '__main__': | |||
| main() | |||
| @@ -1,40 +0,0 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| """This python module contains functions to apply the operators defined by | |||
| megbrain. | |||
| .. note:: | |||
| Most of the functions are automatically generated, and their signature have | |||
| the form contain a ``param`` argument (or more than one arguments such as | |||
| :func:`convolution` that has ``param`` and ``execution_polity``) and also | |||
| accept keyword arguments. In such case, it can be called by either | |||
| providing a param object of appropriate type, or by passing the arguments | |||
| needed by the constructor of param object to the keyword arguments. | |||
| Furthermore, for a param that needs an enumeration member, the enum name | |||
| can be used to refer to the enum object. | |||
| For example, the following statements are equivalent:: | |||
| elemwise([a, b], mode='max') | |||
| elemwise([a, b], mode=opr_param_defs.Elemwise.Mode.MAX) | |||
| elemwise([a, b], param=opr_param_defs.Elemwise('max')) | |||
| """ | |||
| __git_commit__ = "{%git_commit%}" | |||
| import collections | |||
| from . import helper | |||
| from .helper import PodOpVisitor | |||
| from . import param_defs | |||
| from ..._imperative_rt import OperatorNodeConfig as Config | |||
| __all__ = {%all%} | |||
| {%body%} | |||
| @@ -36,7 +36,7 @@ SmallVector<TensorPtr> OpDef::apply_on_physical_tensor( | |||
| return def.trait()->apply_on_physical_tensor(def, inputs); | |||
| } | |||
| cg::OperatorNodeBase* OpDef::apply_on_var_node( | |||
| VarNodeArray OpDef::apply_on_var_node( | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| return def.trait()->apply_on_var_node(def, inputs); | |||
| @@ -56,6 +56,14 @@ BackwardGraphResult OpDef::make_backward_graph( | |||
| return def.trait()->make_backward_graph(def, inputs, input_requires_grad, output_has_grad); | |||
| } | |||
| size_t OpDef::hash() const { | |||
| return trait()->hash(*this); | |||
| } | |||
| bool OpDef::is_same_st(const Hashable& rhs) const { | |||
| return trait()->is_same_st(*this, static_cast<const OpDef&>(rhs)); | |||
| } | |||
| const OpTrait* OpDef::trait() const { | |||
| if (!m_trait) { | |||
| m_trait = OpTrait::find_by_typeinfo(dyn_typeinfo()); | |||
| @@ -23,7 +23,7 @@ namespace detail { | |||
| struct StaticData { | |||
| std::list<OpTrait> registries; | |||
| std::unordered_map<const char*, OpTrait*> name2reg; | |||
| std::unordered_map<std::string, OpTrait*> name2reg; | |||
| std::unordered_map<Typeinfo*, OpTrait*> type2reg; | |||
| }; | |||
| @@ -30,6 +30,32 @@ struct OpMeth<RType(Args...)>: public thin_function<RType(Args...)> { | |||
| return this->Base::operator ()(args...); | |||
| } | |||
| }; | |||
| template<typename T> | |||
| struct ToVarNodeArray: std::false_type {}; | |||
| template<> | |||
| struct ToVarNodeArray<SymbolVar>: std::true_type { | |||
| VarNodeArray operator()(const SymbolVar& inp) { | |||
| return {inp.node()}; | |||
| } | |||
| }; | |||
| template<> | |||
| struct ToVarNodeArray<SymbolVarArray>: std::true_type { | |||
| VarNodeArray operator()(const SymbolVarArray& inputs) { | |||
| return cg::to_var_node_array(inputs); | |||
| } | |||
| }; | |||
| template<size_t N> | |||
| struct ToVarNodeArray<std::array<SymbolVar, N>>: std::true_type { | |||
| VarNodeArray operator()(const std::array<SymbolVar, N>& inp) { | |||
| return cg::to_var_node_array({inp.begin(), inp.end()}); | |||
| } | |||
| }; | |||
| template<> | |||
| struct ToVarNodeArray<cg::OperatorNodeBase*>: std::true_type { | |||
| VarNodeArray operator()(const cg::OperatorNodeBase* opr) { | |||
| return opr->usable_output(); | |||
| } | |||
| }; | |||
| } // detail | |||
| using OpDefMaker = detail::OpMeth< | |||
| @@ -42,6 +68,8 @@ using InferOutputAttrsFallible = detail::OpMeth< | |||
| decltype(OpDef::infer_output_attrs_fallible)>; | |||
| using GradMaker = detail::OpMeth< | |||
| decltype(OpDef::make_backward_graph)>; | |||
| using HashFunc = detail::OpMeth<size_t(const OpDef&)>; | |||
| using IsSame = detail::OpMeth<bool(const OpDef&, const OpDef&)>; | |||
| struct OpTrait { | |||
| const char* name; | |||
| @@ -50,6 +78,8 @@ struct OpTrait { | |||
| ApplyOnVarNode apply_on_var_node; | |||
| InferOutputAttrsFallible infer_output_attrs_fallible; | |||
| GradMaker make_backward_graph; | |||
| HashFunc hash; | |||
| IsSame is_same_st; | |||
| OpTrait(const char* name); | |||
| static OpTrait* find_by_name(const char* name); | |||
| static OpTrait* find_by_typeinfo(Typeinfo* type); | |||
| @@ -61,7 +91,9 @@ struct OpTrait { | |||
| cb(apply_on_physical_tensor) \ | |||
| cb(apply_on_var_node) \ | |||
| cb(infer_output_attrs_fallible) \ | |||
| cb(make_backward_graph) | |||
| cb(make_backward_graph) \ | |||
| cb(hash) \ | |||
| cb(is_same_st) | |||
| struct OpTraitRegistry { | |||
| OpTrait* trait; | |||
| @@ -97,6 +129,15 @@ struct OpTraitRegistry { | |||
| void do_insert(Typeinfo* type); | |||
| static OpTraitRegistry do_insert(const char* name); | |||
| template<typename T, | |||
| typename To = detail::ToVarNodeArray<T>, | |||
| typename = std::enable_if_t<To::value>> | |||
| OpTraitRegistry& apply_on_var_node(T (*f)(const OpDef&, const VarNodeArray&)) { | |||
| return apply_on_var_node([=](const OpDef& opdef, const VarNodeArray& inputs) { | |||
| return To()(f(opdef, inputs)); | |||
| }); | |||
| } | |||
| }; | |||
| } // namespace imperative | |||
| @@ -0,0 +1,46 @@ | |||
| /** | |||
| * \file imperative/src/impl/ops/autogen.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #include "megbrain/imperative/ops/autogen.h" | |||
| #include "../op_trait.h" | |||
| using namespace megdnn; | |||
| // FIXME: remove this when mgb::hash support tuple_hash | |||
| namespace mgb { | |||
| namespace { | |||
| template<typename T, size_t ...Ns> | |||
| auto tail(T t, std::index_sequence<Ns...>) { | |||
| return std::make_tuple(std::get<Ns+1>(t)...); | |||
| } | |||
| } // anonymous namespace | |||
| template<typename T, typename ...Args> | |||
| class HashTrait<std::tuple<T, Args...>> { | |||
| constexpr static size_t length = sizeof...(Args); | |||
| public: | |||
| static size_t eval(const std::tuple<T, Args...> &t) { | |||
| const T& val = std::get<0>(t); | |||
| if constexpr (!length) { | |||
| return mgb::hash(val); | |||
| } else { | |||
| return mgb::hash_pair_combine(mgb::hash(val), | |||
| mgb::hash(tail(t, std::make_index_sequence<length - 1>{}))); | |||
| } | |||
| } | |||
| }; | |||
| } // namespace mgb | |||
| namespace mgb::imperative { | |||
| #include "./opdef.cpp.inl" | |||
| } // namespace mgb::imperative | |||
| @@ -9,7 +9,8 @@ | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #include "megbrain/imperative/ops/batch_norm.h" | |||
| #include "megbrain/imperative/ops/autogen.h" | |||
| #include "megbrain/opr/dnn/batch_norm.h" | |||
| #include "../op_trait.h" | |||
| namespace mgb { | |||
| @@ -19,9 +20,7 @@ namespace { | |||
| std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) { | |||
| auto* node = &node_->cast_final_safe<opr::BatchNorm>(); | |||
| auto&& param = node->param(); | |||
| return BatchNorm::make(param.param_dim, param.fwd_mode, param.epsilon, | |||
| param.avg_factor, param.scale, param.bias); | |||
| return BatchNorm::make(node->param()); | |||
| } | |||
| cg::OperatorNodeBase* apply_on_var_node( | |||
| @@ -33,13 +32,11 @@ cg::OperatorNodeBase* apply_on_var_node( | |||
| "BatchNorm expects 3 or 5 inputs; got %lu actually", nr_inp); | |||
| if (nr_inp == 3) { | |||
| return opr::BatchNorm::make( | |||
| inputs[0], inputs[1], inputs[2], | |||
| {bn_opr.param_dim, bn_opr.fwd_mode, bn_opr.epsilon, bn_opr.avg_factor, bn_opr.scale, bn_opr.bias})[0] | |||
| inputs[0], inputs[1], inputs[2], bn_opr.param())[0] | |||
| .node()->owner_opr(); | |||
| } else { | |||
| return opr::BatchNorm::make( | |||
| inputs[0], inputs[1], inputs[2], inputs[3], inputs[4], | |||
| {bn_opr.param_dim, bn_opr.fwd_mode, bn_opr.epsilon, bn_opr.avg_factor, bn_opr.scale, bn_opr.bias})[0] | |||
| inputs[0], inputs[1], inputs[2], inputs[3], inputs[4], bn_opr.param())[0] | |||
| .node()->owner_opr(); | |||
| } | |||
| } | |||
| @@ -52,7 +49,7 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||
| mgb_assert(nr_inp == 3 ||nr_inp == 5, | |||
| "BatchNorm expects 3 or 5 inputs; got %lu actually", nr_inp); | |||
| // need running mean/variance | |||
| bool need_stat = (nr_inp == 5) && op_def.fwd_mode == BatchNorm::Param::FwdMode::TRAINING; | |||
| bool need_stat = (nr_inp == 5) && op_def.fwd_mode == BatchNorm::FwdMode::TRAINING; | |||
| size_t nr_out = need_stat? 5 : 3; | |||
| SmallVector<LogicalTensorDesc> out_shapes(nr_out); | |||
| auto&& i0 = inputs[0]; | |||
| @@ -76,8 +73,6 @@ OP_TRAIT_REG(BatchNorm, BatchNorm, opr::BatchNorm) | |||
| .fallback(); | |||
| } // anonymous namespace | |||
| MGB_DYN_TYPE_OBJ_FINAL_IMPL(BatchNorm); | |||
| } // namespace imperative | |||
| } // namespace mgb | |||
| @@ -9,7 +9,9 @@ | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #include "megbrain/imperative/ops/broadcast.h" | |||
| #include "megbrain/imperative/ops/autogen.h" | |||
| #include "megbrain/opr/tensor_manip.h" | |||
| #include "../op_trait.h" | |||
| namespace mgb { | |||
| @@ -87,8 +89,6 @@ OP_TRAIT_REG(Broadcast, Broadcast, opr::Broadcast) | |||
| .fallback(); | |||
| } // anonymous namespace | |||
| MGB_DYN_TYPE_OBJ_FINAL_IMPL(Broadcast); | |||
| } // namespace imperative | |||
| } // namespace mgb | |||
| @@ -18,7 +18,7 @@ | |||
| #include "megbrain/utils/hash.h" | |||
| #endif // MGB_ENABLE_OPR_MM | |||
| #include "megbrain/imperative/ops/collective_comm.h" | |||
| #include "megbrain/imperative/ops/autogen.h" | |||
| namespace mgb { | |||
| namespace imperative { | |||
| @@ -61,8 +61,8 @@ std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node) { | |||
| auto [addr, port] = split_address(group_client->get_addr()); | |||
| auto comp_node = node->config().get_single_comp_node().to_string_logical(); | |||
| return std::make_shared<CollectiveComm>( | |||
| comm.key(), comm.nr_devices(), comm.rank(), comm.is_root(), | |||
| comm.local_grad(), addr, std::stoi(port), comm.param().mode, | |||
| comm.param().mode, comm.key(), comm.nr_devices(), comm.rank(), | |||
| comm.is_root(), comm.local_grad(), addr, std::stoi(port), | |||
| comm.dtype(), comm.backend(), comp_node); | |||
| } | |||
| @@ -73,35 +73,6 @@ OP_TRAIT_REG(CollectiveComm, CollectiveComm, opr::CollectiveComm) | |||
| } // anonymous namespace | |||
| #endif // MGB_ENABLE_OPR_MM | |||
| bool CollectiveComm::is_same_st(const Hashable& another) const{ | |||
| auto* comm_opr = another.try_cast_final<CollectiveComm>(); | |||
| if(!comm_opr){ | |||
| return false; | |||
| } | |||
| return as_tuple() == comm_opr->as_tuple(); | |||
| } | |||
| size_t CollectiveComm::hash() const{ | |||
| XXHash xxhash{}; | |||
| auto append = [&xxhash](auto field){ | |||
| auto hash_val = HashTrait<decltype(field)>::eval(field); | |||
| xxhash.update(reinterpret_cast<void*>(&hash_val), sizeof(hash_val)); | |||
| }; | |||
| append(key); | |||
| append(nr_devices); | |||
| append(rank); | |||
| append(is_root); | |||
| append(local_grad); | |||
| append(addr); | |||
| append(port); | |||
| append(mode); | |||
| append(backend); | |||
| append(comp_node); | |||
| return xxhash.digest(); | |||
| } | |||
| MGB_DYN_TYPE_OBJ_FINAL_IMPL(CollectiveComm); | |||
| } // namespace imperative | |||
| } // namespace mgb | |||
| @@ -9,8 +9,7 @@ | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #include "megbrain/imperative/ops/cond_take.h" | |||
| #include "megbrain/imperative/ops/opr_attr.h" | |||
| #include "megbrain/imperative/ops/autogen.h" | |||
| #include "megbrain/opr/misc.h" | |||
| #include "../dnn_op_helper.h" | |||
| #include "../op_trait.h" | |||
| @@ -19,8 +18,6 @@ using namespace megdnn; | |||
| namespace mgb::imperative { | |||
| MGB_DYN_TYPE_OBJ_FINAL_IMPL(CondTake); | |||
| namespace { | |||
| class MegDNNDynOutMallocImpl final: public megdnn::DynOutMallocPolicy { | |||
| @@ -9,7 +9,9 @@ | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #include "megbrain/imperative/ops/elemwise.h" | |||
| #include "megbrain/imperative/ops/autogen.h" | |||
| #include "megbrain/opr/basic_arith.h" | |||
| #include "../op_trait.h" | |||
| namespace mgb { | |||
| @@ -33,7 +35,7 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||
| const OpDef& def, | |||
| const SmallVector<LogicalTensorDesc>& inputs) { | |||
| auto&& op_def = def.cast_final_safe<Elemwise>(); | |||
| auto trait = Elemwise::ModeTrait::from_mode(op_def.mode); | |||
| auto trait = megdnn::Elemwise::ModeTrait::from_mode(op_def.mode); | |||
| mgb_assert(inputs.size() == trait.arity, | |||
| "%s expects %u inputs; got %zu actually", trait.name, | |||
| trait.arity, inputs.size()); | |||
| @@ -70,8 +72,6 @@ OP_TRAIT_REG(Elemwise, Elemwise, opr::Elemwise) | |||
| .fallback(); | |||
| } // anonymous namespace | |||
| MGB_DYN_TYPE_OBJ_FINAL_IMPL(Elemwise); | |||
| } // namespace imperative | |||
| } // namespace mgb | |||
| @@ -18,7 +18,7 @@ | |||
| #include "megbrain/opr/mm_handler.h" | |||
| #endif // MGB_ENABLE_OPR_MM | |||
| #include "megbrain/imperative/ops/io_remote.h" | |||
| #include "megbrain/imperative/ops/autogen.h" | |||
| namespace mgb { | |||
| namespace imperative { | |||
| @@ -60,45 +60,5 @@ OP_TRAIT_REG(RemoteRecv, RemoteRecv, mgb::opr::RemoteRecv) | |||
| } // anonymous namespace | |||
| #endif // MGB_ENABLE_OPR_MM | |||
| bool RemoteSend::is_same_st(const Hashable& another) const{ | |||
| return as_tuple() == another.cast_final<RemoteSend>().as_tuple(); | |||
| } | |||
| size_t RemoteSend::hash() const{ | |||
| XXHash xxhash; | |||
| auto append = [&xxhash](auto field){ | |||
| auto hash_val = HashTrait<decltype(field)>::eval(field); | |||
| xxhash.update(reinterpret_cast<void*>(&hash_val), sizeof(hash_val)); | |||
| }; | |||
| append(key); | |||
| append(addr); | |||
| append(port); | |||
| append(rank_to); | |||
| return xxhash.digest(); | |||
| } | |||
| bool RemoteRecv::is_same_st(const Hashable& another) const{ | |||
| return as_tuple() == another.cast_final<RemoteRecv>().as_tuple(); | |||
| } | |||
| size_t RemoteRecv::hash() const{ | |||
| XXHash xxhash; | |||
| auto append = [&xxhash](auto field){ | |||
| auto hash_val = HashTrait<decltype(field)>::eval(field); | |||
| xxhash.update(reinterpret_cast<void*>(&hash_val), sizeof(hash_val)); | |||
| }; | |||
| append(key); | |||
| append(addr); | |||
| append(port); | |||
| append(rank_from); | |||
| append(cn.to_string()); | |||
| append(dtype.handle()); | |||
| append(shape.to_string()); | |||
| return xxhash.digest(); | |||
| } | |||
| MGB_DYN_TYPE_OBJ_FINAL_IMPL(RemoteSend); | |||
| MGB_DYN_TYPE_OBJ_FINAL_IMPL(RemoteRecv); | |||
| } // namespace imperative | |||
| } // namespace mgb | |||
| @@ -11,7 +11,7 @@ | |||
| #include "../op_trait.h" | |||
| #include "megbrain/imperative/ops/nms.h" | |||
| #include "megbrain/imperative/ops/autogen.h" | |||
| #include "megbrain/opr/standalone/nms_opr.h" | |||
| namespace mgb { | |||
| @@ -37,8 +37,6 @@ OP_TRAIT_REG(NMSKeep, NMSKeep, NMSKeepOpr) | |||
| .fallback(); | |||
| } // anonymous namespace | |||
| MGB_DYN_TYPE_OBJ_FINAL_IMPL(NMSKeep); | |||
| } // namespace imperative | |||
| } // namespace mgb | |||
| @@ -0,0 +1,630 @@ | |||
| /** | |||
| * \file imperative/src/impl/ops/autogen.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| // FIXME: split this file into separate files for each specialized op | |||
| #include "megbrain/imperative/ops/autogen.h" | |||
| #include "megbrain/opr/dnn/convolution.h" | |||
| #include "megbrain/opr/dnn/adaptive_pooling.h" | |||
| #include "megbrain/opr/dnn/fake_quant.h" | |||
| #include "megbrain/opr/dnn/pooling.h" | |||
| #include "megbrain/opr/dnn/local.h" | |||
| #include "megbrain/opr/dnn/roi_align.h" | |||
| #include "megbrain/opr/dnn/roi_pooling.h" | |||
| #include "megbrain/opr/basic_arith.h" | |||
| #include "megbrain/opr/blas.h" | |||
| #include "megbrain/opr/imgproc.h" | |||
| #include "megbrain/opr/indexing.h" | |||
| #include "megbrain/opr/io.h" | |||
| #include "megbrain/opr/misc.h" | |||
| #include "megbrain/opr/nn_int.h" | |||
| #include "megbrain/opr/rand.h" | |||
| #include "megbrain/opr/tensor_gen.h" | |||
| #include "megbrain/opr/tensor_manip.h" | |||
| #include "megbrain/opr/utility.h" | |||
| #include "../op_trait.h" | |||
| namespace mgb::imperative { | |||
| namespace { namespace convolution { | |||
| std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) { | |||
| auto* node = &node_->cast_final_safe<opr::Convolution>(); | |||
| return Convolution::make(node->param(), node->execution_policy()); | |||
| } | |||
| auto apply_on_var_node( | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| auto&& conv = static_cast<const Convolution&>(def); | |||
| return opr::Convolution::make(inputs[0], inputs[1], conv.param(), conv.policy()); | |||
| } | |||
| OP_TRAIT_REG(Convolution, Convolution, opr::Convolution) | |||
| .make_from_op_node(make_from_op_node) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .fallback(); | |||
| }} // convolution | |||
| namespace { namespace convolution_backward_data { | |||
| auto apply_on_var_node( | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| auto&& conv = static_cast<const ConvolutionBackwardData&>(def); | |||
| cg::OperatorNodeConfig config; | |||
| if (inputs.size() == 2) { | |||
| return opr::ConvolutionBackwardData::make(inputs[0], inputs[1], conv.param(), conv.policy(), config); | |||
| } else { | |||
| mgb_assert(inputs.size() == 3); | |||
| return opr::ConvolutionBackwardData::make(inputs[0], inputs[1], inputs[2], conv.param(), conv.policy(), config); | |||
| } | |||
| } | |||
| OP_TRAIT_REG(ConvolutionBackwardData, ConvolutionBackwardData) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .fallback(); | |||
| }} // convolution_backward_data | |||
| namespace { namespace dimshuffle { | |||
| std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) { | |||
| auto* node = &node_->cast_final_safe<opr::Dimshuffle>(); | |||
| std::vector<int> pattern(node->param().pattern_len); | |||
| for (size_t i = 0; i < node->param().pattern_len; ++ i) { | |||
| pattern[i] = node->param().pattern[i]; | |||
| } | |||
| return Dimshuffle::make(pattern); | |||
| } | |||
| auto apply_on_var_node( | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| auto&& ds = static_cast<const Dimshuffle&>(def); | |||
| return opr::Dimshuffle::make(inputs[0], ds.pattern); | |||
| } | |||
| OP_TRAIT_REG(Dimshuffle, Dimshuffle, opr::Dimshuffle) | |||
| .make_from_op_node(make_from_op_node) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .fallback(); | |||
| }} // dimshuffle | |||
| namespace { namespace add_axis { | |||
| auto apply_on_var_node( | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| auto&& add_axis = static_cast<const AddAxis&>(def); | |||
| using Desc = opr::AxisAddRemove::AxisDesc; | |||
| std::vector<Desc> param; | |||
| for (auto&& i : add_axis.axis) { | |||
| param.push_back(Desc::make_add(i)); | |||
| } | |||
| return opr::AxisAddRemove::make(inputs[0], param); | |||
| } | |||
| OP_TRAIT_REG(AddAxis, AddAxis) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .fallback(); | |||
| }} // add_axis | |||
| namespace { namespace remove_axis { | |||
| auto apply_on_var_node( | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| auto&& remove_axis = static_cast<const RemoveAxis&>(def); | |||
| using Desc = opr::AxisAddRemove::AxisDesc; | |||
| std::vector<Desc> param; | |||
| for (auto&& i : remove_axis.axis) { | |||
| param.push_back(Desc::make_remove(i)); | |||
| } | |||
| return opr::AxisAddRemove::make(inputs[0], param); | |||
| } | |||
| OP_TRAIT_REG(RemoveAxis, RemoveAxis) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .fallback(); | |||
| }} // remove_axis | |||
| namespace { namespace top_k { | |||
| auto apply_on_var_node( | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| auto&& topk = static_cast<const TopK&>(def); | |||
| return opr::TopK::make(inputs[0], inputs[1], topk.param())[0] | |||
| .node()->owner_opr(); | |||
| } | |||
| OP_TRAIT_REG(TopK, TopK) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .fallback(); | |||
| }} // top_k | |||
| namespace { namespace reduce { | |||
| auto apply_on_var_node( | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| auto&& reduce = static_cast<const Reduce&>(def); | |||
| if (inputs.size() > 1) { | |||
| return opr::Reduce::make(inputs[0], reduce.param(), inputs[1]); | |||
| } else { | |||
| return opr::Reduce::make(inputs[0], reduce.param()); | |||
| } | |||
| } | |||
| OP_TRAIT_REG(Reduce, Reduce) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .fallback(); | |||
| }} // reduce | |||
| namespace { namespace adaptive_pooling { | |||
| auto apply_on_var_node( | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| auto&& pool = static_cast<const AdaptivePooling&>(def); | |||
| return opr::AdaptivePooling::make(inputs[0], inputs[1], pool.param()); | |||
| } | |||
| OP_TRAIT_REG(AdaptivePooling, AdaptivePooling) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .fallback(); | |||
| }} // adaptive_pooling | |||
| namespace { namespace conv_bias { | |||
| auto apply_on_var_node( | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| auto&& conv = static_cast<const ConvBias&>(def); | |||
| cg::OperatorNodeConfig config{conv.dtype}; | |||
| if (inputs.size() == 2) { | |||
| return opr::ConvBias::make(inputs[0], inputs[1], conv.param(), conv.policy(), config); | |||
| } else if (inputs.size() == 3) { | |||
| return opr::ConvBias::make(inputs[0], inputs[1], inputs[2], conv.param(), conv.policy(), config); | |||
| } else if (inputs.size() == 4) { | |||
| return opr::ConvBias::make(inputs[0], inputs[1], inputs[2], inputs[3], conv.param(), conv.policy(), config); | |||
| } | |||
| mgb_assert(0); | |||
| } | |||
| OP_TRAIT_REG(ConvBias, ConvBias) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .fallback(); | |||
| }} // conv_bias | |||
| namespace { namespace batch_conv_bias { | |||
| auto apply_on_var_node( | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| auto&& conv = static_cast<const BatchConvBias&>(def); | |||
| cg::OperatorNodeConfig config{conv.dtype}; | |||
| if (inputs.size() == 2) { | |||
| return opr::BatchConvBias::make(inputs[0], inputs[1], conv.param(), conv.policy(), config); | |||
| } else if (inputs.size() == 3) { | |||
| return opr::BatchConvBias::make(inputs[0], inputs[1], inputs[2], conv.param(), conv.policy(), config); | |||
| } else if (inputs.size() == 4) { | |||
| return opr::BatchConvBias::make(inputs[0], inputs[1], inputs[2], inputs[3], conv.param(), conv.policy(), config); | |||
| } | |||
| mgb_assert(0); | |||
| } | |||
| OP_TRAIT_REG(BatchConvBias, BatchConvBias) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .fallback(); | |||
| }} // batch_conv_bias | |||
| namespace { namespace pooling { | |||
| auto apply_on_var_node( | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| auto&& pool = static_cast<const Pooling&>(def); | |||
| return opr::Pooling::make(inputs[0], pool.param()); | |||
| } | |||
| OP_TRAIT_REG(Pooling, Pooling) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .fallback(); | |||
| }} // pooling | |||
| namespace { namespace matrix_mul { | |||
| auto apply_on_var_node( | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| auto&& matmul = static_cast<const MatrixMul&>(def); | |||
| mgb_assert(inputs.size() == 2); | |||
| return opr::MatrixMul::make(inputs[0], inputs[1], matmul.param()); | |||
| } | |||
| OP_TRAIT_REG(MatrixMul, MatrixMul) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .fallback(); | |||
| }} // matrix_mul | |||
| namespace { namespace batched_matrix_mul { | |||
| auto apply_on_var_node( | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| auto&& matmul = static_cast<const BatchedMatrixMul&>(def); | |||
| mgb_assert(inputs.size() == 2); | |||
| return opr::BatchedMatrixMul::make(inputs[0], inputs[1], matmul.param()); | |||
| } | |||
| OP_TRAIT_REG(BatchedMatrixMul, BatchedMatrixMul) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .fallback(); | |||
| }} // batched_matrix_mul | |||
| namespace { namespace dot { | |||
| auto apply_on_var_node( | |||
| const OpDef&, | |||
| const VarNodeArray& inputs) { | |||
| mgb_assert(inputs.size() == 2); | |||
| return opr::Dot::make(inputs[0], inputs[1]); | |||
| } | |||
| OP_TRAIT_REG(Dot, Dot) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .fallback(); | |||
| }} // dot | |||
| namespace { namespace argsort { | |||
| auto apply_on_var_node( | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| auto&& argsort = static_cast<const Argsort&>(def); | |||
| return opr::Argsort::make(inputs[0], argsort.param()); | |||
| } | |||
| OP_TRAIT_REG(Argsort, Argsort) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .fallback(); | |||
| }} // argsort | |||
| namespace { namespace argmax { | |||
| auto apply_on_var_node( | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| auto&& argmax = static_cast<const Argmax&>(def); | |||
| return opr::Argmax::make(inputs[0], argmax.param()); | |||
| } | |||
| OP_TRAIT_REG(Argmax, Argmax) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .fallback(); | |||
| }} // argmax | |||
| namespace { namespace argmin { | |||
| auto apply_on_var_node( | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| auto&& argmin = static_cast<const Argmin&>(def); | |||
| return opr::Argmin::make(inputs[0], argmin.param()); | |||
| } | |||
| OP_TRAIT_REG(Argmin, Argmin) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .fallback(); | |||
| }} // argmin | |||
| namespace { namespace warp_perspective { | |||
| auto apply_on_var_node( | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| auto&& warp = static_cast<const WarpPerspective&>(def); | |||
| if (inputs.size() == 3) { | |||
| return opr::WarpPerspective::make(inputs[0], inputs[1], inputs[2], warp.param()); | |||
| } else { | |||
| mgb_assert(inputs.size() == 4); | |||
| return opr::WarpPerspective::make(inputs[0], inputs[1], inputs[2], inputs[3], warp.param()); | |||
| } | |||
| } | |||
| OP_TRAIT_REG(WarpPerspective, WarpPerspective) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .fallback(); | |||
| }} // warp_perspective | |||
| namespace { namespace group_local { | |||
| auto apply_on_var_node( | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| auto&& local = static_cast<const GroupLocal&>(def); | |||
| mgb_assert(inputs.size() == 2); | |||
| return opr::GroupLocal::make(inputs[0], inputs[1], local.param()); | |||
| } | |||
| OP_TRAIT_REG(GroupLocal, GroupLocal) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .fallback(); | |||
| }} // group_local | |||
| namespace { namespace indexing_one_hot { | |||
| auto apply_on_var_node( | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| auto&& op = static_cast<const IndexingOneHot&>(def); | |||
| mgb_assert(inputs.size() == 2); | |||
| return opr::IndexingOneHot::make(inputs[0], inputs[1], op.param()); | |||
| } | |||
| OP_TRAIT_REG(IndexingOneHot, IndexingOneHot) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .fallback(); | |||
| }} // indexing_one_hot | |||
| namespace { namespace indexing_set_one_hot { | |||
| auto apply_on_var_node( | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| auto&& op = static_cast<const IndexingSetOneHot&>(def); | |||
| mgb_assert(inputs.size() == 3); | |||
| return opr::IndexingSetOneHot::make(inputs[0], inputs[1], inputs[2], op.param()); | |||
| } | |||
| OP_TRAIT_REG(IndexingSetOneHot, IndexingSetOneHot) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .fallback(); | |||
| }} // indexing_set_one_hot | |||
| namespace { namespace typecvt { | |||
| auto apply_on_var_node( | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| auto&& op = static_cast<const TypeCvt&>(def); | |||
| mgb_assert(inputs.size() == 1); | |||
| return opr::TypeCvt::make(inputs[0], op.dtype); | |||
| } | |||
| OP_TRAIT_REG(TypeCvt, TypeCvt) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .fallback(); | |||
| }} // typecvt | |||
| namespace { namespace concat { | |||
| auto apply_on_var_node( | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| auto&& op = static_cast<const Concat&>(def); | |||
| cg::OperatorNodeConfig config{op.comp_node}; | |||
| return opr::Concat::make(inputs, op.axis, config); | |||
| } | |||
| OP_TRAIT_REG(Concat, Concat) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .fallback(); | |||
| }} // concat | |||
| namespace { namespace copy { | |||
| auto apply_on_var_node( | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| auto&& op = static_cast<const Copy&>(def); | |||
| mgb_assert(inputs.size() == 1); | |||
| cg::OperatorNodeConfig config{op.comp_node}; | |||
| return opr::Copy::make(inputs[0], config); | |||
| } | |||
| OP_TRAIT_REG(Copy, Copy) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .fallback(); | |||
| }} // copy | |||
| namespace { namespace identity { | |||
| auto apply_on_var_node( | |||
| const OpDef&, | |||
| const VarNodeArray& inputs) { | |||
| mgb_assert(inputs.size() == 1); | |||
| return opr::Identity::make(inputs[0]); | |||
| } | |||
| OP_TRAIT_REG(Identity, Identity) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .fallback(); | |||
| }} // identity | |||
| namespace { namespace uniform_rng { | |||
| auto apply_on_var_node( | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| auto&& op = static_cast<const UniformRNG&>(def); | |||
| mgb_assert(inputs.size() == 1); | |||
| return opr::UniformRNG::make(inputs[0], op.param()); | |||
| } | |||
| OP_TRAIT_REG(UniformRNG, UniformRNG) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .fallback(); | |||
| }} // uniform_rng | |||
| namespace { namespace gaussian_rng { | |||
| auto apply_on_var_node( | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| auto&& op = static_cast<const GaussianRNG&>(def); | |||
| mgb_assert(inputs.size() == 1); | |||
| return opr::GaussianRNG::make(inputs[0], op.param()); | |||
| } | |||
| OP_TRAIT_REG(GaussianRNG, GaussianRNG) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .fallback(); | |||
| }} // gaussian_rng | |||
| namespace { namespace roi_align { | |||
| auto apply_on_var_node( | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| auto&& op = static_cast<const ROIAlign&>(def); | |||
| mgb_assert(inputs.size() == 2); | |||
| return opr::ROIAlign::make(inputs[0], inputs[1], op.param()); | |||
| } | |||
| OP_TRAIT_REG(ROIAlign, ROIAlign) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .fallback(); | |||
| }} // roi_align | |||
| #if MGB_CUDA | |||
| namespace { namespace nvof { | |||
| auto apply_on_var_node( | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| auto&& op = static_cast<const NvOf&>(def); | |||
| mgb_assert(inputs.size() == 1); | |||
| return opr::NvOf::make(inputs[0], op.param()); | |||
| } | |||
| OP_TRAIT_REG(NvOf, NvOf) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .fallback(); | |||
| }} // nvof | |||
| #endif | |||
| namespace { namespace linspace { | |||
| auto apply_on_var_node( | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| auto&& op = static_cast<const Linspace&>(def); | |||
| mgb_assert(inputs.size() == 3); | |||
| cg::OperatorNodeConfig config{op.comp_node}; | |||
| return opr::Linspace::make(inputs[0], inputs[1], inputs[2], op.param(), config); | |||
| } | |||
| OP_TRAIT_REG(Linspace, Linspace) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .fallback(); | |||
| }} // linspace | |||
| namespace { namespace eye { | |||
| auto apply_on_var_node( | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| auto&& op = static_cast<const Eye&>(def); | |||
| mgb_assert(inputs.size() == 1); | |||
| cg::OperatorNodeConfig config{op.comp_node}; | |||
| opr::Eye::Param param{op.k, op.dtype.enumv()}; | |||
| return opr::Eye::make(inputs[0], param, config); | |||
| } | |||
| OP_TRAIT_REG(Eye, Eye) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .fallback(); | |||
| }} // eye | |||
| namespace { namespace roi_pooling { | |||
| auto apply_on_var_node( | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| auto&& op = static_cast<const ROIPooling&>(def); | |||
| mgb_assert(inputs.size() == 3); | |||
| return opr::ROIPooling::make(inputs[0], inputs[1], inputs[2], op.param()); | |||
| } | |||
| OP_TRAIT_REG(ROIPooling, ROIPooling) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .fallback(); | |||
| }} // roi_pooling | |||
| namespace { namespace remap { | |||
| auto apply_on_var_node( | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| auto&& op = static_cast<const Remap&>(def); | |||
| mgb_assert(inputs.size() == 2); | |||
| return opr::Remap::make(inputs[0], inputs[1], op.param()); | |||
| } | |||
| OP_TRAIT_REG(Remap, Remap) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .fallback(); | |||
| }} // remap | |||
| namespace { namespace reshape { | |||
| auto apply_on_var_node( | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| auto&& op = static_cast<const Reshape&>(def); | |||
| mgb_assert(inputs.size() == 2); | |||
| return opr::Reshape::make(inputs[0], inputs[1], op.param()); | |||
| } | |||
| OP_TRAIT_REG(Reshape, Reshape) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .fallback(); | |||
| }} // reshape | |||
| namespace { | |||
| auto get_index( | |||
| const VarNodeArray& inputs, size_t vidx, | |||
| const std::vector<std::tuple<int8_t, bool, bool, bool, bool>>& mask) { | |||
| size_t length = mask.size(); | |||
| opr::Subtensor::IndexDesc ret(length); | |||
| for (size_t i = 0; i < length; ++ i) { | |||
| auto&& [axis, begin, end, step, idx] = mask[i]; | |||
| ret[i].axis = axis; | |||
| if (idx) { | |||
| ret[i].idx = inputs[vidx++]; | |||
| } else { | |||
| mgb_assert(begin || end || step); | |||
| if (begin) ret[i].begin = inputs[vidx++]; | |||
| if (end) ret[i].end = inputs[vidx++]; | |||
| if (step) ret[i].step = inputs[vidx++]; | |||
| } | |||
| } | |||
| mgb_assert(vidx == inputs.size()); | |||
| return ret; | |||
| } | |||
| #define IN1 inputs[0] | |||
| #define IN2 inputs[0], inputs[1] | |||
| #define FANCY_INDEXING_IMPL(NAME, NR_INPUT) \ | |||
| namespace NAME##_impl { \ | |||
| auto apply_on_var_node( \ | |||
| const OpDef& def, \ | |||
| const VarNodeArray& inputs) { \ | |||
| auto&& op = static_cast<const NAME&>(def); \ | |||
| return opr::NAME::make(IN##NR_INPUT, get_index(inputs, NR_INPUT, op.items)); \ | |||
| } \ | |||
| OP_TRAIT_REG(NAME, NAME) \ | |||
| .apply_on_var_node(apply_on_var_node) \ | |||
| .fallback(); \ | |||
| } | |||
| FANCY_INDEXING_IMPL(Subtensor, 1) | |||
| FANCY_INDEXING_IMPL(SetSubtensor, 2) | |||
| FANCY_INDEXING_IMPL(IncrSubtensor, 2) | |||
| FANCY_INDEXING_IMPL(IndexingMultiAxisVec, 1) | |||
| FANCY_INDEXING_IMPL(IndexingSetMultiAxisVec, 2) | |||
| FANCY_INDEXING_IMPL(IndexingIncrMultiAxisVec, 2) | |||
| FANCY_INDEXING_IMPL(MeshIndexing, 1) | |||
| FANCY_INDEXING_IMPL(IncrMeshIndexing, 2) | |||
| FANCY_INDEXING_IMPL(SetMeshIndexing, 2) | |||
| FANCY_INDEXING_IMPL(BatchedMeshIndexing, 1) | |||
| FANCY_INDEXING_IMPL(BatchedIncrMeshIndexing, 2) | |||
| FANCY_INDEXING_IMPL(BatchedSetMeshIndexing, 2) | |||
| #undef FANCY_INDEXING_IMPL | |||
| #undef IN1 | |||
| #undef IN2 | |||
| } // anonymous namespace | |||
| namespace { namespace fake_quant { | |||
| auto apply_on_var_node( | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| auto&& op = static_cast<const FakeQuant&>(def); | |||
| mgb_assert(inputs.size() == 3); | |||
| return opr::FakeQuant::make(inputs[0], inputs[1], inputs[2], op.param()); | |||
| } | |||
| OP_TRAIT_REG(FakeQuant, FakeQuant) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .fallback(); | |||
| }} // fake_quant | |||
| namespace { namespace elemwise_multi_type { | |||
| auto apply_on_var_node( | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| auto&& op = static_cast<const ElemwiseMultiType&>(def); | |||
| OperatorNodeConfig config{op.dtype}; | |||
| return opr::ElemwiseMultiType::make(inputs, op.param(), config); | |||
| } | |||
| OP_TRAIT_REG(ElemwiseMultiType, ElemwiseMultiType) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .fallback(); | |||
| }} // fake_quant | |||
| namespace { namespace svd { | |||
| auto apply_on_var_node( | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| auto&& op = static_cast<const SVD&>(def); | |||
| mgb_assert(inputs.size() == 1); | |||
| return opr::SVD::make(inputs[0], op.param()); | |||
| } | |||
| OP_TRAIT_REG(SVD, SVD) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .fallback(); | |||
| }} // svd | |||
| } // namespace mgb::imperative | |||
| @@ -9,7 +9,7 @@ | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #include "megbrain/imperative/ops/tensor_manip.h" | |||
| #include "megbrain/imperative/ops/autogen.h" | |||
| #include "megbrain/imperative/ops/opr_attr.h" | |||
| #include "megbrain/opr/tensor_manip.h" | |||
| #include "../op_trait.h" | |||
| @@ -140,8 +140,4 @@ OP_TRAIT_REG(ParamPackConcat, ParamPackConcat, mgb::opr::ParamPackConcat) | |||
| .fallback(); | |||
| } // namespace | |||
| MGB_DYN_TYPE_OBJ_FINAL_IMPL(GetVarShape); | |||
| MGB_DYN_TYPE_OBJ_FINAL_IMPL(ParamPackSplit); | |||
| MGB_DYN_TYPE_OBJ_FINAL_IMPL(ParamPackConcat); | |||
| } // namespace mgb::imperative | |||
| @@ -130,7 +130,7 @@ void Profiler::start(uint32_t flags) { | |||
| // TODO: assign parent | |||
| entry.parent = 0; | |||
| // Record apply context and save to m_profile | |||
| entry.op = def.copy(); | |||
| entry.op = const_cast<OpDef&>(def).shared_from_this(); | |||
| for (auto&& input : inputs) { | |||
| entry.inputs.push_back({m_tensor_recorder.record_tensor(input), | |||
| shape2vector(input->layout()), | |||
| @@ -172,31 +172,31 @@ void Profiler::start(uint32_t flags) { | |||
| if (flags & PROFILE_FOOTPRINT) { | |||
| hook_apply_on_var_node->apply_hook( | |||
| [this](auto&& apply, const OpDef& def, | |||
| VarNodeArray inputs) -> cg::OperatorNodeBase* { | |||
| auto* operator_node = apply(def, std::move(inputs)); | |||
| VarNodeArray inputs) -> VarNodeArray { | |||
| auto vars = apply(def, std::move(inputs)); | |||
| std::remove_reference_t<decltype(m_entry_stack.top())> | |||
| top; | |||
| { | |||
| MGB_LOCK_GUARD(m_lock); | |||
| if (m_entry_stack.empty()) { | |||
| return operator_node; | |||
| return vars; | |||
| } | |||
| top = m_entry_stack.top(); | |||
| } | |||
| auto [current_op, current_entry, thread_id] = top; | |||
| if (current_op != &def || | |||
| thread_id != std::this_thread::get_id()) { | |||
| return operator_node; | |||
| return vars; | |||
| } | |||
| auto&& footprint_result = | |||
| footprint.calc_footprint(operator_node); | |||
| footprint.calc_footprint(vars[0]->owner_opr()); | |||
| current_entry->memory = footprint_result.memory; | |||
| current_entry->computation = | |||
| footprint_result.computation; | |||
| #if MGB_ENABLE_JSON | |||
| current_entry->param = footprint_result.param; | |||
| #endif | |||
| return operator_node; | |||
| return vars; | |||
| }); | |||
| } | |||
| m_hooker_list.push_back(std::move(hook_apply_on_physical_tensor)); | |||
| @@ -590,7 +590,7 @@ cg::OperatorNodeBase* ProxyGraph::get_proxy_opr( | |||
| for (size_t i = 0; i < inputs.size(); ++ i) { | |||
| vinputs[i] = InputPlaceholder::make(*m_graph, *inputs[i]).node(); | |||
| } | |||
| auto opr = OpDef::apply_on_var_node(opdef, vinputs); | |||
| auto opr = OpDef::apply_on_var_node(opdef, vinputs)[0]->owner_opr(); | |||
| mgb_assert(!opr->same_type<InputPlaceholder>()); | |||
| for (auto &&i : opr->input()) { | |||
| mgb_assert(i->owner_opr()->same_type<InputPlaceholder>()); | |||
| @@ -639,7 +639,7 @@ ProxyGraph::make_backward_graph( | |||
| return ret.first->second; | |||
| }; | |||
| auto inputs = make_input_place_holders(input_descs); | |||
| auto fwd = OpDef::apply_on_var_node(opdef, inputs); | |||
| auto fwd = OpDef::apply_on_var_node(opdef, inputs)[0]->owner_opr(); | |||
| auto&& outputs = fwd->usable_output(); | |||
| SmallVector<LogicalTensorDesc> output_descs; | |||
| for (auto&& i : outputs) { | |||
| @@ -799,7 +799,7 @@ cg::OperatorNodeBase* ProxyGraph::get_proxy_opr(const OpDef& opdef, | |||
| const SmallVector<LogicalTensorDesc>& inputs) { | |||
| mgb_assert(!m_cur_opr); | |||
| auto vinputs = make_input_place_holders(inputs); | |||
| return OpDef::apply_on_var_node(opdef, vinputs); | |||
| return OpDef::apply_on_var_node(opdef, vinputs)[0]->owner_opr(); | |||
| } | |||
| VarNodeArray ProxyGraph::make_input_place_holders(const SmallVector<LogicalTensorDesc>& inputs) { | |||
| @@ -26,13 +26,12 @@ struct BackwardGraphResult { | |||
| std::vector<bool> input_has_grad; | |||
| }; | |||
| class OpDef : public Hashable { | |||
| class OpDef : public Hashable, | |||
| public std::enable_shared_from_this<OpDef> { | |||
| mutable const OpTrait* m_trait = nullptr; | |||
| public: | |||
| virtual ~OpDef() = default; | |||
| virtual std::shared_ptr<OpDef> copy() const = 0; | |||
| static std::shared_ptr<OpDef> make_from_op_node( | |||
| cg::OperatorNodeBase* node); | |||
| @@ -40,7 +39,7 @@ public: | |||
| const OpDef& def, | |||
| const SmallVector<TensorPtr>& inputs); | |||
| static cg::OperatorNodeBase* apply_on_var_node( | |||
| static cg::VarNodeArray apply_on_var_node( | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs); | |||
| @@ -56,25 +55,17 @@ public: | |||
| const OpTrait* trait() const; | |||
| virtual size_t hash() const { | |||
| mgb_throw(MegBrainError, "not implemented"); | |||
| } | |||
| virtual size_t hash() const; | |||
| virtual bool is_same_st(const Hashable&) const { | |||
| mgb_throw(MegBrainError, "not implemented"); | |||
| } | |||
| virtual bool is_same_st(const Hashable&) const; | |||
| }; | |||
| template<typename T> | |||
| class OpDefImplBase : public OpDef { | |||
| public: | |||
| virtual std::shared_ptr<OpDef> copy() const override { | |||
| return std::shared_ptr<OpDef>(new T(this->cast_final_safe<T>())); | |||
| } | |||
| template<typename ...Args> | |||
| static std::shared_ptr<OpDef> make(const Args& ...args) { | |||
| return std::shared_ptr<OpDef>(new T(args...)); | |||
| static std::shared_ptr<OpDef> make(Args&& ...args) { | |||
| return std::make_shared<T>(std::forward<Args>(args)...); | |||
| } | |||
| }; | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * \file imperative/src/include/megbrain/imperative/ops/cond_take.h | |||
| * \file imperative/src/include/megbrain/imperative/ops/autogen.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| @@ -12,22 +12,15 @@ | |||
| #pragma once | |||
| #include "megbrain/imperative/op_def.h" | |||
| #include "megdnn/opr_param_defs.h" | |||
| #include "megbrain/opr/param_defs.h" | |||
| namespace mgb::imperative { | |||
| class CondTake : public OpDefImplBase<CondTake> { | |||
| MGB_DYN_TYPE_OBJ_FINAL_DECL; | |||
| public: | |||
| CondTake() = default; | |||
| #include "megbrain/utils/hash.h" | |||
| size_t hash() const override { | |||
| return reinterpret_cast<std::uintptr_t>(dyn_typeinfo()); | |||
| } | |||
| bool is_same_st(const Hashable& rhs) const override { | |||
| return rhs.dyn_typeinfo() == dyn_typeinfo(); | |||
| } | |||
| namespace mgb::imperative { | |||
| }; | |||
| // TODO: split into separate files to avoid recompiling all | |||
| // impl/ops/*.cpp on each modification of ops.td | |||
| #include "./opdef.h.inl" | |||
| } // namespace mgb::imperative | |||
| @@ -1,70 +0,0 @@ | |||
| /** | |||
| * \file imperative/src/include/megbrain/imperative/ops/batch_norm.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #pragma once | |||
| #include "megbrain/opr/dnn/batch_norm.h" | |||
| #include "megbrain/imperative/op_def.h" | |||
| #include "megbrain/utils/hash.h" | |||
| namespace mgb::imperative { | |||
| class BatchNorm : public OpDefImplBase<BatchNorm> { | |||
| MGB_DYN_TYPE_OBJ_FINAL_DECL; | |||
| public: | |||
| using Param = opr::BatchNorm::Param; | |||
| Param::ParamDim param_dim; | |||
| Param::FwdMode fwd_mode; | |||
| double epsilon; | |||
| double avg_factor; | |||
| float scale; | |||
| float bias; | |||
| BatchNorm() = default; | |||
| BatchNorm(const Param::ParamDim& param_dim_, const Param::FwdMode& fwd_mode_, | |||
| double epsilon_, double avg_factor_, float scale_, float bias_) | |||
| : param_dim(param_dim_), | |||
| fwd_mode(fwd_mode_), | |||
| epsilon(epsilon_), | |||
| avg_factor(avg_factor_), | |||
| scale(scale_), | |||
| bias(bias_) {} | |||
| size_t hash() const override { | |||
| XXHash xxhash{}; | |||
| auto append = [&xxhash](auto field){ | |||
| auto hash_val = HashTrait<decltype(field)>::eval(field); | |||
| xxhash.update(reinterpret_cast<void*>(&hash_val), sizeof(hash_val)); | |||
| }; | |||
| append(param_dim); | |||
| append(fwd_mode); | |||
| append(epsilon); | |||
| append(avg_factor); | |||
| append(scale); | |||
| append(bias); | |||
| return xxhash.digest(); | |||
| } | |||
| bool is_same_st(const Hashable& rhs_) const override { | |||
| auto&& rhs = static_cast<const BatchNorm&>(rhs_); | |||
| return rhs.param_dim == param_dim | |||
| && rhs.fwd_mode == fwd_mode | |||
| && rhs.epsilon == epsilon | |||
| && rhs.avg_factor == avg_factor | |||
| && rhs.scale == scale | |||
| && rhs.bias == bias; | |||
| } | |||
| }; | |||
| } // namespace mgb::imperative | |||
| @@ -1,35 +0,0 @@ | |||
| /** | |||
| * \file imperative/src/include/megbrain/imperative/ops/broadcast.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #pragma once | |||
| #include "megbrain/opr/tensor_manip.h" | |||
| #include "megbrain/imperative/ops/opr_attr.h" | |||
| #include "megbrain/imperative/op_def.h" | |||
| namespace mgb::imperative { | |||
| class Broadcast : public OpDefImplBase<Broadcast> { | |||
| MGB_DYN_TYPE_OBJ_FINAL_DECL; | |||
| public: | |||
| Broadcast() = default; | |||
| size_t hash() const override { | |||
| return reinterpret_cast<std::uintptr_t>(dyn_typeinfo()); | |||
| } | |||
| bool is_same_st(const Hashable& rhs) const override { | |||
| return true; | |||
| } | |||
| }; | |||
| } // namespace mgb::imperative | |||
| @@ -1,69 +0,0 @@ | |||
| /** | |||
| * \file imperative/src/include/megbrain/imperative/ops/collective_comm.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #pragma once | |||
| #include "megbrain/imperative/op_def.h" | |||
| #include "megbrain/opr/param_defs.h" | |||
| namespace mgb { | |||
| namespace imperative { | |||
| class CollectiveComm : public OpDefImplBase<CollectiveComm> { | |||
| MGB_DYN_TYPE_OBJ_FINAL_DECL; | |||
| public: | |||
| using Mode = megdnn::param::CollectiveComm::Mode; | |||
| CollectiveComm() = default; | |||
| CollectiveComm(const std::string& key_, size_t nr_devices_, | |||
| uint32_t rank_, bool is_root_, bool local_grad_, | |||
| const std::string& addr_, uint32_t port_, | |||
| const Mode& mode_, | |||
| const DType& dtype_, const std::string& backend_, | |||
| const std::string& comp_node_) | |||
| : key(key_), | |||
| nr_devices(nr_devices_), | |||
| rank(rank_), | |||
| is_root(is_root_), | |||
| local_grad(local_grad_), | |||
| addr(addr_), | |||
| port(port_), | |||
| mode(mode_), | |||
| dtype(dtype_), | |||
| backend(backend_), | |||
| comp_node(comp_node_) {} | |||
| std::string key; | |||
| size_t nr_devices; | |||
| uint32_t rank; | |||
| bool is_root; | |||
| bool local_grad; | |||
| std::string addr; | |||
| uint32_t port; | |||
| Mode mode; | |||
| DType dtype; | |||
| std::string backend; | |||
| std::string comp_node; | |||
| size_t hash() const override; | |||
| bool is_same_st(const Hashable& another) const override; | |||
| auto as_tuple() const{ | |||
| return std::tuple(key, nr_devices, rank, is_root, | |||
| local_grad, addr, port, mode, dtype, | |||
| backend, comp_node); | |||
| } | |||
| }; | |||
| } // namespace imperative | |||
| } // namespace mgb | |||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||
| @@ -1,42 +0,0 @@ | |||
| /** | |||
| * \file imperative/src/include/megbrain/imperative/ops/elemwise.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #pragma once | |||
| #include "megbrain/opr/basic_arith.h" | |||
| #include "megbrain/imperative/op_def.h" | |||
| namespace mgb::imperative { | |||
| class Elemwise : public OpDefImplBase<Elemwise> { | |||
| MGB_DYN_TYPE_OBJ_FINAL_DECL; | |||
| public: | |||
| using Mode = opr::Elemwise::Mode; | |||
| using ModeTrait = megdnn::Elemwise::ModeTrait; | |||
| Mode mode; | |||
| Elemwise() = default; | |||
| Elemwise(const Mode& mode_): mode(mode_) {} | |||
| size_t hash() const override { | |||
| return hash_pair_combine(mgb::hash(mode), reinterpret_cast<std::uintptr_t>(dyn_typeinfo())); | |||
| } | |||
| bool is_same_st(const Hashable& rhs_) const override { | |||
| auto&& rhs = static_cast<const Elemwise&>(rhs_); | |||
| return rhs.mode == mode; | |||
| } | |||
| }; | |||
| } // namespace mgb::imperative | |||
| @@ -1,77 +0,0 @@ | |||
| /** | |||
| * \file imperative/src/include/megbrain/imperative/ops/io_remote.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #pragma once | |||
| #include "megbrain/imperative/op_def.h" | |||
| namespace mgb { | |||
| namespace imperative { | |||
| class RemoteSend : public OpDefImplBase<RemoteSend> { | |||
| MGB_DYN_TYPE_OBJ_FINAL_DECL; | |||
| public: | |||
| RemoteSend() = default; | |||
| RemoteSend(const std::string& key_, const std::string& addr_, | |||
| uint32_t port_, uint32_t rank_to_) | |||
| : key(key_), | |||
| addr(addr_), | |||
| port(port_), | |||
| rank_to(rank_to_) {} | |||
| std::string key; | |||
| std::string addr; | |||
| uint32_t port; | |||
| uint32_t rank_to; | |||
| size_t hash() const override; | |||
| bool is_same_st(const Hashable& another) const override; | |||
| auto as_tuple() const{ | |||
| return std::tuple(key, addr, port, rank_to); | |||
| } | |||
| }; | |||
| class RemoteRecv : public OpDefImplBase<RemoteRecv> { | |||
| MGB_DYN_TYPE_OBJ_FINAL_DECL; | |||
| public: | |||
| RemoteRecv() = default; | |||
| RemoteRecv(const std::string& key_, const std::string& addr_, | |||
| uint32_t port_, uint32_t rank_from_, TensorShape shape_, | |||
| CompNode cn_, const DType& dtype_) | |||
| : key(key_), | |||
| addr(addr_), | |||
| port(port_), | |||
| rank_from(rank_from_), | |||
| cn(cn_), | |||
| shape(shape_), | |||
| dtype(dtype_) {} | |||
| std::string key; | |||
| std::string addr; | |||
| uint32_t port; | |||
| uint32_t rank_from; | |||
| CompNode cn; | |||
| TensorShape shape; | |||
| DType dtype; | |||
| size_t hash() const override; | |||
| bool is_same_st(const Hashable& another) const override; | |||
| auto as_tuple() const{ | |||
| return std::tuple(key, addr, port, rank_from, cn, dtype, shape.to_string()); | |||
| } | |||
| }; | |||
| } // namespace imperative | |||
| } // namespace mgb | |||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||
| @@ -1,41 +0,0 @@ | |||
| /** | |||
| * \file imperative/src/include/megbrain/imperative/ops/nms.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #pragma once | |||
| #include "megbrain/imperative/op_def.h" | |||
| namespace mgb::imperative { | |||
| class NMSKeep : public OpDefImplBase<NMSKeep> { | |||
| MGB_DYN_TYPE_OBJ_FINAL_DECL; | |||
| public: | |||
| float iou_thresh; //!< IoU threshold for overlapping | |||
| uint32_t max_output; //!< max number of output boxes per batch | |||
| NMSKeep() = default; | |||
| NMSKeep(float iou_thresh_, uint32_t max_output_): | |||
| iou_thresh(iou_thresh_), max_output(max_output_) {} | |||
| size_t hash() const override { | |||
| return hash_pair_combine( | |||
| hash_pair_combine(mgb::hash(iou_thresh), mgb::hash(max_output)), | |||
| reinterpret_cast<std::uintptr_t>(dyn_typeinfo())); | |||
| } | |||
| bool is_same_st(const Hashable& rhs_) const override { | |||
| auto&& rhs = static_cast<const NMSKeep&>(rhs_); | |||
| return rhs.iou_thresh == iou_thresh | |||
| && rhs.max_output == max_output; | |||
| } | |||
| }; | |||
| } // namespace mgb::imperative | |||
| @@ -1,99 +0,0 @@ | |||
| /** | |||
| * \file imperative/src/include/megbrain/imperative/ops/tensor_manip.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #pragma once | |||
| #include "megbrain/imperative/op_def.h" | |||
| #include "megbrain/utils/hash.h" | |||
| namespace mgb::imperative { | |||
| class GetVarShape : public OpDefImplBase<GetVarShape> { | |||
| MGB_DYN_TYPE_OBJ_FINAL_DECL; | |||
| public: | |||
| GetVarShape() = default; | |||
| size_t hash() const override { | |||
| return reinterpret_cast<std::uintptr_t>(dyn_typeinfo()); | |||
| } | |||
| bool is_same_st(const Hashable& rhs) const override { | |||
| return rhs.dyn_typeinfo() == dyn_typeinfo(); | |||
| } | |||
| }; | |||
| class ParamPackSplit : public OpDefImplBase<ParamPackSplit> { | |||
| MGB_DYN_TYPE_OBJ_FINAL_DECL; | |||
| public: | |||
| ParamPackSplit() = default; | |||
| ParamPackSplit(std::vector<dt_int32>& offsets_, | |||
| std::vector<std::vector<size_t>>& shapes_) | |||
| : offsets(offsets_), shapes(shapes_) {} | |||
| std::vector<dt_int32> offsets; | |||
| std::vector<std::vector<size_t>> shapes; | |||
| size_t hash() const override { | |||
| XXHash builder; | |||
| for (auto&& offset : offsets) { | |||
| builder.update(&offset, sizeof(offset)); | |||
| } | |||
| auto&& offset_cnt = offsets.size(); | |||
| builder.update(&offset_cnt, sizeof(offset_cnt)); | |||
| for (auto&& shape : shapes) { | |||
| for (auto&& dim_len : shape) { | |||
| builder.update(&dim_len, sizeof(dim_len)); | |||
| } | |||
| auto&& dim_cnt = shape.size(); | |||
| builder.update(&dim_cnt, sizeof(dim_cnt)); | |||
| } | |||
| auto&& shape_cnt = shapes.size(); | |||
| builder.update(&shape_cnt, sizeof(shape_cnt)); | |||
| return builder.digest(); | |||
| } | |||
| bool is_same_st(const Hashable& rhs) const override { | |||
| auto&& pps = rhs.cast_final_safe<ParamPackSplit>(); | |||
| return offsets == pps.offsets && shapes == pps.shapes; | |||
| } | |||
| }; | |||
| class ParamPackConcat : public OpDefImplBase<ParamPackConcat> { | |||
| MGB_DYN_TYPE_OBJ_FINAL_DECL; | |||
| public: | |||
| ParamPackConcat() = default; | |||
| ParamPackConcat(std::vector<dt_int32>& offsets_) | |||
| : offsets(offsets_) {} | |||
| std::vector<dt_int32> offsets; | |||
| size_t hash() const override { | |||
| XXHash builder; | |||
| for (auto&& offset : offsets) { | |||
| builder.update(&offset, sizeof(offset)); | |||
| } | |||
| auto&& offset_cnt = offsets.size(); | |||
| builder.update(&offset_cnt, sizeof(offset_cnt)); | |||
| return builder.digest(); | |||
| } | |||
| bool is_same_st(const Hashable& rhs) const override { | |||
| auto&& ppc = rhs.cast_final_safe<ParamPackConcat>(); | |||
| return offsets == ppc.offsets; | |||
| } | |||
| }; | |||
| } // namespace mgb::imperative | |||
| @@ -29,18 +29,18 @@ TEST(TestImperative, BackwardGraphBasic) { | |||
| using Param = opr::Elemwise::Param; | |||
| Param param{Param::Mode::MUL}; | |||
| OprAttr attr{"Elemwise", {}, {}}; | |||
| attr.param.write_pod(param); | |||
| auto attr = OprAttr::make("Elemwise"); | |||
| attr->cast_final_safe<OprAttr>().param.write_pod(param); | |||
| SmallVector<LogicalTensorDesc> input_descs; | |||
| for (auto&& i : inputs) { | |||
| input_descs.push_back({i->layout(), i->comp_node()}); | |||
| } | |||
| auto result = OpDef::make_backward_graph(attr, input_descs, {true, true}, {true}); | |||
| auto result = OpDef::make_backward_graph(*attr, input_descs, {true, true}, {true}); | |||
| auto&& save_for_backward = result.save_for_backward; | |||
| auto&& input_has_grad = result.input_has_grad; | |||
| auto outputs = OpDef::apply_on_physical_tensor(attr, inputs); | |||
| auto outputs = OpDef::apply_on_physical_tensor(*attr, inputs); | |||
| inputs.push_back(outputs[0]); | |||
| hvs.push_back(*gen({42})); | |||
| inputs.push_back(Tensor::make(hvs.back())); | |||
| @@ -82,16 +82,16 @@ TEST(TestImperative, BackwardGraphIdentity) { | |||
| SmallVector<TensorPtr> inputs; | |||
| inputs.push_back(a); | |||
| OprAttr attr{"Identity", {}, {}}; | |||
| attr.param.write_pod<megdnn::param::Empty>({}); | |||
| auto attr = OprAttr::make("Identity"); | |||
| attr->cast_final_safe<OprAttr>().param.write_pod<megdnn::param::Empty>({}); | |||
| SmallVector<LogicalTensorDesc> input_descs; | |||
| input_descs.push_back({a->layout(), a->comp_node()}); | |||
| auto result = OpDef::make_backward_graph(attr, input_descs, {true}, {true}); | |||
| auto result = OpDef::make_backward_graph(*attr, input_descs, {true}, {true}); | |||
| auto&& save_for_backward = result.save_for_backward; | |||
| auto&& input_has_grad = result.input_has_grad; | |||
| auto outputs = OpDef::apply_on_physical_tensor(attr, inputs); | |||
| auto outputs = OpDef::apply_on_physical_tensor(*attr, inputs); | |||
| inputs.push_back(outputs[0]); | |||
| inputs.push_back(dc); | |||
| mgb_assert(save_for_backward.size() == inputs.size()); | |||
| @@ -10,7 +10,7 @@ | |||
| */ | |||
| #include "./helper.h" | |||
| #include "megbrain/imperative/ops/collective_comm.h" | |||
| #include "megbrain/imperative/ops/autogen.h" | |||
| #include "megbrain/opr/mm_handler.h" | |||
| using namespace mgb; | |||
| @@ -32,12 +32,13 @@ TEST(TestImperative, AllReduceBasic) { | |||
| } | |||
| auto run = [&](std::shared_ptr<HostTensorND> hnd, uint32_t idx) { | |||
| imperative::CollectiveComm | |||
| def{"all_reduce", 2, idx, idx==0, false, server_addr, port, | |||
| auto def = | |||
| imperative::CollectiveComm::make( | |||
| megdnn::param::CollectiveComm::Mode::ALL_REDUCE_SUM, | |||
| dtype::Float32(), "nccl", ""}; | |||
| "all_reduce", 2, idx, idx==0, false, server_addr, port, | |||
| dtype::Float32(), "nccl", ""); | |||
| auto inp = Tensor::make(*hnd); | |||
| auto oup = OpDef::apply_on_physical_tensor(def, {inp}); | |||
| auto oup = OpDef::apply_on_physical_tensor(*def, {inp}); | |||
| HostTensorND host_v; | |||
| host_v.copy_from(oup[0]->dev_tensor()).sync(); | |||
| MGB_ASSERT_TENSOR_NEAR(*expect, host_v, 1e-6); | |||
| @@ -10,7 +10,7 @@ | |||
| */ | |||
| #include "./helper.h" | |||
| #include "megbrain/imperative/ops/cond_take.h" | |||
| #include "megbrain/imperative/ops/autogen.h" | |||
| using namespace mgb; | |||
| using namespace imperative; | |||
| @@ -119,7 +119,7 @@ void OprChecker::run(std::vector<InputSpec> inp_keys) { | |||
| }, inp_keys[i]); | |||
| sym_inp[i] = opr::SharedDeviceTensor::make(*graph, host_inp[i]).node(); | |||
| } | |||
| auto sym_oup = OpDef::apply_on_var_node(*m_op, sym_inp)->usable_output(); | |||
| auto sym_oup = OpDef::apply_on_var_node(*m_op, sym_inp); | |||
| size_t nr_oups = sym_oup.size(); | |||
| ComputingGraph::OutputSpec oup_spec(nr_oups); | |||
| SmallVector<HostTensorND> host_sym_oup(nr_oups); | |||
| @@ -10,7 +10,7 @@ | |||
| */ | |||
| #include "./helper.h" | |||
| #include "megbrain/imperative/ops/io_remote.h" | |||
| #include "megbrain/imperative/ops/autogen.h" | |||
| #include "megbrain/opr/mm_handler.h" | |||
| using namespace mgb; | |||
| @@ -33,24 +33,19 @@ TEST(TestImperative, IORemote) { | |||
| } | |||
| auto run_send = [&](std::shared_ptr<HostTensorND> hnd) { | |||
| imperative::RemoteSend def{"io_remote_test", server_addr, port, 1}; | |||
| auto def = imperative::RemoteSend::make( | |||
| "io_remote_test", server_addr, port, 1); | |||
| auto inp = Tensor::make(*hnd); | |||
| auto oup = OpDef::apply_on_physical_tensor(def, {inp}); | |||
| auto oup = OpDef::apply_on_physical_tensor(*def, {inp}); | |||
| }; | |||
| auto run_recv = [&](std::shared_ptr<HostTensorND> hnd) { | |||
| // auto&& shape = std::initializer_list{vector_size}; | |||
| imperative::RemoteRecv def{"io_remote_test", | |||
| server_addr, | |||
| port, | |||
| 0, | |||
| { | |||
| vector_size, | |||
| }, | |||
| CompNode::load("gpu1"), | |||
| dtype::Float32()}; | |||
| auto def = imperative::RemoteRecv::make( | |||
| "io_remote_test", server_addr, port, 0, | |||
| CompNode::load("gpu1"), TensorShape{vector_size}, | |||
| dtype::Float32()); | |||
| auto inp = Tensor::make(*hnd); | |||
| auto oup = OpDef::apply_on_physical_tensor(def, {inp}); | |||
| auto oup = OpDef::apply_on_physical_tensor(*def, {inp}); | |||
| HostTensorND host_v; | |||
| host_v.copy_from(oup[0]->dev_tensor()).sync(); | |||
| MGB_ASSERT_TENSOR_NEAR(*expect, host_v, 1e-6); | |||
| @@ -0,0 +1,14 @@ | |||
| # mgb tablegen executable | |||
| set(TABLE_TARGET mgb-mlir-autogen) | |||
| add_executable(${TABLE_TARGET} autogen.cpp) | |||
| target_include_directories(${TABLE_TARGET} PRIVATE ${MLIR_LLVM_INCLUDE_DIR}) | |||
| target_link_libraries(${TABLE_TARGET} PRIVATE LLVMTableGen MLIRTableGen LLVMSupport) | |||
| set(MGB_TABLEGEN_EXE ${TABLE_TARGET}) | |||
| # generate megbrain opdef c header and python bindings | |||
| set(LLVM_TARGET_DEFINITIONS ${MGE_IR_DIR}/ops.td) | |||
| tablegen(MGB opdef.h.inl ${MGE_IR_INCLUDE_DIRS} "--gen-cpp-header") | |||
| tablegen(MGB opdef.cpp.inl ${MGE_IR_INCLUDE_DIRS} "--gen-cpp-body") | |||
| tablegen(MGB opdef.py.inl ${MGE_IR_INCLUDE_DIRS} "--gen-python-binding") | |||
| add_custom_target(mgb_opdef ALL DEPENDS opdef.h.inl opdef.cpp.inl opdef.py.inl param_defs_tblgen) | |||
| set(MGB_OPDEF_OUT_DIR ${CMAKE_CURRENT_BINARY_DIR} PARENT_SCOPE) | |||
| @@ -0,0 +1,383 @@ | |||
| #include <iostream> | |||
| #include <unordered_map> | |||
| #include <functional> | |||
| #include "./helper.h" | |||
| using llvm::raw_ostream; | |||
| using llvm::RecordKeeper; | |||
| enum ActionType { | |||
| None, | |||
| CppHeader, | |||
| CppBody, | |||
| Pybind | |||
| }; | |||
| // NOLINTNEXTLINE | |||
| llvm::cl::opt<ActionType> action( | |||
| llvm::cl::desc("Action to perform:"), | |||
| llvm::cl::values(clEnumValN(CppHeader, "gen-cpp-header", | |||
| "Generate operator cpp header"), | |||
| clEnumValN(CppBody, "gen-cpp-body", | |||
| "Generate operator cpp body"), | |||
| clEnumValN(Pybind, "gen-python-binding", | |||
| "Generate pybind11 python bindings"))); | |||
| using MgbAttrWrapper = mlir::tblgen::MgbAttrWrapperBase; | |||
| using MgbEnumAttr = mlir::tblgen::MgbEnumAttrMixin; | |||
| using MgbHashableAttr = mlir::tblgen::MgbHashableAttrMixin; | |||
| using MgbAliasAttr = mlir::tblgen::MgbAliasAttrMixin; | |||
| using MgbOp = mlir::tblgen::MgbOpBase; | |||
| using MgbHashableOp = mlir::tblgen::MgbHashableOpMixin; | |||
| llvm::StringRef attr_to_ctype(const mlir::tblgen::Attribute& attr_) { | |||
| // Note: we have already registered the corresponding attr wrappers | |||
| // for following basic ctypes so we needn't handle them here | |||
| /* auto&& attr_type_name = attr.getAttrDefName(); | |||
| if (attr_type_name == "UI32Attr") { | |||
| return "uint32_t"; | |||
| } | |||
| if (attr_type_name == "UI64Attr") { | |||
| return "uint64_t"; | |||
| } | |||
| if (attr_type_name == "I32Attr") { | |||
| return "int32_t"; | |||
| } | |||
| if (attr_type_name == "F32Attr") { | |||
| return "float"; | |||
| } | |||
| if (attr_type_name == "F64Attr") { | |||
| return "double"; | |||
| } | |||
| if (attr_type_name == "StrAttr") { | |||
| return "std::string"; | |||
| } | |||
| if (attr_type_name == "BoolAttr") { | |||
| return "bool"; | |||
| }*/ | |||
| auto&& attr = llvm::cast<MgbAttrWrapper>(attr_); | |||
| if (auto e = llvm::dyn_cast<MgbEnumAttr>(&attr)) { | |||
| return e->getEnumName(); | |||
| } | |||
| return attr.getUnderlyingType(); | |||
| } | |||
| static void gen_op_def_c_header_single(raw_ostream &os, MgbOp& op) { | |||
| os << formatv( | |||
| "class {0} : public OpDefImplBase<{0}> {{\n" | |||
| " MGB_DYN_TYPE_OBJ_FINAL_DECL;\n\n" | |||
| "public:\n", | |||
| op.getCppClassName() | |||
| ); | |||
| // handle enum alias | |||
| for (auto &&i : op.getMgbAttributes()) { | |||
| if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) { | |||
| os << formatv( | |||
| " using {0} = {1};\n", | |||
| attr->getEnumName(), attr->getUnderlyingType() | |||
| ); | |||
| } | |||
| } | |||
| for (auto &&i : op.getMgbAttributes()) { | |||
| auto defaultValue = i.attr.getDefaultValue().str(); | |||
| if (!defaultValue.empty()) { | |||
| defaultValue = formatv(" = {0}", defaultValue); | |||
| } | |||
| os << formatv( | |||
| " {0} {1}{2};\n", | |||
| attr_to_ctype(i.attr), i.name, defaultValue | |||
| ); | |||
| } | |||
| auto gen_ctor = [&](auto&& paramList, auto&& memInitList, auto&& body) { | |||
| os << formatv( | |||
| " {0}({1}){2}{3}\n", | |||
| op.getCppClassName(), paramList, memInitList, body | |||
| ); | |||
| }; | |||
| gen_ctor("", "", " = default;"); | |||
| if (!op.getMgbAttributes().empty()) { | |||
| std::vector<std::string> paramList, initList; | |||
| for (auto &&i : op.getMgbAttributes()) { | |||
| paramList.push_back(formatv( | |||
| "{0} {1}_", attr_to_ctype(i.attr), i.name | |||
| )); | |||
| initList.push_back(formatv( | |||
| "{0}({0}_)", i.name | |||
| )); | |||
| } | |||
| gen_ctor(llvm::join(paramList, ", "), | |||
| ": " + llvm::join(initList, ", "), | |||
| " {}"); | |||
| } | |||
| auto packedParams = op.getPackedParams(); | |||
| if (!packedParams.empty()) { | |||
| std::vector<std::string> paramList, initList; | |||
| for (auto &&p : packedParams) { | |||
| auto&& paramFields = p.getFields(); | |||
| auto&& paramType = p.getFullName(); | |||
| auto&& paramName = formatv("packed_param_{0}", paramList.size()); | |||
| paramList.push_back( | |||
| paramFields.empty() ? paramType.str() | |||
| : formatv("{0} {1}", paramType, paramName) | |||
| ); | |||
| for (auto&& i : paramFields) { | |||
| initList.push_back(formatv( | |||
| "{0}({1}.{0})", i.name, paramName | |||
| )); | |||
| } | |||
| } | |||
| for (auto&& i : op.getExtraArguments()) { | |||
| paramList.push_back(formatv( | |||
| "{0} {1}_", attr_to_ctype(i.attr), i.name | |||
| )); | |||
| initList.push_back(formatv( | |||
| "{0}({0}_)", i.name | |||
| )); | |||
| } | |||
| gen_ctor(llvm::join(paramList, ", "), | |||
| initList.empty() ? "" : ": " + llvm::join(initList, ", "), | |||
| " {}"); | |||
| } | |||
| if (!packedParams.empty()) { | |||
| for (auto&& p : packedParams) { | |||
| auto accessor = p.getAccessor(); | |||
| if (!accessor.empty()) { | |||
| os << formatv( | |||
| " {0} {1}() const {{\n", | |||
| p.getFullName(), accessor | |||
| ); | |||
| std::vector<llvm::StringRef> fields; | |||
| for (auto&& i : p.getFields()) { | |||
| fields.push_back(i.name); | |||
| } | |||
| os << formatv( | |||
| " return {{{0}};\n", | |||
| llvm::join(fields, ", ") | |||
| ); | |||
| os << " }\n"; | |||
| } | |||
| } | |||
| } | |||
| if (auto decl = op.getExtraOpdefDecl()) { | |||
| os << decl.getValue(); | |||
| } | |||
| os << formatv( | |||
| "};\n\n" | |||
| ); | |||
| } | |||
| static void gen_op_def_c_body_single(raw_ostream &os, MgbOp& op) { | |||
| auto&& className = op.getCppClassName(); | |||
| os << formatv( | |||
| "MGB_DYN_TYPE_OBJ_FINAL_IMPL({0});\n\n", className | |||
| ); | |||
| auto formatMethImpl = [&](auto&& meth) { | |||
| return formatv( | |||
| "{0}_{1}_impl", className, meth | |||
| ); | |||
| }; | |||
| std::vector<std::string> methods; | |||
| if (auto hashable = llvm::dyn_cast<MgbHashableOp>(&op)) { | |||
| os << "namespace {\n"; | |||
| // generate hash() | |||
| mlir::tblgen::FmtContext ctx; | |||
| os << formatv( | |||
| "size_t {0}(const OpDef& def_) {{\n", | |||
| formatMethImpl("hash") | |||
| ); | |||
| os << formatv( | |||
| " auto op_ = def_.cast_final_safe<{0}>();\n" | |||
| " static_cast<void>(op_);\n", | |||
| className | |||
| ); | |||
| ctx.withSelf("op_"); | |||
| os << mlir::tblgen::tgfmt(hashable->getHashFunctionTemplate(), &ctx); | |||
| os << "}\n"; | |||
| // generate is_same_st() | |||
| os << formatv( | |||
| "bool {0}(const OpDef& lhs_, const OpDef& rhs_) {{\n", | |||
| formatMethImpl("is_same_st") | |||
| ); | |||
| os << formatv( | |||
| " auto a_ = lhs_.cast_final_safe<{0}>(),\n" | |||
| " b_ = rhs_.cast_final_safe<{0}>();\n" | |||
| " static_cast<void>(a_);\n" | |||
| " static_cast<void>(b_);\n", | |||
| className | |||
| ); | |||
| os << mlir::tblgen::tgfmt(hashable->getCmpFunctionTemplate(), &ctx, "a_", "b_"); | |||
| os << "}\n"; | |||
| os << "} // anonymous namespace\n"; | |||
| methods.push_back("hash"); | |||
| methods.push_back("is_same_st"); | |||
| } | |||
| if (!methods.empty()) { | |||
| os << formatv( | |||
| "OP_TRAIT_REG({0}, {0})", op.getCppClassName() | |||
| ); | |||
| for (auto&& i : methods) { | |||
| os << formatv( | |||
| "\n .{0}({1})", i, formatMethImpl(i) | |||
| ); | |||
| } | |||
| os << ";\n\n"; | |||
| } | |||
| } | |||
| struct PybindContext { | |||
| std::unordered_map<unsigned int, std::string> enumAlias; | |||
| }; | |||
| static void gen_op_def_pybind11_single(raw_ostream &os, MgbOp& op, PybindContext& ctx) { | |||
| auto class_name = op.getCppClassName(); | |||
| os << formatv( | |||
| "py::class_<{0}, std::shared_ptr<{0}>, OpDef> {0}Inst(m, \"{0}\");\n\n", | |||
| class_name | |||
| ); | |||
| for (auto&& i : op.getMgbAttributes()) { | |||
| if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) { | |||
| unsigned int enumID; | |||
| if (auto alias = llvm::dyn_cast<MgbAliasAttr>(attr)) { | |||
| auto&& aliasBase = alias->getAliasBase(); | |||
| enumID = | |||
| llvm::cast<MgbEnumAttr>(aliasBase) | |||
| .getBaseRecord()->getID(); | |||
| } else { | |||
| enumID = attr->getBaseRecord()->getID(); | |||
| } | |||
| auto&& enumAlias = ctx.enumAlias; | |||
| auto&& iter = enumAlias.find(enumID); | |||
| if (iter == enumAlias.end()) { | |||
| os << formatv( | |||
| "py::enum_<{0}::{1}>({0}Inst, \"{1}\")", | |||
| class_name, attr->getEnumName() | |||
| ); | |||
| std::vector<std::string> body; | |||
| for (auto&& i: attr->getEnumMembers()) { | |||
| os << formatv( | |||
| "\n .value(\"{2}\", {0}::{1}::{2})", | |||
| class_name, attr->getEnumName(), i | |||
| ); | |||
| body.push_back(formatv( | |||
| "if (str == \"{2}\") return {0}::{1}::{2};", | |||
| class_name, attr->getEnumName(), i | |||
| )); | |||
| } | |||
| os << formatv( | |||
| "\n .def(py::init([](const std::string& in) {" | |||
| "\n auto&& str = normalize_enum(in);" | |||
| "\n {0}" | |||
| "\n throw py::cast_error(\"invalid enum value \" + in);" | |||
| "\n }));\n", | |||
| llvm::join(body, "\n ") | |||
| ); | |||
| os << formatv( | |||
| "py::implicitly_convertible<std::string, {0}::{1}>();\n\n", | |||
| class_name, attr->getEnumName() | |||
| ); | |||
| enumAlias.emplace(enumID, formatv( | |||
| "{0}Inst.attr(\"{1}\")", class_name, attr->getEnumName() | |||
| )); | |||
| } else { | |||
| os << formatv( | |||
| "{0}Inst.attr(\"{1}\") = {2};\n\n", | |||
| class_name, attr->getEnumName(), iter->second | |||
| ); | |||
| } | |||
| } | |||
| } | |||
| // generate op class binding | |||
| os << formatv("{0}Inst", class_name); | |||
| bool hasDefaultCtor = op.getMgbAttributes().empty(); | |||
| if (!hasDefaultCtor) { | |||
| os << "\n .def(py::init<"; | |||
| std::vector<llvm::StringRef> targs; | |||
| for (auto &&i : op.getMgbAttributes()) { | |||
| targs.push_back(i.attr.getReturnType()); | |||
| } | |||
| os << llvm::join(targs, ", "); | |||
| os << ">()"; | |||
| for (auto &&i : op.getMgbAttributes()) { | |||
| os << formatv(", py::arg(\"{0}\")", i.name); | |||
| auto defaultValue = i.attr.getDefaultValue(); | |||
| if (!defaultValue.empty()) { | |||
| os << formatv(" = {0}", defaultValue); | |||
| } else { | |||
| hasDefaultCtor = true; | |||
| } | |||
| } | |||
| os << ")"; | |||
| } | |||
| if (hasDefaultCtor) { | |||
| os << "\n .def(py::init<>())"; | |||
| } | |||
| for (auto &&i : op.getMgbAttributes()) { | |||
| os << formatv( | |||
| "\n .def_readwrite(\"{0}\", &{1}::{0})", | |||
| i.name, class_name | |||
| ); | |||
| } | |||
| os << ";\n\n"; | |||
| } | |||
| static void for_each_operator(raw_ostream &os, RecordKeeper &keeper, | |||
| std::function<void(raw_ostream&, MgbOp&)> callback) { | |||
| auto op_base_class = keeper.getClass("Op"); | |||
| ASSERT(op_base_class, "could not find base class Op"); | |||
| for (auto&& i: keeper.getDefs()) { | |||
| auto&& r = i.second; | |||
| if (r->isSubClassOf(op_base_class)) { | |||
| auto op = mlir::tblgen::Operator(r.get()); | |||
| if (op.getDialectName().str() == "mgb") { | |||
| std::cerr << "\033[34;15m" << "Generating " << r->getName().str() << "\033[0m" << std::endl; | |||
| callback(os, llvm::cast<MgbOp>(op)); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| static bool gen_op_def_c_header(raw_ostream &os, RecordKeeper &keeper) { | |||
| for_each_operator(os, keeper, gen_op_def_c_header_single); | |||
| return false; | |||
| } | |||
| static bool gen_op_def_c_body(raw_ostream &os, RecordKeeper &keeper) { | |||
| for_each_operator(os, keeper, gen_op_def_c_body_single); | |||
| return false; | |||
| } | |||
| static bool gen_op_def_pybind11(raw_ostream &os, RecordKeeper &keeper) { | |||
| PybindContext ctx; | |||
| using namespace std::placeholders; | |||
| for_each_operator(os, keeper, | |||
| std::bind(gen_op_def_pybind11_single, _1, _2, std::ref(ctx))); | |||
| return false; | |||
| } | |||
| int main(int argc, char **argv) { | |||
| llvm::InitLLVM y(argc, argv); | |||
| llvm::cl::ParseCommandLineOptions(argc, argv); | |||
| if (action == ActionType::CppHeader) { | |||
| return TableGenMain(argv[0], &gen_op_def_c_header); | |||
| } | |||
| if (action == ActionType::CppBody) { | |||
| return TableGenMain(argv[0], &gen_op_def_c_body); | |||
| } | |||
| if (action == ActionType::Pybind) { | |||
| return TableGenMain(argv[0], &gen_op_def_pybind11); | |||
| } | |||
| return -1; | |||
| } | |||
| @@ -0,0 +1,228 @@ | |||
| #include <string> | |||
| #include <vector> | |||
| #include "llvm/Support/CommandLine.h" | |||
| #include "llvm/Support/FormatVariadic.h" | |||
| #include "llvm/Support/InitLLVM.h" | |||
| #include "llvm/Support/Signals.h" | |||
| #include "llvm/TableGen/Main.h" | |||
| #include "llvm/TableGen/Record.h" | |||
| #include "llvm/TableGen/TableGenBackend.h" | |||
| #include "mlir/TableGen/Attribute.h" | |||
| #include "mlir/TableGen/Format.h" | |||
| #include "mlir/TableGen/Operator.h" | |||
| using llvm::formatv; | |||
| using llvm::StringRef; | |||
| using llvm::Record; | |||
| #define ASSERT(stmt, msg) \ | |||
| if (!(stmt)) { \ | |||
| std::cerr << "\033[1;31m" \ | |||
| << "tablegen autogen abort due to: " << msg \ | |||
| << "\033[0m" << std::endl; \ | |||
| exit(1); \ | |||
| } | |||
| namespace mlir { | |||
| namespace tblgen { | |||
| template<typename ConcreteType> | |||
| struct MgbInterface : public ConcreteType { | |||
| MgbInterface() = delete; | |||
| MgbInterface(const MgbInterface&) = delete; | |||
| MgbInterface(MgbInterface&&) = delete; | |||
| ~MgbInterface() = delete; | |||
| }; | |||
| struct MgbAttrWrapperBase : public MgbInterface<Attribute> { | |||
| private: | |||
| struct RecordVisitor : public MgbInterface<Constraint> { | |||
| public: | |||
| static bool classof(const Constraint*) { | |||
| return true; | |||
| } | |||
| const llvm::Record* getDef() const { | |||
| return def; | |||
| } | |||
| }; | |||
| public: | |||
| static bool classof(const Attribute* attr) { | |||
| return attr->isSubClassOf("MgbAttrWrapperBase"); | |||
| } | |||
| const llvm::Record* getBaseRecord() const { | |||
| auto baseAttr = getBaseAttr(); | |||
| return llvm::cast<RecordVisitor>(baseAttr).getDef(); | |||
| } | |||
| llvm::StringRef getUnderlyingType() const { | |||
| return def->getValueAsString("underlyingType"); | |||
| } | |||
| }; | |||
| struct MgbEnumAttrMixin : public MgbAttrWrapperBase { | |||
| static bool classof(const Attribute* attr) { | |||
| return attr->getBaseAttr().isSubClassOf("MgbEnumAttrMixin"); | |||
| } | |||
| llvm::StringRef getParentNamespace() const { | |||
| return getBaseRecord()->getValueAsString("parentNamespce"); | |||
| } | |||
| llvm::StringRef getEnumName() const { | |||
| return getBaseRecord()->getValueAsString("enumName"); | |||
| } | |||
| std::vector<StringRef> getEnumMembers() const { | |||
| return getBaseRecord()->getValueAsListOfStrings("enumMembers"); | |||
| } | |||
| }; | |||
| struct MgbHashableAttrMixin : public MgbAttrWrapperBase { | |||
| static bool classof(const Attribute* attr) { | |||
| return attr->getBaseAttr().isSubClassOf("MgbHashableAttrMixin"); | |||
| } | |||
| llvm::StringRef getHashFunctionTemplate() const { | |||
| return getBaseRecord()->getValueAsString("hashFunction"); | |||
| } | |||
| llvm::StringRef getCmpFunctionTemplate() const { | |||
| return getBaseRecord()->getValueAsString("cmpFunction"); | |||
| } | |||
| }; | |||
| struct MgbAliasAttrMixin : public MgbAttrWrapperBase { | |||
| static bool classof(const Attribute* attr) { | |||
| return attr->getBaseAttr().isSubClassOf("MgbAliasAttrMixin"); | |||
| } | |||
| Attribute getAliasBase() const { | |||
| return Attribute(getBaseRecord()->getValueAsDef("aliasBase")); | |||
| } | |||
| }; | |||
| class MgbPackedParam { | |||
| public: | |||
| MgbPackedParam(Record* def_): def(def_) { | |||
| auto&& dag = def->getValueAsDag("fields"); | |||
| for (size_t i = 0; i < dag->getNumArgs(); ++ i) { | |||
| fields.push_back({ | |||
| dag->getArgNameStr(i), | |||
| Attribute(llvm::cast<llvm::DefInit>(dag->getArg(i))) | |||
| }); | |||
| } | |||
| } | |||
| llvm::StringRef getFullName() const { | |||
| return def->getValueAsString("fullName"); | |||
| } | |||
| std::vector<NamedAttribute> getFields() const { | |||
| return fields; | |||
| } | |||
| llvm::StringRef getAccessor() const { | |||
| return def->getValueAsString("paramAccessor"); | |||
| } | |||
| private: | |||
| std::vector<NamedAttribute> fields; | |||
| Record* def; | |||
| }; | |||
| struct MgbOpBase : public MgbInterface<Operator> { | |||
| static bool isPackedParam(Record* def) { | |||
| return def->isSubClassOf("MgbPackedParamBase"); | |||
| } | |||
| public: | |||
| static bool classof(const Operator* op) { | |||
| return op->getDef().isSubClassOf("MgbOp"); | |||
| } | |||
| std::vector<NamedAttribute> getMgbAttributes() const { | |||
| std::vector<NamedAttribute> ret; | |||
| for (auto&& i: getAttributes()) { | |||
| if (isa<MgbAttrWrapperBase>(i.attr)) { | |||
| ret.push_back(i); | |||
| } | |||
| } | |||
| return ret; | |||
| } | |||
| std::vector<NamedAttribute> getExtraArguments() const { | |||
| std::vector<NamedAttribute> ret; | |||
| auto&& dag = getDef().getValueAsDag("extraArguments"); | |||
| for (size_t i = 0; i < dag->getNumArgs(); ++ i) { | |||
| ret.push_back({ | |||
| dag->getArgNameStr(i), | |||
| Attribute(llvm::cast<llvm::DefInit>(dag->getArg(i))) | |||
| }); | |||
| } | |||
| return ret; | |||
| } | |||
| llvm::Optional<StringRef> getExtraOpdefDecl() const { | |||
| return getDef().getValueAsOptionalString("extraOpdefDecl"); | |||
| } | |||
| std::vector<MgbPackedParam> getPackedParams() const { | |||
| std::vector<MgbPackedParam> ret; | |||
| for (auto&& i : getDef().getValueAsListOfDefs("dnnParams")) { | |||
| if (isPackedParam(i)) { | |||
| ret.emplace_back(i); | |||
| } | |||
| } | |||
| return ret; | |||
| } | |||
| }; | |||
| struct MgbHashableOpMixin : public MgbOpBase { | |||
| private: | |||
| std::string getDefaultHashFunction() const { | |||
| std::string body = " size_t val = mgb::hash($_self.dyn_typeinfo());\n"; | |||
| if (!getMgbAttributes().empty()) { | |||
| auto getHashFunc = [&](auto&& iter) { | |||
| auto&& attr = llvm::cast<MgbHashableAttrMixin>(iter.attr); | |||
| return attr.getHashFunctionTemplate(); | |||
| }; | |||
| mlir::tblgen::FmtContext ctx; | |||
| for (auto&& it: getMgbAttributes()) { | |||
| body += formatv( | |||
| " val = mgb::hash_pair_combine(val, {0});\n", | |||
| mlir::tblgen::tgfmt(getHashFunc(it), &ctx, "$_self." + it.name) | |||
| ); | |||
| } | |||
| } | |||
| body += " return val;\n"; | |||
| return body; | |||
| } | |||
| std::string getDefaultCmpFunction() const { | |||
| std::string body; | |||
| if (!getMgbAttributes().empty()) { | |||
| mlir::tblgen::FmtContext ctx; | |||
| for (auto&& it : getMgbAttributes()) { | |||
| auto&& attr = llvm::cast<MgbHashableAttrMixin>(it.attr); | |||
| body += formatv( | |||
| " if ({0}) return false;\n", | |||
| mlir::tblgen::tgfmt(attr.getCmpFunctionTemplate(), | |||
| &ctx, "$0." + it.name, "$1." + it.name) | |||
| ); | |||
| } | |||
| } | |||
| body += " return true;\n"; | |||
| return body; | |||
| } | |||
| public: | |||
| static bool classof(const Operator* op) { | |||
| return op->getDef().isSubClassOf("MgbHashableOpMixin"); | |||
| } | |||
| std::string getHashFunctionTemplate() const { | |||
| if (auto f = getDef().getValueAsOptionalString("hashFunction")) { | |||
| return f.getValue().str(); | |||
| } | |||
| return getDefaultHashFunction(); | |||
| } | |||
| std::string getCmpFunctionTemplate() const { | |||
| if (auto f = getDef().getValueAsOptionalString("cmpFunction")) { | |||
| return f.getValue().str(); | |||
| } | |||
| return getDefaultCmpFunction(); | |||
| } | |||
| }; | |||
| } // namespace tblgen | |||
| } // namespace mlir | |||
| @@ -11,7 +11,7 @@ endif() | |||
| # TODO: turn python binding into a static/object library | |||
| add_executable(imperative_test ${SOURCES} ${SRCS}) | |||
| target_include_directories(imperative_test PRIVATE ${MGB_TEST_DIR}/include ../src/include) | |||
| target_include_directories(imperative_test PRIVATE ${MGB_TEST_DIR}/include ../src/include ${MGB_OPDEF_OUT_DIR}) | |||
| # Python binding | |||
| target_include_directories(imperative_test PRIVATE ${MODULE_SRC_INCLUDE} ${PYTHON_INCLUDE_DIRS} ${NUMPY_INCLUDE_DIR}) | |||
| @@ -0,0 +1,257 @@ | |||
| /** | |||
| * \file src/core/include/megbrain/ir/base.td | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #ifndef MGB_BASE | |||
| #define MGB_BASE | |||
| include "mlir/IR/OpBase.td" | |||
| def Mgb_Dialect : Dialect { | |||
| let name = "mgb"; | |||
| let cppNamespace = "mgb::dialect"; | |||
| } | |||
| // -- mgb Attr mixin | |||
| class MgbAttrWrapperBase<string className> { | |||
| string underlyingType = className; | |||
| int recursionDepth = 0; | |||
| } | |||
| class MgbHashableAttrMixin { | |||
| string hashFunction = "mgb::hash($0)"; | |||
| // return 0 for eq, else for ne | |||
| string cmpFunction = "$0 != $1"; | |||
| } | |||
| class MgbEnumAttrMixin<string namespace, string name, list<string> members> { | |||
| string parentNamespace = namespace; | |||
| string enumName = name; | |||
| list<string> enumMembers = members; | |||
| } | |||
| class MgbAttrWrapper; | |||
| class MgbAliasAttrMixin<Attr base> { | |||
| Attr aliasBase = base; | |||
| } | |||
| // -- mgb custom Attr | |||
| // TODO: CPred and description | |||
| class MgbAttrWrapper<string className>: | |||
| Attr<CPred<"true">, "TODO">, MgbAttrWrapperBase<className> { | |||
| let returnType = underlyingType; | |||
| } | |||
| class HashableAttr<string className>: | |||
| MgbAttrWrapper<className>, MgbHashableAttrMixin; | |||
| // -- basic types | |||
| class MgbIntegerAttrBase<string CType> : HashableAttr<CType> { | |||
| let storageType = "::mlir::IntegerAttr"; | |||
| } | |||
| class MgbSignlessIntegerAttrBase<string CType> : MgbIntegerAttrBase<CType> { | |||
| let convertFromStorage = "static_cast<" # underlyingType # ">($_self.getInt())"; | |||
| let constBuilderCall = "$_builder.getIntegerAttr($_builder.getIntegerType(sizeof(" # underlyingType # ") * 4), $0)"; | |||
| } | |||
| class MgbSignedIntegerAttrBase<string CType> : MgbIntegerAttrBase<CType> { | |||
| let convertFromStorage = "static_cast<" # underlyingType # ">($_self.getSInt())"; | |||
| let constBuilderCall = "$_builder.getIntegerAttr($_builder.getIntegerType(sizeof(" # underlyingType # ") * 4, true), $0)"; | |||
| } | |||
| class MgbUnsignedIntegerAttrBase<string CType> : MgbIntegerAttrBase<CType> { | |||
| let convertFromStorage = "static_cast<" # underlyingType # ">($_self.getUInt())"; | |||
| let constBuilderCall = "$_builder.getIntegerAttr($_builder.getIntegerType(sizeof(" # underlyingType # ") * 4, false), $0)"; | |||
| } | |||
| def MgbI8Attr: MgbSignlessIntegerAttrBase<"int8_t">; | |||
| def MgbI32Attr: MgbSignlessIntegerAttrBase<"int32_t">; | |||
| def MgbI64Attr: MgbSignlessIntegerAttrBase<"int64_t">; | |||
| def MgbUI32Attr: MgbUnsignedIntegerAttrBase<"uint32_t">; | |||
| def MgbUI64Attr: MgbUnsignedIntegerAttrBase<"uint64_t">; | |||
| def MgbSizeTAddr: MgbUnsignedIntegerAttrBase<"size_t">; | |||
| class MgbFloatAttrBase<string CType, string DType> : HashableAttr<CType> { | |||
| let storageType = "::mlir::FloatAttr"; | |||
| let convertFromStorage = "static_cast<" # underlyingType # ">($_self.getValueAsDouble())"; | |||
| let constBuilderCall = "$_builder.getFloatAttr($_builder.get" # DType # "Type(), $0)"; | |||
| } | |||
| def MgbF32Attr : MgbFloatAttrBase<"float", "F32">; | |||
| def MgbF64Attr : MgbFloatAttrBase<"double", "F64">; | |||
| def MgbBoolAttr : HashableAttr<"bool"> { | |||
| let storageType = "::mlir::BoolAttr"; | |||
| let constBuilderCall = "$_builder.getBoolAttr($0)"; | |||
| } | |||
| def MgbStringAttr : HashableAttr<"std::string"> { | |||
| let storageType = "::mlir::StringAttr"; | |||
| let convertFromStorage = "$_self.getValue().str()"; | |||
| let constBuilderCall = "$_builder.getStringAttr($0)"; // llvm::StringRef implicit ctor | |||
| } | |||
| class MgbArrayAttr<MgbAttrWrapper elem>: | |||
| HashableAttr<"std::vector<" # elem.underlyingType # ">"> { | |||
| let storageType = "::mlir::ArrayAttr"; | |||
| let recursionDepth = !add(elem.recursionDepth, 1); | |||
| let convertFromStorage = | |||
| "[&] {\n" | |||
| " " # underlyingType # " ret" # recursionDepth # ";\n" | |||
| " std::for_each($_self.begin(), $_self.end(), [&](auto&& i" # recursionDepth # ") {\n" | |||
| " ret" # recursionDepth # ".push_back(\n" | |||
| " " # !subst("$_self", "i" # recursionDepth # ".template cast<" # elem.storageType # ">()", "" # elem.convertFromStorage) # "\n" | |||
| " );\n" | |||
| " });\n" | |||
| " return ret" # recursionDepth # ";}()"; | |||
| let constBuilderCall = | |||
| "[&] {\n" | |||
| " std::vector<mlir::Attribute> ret" # recursionDepth # ";\n" | |||
| " std::for_each($0.begin(), $0.end(), [&](auto&& i" # recursionDepth # ") {\n" | |||
| " ret" # recursionDepth # ".push_back(\n" | |||
| " " # !subst("$0", "i" # recursionDepth, "" # elem.constBuilderCall) # "\n" | |||
| " );\n" | |||
| " });\n" | |||
| " return $_builder.getArrayAttr(ret" # recursionDepth # ");" | |||
| "}()"; | |||
| } | |||
| defvar EmptyStrList = !listsplat("", 0); | |||
| class StrListAppend<list<string> l, string s> { | |||
| list<string> r = !listconcat(l, !listsplat(s, 1)); | |||
| } | |||
| class TupleConvertFromStorage<MgbAttrWrapper attr, int idx> { | |||
| string r = !subst( | |||
| "$_self", | |||
| "$_self[" # !cast<string>(idx) # "].template cast<"# attr.storageType #">()", | |||
| "" # attr.convertFromStorage); | |||
| } | |||
| class TupleConstBuilderCall<MgbAttrWrapper attr, int idx> { | |||
| string r = !subst( | |||
| "$0", | |||
| "std::get<" # !cast<string>(idx) # ">($0)", | |||
| "" # attr.constBuilderCall); | |||
| } | |||
| class ApplyTupleConvertFromStorage<list<MgbAttrWrapper> args> { | |||
| list<string> r = !foldl( | |||
| EmptyStrList, args, l, arg, StrListAppend<l, TupleConvertFromStorage<arg, !size(l)>.r>.r); | |||
| } | |||
| class ApplyTupleConstBuilderCall<list<MgbAttrWrapper> args> { | |||
| list<string> r = !foldl( | |||
| EmptyStrList, args, l, arg, StrListAppend<l, TupleConstBuilderCall<arg, !size(l)>.r>.r); | |||
| } | |||
| class MgbTupleAttr<list<MgbAttrWrapper> args>: | |||
| HashableAttr<"std::tuple<" # StrJoin<!foreach(i, args, i.underlyingType)>.result # ">"> { | |||
| let storageType = "::mlir::ArrayAttr"; | |||
| let convertFromStorage = "std::make_tuple(" # StrJoin<ApplyTupleConvertFromStorage<args>.r>.result # ")"; | |||
| let constBuilderCall = "$_builder.getArrayAttr({" # StrJoin<ApplyTupleConstBuilderCall<args>.r>.result # "})"; | |||
| } | |||
| // -- enum types | |||
| class MgbEnumAttr<string namespace, string enumName, list<string> members>: | |||
| HashableAttr<namespace # "::" # enumName>, MgbEnumAttrMixin<namespace, enumName, members> { | |||
| let storageType = "::mlir::IntegerAttr"; | |||
| let convertFromStorage = "static_cast<" # returnType # ">($_self.getInt())"; | |||
| let constBuilderCall = "$_builder.getI32IntegerAttr(static_cast<int32_t>($0))"; | |||
| let hashFunction = "mgb::enumhash()($0)"; | |||
| } | |||
| class MgbEnumAliasAttr<string namespace, string enumName, MgbEnumAttr base>: | |||
| MgbEnumAttr<namespace, enumName, base.enumMembers>, MgbAliasAttrMixin<base>; | |||
| // -- other types | |||
| def MgbDTypeAttr: HashableAttr<"::megdnn::DType"> { | |||
| let storageType = "::mlir::IntegerAttr"; | |||
| let convertFromStorage = underlyingType # "::from_enum(static_cast<::megdnn::DTypeEnum>($_self.getInt()))"; | |||
| let constBuilderCall = "$_builder.getI32IntegerAttr(static_cast<int32_t>($0.enumv()))"; | |||
| let hashFunction = "mgb::hash($0.handle())"; | |||
| } | |||
| def MgbCompNodeAttr: HashableAttr<"::mgb::CompNode"> { | |||
| let storageType = "::mlir::StringAttr"; | |||
| let convertFromStorage = underlyingType # "::load($_self.getValue().str())"; | |||
| let constBuilderCall = "$_builder.getStringAttr($0.to_string_logical())"; | |||
| } | |||
| def MgbTensorShapeAttr: HashableAttr<"::megdnn::TensorShape"> { | |||
| let storageType = "::mlir::ArrayAttr"; | |||
| let hashFunction = "mgb::PODHash<size_t>::perform($0.shape, $0.ndim)"; | |||
| let cmpFunction = "!$0.eq_shape($1)"; | |||
| defvar elemInst = MgbSizeTAddr; | |||
| let convertFromStorage = | |||
| "[&] {\n" | |||
| " " # underlyingType # " ret;\n" | |||
| " std::for_each($_self.begin(), $_self.end(), [&ret](auto&& i) {\n" | |||
| " ret[ret.ndim ++] = " # !subst("$_self", "i.template cast<"# elemInst.storageType #">()", "" # elemInst.convertFromStorage) # ";\n" | |||
| " });\n" | |||
| " return ret;}()"; | |||
| let constBuilderCall = | |||
| "[&] {\n" | |||
| " std::vector<mlir::Attribute> ret;\n" | |||
| " for (size_t i = 0; i < $0.ndim; ++ i) {\n" | |||
| " ret.push_back(\n" | |||
| " " # !subst("$0", "$0[i]", "" # elemInst.constBuilderCall) # "\n" | |||
| " );\n" | |||
| " }\n" | |||
| " return $_builder.getArrayAttr(ret);" | |||
| "}()"; | |||
| } | |||
| class MgbDefaultValuedAttr<MgbAttrWrapper attr, string value>: | |||
| DefaultValuedAttr<attr, value>, MgbAttrWrapperBase<attr.underlyingType> { | |||
| // Note: this class is similar to DefaultValuedAttr but with extra | |||
| // meta informations which are used by mgb dialect tblgen, so this | |||
| // has to be kept up to date with class MgbAttrWrapperMixin | |||
| let recursionDepth = attr.recursionDepth; | |||
| } | |||
| // -- dnn params | |||
| class MgbParamBase<string className> { | |||
| string paramType = className; | |||
| string fullName = "::megdnn::param::" # paramType; | |||
| dag fields = ?; | |||
| } | |||
| class MgbPackedParamBase<string className, string accessor>: | |||
| MgbParamBase<className> { | |||
| string paramAccessor = accessor; | |||
| } | |||
| // -- mgb ops | |||
| class MgbHashableOpMixin { | |||
| string hashFunction = ?; | |||
| string cmpFunction = ?; | |||
| } | |||
| class MgbOp<string mnemonic, list<MgbParamBase> params=[], list<OpTrait> traits=[]>: | |||
| Op<Mgb_Dialect, mnemonic, traits> { | |||
| dag inputs = (ins); | |||
| dag extraArguments = (ins); | |||
| // TODO: remove it | |||
| code extraOpdefDecl = ?; | |||
| let arguments = !con( | |||
| !foldl(inputs, params, args, param, !con(args, param.fields)), | |||
| extraArguments); | |||
| list<MgbParamBase> dnnParams = params; | |||
| } | |||
| class MgbHashableOp<string mnemonic, list<MgbParamBase> params=[], list<OpTrait> traits=[]>: | |||
| MgbOp<mnemonic, params, traits>, MgbHashableOpMixin; | |||
| #endif // MGB_BASE | |||
| @@ -0,0 +1,240 @@ | |||
| /** | |||
| * \file src/core/include/megbrain/ir/ops.td | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #ifndef MGB_OPS | |||
| #define MGB_OPS | |||
| include "base.td" | |||
| include "param_defs.td" | |||
| include "mlir/Interfaces/SideEffectInterfaces.td" | |||
| def Elemwise : MgbHashableOp<"Elemwise", [ElemwiseParam], [NoSideEffect]> { | |||
| let inputs = (ins Variadic<AnyType>:$input); | |||
| let results = (outs AnyType); | |||
| } | |||
| def Reduce: MgbHashableOp<"Reduce", [ReduceParam]>; | |||
| def TypeCvt: MgbHashableOp<"TypeCvt", [], [NoSideEffect]> { | |||
| let inputs = (ins AnyType:$inputs); | |||
| let extraArguments = (ins | |||
| TypeAttr:$idtype, | |||
| MgbDTypeAttr:$dtype | |||
| ); | |||
| let results = (outs AnyType); | |||
| } | |||
| def MatrixMul: MgbHashableOp<"MatrixMul", [MatrixMulParam]>; | |||
| def BatchedMatrixMul: MgbHashableOp<"BatchedMatmul", [MatrixMulParam]>; | |||
| def Dot: MgbHashableOp<"Dot", [EmptyParam]>; | |||
| def SVD: MgbHashableOp<"SVD", [SVDParam]>; | |||
| def Convolution : MgbHashableOp<"Convolution", [ConvolutionParam, ExecutionPolicyParamBase<"policy">]>; | |||
| def ConvolutionBackwardData: MgbHashableOp<"ConvolutionBackwardData", [ConvolutionParam, ExecutionPolicyParamBase<"policy">]>; | |||
| def GroupLocal: MgbHashableOp<"GroupLocal", [ConvolutionParam]>; | |||
| def Pooling: MgbHashableOp<"Pooling", [PoolingParam]>; | |||
| def AdaptivePooling : MgbHashableOp<"AdaptivePooling", [AdaptivePoolingParam]>; | |||
| def ROIPooling: MgbHashableOp<"ROIPooling", [ROIPoolingParam]>; | |||
| def ConvBias : MgbHashableOp<"ConvBias", [ConvBiasParam, ExecutionPolicyParamBase<"policy">]> { | |||
| let extraArguments = (ins | |||
| MgbDTypeAttr:$dtype | |||
| ); | |||
| } | |||
| def BatchConvBias : MgbHashableOp<"BatchConvBias", [BatchConvBiasParam, ExecutionPolicyParamBase<"policy">]> { | |||
| let extraArguments = (ins | |||
| MgbDTypeAttr:$dtype | |||
| ); | |||
| } | |||
| def BatchNorm : MgbHashableOp<"BatchNorm", [BNParam]>; | |||
| def ROIAlign: MgbHashableOp<"ROIAlign", [ROIAlignParam]>; | |||
| def WarpPerspective: MgbHashableOp<"WarpPerspective", [WarpPerspectiveParam]>; | |||
| def Remap: MgbHashableOp<"Remap", [RemapParam]>; | |||
| def IndexingOneHot: MgbHashableOp<"IndexingOneHot", [AxisParam]>; | |||
| def IndexingSetOneHot: MgbHashableOp<"IndexingSetOneHot", [AxisParam]>; | |||
| def Copy: MgbHashableOp<"Copy"> { | |||
| let extraArguments = (ins | |||
| MgbCompNodeAttr:$comp_node | |||
| ); | |||
| } | |||
| def Argsort: MgbHashableOp<"Argsort", [ArgsortParam]>; | |||
| def Argmax : MgbHashableOp<"Argmax", [AxisParam]>; | |||
| def Argmin : MgbHashableOp<"Argmin", [AxisParam]>; | |||
| def CondTake : MgbHashableOp<"CondTake">; | |||
| def TopK: MgbHashableOp<"TopK", [TopKParam]>; | |||
| def NvOf: MgbHashableOp<"NvOf", [NvOfParam]>; | |||
| def UniformRNG: MgbHashableOp<"UniformRNG", [UniformRNGParam]> { | |||
| let hashFunction = [{return mgb::hash($_self.dyn_typeinfo());}]; | |||
| let cmpFunction = [{return true;}]; | |||
| } | |||
| def GaussianRNG: MgbHashableOp<"GaussianRNG", [GaussianRNGParam]> { | |||
| let hashFunction = [{ | |||
| return mgb::hash_pair_combine( | |||
| mgb::hash($_self.dyn_typeinfo()), | |||
| mgb::hash_pair_combine(mgb::hash($_self.mean), mgb::hash($_self.std))); | |||
| }]; | |||
| let cmpFunction = [{return $0.mean == $1.mean && $0.std == $1.std;}]; | |||
| } | |||
| def Linspace: MgbHashableOp<"Linspace", [LinspaceParam]> { | |||
| let extraArguments = (ins | |||
| MgbCompNodeAttr:$comp_node | |||
| ); | |||
| } | |||
| def Eye: MgbHashableOp<"Eye", [EyeParam]> { | |||
| let extraArguments = (ins | |||
| MgbCompNodeAttr:$comp_node | |||
| ); | |||
| } | |||
| def GetVarShape : MgbHashableOp<"GetVarShape">; | |||
| def Concat: MgbHashableOp<"Concat", [AxisParam]> { | |||
| let extraArguments = (ins | |||
| MgbCompNodeAttr:$comp_node | |||
| ); | |||
| } | |||
| def Broadcast : MgbHashableOp<"Broadcast", [EmptyParam]>; | |||
| def Identity: MgbHashableOp<"Identity">; | |||
| def CollectiveComm : MgbHashableOp<"CollectiveComm", [CollectiveCommParam]> { | |||
| let extraArguments = (ins | |||
| MgbStringAttr:$key, | |||
| MgbUI32Attr:$nr_devices, | |||
| MgbUI32Attr:$rank, | |||
| MgbBoolAttr:$is_root, | |||
| MgbBoolAttr:$local_grad, | |||
| MgbStringAttr:$addr, | |||
| MgbUI32Attr:$port, | |||
| MgbDTypeAttr:$dtype, | |||
| MgbStringAttr:$backend, | |||
| MgbStringAttr:$comp_node | |||
| ); | |||
| } | |||
| def RemoteSend : MgbHashableOp<"RemoteSend"> { | |||
| let extraArguments = (ins | |||
| MgbStringAttr:$key, | |||
| MgbStringAttr:$addr, | |||
| MgbUI32Attr:$port, | |||
| MgbUI32Attr:$rank_to | |||
| ); | |||
| } | |||
| def RemoteRecv : MgbHashableOp<"RemoteRecv"> { | |||
| let extraArguments = (ins | |||
| MgbStringAttr:$key, | |||
| MgbStringAttr:$addr, | |||
| MgbUI32Attr:$port, | |||
| MgbUI32Attr:$rank_from, | |||
| MgbCompNodeAttr:$cn, | |||
| MgbTensorShapeAttr:$shape, | |||
| MgbDTypeAttr:$dtype | |||
| ); | |||
| } | |||
| def NMSKeep : MgbHashableOp<"NMSKeep"> { | |||
| let extraArguments = (ins | |||
| MgbF32Attr:$iou_thresh, | |||
| MgbUI32Attr:$max_output | |||
| ); | |||
| } | |||
| def ParamPackSplit : MgbHashableOp<"ParamPackSplit"> { | |||
| let extraArguments = (ins | |||
| MgbArrayAttr<MgbI32Attr>:$offsets, | |||
| MgbArrayAttr<MgbArrayAttr<MgbSizeTAddr>>:$shapes | |||
| ); | |||
| } | |||
| def ParamPackConcat : MgbHashableOp<"ParamPackConcat"> { | |||
| let extraArguments = (ins | |||
| MgbArrayAttr<MgbI32Attr>:$offsets | |||
| ); | |||
| } | |||
| def Dimshuffle: MgbHashableOp<"Dimshuffle"> { | |||
| let inputs = (ins AnyMemRef:$input); | |||
| let extraArguments = (ins MgbArrayAttr<MgbI32Attr>:$pattern); | |||
| let results = (outs AnyMemRef); | |||
| } | |||
| def Reshape: MgbHashableOp<"Reshape", [OptionalAxisV1Param]>; | |||
| // TODO: merge Add/Remove Axis into AxisAddRemove as megbrain? | |||
| def AddAxis: MgbHashableOp<"AddAxis"> { | |||
| let extraArguments = (ins | |||
| MgbArrayAttr<MgbI32Attr>:$axis | |||
| ); | |||
| } | |||
| def RemoveAxis: MgbHashableOp<"RemoveAxis"> { | |||
| let extraArguments = (ins | |||
| MgbArrayAttr<MgbI32Attr>:$axis | |||
| ); | |||
| } | |||
| class FancyIndexingBase<string name>: MgbHashableOp<name> { | |||
| let extraArguments = (ins | |||
| MgbArrayAttr<MgbTupleAttr< | |||
| [MgbI8Attr, MgbBoolAttr, MgbBoolAttr, MgbBoolAttr, MgbBoolAttr]>>:$items | |||
| ); | |||
| } | |||
| def Subtensor: FancyIndexingBase<"Subtensor">; | |||
| def SetSubtensor: FancyIndexingBase<"SetSubtensor">; | |||
| def IncrSubtensor: FancyIndexingBase<"IncrSubtensor">; | |||
| def IndexingMultiAxisVec: FancyIndexingBase<"IndexingMultiAxisVec">; | |||
| def IndexingSetMultiAxisVec: FancyIndexingBase<"IndexingSetMultiAxisVec">; | |||
| def IndexingIncrMultiAxisVec: FancyIndexingBase<"IndexingIncrMultiAxisVec">; | |||
| def MeshIndexing: FancyIndexingBase<"MeshIndexing">; | |||
| def IncrMeshIndexing: FancyIndexingBase<"IncrMeshIndexing">; | |||
| def SetMeshIndexing: FancyIndexingBase<"SetMeshIndexing">; | |||
| def BatchedMeshIndexing: FancyIndexingBase<"BatchedMeshIndexing">; | |||
| def BatchedIncrMeshIndexing: FancyIndexingBase<"BatchedIncrMeshIndexing">; | |||
| def BatchedSetMeshIndexing: FancyIndexingBase<"BatchedSetMeshIndexing">; | |||
| def FakeQuant: MgbHashableOp<"FakeQuant", [FakeQuantParam]>; | |||
| def ElemwiseMultiType: MgbHashableOp<"ElemwiseMultiType", [ElemwiseMultiTypeParam]> { | |||
| let extraArguments = (ins | |||
| MgbDTypeAttr:$dtype | |||
| ); | |||
| } | |||
| #endif // MGB_OPS | |||
| @@ -47,3 +47,4 @@ pushd MegRay/third_party >/dev/null | |||
| popd >/dev/null | |||
| git submodule update --init pybind11 | |||
| git submodule update --init llvm-project | |||