GitOrigin-RevId: 11411b6964
tags/v1.0.0-rc1
| @@ -247,10 +247,6 @@ if(MGE_BUILD_IMPERATIVE_RT) | |||
| set(CMAKE_CXX_STANDARD 17) | |||
| endif() | |||
| if(MGE_BUILD_IMPERATIVE_RT) | |||
| set(MGE_BUILD_SDK OFF) | |||
| endif() | |||
| if(NOT MGE_WITH_CUDA) | |||
| message("-- Disable distributed support, as CUDA is not enabled.") | |||
| set(MGE_WITH_DISTRIBUTED OFF) | |||
| @@ -697,9 +693,7 @@ if(MGE_WITH_PYTHON_MODULE) | |||
| endif() | |||
| if(MGE_WITH_TEST AND MGE_ENABLE_RTTI) | |||
| if(NOT MGE_BUILD_IMPERATIVE_RT) | |||
| add_subdirectory(test) | |||
| endif() | |||
| add_subdirectory(test) | |||
| endif() | |||
| if(TARGET mgb) | |||
| @@ -66,9 +66,7 @@ if(MGE_WITH_CUDA) | |||
| endif() | |||
| if(MGE_WITH_TEST) | |||
| if(NOT MGE_BUILD_IMPERATIVE_RT) | |||
| add_subdirectory(test) | |||
| endif() | |||
| add_subdirectory(test) | |||
| endif() | |||
| add_subdirectory(src) | |||
| @@ -0,0 +1,5 @@ | |||
| Makefile | |||
| /test/imperative_test | |||
| *.so | |||
| /python/megengine/core/ops/_internal/generated_ops.py | |||
| /python/megengine/core/ops/_internal/param_defs.py | |||
| @@ -0,0 +1,110 @@ | |||
| find_package(NumPy REQUIRED) | |||
| set(PACKAGE_NAME megengine) | |||
| set(PACKAGE_NAME ${PACKAGE_NAME} PARENT_SCOPE) | |||
| set(MODULE_NAME _imperative_rt) | |||
| set(MODULE_NAME ${MODULE_NAME} PARENT_SCOPE) | |||
| file(GLOB_RECURSE SRCS src/impl/*.cpp src/include/*.h python/src/*.cpp python/src/*.h) | |||
| set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DMGB_WITH_IMPERATIVE=1") | |||
| file(GLOB_RECURSE OPR_DECL_SRCS "${PROJECT_SOURCE_DIR}/src/**/*.oprdecl") | |||
| file(GLOB_RECURSE PYTHON_SRCS python/${PACKAGE_NAME}/*.py) | |||
| list(REMOVE_ITEM PYTHON_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/python/megengine/core/ops/_internal/generated_ops.py ${CMAKE_CURRENT_SOURCE_DIR}/python/megengine/core/ops/_internal/param_defs.py) | |||
| file(GLOB_RECURSE ALL_HEADERS src/cpp/megbrain_pubapi.h | |||
| ${PROJECT_SOURCE_DIR}/src/core/include/* | |||
| ${PROJECT_SOURCE_DIR}/src/opr/include/* | |||
| ${PROJECT_SOURCE_DIR}/src/serialization/include/* | |||
| ${PROJECT_SOURCE_DIR}/src/plugin/include/* | |||
| ${PROJECT_SOURCE_DIR}/dnn/include/*) | |||
| set(MEGENGINE_DIR ${CMAKE_CURRENT_BINARY_DIR}/python/) | |||
| set(GEN_OPS_DIR ${MEGENGINE_DIR}/${PACKAGE_NAME}/core/ops/_internal) | |||
| file(MAKE_DIRECTORY ${GEN_OPS_DIR}) | |||
| set(GEN_OPS_FILE ${GEN_OPS_DIR}/generated_ops.py) | |||
| set(GEN_OP_PARAMS_FILE ${MEGENGINE_DIR}/${PACKAGE_NAME}/core/ops/_internal/param_defs.py) | |||
| set(GEN_OP_PARAMS_TEMPLATE ${CMAKE_CURRENT_SOURCE_DIR}/python/tools/ops.tpl.py) | |||
| ##################### generate python opr_param_defs.py ############## | |||
| file(COPY ${PROJECT_SOURCE_DIR}/dnn/scripts/opr_param_defs.py DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) | |||
| file(READ ${PROJECT_SOURCE_DIR}/tools/param_defs/mgb_opr_param_defs.py CONTENTS) | |||
| file(APPEND ${CMAKE_CURRENT_BINARY_DIR}/opr_param_defs.py ${CONTENTS}) | |||
| add_custom_command( | |||
| OUTPUT ${GEN_OPS_FILE} | |||
| COMMAND ${CMAKE_COMMAND} -E touch ${MEGENGINE_DIR}/${PACKAGE_NAME}/core/${MODULE_NAME}.so ${GEN_OPS_FILE} ${GEN_OP_PARAMS_FILE} | |||
| COMMAND ${CMAKE_COMMAND} -E copy_directory ${CMAKE_CURRENT_SOURCE_DIR}/python/${PACKAGE_NAME} ${MEGENGINE_DIR}/${PACKAGE_NAME} | |||
| COMMAND ${CMAKE_COMMAND} -E remove -f ${MEGENGINE_DIR}/${PACKAGE_NAME}/core/${MODULE_NAME}.so ${GEN_OPS_FILE} ${GEN_OP_PARAMS_FILE} | |||
| COMMAND ${PYTHON_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/python/tools/gen_ops.py ${OPR_DECL_SRCS} -o ${GEN_OPS_FILE} | |||
| COMMAND ${CMAKE_COMMAND} -E copy_directory ${CMAKE_CURRENT_SOURCE_DIR}/python/test ${MEGENGINE_DIR}/${PACKAGE_NAME}/test | |||
| COMMAND ${PYTHON_EXECUTABLE} ${PROJECT_SOURCE_DIR}/dnn/scripts/gen_param_defs.py -t py --imperative ${CMAKE_CURRENT_BINARY_DIR}/opr_param_defs.py ${GEN_OP_PARAMS_FILE} | |||
| DEPENDS ${OPR_DECL_SRCS} ${PYTHON_SRCS} ${ALL_HEADERS} ${GEN_OP_PARAMS_TEMPLATE} | |||
| VERBATIM | |||
| ) | |||
| add_custom_target(gen_opr_py DEPENDS ${GEN_OPS_FILE}) | |||
| ##################### generate opdef c header and python binding ############## | |||
| set(OP_DEF_HEADER_OUT_DIR ${CMAKE_CURRENT_BINARY_DIR}/src/include) | |||
| file(MAKE_DIRECTORY ${OP_DEF_HEADER_OUT_DIR}/megbrain/imperative/opdef) | |||
| set(OP_DEF_HEADER ${OP_DEF_HEADER_OUT_DIR}/megbrain/imperative/opdef/all.h) | |||
| set(OP_DEF_PYTHON_BINDING_OUT_DIR ${MEGENGINE_DIR}/${PACKAGE_NAME}/src) | |||
| file(MAKE_DIRECTORY ${OP_DEF_PYTHON_BINDING_OUT_DIR}) | |||
| set(OP_DEF_PYTHON_BINDING ${OP_DEF_PYTHON_BINDING_OUT_DIR}/opdef.inl) | |||
| set(OP_PARAM_DEF ${CMAKE_CURRENT_BINARY_DIR}/opr_param_defs.py) | |||
| set(GEN_OP_DEF_SCRIPT ${CMAKE_CURRENT_SOURCE_DIR}/python/tools/gen_op_defs.py) | |||
| add_custom_command( | |||
| OUTPUT ${OP_DEF_HEADER} ${OP_DEF_PYTHON_BINDING} | |||
| COMMAND ${PYTHON_EXECUTABLE} ${GEN_OP_DEF_SCRIPT} ${OP_PARAM_DEF} ${OP_DEF_HEADER} | |||
| COMMAND ${PYTHON_EXECUTABLE} ${GEN_OP_DEF_SCRIPT} -t py ${OP_PARAM_DEF} ${OP_DEF_PYTHON_BINDING} | |||
| DEPENDS ${GEN_OP_DEF_SCRIPT} ${OP_PARAM_DEF} | |||
| VERBATIM | |||
| ) | |||
| add_custom_target(gen_op_def_internal DEPENDS ${OP_DEF_HEADER} ${OP_DEF_PYTHON_BINDING}) | |||
| add_library(gen_op_def INTERFACE) | |||
| target_include_directories(gen_op_def INTERFACE ${OP_DEF_HEADER_OUT_DIR} ${OP_DEF_PYTHON_BINDING_OUT_DIR}) | |||
| add_dependencies(gen_op_def gen_op_def_internal) | |||
| ##################### end of opdef generation ######################### | |||
| set(VERSION_SCRIPT ${CMAKE_CURRENT_SOURCE_DIR}/src/version.ld) | |||
| add_custom_target(_version_ld SOURCES ${VERSION_SCRIPT}) | |||
| add_subdirectory(${PROJECT_SOURCE_DIR}/third_party/pybind11 ${PROJECT_BINARY_DIR}/third_party/pybind11) | |||
| pybind11_add_module(${MODULE_NAME} NO_EXTRAS ${SRCS}) | |||
| target_link_libraries(${MODULE_NAME} PRIVATE gen_op_def megbrain megdnn -Wl,--version-script=${VERSION_SCRIPT}) | |||
| if (MGE_WITH_DISTRIBUTED) | |||
| message("Imperative configured to link megray") | |||
| target_link_libraries(${MODULE_NAME} PRIVATE megray) | |||
| endif() | |||
| target_include_directories(${MODULE_NAME} PUBLIC src/include PRIVATE ${PYTHON_INCLUDE_DIRS} ${NUMPY_INCLUDE_DIR}) | |||
| target_compile_definitions(${MODULE_NAME} PRIVATE MODULE_NAME=${MODULE_NAME}) | |||
| target_compile_options(${MODULE_NAME} PRIVATE -Wno-unused-parameter) | |||
| if(CXX_SUPPORT_WCLASS_MEMACCESS) | |||
| target_compile_options(${MODULE_NAME} PRIVATE "-Wno-class-memaccess") | |||
| endif() | |||
| set_target_properties(${MODULE_NAME} PROPERTIES | |||
| SUFFIX ${CMAKE_SHARED_LIBRARY_SUFFIX} | |||
| LIBRARY_OUTPUT_DIRECTORY ${MEGENGINE_DIR}/${PACKAGE_NAME}/core | |||
| ) | |||
| add_dependencies(${MODULE_NAME} gen_opr_py _version_ld) | |||
| if(MGE_WITH_TEST AND MGE_ENABLE_RTTI) | |||
| add_subdirectory(test) | |||
| endif() | |||
| add_custom_command( | |||
| TARGET ${MODULE_NAME} POST_BUILD | |||
| COMMAND ${CMAKE_COMMAND} -E copy ${PROJECT_SOURCE_DIR}/LICENSE ${PROJECT_SOURCE_DIR}/ACKNOWLEDGMENTS ${PROJECT_BINARY_DIR} | |||
| COMMAND ${CMAKE_COMMAND} -E copy_directory ${CMAKE_CURRENT_SOURCE_DIR}/python/megengine ${CMAKE_CURRENT_BINARY_DIR}/python/megengine | |||
| COMMAND ${CMAKE_COMMAND} -E copy_directory ${CMAKE_CURRENT_SOURCE_DIR}/python/test ${CMAKE_CURRENT_BINARY_DIR}/python/test | |||
| COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_SOURCE_DIR}/python/setup.py ${CMAKE_CURRENT_BINARY_DIR}/python/setup.py | |||
| COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_SOURCE_DIR}/python/requires.txt ${CMAKE_CURRENT_BINARY_DIR}/python/requires.txt | |||
| COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_SOURCE_DIR}/python/requires-style.txt ${CMAKE_CURRENT_BINARY_DIR}/python/requires-style.txt | |||
| COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_SOURCE_DIR}/python/requires-test.txt ${CMAKE_CURRENT_BINARY_DIR}/python/requires-test.txt | |||
| ) | |||
| @@ -0,0 +1,25 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import os | |||
| import sys | |||
| from .core._imperative_rt.utils import _set_fork_exec_path_for_timed_func | |||
| from .device import * | |||
| from .logger import enable_debug_log, get_logger, set_log_file, set_log_level | |||
| from .serialization import load, save | |||
| from .tensor import Tensor, tensor | |||
| from .tensor_nn import Buffer, Parameter | |||
| from .version import __version__ | |||
| _set_fork_exec_path_for_timed_func( | |||
| sys.executable, | |||
| os.path.join(os.path.dirname(__file__), "utils", "_timed_func_fork_exec_entry.py"), | |||
| ) | |||
| del _set_fork_exec_path_for_timed_func | |||
| @@ -0,0 +1,12 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import os | |||
| import sys | |||
| from .tensor import Tensor | |||
| @@ -0,0 +1,46 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import numpy as np | |||
| from ._imperative_rt import CompNode | |||
| class Device: | |||
| def __init__(self, device=None): | |||
| if device is None: | |||
| self._cn = CompNode() | |||
| elif isinstance(device, Device): | |||
| self._cn = device._cn | |||
| elif isinstance(device, CompNode): | |||
| self._cn = device | |||
| else: | |||
| self._cn = CompNode(device) | |||
| def to_c(self): | |||
| return self._cn | |||
| def __repr__(self): | |||
| return "{}({})".format(type(self).__qualname__, self) | |||
| def __str__(self): | |||
| return str(self._cn) | |||
| def __hash__(self): | |||
| return hash(str(self._cn)) | |||
| def __eq__(self, rhs): | |||
| if not isinstance(rhs, Device): | |||
| rhs = Device(rhs) | |||
| return str(self._cn) == str(rhs._cn) | |||
| def device(obj): | |||
| if isinstance(obj, Device): | |||
| return obj | |||
| return Device(obj) | |||
| @@ -0,0 +1,8 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| @@ -0,0 +1,134 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import functools | |||
| import itertools | |||
| import numpy as np | |||
| from .._imperative_rt import TensorAttr, imperative | |||
| from ..ops.builtin import Elemwise, GetVarShape, OpDef, OprAttr, Reduce, Reshape | |||
| from ..tensor.core import apply | |||
| from ..tensor.function import Function | |||
| @functools.singledispatch | |||
| def builtin_op_get_backward_fn(op: OpDef, inputs, outputs, input_requires_grad): | |||
| assert 0 | |||
| _elemwise_add_param = Elemwise(mode="add").to_c().param | |||
| @builtin_op_get_backward_fn.register(OpDef) | |||
| def _(op: OpDef, inputs, outputs, input_requires_grad): | |||
| if ( | |||
| isinstance(op, OprAttr) | |||
| and op.type == "Elemwise" | |||
| and op.param == _elemwise_add_param | |||
| ): | |||
| grad_fn = elemwise_grad_fn | |||
| elif isinstance(op, OprAttr) and op.type == Reshape.name: | |||
| grad_fn = reshape_grad_fn | |||
| else: | |||
| grad_fn = default_grad_fn | |||
| return grad_fn(op, inputs, outputs, input_requires_grad) | |||
| @builtin_op_get_backward_fn.register(Function) | |||
| def _(op: Function, inputs, outputs, input_requires_grad): | |||
| return op.get_backward_fn(), [True,] * len(outputs) | |||
| def default_grad_fn(op, inputs, outputs, input_requires_grad): | |||
| def get_tensor_attr(x): | |||
| attr = TensorAttr() | |||
| attr.dtype = x.dtype | |||
| attr.comp_node = x.device.to_c() | |||
| return attr | |||
| output_has_grads = [True,] * len(outputs) | |||
| result = imperative.make_backward_graph( | |||
| op, list(map(get_tensor_attr, inputs)), input_requires_grad, output_has_grads | |||
| ) | |||
| if result is None: | |||
| nr_inputs = len(inputs) | |||
| nr_outputs = len(outputs) | |||
| def backward(*args): | |||
| return nr_inputs * [ | |||
| None, | |||
| ] | |||
| return backward, nr_outputs * [False,] | |||
| backward_graph, save_for_backward_mask, input_has_grad = result | |||
| intput_output_mask = save_for_backward_mask[: len(inputs + outputs) :] | |||
| output_grad_mask = save_for_backward_mask[len(inputs + outputs) :] | |||
| save_for_backward = tuple( | |||
| val for val, mask in zip(inputs + outputs, intput_output_mask) if mask | |||
| ) | |||
| del inputs | |||
| del outputs | |||
| def backward(*args): | |||
| output_grads = tuple(val for val, mask in zip(args, output_grad_mask) if mask) | |||
| assert None not in output_grads | |||
| ret = iter(apply(backward_graph, *(save_for_backward + output_grads))) | |||
| return tuple(next(ret) if mask else None for mask in input_has_grad) | |||
| return backward, output_grad_mask | |||
| # override for elemwise | |||
| def elemwise_grad_fn(op, inputs, outputs, input_requires_grad): | |||
| assert len(inputs) == len(input_requires_grad) == 2 | |||
| def get_shape(x): | |||
| (s,) = apply(GetVarShape(), x) | |||
| return s | |||
| input_shapes = [ | |||
| get_shape(x) if i else None for i, x in zip(input_requires_grad, inputs) | |||
| ] | |||
| def reduce_to(x, s): | |||
| (y,) = apply(Reduce(), x, s) | |||
| return y | |||
| def backward(dy): | |||
| return tuple( | |||
| reduce_to(dy, s) if i else None | |||
| for i, s in zip(input_requires_grad, input_shapes) | |||
| ) | |||
| return backward, [True] | |||
| def reshape_grad_fn(op, inputs, outputs, input_requires_grad): | |||
| assert len(inputs) == len(input_requires_grad) == 2 | |||
| def get_shape(x): | |||
| (s,) = apply(GetVarShape(), x) | |||
| return s | |||
| input_shapes = [ | |||
| get_shape(x) if i else None for i, x in zip(input_requires_grad, inputs) | |||
| ] | |||
| def reshape_to(dy, s): | |||
| (dx,) = apply(Reshape(), dy, s) | |||
| return dx | |||
| def backward(dy): | |||
| return tuple( | |||
| reshape_to(dy, s) if i else None | |||
| for i, s in zip(input_requires_grad, input_shapes) | |||
| ) | |||
| return backward, [True] | |||
| @@ -0,0 +1,390 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import functools | |||
| import heapq | |||
| import itertools | |||
| import typing | |||
| import weakref | |||
| import numpy as np | |||
| from ..ops.builtin import Elemwise, OpDef | |||
| from ..ops.special import Const | |||
| from ..tensor.core import TensorBase, TensorWrapperBase, apply | |||
| from ..tensor.function import Function | |||
| from ..tensor.tensor import Tensor, get_context | |||
| from . import builtin_op_utils | |||
| """ Some notes: | |||
| 1. Initialize the optimizer: | |||
| for each trainable parameter: | |||
| call wrt(param, callback) | |||
| Each parameter tensor will be assciated with a Tracer object saved in Tensor._extra_data | |||
| 2. Tracer has one member: node, which is a VariableNode | |||
| 3. VariableNode has a OpNode member: opnode | |||
| 4. OpNode has four members: | |||
| a. id | |||
| b. inputs, which is made of VariableNode | |||
| c. outputs, which are weakref's to VariableNode | |||
| d. backward: call back function | |||
| e. has_grad_fn: call has_grad_fn(opnode, reached) to check grad exist | |||
| f. backward_allow_noinput: whether backward allow noinput | |||
| """ | |||
| _grad_count = 0 | |||
| _grad_manager_dict = weakref.WeakValueDictionary() | |||
| def get_grad_managers(): | |||
| return [_grad_manager_dict[key] for key in _grad_manager_dict] | |||
| def add(a, b): | |||
| (c,) = apply(Elemwise(mode="add"), a, b) | |||
| return c | |||
| def get_tensor(x): | |||
| # use recursion to avoid infinite loop | |||
| if isinstance(x, Tensor): | |||
| return x | |||
| try: | |||
| x = x.__wrapped__ | |||
| except AttributeError: | |||
| raise TypeError(type(x)) | |||
| return get_tensor(x) | |||
| class Grad: | |||
| def __init__(self, name=None): | |||
| if name is None: | |||
| global _grad_count | |||
| self._name = "grad_" + str(_grad_count) | |||
| _grad_count += 1 | |||
| else: | |||
| self._name = name | |||
| assert self._name not in _grad_manager_dict, "grad manager name duplicated" | |||
| _grad_manager_dict[self._name] = self | |||
| # list of all x in partial(y) / partial(x) | |||
| self.xs = [] | |||
| # constains weak reference of all OpNode during forward | |||
| # OpNode contains inputs, outputs and its backward | |||
| # ops forms the computational graph | |||
| self.ops = [] | |||
| self._enabled = True | |||
| @property | |||
| def name(self): | |||
| return self._name | |||
| def wrt(self, *args: Tensor, callback=None): | |||
| """ Indicates the loss is a function of the input tensors (usually the net trainable parameters), | |||
| i.e., d (loss) / d (Tensor) != 0 | |||
| callback is used to perform additional operations after gradient is obtained in backward. | |||
| e.g., copy the grad to a particular place | |||
| A VariableNode will be created and saved in the tensor/s _extra_data slot. | |||
| """ | |||
| for x in map(get_tensor, args): | |||
| v = self._new_variable(x, callback=callback) | |||
| assert self not in x._extra_data | |||
| x._extra_data[self] = Tracer(v) | |||
| self.xs.append(v) | |||
| return self | |||
| def _new_variable(self, owner, opnode=None, callback=None): | |||
| return VariableNode(self, owner, opnode=opnode, callback=callback) | |||
| def _new_opnode(self, inputs, outputs): | |||
| inputs = tuple(inputs) | |||
| for i in inputs: | |||
| assert i is None or isinstance(i, VariableNode) | |||
| o = OpNode() | |||
| o.inputs = inputs | |||
| o.outputs = [] | |||
| tracers = [] | |||
| for i in outputs: | |||
| assert isinstance(i, Tensor) | |||
| v = self._new_variable(i, o) | |||
| o.outputs.append(weakref.ref(v)) | |||
| tracers.append(Tracer(v)) | |||
| self.ops.append(weakref.ref(o)) | |||
| return o, tracers | |||
| def copy(self): | |||
| raise NotImplementedError | |||
| def __enter__(self): | |||
| return self | |||
| def __exit__(self, *_): | |||
| """clear all resources""" | |||
| self._enabled = False | |||
| for o in self.ops: | |||
| o = o() | |||
| if o: | |||
| o.clear() | |||
| def __call__(self, ys, dys): | |||
| """ Defines Grad(). | |||
| :param ys: outputs of forward operators, e.g., the loss tensor | |||
| :type ys: list of Tensor or TensorWrapperBase | |||
| :param dys: delta of outputs, physically equivalent to sensitivity of outputs to the loss, | |||
| e.g., one for the loss itself | |||
| :type dys: list of Tensor or TensorWrapperBase | |||
| """ | |||
| assert self._enabled | |||
| self._enabled = False | |||
| def check_wrapper(): | |||
| if isinstance(dys, TensorWrapperBase): | |||
| return type(dys) | |||
| if isinstance(dys, TensorBase): | |||
| return | |||
| assert isinstance(dys, (tuple, list)) | |||
| for i in dys: | |||
| if isinstance(i, TensorWrapperBase): | |||
| return type(i) | |||
| Wrapper = check_wrapper() | |||
| def aslist(x): | |||
| if isinstance(x, (Tensor, TensorWrapperBase)): | |||
| x = [x] | |||
| else: | |||
| x = list(x) | |||
| x = [i.__wrapped__ if isinstance(i, TensorWrapperBase) else i for i in x] | |||
| for i in x: | |||
| assert isinstance(i, Tensor) | |||
| return x | |||
| ys = aslist(ys) | |||
| dys = aslist(dys) | |||
| assert len(ys) == len(dys) | |||
| # ys is changed to a list of VariableNode which contains more information | |||
| # such as OpNode, callback, etc. | |||
| ys = [i._extra_data[self].node for i in ys] | |||
| # NOTE: callback is called only if grad is not None | |||
| # the OpNode sequence in backward | |||
| op_seq = [] | |||
| # VariableNode -> (i, j), where i is time stamp in backward, j means jth input | |||
| last_written_to = {} | |||
| def schedule(): | |||
| reached = set(ys) | |||
| # i is the time stamp in backward | |||
| i = 0 | |||
| for o in self.ops[::-1]: | |||
| o = o() | |||
| if o is None: | |||
| continue | |||
| if not o.has_grad_fn(o, reached): | |||
| continue | |||
| op_seq.append(o) | |||
| for j, v in enumerate(o.inputs): | |||
| reached.add(v) | |||
| last_written_to[v] = i, j | |||
| i += 1 | |||
| schedule() | |||
| # VariableNode -> Tensor | |||
| cache = {} | |||
| def initialize(): | |||
| for y, dy in zip(ys, dys): | |||
| cache[y] = dy | |||
| if y not in last_written_to and y.callback: | |||
| y.callback(y.owner(), dy) | |||
| initialize() | |||
| # NOTE: None is used to mark a node has been consumed | |||
| for seqno, opnode in enumerate(op_seq): | |||
| input_nodes = opnode.inputs | |||
| output_nodes = [i() for i in opnode.outputs] | |||
| backward = opnode.backward | |||
| backward_allow_noinput = opnode.backward_allow_noinput | |||
| opnode.clear() | |||
| output_grads = [] | |||
| for i in output_nodes: | |||
| if i is not None: | |||
| if i in cache: | |||
| assert cache[i] is not None | |||
| output_grads.append(cache[i]) | |||
| else: | |||
| output_grads.append(None) | |||
| # read by backward, mark consumed | |||
| cache[i] = None | |||
| else: | |||
| output_grads.append(None) | |||
| if ( | |||
| any([grad is not None for grad in output_grads]) | |||
| or backward_allow_noinput | |||
| ): | |||
| input_grads = backward(*output_grads) | |||
| else: | |||
| input_grads = [None] * len(input_nodes) | |||
| assert len(input_nodes) == len(input_grads) | |||
| for i, (v, g) in enumerate(zip(input_nodes, input_grads)): | |||
| if v is None: | |||
| continue | |||
| if v in cache: | |||
| assert cache[v] | |||
| if g is not None: | |||
| cache[v] = add(cache[v], g) | |||
| elif g is not None: | |||
| cache[v] = g | |||
| if last_written_to[v] == (seqno, i): | |||
| if v.callback: | |||
| v.callback( | |||
| v.owner(), Wrapper(cache[v]) if Wrapper else cache[v] | |||
| ) | |||
| if v.opnode is None: | |||
| # won't read by backward, mark consumed | |||
| cache[v] = None | |||
| for v in cache.values(): | |||
| assert v is None | |||
| class clearable: | |||
| __cleared = False | |||
| def __bool__(self): | |||
| return not self.__cleared | |||
| def clear(self): | |||
| self.__dict__.clear() | |||
| self.__cleared = True | |||
| class OpNode(clearable): | |||
| """ OpNode saves all the information to form the computational graph. | |||
| """ | |||
| def __init__(self): | |||
| self.id = None | |||
| self.inputs = None # Could be VariableNode | |||
| self.outputs = None # Could be VariableNode | |||
| self.backward = None | |||
| self.has_grad_fn = None | |||
| self.backward_allow_noinput = False | |||
| class VariableNode(clearable): | |||
| """ VariableNode saves OpNode and callback. | |||
| FIXME!!! Explain manager and owner | |||
| """ | |||
| def __init__(self, manager, owner, opnode=None, callback=None): | |||
| # manager is Grad type | |||
| self.manager = weakref.ref(manager) | |||
| # owner is Tensor type | |||
| self.owner = weakref.ref(owner) | |||
| self.opnode = opnode | |||
| self.callback = callback | |||
| class Tracer(clearable, TensorBase): | |||
| def __init__(self, node=None): | |||
| """ type(node) is VariableNode | |||
| """ | |||
| self.node = node | |||
| @functools.singledispatch | |||
| def check_backward_allow_noinput(op: OpDef): | |||
| return False | |||
| @functools.singledispatch | |||
| def get_op_has_grad_fn(op: OpDef): | |||
| assert 0 | |||
| @get_op_has_grad_fn.register(OpDef) | |||
| def _(op: OpDef): | |||
| return default_has_grad_fn | |||
| @get_op_has_grad_fn.register(Function) | |||
| def _(op: Function): | |||
| return default_has_grad_fn | |||
| def default_has_grad_fn(opnode, reached): | |||
| for v in opnode.outputs: | |||
| if v() in reached: | |||
| return True | |||
| return False | |||
| @apply.add | |||
| def tracer_apply(op: (OpDef, Function), *args: typing.Optional[Tracer]): | |||
| args = tuple(i if isinstance(i, Tracer) else None for i in args) | |||
| input_requires_grad = list(map(bool, args)) | |||
| if not any(input_requires_grad): | |||
| return | |||
| ctx = get_context() | |||
| manager = None | |||
| assert len(ctx.inputs) == len(args) | |||
| for i, j in zip(ctx.inputs, args): | |||
| if j: | |||
| j = j.node | |||
| assert i is j.owner() | |||
| if manager is None: | |||
| manager = j.manager() | |||
| assert manager | |||
| else: | |||
| assert manager is j.manager() | |||
| if not manager._enabled: | |||
| return | |||
| opnode, outputs = manager._new_opnode([i and i.node for i in args], ctx.outputs) | |||
| # register backward method | |||
| # tuple of backward functions corresponding to dy / dx_i | |||
| # None means y is not a function of x_i | |||
| opnode.backward, output_need_grad = builtin_op_utils.builtin_op_get_backward_fn( | |||
| op, ctx.inputs, ctx.outputs, input_requires_grad | |||
| ) | |||
| assert len(outputs) == len(output_need_grad) | |||
| outputs = [x if y else None for (x, y) in zip(outputs, output_need_grad)] | |||
| opnode.backward_allow_noinput = check_backward_allow_noinput(op) | |||
| opnode.has_grad_fn = get_op_has_grad_fn(op) | |||
| return tuple(outputs) | |||
| @apply.add | |||
| def _(op: Const, *_: typing.Optional[Tracer]): | |||
| return None | |||
| @@ -0,0 +1,8 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| @@ -0,0 +1,8 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| @@ -0,0 +1,10 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from .generated_ops import * | |||
| from .misc_ops import * | |||
| @@ -0,0 +1,929 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import sys | |||
| from functools import reduce | |||
| from operator import or_ as _or_ | |||
| from types import DynamicClassAttribute, MappingProxyType | |||
| # try _collections first to reduce startup cost | |||
| try: | |||
| from _collections import OrderedDict | |||
| except ImportError: | |||
| from collections import OrderedDict | |||
| __all__ = [ | |||
| "EnumMeta", | |||
| "Enum", | |||
| "IntEnum", | |||
| "Flag", | |||
| "IntFlag", | |||
| "auto", | |||
| "unique", | |||
| ] | |||
| def _is_descriptor(obj): | |||
| """Returns True if obj is a descriptor, False otherwise.""" | |||
| return ( | |||
| hasattr(obj, "__get__") or hasattr(obj, "__set__") or hasattr(obj, "__delete__") | |||
| ) | |||
| def _is_dunder(name): | |||
| """Returns True if a __dunder__ name, False otherwise.""" | |||
| return ( | |||
| name[:2] == name[-2:] == "__" | |||
| and name[2:3] != "_" | |||
| and name[-3:-2] != "_" | |||
| and len(name) > 4 | |||
| ) | |||
| def _is_sunder(name): | |||
| """Returns True if a _sunder_ name, False otherwise.""" | |||
| return ( | |||
| name[0] == name[-1] == "_" | |||
| and name[1:2] != "_" | |||
| and name[-2:-1] != "_" | |||
| and len(name) > 2 | |||
| ) | |||
| def _make_class_unpicklable(cls): | |||
| """Make the given class un-picklable.""" | |||
| def _break_on_call_reduce(self, proto): | |||
| raise TypeError("%r cannot be pickled" % self) | |||
| cls.__reduce_ex__ = _break_on_call_reduce | |||
| cls.__module__ = "<unknown>" | |||
| _auto_null = object() | |||
| class auto: | |||
| """ | |||
| Instances are replaced with an appropriate value in Enum class suites. | |||
| """ | |||
| value = _auto_null | |||
| class _EnumDict(dict): | |||
| """Track enum member order and ensure member names are not reused. | |||
| EnumMeta will use the names found in self._member_names as the | |||
| enumeration member names. | |||
| """ | |||
| def __init__(self): | |||
| super().__init__() | |||
| self._member_names = [] | |||
| self._last_values = [] | |||
| def __setitem__(self, key, value): | |||
| """Changes anything not dundered or not a descriptor. | |||
| If an enum member name is used twice, an error is raised; duplicate | |||
| values are not checked for. | |||
| Single underscore (sunder) names are reserved. | |||
| """ | |||
| if _is_sunder(key): | |||
| if key not in ( | |||
| "_order_", | |||
| "_create_pseudo_member_", | |||
| "_generate_next_value_", | |||
| "_missing_", | |||
| ): | |||
| raise ValueError("_names_ are reserved for future Enum use") | |||
| if key == "_generate_next_value_": | |||
| setattr(self, "_generate_next_value", value) | |||
| elif _is_dunder(key): | |||
| if key == "__order__": | |||
| key = "_order_" | |||
| elif key in self._member_names: | |||
| # descriptor overwriting an enum? | |||
| raise TypeError("Attempted to reuse key: %r" % key) | |||
| elif not _is_descriptor(value): | |||
| if key in self: | |||
| # enum overwriting a descriptor? | |||
| raise TypeError("%r already defined as: %r" % (key, self[key])) | |||
| if isinstance(value, auto): | |||
| if value.value == _auto_null: | |||
| value.value = self._generate_next_value( | |||
| key, 1, len(self._member_names), self._last_values[:] | |||
| ) | |||
| value = value.value | |||
| self._member_names.append(key) | |||
| self._last_values.append(value) | |||
| super().__setitem__(key, value) | |||
| # Dummy value for Enum as EnumMeta explicitly checks for it, but of course | |||
| # until EnumMeta finishes running the first time the Enum class doesn't exist. | |||
| # This is also why there are checks in EnumMeta like `if Enum is not None` | |||
| Enum = None | |||
| class EnumMeta(type): | |||
| """Metaclass for Enum""" | |||
| @classmethod | |||
| def __prepare__(metacls, cls, bases): | |||
| # create the namespace dict | |||
| enum_dict = _EnumDict() | |||
| # inherit previous flags and _generate_next_value_ function | |||
| member_type, first_enum = metacls._get_mixins_(bases) | |||
| if first_enum is not None: | |||
| enum_dict["_generate_next_value_"] = getattr( | |||
| first_enum, "_generate_next_value_", None | |||
| ) | |||
| return enum_dict | |||
| def __new__(metacls, cls, bases, classdict): | |||
| # an Enum class is final once enumeration items have been defined; it | |||
| # cannot be mixed with other types (int, float, etc.) if it has an | |||
| # inherited __new__ unless a new __new__ is defined (or the resulting | |||
| # class will fail). | |||
| member_type, first_enum = metacls._get_mixins_(bases) | |||
| __new__, save_new, use_args = metacls._find_new_( | |||
| classdict, member_type, first_enum | |||
| ) | |||
| # save enum items into separate mapping so they don't get baked into | |||
| # the new class | |||
| enum_members = {k: classdict[k] for k in classdict._member_names} | |||
| for name in classdict._member_names: | |||
| del classdict[name] | |||
| # adjust the sunders | |||
| _order_ = classdict.pop("_order_", None) | |||
| # check for illegal enum names (any others?) | |||
| invalid_names = set(enum_members) & { | |||
| "mro", | |||
| } | |||
| if invalid_names: | |||
| raise ValueError( | |||
| "Invalid enum member name: {0}".format(",".join(invalid_names)) | |||
| ) | |||
| # create a default docstring if one has not been provided | |||
| if "__doc__" not in classdict: | |||
| classdict["__doc__"] = "An enumeration." | |||
| # create our new Enum type | |||
| enum_class = super().__new__(metacls, cls, bases, classdict) | |||
| enum_class._member_names_ = [] # names in definition order | |||
| enum_class._member_map_ = OrderedDict() # name->value map | |||
| enum_class._member_type_ = member_type | |||
| # save attributes from super classes so we know if we can take | |||
| # the shortcut of storing members in the class dict | |||
| base_attributes = {a for b in enum_class.mro() for a in b.__dict__} | |||
| # Reverse value->name map for hashable values. | |||
| enum_class._value2member_map_ = {} | |||
| # If a custom type is mixed into the Enum, and it does not know how | |||
| # to pickle itself, pickle.dumps will succeed but pickle.loads will | |||
| # fail. Rather than have the error show up later and possibly far | |||
| # from the source, sabotage the pickle protocol for this class so | |||
| # that pickle.dumps also fails. | |||
| # | |||
| # However, if the new class implements its own __reduce_ex__, do not | |||
| # sabotage -- it's on them to make sure it works correctly. We use | |||
| # __reduce_ex__ instead of any of the others as it is preferred by | |||
| # pickle over __reduce__, and it handles all pickle protocols. | |||
| if "__reduce_ex__" not in classdict: | |||
| if member_type is not object: | |||
| methods = ( | |||
| "__getnewargs_ex__", | |||
| "__getnewargs__", | |||
| "__reduce_ex__", | |||
| "__reduce__", | |||
| ) | |||
| if not any(m in member_type.__dict__ for m in methods): | |||
| _make_class_unpicklable(enum_class) | |||
| # instantiate them, checking for duplicates as we go | |||
| # we instantiate first instead of checking for duplicates first in case | |||
| # a custom __new__ is doing something funky with the values -- such as | |||
| # auto-numbering ;) | |||
| for member_name in classdict._member_names: | |||
| value = enum_members[member_name] | |||
| if not isinstance(value, tuple): | |||
| args = (value,) | |||
| else: | |||
| args = value | |||
| if member_type is tuple: # special case for tuple enums | |||
| args = (args,) # wrap it one more time | |||
| if not use_args: | |||
| enum_member = __new__(enum_class) | |||
| if not hasattr(enum_member, "_value_"): | |||
| enum_member._value_ = value | |||
| else: | |||
| enum_member = __new__(enum_class, *args) | |||
| if not hasattr(enum_member, "_value_"): | |||
| if member_type is object: | |||
| enum_member._value_ = value | |||
| else: | |||
| enum_member._value_ = member_type(*args) | |||
| value = enum_member._value_ | |||
| enum_member._name_ = member_name | |||
| enum_member.__objclass__ = enum_class | |||
| enum_member.__init__(*args) | |||
| # If another member with the same value was already defined, the | |||
| # new member becomes an alias to the existing one. | |||
| for name, canonical_member in enum_class._member_map_.items(): | |||
| if canonical_member._value_ == enum_member._value_: | |||
| enum_member = canonical_member | |||
| break | |||
| else: | |||
| # Aliases don't appear in member names (only in __members__). | |||
| enum_class._member_names_.append(member_name) | |||
| # performance boost for any member that would not shadow | |||
| # a DynamicClassAttribute | |||
| if member_name not in base_attributes: | |||
| setattr(enum_class, member_name, enum_member) | |||
| # now add to _member_map_ | |||
| enum_class._member_map_[member_name] = enum_member | |||
| try: | |||
| # This may fail if value is not hashable. We can't add the value | |||
| # to the map, and by-value lookups for this value will be | |||
| # linear. | |||
| enum_class._value2member_map_[value] = enum_member | |||
| except TypeError: | |||
| pass | |||
| # double check that repr and friends are not the mixin's or various | |||
| # things break (such as pickle) | |||
| for name in ("__repr__", "__str__", "__format__", "__reduce_ex__"): | |||
| class_method = getattr(enum_class, name) | |||
| obj_method = getattr(member_type, name, None) | |||
| enum_method = getattr(first_enum, name, None) | |||
| if obj_method is not None and obj_method is class_method: | |||
| setattr(enum_class, name, enum_method) | |||
| # replace any other __new__ with our own (as long as Enum is not None, | |||
| # anyway) -- again, this is to support pickle | |||
| if Enum is not None: | |||
| # if the user defined their own __new__, save it before it gets | |||
| # clobbered in case they subclass later | |||
| if save_new: | |||
| enum_class.__new_member__ = __new__ | |||
| enum_class.__new__ = Enum.__new__ | |||
| # py3 support for definition order (helps keep py2/py3 code in sync) | |||
| if _order_ is not None: | |||
| if isinstance(_order_, str): | |||
| _order_ = _order_.replace(",", " ").split() | |||
| if _order_ != enum_class._member_names_: | |||
| raise TypeError("member order does not match _order_") | |||
| return enum_class | |||
| def __bool__(self): | |||
| """ | |||
| classes/types should always be True. | |||
| """ | |||
| return True | |||
| def __call__( | |||
| cls, value, names=None, *, module=None, qualname=None, type=None, start=1 | |||
| ): | |||
| """Either returns an existing member, or creates a new enum class. | |||
| This method is used both when an enum class is given a value to match | |||
| to an enumeration member (i.e. Color(3)) and for the functional API | |||
| (i.e. Color = Enum('Color', names='RED GREEN BLUE')). | |||
| When used for the functional API: | |||
| `value` will be the name of the new class. | |||
| `names` should be either a string of white-space/comma delimited names | |||
| (values will start at `start`), or an iterator/mapping of name, value pairs. | |||
| `module` should be set to the module this class is being created in; | |||
| if it is not set, an attempt to find that module will be made, but if | |||
| it fails the class will not be picklable. | |||
| `qualname` should be set to the actual location this class can be found | |||
| at in its module; by default it is set to the global scope. If this is | |||
| not correct, unpickling will fail in some circumstances. | |||
| `type`, if set, will be mixed in as the first base class. | |||
| """ | |||
| if names is None: # simple value lookup | |||
| return cls.__new__(cls, value) | |||
| # otherwise, functional API: we're creating a new Enum type | |||
| return cls._create_( | |||
| value, names, module=module, qualname=qualname, type=type, start=start | |||
| ) | |||
| def __contains__(cls, member): | |||
| return isinstance(member, cls) and member._name_ in cls._member_map_ | |||
| def __delattr__(cls, attr): | |||
| # nicer error message when someone tries to delete an attribute | |||
| # (see issue19025). | |||
| if attr in cls._member_map_: | |||
| raise AttributeError("%s: cannot delete Enum member." % cls.__name__) | |||
| super().__delattr__(attr) | |||
| def __dir__(self): | |||
| return [ | |||
| "__class__", | |||
| "__doc__", | |||
| "__members__", | |||
| "__module__", | |||
| ] + self._member_names_ | |||
| def __getattr__(cls, name): | |||
| """Return the enum member matching `name` | |||
| We use __getattr__ instead of descriptors or inserting into the enum | |||
| class' __dict__ in order to support `name` and `value` being both | |||
| properties for enum members (which live in the class' __dict__) and | |||
| enum members themselves. | |||
| """ | |||
| if _is_dunder(name): | |||
| raise AttributeError(name) | |||
| try: | |||
| return cls._member_map_[name] | |||
| except KeyError: | |||
| raise AttributeError(name) from None | |||
| def __getitem__(cls, name): | |||
| return cls._member_map_[name] | |||
| def __iter__(cls): | |||
| return (cls._member_map_[name] for name in cls._member_names_) | |||
| def __len__(cls): | |||
| return len(cls._member_names_) | |||
| @property | |||
| def __members__(cls): | |||
| """Returns a mapping of member name->value. | |||
| This mapping lists all enum members, including aliases. Note that this | |||
| is a read-only view of the internal mapping. | |||
| """ | |||
| return MappingProxyType(cls._member_map_) | |||
| def __repr__(cls): | |||
| return "<enum %r>" % cls.__name__ | |||
| def __reversed__(cls): | |||
| return (cls._member_map_[name] for name in reversed(cls._member_names_)) | |||
| def __setattr__(cls, name, value): | |||
| """Block attempts to reassign Enum members. | |||
| A simple assignment to the class namespace only changes one of the | |||
| several possible ways to get an Enum member from the Enum class, | |||
| resulting in an inconsistent Enumeration. | |||
| """ | |||
| member_map = cls.__dict__.get("_member_map_", {}) | |||
| if name in member_map: | |||
| raise AttributeError("Cannot reassign members.") | |||
| super().__setattr__(name, value) | |||
| def _create_( | |||
| cls, class_name, names=None, *, module=None, qualname=None, type=None, start=1 | |||
| ): | |||
| """Convenience method to create a new Enum class. | |||
| `names` can be: | |||
| * A string containing member names, separated either with spaces or | |||
| commas. Values are incremented by 1 from `start`. | |||
| * An iterable of member names. Values are incremented by 1 from `start`. | |||
| * An iterable of (member name, value) pairs. | |||
| * A mapping of member name -> value pairs. | |||
| """ | |||
| metacls = cls.__class__ | |||
| bases = (cls,) if type is None else (type, cls) | |||
| _, first_enum = cls._get_mixins_(bases) | |||
| classdict = metacls.__prepare__(class_name, bases) | |||
| # special processing needed for names? | |||
| if isinstance(names, str): | |||
| names = names.replace(",", " ").split() | |||
| if isinstance(names, (tuple, list)) and names and isinstance(names[0], str): | |||
| original_names, names = names, [] | |||
| last_values = [] | |||
| for count, name in enumerate(original_names): | |||
| value = first_enum._generate_next_value_( | |||
| name, start, count, last_values[:] | |||
| ) | |||
| last_values.append(value) | |||
| names.append((name, value)) | |||
| # Here, names is either an iterable of (name, value) or a mapping. | |||
| for item in names: | |||
| if isinstance(item, str): | |||
| member_name, member_value = item, names[item] | |||
| else: | |||
| member_name, member_value = item | |||
| classdict[member_name] = member_value | |||
| enum_class = metacls.__new__(metacls, class_name, bases, classdict) | |||
| # TODO: replace the frame hack if a blessed way to know the calling | |||
| # module is ever developed | |||
| if module is None: | |||
| try: | |||
| module = sys._getframe(2).f_globals["__name__"] | |||
| except (AttributeError, ValueError) as exc: | |||
| pass | |||
| if module is None: | |||
| _make_class_unpicklable(enum_class) | |||
| else: | |||
| enum_class.__module__ = module | |||
| if qualname is not None: | |||
| enum_class.__qualname__ = qualname | |||
| return enum_class | |||
| @staticmethod | |||
| def _get_mixins_(bases): | |||
| """Returns the type for creating enum members, and the first inherited | |||
| enum class. | |||
| bases: the tuple of bases that was given to __new__ | |||
| """ | |||
| if not bases: | |||
| return object, Enum | |||
| # double check that we are not subclassing a class with existing | |||
| # enumeration members; while we're at it, see if any other data | |||
| # type has been mixed in so we can use the correct __new__ | |||
| member_type = first_enum = None | |||
| for base in bases: | |||
| if base is not Enum and issubclass(base, Enum) and base._member_names_: | |||
| raise TypeError("Cannot extend enumerations") | |||
| # base is now the last base in bases | |||
| if not issubclass(base, Enum): | |||
| raise TypeError( | |||
| "new enumerations must be created as " | |||
| "`ClassName([mixin_type,] enum_type)`" | |||
| ) | |||
| # get correct mix-in type (either mix-in type of Enum subclass, or | |||
| # first base if last base is Enum) | |||
| if not issubclass(bases[0], Enum): | |||
| member_type = bases[0] # first data type | |||
| first_enum = bases[-1] # enum type | |||
| else: | |||
| for base in bases[0].__mro__: | |||
| # most common: (IntEnum, int, Enum, object) | |||
| # possible: (<Enum 'AutoIntEnum'>, <Enum 'IntEnum'>, | |||
| # <class 'int'>, <Enum 'Enum'>, | |||
| # <class 'object'>) | |||
| if issubclass(base, Enum): | |||
| if first_enum is None: | |||
| first_enum = base | |||
| else: | |||
| if member_type is None: | |||
| member_type = base | |||
| return member_type, first_enum | |||
| @staticmethod | |||
| def _find_new_(classdict, member_type, first_enum): | |||
| """Returns the __new__ to be used for creating the enum members. | |||
| classdict: the class dictionary given to __new__ | |||
| member_type: the data type whose __new__ will be used by default | |||
| first_enum: enumeration to check for an overriding __new__ | |||
| """ | |||
| # now find the correct __new__, checking to see of one was defined | |||
| # by the user; also check earlier enum classes in case a __new__ was | |||
| # saved as __new_member__ | |||
| __new__ = classdict.get("__new__", None) | |||
| # should __new__ be saved as __new_member__ later? | |||
| save_new = __new__ is not None | |||
| if __new__ is None: | |||
| # check all possibles for __new_member__ before falling back to | |||
| # __new__ | |||
| for method in ("__new_member__", "__new__"): | |||
| for possible in (member_type, first_enum): | |||
| target = getattr(possible, method, None) | |||
| if target not in { | |||
| None, | |||
| None.__new__, | |||
| object.__new__, | |||
| Enum.__new__, | |||
| }: | |||
| __new__ = target | |||
| break | |||
| if __new__ is not None: | |||
| break | |||
| else: | |||
| __new__ = object.__new__ | |||
| # if a non-object.__new__ is used then whatever value/tuple was | |||
| # assigned to the enum member name will be passed to __new__ and to the | |||
| # new enum member's __init__ | |||
| if __new__ is object.__new__: | |||
| use_args = False | |||
| else: | |||
| use_args = True | |||
| return __new__, save_new, use_args | |||
| class Enum(metaclass=EnumMeta): | |||
| """Generic enumeration. | |||
| Derive from this class to define new enumerations. | |||
| """ | |||
| def __new__(cls, value): | |||
| # all enum instances are actually created during class construction | |||
| # without calling this method; this method is called by the metaclass' | |||
| # __call__ (i.e. Color(3) ), and by pickle | |||
| if type(value) is cls: | |||
| # For lookups like Color(Color.RED) | |||
| return value | |||
| # by-value search for a matching enum member | |||
| # see if it's in the reverse mapping (for hashable values) | |||
| try: | |||
| if value in cls._value2member_map_: | |||
| return cls._value2member_map_[value] | |||
| except TypeError: | |||
| # not there, now do long search -- O(n) behavior | |||
| for member in cls._member_map_.values(): | |||
| if member._value_ == value: | |||
| return member | |||
| # still not found -- try _missing_ hook | |||
| return cls._missing_(value) | |||
| def _generate_next_value_(name, start, count, last_values): | |||
| for last_value in reversed(last_values): | |||
| try: | |||
| return last_value + 1 | |||
| except TypeError: | |||
| pass | |||
| else: | |||
| return start | |||
| @classmethod | |||
| def _missing_(cls, value): | |||
| raise ValueError("%r is not a valid %s" % (value, cls.__name__)) | |||
| def __repr__(self): | |||
| return "<%s.%s: %r>" % (self.__class__.__name__, self._name_, self._value_) | |||
| def __str__(self): | |||
| return "%s.%s" % (self.__class__.__name__, self._name_) | |||
| def __dir__(self): | |||
| added_behavior = [ | |||
| m | |||
| for cls in self.__class__.mro() | |||
| for m in cls.__dict__ | |||
| if m[0] != "_" and m not in self._member_map_ | |||
| ] | |||
| return ["__class__", "__doc__", "__module__"] + added_behavior | |||
| def __format__(self, format_spec): | |||
| # mixed-in Enums should use the mixed-in type's __format__, otherwise | |||
| # we can get strange results with the Enum name showing up instead of | |||
| # the value | |||
| # pure Enum branch | |||
| if self._member_type_ is object: | |||
| cls = str | |||
| val = str(self) | |||
| # mix-in branch | |||
| else: | |||
| cls = self._member_type_ | |||
| val = self._value_ | |||
| return cls.__format__(val, format_spec) | |||
| def __hash__(self): | |||
| return hash(self._name_) | |||
| def __reduce_ex__(self, proto): | |||
| return self.__class__, (self._value_,) | |||
| # DynamicClassAttribute is used to provide access to the `name` and | |||
| # `value` properties of enum members while keeping some measure of | |||
| # protection from modification, while still allowing for an enumeration | |||
| # to have members named `name` and `value`. This works because enumeration | |||
| # members are not set directly on the enum class -- __getattr__ is | |||
| # used to look them up. | |||
| @DynamicClassAttribute | |||
| def name(self): | |||
| """The name of the Enum member.""" | |||
| return self._name_ | |||
| @DynamicClassAttribute | |||
| def value(self): | |||
| """The value of the Enum member.""" | |||
| return self._value_ | |||
| @classmethod | |||
| def _convert(cls, name, module, filter, source=None): | |||
| """ | |||
| Create a new Enum subclass that replaces a collection of global constants | |||
| """ | |||
| # convert all constants from source (or module) that pass filter() to | |||
| # a new Enum called name, and export the enum and its members back to | |||
| # module; | |||
| # also, replace the __reduce_ex__ method so unpickling works in | |||
| # previous Python versions | |||
| module_globals = vars(sys.modules[module]) | |||
| if source: | |||
| source = vars(source) | |||
| else: | |||
| source = module_globals | |||
| # We use an OrderedDict of sorted source keys so that the | |||
| # _value2member_map is populated in the same order every time | |||
| # for a consistent reverse mapping of number to name when there | |||
| # are multiple names for the same number rather than varying | |||
| # between runs due to hash randomization of the module dictionary. | |||
| members = [(name, source[name]) for name in source.keys() if filter(name)] | |||
| try: | |||
| # sort by value | |||
| members.sort(key=lambda t: (t[1], t[0])) | |||
| except TypeError: | |||
| # unless some values aren't comparable, in which case sort by name | |||
| members.sort(key=lambda t: t[0]) | |||
| cls = cls(name, members, module=module) | |||
| cls.__reduce_ex__ = _reduce_ex_by_name | |||
| module_globals.update(cls.__members__) | |||
| module_globals[name] = cls | |||
| return cls | |||
| class IntEnum(int, Enum): | |||
| """Enum where members are also (and must be) ints""" | |||
| def _reduce_ex_by_name(self, proto): | |||
| return self.name | |||
| class Flag(Enum): | |||
| """Support for flags""" | |||
| def _generate_next_value_(name, start, count, last_values): | |||
| """ | |||
| Generate the next value when not given. | |||
| name: the name of the member | |||
| start: the initital start value or None | |||
| count: the number of existing members | |||
| last_value: the last value assigned or None | |||
| """ | |||
| if not count: | |||
| return start if start is not None else 1 | |||
| for last_value in reversed(last_values): | |||
| try: | |||
| high_bit = _high_bit(last_value) | |||
| break | |||
| except Exception: | |||
| raise TypeError("Invalid Flag value: %r" % last_value) from None | |||
| return 2 ** (high_bit + 1) | |||
| @classmethod | |||
| def _missing_(cls, value): | |||
| original_value = value | |||
| if value < 0: | |||
| value = ~value | |||
| possible_member = cls._create_pseudo_member_(value) | |||
| if original_value < 0: | |||
| possible_member = ~possible_member | |||
| return possible_member | |||
| @classmethod | |||
| def _create_pseudo_member_(cls, value): | |||
| """ | |||
| Create a composite member iff value contains only members. | |||
| """ | |||
| pseudo_member = cls._value2member_map_.get(value, None) | |||
| if pseudo_member is None: | |||
| # verify all bits are accounted for | |||
| _, extra_flags = _decompose(cls, value) | |||
| if extra_flags: | |||
| raise ValueError("%r is not a valid %s" % (value, cls.__name__)) | |||
| # construct a singleton enum pseudo-member | |||
| pseudo_member = object.__new__(cls) | |||
| pseudo_member._name_ = None | |||
| pseudo_member._value_ = value | |||
| # use setdefault in case another thread already created a composite | |||
| # with this value | |||
| pseudo_member = cls._value2member_map_.setdefault(value, pseudo_member) | |||
| return pseudo_member | |||
| def __contains__(self, other): | |||
| if not isinstance(other, self.__class__): | |||
| return NotImplemented | |||
| return other._value_ & self._value_ == other._value_ | |||
| def __repr__(self): | |||
| cls = self.__class__ | |||
| if self._name_ is not None: | |||
| return "<%s.%s: %r>" % (cls.__name__, self._name_, self._value_) | |||
| members, uncovered = _decompose(cls, self._value_) | |||
| return "<%s.%s: %r>" % ( | |||
| cls.__name__, | |||
| "|".join([str(m._name_ or m._value_) for m in members]), | |||
| self._value_, | |||
| ) | |||
| def __str__(self): | |||
| cls = self.__class__ | |||
| if self._name_ is not None: | |||
| return "%s.%s" % (cls.__name__, self._name_) | |||
| members, uncovered = _decompose(cls, self._value_) | |||
| if len(members) == 1 and members[0]._name_ is None: | |||
| return "%s.%r" % (cls.__name__, members[0]._value_) | |||
| else: | |||
| return "%s.%s" % ( | |||
| cls.__name__, | |||
| "|".join([str(m._name_ or m._value_) for m in members]), | |||
| ) | |||
| def __bool__(self): | |||
| return bool(self._value_) | |||
| def __or__(self, other): | |||
| if not isinstance(other, self.__class__): | |||
| return NotImplemented | |||
| return self.__class__(self._value_ | other._value_) | |||
| def __and__(self, other): | |||
| if not isinstance(other, self.__class__): | |||
| return NotImplemented | |||
| return self.__class__(self._value_ & other._value_) | |||
| def __xor__(self, other): | |||
| if not isinstance(other, self.__class__): | |||
| return NotImplemented | |||
| return self.__class__(self._value_ ^ other._value_) | |||
| def __invert__(self): | |||
| members, uncovered = _decompose(self.__class__, self._value_) | |||
| inverted_members = [ | |||
| m | |||
| for m in self.__class__ | |||
| if m not in members and not m._value_ & self._value_ | |||
| ] | |||
| inverted = reduce(_or_, inverted_members, self.__class__(0)) | |||
| return self.__class__(inverted) | |||
| class IntFlag(int, Flag): | |||
| """Support for integer-based Flags""" | |||
| @classmethod | |||
| def _missing_(cls, value): | |||
| if not isinstance(value, int): | |||
| raise ValueError("%r is not a valid %s" % (value, cls.__name__)) | |||
| new_member = cls._create_pseudo_member_(value) | |||
| return new_member | |||
| @classmethod | |||
| def _create_pseudo_member_(cls, value): | |||
| pseudo_member = cls._value2member_map_.get(value, None) | |||
| if pseudo_member is None: | |||
| need_to_create = [value] | |||
| # get unaccounted for bits | |||
| _, extra_flags = _decompose(cls, value) | |||
| # timer = 10 | |||
| while extra_flags: | |||
| # timer -= 1 | |||
| bit = _high_bit(extra_flags) | |||
| flag_value = 2 ** bit | |||
| if ( | |||
| flag_value not in cls._value2member_map_ | |||
| and flag_value not in need_to_create | |||
| ): | |||
| need_to_create.append(flag_value) | |||
| if extra_flags == -flag_value: | |||
| extra_flags = 0 | |||
| else: | |||
| extra_flags ^= flag_value | |||
| for value in reversed(need_to_create): | |||
| # construct singleton pseudo-members | |||
| pseudo_member = int.__new__(cls, value) | |||
| pseudo_member._name_ = None | |||
| pseudo_member._value_ = value | |||
| # use setdefault in case another thread already created a composite | |||
| # with this value | |||
| pseudo_member = cls._value2member_map_.setdefault(value, pseudo_member) | |||
| return pseudo_member | |||
| def __or__(self, other): | |||
| if not isinstance(other, (self.__class__, int)): | |||
| return NotImplemented | |||
| result = self.__class__(self._value_ | self.__class__(other)._value_) | |||
| return result | |||
| def __and__(self, other): | |||
| if not isinstance(other, (self.__class__, int)): | |||
| return NotImplemented | |||
| return self.__class__(self._value_ & self.__class__(other)._value_) | |||
| def __xor__(self, other): | |||
| if not isinstance(other, (self.__class__, int)): | |||
| return NotImplemented | |||
| return self.__class__(self._value_ ^ self.__class__(other)._value_) | |||
| __ror__ = __or__ | |||
| __rand__ = __and__ | |||
| __rxor__ = __xor__ | |||
| def __invert__(self): | |||
| result = self.__class__(~self._value_) | |||
| return result | |||
| def _high_bit(value): | |||
| """returns index of highest bit, or -1 if value is zero or negative""" | |||
| return value.bit_length() - 1 | |||
| def unique(enumeration): | |||
| """Class decorator for enumerations ensuring unique member values.""" | |||
| duplicates = [] | |||
| for name, member in enumeration.__members__.items(): | |||
| if name != member.name: | |||
| duplicates.append((name, member.name)) | |||
| if duplicates: | |||
| alias_details = ", ".join( | |||
| ["%s -> %s" % (alias, name) for (alias, name) in duplicates] | |||
| ) | |||
| raise ValueError( | |||
| "duplicate values found in %r: %s" % (enumeration, alias_details) | |||
| ) | |||
| return enumeration | |||
| def _decompose(flag, value): | |||
| """Extract all members from the value.""" | |||
| # _decompose is only called if the value is not named | |||
| not_covered = value | |||
| negative = value < 0 | |||
| # issue29167: wrap accesses to _value2member_map_ in a list to avoid race | |||
| # conditions between iterating over it and having more psuedo- | |||
| # members added to it | |||
| if negative: | |||
| # only check for named flags | |||
| flags_to_check = [ | |||
| (m, v) | |||
| for v, m in list(flag._value2member_map_.items()) | |||
| if m.name is not None | |||
| ] | |||
| else: | |||
| # check for named flags and powers-of-two flags | |||
| flags_to_check = [ | |||
| (m, v) | |||
| for v, m in list(flag._value2member_map_.items()) | |||
| if m.name is not None or _power_of_two(v) | |||
| ] | |||
| members = [] | |||
| for member, member_value in flags_to_check: | |||
| if member_value and member_value & value == member_value: | |||
| members.append(member) | |||
| not_covered &= ~member_value | |||
| if not members and value in flag._value2member_map_: | |||
| members.append(flag._value2member_map_[value]) | |||
| members.sort(key=lambda m: m._value_, reverse=True) | |||
| if len(members) > 1 and members[0].value == value: | |||
| # we have the breakdown, don't need the value member itself | |||
| members.pop(0) | |||
| return members, not_covered | |||
| def _power_of_two(value): | |||
| if value < 1: | |||
| return False | |||
| return value == 2 ** _high_bit(value) | |||
| @@ -0,0 +1,94 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import warnings | |||
| from ..._imperative_rt.ops import OprAttr | |||
| from . import param_defs | |||
| def make_param(param, ptype, kwargs): | |||
| if param is not None: | |||
| if isinstance(param, ptype): | |||
| return param | |||
| param = [param] | |||
| assert len(param) == len( | |||
| ptype.__slots__ | |||
| ), "{} needs {} params, but {} are provided".format( | |||
| ptype, len(ptype.__slots__), len(param) | |||
| ) | |||
| return ptype(*param) | |||
| ckw = {} | |||
| for i in ptype.__slots__: | |||
| val = kwargs.pop(i, ckw) | |||
| if val is not ckw: | |||
| ckw[i] = val | |||
| return ptype(**ckw) | |||
| class PodOpVisitor: | |||
| __name2subclass = {} | |||
| __c = None | |||
| name = None | |||
| param_names = [] | |||
| config = None | |||
| def __init__(self, config, **params): | |||
| self.config = config | |||
| assert set(params) == set(self.param_names) | |||
| self.__dict__.update(params) | |||
| def __init_subclass__(cls, **kwargs): | |||
| super().__init_subclass__(**kwargs) # python 3.5 does not have this | |||
| name = cls.name | |||
| if name in cls.__name2subclass: | |||
| if not issubclass(cls, cls.__name2subclass[name]): | |||
| warnings.warn("Multiple subclasses for bultin op: %s" % name) | |||
| cls.__name2subclass[name] = cls | |||
| def to_c(self): | |||
| if self.__c: | |||
| return self.__c | |||
| op = OprAttr() | |||
| op.type = self.name | |||
| if self.config is not None: | |||
| op.config = self.config | |||
| # first 4 bytes is TAG, has to remove them currently | |||
| op.param = b"".join(self.__dict__[k].serialize()[4:] for k in self.param_names) | |||
| self.__c = op | |||
| return op | |||
| def __eq__(self, rhs): | |||
| return self.to_c() == rhs.to_c() | |||
| def __repr__(self): | |||
| name = self.__class__.__name__ | |||
| if self.__c: | |||
| return "{}(<binary data>)".format(name) | |||
| kwargs = {} | |||
| for i in self.param_names: | |||
| p = self.__dict__[i] | |||
| if isinstance(p, param_defs._ParamDefBase): | |||
| for k in p.__slots__: | |||
| v = getattr(p, k) | |||
| if isinstance(v, param_defs._EnumBase): | |||
| v = v.name | |||
| kwargs[k] = repr(v) | |||
| else: | |||
| kwargs[i] = repr(p) | |||
| if self.config: | |||
| if len(self.config.comp_node_arr) == 1: | |||
| kwargs["device"] = "'%s'" % self.config.comp_node | |||
| return "{}({})".format( | |||
| name, ", ".join("{}={}".format(k, v) for k, v in kwargs.items()) | |||
| ) | |||
| @@ -0,0 +1,194 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import collections | |||
| import ctypes | |||
| from ..._imperative_rt import OperatorNodeConfig as Config | |||
| from . import param_defs | |||
| from .helper import PodOpVisitor, make_param | |||
| __all__ = ["ConvolutionBackwardData", "Dimshuffle", "Reshape", "AxisAddRemove"] | |||
| class TensorShape: | |||
| MAX_NDIM = 7 | |||
| class ConvolutionBackwardData(PodOpVisitor): | |||
| param_names = ( | |||
| "param", | |||
| "execution_polity", | |||
| ) | |||
| name = "ConvolutionBackwardDataV1" | |||
| def __init__( | |||
| self, | |||
| *, | |||
| param=None, | |||
| execution_polity=None, | |||
| name=None, | |||
| comp_node=None, | |||
| config=None, | |||
| dtype=None, | |||
| **kwargs | |||
| ): | |||
| config = config or Config() | |||
| if name: | |||
| config.name = name | |||
| if comp_node: | |||
| config.comp_node = comp_node | |||
| if dtype: | |||
| config.dtype = dtype | |||
| self.config = config | |||
| self.param = make_param(param, param_defs.Convolution, kwargs) | |||
| self.execution_polity = make_param( | |||
| execution_polity, param_defs.ExecutionPolicy, kwargs | |||
| ) | |||
| assert not kwargs, "extra kwargs: {}".format(kwargs) | |||
| class Dimshuffle(PodOpVisitor): | |||
| name = "Dimshuffle" | |||
| param_names = ("pattern",) | |||
| class Pattern(ctypes.Structure): | |||
| Pattern_Array = ctypes.c_int32 * TensorShape.MAX_NDIM | |||
| _fields_ = [ | |||
| ("length", ctypes.c_uint32), | |||
| ("pattern", Pattern_Array), | |||
| ("ndim", ctypes.c_uint32), | |||
| ] | |||
| def serialize(self): | |||
| return bytes(ctypes.c_uint32(0)) + bytes(self) | |||
| def __init__(self, pattern, ndim=0): | |||
| assert isinstance(pattern, collections.Iterable) | |||
| assert len(pattern) <= TensorShape.MAX_NDIM | |||
| pattern_array = Dimshuffle.Pattern.Pattern_Array() | |||
| for idx, v in enumerate(pattern): | |||
| pattern_array[idx] = ctypes.c_int32(-1 if v == "x" else int(v)) | |||
| self.pattern = Dimshuffle.Pattern(len(pattern), pattern_array, ndim) | |||
| class Reshape(PodOpVisitor): | |||
| name = "ReshapeV1" | |||
| param_names = ("unspec_axis",) | |||
| def __init__(self, unspec_axis=None): | |||
| if unspec_axis is None: | |||
| self.unspec_axis = param_defs.OptionalAxisV1() | |||
| else: | |||
| self.unspec_axis = param_defs.OptionalAxisV1(unspec_axis) | |||
| class AxisNum(ctypes.Structure): | |||
| _fields_ = [ | |||
| ("m_num", ctypes.c_int), | |||
| ] | |||
| class AxisDesc(ctypes.Structure): | |||
| class Method(ctypes.c_int): | |||
| ADD_1 = 0 | |||
| REMOVE = 1 | |||
| _fields_ = [ | |||
| ("method", Method), | |||
| ("axis", AxisNum), | |||
| ] | |||
| @classmethod | |||
| def make_add(cls, axis): | |||
| return cls(cls.Method.ADD_1, AxisNum(axis)) | |||
| @classmethod | |||
| def make_remove(cls, axis): | |||
| return cls(cls.Method.REMOVE, AxisNum(axis)) | |||
| class AxisAddRemove(PodOpVisitor): | |||
| name = "AxisAddRemove" | |||
| param_names = ("param",) | |||
| AxisDesc = AxisDesc | |||
| class Param(ctypes.Structure): | |||
| MAX_DESC_SIZE = TensorShape.MAX_NDIM * 2 | |||
| _fields_ = [("nr_desc", ctypes.c_uint32), ("desc", AxisDesc * MAX_DESC_SIZE)] | |||
| def __init__(self, *args): | |||
| super().__init__() | |||
| self.nr_desc = len(args) | |||
| for i, a in enumerate(args): | |||
| self.desc[i] = a | |||
| def serialize(self): | |||
| return bytes(ctypes.c_uint32(0)) + bytes(self) | |||
| def __init__(self, param): | |||
| assert isinstance(param, self.Param) | |||
| self.param = param | |||
| del AxisDesc | |||
| class IndexingOpBase(PodOpVisitor): | |||
| param_names = ("index_desc",) | |||
| class IndexDescMaskDump(ctypes.Structure): | |||
| class Item(ctypes.Structure): | |||
| _fields_ = [ | |||
| ("axis", ctypes.c_int8), | |||
| ("begin", ctypes.c_bool), | |||
| ("end", ctypes.c_bool), | |||
| ("step", ctypes.c_bool), | |||
| ("idx", ctypes.c_bool), | |||
| ] | |||
| Item_Array = Item * TensorShape.MAX_NDIM | |||
| _fields_ = [("nr_item", ctypes.c_uint8), ("items", Item_Array)] | |||
| def serialize(self): | |||
| return bytes(ctypes.c_uint32(0)) + bytes(self) | |||
| def __init__(self, items): | |||
| nr_item = len(items) | |||
| assert nr_item <= TensorShape.MAX_NDIM | |||
| item_array = IndexingOpBase.IndexDescMaskDump.Item_Array() | |||
| for idx, item in enumerate(items): | |||
| assert isinstance(item, (tuple, list)) and len(item) == 5 | |||
| item_array[idx] = IndexingOpBase.IndexDescMaskDump.Item(*item) | |||
| self.index_desc = IndexingOpBase.IndexDescMaskDump(nr_item, item_array) | |||
| def _gen_indexing_defs(*names): | |||
| for name in names: | |||
| globals()[name] = type(name, (IndexingOpBase,), dict(name=name)) | |||
| __all__.append(name) | |||
| _gen_indexing_defs( | |||
| "Subtensor", | |||
| "SetSubtensor", | |||
| "IncrSubtensor", | |||
| "IndexingMultiAxisVec", | |||
| "IndexingSetMultiAxisVec", | |||
| "IndexingIncrMultiAxisVec", | |||
| "MeshIndexing", | |||
| "IncrMeshIndexing", | |||
| "SetMeshIndexing", | |||
| "BatchedMeshIndexing", | |||
| "BatchedIncrMeshIndexing", | |||
| "BatchedSetMeshIndexing", | |||
| ) | |||
| @@ -0,0 +1,37 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import warnings | |||
| from typing import Union | |||
| from ..._imperative_rt import OpDef, ops | |||
| from ...tensor.core import OpBase, TensorBase, TensorWrapperBase, apply | |||
| from .._internal import all_ops | |||
| from .._internal.helper import PodOpVisitor | |||
| # register OpDef as a "virtual subclass" of OpBase, so any of registered | |||
| # apply(OpBase, ...) rules could work well on OpDef | |||
| OpBase.register(OpDef) | |||
| # forward to apply(OpDef, ...) | |||
| @apply.add | |||
| def _(op: PodOpVisitor, *args: Union[TensorBase, TensorWrapperBase]): | |||
| return apply(op.to_c(), *args) | |||
| __all__ = ["OpDef", "PodOpVisitor"] | |||
| for k, v in all_ops.__dict__.items(): | |||
| if isinstance(v, type) and issubclass(v, PodOpVisitor): | |||
| globals()[k] = v | |||
| __all__.append(k) | |||
| for k, v in ops.__dict__.items(): | |||
| if isinstance(v, type) and issubclass(v, OpDef): | |||
| globals()[k] = v | |||
| __all__.append(k) | |||
| @@ -0,0 +1,16 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from ..tensor.core import OpBase, TensorBase, apply | |||
| class Const(OpBase): | |||
| def __init__(self, value=None, *, dtype=None, device=None): | |||
| self.value = value | |||
| self.dtype = dtype | |||
| self.device = device | |||
| @@ -0,0 +1,9 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from .tensor_wrapper import TensorWrapper as Tensor | |||
| @@ -0,0 +1,115 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import collections | |||
| import functools | |||
| import inspect | |||
| import sys | |||
| import typing | |||
| from abc import ABC | |||
| import multipledispatch | |||
| class OpBase(ABC): | |||
| def __call__(self, *args): | |||
| return apply(self, *args) | |||
| class TensorBase: | |||
| pass | |||
| class TensorWrapperBase: | |||
| pass | |||
| class Dispatcher(multipledispatch.Dispatcher): | |||
| def add(self, f, g=None): | |||
| if g is None: | |||
| super().add(get_signature(f), f) | |||
| else: | |||
| super().add(f, g) | |||
| return f | |||
| def __get__(self, instance, owner=None): | |||
| if instance is not None: | |||
| return self | |||
| return functools.partial(self, instance) | |||
| if sys.version_info < (3, 6): | |||
| def parse_union(ann): | |||
| if type(ann) is not typing.UnionMeta: | |||
| return | |||
| return ann.__union_params__ | |||
| elif sys.version_info < (3, 7): | |||
| def parse_union(ann): | |||
| if type(ann) is not typing._Union: | |||
| return | |||
| return ann.__args__ | |||
| elif sys.version_info < (3, 8): | |||
| def parse_union(ann): | |||
| if type(ann) is not typing._GenericAlias: | |||
| if type(ann) is not typing.Union: | |||
| return | |||
| else: | |||
| if ann.__origin__ is not typing.Union: | |||
| return | |||
| return ann.__args__ | |||
| else: | |||
| def parse_union(ann): | |||
| if typing.get_origin(ann) is not typing.Union: | |||
| return | |||
| return typing.get_args(ann) | |||
| def get_signature(function, op_type=None): | |||
| sig = inspect.signature(function) | |||
| types = [] | |||
| for p in sig.parameters.values(): | |||
| ann = p.annotation | |||
| ann = parse_union(ann) or ann | |||
| if p.kind in ( | |||
| inspect.Parameter.POSITIONAL_ONLY, | |||
| inspect.Parameter.POSITIONAL_OR_KEYWORD, | |||
| ): | |||
| types.append(ann) | |||
| if p.kind == inspect.Parameter.VAR_POSITIONAL: | |||
| types.append([ann]) | |||
| return tuple(types) | |||
| apply = Dispatcher("apply") | |||
| OpBase.apply = apply | |||
| @apply.add | |||
| def _(op: OpBase, *args: TensorBase): | |||
| raise NotImplementedError | |||
| @apply.add | |||
| def _(op: OpBase, *args: TensorWrapperBase): | |||
| assert args | |||
| Wrapper = type(args[0]) | |||
| outputs = apply(op, *(i.__wrapped__ for i in args)) | |||
| assert isinstance(outputs, tuple) | |||
| return tuple(map(Wrapper, outputs)) | |||
| @@ -0,0 +1,289 @@ | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import collections | |||
| from typing import Union | |||
| import numpy as np | |||
| # normal dtype related | |||
| from .._imperative_rt import bfloat16, intb1, intb2, intb4 | |||
| def is_lowbit(dtype): | |||
| return (dtype is intb1) or (dtype is intb2) or (dtype is intb4) | |||
| def is_bfloat16(dtype): | |||
| return dtype is bfloat16 | |||
| # quantization dtype related | |||
| _QuantDtypeMetadata = collections.namedtuple( | |||
| "QuantDtypeMetadata", ["name", "np_dtype_str", "is_unsigned", "qmin", "qmax",] | |||
| ) | |||
| _metadata_dict = { | |||
| "quint8": _QuantDtypeMetadata("Quantized8Asymm", "uint8", True, 0, 255), | |||
| "qint8": _QuantDtypeMetadata("QuantizedS8", "int8", False, -128, 127), | |||
| "quint4": _QuantDtypeMetadata("Quantized4Asymm", "uint8", True, 0, 15), | |||
| "qint4": _QuantDtypeMetadata("QuantizedS4", "int8", False, -8, 7), | |||
| "qint32": _QuantDtypeMetadata( | |||
| "QuantizedS32", "int32", False, -(2 ** 31), 2 ** 31 - 1, | |||
| ), | |||
| # NOTE: int2 is not supported for model dump yet | |||
| "quint2": _QuantDtypeMetadata(None, "uint8", True, 0, 3), | |||
| "qint2": _QuantDtypeMetadata(None, "int8", False, -2, 1), | |||
| } | |||
| def is_quantize(dtype): | |||
| return ( | |||
| hasattr(dtype, "metadata") | |||
| and dtype.metadata is not None | |||
| and "mgb_dtype" in dtype.metadata | |||
| ) | |||
| def get_scale(dtype): | |||
| assert is_quantize(dtype) | |||
| return dtype.metadata["mgb_dtype"]["scale"] | |||
| def get_zero_point(dtype): | |||
| assert is_quantize(dtype) | |||
| metadata = dtype.metadata["mgb_dtype"] | |||
| assert metadata["name"] in ("Quantized8Asymm", "Quantized4Asymm") | |||
| return metadata["zero_point"] | |||
| def _check_zero_point(zp: int, dtype_str: str): | |||
| qmin = _metadata_dict[dtype_str].qmin | |||
| qmax = _metadata_dict[dtype_str].qmax | |||
| if zp < qmin or zp > qmax: | |||
| raise ValueError( | |||
| "zero_point should be within [{}, {}] for {}".format(qmin, qmax, dtype_str) | |||
| ) | |||
| def get_quantized_dtype(dtype_str: str, scale: float, zp: Union[int, None]): | |||
| r""" | |||
| Get quantized dtype with metadata attribute according to _metadata_dict. | |||
| Note that unsigned dtype must have ``zero_point`` and signed dtype must | |||
| not have ``zero_point``, to be consitent with tensor generated by calling | |||
| compiled function from `CompGraph.compile(inputs, outspec)`. | |||
| :param dtype: a string indicating which dtype to return | |||
| :param scale: a number for scale to store in dtype's metadata | |||
| :param zp: a number for zero_point to store in dtype's metadata | |||
| """ | |||
| metadata = _metadata_dict[dtype_str] | |||
| np_dtype_str = metadata.np_dtype_str | |||
| is_unsigned = metadata.is_unsigned | |||
| if is_unsigned: | |||
| if zp is None or int(zp) != zp: | |||
| raise ValueError("zero_point should be an integer") | |||
| zp = int(zp) | |||
| _check_zero_point(zp, dtype_str) | |||
| return np.dtype( | |||
| np_dtype_str, | |||
| metadata={ | |||
| "mgb_dtype": { | |||
| "name": metadata.name, | |||
| "scale": float(scale), | |||
| "zero_point": zp, | |||
| } | |||
| }, | |||
| ) | |||
| else: | |||
| return np.dtype( | |||
| np_dtype_str, | |||
| metadata={"mgb_dtype": {"name": metadata.name, "scale": float(scale)}}, | |||
| ) | |||
| def quint8(scale, zero_point): | |||
| """ | |||
| Consturct a quantized unsigned int8 data type with ``scale`` (float) and | |||
| ``zero_point`` (uint8). The real value represented by a quint8 data type is | |||
| float_val = scale * (uint8_val - zero_point) | |||
| """ | |||
| return get_quantized_dtype("quint8", scale, zero_point) | |||
| def qint8(scale): | |||
| """ | |||
| Construct a quantized int8 data type with ``scale`` (float). The real value | |||
| represented by a qint8 data type is float_val = scale * int8_val | |||
| """ | |||
| return get_quantized_dtype("qint8", scale, None) | |||
| def qint32(scale): | |||
| """ | |||
| Construct a quantized int32 data type with ``scale`` (float). The real value | |||
| represented by a qint32 data type is float_val = scale * int32_val | |||
| """ | |||
| return get_quantized_dtype("qint32", scale, None) | |||
| def quint4(scale, zero_point): | |||
| """ | |||
| Consturct a quantized unsigned int4 data type with ``scale`` (float) and | |||
| ``zero_point`` (uint8). The real value represented by a quint4 data type is | |||
| float_val = scale * (uint4_val - zero_point) | |||
| """ | |||
| return get_quantized_dtype("quint4", scale, zero_point) | |||
| def qint4(scale): | |||
| """ | |||
| Construct a quantized int4 data type with ``scale`` (float). The real value | |||
| represented by a qint4 data type is float_val = scale * int4_val | |||
| """ | |||
| return get_quantized_dtype("qint4", scale, None) | |||
| def _convert_to_quantized_dtype(arr: np.ndarray, dtype: np.dtype, dtype_str: str): | |||
| metadata = _metadata_dict[dtype_str] | |||
| arr_metadata = dtype.metadata["mgb_dtype"] | |||
| if not isinstance(arr, np.ndarray): | |||
| raise ValueError("arr parameter should be instance of np.ndarray") | |||
| if not is_quantize(dtype) or arr_metadata["name"] != metadata.name: | |||
| raise ValueError("dtype parameter should be a {} dtype".format(dtype_str)) | |||
| is_unsigned = metadata.is_unsigned | |||
| if is_unsigned: | |||
| scale, zp = ( | |||
| arr_metadata["scale"], | |||
| arr_metadata["zero_point"], | |||
| ) | |||
| return ( | |||
| (np.round(arr / scale) + zp) | |||
| .clip(metadata.qmin, metadata.qmax) | |||
| .astype(dtype) | |||
| ) | |||
| else: | |||
| # don't trick to combine with is_unsigned, seeing ``get_quantized_dtype`` | |||
| scale = arr_metadata["scale"] | |||
| return np.round(arr / scale).clip(metadata.qmin, metadata.qmax).astype(dtype) | |||
| def _convert_from_quantized_dtype(arr: np.ndarray, dtype_str: str): | |||
| metadata = _metadata_dict[dtype_str] | |||
| arr_metadata = arr.dtype.metadata["mgb_dtype"] | |||
| if not isinstance(arr, np.ndarray): | |||
| raise ValueError("arr parameter should be instance of np.ndarray") | |||
| if not is_quantize(arr.dtype) or arr_metadata["name"] != metadata.name: | |||
| raise ValueError("arr's dtype should be a {} dtype".format(dtype_str)) | |||
| is_unsigned = metadata.is_unsigned | |||
| if is_unsigned: | |||
| scale, zp = ( | |||
| arr_metadata["scale"], | |||
| arr_metadata["zero_point"], | |||
| ) | |||
| return (arr.astype(np.float32) - zp) * scale | |||
| else: | |||
| # don't trick to combine with is_unsigned, seeing ``get_quantized_dtype`` | |||
| scale = arr_metadata["scale"] | |||
| return (arr.astype(np.float32)) * scale | |||
| def convert_to_quint8(arr: np.ndarray, q: np.dtype): | |||
| """ | |||
| Quantize a float NumPy ndarray into a quint8 one with specified params. | |||
| :param arr: Input ndarray. | |||
| :param q: Target data type, should be a quint8. | |||
| """ | |||
| return _convert_to_quantized_dtype(arr, q, "quint8") | |||
| def convert_from_quint8(arr: np.ndarray): | |||
| """ | |||
| Dequantize a quint8 NumPy ndarray into a float one. | |||
| :param arr: Input ndarray. | |||
| """ | |||
| return _convert_from_quantized_dtype(arr, "quint8") | |||
| def convert_to_qint8(arr: np.ndarray, q: np.dtype): | |||
| """ | |||
| Quantize a float NumPy ndarray into a qint8 one with specified params. | |||
| :param arr: Input ndarray. | |||
| :param q: Target data type, should be a qint8. | |||
| """ | |||
| return _convert_to_quantized_dtype(arr, q, "qint8") | |||
| def convert_from_qint8(arr: np.ndarray): | |||
| """ | |||
| Dequantize a qint8 NumPy ndarray into a float one. | |||
| :param arr: Input ndarray. | |||
| """ | |||
| return _convert_from_quantized_dtype(arr, "qint8") | |||
| def convert_to_qint32(arr: np.ndarray, q: np.dtype): | |||
| """ | |||
| Quantize a float NumPy ndarray into a qint32 one with specified params. | |||
| :param arr: Input ndarray. | |||
| :param q: Target data type, should be a qint8. | |||
| """ | |||
| return _convert_to_quantized_dtype(arr, q, "qint32") | |||
| def convert_from_qint32(arr): | |||
| """ | |||
| Dequantize a qint32 NumPy ndarray into a float one. | |||
| :param arr: Input ndarray. | |||
| """ | |||
| return _convert_from_quantized_dtype(arr, "qint32") | |||
| def convert_to_quint4(arr: np.ndarray, q: np.dtype): | |||
| """ | |||
| Quantize a float NumPy ndarray into a quint4 one with specified params. | |||
| :param arr: Input ndarray. | |||
| :param q: Target data type, should be a quint4. | |||
| """ | |||
| return _convert_to_quantized_dtype(arr, q, "quint4") | |||
| def convert_from_quint4(arr: np.ndarray): | |||
| """ | |||
| Dequantize a quint4 NumPy ndarray into a float one. | |||
| :param arr: Input ndarray. | |||
| """ | |||
| return _convert_from_quantized_dtype(arr, "quint4") | |||
| def convert_to_qint4(arr: np.ndarray, q: np.dtype): | |||
| """ | |||
| Quantize a float NumPy ndarray into a qint4 one with specified params. | |||
| :param arr: Input ndarray. | |||
| :param q: Target data type, should be a qint4. | |||
| """ | |||
| return _convert_to_quantized_dtype(arr, q, "qint4") | |||
| def convert_from_qint4(arr: np.ndarray): | |||
| """ | |||
| Dequantize a qint4 NumPy ndarray into a float one. | |||
| :param arr: Input ndarray. | |||
| """ | |||
| return _convert_from_quantized_dtype(arr, "qint4") | |||
| @@ -0,0 +1,158 @@ | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from ..ops.builtin import OpDef | |||
| from .core import TensorBase, TensorWrapperBase, apply | |||
| from .raw_tensor import RawTensor | |||
| from .tensor import Tensor, push_context | |||
| from .tensor_wrapper import TensorWrapper | |||
| class Function: | |||
| """ | |||
| Defines a block of operations with customizable differentiation. | |||
| The computation should be defined in ``forward`` method, with gradient | |||
| computation defined in ``backward`` method. | |||
| Each instance of ``Function`` should be used only once during forwardding. | |||
| Examples: | |||
| .. testcode:: | |||
| class Sigmoid(Function): | |||
| def forward(self, x): | |||
| y = 1 / (1 + F.exp(-x)) | |||
| self.y = y | |||
| return y | |||
| def backward(self. output_grads): | |||
| y = self.y | |||
| return output_grads * y * (1-y) | |||
| """ | |||
| def __init__(self, *args, **kwargs): | |||
| pass | |||
| def __call__(self, *args): | |||
| ret = apply(self, *args) | |||
| if type(ret) == tuple and len(ret) == 1: | |||
| return ret[0] | |||
| return ret | |||
| def forward(self, *args, **kwargs): | |||
| """ | |||
| Applies operations to ``inputs`` and returns results. It must be overriden by all subclasses. | |||
| :param input: Input tensors. | |||
| :return: A tuple of Tensor or a single Tensor. | |||
| .. note:: | |||
| This method should return a tuple of Tensor or a single Tensor representing the output | |||
| of the function. | |||
| """ | |||
| raise NotImplementedError | |||
| def backward(self, *output_grads): | |||
| """ | |||
| Compute the gradient of the forward function. It must be overriden by all subclasses. | |||
| :param output_grads: gradients of outputs that are returned by :meth:`~.function.Function.forward` | |||
| .. note:: | |||
| In case when some tensors of outputs are not related to loss function, the corresponding | |||
| values in ``output_grads`` would be ``None``. | |||
| .. note:: | |||
| This method should return a tuple which containing the gradients of all inputs, in the same order | |||
| as the ``inputs`` argument of :meth:`~.function.Function.forward` . A ``Tensor`` could be returned | |||
| instead if there is only one input. If users want to stop the propagation of some gradients, | |||
| the corresponding returned values should be set ``None`` . | |||
| """ | |||
| raise NotImplementedError | |||
| def get_backward_fn(self): | |||
| if self.backward is None: | |||
| return None | |||
| def _backward(*output_grads): | |||
| if type(output_grads) is tuple: | |||
| _output_grads = map(TensorWrapper, output_grads) | |||
| else: | |||
| _output_grads = (TensorWrapper(output_grads),) | |||
| ret = self.backward(*_output_grads) | |||
| if type(ret) is not tuple: | |||
| ret = (ret,) | |||
| ret = tuple([i.__wrapped__ for i in ret]) | |||
| return ret | |||
| return _backward | |||
| Function.apply = Function.__call__ | |||
| @apply.add | |||
| def _(op: Function, *args: TensorWrapperBase): | |||
| assert args | |||
| Wrapper = type(args[0]) | |||
| # compute the value for self define function | |||
| extra_data_dic = {} | |||
| for arg in args: | |||
| extra_data_dic[arg.__wrapped__] = arg.__wrapped__._extra_data | |||
| arg.__wrapped__._extra_data = {} | |||
| rets = op.forward(*args) | |||
| for arg in args: | |||
| arg.__wrapped__._extra_data = extra_data_dic[arg.__wrapped__] | |||
| # update the gradient information for self define function | |||
| inputs = tuple(map(lambda i: i.__wrapped__, args)) | |||
| outputs = ( | |||
| tuple(map(lambda i: i.__wrapped__, rets)) | |||
| if type(rets) is tuple | |||
| else (rets.__wrapped__,) | |||
| ) | |||
| for output in outputs: | |||
| output._extra_data = {} | |||
| with push_context() as ctx: | |||
| ctx.inputs = inputs | |||
| ctx.outputs = outputs | |||
| for k in set().union(*(i._extra_data for i in inputs if isinstance(i, Tensor))): | |||
| ctx.key = k | |||
| data = tuple( | |||
| i._extra_data.get(k) if isinstance(i, Tensor) else i for i in inputs | |||
| ) | |||
| # data are instances of Tracer | |||
| # dispatched to apply.add@grad.py | |||
| rets = apply(op, *data) | |||
| if rets is not None: | |||
| assert len(outputs) == len(rets) | |||
| for t, i in zip(outputs, rets): | |||
| t._extra_data[k] = i | |||
| return tuple(map(Wrapper, outputs)) | |||
| @apply.add | |||
| def _(op: Function, *args: Tensor): | |||
| raise NotImplementedError | |||
| @apply.add | |||
| def _(op: Function, *args: RawTensor): | |||
| raise NotImplementedError | |||
| @@ -0,0 +1,251 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import numpy as np | |||
| from ..ops import builtin | |||
| from ..ops.special import Const | |||
| from .core import TensorBase, TensorWrapperBase, apply | |||
| def remove_ellipsis(tensor, tuple_val): | |||
| ndim_sum = tensor.ndim | |||
| cur_sum = 0 | |||
| pos = -1 | |||
| for i_idx, i in enumerate(tuple_val): | |||
| if i is Ellipsis: | |||
| for j in tuple_val[:i_idx:-1]: | |||
| if j is Ellipsis: | |||
| raise IndexError("only one ellipsis is allowed") | |||
| pos = i_idx | |||
| else: | |||
| cur_sum += i.ndim if hasattr(i, "ndim") else 1 | |||
| if pos == -1: | |||
| return tuple_val | |||
| else: | |||
| return ( | |||
| tuple_val[:pos] | |||
| + (slice(None, None, None),) * (ndim_sum - cur_sum) | |||
| + tuple_val[pos + 1 :] | |||
| ) | |||
| def check_bool_index(tensor, tuple_val): | |||
| cur_shape = tensor.shape | |||
| new_tuple_val = [] | |||
| offset = 0 | |||
| tdim = 0 | |||
| for idx, i in enumerate(tuple_val): | |||
| if hasattr(i, "dtype") and i.dtype == np.bool_: | |||
| if i.ndim > 1: | |||
| tot = i.ndim | |||
| for j in range(i.ndim): | |||
| if cur_shape[tdim + j - offset] != i.shape[j]: | |||
| raise IndexError( | |||
| "boolean index did not match tensor along dimension {}; dimension is {} but corresponding boolean dimension is {}".format( | |||
| tdim + j, cur_shape[tdim + j - offset], i.shape[j] | |||
| ) | |||
| ) | |||
| i = i.reshape(-1) | |||
| cur_shape = ( | |||
| cur_shape[:idx] + (i.shape[0],) + cur_shape[tdim + tot - offset :] | |||
| ) | |||
| offset += 1 | |||
| tensor = tensor.reshape(cur_shape) | |||
| tdim += tot | |||
| new_tuple_val.append(i) | |||
| else: | |||
| new_tuple_val.append(i) | |||
| tdim += 1 | |||
| return tensor, new_tuple_val | |||
| def unpack_getitem(inp, tuple_val, *, allow_newaxis=True): | |||
| if not isinstance(tuple_val, tuple): | |||
| tuple_val = (tuple_val,) | |||
| ndim_indexed = 0 | |||
| for i in tuple_val: | |||
| if not i is Ellipsis: | |||
| ndim_indexed += 1 if not hasattr(i, "ndim") else i.ndim | |||
| if ndim_indexed > inp.ndim: | |||
| raise IndexError( | |||
| "too many indices for tensor: tensor is {}-dimensional, but {} were indexed".format( | |||
| inp.ndim, ndim_indexed | |||
| ) | |||
| ) | |||
| tuple_val = remove_ellipsis(inp, tuple_val) | |||
| use_subtensor = True | |||
| inp, tuple_val = check_bool_index(inp, tuple_val) | |||
| def is_scalar(d): | |||
| if isinstance(i, int): | |||
| return True | |||
| if type(d).__module__ == np.__name__: | |||
| return np.isscalar(d) | |||
| # if isinstance(d, (TensorBase, TensorWrapperBase)): | |||
| # return d.shape == (1,) | |||
| return False | |||
| new_axes = [] | |||
| tensors = [] | |||
| items = [] | |||
| cur_axis = -1 | |||
| for i_idx, i in enumerate(tuple_val): | |||
| cur_axis += 1 | |||
| if i is np.newaxis: | |||
| if cur_axis >= 0: | |||
| new_axes.append(cur_axis) | |||
| continue | |||
| if i is Ellipsis: | |||
| cur_axis = -1 | |||
| for j in tuple_val[:i_idx:-1]: | |||
| if j is Ellipsis: | |||
| raise IndexError("only one ellipsis is allowed") | |||
| if j is np.newaxis: | |||
| new_axes.append(cur_axis) | |||
| cur_axis -= 1 | |||
| continue | |||
| if ( | |||
| not is_scalar(i) | |||
| and not i is np.newaxis | |||
| and not i is Ellipsis | |||
| and not isinstance(i, slice) | |||
| ): | |||
| use_subtensor = False | |||
| item = [ | |||
| cur_axis, | |||
| ] | |||
| def is_bool_list(x): | |||
| if not isinstance(x, list): | |||
| return False | |||
| for i in x: | |||
| if not isinstance(i, bool): | |||
| return False | |||
| return True | |||
| def get_index(i): | |||
| if not isinstance(i, (TensorBase, TensorWrapperBase)): | |||
| if is_bool_list(i) or isinstance(i, np.ndarray) and i.dtype == np.bool_: | |||
| (i,) = Const(i, dtype=np.bool_, device=inp.device)(inp) | |||
| else: | |||
| (i,) = Const(i, dtype=np.int32, device=inp.device)(inp) | |||
| return i | |||
| assert isinstance(i, (TensorBase, TensorWrapperBase)) | |||
| if i.dtype != np.bool_: | |||
| return i | |||
| _, ind = apply(builtin.CondTake(), i, i) | |||
| return ind | |||
| def push(v, item, tensors): | |||
| if v is None: | |||
| item.append(False) | |||
| else: | |||
| item.append(True) | |||
| v = get_index(v) | |||
| assert np.issubdtype(v.dtype, np.integer) or np.issubdtype( | |||
| v.dtype, np.bool | |||
| ), "var type in the subscript must be int or bool" | |||
| tensors.append(v) | |||
| if isinstance(i, slice): | |||
| if i.start is None and i.stop is None and i.step is None: | |||
| continue | |||
| push(i.start, item, tensors) | |||
| push(i.stop, item, tensors) | |||
| push(i.step, item, tensors) | |||
| item.append(False) # idx | |||
| else: | |||
| item += [False,] * 3 # begin, end, stop | |||
| push(i, item, tensors) | |||
| assert len(item) == 5 | |||
| items.append(item) | |||
| if new_axes: | |||
| raise IndexError("newaxis is not allowed here") | |||
| return inp, tensors, items, use_subtensor | |||
| def try_condtake(tensor, index): | |||
| if not hasattr(index, "dtype") or not hasattr(index, "shape"): | |||
| return [] | |||
| if index.dtype != np.bool_ or index.shape != tensor.shape: | |||
| return [] | |||
| if isinstance(index, np.ndarray): | |||
| (i,) = Const(i, dtype=np.bool_, device=inp.device)(inp) | |||
| assert isinstance(index, (TensorBase, TensorWrapperBase)) | |||
| if not isinstance(tensor, (TensorWrapperBase, TensorBase)): | |||
| raise TypeError("input must be a tensor") | |||
| if tensor.device != index.device: | |||
| raise ValueError( | |||
| "ambiguous device: {} vs {}".format(tensor.device, index.device) | |||
| ) | |||
| return apply(builtin.CondTake(), tensor, index) | |||
| def getitem(tensor, index): | |||
| try_result = try_condtake(tensor, index) | |||
| if len(try_result) == 2: | |||
| return try_result[0] | |||
| tensor, tensors, items, use_subtensor = unpack_getitem(tensor, index) | |||
| for v in tensors: | |||
| if v.shape[0] == 0: | |||
| (empty_tensor,) = Const([], dtype=tensor.dtype, device=tensor.device)( | |||
| tensor | |||
| ) | |||
| return empty_tensor | |||
| if use_subtensor: | |||
| op = builtin.Subtensor(items=items) | |||
| else: | |||
| op = builtin.IndexingMultiAxisVec(items=items) | |||
| (result,) = apply(op, tensor, *tensors) | |||
| return result | |||
| def setitem(tensor, index, value): | |||
| org_shape = tensor.shape | |||
| try_result = try_condtake(tensor, index) | |||
| if len(try_result) == 2: | |||
| index = try_result[1] | |||
| if index.shape[0] == 0: | |||
| return tensor | |||
| tensor = tensor.reshape(-1) | |||
| if not isinstance(value, (TensorBase, TensorWrapperBase)): | |||
| op = Const(value, dtype=tensor.dtype, device=tensor.device) | |||
| (value,) = op(tensor) | |||
| tensor, tensors, items, use_subtensor = unpack_getitem(tensor, index) | |||
| for v in tensors: | |||
| if v.shape[0] == 0: | |||
| return tensor | |||
| if use_subtensor: | |||
| op = builtin.Subtensor(items=items) | |||
| else: | |||
| op = builtin.IndexingMultiAxisVec(items=items) | |||
| (tmp_result,) = apply(op, tensor, *tensors) | |||
| if value.shape != tmp_result.shape: | |||
| for i in range(min(len(value.shape), len(tmp_result.shape))): | |||
| if ( | |||
| value.shape[-i - 1] != 1 | |||
| and value.shape[-i - 1] != tmp_result.shape[-i - 1] | |||
| ): | |||
| raise ValueError( | |||
| "cannot copy tensor with shape {} to subtensor with shape {}".format( | |||
| value.shape, tmp_result.shape | |||
| ) | |||
| ) | |||
| value = value.broadcast(tmp_result.shape) | |||
| if use_subtensor: | |||
| op = builtin.SetSubtensor(items=items) | |||
| else: | |||
| op = builtin.IndexingSetMultiAxisVec(items=items) | |||
| (result,) = apply(op, tensor, value, *tensors) | |||
| result = result.reshape(org_shape) | |||
| return result | |||
| @@ -0,0 +1,196 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import collections | |||
| import threading | |||
| import weakref | |||
| from concurrent.futures import Future, ThreadPoolExecutor | |||
| from .. import _imperative_rt | |||
| from .._wrap import device as as_device | |||
| from ..ops.builtin import OpDef | |||
| from .core import OpBase, TensorBase, apply | |||
| class CompiledFunction: | |||
| def __init__(self, graph, function): | |||
| self._graph = graph | |||
| self._function = function | |||
| self._future = None | |||
| def execute(self, *args): | |||
| assert self._future is None | |||
| self._future = self._graph._executor.submit(self._function.execute, *args) | |||
| def wait(self): | |||
| assert self._future is not None | |||
| self._future.exception() | |||
| self._function.wait() | |||
| try: | |||
| return self._future.result() | |||
| finally: | |||
| self._future = None | |||
| def __call__(self, *args): | |||
| self.execute(*args) | |||
| return self.wait() | |||
| class Graph(_imperative_rt.ComputingGraph): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self._var_cache = weakref.WeakKeyDictionary() | |||
| self._op_cache = weakref.WeakKeyDictionary() | |||
| self._executor = ThreadPoolExecutor(1) | |||
| def _wrap(self, obj): | |||
| if type(obj) is _imperative_rt.VarNode: | |||
| wrapper, cache = VarNode, self._var_cache | |||
| elif type(obj) is _imperative_rt.OperatorNode: | |||
| wrapper, cache = OpNode, self._op_cache | |||
| if obj not in cache: | |||
| cache[obj] = wrapper(obj) | |||
| return cache[obj] | |||
| def compile(self, *args): | |||
| return CompiledFunction(self, super().compile(_unwrap(args))) | |||
| class VarNode(TensorBase): | |||
| def __init__(self, node: _imperative_rt.VarNode): | |||
| self._node = node | |||
| @property | |||
| def graph(self) -> Graph: | |||
| return self._node.graph | |||
| @property | |||
| def op(self): | |||
| return self.graph._wrap(self._node.owner) | |||
| @property | |||
| def dtype(self): | |||
| return self._node.dtype | |||
| @property | |||
| def device(self): | |||
| return as_device(self._node.comp_node) | |||
| class OpNode: | |||
| def __init__(self, node: _imperative_rt.OperatorNode): | |||
| self._node = node | |||
| @property | |||
| def graph(self) -> Graph: | |||
| return self._node.graph | |||
| @property | |||
| def inputs(self): | |||
| return tuple(map(self.graph._wrap, self._node.inputs)) | |||
| @property | |||
| def outputs(self): | |||
| return tuple(map(self.graph._wrap, self._node.outputs)) | |||
| def _wrap(x): | |||
| if isinstance(x, collections.Sequence): | |||
| return type(x)(map(_wrap, x)) | |||
| return x.graph._wrap(x) | |||
| def _unwrap(x): | |||
| if isinstance(x, collections.Sequence): | |||
| return type(x)(map(_unwrap, x)) | |||
| return x._node | |||
| @apply.add | |||
| def _(op: OpDef, *args: VarNode): | |||
| outputs = _imperative_rt.invoke_op(op, _unwrap(args)) | |||
| return _wrap(outputs) | |||
| def input_callback(callback, *args, device=None, dtype=None, graph=None): | |||
| outputs = _imperative_rt.input_callback( | |||
| callback, as_device(device).to_c(), dtype, _unwrap(args), graph=graph | |||
| ) | |||
| value, dummy = _wrap(outputs) | |||
| return value, dummy | |||
| class InputNode(OpNode): | |||
| def __init__(self, *args: VarNode, device=None, dtype=None, graph=None): | |||
| r = _imperative_rt.DeviceTensorNDRendezvous() | |||
| if device is not None: | |||
| device = as_device(device).to_c() | |||
| outputs = _imperative_rt.input_callback( | |||
| r, device, dtype, _unwrap(args), graph=graph | |||
| ) | |||
| super().__init__(outputs[0].owner) | |||
| self._rendezvous = r | |||
| def set_value(self, value): | |||
| assert isinstance(value, _imperative_rt.DeviceTensorND) | |||
| self._rendezvous.set(value) | |||
| def reset(self): | |||
| self._rendezvous.reset() | |||
| @property | |||
| def device(self): | |||
| return self.outputs[0].device | |||
| @property | |||
| def dtype(self): | |||
| return self.outputs[0].dtype | |||
| def output_callback(callback, var, *args): | |||
| args = (var,) + args | |||
| dummy = _imperative_rt.output_callback(callback, _unwrap(args)) | |||
| return _wrap(dummy) | |||
| class OutputNode(OpNode): | |||
| def __init__(self, var, *args): | |||
| args = (var,) + args | |||
| r = _imperative_rt.DeviceTensorNDRendezvous() | |||
| dummy = _imperative_rt.output_callback(r, _unwrap(args)) | |||
| super().__init__(dummy.owner) | |||
| self._rendezvous = r | |||
| def get_value(self): | |||
| return self._rendezvous.get() | |||
| def reset(self): | |||
| self._rendezvous.reset() | |||
| class TensorAttr: | |||
| def __init__(self, shape, dtype, device): | |||
| self.shape = shape | |||
| self.dtype = dtype | |||
| self.device = device | |||
| class AttrOutputNode(OpNode): | |||
| def __init__(self, var, *args): | |||
| args = (var,) + args | |||
| r = _imperative_rt.TensorAttrRendezvous() | |||
| dummy = _imperative_rt.attr_output_callback(r, _unwrap(args)) | |||
| super().__init__(dummy.owner) | |||
| self._rendezvous = r | |||
| def get_value(self): | |||
| attr = self._rendezvous.get() | |||
| return TensorAttr(attr.shape, attr.dtype, as_device(attr.comp_node)) | |||
| def reset(self): | |||
| self._rendezvous.reset() | |||
| @@ -0,0 +1,108 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import functools | |||
| import numpy as np | |||
| from ..._imperative_rt import CompNode, DeviceTensorND | |||
| from ..._imperative_rt.imperative import ( | |||
| _get_dev_tensor, | |||
| apply_op, | |||
| delete, | |||
| get_device, | |||
| get_dtype, | |||
| get_shape, | |||
| get_value, | |||
| put, | |||
| ) | |||
| from ..._wrap import device as as_device | |||
| from ...ops.builtin import Copy, OpDef, TypeCvt | |||
| from ...ops.special import Const | |||
| from ..core import OpBase, TensorBase, apply | |||
| class RawTensor(TensorBase): | |||
| _init_cb = None | |||
| _del_cb = None | |||
| def __init__(self, handle): | |||
| self._handle = handle | |||
| if self._init_cb: | |||
| self._init_cb() | |||
| @property | |||
| def dtype(self): | |||
| return get_dtype(self._handle) | |||
| @property | |||
| def device(self): | |||
| return as_device(get_device(self._handle)) | |||
| @property | |||
| def shape(self): | |||
| return get_shape(self._handle) | |||
| def numpy(self): | |||
| return get_value(self._handle) | |||
| def _dev_tensor(self): | |||
| return _get_dev_tensor(self._handle) | |||
| def __repr__(self): | |||
| return "{}({}, device='{}')".format( | |||
| type(self).__qualname__, repr(self.numpy()), self.device | |||
| ) | |||
| def __del__(self): | |||
| if self._del_cb: | |||
| self._del_cb() | |||
| delete(self._handle) | |||
| @apply.add | |||
| def _(op: OpDef, *args: RawTensor): | |||
| outputs = apply_op(op, tuple(i._handle for i in args)) | |||
| return tuple(map(RawTensor, outputs)) | |||
| @apply.add | |||
| def _(op: Const, *args: RawTensor): | |||
| dtype = op.dtype | |||
| device = as_device(op.device).to_c() | |||
| return (as_raw_tensor(op.value, dtype=dtype, device=device),) | |||
| @functools.singledispatch | |||
| def as_raw_tensor(obj, dtype=None, device=None): | |||
| obj = np.asarray(obj, dtype=dtype) | |||
| if obj.dtype == np.float64: | |||
| obj = obj.astype(np.float32) | |||
| if obj.dtype == np.int64: | |||
| obj = obj.astype(np.int32) | |||
| return as_raw_tensor(obj, device=device) | |||
| @as_raw_tensor.register(np.ndarray) | |||
| def _(array: np.ndarray, dtype=None, device=None): | |||
| device = None if device is None else as_device(device).to_c() | |||
| return RawTensor(put(array, dtype=dtype, device=device)) | |||
| @as_raw_tensor.register(RawTensor) | |||
| def _(tensor: RawTensor, dtype=None, device=None): | |||
| if dtype is not None: | |||
| dtype = np.dtype(dtype) | |||
| if dtype != tensor.dtype: | |||
| (tensor,) = apply(TypeCvt(dtype=dtype), tensor) | |||
| if device is not None: | |||
| device = as_device(device) | |||
| if device != tensor.device: | |||
| (tensor,) = apply(Copy(comp_node=device.to_c()), tensor) | |||
| return tensor | |||
| @@ -0,0 +1,251 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import functools | |||
| import io | |||
| import weakref | |||
| class partial(functools.partial): | |||
| def __get__(self, instance, owner=None): | |||
| if instance is None: | |||
| return self | |||
| return functools.partial(self, instance) | |||
| def hook(f): | |||
| def decorator(impl): | |||
| return functools.update_wrapper(partial(f, impl), impl) | |||
| return decorator | |||
| def on_input(impl, value): | |||
| tensor = impl(value) | |||
| trace = get_trace() | |||
| if trace: | |||
| var = trace.get_var(tensor) | |||
| event = InputEvent(var) | |||
| trace.append(event) | |||
| return tensor | |||
| def on_read_dtype(impl, self): | |||
| trace = get_trace() | |||
| if trace: | |||
| var = trace.get_var(self) | |||
| event = ReadDtypeEvent(var) | |||
| trace.append(event) | |||
| return impl(self) | |||
| def on_read_device(impl, self): | |||
| trace = get_trace() | |||
| if trace: | |||
| var = trace.get_var(self) | |||
| event = ReadDeviceEvent(var) | |||
| trace.append(event) | |||
| return impl(self) | |||
| def on_read_shape(impl, self): | |||
| trace = get_trace() | |||
| if trace: | |||
| var = trace.get_var(self) | |||
| event = ReadShapeEvent(var) | |||
| trace.append(event) | |||
| return impl(self) | |||
| def on_read_value(impl, self): | |||
| trace = get_trace() | |||
| if trace: | |||
| var = trace.get_var(self) | |||
| event = ReadValueEvent(var) | |||
| trace.append(event) | |||
| return impl(self) | |||
| def on_builtin_op(impl, op, *args): | |||
| outputs = impl(op, *args) | |||
| trace = get_trace() | |||
| if trace: | |||
| input_vars = tuple(map(trace.get_var, args)) | |||
| output_vars = outputs and tuple(map(trace.get_var, outputs)) | |||
| event = OpEvent(op, input_vars, output_vars) | |||
| trace.append(event) | |||
| return outputs | |||
| def on_del(impl, self): | |||
| trace = get_trace() | |||
| if trace: | |||
| var = trace.get_var(self) | |||
| event = DelEvent(var) | |||
| trace.append(event) | |||
| return impl(self) | |||
| class Trace(list): | |||
| def __init__(self): | |||
| self._var_id = 1 | |||
| self._t2v = weakref.WeakKeyDictionary() | |||
| self._v2t = weakref.WeakValueDictionary() | |||
| def get_var(self, x): | |||
| v = self._t2v.get(x) | |||
| if v: | |||
| return v | |||
| v = self._var_id | |||
| self._var_id += 1 | |||
| self._t2v[x] = v | |||
| self._v2t[v] = x | |||
| return v | |||
| def __bool__(self): | |||
| return True | |||
| def __enter__(self): | |||
| global _current_trace | |||
| if hasattr(self, "_prev_trace"): | |||
| raise RuntimeError | |||
| self._prev_trace = _current_trace | |||
| _current_trace = self | |||
| return self | |||
| def __exit__(self, *_): | |||
| global _current_trace | |||
| if _current_trace is not self: | |||
| raise RuntimeError | |||
| _current_trace = self._prev_trace | |||
| del self._prev_trace | |||
| class Event: | |||
| pass | |||
| class InputEvent(Event): | |||
| def __init__(self, var): | |||
| self.var = var | |||
| class ReadEvent(Event): | |||
| def __init__(self, var): | |||
| self.var = var | |||
| class ReadDtypeEvent(ReadEvent): | |||
| pass | |||
| class ReadDeviceEvent(ReadEvent): | |||
| pass | |||
| class ReadShapeEvent(ReadEvent): | |||
| pass | |||
| class ReadValueEvent(ReadEvent): | |||
| pass | |||
| class OpEvent(Event): | |||
| def __init__(self, op, inputs, outputs): | |||
| self.op = op | |||
| self.inputs = inputs | |||
| self.outputs = outputs | |||
| class DelEvent(Event): | |||
| def __init__(self, var): | |||
| self.var = var | |||
| _current_trace = None | |||
| def get_trace() -> Trace: | |||
| global _current_trace | |||
| return _current_trace | |||
| def format_trace(trace): | |||
| buf = io.StringIO() | |||
| active_vars = set() | |||
| def write(fmt, *args, **kwargs): | |||
| print(fmt.format(*args, **kwargs), file=buf) | |||
| def init_vars(*args): | |||
| for i in args: | |||
| if i in active_vars: | |||
| continue | |||
| active_vars.add(i) | |||
| write("_{} = input()", i) | |||
| for event in trace: | |||
| if isinstance(event, InputEvent): | |||
| init_vars(event.var) | |||
| elif isinstance(event, ReadDtypeEvent): | |||
| init_vars(event.var) | |||
| write("output(_{}.dtype)", event.var) | |||
| elif isinstance(event, ReadDeviceEvent): | |||
| init_vars(event.var) | |||
| write("output(_{}.device)", event.var) | |||
| elif isinstance(event, ReadShapeEvent): | |||
| init_vars(event.var) | |||
| write("output(_{}.shape)", event.var) | |||
| elif isinstance(event, ReadValueEvent): | |||
| init_vars(event.var) | |||
| write("output(_{}.dtype)", event.var) | |||
| elif isinstance(event, ReadValueEvent): | |||
| init_vars(event.var) | |||
| write("output(_{}.value)", event.var) | |||
| elif isinstance(event, OpEvent): | |||
| init_vars(*event.inputs) | |||
| active_vars.update(event.outputs) | |||
| ovars = ", ".join(map("_{}".format, event.outputs)) | |||
| ivars = ", ".join(map("_{}".format, event.inputs)) | |||
| if ovars: | |||
| write("{} = {}({})", ovars, repr(event.op), ivars) | |||
| else: | |||
| write("{}({})", repr(event.op), ivars) | |||
| elif isinstance(event, DelEvent): | |||
| init_vars(event.var) | |||
| write("del _{}", event.var) | |||
| else: | |||
| raise TypeError(type(event)) | |||
| return buf.getvalue() | |||
| def compile_trace(trace): | |||
| trace = list(trace) | |||
| def static_function(f): | |||
| trace = None | |||
| @functools.wraps(f) | |||
| def wrapper(*args, **kwargs): | |||
| nonlocal trace | |||
| if trace is None: | |||
| with Trace() as trace: | |||
| return f(*args, **kwargs) | |||
| return f(*args, **kwargs) | |||
| return wrapper | |||
| @@ -0,0 +1,263 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import functools | |||
| import weakref | |||
| # Concepts | |||
| # | |||
| # * Internal tensor | |||
| # Tensor produced by the static sequence | |||
| # | |||
| # * External tensor | |||
| # Tensor not produced, but used as input, by the static sequence | |||
| # | |||
| # * Irrelevant tensor | |||
| # Tensor not present in input/output of any op | |||
| # | |||
| # * Escape | |||
| # An internal tensor is said to escape if it is still alive | |||
| # at the end of the sequence | |||
| # JIT-ed execution | |||
| # | |||
| # 1. read attr (dtype, device, shape) | |||
| # a. internal tensor | |||
| # read out as soon as tensor is produced | |||
| # b. external or irrelevant tensor | |||
| # fallback | |||
| # | |||
| # 2. apply op | |||
| # bind external tensors in input | |||
| # | |||
| # 3. del | |||
| class Action: | |||
| pass | |||
| class ReadAttrAction(Action): | |||
| def __init__(self, var, name, getter): | |||
| self.var = var | |||
| self.name = name | |||
| self.getter = getter | |||
| class ReadValueAction(Action): | |||
| def __init__(self, var, getter): | |||
| self.var = var | |||
| self.getter = getter | |||
| class GetTensorAction(Action): | |||
| def __init__(self, var, getter): | |||
| self.var = var | |||
| self.getter = getter | |||
| class OpAction(Action): | |||
| def __init__(self, op, inputs, outputs, input_receivers): | |||
| self.op = op | |||
| self.inputs = inputs | |||
| self.outputs = outputs | |||
| self.input_receivers = input_receivers | |||
| class TensorAttr: | |||
| def __init__(self): | |||
| self.shape = None | |||
| self.dtype = None | |||
| self.device = None | |||
| class Bailout(Exception): | |||
| pass | |||
| class Fallback(Exception): | |||
| pass | |||
| def handle_bailout_fallback_finalize(f): | |||
| @functools.wraps(f) | |||
| def wrapper(self, impl, *args, **kwargs): | |||
| try: | |||
| return f(*args, **kwargs) | |||
| except Bailout: | |||
| self.bailout() | |||
| except Fallback: | |||
| pass | |||
| finally: | |||
| if self.pc == len(self): | |||
| self.finalize() | |||
| return impl(*args, **kwargs) | |||
| return wrapper | |||
| class ExecTrajectory(list): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.reset() | |||
| def __bool__(self): | |||
| return True | |||
| def __enter__(self): | |||
| global _current_trajectory | |||
| if hasattr(self, "_prev_trajectory"): | |||
| raise RuntimeError | |||
| self._prev_trajectory = _current_trajectory | |||
| _current_trajectory = self | |||
| self._exited = False | |||
| return self | |||
| def __exit__(self, *exc_info): | |||
| # cleanup should be done at completion, | |||
| # which is before exiting context manager | |||
| assert self._exited == (exc_info == (None, None, None)) | |||
| if not self._exited: | |||
| assert self.pc < len(self) | |||
| self.bailout() | |||
| def _exit(self): | |||
| # clean up self and global varaible | |||
| assert not self._exited | |||
| self.reset() | |||
| global _current_trajectory | |||
| if _current_trajectory is not self: | |||
| raise RuntimeError | |||
| _current_trajectory = self._prev_trajectory | |||
| del self._prev_trajectory | |||
| def reset(self): | |||
| self._exited = True | |||
| self.pc = 0 | |||
| self.attr_cache = weakref.WeakKeyDictionary() | |||
| ### Internal and External Tensor ### | |||
| # internal tensors are those produced by us | |||
| # external tensors are those received from outside | |||
| # during JIT-ed execution, internal tensors are just placeholders. | |||
| # var_to_tensor is the binding table for all tensors | |||
| self.var_to_tensor = {} # var -> weakref[tensor] | |||
| # tensor_to_var is the reverse binding table for internal tensors | |||
| # note that external tensors could map to >1 vars. | |||
| self.tensor_to_var = weakref.WeakKeyDictionary() | |||
| # internal tensor will be materialized if its .data is accessed from outside | |||
| # after being meterialized, an intern tensor is much like an external tensor | |||
| def finalize(self): | |||
| assert self.pc == len(self) | |||
| self._exit() | |||
| def bailout(self): | |||
| self._exit() | |||
| raise NotImplementedError | |||
| def next_action(self): | |||
| assert not self._exited | |||
| assert self.pc < len(self) | |||
| return self[self.pc] | |||
| @handle_bailout_fallback_finalize | |||
| def read_attr(self, tensor, name): | |||
| attrs = self.attr_cache.setdefault(tensor, TensorAttr()) | |||
| value = getattr(attrs, name, None) | |||
| if value is None: | |||
| action = self.next_action() | |||
| if not isinstance(action, ReadAttrAction): | |||
| raise Bailout | |||
| if name != action.name: | |||
| raise Bailout | |||
| value = action.getter() | |||
| setattr(attrs, name, value) | |||
| return value | |||
| @handle_bailout_fallback_finalize | |||
| def read_value(self, impl, tensor): | |||
| # possibilities: | |||
| # 1. internal tensor | |||
| # 2. external tensor | |||
| # 3. irrelevant tensor (not an input / output of any op) | |||
| if tensor not in self.tensor_to_var: | |||
| raise Fallback | |||
| assert tensor._data is None | |||
| action = self.next_action() | |||
| if not isinstance(action, ReadValueAction): | |||
| raise Bailout | |||
| return action.getter() | |||
| @handle_bailout_fallback_finalize | |||
| def apply_op(self, impl, op, *args): | |||
| from . import RawTensor | |||
| action = self.next_action() | |||
| if not isinstance(action, OpAction): | |||
| raise Bailout | |||
| if len(args) != len(action.inputs): | |||
| raise Bailout | |||
| assert len(actions.inputs) == len(action.input_receivers) | |||
| for v, t, r in zip(action.inputs, args, action.input_receivers): | |||
| if v in self.var_to_tensor: | |||
| assert r is None | |||
| if t is not self.var_to_tensor[v](): | |||
| raise Bailout | |||
| else: | |||
| # NOTE: not checking for aliasing (>=2 vars map to 1 tensor) | |||
| # the static execution backend must handle this | |||
| self.var_to_tensor[v] = weakref.ref(t) | |||
| r(t) | |||
| outputs = [] | |||
| for v in action.outputs: | |||
| assert v not in self.var_to_tensor | |||
| t = RawTensor() | |||
| t._data_getter = functools.partial(self.get_data, v) | |||
| outputs.append(t) | |||
| self.var_to_tensor[v] = weakref.ref(t) | |||
| return tuple(outputs) | |||
| def get_data(self, var): | |||
| tensor = self.var_to_tensor[var]() | |||
| assert tensor is not None | |||
| assert tensor._data is None | |||
| assert tensor in self.tensor_to_var | |||
| action = self.next_action() | |||
| if not isinstance(action, GetTensorAction): | |||
| self.bailout() | |||
| elif action.var != var: | |||
| self.bailout() | |||
| else: | |||
| tensor._data = action.getter() | |||
| del tensor._data_getter | |||
| del self.tensor_to_var[tensor] | |||
| assert "_data_getter" not in tensor.__dict__ | |||
| return tensor._data_getter() | |||
| _current_trajectory = None | |||
| def get_trajectory(): | |||
| return _current_trajectory | |||
| def compile_trace(trace): | |||
| from .jit import ReadDTypeEvent, ReadDeviceEvent, ReadShapeEvent, OpEvent, DelEvent | |||
| traj = ExecutionTrajectory() | |||
| active_vars = set() | |||
| for event in trace: | |||
| if isinstance(event, ReadDTypeEvent): | |||
| traj.append(ReadAttrAction()) | |||
| @@ -0,0 +1,106 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import contextlib | |||
| import copy | |||
| from .core import Dispatcher, OpBase, TensorBase, apply | |||
| class Tensor(TensorBase): | |||
| def __init__(self, data: TensorBase): | |||
| self._data = data | |||
| # _extra_data is set up in Grad.wrt | |||
| self._extra_data = {} | |||
| self._user_data = {} | |||
| def __getattr__(self, name): | |||
| if name in self._user_data: | |||
| return self._user_data[name] | |||
| raise AttributeError(name) | |||
| def reset(self, other): | |||
| assert isinstance(other, __class__) | |||
| self.__dict__.clear() | |||
| self._data = other.data | |||
| self._extra_data = other._extra_data.copy() | |||
| self._user_data = other._user_data.copy() | |||
| def copy(self): | |||
| other = object.__new__(type(self)) | |||
| other.reset(self) | |||
| return other | |||
| # tensor interface | |||
| @property | |||
| def shape(self): | |||
| return self._data.shape | |||
| @property | |||
| def dtype(self): | |||
| return self._data.dtype | |||
| @property | |||
| def device(self): | |||
| return self._data.device | |||
| def numpy(self): | |||
| return self._data.numpy() | |||
| class ApplyContext: | |||
| def __init__(self): | |||
| self.inputs = None | |||
| self.outputs = None | |||
| self.key = None | |||
| _context = None | |||
| @contextlib.contextmanager | |||
| def push_context(): | |||
| global _context | |||
| backup = _context | |||
| try: | |||
| _context = ApplyContext() | |||
| yield _context | |||
| finally: | |||
| _context = backup | |||
| def get_context(): | |||
| return _context | |||
| @apply.add | |||
| def tensor_apply(op: OpBase, *args: Tensor): | |||
| data = tuple(i._data if isinstance(i, Tensor) else i for i in args) | |||
| # type(Tensor._data) is RawTensor | |||
| # dispached to apply.add@RawTensor.py if passed Tensor args | |||
| outputs = apply(op, *data) | |||
| ret = tuple(map(Tensor, outputs)) | |||
| with push_context() as ctx: | |||
| ctx.inputs = args | |||
| ctx.outputs = ret | |||
| for k in set().union(*(i._extra_data for i in args if isinstance(i, Tensor))): | |||
| ctx.key = k | |||
| data = tuple( | |||
| i._extra_data.get(k) if isinstance(i, Tensor) else i for i in args | |||
| ) | |||
| # data are instances of Tracer | |||
| # dispatched to apply.add@grad.py | |||
| outputs = apply(op, *data) | |||
| if outputs is not None: | |||
| assert len(outputs) == len(ret) | |||
| for t, i in zip(ret, outputs): | |||
| t._extra_data[k] = i | |||
| return ret | |||
| @@ -0,0 +1,367 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import abc | |||
| import collections | |||
| import numpy as np | |||
| from ..ops import builtin | |||
| from ..ops.special import Const | |||
| from . import utils | |||
| from .core import OpBase, TensorBase, TensorWrapperBase, apply | |||
| from .indexing import getitem as _getitem | |||
| from .indexing import setitem as _setitem | |||
| from .raw_tensor import RawTensor, as_raw_tensor | |||
| from .tensor import Tensor | |||
| def _elwise(*args, mode): | |||
| op = builtin.Elemwise(mode=mode) | |||
| args = utils.convert_inputs(*args) | |||
| (result,) = apply(op, *args) | |||
| return result | |||
| def _matmul(inp1, inp2): | |||
| op = builtin.MatrixMul( | |||
| transposeA=False, transposeB=False, compute_mode="DEFAULT", format="DEFAULT" | |||
| ) | |||
| inp1, inp2 = utils.convert_inputs(inp1, inp2) | |||
| (result,) = apply(op, inp1, inp2) | |||
| return result | |||
| def _transpose(data, axes): | |||
| op = builtin.Dimshuffle(axes) | |||
| (data,) = utils.convert_inputs(data) | |||
| (result,) = apply(op, data) | |||
| return result | |||
| def _broadcast(inp, shape): | |||
| shape = utils.astensor1d(shape, inp, dtype="int32", device=inp.device) | |||
| (result,) = apply(builtin.Broadcast(), inp, shape) | |||
| return result | |||
| def _reshape(x, shape): | |||
| if isinstance(shape, (TensorBase, TensorWrapperBase)): | |||
| shape = shape.numpy() | |||
| shape = tuple(map(int, shape)) | |||
| unspec_axis = None | |||
| for i, s in enumerate(shape): | |||
| if s < 0: | |||
| if s != -1: | |||
| raise ValueError("expect shape[{}] >= -1, got {}".format(i, s)) | |||
| if unspec_axis is not None: | |||
| raise ValueError("multiple -1 in shape: {} & {}".format(unspec_axis, i)) | |||
| unspec_axis = i | |||
| # TODO: device should be None (cpu) | |||
| (shape,) = Const(shape, dtype=np.int32, device=x.device)(x) | |||
| if unspec_axis is None: | |||
| op = builtin.Reshape() | |||
| else: | |||
| op = builtin.Reshape(unspec_axis=unspec_axis) | |||
| (x,) = apply(op, x, shape) | |||
| return x | |||
| def _unary_elwise(mode): | |||
| def f(self): | |||
| return _elwise(self, mode=mode) | |||
| return f | |||
| def _binary_elwise(mode, rev=False): | |||
| if not rev: | |||
| def f(self, value): | |||
| return _elwise(self, value, mode=mode) | |||
| else: | |||
| def f(self, value): | |||
| return _elwise(value, self, mode=mode) | |||
| return f | |||
| def _logical_unary_elwise(mode, rev=False): | |||
| def f(self): | |||
| if self.dtype != np.bool_: | |||
| raise TypeError("{} requires a bool tensor".format(mode)) | |||
| return _elwise(self, mode=mode) | |||
| return f | |||
| def _logical_binary_elwise(mode, rev=False): | |||
| if not rev: | |||
| def f(self, value): | |||
| if self.dtype != np.bool_ or value.dtype != np.bool_: | |||
| raise TypeError("{} requires 2 bool tensors".format(mode)) | |||
| return _elwise(self, value, mode=mode) | |||
| else: | |||
| def f(self, value): | |||
| if self.dtype != np.bool_ or value.dtype != np.bool_: | |||
| raise TypeError("{} requires 2 bool tensors".format(mode)) | |||
| return _elwise(value, self, mode=mode) | |||
| return f | |||
| def _reduce(mode): | |||
| def f(self, axis=None): | |||
| inp = self | |||
| if axis is None: | |||
| inp = self.flatten() | |||
| axis = 0 | |||
| op = builtin.Reduce(mode=mode, axis=axis) | |||
| (result,) = utils.convert_inputs(inp) | |||
| (result,) = apply(op, result) | |||
| return result | |||
| return f | |||
| def _inplace(f): | |||
| def g(self, value): | |||
| result = f(self, value) | |||
| if result is NotImplemented: | |||
| raise NotImplementedError | |||
| self._reset(result) | |||
| return self | |||
| return g | |||
| def _todo(*_): | |||
| raise NotImplementedError | |||
| class ArrayMethodMixin(abc.ABC): | |||
| __array_priority__ = 233333 | |||
| @abc.abstractmethod | |||
| def _reset(self, other): | |||
| pass | |||
| @abc.abstractproperty | |||
| def dtype(self) -> np.dtype: | |||
| pass | |||
| @abc.abstractproperty | |||
| def shape(self) -> tuple: | |||
| pass | |||
| @abc.abstractmethod | |||
| def numpy(self) -> np.ndarray: | |||
| pass | |||
| __hash__ = None # due to __eq__ diviates from python convention | |||
| __lt__ = lambda self, value: _elwise(self, value, mode="LT").astype("bool") | |||
| __le__ = lambda self, value: _elwise(self, value, mode="LEQ").astype("bool") | |||
| __gt__ = lambda self, value: _elwise(value, self, mode="LT").astype("bool") | |||
| __ge__ = lambda self, value: _elwise(value, self, mode="LEQ").astype("bool") | |||
| __eq__ = lambda self, value: _elwise(self, value, mode="EQ").astype("bool") | |||
| __ne__ = lambda self, value: _elwise( | |||
| _elwise(self, value, mode="EQ").astype("bool"), mode="NOT" | |||
| ) | |||
| __neg__ = _unary_elwise("NEGATE") | |||
| __pos__ = lambda self: self | |||
| __abs__ = _unary_elwise("ABS") | |||
| __invert__ = _logical_unary_elwise("NOT") | |||
| __round__ = _unary_elwise("ROUND") | |||
| __trunc__ = _todo | |||
| __floor__ = _unary_elwise("FLOOR") | |||
| __ceil__ = _unary_elwise("CEIL") | |||
| __add__ = _binary_elwise("ADD") | |||
| __sub__ = _binary_elwise("SUB") | |||
| __mul__ = _binary_elwise("MUL") | |||
| __matmul__ = lambda self, other: _matmul(self, other) | |||
| __truediv__ = _binary_elwise("TRUE_DIV") | |||
| __floordiv__ = _binary_elwise("FLOOR_DIV") | |||
| __mod__ = _binary_elwise("MOD") | |||
| # __divmode__ | |||
| __pow__ = _binary_elwise("POW") | |||
| __lshift__ = _binary_elwise("SHL") | |||
| __rshift__ = _binary_elwise("SHR") | |||
| __and__ = _logical_binary_elwise("AND") | |||
| __or__ = _logical_binary_elwise("OR") | |||
| __xor__ = _logical_binary_elwise("XOR") | |||
| __radd__ = _binary_elwise("ADD", rev=1) | |||
| __rsub__ = _binary_elwise("SUB", rev=1) | |||
| __rmul__ = _binary_elwise("MUL", rev=1) | |||
| __rmatmul__ = lambda self, other: _matmul(other, self) | |||
| __rtruediv__ = _binary_elwise("TRUE_DIV", rev=1) | |||
| __rfloordiv__ = _binary_elwise("FLOOR_DIV", rev=1) | |||
| __rmod__ = _binary_elwise("MOD", rev=1) | |||
| # __rdivmode__ | |||
| __rpow__ = _binary_elwise("POW", rev=1) | |||
| __rlshift__ = _binary_elwise("SHL", rev=1) | |||
| __rrshift__ = _binary_elwise("SHR", rev=1) | |||
| __rand__ = _logical_binary_elwise("AND", rev=1) | |||
| __ror__ = _logical_binary_elwise("OR", rev=1) | |||
| __rxor__ = _logical_binary_elwise("XOR", rev=1) | |||
| __iadd__ = _inplace(__add__) | |||
| __isub__ = _inplace(__sub__) | |||
| __imul__ = _inplace(__mul__) | |||
| __imatmul__ = _inplace(__matmul__) | |||
| __itruediv__ = _inplace(__truediv__) | |||
| __ifloordiv__ = _inplace(__floordiv__) | |||
| __imod__ = _inplace(__mod__) | |||
| __ipow__ = _inplace(__pow__) | |||
| __ilshift__ = _inplace(__lshift__) | |||
| __irshift__ = _inplace(__rshift__) | |||
| __iand__ = _inplace(__and__) | |||
| __ior__ = _inplace(__or__) | |||
| __ixor__ = _inplace(__xor__) | |||
| __index__ = lambda self: self.item().__index__() | |||
| __bool__ = lambda self: bool(self.item()) | |||
| __int__ = lambda self: int(self.item()) | |||
| __float__ = lambda self: float(self.item()) | |||
| __complex__ = lambda self: complex(self.item()) | |||
| def __len__(self): | |||
| shape = self.shape | |||
| if shape: | |||
| return int(shape[0]) | |||
| raise TypeError("ndim is 0") | |||
| def __iter__(self): | |||
| for i in range(len(self)): | |||
| yield self[i] | |||
| def __getitem__(self, index): | |||
| return _getitem(self, index) | |||
| def __setitem__(self, index, value): | |||
| if index is not Ellipsis: | |||
| value = _setitem(self, index, value) | |||
| self._reset(value) | |||
| __contains__ = _todo | |||
| @property | |||
| def ndim(self): | |||
| return len(self.shape) | |||
| @property | |||
| def size(self): | |||
| return np.prod(self.shape).item() | |||
| @property | |||
| def T(self): | |||
| return self.transpose() | |||
| def item(self, *args): | |||
| if not args: | |||
| assert self.size == 1 | |||
| return self.numpy().item() | |||
| return self[args].item() | |||
| def tolist(self): | |||
| return self.numpy().tolist() | |||
| def astype(self, dtype): | |||
| return utils.astype(self, dtype) | |||
| def reshape(self, *args): | |||
| if len(args) == 1: | |||
| if isinstance(args[0], collections.Sequence): | |||
| args = args[0] | |||
| return _reshape(self, args) | |||
| def broadcast(self, *args): | |||
| if len(args) == 1: | |||
| if isinstance(args[0], collections.Sequence): | |||
| args = args[0] | |||
| return _broadcast(self, args) | |||
| def transpose(self, *args): | |||
| if not args: | |||
| args = reversed(range(self.ndim)) | |||
| elif len(args) == 1: | |||
| if isinstance(args[0], collections.Sequence): | |||
| args = args[0] | |||
| return _transpose(self, args) | |||
| def flatten(self): | |||
| return self.reshape(-1) | |||
| sum = _reduce("SUM") | |||
| prod = _reduce("PRODUCT") | |||
| min = _reduce("MIN") | |||
| max = _reduce("MAX") | |||
| mean = _reduce("MEAN") | |||
| class GenericTensorWrapper(ArrayMethodMixin, TensorWrapperBase): | |||
| def __init__(self, data): | |||
| self.__wrapped__ = data | |||
| def _reset(self, other): | |||
| if not isinstance(other, __class__): | |||
| raise TypeError(type(other)) | |||
| self.__wrapped__ = other.__wrapped__ | |||
| return self | |||
| @property | |||
| def dtype(self): | |||
| return self.__wrapped__.dtype | |||
| @property | |||
| def shape(self): | |||
| return self.__wrapped__.shape | |||
| @property | |||
| def device(self): | |||
| return self.__wrapped__.device | |||
| def numpy(self): | |||
| return self.__wrapped__.numpy() | |||
| class TensorWrapper(GenericTensorWrapper): | |||
| def __init__(self, data, dtype=None, device=None): | |||
| if isinstance(data, TensorWrapperBase): | |||
| data = data.__wrapped__ | |||
| elif not isinstance(data, TensorBase): | |||
| assert data is not None, "Cannot init a tensor with data as None" | |||
| data = Tensor(as_raw_tensor(data, dtype=dtype, device=device)) | |||
| super().__init__(data) | |||
| def _reset(self, other): | |||
| if isinstance(other, TensorWrapperBase): | |||
| self.__wrapped__ = other.__wrapped__ | |||
| elif isinstance(other, TensorBase): | |||
| self.__wrapped__ = other | |||
| else: | |||
| self._reset(type(self)(other, dtype=self.dtype, device=self.device)) | |||
| def __repr__(self): | |||
| piece = "Tensor(" | |||
| with np.printoptions(precision=4, suppress=True): | |||
| piece += "{}".format(str(self.numpy())) | |||
| if self.dtype != np.float32: | |||
| piece += ", dtype={}".format(np.dtype(self.dtype).name) | |||
| piece += ", device={}".format(self.device) + ")" | |||
| return piece | |||
| @@ -0,0 +1,154 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import collections | |||
| from typing import Iterable, Union | |||
| import numpy as np | |||
| from ..ops import builtin | |||
| from ..ops.special import Const | |||
| from ..tensor.core import OpBase, TensorBase, TensorWrapperBase, apply | |||
| def dtype_promotion(raw_inputs): | |||
| def add_dtype(i): | |||
| if type(i) == int: | |||
| return np.array(i, dtype=np.int32) | |||
| if type(i) == float: | |||
| return np.array(i, dtype=np.float32) | |||
| if type(i) == bool: | |||
| return np.array(i, dtype=np.bool_) | |||
| return None | |||
| scalar_inputs = [ | |||
| add_dtype(i) for i in raw_inputs if not hasattr(i, "dtype") and add_dtype(i) | |||
| ] | |||
| inputs = [i for i in raw_inputs if hasattr(i, "dtype")] | |||
| assert len(scalar_inputs + inputs) > 0 | |||
| dtype = np.result_type(*inputs) | |||
| dtype_all = np.result_type(*(inputs + scalar_inputs)) | |||
| assert ( | |||
| dtype != np.float64 and dtype != np.int64 | |||
| ), "unsupport dtype {} by dtype_promotion, please use explict type convert".format( | |||
| dtype | |||
| ) | |||
| if dtype_all == np.bool_: | |||
| for i in raw_inputs: | |||
| if not hasattr(i, "dtype") or i.dtype != np.bool_: | |||
| raise TypeError( | |||
| "bool dtype can not be operated with an element without bool dtype" | |||
| ) | |||
| if dtype_all == np.float64: | |||
| dtype_all = np.float32 | |||
| return dtype_all | |||
| def get_device(inputs): | |||
| device = None | |||
| for i in inputs: | |||
| if isinstance(i, (TensorWrapperBase, TensorBase)): | |||
| if device is None: | |||
| device = i.device | |||
| elif device != i.device: | |||
| raise ValueError("ambiguous device: {} vs {}".format(device, i.device)) | |||
| assert device is not None | |||
| return device | |||
| def concatenate(inputs, axis=0, *, device=None): | |||
| dtype = dtype_promotion(inputs) | |||
| device = get_device(inputs) | |||
| def convert(x): | |||
| return convert_single_value(x, inputs, dtype=dtype) | |||
| inputs = tuple(map(convert, inputs)) | |||
| (result,) = apply(builtin.Concat(axis=axis, comp_node=device.to_c()), *inputs) | |||
| return result | |||
| def astype(x, dtype): | |||
| dtype = np.dtype(dtype) | |||
| if x.dtype != dtype: | |||
| (x,) = apply(builtin.TypeCvt(param=dtype), x) | |||
| return x | |||
| def convert_single_value(v, inputs, *, dtype=None, device=None): | |||
| tensors = [i for i in inputs if isinstance(i, (TensorBase, TensorWrapperBase))] | |||
| assert len(tensors) > 0 | |||
| if isinstance(v, (TensorWrapperBase, TensorBase)): | |||
| v = astype(v, dtype) | |||
| else: | |||
| (v,) = Const(v, dtype=dtype, device=device)(*tensors) | |||
| return v | |||
| def convert_inputs(*args: TensorBase): | |||
| dtype = dtype_promotion(args) | |||
| device = get_device(args) | |||
| def convert(value): | |||
| if value is None: | |||
| return value | |||
| return convert_single_value(value, args, dtype=dtype, device=device) | |||
| return tuple(map(convert, args)) | |||
| def result_type(*args): | |||
| dtypes = [] | |||
| for i in args: | |||
| if isinstance(i, (TensorWrapperBase, TensorBase)): | |||
| dtypes.append(i.dtype) | |||
| continue | |||
| try: | |||
| dtypes.append(np.dtype(i)) | |||
| except TypeError: | |||
| pass | |||
| return np.result_type(*dtypes) | |||
| def isscalar(x): | |||
| try: | |||
| return x.ndim == 0 | |||
| except: | |||
| pass | |||
| return np.isscalar(x) | |||
| def astensor1d(x, *reference, dtype=None, device=None): | |||
| """ | |||
| Convert something to 1D tensor. Support following types | |||
| * sequence of scalar literal / tensor | |||
| * numpy array | |||
| * tensor (returned as is, regardless of dtype and device) | |||
| """ | |||
| try: | |||
| ndim = x.ndim | |||
| except AttributeError: | |||
| pass | |||
| else: | |||
| if ndim != 1: | |||
| raise ValueError("ndim != 1: %d" % ndim) | |||
| if not isinstance(x, (TensorBase, TensorWrapperBase)): | |||
| (x,) = Const(x, dtype=dtype, device=device)(*reference) | |||
| return x | |||
| if not isinstance(x, collections.Sequence): | |||
| raise TypeError | |||
| if any(isinstance(i, (TensorBase, TensorWrapperBase)) for i in x): | |||
| x = concatenate(x, device=device) | |||
| if dtype is not None: | |||
| x = astype(x, dtype) | |||
| return x | |||
| (x,) = Const(x, dtype=dtype, device=device)(*reference) | |||
| return x | |||
| @@ -0,0 +1,17 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from .collator import Collator | |||
| from .dataloader import DataLoader | |||
| from .sampler import ( | |||
| Infinite, | |||
| RandomSampler, | |||
| ReplacementSampler, | |||
| Sampler, | |||
| SequentialSampler, | |||
| ) | |||
| @@ -0,0 +1,139 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import binascii | |||
| import os | |||
| import queue | |||
| import subprocess | |||
| from multiprocessing import Queue | |||
| import pyarrow | |||
| import pyarrow.plasma as plasma | |||
| MGE_PLASMA_MEMORY = int(os.environ.get("MGE_PLASMA_MEMORY", 4000000000)) # 4GB | |||
| # Each process only need to start one plasma store, so we set it as a global variable. | |||
| # TODO: how to share between different processes? | |||
| MGE_PLASMA_STORE_MANAGER = None | |||
| def _clear_plasma_store(): | |||
| # `_PlasmaStoreManager.__del__` will not be called automaticly in subprocess, | |||
| # so this function should be called explicitly | |||
| global MGE_PLASMA_STORE_MANAGER | |||
| if MGE_PLASMA_STORE_MANAGER is not None: | |||
| del MGE_PLASMA_STORE_MANAGER | |||
| MGE_PLASMA_STORE_MANAGER = None | |||
| class _PlasmaStoreManager: | |||
| __initialized = False | |||
| def __init__(self): | |||
| self.socket_name = "/tmp/mge_plasma_{}".format( | |||
| binascii.hexlify(os.urandom(8)).decode() | |||
| ) | |||
| debug_flag = bool(os.environ.get("MGE_DATALOADER_PLASMA_DEBUG", 0)) | |||
| # NOTE: this is a hack. Directly use `plasma_store` may make subprocess | |||
| # difficult to handle the exception happened in `plasma-store-server`. | |||
| # For `plasma_store` is just a wrapper of `plasma-store-server`, which use | |||
| # `os.execv` to call the executable `plasma-store-server`. | |||
| cmd_path = os.path.join(pyarrow.__path__[0], "plasma-store-server") | |||
| self.plasma_store = subprocess.Popen( | |||
| [cmd_path, "-s", self.socket_name, "-m", str(MGE_PLASMA_MEMORY),], | |||
| stdout=None if debug_flag else subprocess.DEVNULL, | |||
| stderr=None if debug_flag else subprocess.DEVNULL, | |||
| ) | |||
| self.__initialized = True | |||
| def __del__(self): | |||
| if self.__initialized and self.plasma_store.returncode is None: | |||
| self.plasma_store.kill() | |||
| class PlasmaShmQueue: | |||
| def __init__(self, maxsize: int = 0): | |||
| r"""Use pyarrow in-memory plasma store to implement shared memory queue. | |||
| Compared to native `multiprocess.Queue`, `PlasmaShmQueue` avoid pickle/unpickle | |||
| and communication overhead, leading to better performance in multi-process | |||
| application. | |||
| :type maxsize: int | |||
| :param maxsize: maximum size of the queue, `None` means no limit. (default: ``None``) | |||
| """ | |||
| # Lazy start the plasma store manager | |||
| global MGE_PLASMA_STORE_MANAGER | |||
| if MGE_PLASMA_STORE_MANAGER is None: | |||
| try: | |||
| MGE_PLASMA_STORE_MANAGER = _PlasmaStoreManager() | |||
| except Exception as e: | |||
| err_info = ( | |||
| "Please make sure pyarrow installed correctly!\n" | |||
| "You can try reinstall pyarrow and see if you can run " | |||
| "`plasma_store -s /tmp/mge_plasma_xxx -m 1000` normally." | |||
| ) | |||
| raise RuntimeError( | |||
| "Exception happened in starting plasma_store: {}\n" | |||
| "Tips: {}".format(str(e), err_info) | |||
| ) | |||
| self.socket_name = MGE_PLASMA_STORE_MANAGER.socket_name | |||
| # TODO: how to catch the exception happened in `plasma.connect`? | |||
| self.client = None | |||
| # Used to store the header for the data.(ObjectIDs) | |||
| self.queue = Queue(maxsize) # type: Queue | |||
| def put(self, data, block=True, timeout=None): | |||
| if self.client is None: | |||
| self.client = plasma.connect(self.socket_name) | |||
| try: | |||
| object_id = self.client.put(data) | |||
| except plasma.PlasmaStoreFull: | |||
| raise RuntimeError("plasma store out of memory!") | |||
| try: | |||
| self.queue.put(object_id, block, timeout) | |||
| except queue.Full: | |||
| self.client.delete([object_id]) | |||
| raise queue.Full | |||
| def get(self, block=True, timeout=None): | |||
| if self.client is None: | |||
| self.client = plasma.connect(self.socket_name) | |||
| object_id = self.queue.get(block, timeout) | |||
| if not self.client.contains(object_id): | |||
| raise RuntimeError( | |||
| "ObjectID: {} not found in plasma store".format(object_id) | |||
| ) | |||
| data = self.client.get(object_id) | |||
| self.client.delete([object_id]) | |||
| return data | |||
| def qsize(self): | |||
| return self.queue.qsize() | |||
| def empty(self): | |||
| return self.queue.empty() | |||
| def join(self): | |||
| self.queue.join() | |||
| def disconnect_client(self): | |||
| if self.client is not None: | |||
| self.client.disconnect() | |||
| def close(self): | |||
| self.queue.close() | |||
| self.disconnect_client() | |||
| _clear_plasma_store() | |||
| def cancel_join_thread(self): | |||
| self.queue.cancel_join_thread() | |||
| @@ -0,0 +1,76 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # Copyright (c) 2016- Facebook, Inc (Adam Paszke) | |||
| # Copyright (c) 2014- Facebook, Inc (Soumith Chintala) | |||
| # Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) | |||
| # Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) | |||
| # Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) | |||
| # Copyright (c) 2011-2013 NYU (Clement Farabet) | |||
| # Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) | |||
| # Copyright (c) 2006 Idiap Research Institute (Samy Bengio) | |||
| # Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) | |||
| # --------------------------------------------------------------------- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # | |||
| # This file has been modified by Megvii ("Megvii Modifications"). | |||
| # All Megvii Modifications are Copyright (C) 2014-2020 Megvii Inc. All rights reserved. | |||
| # ---------------------------------------------------------------------- | |||
| import collections.abc | |||
| import re | |||
| import numpy as np | |||
| np_str_obj_array_pattern = re.compile(r"[aO]") | |||
| default_collate_err_msg_format = ( | |||
| "default_collator: inputs must contain numpy arrays, numbers, " | |||
| "Unicode strings, bytes, dicts or lists; found {}" | |||
| ) | |||
| class Collator: | |||
| r""" | |||
| Used for merge a list of samples to form a mini-batch of Tenor(s). Used when using batched loading from a dataset. | |||
| modified from https://github.com/pytorch/pytorch/blob/master/torch/utils/data/_utils/collate.py | |||
| """ | |||
| def apply(self, inputs): | |||
| """ | |||
| input : sequence_N(tuple(CHW, C, CK)) | |||
| output : tuple(NCHW, NC, NCK) | |||
| """ | |||
| elem = inputs[0] | |||
| elem_type = type(elem) | |||
| if ( | |||
| elem_type.__module__ == "numpy" | |||
| and elem_type.__name__ != "str_" | |||
| and elem_type.__name__ != "string_" | |||
| ): | |||
| elem = inputs[0] | |||
| if elem_type.__name__ == "ndarray": | |||
| # array of string classes and object | |||
| if np_str_obj_array_pattern.search(elem.dtype.str) is not None: | |||
| raise TypeError(default_collate_err_msg_format.format(elem.dtype)) | |||
| return np.ascontiguousarray(np.stack(inputs)) | |||
| elif elem.shape == (): # scalars | |||
| return np.array(inputs) | |||
| elif isinstance(elem, float): | |||
| return np.array(inputs, dtype=np.float64) | |||
| elif isinstance(elem, int): | |||
| return np.array(inputs) | |||
| elif isinstance(elem, (str, bytes)): | |||
| return inputs | |||
| elif isinstance(elem, collections.abc.Mapping): | |||
| return {key: self.apply([d[key] for d in inputs]) for key in elem} | |||
| elif isinstance(elem, tuple) and hasattr(elem, "_fields"): # namedtuple | |||
| return elem_type(*(self.apply(samples) for samples in zip(*inputs))) | |||
| elif isinstance(elem, collections.abc.Sequence): | |||
| transposed = zip(*inputs) | |||
| return [self.apply(samples) for samples in transposed] | |||
| raise TypeError(default_collate_err_msg_format.format(elem_type)) | |||
| @@ -0,0 +1,500 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import collections | |||
| import math | |||
| import multiprocessing | |||
| import queue | |||
| import random | |||
| import time | |||
| import numpy as np | |||
| from ..logger import get_logger | |||
| from ..random.rng import _random_seed_generator | |||
| from .collator import Collator | |||
| from .dataset import Dataset | |||
| from .sampler import Sampler, SequentialSampler | |||
| from .transform import PseudoTransform, Transform | |||
| logger = get_logger(__name__) | |||
| MP_QUEUE_GET_TIMEOUT = 5 | |||
| class DataLoader: | |||
| __initialized = False | |||
| def __init__( | |||
| self, | |||
| dataset: Dataset, | |||
| sampler: Sampler = None, | |||
| transform: Transform = None, | |||
| collator: Collator = None, | |||
| num_workers: int = 0, | |||
| timeout: int = 0, | |||
| divide: bool = False, | |||
| ): | |||
| r"""Provides a convenient way to iterate on a given dataset. | |||
| `DataLoader` combines a dataset with sampler, transform and collator, | |||
| make it flexible to get minibatch continually from a dataset. | |||
| :type dataset: Dataset | |||
| :param dataset: dataset from which to load the minibatch. | |||
| :type sampler: Sampler | |||
| :param sampler: defines the strategy to sample data from the dataset. | |||
| If specified, :attr:`shuffle` must be ``False``. | |||
| :type transform: Transform | |||
| :param transform: defined the transforming strategy for a sampled batch. | |||
| (default: ``None``) | |||
| :type collator: Collator | |||
| :param collator: defined the merging strategy for a transformed batch. | |||
| (default: ``None``) | |||
| :type num_workers: int | |||
| :param num_workers: the number of sub-process to load, transform and collate | |||
| the batch. ``0`` means using single-process. (default: ``0``) | |||
| :type timeout: int | |||
| :param timeout: if positive, means the timeout value(second) for collecting a | |||
| batch from workers. (default: 0) | |||
| :type divide: bool | |||
| :param divide: define the paralleling strategy in multi-processing mode. | |||
| ``True`` means one batch is divided into :attr:`num_workers` pieces, and | |||
| the workers will process these pieces parallelly. ``False`` means | |||
| different sub-process will process different batch. (default: ``False``) | |||
| """ | |||
| if num_workers < 0: | |||
| raise ValueError("num_workers should not be negative") | |||
| if timeout < 0: | |||
| raise ValueError("timeout should not be negative") | |||
| if divide and num_workers <= 1: | |||
| raise ValueError("divide should not be set to True when num_workers <= 1") | |||
| self.dataset = dataset | |||
| self.num_workers = num_workers | |||
| self.timeout = timeout | |||
| self.divide = divide | |||
| if sampler is None: | |||
| self.sampler = SequentialSampler(dataset, batch_size=1, drop_last=False) | |||
| else: | |||
| self.sampler = sampler | |||
| if divide: | |||
| if self.sampler.batch_size <= self.num_workers: | |||
| raise ValueError( | |||
| "batch size must not smaller than num_workers in divide mode." | |||
| ) | |||
| elif self.sampler.batch_size % self.num_workers: | |||
| logger.warning( | |||
| "batch size is not divisible by num_workers, may lose performance in divide mode." | |||
| ) | |||
| if transform is None: | |||
| self.transform = PseudoTransform() | |||
| else: | |||
| self.transform = transform | |||
| if collator is None: | |||
| self.collator = Collator() | |||
| else: | |||
| self.collator = collator | |||
| self.__initialized = True | |||
| def __iter__(self): | |||
| if self.num_workers == 0: | |||
| return _SerialDataLoaderIter(self) | |||
| else: | |||
| return _ParallelDataLoaderIter(self) | |||
| def __len__(self): | |||
| return len(self.sampler) | |||
| class _BaseDataLoaderIter: | |||
| def __init__(self, loader): | |||
| self.dataset = loader.dataset | |||
| self.sampler = loader.sampler | |||
| self.seed = _random_seed_generator().__next__() | |||
| self.transform = loader.transform | |||
| self.collator = loader.collator | |||
| self.num_workers = loader.num_workers | |||
| self.timeout = loader.timeout | |||
| self.divide = loader.divide | |||
| self.num_processed = 0 | |||
| def _get_next_batch(self): | |||
| raise NotImplementedError | |||
| def __len__(self): | |||
| return len(self.sampler) | |||
| def __iter__(self): | |||
| return self | |||
| def __next__(self): | |||
| if self.num_processed >= len(self): | |||
| raise StopIteration | |||
| minibatch = self._get_next_batch() | |||
| self.num_processed += 1 | |||
| return minibatch | |||
| class _SerialDataLoaderIter(_BaseDataLoaderIter): | |||
| def __init__(self, loader): | |||
| super(_SerialDataLoaderIter, self).__init__(loader) | |||
| self.indices_iter = iter(self.sampler) | |||
| def _get_next_batch(self): | |||
| indices = next(self.indices_iter) | |||
| items = [self.dataset[idx] for idx in indices] | |||
| trans_items = self.transform.apply_batch(items) | |||
| return self.collator.apply(trans_items) | |||
| class _ParallelDataLoaderIter(_BaseDataLoaderIter): | |||
| __initialized = False | |||
| def __init__(self, loader): | |||
| super(_ParallelDataLoaderIter, self).__init__(loader) | |||
| self.task_queues = [ | |||
| multiprocessing.Queue(maxsize=2) for _ in range(self.num_workers) | |||
| ] | |||
| self.feed_batch_idx = multiprocessing.Value("i", 0) | |||
| self.target_batch_idx = multiprocessing.Value("i", 0) | |||
| self.shutdown_flag = multiprocessing.Value("i", 0) | |||
| self.trans_data_queues = [ | |||
| multiprocessing.Queue(maxsize=1) for _ in range(self.num_workers) | |||
| ] | |||
| # use shared-memory queue implemented by pyarrow plasma store. | |||
| from ._queue import PlasmaShmQueue | |||
| self.batch_queue = PlasmaShmQueue(maxsize=2) | |||
| self.task_feeding_worker = multiprocessing.Process( | |||
| target=_task_feeding_loop, | |||
| args=( | |||
| iter(self.sampler), | |||
| self.task_queues, | |||
| self.num_workers, | |||
| self.divide, | |||
| self.shutdown_flag, | |||
| self.feed_batch_idx, | |||
| ), | |||
| daemon=True, | |||
| ) | |||
| self.task_feeding_worker.start() | |||
| self.workers = [] | |||
| for worker_id in range(self.num_workers): | |||
| worker = multiprocessing.Process( | |||
| target=_worker_loop, | |||
| args=( | |||
| self.dataset, | |||
| self.task_queues[worker_id], | |||
| self.trans_data_queues[worker_id], | |||
| self.transform, | |||
| self.seed + worker_id + 1, | |||
| self.shutdown_flag, | |||
| ), | |||
| daemon=True, | |||
| ) | |||
| worker.start() | |||
| self.workers.append(worker) | |||
| if self.divide: | |||
| self.data_collecting_worker = multiprocessing.Process( | |||
| target=_data_gathering_loop, | |||
| args=( | |||
| self.trans_data_queues, | |||
| self.batch_queue, | |||
| self.collator, | |||
| len(self), | |||
| self.num_workers, | |||
| self.shutdown_flag, | |||
| self.target_batch_idx, | |||
| ), | |||
| daemon=True, | |||
| ) | |||
| else: | |||
| self.data_collecting_worker = multiprocessing.Process( | |||
| target=_data_selecting_loop, | |||
| args=( | |||
| self.trans_data_queues, | |||
| self.batch_queue, | |||
| self.collator, | |||
| len(self), | |||
| self.num_workers, | |||
| self.shutdown_flag, | |||
| self.target_batch_idx, | |||
| ), | |||
| daemon=True, | |||
| ) | |||
| self.data_collecting_worker.start() | |||
| self.__initialized = True | |||
| def _check_workers(self): | |||
| # Check the status of each worker. | |||
| if not self.data_collecting_worker.is_alive(): | |||
| exitcode = self.task_feeding_worker.exitcode | |||
| if exitcode != 0: | |||
| raise RuntimeError("data collecting worker died. {}".format(exitcode)) | |||
| if not self.task_feeding_worker.is_alive(): | |||
| exitcode = self.task_feeding_worker.exitcode | |||
| if exitcode != 0: | |||
| raise RuntimeError("task feeding worker died. {}".format(exitcode)) | |||
| for worker_id, worker in enumerate(self.workers): | |||
| if not worker.is_alive(): | |||
| exitcode = worker.exitcode | |||
| if exitcode != 0: | |||
| raise RuntimeError("worker:{} died. {}".format(worker_id, exitcode)) | |||
| logger.debug("all workers are alive.") | |||
| def _try_get_next_batch(self): | |||
| start_time = time.time() | |||
| while True: | |||
| self._check_workers() | |||
| try: | |||
| return self.batch_queue.get(timeout=1) | |||
| except queue.Empty: | |||
| logger.debug("batch queue empty!") | |||
| waited_time = time.time() - start_time | |||
| if self.timeout > 0: | |||
| if waited_time > self.timeout: | |||
| raise RuntimeError("get_next_batch timeout!") | |||
| def _get_next_batch(self): | |||
| batch_data = self._try_get_next_batch() | |||
| return batch_data | |||
| def _shutdown(self): | |||
| with self.shutdown_flag.get_lock(): | |||
| self.shutdown_flag.value = 1 | |||
| if self.task_feeding_worker.is_alive(): | |||
| self.task_feeding_worker.terminate() | |||
| self.task_feeding_worker.join() | |||
| if self.data_collecting_worker.is_alive(): | |||
| self.data_collecting_worker.terminate() | |||
| self.data_collecting_worker.join() | |||
| for worker in self.workers: | |||
| if worker.is_alive(): | |||
| worker.terminate() | |||
| worker.join() | |||
| for q in self.trans_data_queues: | |||
| q.cancel_join_thread() | |||
| q.close() | |||
| for q in self.task_queues: | |||
| q.cancel_join_thread() | |||
| q.close() | |||
| self.batch_queue.cancel_join_thread() | |||
| self.batch_queue.close() | |||
| def __del__(self): | |||
| if self.__initialized: | |||
| self._shutdown() | |||
| def _task_feeding_loop( | |||
| indices_iter, task_queues, num_workers, divide, shutdown_flag, feed_batch_idx | |||
| ): | |||
| # Feed the indices into the task queues | |||
| while True: | |||
| if shutdown_flag.value == 1: | |||
| break | |||
| batch_idx = feed_batch_idx.value | |||
| try: | |||
| indices = next(indices_iter) | |||
| except StopIteration: | |||
| break | |||
| if divide: | |||
| # make sure all task_queues is ready for put | |||
| while any([q.full() for q in task_queues]): | |||
| if shutdown_flag.value == 1: | |||
| return | |||
| # divide into small pieces, feed to different workers. | |||
| sub_num = math.ceil(len(indices) / num_workers) | |||
| for worker_id in range(num_workers): | |||
| sub_indices = indices[worker_id * sub_num : (worker_id + 1) * sub_num] | |||
| task_queues[worker_id].put((batch_idx, sub_indices)) | |||
| else: | |||
| # distribute tasks to different workers uniformly. | |||
| target_id = batch_idx % num_workers | |||
| while task_queues[target_id].full(): | |||
| if shutdown_flag.value == 1: | |||
| return | |||
| task_queues[target_id].put((batch_idx, indices)) | |||
| with feed_batch_idx.get_lock(): | |||
| feed_batch_idx.value += 1 | |||
| def _worker_loop(dataset, task_queue, trans_data_queue, transform, seed, shutdown_flag): | |||
| # Get dataset items and do the transform | |||
| random.seed(seed) | |||
| np.random.seed(seed) | |||
| while True: | |||
| if shutdown_flag.value == 1: | |||
| break | |||
| try: | |||
| batch_idx, indices = task_queue.get(timeout=MP_QUEUE_GET_TIMEOUT) | |||
| except queue.Empty: | |||
| continue | |||
| if len(indices) > 0: | |||
| items = [dataset[idx] for idx in indices] | |||
| trans_items = transform.apply_batch(items) | |||
| else: | |||
| # in case of incomplete last batch | |||
| trans_items = () | |||
| while True: | |||
| try: | |||
| trans_data_queue.put((batch_idx, trans_items), timeout=1) | |||
| break | |||
| except queue.Full: | |||
| if shutdown_flag.value == 1: | |||
| break | |||
| logger.debug("batch part queue is full!") | |||
| def _data_gathering_loop( | |||
| trans_data_queues, | |||
| batch_queue, | |||
| collator, | |||
| length, | |||
| num_workers, | |||
| shutdown_flag, | |||
| target_idx, | |||
| ): | |||
| # Gathering the small pieces of batch data into full batch data | |||
| while True: | |||
| if shutdown_flag.value == 1: | |||
| break | |||
| target_batch_idx = target_idx.value | |||
| if target_batch_idx >= length: | |||
| break | |||
| full_trans_items = [] | |||
| for worker_id in range(num_workers): | |||
| while True: | |||
| try: | |||
| batch_idx, trans_items = trans_data_queues[worker_id].get( | |||
| timeout=MP_QUEUE_GET_TIMEOUT | |||
| ) | |||
| break | |||
| except queue.Empty: | |||
| if shutdown_flag.value == 1: | |||
| break | |||
| logger.debug( | |||
| "worker:{} data queue get timeout! target batch idx:{}".format( | |||
| worker_id, target_batch_idx | |||
| ) | |||
| ) | |||
| if batch_idx != target_batch_idx: | |||
| raise RuntimeError( | |||
| "Unexperted batch_idx in data gathering loop. worker_id:{}.".format( | |||
| worker_id | |||
| ) | |||
| ) | |||
| else: | |||
| full_trans_items.extend(trans_items) | |||
| # Merge different parts into a batch. | |||
| full_batch = collator.apply(full_trans_items) | |||
| while True: | |||
| try: | |||
| batch_queue.put(full_batch, timeout=1) | |||
| break | |||
| except queue.Full: | |||
| if shutdown_flag.value == 1: | |||
| break | |||
| logger.debug("batch queue is full!") | |||
| with target_idx.get_lock(): | |||
| target_idx.value += 1 | |||
| batch_queue.disconnect_client() | |||
| def _data_selecting_loop( | |||
| trans_data_queues, | |||
| batch_queue, | |||
| collator, | |||
| length, | |||
| num_workers, | |||
| shutdown_flag, | |||
| target_idx, | |||
| ): | |||
| # Make sure that batch is generated exactly with the same order as generated indices | |||
| while True: | |||
| if shutdown_flag.value == 1: | |||
| break | |||
| target_batch_idx = target_idx.value | |||
| if target_batch_idx >= length: | |||
| break | |||
| target_worker_id = target_batch_idx % num_workers | |||
| while True: | |||
| try: | |||
| batch_idx, trans_items = trans_data_queues[target_worker_id].get( | |||
| timeout=MP_QUEUE_GET_TIMEOUT | |||
| ) | |||
| batch_data = collator.apply(trans_items) | |||
| break | |||
| except queue.Empty: | |||
| if shutdown_flag.value == 1: | |||
| break | |||
| logger.debug( | |||
| "worker:{} data queue get timeout! target batch idx:{}".format( | |||
| target_worker_id, target_batch_idx | |||
| ) | |||
| ) | |||
| if batch_idx != target_batch_idx: | |||
| raise RuntimeError( | |||
| "batch_idx {} mismatch the target_batch_idx {}".format( | |||
| batch_idx, target_batch_idx | |||
| ) | |||
| ) | |||
| while True: | |||
| try: | |||
| batch_queue.put(batch_data, timeout=1) | |||
| break | |||
| except queue.Full: | |||
| if shutdown_flag.value == 1: | |||
| break | |||
| logger.debug("batch queue is full!") | |||
| with target_idx.get_lock(): | |||
| target_idx.value += 1 | |||
| batch_queue.disconnect_client() | |||
| @@ -0,0 +1,10 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from .meta_dataset import ArrayDataset, Dataset, MapDataset, StreamDataset | |||
| from .vision import * | |||
| @@ -0,0 +1,73 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from abc import ABC, abstractmethod | |||
| from typing import Tuple | |||
| class Dataset(ABC): | |||
| r""" | |||
| An abstract class for all Datasets | |||
| """ | |||
| @abstractmethod | |||
| def __init__(self): | |||
| pass | |||
| class MapDataset(Dataset): | |||
| r""" | |||
| An abstract class for map data | |||
| __getitem__ and __len__ method are aditionally needed | |||
| """ | |||
| @abstractmethod | |||
| def __init__(self): | |||
| pass | |||
| @abstractmethod | |||
| def __getitem__(self, index): | |||
| pass | |||
| @abstractmethod | |||
| def __len__(self): | |||
| pass | |||
| class StreamDataset(Dataset): | |||
| r""" | |||
| An abstract class for stream data | |||
| __iter__ method is aditionally needed | |||
| """ | |||
| @abstractmethod | |||
| def __init__(self): | |||
| pass | |||
| @abstractmethod | |||
| def __iter__(self): | |||
| pass | |||
| class ArrayDataset(MapDataset): | |||
| def __init__(self, *arrays): | |||
| r""" | |||
| ArrayDataset is a dataset for numpy array data, one or more numpy arrays | |||
| are needed to initiate the dataset. And the dimensions represented sample number | |||
| are expected to be the same. | |||
| """ | |||
| super().__init__() | |||
| if not all(len(arrays[0]) == len(array) for array in arrays): | |||
| raise ValueError("lengths of input arrays are inconsistent") | |||
| self.arrays = arrays | |||
| def __getitem__(self, index: int) -> Tuple: | |||
| return tuple(array[index] for array in self.arrays) | |||
| def __len__(self) -> int: | |||
| return len(self.arrays[0]) | |||
| @@ -0,0 +1,17 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from .cifar import CIFAR10, CIFAR100 | |||
| from .cityscapes import Cityscapes | |||
| from .coco import COCO | |||
| from .folder import ImageFolder | |||
| from .imagenet import ImageNet | |||
| from .meta_vision import VisionDataset | |||
| from .mnist import MNIST | |||
| from .objects365 import Objects365 | |||
| from .voc import PascalVOC | |||
| @@ -0,0 +1,171 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import os | |||
| import pickle | |||
| import tarfile | |||
| from typing import Tuple | |||
| import numpy as np | |||
| from ....logger import get_logger | |||
| from .meta_vision import VisionDataset | |||
| from .utils import _default_dataset_root, load_raw_data_from_url | |||
| logger = get_logger(__name__) | |||
| class CIFAR10(VisionDataset): | |||
| r""" ``Dataset`` for CIFAR10 meta data | |||
| """ | |||
| url_path = "http://www.cs.utoronto.ca/~kriz/" | |||
| raw_file_name = "cifar-10-python.tar.gz" | |||
| raw_file_md5 = "c58f30108f718f92721af3b95e74349a" | |||
| raw_file_dir = "cifar-10-batches-py" | |||
| train_batch = [ | |||
| "data_batch_1", | |||
| "data_batch_2", | |||
| "data_batch_3", | |||
| "data_batch_4", | |||
| "data_batch_5", | |||
| ] | |||
| test_batch = ["test_batch"] | |||
| meta_info = {"name": "batches.meta"} | |||
| def __init__( | |||
| self, | |||
| root: str = None, | |||
| train: bool = True, | |||
| download: bool = True, | |||
| timeout: int = 500, | |||
| ): | |||
| super().__init__(root, order=("image", "image_category")) | |||
| self.timeout = timeout | |||
| # process the root path | |||
| if root is None: | |||
| self.root = self._default_root | |||
| if not os.path.exists(self.root): | |||
| os.makedirs(self.root) | |||
| else: | |||
| self.root = root | |||
| if not os.path.exists(self.root): | |||
| if download: | |||
| logger.debug( | |||
| "dir %s does not exist, will be automatically created", | |||
| self.root, | |||
| ) | |||
| os.makedirs(self.root) | |||
| else: | |||
| raise ValueError("dir %s does not exist" % self.root) | |||
| self.target_file = os.path.join(self.root, self.raw_file_dir) | |||
| # check existence of target pickle dir, if exists load the | |||
| # pickle file no matter what download is set | |||
| if os.path.exists(self.target_file): | |||
| if train: | |||
| self.arrays = self.bytes2array(self.train_batch) | |||
| else: | |||
| self.arrays = self.bytes2array(self.test_batch) | |||
| else: | |||
| if download: | |||
| self.download() | |||
| if train: | |||
| self.arrays = self.bytes2array(self.train_batch) | |||
| else: | |||
| self.arrays = self.bytes2array(self.test_batch) | |||
| else: | |||
| raise ValueError( | |||
| "dir does not contain target file %s, please set download=True" | |||
| % (self.target_file) | |||
| ) | |||
| def __getitem__(self, index: int) -> Tuple: | |||
| return tuple(array[index] for array in self.arrays) | |||
| def __len__(self) -> int: | |||
| return len(self.arrays[0]) | |||
| @property | |||
| def _default_root(self): | |||
| return os.path.join(_default_dataset_root(), self.__class__.__name__) | |||
| @property | |||
| def meta(self): | |||
| meta_path = os.path.join(self.root, self.raw_file_dir, self.meta_info["name"]) | |||
| with open(meta_path, "rb") as f: | |||
| meta = pickle.load(f, encoding="bytes") | |||
| return meta | |||
| def download(self): | |||
| url = self.url_path + self.raw_file_name | |||
| load_raw_data_from_url( | |||
| url, self.raw_file_name, self.raw_file_md5, self.root, self.timeout | |||
| ) | |||
| self.process() | |||
| def untar(self, file_path, dirs): | |||
| assert file_path.endswith(".tar.gz") | |||
| logger.debug("untar file %s to %s", file_path, dirs) | |||
| t = tarfile.open(file_path) | |||
| t.extractall(path=dirs) | |||
| def bytes2array(self, filenames): | |||
| data = [] | |||
| label = [] | |||
| for filename in filenames: | |||
| path = os.path.join(self.root, self.raw_file_dir, filename) | |||
| logger.debug("unpickle file %s", path) | |||
| with open(path, "rb") as fo: | |||
| dic = pickle.load(fo, encoding="bytes") | |||
| batch_data = dic[b"data"].reshape(-1, 3, 32, 32).transpose((0, 2, 3, 1)) | |||
| data.extend(list(batch_data[..., [2, 1, 0]])) | |||
| label.extend(dic[b"labels"]) | |||
| label = np.array(label, dtype=np.int32) | |||
| return (data, label) | |||
| def process(self): | |||
| logger.info("process raw data ...") | |||
| self.untar(os.path.join(self.root, self.raw_file_name), self.root) | |||
| class CIFAR100(CIFAR10): | |||
| url_path = "http://www.cs.utoronto.ca/~kriz/" | |||
| raw_file_name = "cifar-100-python.tar.gz" | |||
| raw_file_md5 = "eb9058c3a382ffc7106e4002c42a8d85" | |||
| raw_file_dir = "cifar-100-python" | |||
| train_batch = ["train"] | |||
| test_batch = ["test"] | |||
| meta_info = {"name": "meta"} | |||
| @property | |||
| def meta(self): | |||
| meta_path = os.path.join(self.root, self.raw_file_dir, self.meta_info["name"]) | |||
| with open(meta_path, "rb") as f: | |||
| meta = pickle.load(f, encoding="bytes") | |||
| return meta | |||
| def bytes2array(self, filenames): | |||
| data = [] | |||
| fine_label = [] | |||
| coarse_label = [] | |||
| for filename in filenames: | |||
| path = os.path.join(self.root, self.raw_file_dir, filename) | |||
| logger.debug("unpickle file %s", path) | |||
| with open(path, "rb") as fo: | |||
| dic = pickle.load(fo, encoding="bytes") | |||
| batch_data = dic[b"data"].reshape(-1, 3, 32, 32).transpose((0, 2, 3, 1)) | |||
| data.extend(list(batch_data[..., [2, 1, 0]])) | |||
| fine_label.extend(dic[b"fine_labels"]) | |||
| coarse_label.extend(dic[b"coarse_labels"]) | |||
| fine_label = np.array(fine_label, dtype=np.int32) | |||
| coarse_label = np.array(coarse_label, dtype=np.int32) | |||
| return data, fine_label, coarse_label | |||
| @@ -0,0 +1,151 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # --------------------------------------------------------------------- | |||
| # Part of the following code in this file refs to torchvision | |||
| # BSD 3-Clause License | |||
| # | |||
| # Copyright (c) Soumith Chintala 2016, | |||
| # All rights reserved. | |||
| # --------------------------------------------------------------------- | |||
| import json | |||
| import os | |||
| import cv2 | |||
| import numpy as np | |||
| from .meta_vision import VisionDataset | |||
| class Cityscapes(VisionDataset): | |||
| r"""`Cityscapes <http://www.cityscapes-dataset.com/>`_ Dataset. | |||
| """ | |||
| supported_order = ( | |||
| "image", | |||
| "mask", | |||
| "info", | |||
| ) | |||
| def __init__(self, root, image_set, mode, *, order=None): | |||
| super().__init__(root, order=order, supported_order=self.supported_order) | |||
| city_root = self.root | |||
| if not os.path.isdir(city_root): | |||
| raise RuntimeError("Dataset not found or corrupted.") | |||
| self.mode = mode | |||
| self.images_dir = os.path.join(city_root, "leftImg8bit", image_set) | |||
| self.masks_dir = os.path.join(city_root, self.mode, image_set) | |||
| self.images, self.masks = [], [] | |||
| # self.target_type = ["instance", "semantic", "polygon", "color"] | |||
| # for semantic segmentation | |||
| if mode == "gtFine": | |||
| valid_modes = ("train", "test", "val") | |||
| else: | |||
| valid_modes = ("train", "train_extra", "val") | |||
| for city in os.listdir(self.images_dir): | |||
| img_dir = os.path.join(self.images_dir, city) | |||
| mask_dir = os.path.join(self.masks_dir, city) | |||
| for file_name in os.listdir(img_dir): | |||
| mask_name = "{}_{}".format( | |||
| file_name.split("_leftImg8bit")[0], | |||
| self._get_target_suffix(self.mode, "semantic"), | |||
| ) | |||
| self.images.append(os.path.join(img_dir, file_name)) | |||
| self.masks.append(os.path.join(mask_dir, mask_name)) | |||
| def __getitem__(self, index): | |||
| target = [] | |||
| for k in self.order: | |||
| if k == "image": | |||
| image = cv2.imread(self.images[index], cv2.IMREAD_COLOR) | |||
| target.append(image) | |||
| elif k == "mask": | |||
| mask = cv2.imread(self.masks[index], cv2.IMREAD_GRAYSCALE) | |||
| mask = self._trans_mask(mask) | |||
| mask = mask[:, :, np.newaxis] | |||
| target.append(mask) | |||
| elif k == "info": | |||
| if image is None: | |||
| image = cv2.imread(self.images[index], cv2.IMREAD_COLOR) | |||
| info = [image.shape[0], image.shape[1], self.images[index]] | |||
| target.append(info) | |||
| else: | |||
| raise NotImplementedError | |||
| return tuple(target) | |||
| def __len__(self): | |||
| return len(self.images) | |||
| def _trans_mask(self, mask): | |||
| trans_labels = [ | |||
| 7, | |||
| 8, | |||
| 11, | |||
| 12, | |||
| 13, | |||
| 17, | |||
| 19, | |||
| 20, | |||
| 21, | |||
| 22, | |||
| 23, | |||
| 24, | |||
| 25, | |||
| 26, | |||
| 27, | |||
| 28, | |||
| 31, | |||
| 32, | |||
| 33, | |||
| ] | |||
| label = np.ones(mask.shape) * 255 | |||
| for i, tl in enumerate(trans_labels): | |||
| label[mask == tl] = i | |||
| return label.astype(np.uint8) | |||
| def _get_target_suffix(self, mode, target_type): | |||
| if target_type == "instance": | |||
| return "{}_instanceIds.png".format(mode) | |||
| elif target_type == "semantic": | |||
| return "{}_labelIds.png".format(mode) | |||
| elif target_type == "color": | |||
| return "{}_color.png".format(mode) | |||
| else: | |||
| return "{}_polygons.json".format(mode) | |||
| def _load_json(self, path): | |||
| with open(path, "r") as file: | |||
| data = json.load(file) | |||
| return data | |||
| class_names = ( | |||
| "road", | |||
| "sidewalk", | |||
| "building", | |||
| "wall", | |||
| "fence", | |||
| "pole", | |||
| "traffic light", | |||
| "traffic sign", | |||
| "vegetation", | |||
| "terrain", | |||
| "sky", | |||
| "person", | |||
| "rider", | |||
| "car", | |||
| "truck", | |||
| "bus", | |||
| "train", | |||
| "motorcycle", | |||
| "bicycle", | |||
| ) | |||
| @@ -0,0 +1,366 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # --------------------------------------------------------------------- | |||
| # Part of the following code in this file refs to maskrcnn-benchmark | |||
| # MIT License | |||
| # | |||
| # Copyright (c) 2018 Facebook | |||
| # --------------------------------------------------------------------- | |||
| import json | |||
| import os | |||
| from collections import defaultdict | |||
| import cv2 | |||
| import numpy as np | |||
| from .meta_vision import VisionDataset | |||
| min_keypoints_per_image = 10 | |||
| def _count_visible_keypoints(anno): | |||
| return sum(sum(1 for v in ann["keypoints"][2::3] if v > 0) for ann in anno) | |||
| def has_valid_annotation(anno, order): | |||
| # if it"s empty, there is no annotation | |||
| if len(anno) == 0: | |||
| return False | |||
| if "boxes" in order or "boxes_category" in order: | |||
| if "bbox" not in anno[0]: | |||
| return False | |||
| if "keypoints" in order: | |||
| if "keypoints" not in anno[0]: | |||
| return False | |||
| # for keypoint detection tasks, only consider valid images those | |||
| # containing at least min_keypoints_per_image | |||
| if _count_visible_keypoints(anno) < min_keypoints_per_image: | |||
| return False | |||
| return True | |||
| class COCO(VisionDataset): | |||
| r"""`MS COCO <http://cocodataset.org/#home>`_ Dataset. | |||
| """ | |||
| supported_order = ( | |||
| "image", | |||
| "boxes", | |||
| "boxes_category", | |||
| "keypoints", | |||
| # TODO: need to check | |||
| # "polygons", | |||
| "info", | |||
| ) | |||
| def __init__( | |||
| self, root, ann_file, remove_images_without_annotations=False, *, order=None | |||
| ): | |||
| super().__init__(root, order=order, supported_order=self.supported_order) | |||
| with open(ann_file, "r") as f: | |||
| dataset = json.load(f) | |||
| self.imgs = dict() | |||
| for img in dataset["images"]: | |||
| # for saving memory | |||
| if "license" in img: | |||
| del img["license"] | |||
| if "coco_url" in img: | |||
| del img["coco_url"] | |||
| if "date_captured" in img: | |||
| del img["date_captured"] | |||
| if "flickr_url" in img: | |||
| del img["flickr_url"] | |||
| self.imgs[img["id"]] = img | |||
| self.img_to_anns = defaultdict(list) | |||
| for ann in dataset["annotations"]: | |||
| # for saving memory | |||
| if ( | |||
| "boxes" not in self.order | |||
| and "boxes_category" not in self.order | |||
| and "bbox" in ann | |||
| ): | |||
| del ann["bbox"] | |||
| if "polygons" not in self.order and "segmentation" in ann: | |||
| del ann["segmentation"] | |||
| self.img_to_anns[ann["image_id"]].append(ann) | |||
| self.cats = dict() | |||
| for cat in dataset["categories"]: | |||
| self.cats[cat["id"]] = cat | |||
| self.ids = list(sorted(self.imgs.keys())) | |||
| # filter images without detection annotations | |||
| if remove_images_without_annotations: | |||
| ids = [] | |||
| for img_id in self.ids: | |||
| anno = self.img_to_anns[img_id] | |||
| # filter crowd annotations | |||
| anno = [obj for obj in anno if obj["iscrowd"] == 0] | |||
| anno = [ | |||
| obj for obj in anno if obj["bbox"][2] > 0 and obj["bbox"][3] > 0 | |||
| ] | |||
| if has_valid_annotation(anno, order): | |||
| ids.append(img_id) | |||
| self.img_to_anns[img_id] = anno | |||
| else: | |||
| del self.imgs[img_id] | |||
| del self.img_to_anns[img_id] | |||
| self.ids = ids | |||
| self.json_category_id_to_contiguous_id = { | |||
| v: i + 1 for i, v in enumerate(self.cats.keys()) | |||
| } | |||
| self.contiguous_category_id_to_json_id = { | |||
| v: k for k, v in self.json_category_id_to_contiguous_id.items() | |||
| } | |||
| def __getitem__(self, index): | |||
| img_id = self.ids[index] | |||
| anno = self.img_to_anns[img_id] | |||
| target = [] | |||
| for k in self.order: | |||
| if k == "image": | |||
| file_name = self.imgs[img_id]["file_name"] | |||
| path = os.path.join(self.root, file_name) | |||
| image = cv2.imread(path, cv2.IMREAD_COLOR) | |||
| target.append(image) | |||
| elif k == "boxes": | |||
| boxes = [obj["bbox"] for obj in anno] | |||
| boxes = np.array(boxes, dtype=np.float32).reshape(-1, 4) | |||
| # transfer boxes from xywh to xyxy | |||
| boxes[:, 2:] += boxes[:, :2] | |||
| target.append(boxes) | |||
| elif k == "boxes_category": | |||
| boxes_category = [obj["category_id"] for obj in anno] | |||
| boxes_category = [ | |||
| self.json_category_id_to_contiguous_id[c] for c in boxes_category | |||
| ] | |||
| boxes_category = np.array(boxes_category, dtype=np.int32) | |||
| target.append(boxes_category) | |||
| elif k == "keypoints": | |||
| keypoints = [obj["keypoints"] for obj in anno] | |||
| keypoints = np.array(keypoints, dtype=np.float32).reshape( | |||
| -1, len(self.keypoint_names), 3 | |||
| ) | |||
| target.append(keypoints) | |||
| elif k == "polygons": | |||
| polygons = [obj["segmentation"] for obj in anno] | |||
| polygons = [ | |||
| [np.array(p, dtype=np.float32).reshape(-1, 2) for p in ps] | |||
| for ps in polygons | |||
| ] | |||
| target.append(polygons) | |||
| elif k == "info": | |||
| info = self.imgs[img_id] | |||
| info = [info["height"], info["width"], info["file_name"]] | |||
| target.append(info) | |||
| else: | |||
| raise NotImplementedError | |||
| return tuple(target) | |||
| def __len__(self): | |||
| return len(self.ids) | |||
| def get_img_info(self, index): | |||
| img_id = self.ids[index] | |||
| img_info = self.imgs[img_id] | |||
| return img_info | |||
| class_names = ( | |||
| "person", | |||
| "bicycle", | |||
| "car", | |||
| "motorcycle", | |||
| "airplane", | |||
| "bus", | |||
| "train", | |||
| "truck", | |||
| "boat", | |||
| "traffic light", | |||
| "fire hydrant", | |||
| "stop sign", | |||
| "parking meter", | |||
| "bench", | |||
| "bird", | |||
| "cat", | |||
| "dog", | |||
| "horse", | |||
| "sheep", | |||
| "cow", | |||
| "elephant", | |||
| "bear", | |||
| "zebra", | |||
| "giraffe", | |||
| "backpack", | |||
| "umbrella", | |||
| "handbag", | |||
| "tie", | |||
| "suitcase", | |||
| "frisbee", | |||
| "skis", | |||
| "snowboard", | |||
| "sports ball", | |||
| "kite", | |||
| "baseball bat", | |||
| "baseball glove", | |||
| "skateboard", | |||
| "surfboard", | |||
| "tennis racket", | |||
| "bottle", | |||
| "wine glass", | |||
| "cup", | |||
| "fork", | |||
| "knife", | |||
| "spoon", | |||
| "bowl", | |||
| "banana", | |||
| "apple", | |||
| "sandwich", | |||
| "orange", | |||
| "broccoli", | |||
| "carrot", | |||
| "hot dog", | |||
| "pizza", | |||
| "donut", | |||
| "cake", | |||
| "chair", | |||
| "couch", | |||
| "potted plant", | |||
| "bed", | |||
| "dining table", | |||
| "toilet", | |||
| "tv", | |||
| "laptop", | |||
| "mouse", | |||
| "remote", | |||
| "keyboard", | |||
| "cell phone", | |||
| "microwave", | |||
| "oven", | |||
| "toaster", | |||
| "sink", | |||
| "refrigerator", | |||
| "book", | |||
| "clock", | |||
| "vase", | |||
| "scissors", | |||
| "teddy bear", | |||
| "hair drier", | |||
| "toothbrush", | |||
| ) | |||
| classes_originID = { | |||
| "person": 1, | |||
| "bicycle": 2, | |||
| "car": 3, | |||
| "motorcycle": 4, | |||
| "airplane": 5, | |||
| "bus": 6, | |||
| "train": 7, | |||
| "truck": 8, | |||
| "boat": 9, | |||
| "traffic light": 10, | |||
| "fire hydrant": 11, | |||
| "stop sign": 13, | |||
| "parking meter": 14, | |||
| "bench": 15, | |||
| "bird": 16, | |||
| "cat": 17, | |||
| "dog": 18, | |||
| "horse": 19, | |||
| "sheep": 20, | |||
| "cow": 21, | |||
| "elephant": 22, | |||
| "bear": 23, | |||
| "zebra": 24, | |||
| "giraffe": 25, | |||
| "backpack": 27, | |||
| "umbrella": 28, | |||
| "handbag": 31, | |||
| "tie": 32, | |||
| "suitcase": 33, | |||
| "frisbee": 34, | |||
| "skis": 35, | |||
| "snowboard": 36, | |||
| "sports ball": 37, | |||
| "kite": 38, | |||
| "baseball bat": 39, | |||
| "baseball glove": 40, | |||
| "skateboard": 41, | |||
| "surfboard": 42, | |||
| "tennis racket": 43, | |||
| "bottle": 44, | |||
| "wine glass": 46, | |||
| "cup": 47, | |||
| "fork": 48, | |||
| "knife": 49, | |||
| "spoon": 50, | |||
| "bowl": 51, | |||
| "banana": 52, | |||
| "apple": 53, | |||
| "sandwich": 54, | |||
| "orange": 55, | |||
| "broccoli": 56, | |||
| "carrot": 57, | |||
| "hot dog": 58, | |||
| "pizza": 59, | |||
| "donut": 60, | |||
| "cake": 61, | |||
| "chair": 62, | |||
| "couch": 63, | |||
| "potted plant": 64, | |||
| "bed": 65, | |||
| "dining table": 67, | |||
| "toilet": 70, | |||
| "tv": 72, | |||
| "laptop": 73, | |||
| "mouse": 74, | |||
| "remote": 75, | |||
| "keyboard": 76, | |||
| "cell phone": 77, | |||
| "microwave": 78, | |||
| "oven": 79, | |||
| "toaster": 80, | |||
| "sink": 81, | |||
| "refrigerator": 82, | |||
| "book": 84, | |||
| "clock": 85, | |||
| "vase": 86, | |||
| "scissors": 87, | |||
| "teddy bear": 88, | |||
| "hair drier": 89, | |||
| "toothbrush": 90, | |||
| } | |||
| keypoint_names = ( | |||
| "nose", | |||
| "left_eye", | |||
| "right_eye", | |||
| "left_ear", | |||
| "right_ear", | |||
| "left_shoulder", | |||
| "right_shoulder", | |||
| "left_elbow", | |||
| "right_elbow", | |||
| "left_wrist", | |||
| "right_wrist", | |||
| "left_hip", | |||
| "right_hip", | |||
| "left_knee", | |||
| "right_knee", | |||
| "left_ankle", | |||
| "right_ankle", | |||
| ) | |||
| @@ -0,0 +1,90 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # BSD 3-Clause License | |||
| # Copyright (c) Soumith Chintala 2016, | |||
| # All rights reserved. | |||
| # --------------------------------------------------------------------- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # | |||
| # This file has been modified by Megvii ("Megvii Modifications"). | |||
| # All Megvii Modifications are Copyright (C) 2014-2020 Megvii Inc. All rights reserved. | |||
| # --------------------------------------------------------------------- | |||
| import os | |||
| from typing import Dict, List, Tuple | |||
| import cv2 | |||
| import numpy as np | |||
| from .meta_vision import VisionDataset | |||
| from .utils import is_img | |||
| class ImageFolder(VisionDataset): | |||
| def __init__(self, root: str, check_valid_func=None, class_name: bool = False): | |||
| r""" | |||
| ImageFolder is a class for loading image data and labels from a organized folder. | |||
| the folder is expected to be organized as followed | |||
| root/cls/xxx.img_ext | |||
| labels are indices of sorted classes in the root directory | |||
| :param root: root directory of an image folder | |||
| :param loader: a function used to load image from path, | |||
| if ``None``, default function that loads | |||
| images with PILwill be called | |||
| :param check_valid_func: a function used to check if files in folder are | |||
| expected image files, if ``None``, default function | |||
| that checks file extensions will be called | |||
| :param class_name: if ``True``, return class name instead of class index | |||
| """ | |||
| super().__init__(root, order=("image", "image_category")) | |||
| self.root = root | |||
| if check_valid_func is not None: | |||
| self.check_valid = check_valid_func | |||
| else: | |||
| self.check_valid = is_img | |||
| self.class_name = class_name | |||
| self.class_dict = self.collect_class() | |||
| self.samples = self.collect_samples() | |||
| def collect_samples(self) -> List: | |||
| samples = [] | |||
| directory = os.path.expanduser(self.root) | |||
| for key in sorted(self.class_dict.keys()): | |||
| d = os.path.join(directory, key) | |||
| if not os.path.isdir(d): | |||
| continue | |||
| for r, _, filename in sorted(os.walk(d, followlinks=True)): | |||
| for name in sorted(filename): | |||
| path = os.path.join(r, name) | |||
| if self.check_valid(path): | |||
| if self.class_name: | |||
| samples.append((path, key)) | |||
| else: | |||
| samples.append((path, self.class_dict[key])) | |||
| return samples | |||
| def collect_class(self) -> Dict: | |||
| classes = [d.name for d in os.scandir(self.root) if d.is_dir()] | |||
| classes.sort() | |||
| return {classes[i]: np.int32(i) for i in range(len(classes))} | |||
| def __getitem__(self, index: int) -> Tuple: | |||
| path, label = self.samples[index] | |||
| img = cv2.imread(path, cv2.IMREAD_COLOR) | |||
| return img, label | |||
| def __len__(self): | |||
| return len(self.samples) | |||
| @@ -0,0 +1,248 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # BSD 3-Clause License | |||
| # | |||
| # Copyright (c) Soumith Chintala 2016, | |||
| # All rights reserved. | |||
| # --------------------------------------------------------------------- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # | |||
| # This file has been modified by Megvii ("Megvii Modifications"). | |||
| # All Megvii Modifications are Copyright (C) 2014-2020 Megvii Inc. All rights reserved. | |||
| # --------------------------------------------------------------------- | |||
| import os | |||
| import shutil | |||
| from tqdm import tqdm | |||
| from ....distributed.group import is_distributed | |||
| from ....logger import get_logger | |||
| from ....serialization import load, save | |||
| from .folder import ImageFolder | |||
| from .utils import _default_dataset_root, calculate_md5, untar, untargz | |||
| logger = get_logger(__name__) | |||
| class ImageNet(ImageFolder): | |||
| r""" | |||
| Load ImageNet from raw files or folder, expected folder looks like | |||
| .. code-block:: bash | |||
| ${root}/ | |||
| | [REQUIRED TAR FILES] | |||
| |- ILSVRC2012_img_train.tar | |||
| |- ILSVRC2012_img_val.tar | |||
| |- ILSVRC2012_devkit_t12.tar.gz | |||
| | [OPTIONAL IMAGE FOLDERS] | |||
| |- train/cls/xxx.${img_ext} | |||
| |- val/cls/xxx.${img_ext} | |||
| |- ILSVRC2012_devkit_t12/data/meta.mat | |||
| |- ILSVRC2012_devkit_t12/data/ILSVRC2012_validation_ground_truth.txt | |||
| If the image folders don't exist, raw tar files are required to get extracted and processed. | |||
| """ | |||
| raw_file_meta = { | |||
| "train": ("ILSVRC2012_img_train.tar", "1d675b47d978889d74fa0da5fadfb00e"), | |||
| "val": ("ILSVRC2012_img_val.tar", "29b22e2961454d5413ddabcf34fc5622"), | |||
| "devkit": ("ILSVRC2012_devkit_t12.tar.gz", "fa75699e90414af021442c21a62c3abf"), | |||
| } # ImageNet raw files | |||
| default_train_dir = "train" | |||
| default_val_dir = "val" | |||
| default_devkit_dir = "ILSVRC2012_devkit_t12" | |||
| def __init__(self, root: str = None, train: bool = True, **kwargs): | |||
| r""" | |||
| initialization: | |||
| * if ``root`` contains ``self.target_folder`` depent on ``train``: | |||
| * initialize ImageFolder with target_folder | |||
| * else: | |||
| * if all raw files are in ``root``: | |||
| * parse ``self.target_folder`` from raw files | |||
| * initialize ImageFolder with ``self.target_folder`` | |||
| * else: | |||
| * raise error | |||
| :param root: root directory of imagenet data, if root is ``None``, used default_dataset_root | |||
| :param train: if ``True``, load the train split, otherwise load the validation split | |||
| """ | |||
| # process the root path | |||
| if root is None: | |||
| self.root = self._default_root | |||
| else: | |||
| self.root = root | |||
| if not os.path.exists(self.root): | |||
| raise FileNotFoundError("dir %s does not exist" % self.root) | |||
| self.devkit_dir = os.path.join(self.root, self.default_devkit_dir) | |||
| if not os.path.exists(self.devkit_dir): | |||
| logger.warning("devkit directory %s does not exists", self.devkit_dir) | |||
| self._prepare_devkit() | |||
| self.train = train | |||
| if train: | |||
| self.target_folder = os.path.join(self.root, self.default_train_dir) | |||
| else: | |||
| self.target_folder = os.path.join(self.root, self.default_val_dir) | |||
| if not os.path.exists(self.target_folder): | |||
| logger.warning( | |||
| "expected image folder %s does not exist, try to load from raw file", | |||
| self.target_folder, | |||
| ) | |||
| if not self.check_raw_file(): | |||
| raise FileNotFoundError( | |||
| "expected image folder %s does not exist, and raw files do not exist in %s" | |||
| % (self.target_folder, self.root) | |||
| ) | |||
| elif is_distributed(): | |||
| raise RuntimeError( | |||
| "extracting raw file shouldn't be done in distributed mode, use single process instead" | |||
| ) | |||
| elif train: | |||
| self._prepare_train() | |||
| else: | |||
| self._prepare_val() | |||
| super().__init__(self.target_folder, **kwargs) | |||
| @property | |||
| def _default_root(self): | |||
| return os.path.join(_default_dataset_root(), self.__class__.__name__) | |||
| @property | |||
| def valid_ground_truth(self): | |||
| groud_truth_path = os.path.join( | |||
| self.devkit_dir, "data", "ILSVRC2012_validation_ground_truth.txt" | |||
| ) | |||
| if os.path.exists(groud_truth_path): | |||
| with open(groud_truth_path, "r") as f: | |||
| val_labels = f.readlines() | |||
| return [int(val_label) for val_label in val_labels] | |||
| else: | |||
| raise FileNotFoundError( | |||
| "valid ground truth file %s does not exist" % groud_truth_path | |||
| ) | |||
| @property | |||
| def meta(self): | |||
| try: | |||
| return load(os.path.join(self.devkit_dir, "meta.pkl")) | |||
| except FileNotFoundError: | |||
| import scipy.io | |||
| meta_path = os.path.join(self.devkit_dir, "data", "meta.mat") | |||
| if not os.path.exists(meta_path): | |||
| raise FileNotFoundError("meta file %s does not exist" % meta_path) | |||
| meta = scipy.io.loadmat(meta_path, squeeze_me=True)["synsets"] | |||
| nums_children = list(zip(*meta))[4] | |||
| meta = [ | |||
| meta[idx] | |||
| for idx, num_children in enumerate(nums_children) | |||
| if num_children == 0 | |||
| ] | |||
| idcs, wnids, classes = list(zip(*meta))[:3] | |||
| classes = [tuple(clss.split(", ")) for clss in classes] | |||
| idx_to_wnid = dict(zip(idcs, wnids)) | |||
| wnid_to_classes = dict(zip(wnids, classes)) | |||
| logger.info( | |||
| "saving cached meta file to %s", | |||
| os.path.join(self.devkit_dir, "meta.pkl"), | |||
| ) | |||
| save( | |||
| (idx_to_wnid, wnid_to_classes), | |||
| os.path.join(self.devkit_dir, "meta.pkl"), | |||
| ) | |||
| return idx_to_wnid, wnid_to_classes | |||
| def check_raw_file(self) -> bool: | |||
| return all( | |||
| [ | |||
| os.path.exists(os.path.join(self.root, value[0])) | |||
| for _, value in self.raw_file_meta.items() | |||
| ] | |||
| ) | |||
| def _organize_val_data(self): | |||
| id2wnid = self.meta[0] | |||
| val_idcs = self.valid_ground_truth | |||
| val_wnids = [id2wnid[idx] for idx in val_idcs] | |||
| val_images = sorted( | |||
| [ | |||
| os.path.join(self.target_folder, image) | |||
| for image in os.listdir(self.target_folder) | |||
| ] | |||
| ) | |||
| logger.debug("mkdir for val set wnids") | |||
| for wnid in set(val_wnids): | |||
| os.makedirs(os.path.join(self.root, self.default_val_dir, wnid)) | |||
| logger.debug("mv val images into wnids dir") | |||
| for wnid, img_file in tqdm(zip(val_wnids, val_images)): | |||
| shutil.move( | |||
| img_file, | |||
| os.path.join( | |||
| self.root, self.default_val_dir, wnid, os.path.basename(img_file) | |||
| ), | |||
| ) | |||
| def _prepare_val(self): | |||
| assert not self.train | |||
| raw_filename, checksum = self.raw_file_meta["val"] | |||
| raw_file = os.path.join(self.root, raw_filename) | |||
| logger.info("checksum valid tar file %s ...", raw_file) | |||
| assert ( | |||
| calculate_md5(raw_file) == checksum | |||
| ), "checksum mismatch, {} may be damaged".format(raw_file) | |||
| logger.info("extract valid tar file... this may take 10-20 minutes") | |||
| untar(os.path.join(self.root, raw_file), self.target_folder) | |||
| self._organize_val_data() | |||
| def _prepare_train(self): | |||
| assert self.train | |||
| raw_filename, checksum = self.raw_file_meta["train"] | |||
| raw_file = os.path.join(self.root, raw_filename) | |||
| logger.info("checksum train tar file %s ...", raw_file) | |||
| assert ( | |||
| calculate_md5(raw_file) == checksum | |||
| ), "checksum mismatch, {} may be damaged".format(raw_file) | |||
| logger.info("extract train tar file.. this may take several hours") | |||
| untar( | |||
| os.path.join(self.root, raw_file), self.target_folder, | |||
| ) | |||
| paths = [ | |||
| os.path.join(self.target_folder, child_dir) | |||
| for child_dir in os.listdir(self.target_folder) | |||
| ] | |||
| for path in tqdm(paths): | |||
| untar(path, os.path.splitext(path)[0], remove=True) | |||
| def _prepare_devkit(self): | |||
| raw_filename, checksum = self.raw_file_meta["devkit"] | |||
| raw_file = os.path.join(self.root, raw_filename) | |||
| logger.info("checksum devkit tar file %s ...", raw_file) | |||
| assert ( | |||
| calculate_md5(raw_file) == checksum | |||
| ), "checksum mismatch, {} may be damaged".format(raw_file) | |||
| logger.info("extract devkit file..") | |||
| untargz(os.path.join(self.root, self.raw_file_meta["devkit"][0])) | |||
| @@ -0,0 +1,41 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import collections.abc | |||
| import os | |||
| from ..meta_dataset import MapDataset | |||
| class VisionDataset(MapDataset): | |||
| _repr_indent = 4 | |||
| def __init__(self, root, *, order=None, supported_order=None): | |||
| if isinstance(root, (str, bytes)): | |||
| root = os.path.expanduser(root) | |||
| self.root = root | |||
| if order is None: | |||
| order = ("image",) | |||
| if not isinstance(order, collections.abc.Sequence): | |||
| raise ValueError( | |||
| "order should be a sequence, but got order={}".format(order) | |||
| ) | |||
| if supported_order is not None: | |||
| assert isinstance(supported_order, collections.abc.Sequence) | |||
| for k in order: | |||
| if k not in supported_order: | |||
| raise NotImplementedError("{} is unsupported data type".format(k)) | |||
| self.order = order | |||
| def __getitem__(self, index): | |||
| raise NotImplementedError | |||
| def __len__(self): | |||
| raise NotImplementedError | |||
| @@ -0,0 +1,197 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import gzip | |||
| import os | |||
| import struct | |||
| from typing import Tuple | |||
| import numpy as np | |||
| from tqdm import tqdm | |||
| from ....logger import get_logger | |||
| from .meta_vision import VisionDataset | |||
| from .utils import _default_dataset_root, load_raw_data_from_url | |||
| logger = get_logger(__name__) | |||
| class MNIST(VisionDataset): | |||
| r""" ``Dataset`` for MNIST meta data | |||
| """ | |||
| url_path = "http://yann.lecun.com/exdb/mnist/" | |||
| """ | |||
| url prefix for downloading raw file | |||
| """ | |||
| raw_file_name = [ | |||
| "train-images-idx3-ubyte.gz", | |||
| "train-labels-idx1-ubyte.gz", | |||
| "t10k-images-idx3-ubyte.gz", | |||
| "t10k-labels-idx1-ubyte.gz", | |||
| ] | |||
| """ | |||
| raw file names of both training set and test set (10k) | |||
| """ | |||
| raw_file_md5 = [ | |||
| "f68b3c2dcbeaaa9fbdd348bbdeb94873", | |||
| "d53e105ee54ea40749a09fcbcd1e9432", | |||
| "9fb629c4189551a2d022fa330f9573f3", | |||
| "ec29112dd5afa0611ce80d1b7f02629c", | |||
| ] | |||
| """ | |||
| md5 for checking raw files | |||
| """ | |||
| def __init__( | |||
| self, | |||
| root: str = None, | |||
| train: bool = True, | |||
| download: bool = True, | |||
| timeout: int = 500, | |||
| ): | |||
| r""" | |||
| :param root: path for mnist dataset downloading or loading, if ``None``, | |||
| set ``root`` to the ``_default_root`` | |||
| :param train: if ``True``, loading trainingset, else loading test set | |||
| :param download: if raw files do not exists and download sets to ``True``, | |||
| download raw files and process, otherwise raise ValueError, default is True | |||
| """ | |||
| super().__init__(root, order=("image", "image_category")) | |||
| self.timeout = timeout | |||
| # process the root path | |||
| if root is None: | |||
| self.root = self._default_root | |||
| if not os.path.exists(self.root): | |||
| os.makedirs(self.root) | |||
| else: | |||
| self.root = root | |||
| if not os.path.exists(self.root): | |||
| if download: | |||
| logger.debug( | |||
| "dir %s does not exist, will be automatically created", | |||
| self.root, | |||
| ) | |||
| os.makedirs(self.root) | |||
| else: | |||
| raise ValueError("dir %s does not exist" % self.root) | |||
| if self._check_raw_files(): | |||
| self.process(train) | |||
| elif download: | |||
| self.download() | |||
| self.process(train) | |||
| else: | |||
| raise ValueError( | |||
| "root does not contain valid raw files, please set download=True" | |||
| ) | |||
| def __getitem__(self, index: int) -> Tuple: | |||
| return tuple(array[index] for array in self.arrays) | |||
| def __len__(self) -> int: | |||
| return len(self.arrays[0]) | |||
| @property | |||
| def _default_root(self): | |||
| return os.path.join(_default_dataset_root(), self.__class__.__name__) | |||
| @property | |||
| def meta(self): | |||
| return self._meta_data | |||
| def _check_raw_files(self): | |||
| return all( | |||
| [ | |||
| os.path.exists(os.path.join(self.root, path)) | |||
| for path in self.raw_file_name | |||
| ] | |||
| ) | |||
| def download(self): | |||
| for file_name, md5 in zip(self.raw_file_name, self.raw_file_md5): | |||
| url = self.url_path + file_name | |||
| load_raw_data_from_url(url, file_name, md5, self.root, self.timeout) | |||
| def process(self, train): | |||
| # load raw files and transform them into meta data and datasets Tuple(np.array) | |||
| logger.info("process the raw files of %s set...", "train" if train else "test") | |||
| if train: | |||
| meta_data_images, images = parse_idx3( | |||
| os.path.join(self.root, self.raw_file_name[0]) | |||
| ) | |||
| meta_data_labels, labels = parse_idx1( | |||
| os.path.join(self.root, self.raw_file_name[1]) | |||
| ) | |||
| else: | |||
| meta_data_images, images = parse_idx3( | |||
| os.path.join(self.root, self.raw_file_name[2]) | |||
| ) | |||
| meta_data_labels, labels = parse_idx1( | |||
| os.path.join(self.root, self.raw_file_name[3]) | |||
| ) | |||
| self._meta_data = { | |||
| "images": meta_data_images, | |||
| "labels": meta_data_labels, | |||
| } | |||
| self.arrays = (images, labels.astype(np.int32)) | |||
| def parse_idx3(idx3_file): | |||
| # parse idx3 file to meta data and data in numpy array (images) | |||
| logger.debug("parse idx3 file %s ...", idx3_file) | |||
| assert idx3_file.endswith(".gz") | |||
| with gzip.open(idx3_file, "rb") as f: | |||
| bin_data = f.read() | |||
| # parse meta data | |||
| offset = 0 | |||
| fmt_header = ">iiii" | |||
| magic, imgs, height, width = struct.unpack_from(fmt_header, bin_data, offset) | |||
| meta_data = {"magic": magic, "imgs": imgs, "height": height, "width": width} | |||
| # parse images | |||
| image_size = height * width | |||
| offset += struct.calcsize(fmt_header) | |||
| fmt_image = ">" + str(image_size) + "B" | |||
| images = [] | |||
| bar = tqdm(total=meta_data["imgs"], ncols=80) | |||
| for image in struct.iter_unpack(fmt_image, bin_data[offset:]): | |||
| images.append(np.array(image, dtype=np.uint8).reshape((height, width, 1))) | |||
| bar.update() | |||
| bar.close() | |||
| return meta_data, images | |||
| def parse_idx1(idx1_file): | |||
| # parse idx1 file to meta data and data in numpy array (labels) | |||
| logger.debug("parse idx1 file %s ...", idx1_file) | |||
| assert idx1_file.endswith(".gz") | |||
| with gzip.open(idx1_file, "rb") as f: | |||
| bin_data = f.read() | |||
| # parse meta data | |||
| offset = 0 | |||
| fmt_header = ">ii" | |||
| magic, imgs = struct.unpack_from(fmt_header, bin_data, offset) | |||
| meta_data = {"magic": magic, "imgs": imgs} | |||
| # parse labels | |||
| offset += struct.calcsize(fmt_header) | |||
| fmt_image = ">B" | |||
| labels = np.empty(imgs, dtype=int) | |||
| bar = tqdm(total=meta_data["imgs"], ncols=80) | |||
| for i, label in enumerate(struct.iter_unpack(fmt_image, bin_data[offset:])): | |||
| labels[i] = label[0] | |||
| bar.update() | |||
| bar.close() | |||
| return meta_data, labels | |||
| @@ -0,0 +1,498 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # --------------------------------------------------------------------- | |||
| # Part of the following code in this file refs to maskrcnn-benchmark | |||
| # MIT License | |||
| # | |||
| # Copyright (c) 2018 Facebook | |||
| # --------------------------------------------------------------------- | |||
| import json | |||
| import os | |||
| from collections import defaultdict | |||
| import cv2 | |||
| import numpy as np | |||
| from .meta_vision import VisionDataset | |||
| class Objects365(VisionDataset): | |||
| r"""`Objects365 <https://www.objects365.org/overview.html>`_ Dataset. | |||
| """ | |||
| supported_order = ( | |||
| "image", | |||
| "boxes", | |||
| "boxes_category", | |||
| "info", | |||
| ) | |||
| def __init__( | |||
| self, root, ann_file, remove_images_without_annotations=False, *, order=None | |||
| ): | |||
| super().__init__(root, order=order, supported_order=self.supported_order) | |||
| with open(ann_file, "r") as f: | |||
| dataset = json.load(f) | |||
| self.imgs = dict() | |||
| for img in dataset["images"]: | |||
| self.imgs[img["id"]] = img | |||
| self.img_to_anns = defaultdict(list) | |||
| for ann in dataset["annotations"]: | |||
| # for saving memory | |||
| if ( | |||
| "boxes" not in self.order | |||
| and "boxes_category" not in self.order | |||
| and "bbox" in ann | |||
| ): | |||
| del ann["bbox"] | |||
| self.img_to_anns[ann["image_id"]].append(ann) | |||
| self.cats = dict() | |||
| for cat in dataset["categories"]: | |||
| self.cats[cat["id"]] = cat | |||
| self.ids = list(sorted(self.imgs.keys())) | |||
| # filter images without detection annotations | |||
| if remove_images_without_annotations: | |||
| ids = [] | |||
| for img_id in self.ids: | |||
| anno = self.img_to_anns[img_id] | |||
| # filter crowd annotations | |||
| anno = [obj for obj in anno if obj["iscrowd"] == 0] | |||
| anno = [ | |||
| obj for obj in anno if obj["bbox"][2] > 0 and obj["bbox"][3] > 0 | |||
| ] | |||
| if len(anno) > 0: | |||
| ids.append(img_id) | |||
| self.img_to_anns[img_id] = anno | |||
| else: | |||
| del self.imgs[img_id] | |||
| del self.img_to_anns[img_id] | |||
| self.ids = ids | |||
| self.json_category_id_to_contiguous_id = { | |||
| v: i + 1 for i, v in enumerate(self.cats.keys()) | |||
| } | |||
| self.contiguous_category_id_to_json_id = { | |||
| v: k for k, v in self.json_category_id_to_contiguous_id.items() | |||
| } | |||
| def __getitem__(self, index): | |||
| img_id = self.ids[index] | |||
| anno = self.img_to_anns[img_id] | |||
| target = [] | |||
| for k in self.order: | |||
| if k == "image": | |||
| file_name = self.imgs[img_id]["file_name"] | |||
| path = os.path.join(self.root, file_name) | |||
| image = cv2.imread(path, cv2.IMREAD_COLOR) | |||
| target.append(image) | |||
| elif k == "boxes": | |||
| boxes = [obj["bbox"] for obj in anno] | |||
| boxes = np.array(boxes, dtype=np.float32).reshape(-1, 4) | |||
| # transfer boxes from xywh to xyxy | |||
| boxes[:, 2:] += boxes[:, :2] | |||
| target.append(boxes) | |||
| elif k == "boxes_category": | |||
| boxes_category = [obj["category_id"] for obj in anno] | |||
| boxes_category = [ | |||
| self.json_category_id_to_contiguous_id[c] for c in boxes_category | |||
| ] | |||
| boxes_category = np.array(boxes_category, dtype=np.int32) | |||
| target.append(boxes_category) | |||
| elif k == "info": | |||
| info = self.imgs[img_id] | |||
| info = [info["height"], info["width"], info["file_name"]] | |||
| target.append(info) | |||
| else: | |||
| raise NotImplementedError | |||
| return tuple(target) | |||
| def __len__(self): | |||
| return len(self.ids) | |||
| def get_img_info(self, index): | |||
| img_id = self.ids[index] | |||
| img_info = self.imgs[img_id] | |||
| return img_info | |||
| class_names = ( | |||
| "person", | |||
| "sneakers", | |||
| "chair", | |||
| "hat", | |||
| "lamp", | |||
| "bottle", | |||
| "cabinet/shelf", | |||
| "cup", | |||
| "car", | |||
| "glasses", | |||
| "picture/frame", | |||
| "desk", | |||
| "handbag", | |||
| "street lights", | |||
| "book", | |||
| "plate", | |||
| "helmet", | |||
| "leather shoes", | |||
| "pillow", | |||
| "glove", | |||
| "potted plant", | |||
| "bracelet", | |||
| "flower", | |||
| "tv", | |||
| "storage box", | |||
| "vase", | |||
| "bench", | |||
| "wine glass", | |||
| "boots", | |||
| "bowl", | |||
| "dining table", | |||
| "umbrella", | |||
| "boat", | |||
| "flag", | |||
| "speaker", | |||
| "trash bin/can", | |||
| "stool", | |||
| "backpack", | |||
| "couch", | |||
| "belt", | |||
| "carpet", | |||
| "basket", | |||
| "towel/napkin", | |||
| "slippers", | |||
| "barrel/bucket", | |||
| "coffee table", | |||
| "suv", | |||
| "toy", | |||
| "tie", | |||
| "bed", | |||
| "traffic light", | |||
| "pen/pencil", | |||
| "microphone", | |||
| "sandals", | |||
| "canned", | |||
| "necklace", | |||
| "mirror", | |||
| "faucet", | |||
| "bicycle", | |||
| "bread", | |||
| "high heels", | |||
| "ring", | |||
| "van", | |||
| "watch", | |||
| "sink", | |||
| "horse", | |||
| "fish", | |||
| "apple", | |||
| "camera", | |||
| "candle", | |||
| "teddy bear", | |||
| "cake", | |||
| "motorcycle", | |||
| "wild bird", | |||
| "laptop", | |||
| "knife", | |||
| "traffic sign", | |||
| "cell phone", | |||
| "paddle", | |||
| "truck", | |||
| "cow", | |||
| "power outlet", | |||
| "clock", | |||
| "drum", | |||
| "fork", | |||
| "bus", | |||
| "hanger", | |||
| "nightstand", | |||
| "pot/pan", | |||
| "sheep", | |||
| "guitar", | |||
| "traffic cone", | |||
| "tea pot", | |||
| "keyboard", | |||
| "tripod", | |||
| "hockey", | |||
| "fan", | |||
| "dog", | |||
| "spoon", | |||
| "blackboard/whiteboard", | |||
| "balloon", | |||
| "air conditioner", | |||
| "cymbal", | |||
| "mouse", | |||
| "telephone", | |||
| "pickup truck", | |||
| "orange", | |||
| "banana", | |||
| "airplane", | |||
| "luggage", | |||
| "skis", | |||
| "soccer", | |||
| "trolley", | |||
| "oven", | |||
| "remote", | |||
| "baseball glove", | |||
| "paper towel", | |||
| "refrigerator", | |||
| "train", | |||
| "tomato", | |||
| "machinery vehicle", | |||
| "tent", | |||
| "shampoo/shower gel", | |||
| "head phone", | |||
| "lantern", | |||
| "donut", | |||
| "cleaning products", | |||
| "sailboat", | |||
| "tangerine", | |||
| "pizza", | |||
| "kite", | |||
| "computer box", | |||
| "elephant", | |||
| "toiletries", | |||
| "gas stove", | |||
| "broccoli", | |||
| "toilet", | |||
| "stroller", | |||
| "shovel", | |||
| "baseball bat", | |||
| "microwave", | |||
| "skateboard", | |||
| "surfboard", | |||
| "surveillance camera", | |||
| "gun", | |||
| "life saver", | |||
| "cat", | |||
| "lemon", | |||
| "liquid soap", | |||
| "zebra", | |||
| "duck", | |||
| "sports car", | |||
| "giraffe", | |||
| "pumpkin", | |||
| "piano", | |||
| "stop sign", | |||
| "radiator", | |||
| "converter", | |||
| "tissue ", | |||
| "carrot", | |||
| "washing machine", | |||
| "vent", | |||
| "cookies", | |||
| "cutting/chopping board", | |||
| "tennis racket", | |||
| "candy", | |||
| "skating and skiing shoes", | |||
| "scissors", | |||
| "folder", | |||
| "baseball", | |||
| "strawberry", | |||
| "bow tie", | |||
| "pigeon", | |||
| "pepper", | |||
| "coffee machine", | |||
| "bathtub", | |||
| "snowboard", | |||
| "suitcase", | |||
| "grapes", | |||
| "ladder", | |||
| "pear", | |||
| "american football", | |||
| "basketball", | |||
| "potato", | |||
| "paint brush", | |||
| "printer", | |||
| "billiards", | |||
| "fire hydrant", | |||
| "goose", | |||
| "projector", | |||
| "sausage", | |||
| "fire extinguisher", | |||
| "extension cord", | |||
| "facial mask", | |||
| "tennis ball", | |||
| "chopsticks", | |||
| "electronic stove and gas stove", | |||
| "pie", | |||
| "frisbee", | |||
| "kettle", | |||
| "hamburger", | |||
| "golf club", | |||
| "cucumber", | |||
| "clutch", | |||
| "blender", | |||
| "tong", | |||
| "slide", | |||
| "hot dog", | |||
| "toothbrush", | |||
| "facial cleanser", | |||
| "mango", | |||
| "deer", | |||
| "egg", | |||
| "violin", | |||
| "marker", | |||
| "ship", | |||
| "chicken", | |||
| "onion", | |||
| "ice cream", | |||
| "tape", | |||
| "wheelchair", | |||
| "plum", | |||
| "bar soap", | |||
| "scale", | |||
| "watermelon", | |||
| "cabbage", | |||
| "router/modem", | |||
| "golf ball", | |||
| "pine apple", | |||
| "crane", | |||
| "fire truck", | |||
| "peach", | |||
| "cello", | |||
| "notepaper", | |||
| "tricycle", | |||
| "toaster", | |||
| "helicopter", | |||
| "green beans", | |||
| "brush", | |||
| "carriage", | |||
| "cigar", | |||
| "earphone", | |||
| "penguin", | |||
| "hurdle", | |||
| "swing", | |||
| "radio", | |||
| "CD", | |||
| "parking meter", | |||
| "swan", | |||
| "garlic", | |||
| "french fries", | |||
| "horn", | |||
| "avocado", | |||
| "saxophone", | |||
| "trumpet", | |||
| "sandwich", | |||
| "cue", | |||
| "kiwi fruit", | |||
| "bear", | |||
| "fishing rod", | |||
| "cherry", | |||
| "tablet", | |||
| "green vegetables", | |||
| "nuts", | |||
| "corn", | |||
| "key", | |||
| "screwdriver", | |||
| "globe", | |||
| "broom", | |||
| "pliers", | |||
| "volleyball", | |||
| "hammer", | |||
| "eggplant", | |||
| "trophy", | |||
| "dates", | |||
| "board eraser", | |||
| "rice", | |||
| "tape measure/ruler", | |||
| "dumbbell", | |||
| "hamimelon", | |||
| "stapler", | |||
| "camel", | |||
| "lettuce", | |||
| "goldfish", | |||
| "meat balls", | |||
| "medal", | |||
| "toothpaste", | |||
| "antelope", | |||
| "shrimp", | |||
| "rickshaw", | |||
| "trombone", | |||
| "pomegranate", | |||
| "coconut", | |||
| "jellyfish", | |||
| "mushroom", | |||
| "calculator", | |||
| "treadmill", | |||
| "butterfly", | |||
| "egg tart", | |||
| "cheese", | |||
| "pig", | |||
| "pomelo", | |||
| "race car", | |||
| "rice cooker", | |||
| "tuba", | |||
| "crosswalk sign", | |||
| "papaya", | |||
| "hair drier", | |||
| "green onion", | |||
| "chips", | |||
| "dolphin", | |||
| "sushi", | |||
| "urinal", | |||
| "donkey", | |||
| "electric drill", | |||
| "spring rolls", | |||
| "tortoise/turtle", | |||
| "parrot", | |||
| "flute", | |||
| "measuring cup", | |||
| "shark", | |||
| "steak", | |||
| "poker card", | |||
| "binoculars", | |||
| "llama", | |||
| "radish", | |||
| "noodles", | |||
| "yak", | |||
| "mop", | |||
| "crab", | |||
| "microscope", | |||
| "barbell", | |||
| "bread/bun", | |||
| "baozi", | |||
| "lion", | |||
| "red cabbage", | |||
| "polar bear", | |||
| "lighter", | |||
| "seal", | |||
| "mangosteen", | |||
| "comb", | |||
| "eraser", | |||
| "pitaya", | |||
| "scallop", | |||
| "pencil case", | |||
| "saw", | |||
| "table tennis paddle", | |||
| "okra", | |||
| "starfish", | |||
| "eagle", | |||
| "monkey", | |||
| "durian", | |||
| "game board", | |||
| "rabbit", | |||
| "french horn", | |||
| "ambulance", | |||
| "asparagus", | |||
| "hoverboard", | |||
| "pasta", | |||
| "target", | |||
| "hotair balloon", | |||
| "chainsaw", | |||
| "lobster", | |||
| "iron", | |||
| "flashlight", | |||
| ) | |||
| @@ -0,0 +1,89 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import hashlib | |||
| import os | |||
| import tarfile | |||
| from ....distributed.group import is_distributed | |||
| from ....logger import get_logger | |||
| from ....utils.http_download import download_from_url | |||
| IMG_EXT = (".jpg", ".png", ".jpeg", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp") | |||
| logger = get_logger(__name__) | |||
| def _default_dataset_root(): | |||
| default_dataset_root = os.path.expanduser( | |||
| os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "megengine") | |||
| ) | |||
| return default_dataset_root | |||
| def load_raw_data_from_url( | |||
| url: str, filename: str, target_md5: str, raw_data_dir: str, timeout: int | |||
| ): | |||
| cached_file = os.path.join(raw_data_dir, filename) | |||
| logger.debug( | |||
| "load_raw_data_from_url: downloading to or using cached %s ...", cached_file | |||
| ) | |||
| if not os.path.exists(cached_file): | |||
| if is_distributed(): | |||
| logger.warning( | |||
| "Downloading raw data in DISTRIBUTED mode\n" | |||
| " File may be downloaded multiple times. We recommend\n" | |||
| " users to download in single process first." | |||
| ) | |||
| md5 = download_from_url(url, cached_file, http_read_timeout=timeout) | |||
| else: | |||
| md5 = calculate_md5(cached_file) | |||
| if target_md5 == md5: | |||
| logger.debug("%s exists with correct md5: %s", filename, target_md5) | |||
| else: | |||
| os.remove(cached_file) | |||
| raise RuntimeError("{} exists but fail to match md5".format(filename)) | |||
| def calculate_md5(filename): | |||
| m = hashlib.md5() | |||
| with open(filename, "rb") as f: | |||
| while True: | |||
| data = f.read(4096) | |||
| if not data: | |||
| break | |||
| m.update(data) | |||
| return m.hexdigest() | |||
| def is_img(filename): | |||
| return filename.lower().endswith(IMG_EXT) | |||
| def untar(path, to=None, remove=False): | |||
| if to is None: | |||
| to = os.path.dirname(path) | |||
| with tarfile.open(path, "r") as tar: | |||
| tar.extractall(path=to) | |||
| if remove: | |||
| os.remove(path) | |||
| def untargz(path, to=None, remove=False): | |||
| if path.endswith(".tar.gz"): | |||
| if to is None: | |||
| to = os.path.dirname(path) | |||
| with tarfile.open(path, "r:gz") as tar: | |||
| tar.extractall(path=to) | |||
| else: | |||
| raise ValueError("path %s does not end with .tar" % path) | |||
| if remove: | |||
| os.remove(path) | |||
| @@ -0,0 +1,195 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # --------------------------------------------------------------------- | |||
| # Part of the following code in this file refs to torchvision | |||
| # BSD 3-Clause License | |||
| # | |||
| # Copyright (c) Soumith Chintala 2016, | |||
| # All rights reserved. | |||
| # --------------------------------------------------------------------- | |||
| import collections.abc | |||
| import os | |||
| import xml.etree.ElementTree as ET | |||
| import cv2 | |||
| import numpy as np | |||
| from .meta_vision import VisionDataset | |||
| class PascalVOC(VisionDataset): | |||
| r"""`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Dataset. | |||
| """ | |||
| supported_order = ( | |||
| "image", | |||
| "boxes", | |||
| "boxes_category", | |||
| "mask", | |||
| "info", | |||
| ) | |||
| def __init__(self, root, image_set, *, order=None): | |||
| if ("boxes" in order or "boxes_category" in order) and "mask" in order: | |||
| raise ValueError( | |||
| "PascalVOC only supports boxes & boxes_category or mask, not both." | |||
| ) | |||
| super().__init__(root, order=order, supported_order=self.supported_order) | |||
| if not os.path.isdir(self.root): | |||
| raise RuntimeError("Dataset not found or corrupted.") | |||
| self.image_set = image_set | |||
| image_dir = os.path.join(self.root, "JPEGImages") | |||
| if "boxes" in order or "boxes_category" in order: | |||
| annotation_dir = os.path.join(self.root, "Annotations") | |||
| splitdet_dir = os.path.join(self.root, "ImageSets/Main") | |||
| split_f = os.path.join(splitdet_dir, image_set.rstrip("\n") + ".txt") | |||
| with open(os.path.join(split_f), "r") as f: | |||
| self.file_names = [x.strip() for x in f.readlines()] | |||
| self.images = [os.path.join(image_dir, x + ".jpg") for x in self.file_names] | |||
| self.annotations = [ | |||
| os.path.join(annotation_dir, x + ".xml") for x in self.file_names | |||
| ] | |||
| assert len(self.images) == len(self.annotations) | |||
| elif "mask" in order: | |||
| if "aug" in image_set: | |||
| mask_dir = os.path.join(self.root, "SegmentationClass_aug") | |||
| else: | |||
| mask_dir = os.path.join(self.root, "SegmentationClass") | |||
| splitmask_dir = os.path.join(self.root, "ImageSets/Segmentation") | |||
| split_f = os.path.join(splitmask_dir, image_set.rstrip("\n") + ".txt") | |||
| with open(os.path.join(split_f), "r") as f: | |||
| self.file_names = [x.strip() for x in f.readlines()] | |||
| self.images = [os.path.join(image_dir, x + ".jpg") for x in self.file_names] | |||
| self.masks = [os.path.join(mask_dir, x + ".png") for x in self.file_names] | |||
| assert len(self.images) == len(self.masks) | |||
| else: | |||
| raise NotImplementedError | |||
| def __getitem__(self, index): | |||
| target = [] | |||
| for k in self.order: | |||
| if k == "image": | |||
| image = cv2.imread(self.images[index], cv2.IMREAD_COLOR) | |||
| target.append(image) | |||
| elif k == "boxes": | |||
| anno = self.parse_voc_xml(ET.parse(self.annotations[index]).getroot()) | |||
| boxes = [obj["bndbox"] for obj in anno["annotation"]["object"]] | |||
| # boxes type xyxy | |||
| boxes = [ | |||
| (bb["xmin"], bb["ymin"], bb["xmax"], bb["ymax"]) for bb in boxes | |||
| ] | |||
| boxes = np.array(boxes, dtype=np.float32).reshape(-1, 4) | |||
| target.append(boxes) | |||
| elif k == "boxes_category": | |||
| anno = self.parse_voc_xml(ET.parse(self.annotations[index]).getroot()) | |||
| boxes_category = [obj["name"] for obj in anno["annotation"]["object"]] | |||
| boxes_category = [ | |||
| self.class_names.index(bc) + 1 for bc in boxes_category | |||
| ] | |||
| boxes_category = np.array(boxes_category, dtype=np.int32) | |||
| target.append(boxes_category) | |||
| elif k == "mask": | |||
| if "aug" in self.image_set: | |||
| mask = cv2.imread(self.masks[index], cv2.IMREAD_GRAYSCALE) | |||
| else: | |||
| mask = cv2.imread(self.masks[index], cv2.IMREAD_COLOR) | |||
| mask = self._trans_mask(mask) | |||
| mask = mask[:, :, np.newaxis] | |||
| target.append(mask) | |||
| elif k == "info": | |||
| if image is None: | |||
| image = cv2.imread(self.images[index], cv2.IMREAD_COLOR) | |||
| info = [image.shape[0], image.shape[1], self.file_names[index]] | |||
| target.append(info) | |||
| else: | |||
| raise NotImplementedError | |||
| return tuple(target) | |||
| def __len__(self): | |||
| return len(self.images) | |||
| def _trans_mask(self, mask): | |||
| label = np.ones(mask.shape[:2]) * 255 | |||
| for i in range(len(self.class_colors)): | |||
| b, g, r = self.class_colors[i] | |||
| label[ | |||
| (mask[:, :, 0] == b) & (mask[:, :, 1] == g) & (mask[:, :, 2] == r) | |||
| ] = i | |||
| return label.astype(np.uint8) | |||
| def parse_voc_xml(self, node): | |||
| voc_dict = {} | |||
| children = list(node) | |||
| if children: | |||
| def_dic = collections.defaultdict(list) | |||
| for dc in map(self.parse_voc_xml, children): | |||
| for ind, v in dc.items(): | |||
| def_dic[ind].append(v) | |||
| if node.tag == "annotation": | |||
| def_dic["object"] = [def_dic["object"]] | |||
| voc_dict = { | |||
| node.tag: { | |||
| ind: v[0] if len(v) == 1 else v for ind, v in def_dic.items() | |||
| } | |||
| } | |||
| if node.text: | |||
| text = node.text.strip() | |||
| if not children: | |||
| voc_dict[node.tag] = text | |||
| return voc_dict | |||
| class_names = ( | |||
| "aeroplane", | |||
| "bicycle", | |||
| "bird", | |||
| "boat", | |||
| "bottle", | |||
| "bus", | |||
| "car", | |||
| "cat", | |||
| "chair", | |||
| "cow", | |||
| "diningtable", | |||
| "dog", | |||
| "horse", | |||
| "motorbike", | |||
| "person", | |||
| "pottedplant", | |||
| "sheep", | |||
| "sofa", | |||
| "train", | |||
| "tvmonitor", | |||
| ) | |||
| class_colors = [ | |||
| [0, 0, 128], | |||
| [0, 128, 0], | |||
| [0, 128, 128], | |||
| [128, 0, 0], | |||
| [128, 0, 128], | |||
| [128, 128, 0], | |||
| [128, 128, 128], | |||
| [0, 0, 64], | |||
| [0, 0, 192], | |||
| [0, 128, 64], | |||
| [0, 128, 192], | |||
| [128, 0, 64], | |||
| [128, 0, 192], | |||
| [128, 128, 64], | |||
| [128, 128, 192], | |||
| [0, 64, 0], | |||
| [0, 64, 128], | |||
| [0, 192, 0], | |||
| [0, 192, 128], | |||
| [128, 64, 0], | |||
| ] | |||
| @@ -0,0 +1,274 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import collections.abc | |||
| import math | |||
| from abc import ABC | |||
| from typing import Any, Generator, Iterator, List, Union | |||
| import numpy as np | |||
| import megengine.distributed as dist | |||
| class Sampler(ABC): | |||
| def __init__( | |||
| self, | |||
| dataset, | |||
| batch_size=1, | |||
| drop_last=False, | |||
| num_samples=None, | |||
| world_size=None, | |||
| rank=None, | |||
| seed=None, | |||
| ): | |||
| r""" | |||
| An abstract class for all sampler | |||
| :type dataset: `dataset` | |||
| :param dataset: dataset to sample from | |||
| :type batch_size: positive integer | |||
| :param batch_size: batch size for batch method | |||
| :type drop_last: bool | |||
| :param drop_last: set ``True`` to drop the last incomplete batch, | |||
| if the dataset size is not divisible by the batch size. If ``False`` and | |||
| the size of dataset is not divisible by the batch_size, then the last batch will | |||
| be smaller. (default: ``False``) | |||
| :type num_samples: positive integer | |||
| :param num_samples: number of samples assigned to one rank | |||
| :type world_size: positive integer | |||
| :param world_size: number of ranks | |||
| :type rank: non-negative integer within 0 and world_size | |||
| :param rank: rank id, non-negative interger within 0 and ``world_size`` | |||
| :type seed: non-negative integer | |||
| :param seed: seed for random operators | |||
| """ | |||
| if ( | |||
| not isinstance(batch_size, int) | |||
| or isinstance(batch_size, bool) | |||
| or batch_size <= 0 | |||
| ): | |||
| raise ValueError( | |||
| "batch_size should be a positive integer value, " | |||
| "but got batch_size={}".format(batch_size) | |||
| ) | |||
| if not isinstance(drop_last, bool): | |||
| raise ValueError( | |||
| "drop_last should be a boolean value, but got " | |||
| "drop_last={}".format(drop_last) | |||
| ) | |||
| if num_samples is not None and ( | |||
| not isinstance(num_samples, int) | |||
| or isinstance(num_samples, bool) | |||
| or num_samples <= 0 | |||
| ): | |||
| raise ValueError( | |||
| "num_samples should be a positive integer " | |||
| "value, but got num_samples={}".format(num_samples) | |||
| ) | |||
| self.batch_size = batch_size | |||
| self.dataset = dataset | |||
| self.drop_last = drop_last | |||
| if world_size is None: | |||
| world_size = dist.get_world_size() if dist.is_distributed() else 1 | |||
| self.world_size = world_size | |||
| if rank is None: | |||
| rank = dist.get_rank() if dist.is_distributed() else 0 | |||
| self.rank = rank | |||
| if num_samples is None: | |||
| num_samples = len(self.dataset) | |||
| self.num_samples = int(math.ceil(num_samples / self.world_size)) | |||
| # Make sure seeds are the same at each rank | |||
| if seed is None and self.world_size > 1: | |||
| seed = 0 | |||
| self.rng = np.random.RandomState(seed) | |||
| def __iter__(self) -> Union[Generator, Iterator]: | |||
| return self.batch() | |||
| def __len__(self) -> int: | |||
| if self.drop_last: | |||
| return self.num_samples // self.batch_size | |||
| else: | |||
| return int(math.ceil(self.num_samples / self.batch_size)) | |||
| def sample(self): | |||
| """ | |||
| return a list contains all sample indices | |||
| """ | |||
| raise NotImplementedError | |||
| def scatter(self, indices) -> List: | |||
| r""" | |||
| scatter method is used for splitting indices into subset, each subset | |||
| will be assigned to a rank. Indices are evenly splitted by default. | |||
| If customized indices assignment method is needed, please rewrite this method | |||
| """ | |||
| total_size = self.num_samples * self.world_size | |||
| # add extra indices to make it evenly divisible | |||
| indices += indices[: (total_size - len(indices))] | |||
| assert len(indices) == total_size | |||
| # subsample | |||
| indices = indices[self.rank : total_size : self.world_size] | |||
| assert len(indices) == self.num_samples | |||
| return indices | |||
| def batch(self) -> Iterator[List[Any]]: | |||
| r""" | |||
| batch method provides a batch indices generator | |||
| """ | |||
| indices = list(self.sample()) | |||
| # user might pass the world_size parameter without dist, | |||
| # so dist.is_distributed() should not be used | |||
| if self.world_size > 1: | |||
| indices = self.scatter(indices) | |||
| step, length = self.batch_size, len(indices) | |||
| batch_index = [indices[i : i + step] for i in range(0, length, step)] | |||
| if self.drop_last and len(batch_index[-1]) < self.batch_size: | |||
| batch_index.pop() | |||
| return iter(batch_index) | |||
| class SequentialSampler(Sampler): | |||
| def __init__( | |||
| self, | |||
| dataset, | |||
| batch_size=1, | |||
| drop_last=False, | |||
| indices=None, | |||
| world_size=None, | |||
| rank=None, | |||
| ): | |||
| r""" | |||
| Sample elements sequentially | |||
| """ | |||
| super().__init__(dataset, batch_size, drop_last, None, world_size, rank) | |||
| if indices is not None and not isinstance(indices, collections.abc.Sequence): | |||
| raise ValueError( | |||
| "indices should be None or a sequence, " | |||
| "but got indices={}".format(indices) | |||
| ) | |||
| self.indices = indices | |||
| def sample(self) -> Iterator[Any]: | |||
| r""" | |||
| return a generator | |||
| """ | |||
| if self.indices is None: | |||
| return iter(range(len(self.dataset))) | |||
| else: | |||
| return self.indices | |||
| class RandomSampler(Sampler): | |||
| def __init__( | |||
| self, | |||
| dataset, | |||
| batch_size=1, | |||
| drop_last=False, | |||
| indices=None, | |||
| world_size=None, | |||
| rank=None, | |||
| seed=None, | |||
| ): | |||
| r""" | |||
| Sample elements randomly without replacement | |||
| """ | |||
| super().__init__(dataset, batch_size, drop_last, None, world_size, rank, seed) | |||
| if indices is not None and not isinstance(indices, collections.abc.Sequence): | |||
| raise ValueError( | |||
| "indices should be None or a sequence, " | |||
| "but got indices={}".format(indices) | |||
| ) | |||
| self.indices = indices | |||
| def sample(self) -> List: | |||
| if self.indices is None: | |||
| return self.rng.permutation(len(self.dataset)).tolist() | |||
| else: | |||
| return self.rng.permutation(self.indices).tolist() | |||
| class ReplacementSampler(Sampler): | |||
| def __init__( | |||
| self, | |||
| dataset, | |||
| batch_size=1, | |||
| drop_last=False, | |||
| num_samples=None, | |||
| weights=None, | |||
| world_size=None, | |||
| rank=None, | |||
| seed=None, | |||
| ): | |||
| r""" | |||
| Sample elements randomly with replacement | |||
| :type weights: List | |||
| :param weights: weights for sampling indices, it could be unnormalized weights | |||
| """ | |||
| super().__init__( | |||
| dataset, batch_size, drop_last, num_samples, world_size, rank, seed | |||
| ) | |||
| if weights is not None: | |||
| if not isinstance(weights, collections.abc.Sequence): | |||
| raise ValueError( | |||
| "weights should be None or a sequence, " | |||
| "but got weights={}".format(weights) | |||
| ) | |||
| if len(weights) != len(dataset): | |||
| raise ValueError( | |||
| "len(dataset)={} should be equal to" | |||
| "len(weights)={}".format(len(dataset), len(weights)) | |||
| ) | |||
| self.weights = weights | |||
| if self.weights is not None: | |||
| self.weights = np.array(weights) / sum(weights) | |||
| def sample(self) -> List: | |||
| n = len(self.dataset) | |||
| if self.weights is None: | |||
| return self.rng.randint(n, size=self.num_samples).tolist() | |||
| else: | |||
| return self.rng.multinomial(n, self.weights, self.num_samples).tolist() | |||
| class Infinite(Sampler): | |||
| r"""Infinite Sampler warper for basic sampler""" | |||
| def sample(self): | |||
| raise NotImplementedError("sample method not supported in Infinite") | |||
| def __init__(self, sampler): | |||
| self.sampler = sampler | |||
| self.sampler_iter = iter(self.sampler) | |||
| def __iter__(self): | |||
| return self | |||
| def __next__(self): | |||
| try: | |||
| index = next(self.sampler_iter) | |||
| except StopIteration: | |||
| self.sampler_iter = iter(self.sampler) | |||
| index = next(self.sampler_iter) | |||
| return index | |||
| def __len__(self): | |||
| return np.iinfo(np.int64).max | |||
| @@ -0,0 +1,10 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from .meta_transform import PseudoTransform, Transform | |||
| from .vision import * | |||
| @@ -0,0 +1,31 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from abc import ABC, abstractmethod | |||
| from typing import Sequence, Tuple | |||
| class Transform(ABC): | |||
| """ | |||
| rewrite apply method in subclass | |||
| """ | |||
| def apply_batch(self, inputs: Sequence[Tuple]): | |||
| return tuple(self.apply(input) for input in inputs) | |||
| @abstractmethod | |||
| def apply(self, input: Tuple): | |||
| pass | |||
| def __repr__(self): | |||
| return self.__class__.__name__ | |||
| class PseudoTransform(Transform): | |||
| def apply(self, input: Tuple): | |||
| return input | |||
| @@ -0,0 +1,9 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from .transform import * | |||
| @@ -0,0 +1,111 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import collections.abc | |||
| import functools | |||
| import random | |||
| import cv2 | |||
| import numpy as np | |||
| def wrap_keepdims(func): | |||
| """Wraper to keep the dimension of input images unchanged""" | |||
| @functools.wraps(func) | |||
| def wrapper(image, *args, **kwargs): | |||
| if len(image.shape) != 3: | |||
| raise ValueError( | |||
| "image must have 3 dims, but got {} dims".format(len(image.shape)) | |||
| ) | |||
| ret = func(image, *args, **kwargs) | |||
| if len(ret.shape) == 2: | |||
| ret = ret[:, :, np.newaxis] | |||
| return ret | |||
| return wrapper | |||
| @wrap_keepdims | |||
| def to_gray(image): | |||
| r""" | |||
| Change BGR format image's color space to gray | |||
| :param image: Input BGR format image, with (H, W, C) shape | |||
| :return: Gray format image, with (H, W, C) shape | |||
| """ | |||
| return cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) | |||
| @wrap_keepdims | |||
| def to_bgr(image): | |||
| r""" | |||
| Change gray format image's color space to BGR | |||
| :param image: input Gray format image, with (H, W, C) shape | |||
| :return: BGR format image, with (H, W, C) shape | |||
| """ | |||
| return cv2.cvtColor(image, cv2.COLOR_GRAY2BGR) | |||
| @wrap_keepdims | |||
| def pad(input, size, value): | |||
| r""" | |||
| Pad input data with *value* and given *size* | |||
| :param input: Input data, with (H, W, C) shape | |||
| :param size: Padding size of input data, it could be integer or sequence. | |||
| If it's an integer, the input data will be padded in four directions. | |||
| If it's a sequence contains two integer, the bottom and right side | |||
| of input data will be padded. | |||
| If it's a sequence contains four integer, the top, bottom, left, right | |||
| side of input data will be padded with given size. | |||
| :param value: Padding value of data, could be a sequence of int or float. | |||
| if it's float value, the dtype of image will be casted to float32 also. | |||
| :return: Padded image | |||
| """ | |||
| if isinstance(size, int): | |||
| size = (size, size, size, size) | |||
| elif isinstance(size, collections.abc.Sequence) and len(size) == 2: | |||
| size = (0, size[0], 0, size[1]) | |||
| if np.array(value).dtype == float: | |||
| input = input.astype(np.float32) | |||
| return cv2.copyMakeBorder(input, *size, cv2.BORDER_CONSTANT, value=value) | |||
| @wrap_keepdims | |||
| def flip(image, flipCode): | |||
| r""" | |||
| Accordding to the flipCode (the type of flip), flip the input image | |||
| :param image: Input image, with (H, W, C) shape | |||
| :param flipCode: code that indicates the type of flip. | |||
| 1 : Flip horizontally | |||
| 0 : Flip vertically | |||
| -1 : Flip horizontally and vertically | |||
| :return: BGR format image, with (H, W, C) shape | |||
| """ | |||
| return cv2.flip(image, flipCode=flipCode) | |||
| @wrap_keepdims | |||
| def resize(input, size, interpolation=cv2.INTER_LINEAR): | |||
| r""" | |||
| resize the input data to given size | |||
| :param input: Input data, could be image or masks, with (H, W, C) shape | |||
| :param size: Target size of input data, with (height, width) shape. | |||
| :param interpolation: Interpolation method. | |||
| :return: Resized data, with (H, W, C) shape | |||
| """ | |||
| if len(size) != 2: | |||
| raise ValueError("resize needs (h, w), but got {}".format(size)) | |||
| if isinstance(interpolation, collections.abc.Sequence): | |||
| interpolation = random.choice(interpolation) | |||
| return cv2.resize(input, size[::-1], interpolation=interpolation) | |||
| @@ -0,0 +1,89 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import os | |||
| from .core._imperative_rt.common import CompNode, DeviceType | |||
| __all__ = [ | |||
| "is_cuda_available", | |||
| "get_device_count", | |||
| "get_default_device", | |||
| "set_default_device", | |||
| ] | |||
| _default_device = os.getenv("MGE_DEFAULT_DEVICE", "xpux") | |||
| def _valid_device(inp): | |||
| if isinstance(inp, str) and len(inp) == 4: | |||
| if inp[0] in {"x", "c", "g"} and inp[1:3] == "pu": | |||
| if inp[3] == "x" or inp[3].isdigit(): | |||
| return True | |||
| return False | |||
| def _str2device_type(type_str: str, allow_unspec: bool = True): | |||
| type_str = type_str.upper() | |||
| if type_str == "CPU": | |||
| return DeviceType.CPU | |||
| elif type_str == "GPU" or type_str == "CUDA": | |||
| return DeviceType.CUDA | |||
| else: | |||
| assert allow_unspec and str == "XPU", "bad device type" | |||
| return DeviceType.UNSPEC | |||
| def get_device_count(device_type: str) -> int: | |||
| """Gets number of devices installed on this system. | |||
| :param device_type: device type, one of 'gpu' or 'cpu' | |||
| """ | |||
| device_type_set = ("cpu", "gpu") | |||
| assert device_type in device_type_set, "device must be one of {}".format( | |||
| device_type_set | |||
| ) | |||
| device_type = _str2device_type(device_type) | |||
| return CompNode._get_device_count(device_type, False) | |||
| def is_cuda_available() -> bool: | |||
| """Returns whether cuda device is available on this system. | |||
| """ | |||
| t = _str2device_type("gpu") | |||
| return CompNode._get_device_count(t, False) > 0 | |||
| def set_default_device(device: str = "xpux"): | |||
| r"""Sets default computing node. | |||
| :param device: default device type. The type can be 'cpu0', 'cpu1', etc., | |||
| or 'gpu0', 'gpu1', etc., to specify the particular cpu or gpu to use. | |||
| 'cpux' and 'gpux' can also be used to specify any number of cpu or gpu devices. | |||
| 'multithread' device type is avaliable when inference, which implements | |||
| multi-threading parallelism at the operator level. For example, | |||
| 'multithread4' will compute with 4 threads. which implements | |||
| The default value is 'xpux' to specify any device available. | |||
| It can also be set by environmental variable `MGE_DEFAULT_DEVICE`. | |||
| """ | |||
| global _default_device # pylint: disable=global-statement | |||
| assert _valid_device(device), "Invalid device name {}".format(device) | |||
| _default_device = device | |||
| def get_default_device() -> str: | |||
| r"""Gets default computing node. | |||
| It returns the value set by :func:`~.set_default_device`. | |||
| """ | |||
| return _default_device | |||
| @@ -0,0 +1,25 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from .group import ( | |||
| WORLD, | |||
| get_backend, | |||
| get_client, | |||
| get_mm_server_addr, | |||
| get_py_server_addr, | |||
| get_rank, | |||
| get_world_size, | |||
| group_barrier, | |||
| init_process_group, | |||
| is_distributed, | |||
| new_group, | |||
| ) | |||
| from .helper import synchronized | |||
| from .launcher import launcher | |||
| from .server import Client, Server | |||
| from .util import get_free_ports | |||
| @@ -0,0 +1,176 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from typing import List, Optional, Tuple | |||
| from ..device import set_default_device | |||
| from .server import Client, Server | |||
| class StaticData: | |||
| server = None | |||
| client = None | |||
| master_ip = None | |||
| py_server_port = None | |||
| mm_server_port = None | |||
| world_size = None | |||
| proc_rank = None | |||
| device = None | |||
| backend = None | |||
| next_stream = None | |||
| _sd = None | |||
| class Group: | |||
| def __init__(self, proc_ranks): | |||
| if len(proc_ranks) == 0: # empty group | |||
| self.proc_ranks = None | |||
| self.stream = None | |||
| else: | |||
| self.reset(proc_ranks) | |||
| def reset(self, proc_ranks): | |||
| self.check(proc_ranks) | |||
| self.proc_ranks = proc_ranks | |||
| self.stream = _sd.next_stream | |||
| _sd.next_stream += 1 | |||
| def check(self, proc_ranks): | |||
| assert _sd is not None, "please call init_process_group first" | |||
| for rank in proc_ranks: | |||
| assert isinstance(rank, int) | |||
| assert rank >= 0 and rank < _sd.world_size | |||
| assert _sd.proc_rank in proc_ranks | |||
| @property | |||
| def size(self): | |||
| assert len(self.proc_ranks) > 0, "invalid group" | |||
| return len(self.proc_ranks) | |||
| @property | |||
| def key(self): | |||
| assert len(self.proc_ranks) > 0, "invalid group" | |||
| return ",".join(map(str, self.proc_ranks)) | |||
| @property | |||
| def rank(self): | |||
| assert len(self.proc_ranks) > 0, "invalid group" | |||
| return self.proc_ranks.index(_sd.proc_rank) | |||
| @property | |||
| def comp_node(self): | |||
| assert len(self.proc_ranks) > 0, "invalid group" | |||
| return "gpu{}:{}".format(_sd.device, self.stream) | |||
| WORLD = Group([]) | |||
| def init_process_group( | |||
| master_ip: str, | |||
| port: int, | |||
| world_size: int, | |||
| rank: int, | |||
| device: int, | |||
| backend: Optional[str] = "nccl", | |||
| ) -> None: | |||
| """Initialize the distributed process group and specify the device used in the current process | |||
| :param master_ip: IP address of the master node | |||
| :param port: Port available for all processes to communicate | |||
| :param world_size: Total number of processes participating in the job | |||
| :param rank: Rank of the current process | |||
| :param device: The GPU device id to bind this process to | |||
| :param backend: Communicator backend, currently support 'nccl' and 'ucx' | |||
| """ | |||
| if not isinstance(master_ip, str): | |||
| raise TypeError("Expect type str but got {}".format(type(master_ip))) | |||
| if not isinstance(port, int): | |||
| raise TypeError("Expect type int but got {}".format(type(port))) | |||
| if not isinstance(world_size, int): | |||
| raise TypeError("Expect type int but got {}".format(type(world_size))) | |||
| if not isinstance(rank, int): | |||
| raise TypeError("Expect type int but got {}".format(type(rank))) | |||
| if not isinstance(device, int): | |||
| raise TypeError("Expect type int but got {}".format(type(backend))) | |||
| if not isinstance(backend, str): | |||
| raise TypeError("Expect type str but got {}".format(type(backend))) | |||
| global _sd | |||
| assert _sd is None, "init_process_group should be called only once" | |||
| _sd = StaticData() | |||
| assert world_size > 1 | |||
| assert rank >= 0 and rank < world_size | |||
| assert port > 0 | |||
| _sd.client = Client(master_ip, port) | |||
| _sd.master_ip = master_ip | |||
| _sd.py_server_port = port | |||
| _sd.mm_server_port = _sd.client.get_mm_server_port() | |||
| _sd.world_size = world_size | |||
| _sd.proc_rank = rank | |||
| _sd.device = device | |||
| _sd.backend = backend | |||
| _sd.next_stream = 1 | |||
| WORLD.reset(list(range(world_size))) | |||
| set_default_device("gpu{}".format(device)) | |||
| def is_distributed() -> bool: | |||
| """Return True if the distributed process group has been initialized""" | |||
| return _sd is not None | |||
| def get_rank() -> int: | |||
| """Get the rank of the current process""" | |||
| return _sd.proc_rank if _sd is not None else 0 | |||
| def get_world_size() -> int: | |||
| """Get the total number of processes participating in the job""" | |||
| return _sd.world_size if _sd is not None else 1 | |||
| def get_backend() -> str: | |||
| """Get the backend str""" | |||
| assert _sd is not None, "please call init_process_group first" | |||
| return _sd.backend if _sd is not None else None | |||
| def get_py_server_addr() -> Tuple[str, int]: | |||
| """Get master_ip and port of python XML RPC server""" | |||
| assert _sd is not None, "please call init_process_group first" | |||
| return _sd.master_ip, _sd.py_server_port | |||
| def get_mm_server_addr() -> Tuple[str, int]: | |||
| """Get master_ip and port of C++ mm_server""" | |||
| assert _sd is not None, "please call init_process_group first" | |||
| return _sd.master_ip, _sd.mm_server_port | |||
| def get_client() -> Client: | |||
| """Get client of python XML RPC server""" | |||
| assert _sd is not None, "please call init_process_group first" | |||
| return _sd.client | |||
| def new_group(proc_ranks: List[int]) -> Group: | |||
| """Build a subgroup containing certain ranks""" | |||
| return Group(proc_ranks) | |||
| def group_barrier(group: Optional[Group] = WORLD) -> None: | |||
| """Block until all ranks in the group reach this barrier""" | |||
| assert isinstance(group, Group) | |||
| _sd.client.group_barrier(group.key, group.size) | |||
| @@ -0,0 +1,28 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import functools | |||
| from typing import Callable | |||
| from .group import group_barrier, is_distributed | |||
| def synchronized(func: Callable): | |||
| """Decorator. Decorated function will synchronize when finished. | |||
| Specifically, we use this to prevent data race during hub.load""" | |||
| @functools.wraps(func) | |||
| def wrapper(*args, **kwargs): | |||
| if not is_distributed(): | |||
| return func(*args, **kwargs) | |||
| ret = func(*args, **kwargs) | |||
| group_barrier() | |||
| return ret | |||
| return wrapper | |||
| @@ -0,0 +1,68 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import multiprocessing as mp | |||
| from ..device import get_device_count | |||
| from .group import init_process_group | |||
| from .server import Server | |||
| from .util import get_free_ports | |||
| def _get_device_count(): | |||
| """use subprocess to avoid cuda environment initialization in the main process""" | |||
| def run(q): | |||
| count = get_device_count("gpu") | |||
| q.put(count) | |||
| q = mp.Queue() | |||
| p = mp.Process(target=run, args=(q,)) | |||
| p.start() | |||
| p.join() | |||
| return q.get() | |||
| def _run_wrapped(func, master_ip, port, world_size, rank, dev, args, kwargs): | |||
| """init distributed process group and run wrapped function""" | |||
| init_process_group( | |||
| master_ip=master_ip, port=port, world_size=world_size, rank=rank, device=dev | |||
| ) | |||
| func(*args, **kwargs) | |||
| def launcher(n_gpus): | |||
| """decorator for launching multiple processes in single-machine multi-gpu training""" | |||
| count = _get_device_count() | |||
| assert isinstance(n_gpus, int) and n_gpus > 1, "invalid n_gpus" | |||
| assert n_gpus <= count, "{} gpus required, {} gpus provided".format(n_gpus, count) | |||
| def decorator(func): | |||
| def wrapper(*args, **kwargs): | |||
| master_ip = "localhost" | |||
| port = get_free_ports(1)[0] | |||
| server = Server(port) | |||
| procs = [] | |||
| for rank in range(n_gpus): | |||
| p = mp.Process( | |||
| target=_run_wrapped, | |||
| args=(func, master_ip, port, n_gpus, rank, rank, args, kwargs), | |||
| ) | |||
| p.start() | |||
| procs.append(p) | |||
| for rank in range(n_gpus): | |||
| procs[rank].join() | |||
| code = procs[rank].exitcode | |||
| assert code == 0, "subprocess {} exit with code {}".format(rank, code) | |||
| return wrapper | |||
| return decorator | |||
| @@ -0,0 +1,170 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import multiprocessing as mp | |||
| import threading | |||
| import time | |||
| from collections import defaultdict | |||
| from functools import partial | |||
| from socketserver import ThreadingMixIn | |||
| from xmlrpc.client import ServerProxy | |||
| from xmlrpc.server import SimpleXMLRPCServer | |||
| from ..core._imperative_rt.utils import create_mm_server | |||
| from .util import get_free_ports | |||
| class Future: | |||
| def __init__(self, ack=True): | |||
| self.ready = threading.Event() | |||
| self.ack = threading.Event() if ack else None | |||
| def set(self, value): | |||
| self.value = value | |||
| self.ready.set() | |||
| if self.ack: | |||
| self.ack.wait() | |||
| def get(self): | |||
| self.ready.wait() | |||
| if self.ack: | |||
| self.ack.set() | |||
| return self.value | |||
| class Methods: | |||
| def __init__(self, mm_server_port): | |||
| self.lock = threading.Lock() | |||
| self.mm_server_port = mm_server_port | |||
| self.dict_is_grad = defaultdict(partial(Future, True)) | |||
| self.dict_remote_tracer = defaultdict(partial(Future, True)) | |||
| self.dict_pack_list = defaultdict(partial(Future, False)) | |||
| self.dict_barrier_counter = defaultdict(int) | |||
| self.dict_barrier_event = defaultdict(threading.Event) | |||
| def connect(self): | |||
| return True | |||
| def get_mm_server_port(self): | |||
| return self.mm_server_port | |||
| def set_is_grad(self, rank_peer, is_grad): | |||
| with self.lock: | |||
| future = self.dict_is_grad[rank_peer] | |||
| future.set(is_grad) | |||
| return True | |||
| def check_is_grad(self, rank_peer): | |||
| with self.lock: | |||
| future = self.dict_is_grad[rank_peer] | |||
| ret = future.get() | |||
| with self.lock: | |||
| del self.dict_is_grad[rank_peer] | |||
| return ret | |||
| def set_remote_tracer(self, rank_peer, tracer_set): | |||
| with self.lock: | |||
| future = self.dict_remote_tracer[rank_peer] | |||
| future.set(tracer_set) | |||
| return True | |||
| def check_remote_tracer(self, rank_peer): | |||
| with self.lock: | |||
| future = self.dict_remote_tracer[rank_peer] | |||
| ret = future.get() | |||
| with self.lock: | |||
| del self.dict_remote_tracer[rank_peer] | |||
| return ret | |||
| def set_pack_list(self, key, pack_list): | |||
| with self.lock: | |||
| future = self.dict_pack_list[key] | |||
| future.set(pack_list) | |||
| return True | |||
| def get_pack_list(self, key): | |||
| with self.lock: | |||
| future = self.dict_pack_list[key] | |||
| return future.get() | |||
| def group_barrier(self, key, size): | |||
| with self.lock: | |||
| self.dict_barrier_counter[key] += 1 | |||
| counter = self.dict_barrier_counter[key] | |||
| event = self.dict_barrier_event[key] | |||
| if counter == size: | |||
| del self.dict_barrier_counter[key] | |||
| del self.dict_barrier_event[key] | |||
| event.set() | |||
| else: | |||
| event.wait() | |||
| return True | |||
| class ThreadXMLRPCServer(ThreadingMixIn, SimpleXMLRPCServer): | |||
| pass | |||
| def start_server(py_server_port, mm_server_port): | |||
| server = ThreadXMLRPCServer(("0.0.0.0", py_server_port), logRequests=False) | |||
| server.register_instance(Methods(mm_server_port)) | |||
| server.serve_forever() | |||
| class Server: | |||
| def __init__(self, port): | |||
| self.py_server_port = get_free_ports(1)[0] if port == 0 else port | |||
| self.mm_server_port = create_mm_server("0.0.0.0", 0) | |||
| self.proc = mp.Process( | |||
| target=start_server, | |||
| args=(self.py_server_port, self.mm_server_port), | |||
| daemon=True, | |||
| ) | |||
| self.proc.start() | |||
| class Client: | |||
| def __init__(self, master_ip, port): | |||
| self.master_ip = master_ip | |||
| self.port = port | |||
| self.connect() | |||
| def connect(self): | |||
| while True: | |||
| try: | |||
| self.proxy = ServerProxy( | |||
| "http://{}:{}".format(self.master_ip, self.port) | |||
| ) | |||
| if self.proxy.connect(): | |||
| break | |||
| except: | |||
| time.sleep(1) | |||
| def get_mm_server_port(self): | |||
| return self.proxy.get_mm_server_port() | |||
| def set_is_grad(self, rank_peer, is_grad): | |||
| self.proxy.set_is_grad(rank_peer, is_grad) | |||
| def check_is_grad(self, rank_peer): | |||
| return self.proxy.check_is_grad(rank_peer) | |||
| def set_remote_tracer(self, rank_peer, tracer_set): | |||
| self.proxy.set_remote_tracer(rank_peer, tracer_set) | |||
| def check_remote_tracer(self, rank_peer): | |||
| return self.proxy.check_remote_tracer(rank_peer) | |||
| def set_pack_list(self, key, pack_list): | |||
| self.proxy.set_pack_list(key, pack_list) | |||
| def get_pack_list(self, key): | |||
| return self.proxy.get_pack_list(key) | |||
| def group_barrier(self, key, size): | |||
| self.proxy.group_barrier(key, size) | |||
| @@ -0,0 +1,25 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import functools | |||
| import socket | |||
| from typing import List | |||
| def get_free_ports(num: int) -> List[int]: | |||
| """Get one or more free ports. | |||
| """ | |||
| socks, ports = [], [] | |||
| for i in range(num): | |||
| sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | |||
| sock.bind(("", 0)) | |||
| socks.append(sock) | |||
| ports.append(sock.getsockname()[1]) | |||
| for sock in socks: | |||
| sock.close() | |||
| return ports | |||
| @@ -0,0 +1,32 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # pylint: disable=redefined-builtin | |||
| from . import distributed | |||
| from .elemwise import * | |||
| from .graph import add_update | |||
| from .loss import ( | |||
| binary_cross_entropy, | |||
| cross_entropy, | |||
| cross_entropy_with_softmax, | |||
| hinge_loss, | |||
| l1_loss, | |||
| nll_loss, | |||
| smooth_l1_loss, | |||
| square_loss, | |||
| triplet_margin_loss, | |||
| ) | |||
| from .math import * | |||
| from .nn import * | |||
| from .quantized import conv_bias_activation | |||
| from .tensor import * | |||
| from .utils import accuracy, zero_grad | |||
| # delete namespace | |||
| # pylint: disable=undefined-variable | |||
| # del elemwise, graph, loss, math, nn, tensor # type: ignore[name-defined] | |||
| @@ -0,0 +1,49 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import os | |||
| _conv_execution_strategy = os.getenv("MEGENGINE_CONV_EXECUTION_STRATEGY", "HEURISTIC") | |||
| def get_conv_execution_strategy() -> str: | |||
| """Returns the execuation strategy of :class:`~.Conv2d`. | |||
| See :func:`~.set_conv_execution_strategy` for possible return values | |||
| """ | |||
| return _conv_execution_strategy | |||
| def set_conv_execution_strategy(option: str): | |||
| """Sets the execuation strategy of :class:`~.Conv2d`. | |||
| :param option: Decides how :class:`~.Conv2d` algorithm is chosen. | |||
| Available values: | |||
| * 'HEURISTIC' uses heuristic to choose the fastest algorithm. | |||
| * 'PROFILE' runs possible algorithms on real device to find the best. | |||
| * 'PROFILE_HEURISTIC' uses profile result and heuristic to choose the fastest algorithm. | |||
| * 'PROFILE_REPRODUCIBLE' uses the fastest of profile result that is also reproducible. | |||
| * 'HEURISTIC_REPRODUCIBLE' uses heuristic to choose the fastest algorithm that is also reproducible. | |||
| The default strategy is 'HEURISTIC'. | |||
| It can also be set through the environmental variable 'MEGENGINE_CONV_EXECUTION_STRATEGY'. | |||
| """ | |||
| valid_option = ( | |||
| "HEURISTIC", | |||
| "PROFILE", | |||
| "PROFILE_HEURISTIC", | |||
| "PROFILE_REPRODUCIBLE", | |||
| "HEURISTIC_REPRODUCIBLE", | |||
| ) | |||
| if not option in valid_option: | |||
| raise ValueError("Valid option can only be one of {}".format(valid_option)) | |||
| global _conv_execution_strategy # pylint: disable=global-statement | |||
| _conv_execution_strategy = option | |||
| @@ -0,0 +1,299 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from typing import Optional, Tuple | |||
| from ..core._imperative_rt.ops import CollectiveCommDefModeEnum | |||
| from ..core.autodiff.builtin_op_utils import builtin_op_get_backward_fn | |||
| from ..core.autodiff.grad import ( | |||
| Tracer, | |||
| check_backward_allow_noinput, | |||
| get_grad_managers, | |||
| get_op_has_grad_fn, | |||
| tracer_apply, | |||
| ) | |||
| from ..core.ops.builtin import CollectiveComm, Copy, RemoteRecv, RemoteSend | |||
| from ..core.tensor.core import apply | |||
| from ..core.tensor.tensor import Tensor, tensor_apply | |||
| from ..distributed.group import ( | |||
| WORLD, | |||
| Group, | |||
| get_backend, | |||
| get_client, | |||
| get_mm_server_addr, | |||
| get_rank, | |||
| ) | |||
| from ..tensor import tensor | |||
| __all__ = [ | |||
| "reduce_sum", | |||
| "broadcast", | |||
| "all_gather", | |||
| "reduce_scatter_sum", | |||
| "all_reduce_sum", | |||
| "all_reduce_max", | |||
| "all_reduce_min", | |||
| "gather", | |||
| "scatter", | |||
| "all_to_all", | |||
| "remote_send", | |||
| "remote_recv", | |||
| ] | |||
| @apply.add | |||
| def _(op: RemoteSend, *args: Tensor): | |||
| ret = tensor_apply(op, *args) | |||
| # set extra information | |||
| tracer_set = dict() | |||
| for k in set().union(*(i._extra_data for i in args if isinstance(i, Tensor))): | |||
| tracer_set[k.name] = True | |||
| # check tracer_set in remote_recv | |||
| get_client().set_remote_tracer(op.key, tracer_set) | |||
| return ret | |||
| @builtin_op_get_backward_fn.register(RemoteSend) | |||
| def _(op: RemoteSend, inputs, outputs, input_requires_grad): | |||
| def backward(*args): | |||
| return [ | |||
| remote_recv( | |||
| op.rank_to, inputs[0].shape, inputs[0].dtype, str(inputs[0].device) | |||
| ) | |||
| ] | |||
| return backward, [True] | |||
| @get_op_has_grad_fn.register(RemoteSend) | |||
| def _(op: RemoteSend): | |||
| def has_grad(opnode, reached): | |||
| return get_client().check_is_grad(op.key) | |||
| return has_grad | |||
| @check_backward_allow_noinput.register(RemoteSend) | |||
| def _(op: RemoteSend): | |||
| return True | |||
| @builtin_op_get_backward_fn.register(RemoteRecv) | |||
| def _(op: RemoteRecv, inputs, outputs, input_requires_grad): | |||
| def backward(*output_grads): | |||
| return [remote_send(output_grads[0], op.rank_from)] | |||
| return backward, [True] | |||
| @get_op_has_grad_fn.register(RemoteRecv) | |||
| def _(op: RemoteRecv): | |||
| def has_grad(opnode, reached): | |||
| ret = False | |||
| for v in opnode.outputs: | |||
| if v() in reached: | |||
| ret = True | |||
| break | |||
| get_client().set_is_grad(op.key, ret) | |||
| return ret | |||
| return has_grad | |||
| def collective_comm(inp, mode, group, device): | |||
| """Helper function for applying collective communication functions""" | |||
| assert isinstance(group, Group) | |||
| if group is None: | |||
| return inp | |||
| op = CollectiveComm() | |||
| op.key = group.key | |||
| op.nr_devices = group.size | |||
| op.rank = group.rank | |||
| op.is_root = op.rank == 0 | |||
| op.local_grad = False | |||
| op.addr, op.port = get_mm_server_addr() | |||
| op.mode = mode | |||
| op.dtype = inp.dtype | |||
| op.backend = get_backend() | |||
| op.comp_node = device | |||
| return apply(op, inp)[0] | |||
| def reduce_sum( | |||
| inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||
| ) -> Tensor: | |||
| """Create reduce_sum operator for collective communication | |||
| :param inp: input tensor | |||
| :param group: communication group | |||
| :param device: execute placement | |||
| """ | |||
| mode = CollectiveCommDefModeEnum.REDUCE_SUM | |||
| return collective_comm(inp, mode, group, device) | |||
| def broadcast( | |||
| inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||
| ) -> Tensor: | |||
| """Create broadcast operator for collective communication | |||
| :param inp: input tensor | |||
| :param group: communication group | |||
| :param device: execute placement | |||
| """ | |||
| mode = CollectiveCommDefModeEnum.BROADCAST | |||
| return collective_comm(inp, mode, group, device) | |||
| def all_gather( | |||
| inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||
| ) -> Tensor: | |||
| """Create all_gather operator for collective communication | |||
| :param inp: input tensor | |||
| :param group: communication group | |||
| :param device: execute placement | |||
| """ | |||
| mode = CollectiveCommDefModeEnum.ALL_GATHER | |||
| return collective_comm(inp, mode, group, device) | |||
| def reduce_scatter_sum( | |||
| inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||
| ) -> Tensor: | |||
| """Create reduce_scatter_sum operator for collective communication | |||
| :param inp: input tensor | |||
| :param group: communication group | |||
| :param device: execute placement | |||
| """ | |||
| mode = CollectiveCommDefModeEnum.REDUCE_SCATTER_SUM | |||
| return collective_comm(inp, mode, group, device) | |||
| def all_reduce_sum( | |||
| inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||
| ) -> Tensor: | |||
| """Create all_reduce_sum operator for collective communication | |||
| :param inp: input tensor | |||
| :param group: communication group | |||
| :param device: execute placement | |||
| """ | |||
| mode = CollectiveCommDefModeEnum.ALL_REDUCE_SUM | |||
| return collective_comm(inp, mode, group, device) | |||
| def all_reduce_max( | |||
| inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||
| ) -> Tensor: | |||
| """Create all_reduce_max operator for collective communication | |||
| :param inp: input tensor | |||
| :param group: communication group | |||
| :param device: execute placement | |||
| """ | |||
| mode = CollectiveCommDefModeEnum.ALL_REDUCE_MAX | |||
| return collective_comm(inp, mode, group, device) | |||
| def all_reduce_min( | |||
| inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||
| ) -> Tensor: | |||
| """Create all_reduce_min operator for collective communication | |||
| :param inp: input tensor | |||
| :param group: communication group | |||
| :param device: execute placement | |||
| """ | |||
| mode = CollectiveCommDefModeEnum.ALL_REDUCE_MIN | |||
| return collective_comm(inp, mode, group, device) | |||
| def gather( | |||
| inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||
| ) -> Tensor: | |||
| """Create gather operator for collective communication | |||
| :param inp: input tensor | |||
| :param group: communication group | |||
| :param device: execute placement | |||
| """ | |||
| mode = CollectiveCommDefModeEnum.GATHER | |||
| return collective_comm(inp, mode, group, device) | |||
| def scatter( | |||
| inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||
| ) -> Tensor: | |||
| """Create scatter operator for collective communication | |||
| :param inp: input tensor | |||
| :param group: communication group | |||
| :param device: execute placement | |||
| """ | |||
| mode = CollectiveCommDefModeEnum.SCATTER | |||
| return collective_comm(inp, mode, group, device) | |||
| def all_to_all( | |||
| inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||
| ) -> Tensor: | |||
| """Create all_to_all operator for collective communication | |||
| :param inp: input tensor | |||
| :param group: communication group | |||
| :param device: execute placement | |||
| """ | |||
| mode = CollectiveCommDefModeEnum.ALL_TO_ALL | |||
| return collective_comm(inp, mode, group, device) | |||
| def remote_send(inp: Tensor, dest_rank: int) -> Tensor: | |||
| """Send a Tensor to a remote process | |||
| :param inp: tensor to send | |||
| :param dest_rank: destination process rank | |||
| """ | |||
| op = RemoteSend() | |||
| op.key = "{}->{}".format(get_rank(), dest_rank) | |||
| op.addr, op.port = get_mm_server_addr() | |||
| op.rank_to = dest_rank | |||
| return apply(op, inp)[0] | |||
| def remote_recv( | |||
| src_rank: int, shape: Tuple[int], dtype: type, cn: Optional[str] = "gpu0" | |||
| ) -> Tensor: | |||
| """Receive a Tensor from a remote process | |||
| :param src_rank: source process rank | |||
| :param shape: the shape of the tensor to receive | |||
| :param dtype: the data type of the tensor to receive | |||
| :param cn: the comp node to place the received tensor | |||
| """ | |||
| key = "{}->{}".format(src_rank, get_rank()) | |||
| # dummpy input | |||
| inp = tensor([0]) | |||
| tracer_set = get_client().check_remote_tracer(key) | |||
| for grad_manager in get_grad_managers(): | |||
| if grad_manager.name in tracer_set: | |||
| grad_manager.wrt(inp) | |||
| op = RemoteRecv() | |||
| op.key = key | |||
| op.cn = cn | |||
| op.shape = shape | |||
| op.dtype = dtype | |||
| op.addr, op.port = get_mm_server_addr() | |||
| op.rank_from = src_rank | |||
| return apply(op, inp)[0] | |||
| @@ -0,0 +1,481 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # pylint: disable=unused-argument,invalid-name,redefined-builtin,arguments-out-of-order | |||
| import functools | |||
| from ..core.ops import builtin | |||
| from ..core.tensor import utils | |||
| from ..core.tensor.core import apply | |||
| from ..tensor import Tensor | |||
| __all__ = [ | |||
| "abs", | |||
| "add", | |||
| "acos", | |||
| "asin", | |||
| "atan", | |||
| "atan2", | |||
| "asinh", | |||
| "acosh", | |||
| "atanh", | |||
| "bitwise_and", # TODO | |||
| "bitwise_not", # TODO | |||
| "bitwise_or", # TODO | |||
| "bitwise_xor", # TODO | |||
| "ceil", | |||
| "clamp", | |||
| "cos", | |||
| "cosh", | |||
| "div", | |||
| "eq", | |||
| "exp", | |||
| "expm1", | |||
| "floor", | |||
| "floor_div", | |||
| "gt", | |||
| "ge", | |||
| "hswish", | |||
| "hsigmoid", | |||
| "left_shift", | |||
| "lt", | |||
| "le", | |||
| "log", | |||
| "log1p", | |||
| "logical_and", | |||
| "logical_not", | |||
| "logical_or", | |||
| "logical_xor", | |||
| "maximum", | |||
| "minimum", | |||
| "mod", | |||
| "mul", | |||
| "neg", | |||
| "ne", | |||
| "pow", | |||
| "relu", | |||
| "relu6", | |||
| "right_shift", | |||
| "round", | |||
| "sigmoid", | |||
| "sin", | |||
| "sinh", | |||
| "sqrt", | |||
| "square", | |||
| "sub", | |||
| "tan", | |||
| "tanh", | |||
| "fast_tanh", | |||
| ] | |||
| def _elwise(*args, mode): | |||
| op = builtin.Elemwise(mode=mode) | |||
| args = utils.convert_inputs(*args) | |||
| (result,) = apply(op, *args) | |||
| return result | |||
| def _logical(*args, mode): | |||
| op = builtin.CondExecPredLogical(mode=mode) | |||
| args = utils.convert_inputs(*args) | |||
| (result,) = apply(op, *args) | |||
| return result | |||
| def _elemwise_multi_type(*args, mode, **kwargs): | |||
| op = builtin.ElemwiseMultiType(mode=mode, **kwargs) | |||
| args = utils.convert_inputs(*args) | |||
| (result,) = apply(op, *args) | |||
| return result | |||
| # math operations | |||
| def add(x, y): | |||
| """Element-wise addition. | |||
| At least one operand should be tensor. | |||
| same for sub/mul/div/floor_div/pow/mod/atan2/eq/ne/lt/le/gt/ge/maximum/minmium. | |||
| """ | |||
| return _elwise(x, y, mode="add") | |||
| def sub(x, y): | |||
| """Element-wise subtract.""" | |||
| return _elwise(x, y, mode="sub") | |||
| def mul(x, y): | |||
| """Element-wise multiplication.""" | |||
| return _elwise(x, y, mode="mul") | |||
| def div(x, y): | |||
| """Element-wise (x / y).""" | |||
| return _elwise(x, y, mode="true_div") | |||
| def floor_div(x, y): | |||
| """Element-wise floor(x / y).""" | |||
| return _elwise(x, y, mode="floor_divide") | |||
| def neg(x): | |||
| """Element-wise negation.""" | |||
| return _elwise(x, mode="negate") | |||
| def pow(x, y): | |||
| """Element-wise power.""" | |||
| return _elwise(x, y, mode="pow") | |||
| def mod(x, y): | |||
| """Element-wise remainder of division.""" | |||
| return _elwise(x, y, mode="mod") | |||
| def abs(x): | |||
| """Element-wise absolute value.""" | |||
| return _elwise(x, mode="abs") | |||
| def exp(x): | |||
| """Element-wise exponential.""" | |||
| return _elwise(x, mode="exp") | |||
| def expm1(x): | |||
| """Element-wise exp(x)-1.""" | |||
| return _elwise(x, mode="expm1") | |||
| def log(x): | |||
| """Element-wise logarithm (base `e`).""" | |||
| return _elwise(x, mode="log") | |||
| def log1p(x): | |||
| """Element-wise log(x+1) (base `e`).""" | |||
| return _elwise(x, mode="log1p") | |||
| def sqrt(inp: Tensor) -> Tensor: | |||
| """ | |||
| Return a new tensor with the square-root of the elements of ``inp``. | |||
| For negative value, return nan. | |||
| :param inp: The input tensor | |||
| :return: The computed tensor | |||
| Examples: | |||
| .. testcode:: | |||
| import numpy as np | |||
| import megengine as mge | |||
| import megengine.functional as F | |||
| data = mge.tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3)) | |||
| out = F.sqrt(data) | |||
| print(out.numpy()) | |||
| Outputs: | |||
| .. testoutput:: | |||
| [[0. 1. 1.4142] | |||
| [1.7321 2. 2.2361 ]] | |||
| """ | |||
| return inp ** 0.5 | |||
| def square(inp: Tensor) -> Tensor: | |||
| """ | |||
| Return a new tensor with the square of the elements of ``inp`` | |||
| :param inp: The input tensor | |||
| :return: The computed tensor | |||
| Examples: | |||
| .. testcode:: | |||
| import numpy as np | |||
| import megengine as mge | |||
| import megengine.functional as F | |||
| data = mge.tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3)) | |||
| out = F.square(data) | |||
| print(out.numpy()) | |||
| Outputs: | |||
| .. testoutput:: | |||
| [[0. 1. 4.] | |||
| [9. 16. 25.]] | |||
| """ | |||
| return inp ** 2 | |||
| def round(x): | |||
| """Round tensor to int element-wise.""" | |||
| return _elwise(x, mode="round") | |||
| def ceil(x): | |||
| """Return the ceil of the input, element-wise.""" | |||
| return _elwise(x, mode="ceil") | |||
| def floor(x): | |||
| """Calculate the floor element-wise""" | |||
| return _elwise(x, mode="floor") | |||
| # trigonometric functions | |||
| def cos(x): | |||
| """Cosine, element-wise.""" | |||
| return _elwise(x, mode="cos") | |||
| def sin(x): | |||
| """Sine, element-wise.""" | |||
| return _elwise(x, mode="sin") | |||
| def tan(x): | |||
| return sin(x) / cos(x) | |||
| def acos(x): | |||
| """Inverse cosine, element-wise.""" | |||
| return _elwise(x, mode="acos") | |||
| def asin(x): | |||
| """Inverse sine, element-wise.""" | |||
| return _elwise(x, mode="asin") | |||
| def atan(x): | |||
| return _elwise(x, 1, mode="atan2") | |||
| def atan2(y, x): | |||
| return _elwise(y, x, mode="atan2") | |||
| def cosh(x): | |||
| r"""Compute element-wise hyperbolic cosine.""" | |||
| return 0.5 * (exp(x) + exp(-x)) | |||
| def sinh(x): | |||
| r"""Compute element-wise hyperbolic sine.""" | |||
| u = expm1(x) | |||
| return 0.5 * u / (u + 1) * (u + 2) | |||
| def tanh(x): | |||
| r"""Compute element-wise hyperbolic tangent.""" | |||
| return _elwise(x, mode="tanh") | |||
| def asinh(x): | |||
| r"""Compute element-wise inverse hyperbolic sine.""" | |||
| return log(x + (x ** 2 + 1) ** 0.5) | |||
| def acosh(x): | |||
| r"""Compute element-wise inverse hyperbolic cosine.""" | |||
| return log(x + (x ** 2 - 1) ** 0.5) | |||
| def atanh(x): | |||
| r"""Compute element-wise inverse hyperbolic tangent.""" | |||
| return log1p(2 * x / (1 - x)) / 2 | |||
| def fast_tanh(x): | |||
| r"""Compute element-wise fast tanh; this is an approximation: | |||
| .. math:: | |||
| \text{fast_tanh}(x) = x * (27. + x * x) / (27. + 9. * x * x) | |||
| """ | |||
| return _elwise(x, mode="fast_tanh") | |||
| # bit-twiddling functions | |||
| def left_shift(x, y): | |||
| return _elwise(x, y, mode="shl") | |||
| def right_shift(x, y): | |||
| return _elwise(x, y, mode="shl") | |||
| def bitwise_and(x, y): | |||
| raise NotImplementedError | |||
| def bitwise_not(x): | |||
| raise NotImplementedError | |||
| def bitwise_or(x, y): | |||
| raise NotImplementedError | |||
| def bitwise_xor(x, y): | |||
| raise NotImplementedError | |||
| # logical functions | |||
| def logical_and(x, y): | |||
| return _elwise(x, y, mode="AND") | |||
| def logical_not(x): | |||
| return _elwise(x, mode="NOT") | |||
| def logical_or(x, y): | |||
| return _elwise(x, y, mode="OR") | |||
| def logical_xor(x, y): | |||
| return _elwise(x, y, mode="XOR") | |||
| # comparison functions | |||
| def eq(x, y): | |||
| """Return (x == y) element-wise.""" | |||
| return _elwise(x, y, mode="eq") | |||
| def ne(x, y): | |||
| return x != y | |||
| def lt(x, y): | |||
| """Return (x < y) element-wise.""" | |||
| return _elwise(x, y, mode="lt") | |||
| def le(x, y): | |||
| """Return (x =< y) element-wise.""" | |||
| return _elwise(x, y, mode="leq") | |||
| def gt(x, y): | |||
| """Return (x > y) element-wise.""" | |||
| return _elwise(y, x, mode="lt") | |||
| def ge(x, y): | |||
| """Return (x >= y) element-wise""" | |||
| return _elwise(y, x, mode="leq") | |||
| def hswish(x): | |||
| """Return x * relu6(x + 3) / 6 element-wise""" | |||
| return _elwise(x, mode="h_swish") | |||
| def hsigmoid(x): | |||
| """Return relu6(x + 3) / 6 element-wise""" | |||
| return relu6(x + 3) / 6 | |||
| def relu(x): | |||
| """Return `max(x, 0)` element-wise.""" | |||
| return _elwise(x, mode="relu") | |||
| def relu6(x): | |||
| """Return min(max(x, 0), 6) element-wise.""" | |||
| return minimum(maximum(x, 0), 6) | |||
| def sigmoid(x): | |||
| """Return 1 / ( 1 + exp( -x ) ) element-wise.""" | |||
| return _elwise(x, mode="sigmoid") | |||
| def maximum(x, y): | |||
| """Element-wise maximum of array elements.""" | |||
| return _elwise(x, y, mode="max") | |||
| def minimum(x, y): | |||
| """Element-wise minimum of array elements.""" | |||
| return _elwise(x, y, mode="min") | |||
| def clamp(inp: Tensor, lower=None, upper=None) -> Tensor: | |||
| r""" | |||
| Clamp all elements in :attr:`inp` into the range `[` :attr:`lower`, :attr:`upper` `]` and return | |||
| a resulting tensor: | |||
| .. math:: | |||
| y_i = \begin{cases} | |||
| \text{lower} & \text{if } x_i < \text{lower} \\ | |||
| x_i & \text{if } \text{lower} \leq x_i \leq \text{upper} \\ | |||
| \text{upper} & \text{if } x_i > \text{upper} | |||
| \end{cases} | |||
| :param inp: the input tensor. | |||
| :param lower: lower-bound of the range to be clamped to | |||
| :param upper: upper-bound of the range to be clamped to | |||
| Example: | |||
| .. testcode:: | |||
| import numpy as np | |||
| from megengine import tensor | |||
| import megengine.functional as F | |||
| a = tensor(np.arange(5).astype(np.int32)) | |||
| print(F.clamp(a, 2, 4).numpy()) | |||
| print(F.clamp(a, lower=3).numpy()) | |||
| print(F.clamp(a, upper=3).numpy()) | |||
| .. testoutput:: | |||
| [2 2 2 3 4] | |||
| [3 3 3 3 4] | |||
| [0 1 2 3 3] | |||
| """ | |||
| assert ( | |||
| lower is not None or upper is not None | |||
| ), "At least one of 'lower' or 'upper' must not be None" | |||
| if lower is not None: | |||
| if upper is not None: | |||
| assert lower <= upper, "clamp lower bound is bigger that upper bound" | |||
| return minimum(maximum(inp, lower), upper) | |||
| else: | |||
| return maximum(inp, lower) | |||
| else: | |||
| return minimum(inp, upper) | |||
| @@ -0,0 +1,44 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # pylint: disable=too-many-lines | |||
| from typing import List | |||
| from ..core import Tensor | |||
| def cambricon_subgraph( | |||
| inputs: List[Tensor], data: bytes, symbol: str, tensor_dim_mutable: bool, | |||
| ) -> List[Tensor]: | |||
| """Load a serialized Cambricon subgraph (i.e. cnrtModel_t) and | |||
| execute the operations defined in the subgraph. | |||
| :param inputs: List of input tensors of the subgraph. | |||
| :param data: The serialized subgraph. | |||
| :param symbol: The name of the function in the subgraph. | |||
| The function is corresponding to a cnmlFusionOp | |||
| which is added to the cnmlModel_t/cnrtModel_t. | |||
| :param tensor_dim_mutable: Whether the input tensors' shapes are mutalbe | |||
| in cnrtModel_t | |||
| """ | |||
| raise NotImplementedError | |||
| def extern_opr_subgraph( | |||
| inputs, output_shapes: List[tuple], dump_name: str, dump_data: bytes, | |||
| ) -> List[Tensor]: | |||
| """Load a serialized extern opr subgraph and fake execute the operator | |||
| :param inputs: Tensor or list of input tensors. | |||
| :param output_shapes: The output shapes. | |||
| :param dump_name: The serialized subgraph name. | |||
| :param dump_data: The serialized subgraph. | |||
| :return: List of tensors | |||
| """ | |||
| raise NotImplementedError | |||
| @@ -0,0 +1,41 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import collections | |||
| from typing import Iterable, Optional, Union | |||
| from ..core.tensor import Tensor | |||
| def add_update( | |||
| dest: Tensor, | |||
| delta: Tensor, | |||
| *, | |||
| alpha: Union[Tensor, float, int] = 1.0, | |||
| beta: Union[Tensor, float, int] = 1.0, | |||
| bias: Union[Tensor, float, int] = 0.0 | |||
| ): | |||
| r"""Inplace modify ``dest`` as follows: | |||
| .. math:: | |||
| dest = alpha * dest + beta * delta + bias | |||
| :param dest: input data that will be inplace modified. | |||
| :param delta: update value that will be added to ``dest``. | |||
| :param alpha: weight ratio of ``dest``. Default: 1.0 | |||
| :param beta: weight ratio of ``delta``. Default: 1.0 | |||
| :param bias: bias value appended to the result. Default: 0.0 | |||
| """ | |||
| if beta is not None and beta != 1.0: | |||
| delta = delta * beta | |||
| if bias is not None and bias != 0.0: | |||
| delta = delta + bias | |||
| if alpha is not None and alpha != 1.0: | |||
| dest *= alpha | |||
| dest += delta | |||
| return dest | |||
| @@ -0,0 +1,388 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import numpy as np | |||
| from ..tensor import Tensor | |||
| from .elemwise import abs, eq, exp, log, maximum, pow, relu | |||
| from .nn import assert_equal, indexing_one_hot | |||
| from .tensor import where | |||
| from .utils import zero_grad | |||
| def l1_loss(pred: Tensor, label: Tensor) -> Tensor: | |||
| r""" | |||
| Calculates the mean absolute error (MAE) between | |||
| each element in the pred :math:`x` and label :math:`y`. | |||
| The mean absolute error can be described as: | |||
| .. math:: \ell(x,y) = mean\left(L \right) | |||
| where | |||
| .. math:: | |||
| L = \{l_1,\dots,l_N\}, \quad | |||
| l_n = \left| x_n - y_n \right|, | |||
| :math:`x` and :math:`y` are tensors of arbitrary shapes with a total | |||
| of :math:`N` elements each. :math:`N` is the batch size. | |||
| :param pred: The predicted result from model. | |||
| :param label: The ground truth to compare. | |||
| Examples: | |||
| .. testcode:: | |||
| import numpy as np | |||
| import megengine as mge | |||
| import megengine.functional as F | |||
| ipt = mge.tensor(np.array([3, 3, 3, 3]).astype(np.float32)) | |||
| tgt = mge.tensor(np.array([2, 8, 6, 1]).astype(np.float32)) | |||
| loss = F.l1_loss(ipt,tgt) | |||
| print(loss.numpy()) | |||
| Outputs: | |||
| .. testoutput:: | |||
| [2.75] | |||
| """ | |||
| diff = pred - label | |||
| return abs(diff).mean() | |||
| def square_loss(pred: Tensor, label: Tensor) -> Tensor: | |||
| r""" | |||
| Calculates the mean squared error (squared L2 norm) between | |||
| each element in the pred :math:`x` and label :math:`y`. | |||
| The mean squared error can be described as: | |||
| .. math:: \ell(x, y) = mean\left( L \right) | |||
| where | |||
| .. math:: | |||
| L = \{l_1,\dots,l_N\}, \quad | |||
| l_n = \left( x_n - y_n \right)^2, | |||
| :math:`x` and :math:`y` are tensors of arbitrary shapes with a total | |||
| of :math:`N` elements each. :math:`N` is the batch size. | |||
| :param pred: The predicted result from model. | |||
| :param label: The ground truth to compare. | |||
| Shape: | |||
| - pred: :math:`(N, *)` where :math:`*` means any number of additional | |||
| dimensions | |||
| - label: :math:`(N, *)`. Same shape as ``pred`` | |||
| """ | |||
| diff = pred - label | |||
| return (diff ** 2).mean() | |||
| def cross_entropy( | |||
| inp: Tensor, target: Tensor, axis: int = 1, ignore_index: int = -1 | |||
| ) -> Tensor: | |||
| r""" | |||
| Returns the cross entropy loss in a classification problem. | |||
| .. math:: \textrm{CrossEntropy}(x, y) = - \sum_{i} y_i\log(x_i) | |||
| :param inp: The input tensor representing the predicted probability. | |||
| :param label: The input tensor representing the classification label. | |||
| :param axis: An axis along which cross_entropy will be applied. Default: 1 | |||
| :param ignore_index: Specifies a target value that is ignored and does not contribute to the input gradient. Default: -1 | |||
| Examples: | |||
| .. testcode:: | |||
| import numpy as np | |||
| from megengine import tensor | |||
| import megengine.functional as F | |||
| data_shape = (1, 2) | |||
| label_shape = (1, ) | |||
| pred = tensor(np.array([0.5, 0.5], dtype=np.float32).reshape(data_shape)) | |||
| label = tensor(np.ones(label_shape, dtype=np.int32)) | |||
| loss = F.cross_entropy(pred, label) | |||
| print(loss.numpy()) | |||
| Outputs: | |||
| .. testoutput:: | |||
| [0.69] | |||
| """ | |||
| raise NotImplementedError | |||
| # n0 = inp.ndim | |||
| # n1 = target.ndim | |||
| # assert n0 == n1 + 1, ( | |||
| # "target ndim must be one less than input ndim; input_ndim={} " | |||
| # "target_ndim={}".format(n0, n1) | |||
| # ) | |||
| # if ignore_index != -1: | |||
| # mask = 1 - equal(target, ignore_index) | |||
| # target = target * mask | |||
| # loss = -log(indexing_one_hot(inp, target, axis)) * mask | |||
| # return loss.sum() / maximum(mask.sum(), 1.0) | |||
| # else: | |||
| # return -log(indexing_one_hot(inp, target, axis)).mean() | |||
| def cross_entropy_with_softmax( | |||
| pred: Tensor, label: Tensor, axis: int = 1, label_smooth: float = 0 | |||
| ) -> Tensor: | |||
| r""" | |||
| Returns loss after applying :func:`~.softmax` + :func:`~.cross_entropy`. | |||
| It has better numerical stability compared with sequential calls to :func:`~.softmax` and :func:`~.cross_entropy`. | |||
| When using label smoothing, the label distribution is as follows: | |||
| .. math:: y^{LS}_{k}=y_{k}\left(1-\alpha\right)+\alpha/K | |||
| where :math:`y^{LS}` and :math:`y` are new label distribution and origin label distribution respectively. | |||
| k is the index of label distribution. :math:`\alpha` is label_smooth and :math:`K` is the number of classes. | |||
| :param pred: The input tensor representing the predicted probability. | |||
| :param label: The input tensor representing the classification label. | |||
| :param axis: An axis along which softmax will be applied. Default: 1. | |||
| :param label_smooth: A label smoothing of parameter that can re-distribute target distribution. Default: 0. | |||
| """ | |||
| n0 = pred.ndim | |||
| n1 = label.ndim | |||
| assert n0 == n1 + 1, ( | |||
| "target ndim must be one less than input ndim; input_ndim={} " | |||
| "target_ndim={}".format(n0, n1) | |||
| ) | |||
| num_classes = pred.shape[axis] | |||
| # Denominator of the softmax | |||
| offset = pred.max(axis=axis).detach() | |||
| pred = pred - offset | |||
| down = exp(pred).sum(axis=axis) | |||
| up = pred[np.arange(pred.shape[0]), label] | |||
| if label_smooth != 0: | |||
| factor = label_smooth / num_classes | |||
| up = up * (1 - label_smooth) + pred.sum(axis=axis) * factor | |||
| return (log(down) - up).mean() | |||
| def triplet_margin_loss( | |||
| anchor: Tensor, positive: Tensor, negative: Tensor, margin: float = 1.0, p: int = 2 | |||
| ) -> Tensor: | |||
| r""" | |||
| Creates a criterion that measures the triplet loss given an input tensors. | |||
| .. math:: | |||
| L(a, p, n) = max\left\{d\left(a_{i},p_{i}\right)-d\left(a_{i}, n_{i}\right)+margin, 0\right\},\ | |||
| d\left(x_{i},y_{i}\right)=\left\|x_{i}-y_{i}\right\|_{p} | |||
| :param anchor: The input tensor representing the anchor samples. | |||
| :param positive: The input tensor representing the positive samples. | |||
| :param negative: The input tensor representing the negative samples. | |||
| :param margin: Default: 1.0 | |||
| :param p: The norm degree for pairwise distance. Default: 2.0 | |||
| """ | |||
| s0 = anchor.shapeof() | |||
| s1 = positive.shapeof() | |||
| s2 = negative.shapeof() | |||
| assert_equal(s0, s1) | |||
| assert_equal(s1, s2) | |||
| n0 = anchor.ndim | |||
| n1 = positive.ndim | |||
| n2 = negative.ndim | |||
| assert n0 == 2 and n1 == 2 and n2 == 2, ( | |||
| "anchor ndim, positive ndim, and negative ndim must be 2; " | |||
| "anchor_ndim={} positive_ndim={} negative_ndim={}".format(n0, n1, n2) | |||
| ) | |||
| assert p > 0, "a margin with a value greater than 0; p={}".format(p) | |||
| diff0 = abs(anchor - positive) | |||
| diff1 = abs(anchor - negative) | |||
| d1 = power(power(diff0, p).sum(axis=1, keepdims=True), 1 / p) | |||
| d2 = power(power(diff1, p).sum(axis=1, keepdims=True), 1 / p) | |||
| loss = maximum(d1 - d2 + margin, 0) | |||
| return loss.mean() | |||
| def binary_cross_entropy(pred: Tensor, label: Tensor) -> Tensor: | |||
| r"""Function that measures the Binary Cross Entropy between the target and the prediction. | |||
| :param pred: (N,*) where * means, any number of additional dimensions. | |||
| :param label: (N,*), same shape as the input. | |||
| """ | |||
| assert pred.shape == label.shape | |||
| return -1.0 * (label * log(pred) + (1.0 - label) * log(1 - pred)).mean() | |||
| def nll_loss( | |||
| pred: Tensor, label: Tensor, axis: int = 1, ignore_index: int = -1 | |||
| ) -> Tensor: | |||
| r""" | |||
| The negative log likelihood loss. | |||
| :param pred: The predicted result from model. | |||
| :param label: The ground truth to compare. | |||
| Examples: | |||
| .. testcode:: | |||
| import numpy as np | |||
| from megengine import tensor | |||
| import megengine.functional as F | |||
| data_shape = (2, 2) | |||
| label_shape = (2, ) | |||
| data = tensor( | |||
| np.array([[1, 0.5], [0.3, 1.2]], dtype=np.float32).reshape(data_shape), | |||
| ) | |||
| label = tensor( | |||
| np.ones(label_shape, dtype=np.int32) | |||
| ) | |||
| pred = F.log(F.softmax(data)) | |||
| loss1 = F.nll_loss(pred, label) | |||
| loss2 = F.cross_entropy_with_softmax(data, label) | |||
| print(loss1.numpy(), loss2.numpy()) | |||
| Outputs: | |||
| .. testoutput:: | |||
| [0.6576154] [0.6576154] | |||
| """ | |||
| raise NotImplementedError | |||
| # n0 = pred.ndim | |||
| # n1 = label.ndim | |||
| # assert n0 == n1 + 1, ( | |||
| # "target ndim must be one less than input ndim; input_ndim={} " | |||
| # "target_ndim={}".format(n0, n1) | |||
| # ) | |||
| # mask = 1.0 - equal(label, ignore_index) | |||
| # label = label * mask | |||
| # loss = indexing_one_hot(pred, label, axis) * mask | |||
| # return -1.0 * loss.sum() / maximum(mask.sum(), 1.0) | |||
| def hinge_loss(pred: Tensor, label: Tensor, norm: str = "L1") -> Tensor: | |||
| r""" | |||
| Caculate the hinge loss which is often used in SVMs. | |||
| The hinge loss can be described as: | |||
| .. math:: loss(x, y) = \frac{1}{N}\sum_i\sum_j(max(0, 1 - x_i_j*y_i_j)) | |||
| :param pred: The input tensor representing the predicted probability, shape is (N, C). | |||
| :param label: The input tensor representing the binary classification label, shape is (N, C). | |||
| :param norm: Specify the norm to caculate the loss, should be "L1" or "L2". | |||
| Examples: | |||
| .. testcode:: | |||
| from megengine import tensor | |||
| import megengine.functional as F | |||
| pred = tensor([[0.5, -0.5, 0.1], [-0.6, 0.7, 0.8]], dtype="float32") | |||
| label = tensor([[1, -1, -1], [-1, 1, 1]], dtype="float32") | |||
| loss = F.hinge_loss(pred, label) | |||
| print(loss.numpy()) | |||
| Outputs: | |||
| .. testoutput:: | |||
| [1.5] | |||
| """ | |||
| assert norm in ["L1", "L2"], "norm must be L1 or L2" | |||
| # Converts binary labels to -1/1 labels. | |||
| loss = relu(1.0 - pred * label) | |||
| if norm == "L1": | |||
| return loss.sum(axis=1).mean() | |||
| else: | |||
| return (loss ** 2).sum(axis=1).mean() | |||
| def smooth_l1_loss(pred: Tensor, label: Tensor) -> Tensor: | |||
| r""" | |||
| Caculate the smooth l1 loss proposed in `Fast R-CNN paper by Ross Girshick`. | |||
| The smooth l1 loss can be described as: | |||
| .. math:: | |||
| \text{loss}(x, y) = \frac{1}{n} \sum_{i} l_{i} | |||
| where :math:`l_{i}` is given by: | |||
| .. math:: | |||
| l_{i} = | |||
| \begin{cases} | |||
| 0.5 (x_i - y_i)^2, & \text{if } |x_i - y_i| < 1 \\ | |||
| |x_i - y_i| - 0.5, & \text{otherwise } | |||
| \end{cases} | |||
| :param pred: The predicted result from model. | |||
| :param label: The ground truth to compare. | |||
| Examples: | |||
| .. testcode:: | |||
| from megengine import tensor | |||
| import megengine.functional as F | |||
| pred = tensor([[0.5, -0.5, 0.1], [-0.6, 0.7, 0.8]]) | |||
| label = tensor([[0.4, 1.5, 1.2], [0., 0.1, 2.2]]) | |||
| loss = F.smooth_l1_loss(pred, label) | |||
| print(loss.numpy()) | |||
| Outputs: | |||
| .. testoutput:: | |||
| [0.5608334] | |||
| """ | |||
| raise NotImplementedError | |||
| # diff = abs(pred - label) | |||
| # l2_loss = 0.5 * (diff ** 2) | |||
| # l1_loss = diff - 0.5 | |||
| # mask = diff < 1 | |||
| # loss = where(mask, l2_loss, l1_loss) | |||
| # return loss.mean() | |||
| @@ -0,0 +1,696 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import collections | |||
| import functools | |||
| import math | |||
| import numbers | |||
| from typing import Optional, Sequence, Tuple, Union | |||
| from ..core.ops import builtin | |||
| from ..core.ops._internal import param_defs as P | |||
| from ..core.tensor import utils | |||
| from ..core.tensor.core import apply | |||
| from ..tensor import Tensor | |||
| from .elemwise import clamp, exp, log, log1p | |||
| from .tensor import remove_axis, reshape | |||
| __all__ = [ | |||
| "all", # TODO | |||
| "all_close", # TODO | |||
| "any", # TODO | |||
| "argmax", | |||
| "argmin", | |||
| "argsort", | |||
| "isinf", | |||
| "isnan", # TODO | |||
| "max", | |||
| "mean", | |||
| "median", # TODO | |||
| "min", | |||
| "norm", | |||
| "normalize", | |||
| "prod", | |||
| "sign", # TODO | |||
| "sort", | |||
| "std", | |||
| "sum", | |||
| "topk", | |||
| "unique", # TODO | |||
| "var", | |||
| ] | |||
| def all(inp): | |||
| raise NotImplementedError | |||
| def all_close(inp): | |||
| raise NotImplementedError | |||
| def any(inp): | |||
| raise NotImplementedError | |||
| def unique(inp): | |||
| raise NotImplementedError | |||
| def isnan(inp: Tensor) -> Tensor: | |||
| r"""Returns a new tensor representing if each element is NaN or not. | |||
| :param: inp | |||
| :return: a new tensor representing if each element in :attr:`inp` is NaN or not. | |||
| Examples: | |||
| .. testcode:: | |||
| from megengine import tensor | |||
| import megengine.functional as F | |||
| x = tensor([1, float("nan"), 0]) | |||
| print(F.isnan(x)) | |||
| .. testoutput:: | |||
| Tensor([0 1 0], dtype=uint8) | |||
| """ | |||
| raise NotImplementedError | |||
| # return (inp != inp).astype("uint8") | |||
| def isinf(inp: Tensor) -> Tensor: | |||
| r"""Returns a new tensor representing if each element is Inf or not. | |||
| :param: inp | |||
| :return: a new tensor representing if each element in :attr:`inp` is Inf or not. | |||
| Examples: | |||
| .. testcode:: | |||
| from megengine import tensor | |||
| import megengine.functional as F | |||
| x = tensor([1, float("inf"), 0]) | |||
| print(F.isinf(x)) | |||
| .. testoutput:: | |||
| Tensor([0 1 0], dtype=uint8) | |||
| """ | |||
| return (abs(inp).astype("float32") == float("inf")).astype("uint8") | |||
| def sign(inp: Tensor): | |||
| raise NotImplementedError | |||
| def _reduce( | |||
| data, | |||
| *, | |||
| mode, | |||
| axis: Optional[Union[int, Sequence[int]]] = None, | |||
| keepdims: bool = False | |||
| ): | |||
| (data,) = utils.convert_inputs(data) | |||
| if axis is None: | |||
| data = data.reshape(-1) | |||
| assert not keepdims, "can not set axis=None and keepdims=True" | |||
| op = builtin.Reduce(mode=mode, axis=0) | |||
| (result,) = apply(op, data) | |||
| elif isinstance(axis, collections.Iterable): | |||
| axis = list(axis) | |||
| axis.sort(reverse=True) | |||
| for ai in axis: | |||
| op = builtin.Reduce(mode=mode, axis=ai) | |||
| (data,) = apply(op, data) | |||
| if not keepdims: | |||
| data = remove_axis(data, ai) | |||
| result = data | |||
| else: | |||
| op = builtin.Reduce(mode=mode, axis=axis) | |||
| (result,) = apply(op, data) | |||
| if not keepdims: | |||
| result = remove_axis(result, axis) | |||
| return result | |||
| def sum( | |||
| inp: Tensor, | |||
| axis: Optional[Union[int, Sequence[int]]] = None, | |||
| keepdims: bool = False, | |||
| ) -> Tensor: | |||
| r"""Returns the sum of each row of the ``inp`` tensor in the given ``axis``. | |||
| :param inp: The input tensor. | |||
| :param axis: The dimension to reduce. If None, all the dimensions will be reduced. | |||
| Default: None | |||
| :param keepdims: Whether the output tensor has ``axis`` retained or not. | |||
| Default: False | |||
| :return: The output tensor | |||
| Examples: | |||
| .. testcode:: | |||
| import numpy as np | |||
| from megengine import tensor | |||
| import megengine.functional as F | |||
| data = tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3)) | |||
| out = F.sum(data) | |||
| print(out.numpy()) | |||
| .. testoutput:: | |||
| [21] | |||
| """ | |||
| return _reduce(inp, mode="SUM", axis=axis, keepdims=keepdims) | |||
| def prod( | |||
| inp: Tensor, axis: Optional[Union[int, Sequence[int]]] = None, keepdims=False | |||
| ) -> Tensor: | |||
| r""" | |||
| Returns the element product of input tensor along given *axis*. | |||
| :param inp: The input tensor | |||
| :param axis: The dimension to reduce. If None, all the dimensions will be reduced. Default: ``None`` | |||
| :param keepdims: Whether the output tensor has *axis* retained or not. Default: ``False`` | |||
| :return: The output tensor | |||
| Examples: | |||
| .. testcode:: | |||
| import numpy as np | |||
| from megengine import tensor | |||
| import megengine.functional as F | |||
| data = tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3)) | |||
| out = F.prod(data) | |||
| print(out.numpy()) | |||
| Outputs: | |||
| .. testoutput:: | |||
| [720] | |||
| """ | |||
| return _reduce(inp, mode="PRODUCT", axis=axis, keepdims=keepdims) | |||
| def mean( | |||
| inp: Tensor, | |||
| axis: Optional[Union[int, Sequence[int]]] = None, | |||
| keepdims: bool = False, | |||
| ) -> Tensor: | |||
| """Returns the mean value of each row of the ``inp`` tensor in | |||
| the given ``axis``. If axis is a list of dimensions, | |||
| reduce over all of them. | |||
| :param inp: The input tensor | |||
| :param axis: The dimension to reduce. If None, all the dimensions will be reduced. Default: None | |||
| :param keepdims: Whether the output tensor has ``axis`` retained or not. Default: False | |||
| Examples: | |||
| .. testcode:: | |||
| import numpy as np | |||
| from megengine import tensor | |||
| import megengine.functional as F | |||
| data = tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3)) | |||
| out = F.mean(data) | |||
| print(out.numpy()) | |||
| .. testoutput:: | |||
| [3.5] | |||
| """ | |||
| return _reduce(inp, mode="MEAN", axis=axis, keepdims=keepdims) | |||
| def median( | |||
| inp: Tensor, | |||
| axis: Optional[Union[int, Sequence[int]]] = None, | |||
| keepdims: bool = False, | |||
| ) -> Tensor: | |||
| raise NotImplementedError | |||
| def var( | |||
| inp: Tensor, | |||
| axis: Optional[Union[int, Sequence[int]]] = None, | |||
| keepdims: bool = False, | |||
| ) -> Tensor: | |||
| """Returns the variance value of input tensor along | |||
| given ``axis``. If axis is a list of dimensions, | |||
| reduce over all of them. | |||
| :param inp: The input tensor. | |||
| :param axis: The dimension to reduce. If None, all the dimensions will be reduced. Default: ``None``. | |||
| :param keepdims: Whether the output tensor has ``axis`` retained or not. Default: ``False``. | |||
| :return: The output tensor. | |||
| Examples: | |||
| .. testcode:: | |||
| import numpy as np | |||
| from megengine import tensor | |||
| import megengine.functional as F | |||
| data = tensor(np.arange(1, 7, dtype=np.float32).reshape(2, 3)) | |||
| out = F.var(data) | |||
| print(out.numpy()) | |||
| .. testoutput:: | |||
| [2.9166667] | |||
| """ | |||
| if axis is None: | |||
| m = mean(inp, axis=axis, keepdims=False) | |||
| else: | |||
| m = mean(inp, axis=axis, keepdims=True) | |||
| v = inp - m | |||
| return mean(v ** 2, axis=axis, keepdims=keepdims) | |||
| def std( | |||
| inp: Tensor, | |||
| axis: Optional[Union[int, Sequence[int]]] = None, | |||
| keepdims: bool = False, | |||
| ) -> Tensor: | |||
| """Returns the standard deviation of input tensor along | |||
| given ``axis``. If axis is a list of dimensions, | |||
| reduce over all of them. | |||
| :param inp: The input tensor. | |||
| :param axis: The dimension to reduce. If None, all the dimensions will be reduced. Default: ``None``. | |||
| :param keepdims: Whether the output tensor has ``axis`` retained or not. Default: ``False``. | |||
| :return: The output tensor. | |||
| Examples: | |||
| .. testcode:: | |||
| import numpy as np | |||
| from megengine import tensor | |||
| import megengine.functional as F | |||
| data = tensor(np.arange(1, 7, dtype=np.float32).reshape(2, 3)) | |||
| out = F.std(data, axis=1) | |||
| print(out.numpy()) | |||
| .. testoutput:: | |||
| [0.8164966 0.8164966] | |||
| """ | |||
| return var(inp, axis=axis, keepdims=keepdims) ** 0.5 | |||
| def min( | |||
| inp: Tensor, | |||
| axis: Optional[Union[int, Sequence[int]]] = None, | |||
| keepdims: bool = False, | |||
| ) -> Tensor: | |||
| r""" | |||
| Returns the min value of input tensor along given *axis*. | |||
| :param inp: The input tensor | |||
| :param axis: The dimension to reduce. If None, all the dimensions will be reduced. Default: None | |||
| :param keepdims: Whether the output tensor has *axis* retained or not. Default: False | |||
| :return: The output tensor | |||
| Examples: | |||
| .. testcode:: | |||
| import numpy as np | |||
| from megengine import tensor | |||
| import megengine.functional as F | |||
| x = tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3)) | |||
| y = F.min(x) | |||
| print(y.numpy()) | |||
| Outputs: | |||
| .. testoutput:: | |||
| [1] | |||
| """ | |||
| return _reduce(inp, mode="MIN", axis=axis, keepdims=keepdims) | |||
| def max( | |||
| inp: Tensor, | |||
| axis: Optional[Union[int, Sequence[int]]] = None, | |||
| keepdims: bool = False, | |||
| ) -> Tensor: | |||
| r"""Returns the max value of the input tensor along given *axis*. | |||
| :param inp: The input tensor | |||
| :param axis: The dimension to reduce. If None, all the dimensions will be reduced. Default: None | |||
| :param keepdims: Whether the output tensor has *axis* retained or not. Default: False | |||
| :return: The output tensor | |||
| Examples: | |||
| .. testcode:: | |||
| import numpy as np | |||
| from megengine import tensor | |||
| import megengine.functional as F | |||
| x = tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3)) | |||
| y = F.max(x) | |||
| print(y.numpy()) | |||
| .. testoutput:: | |||
| [6] | |||
| """ | |||
| return _reduce(inp, mode="MAX", axis=axis, keepdims=keepdims) | |||
| def norm( | |||
| inp: Tensor, | |||
| p: int = 2, | |||
| axis: Optional[Union[int, Sequence[int]]] = None, | |||
| keepdims=False, | |||
| ): | |||
| """Calculate ``p``-norm of input tensor along certain axis. | |||
| :param inp: The input tensor | |||
| :param p: power of value ``p`` applied to ``inp``. Default: 2 | |||
| :param axis: The dimension to reduce. If None, all the dimensions will be reduced. Default: None | |||
| :param keepdims: Whether the output tensor has ``axis`` retained or not. Default: False | |||
| :return: The output tensor | |||
| Examples: | |||
| .. testcode:: | |||
| import numpy as np | |||
| from megengine import tensor | |||
| import megengine.functional as F | |||
| x = tensor(np.arange(-3, 3, dtype=np.float32).reshape(2,3)) | |||
| y = F.norm(x) | |||
| print(y.numpy()) | |||
| .. testoutput:: | |||
| [4.358899] | |||
| """ | |||
| if p == 0: | |||
| return sum(inp != 0, axis=axis, keepdims=keepdims) | |||
| if p == math.inf: | |||
| return max(abs(inp)) | |||
| if p == -math.inf: | |||
| return min(abs(inp)) | |||
| return sum(abs(inp) ** p, axis=axis, keepdims=keepdims) ** (1.0 / p) | |||
| def argmin( | |||
| inp: Tensor, | |||
| axis: Optional[Union[int, Sequence[int]]] = None, | |||
| keepdims: bool = False, | |||
| ) -> Tensor: | |||
| r"""Returns the indices of the minimum values along an axis | |||
| :param inp: The input tensor | |||
| :param axis: The dimension to reduce. If None, all the dimensions will be reduced. Default: None | |||
| :param keepdims: Whether the output tensor has *axis* retained or not. Default: False | |||
| :return: The output tensor | |||
| Examples: | |||
| .. testcode:: | |||
| import numpy as np | |||
| from megengine import tensor | |||
| import megengine.functional as F | |||
| x = tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3)) | |||
| y = F.argmin(x) | |||
| print(y.numpy()) | |||
| .. testoutput:: | |||
| [0] | |||
| """ | |||
| if isinstance(axis, collections.Iterable): | |||
| axis = list(axis) | |||
| axis.sort(reverse=True) | |||
| for ai in axis: | |||
| op = builtin.Argmin(axis=ai) | |||
| (inp,) = apply(op, inp) | |||
| if not keepdims: | |||
| inp = remove_axis(inp, ai) | |||
| return inp | |||
| if axis is None: | |||
| assert not keepdims, "can not set axis=None and keepdims=True" | |||
| inp = inp.flatten() | |||
| axis = 0 | |||
| op = builtin.Argmin(axis=axis) | |||
| (result,) = apply(op, inp) | |||
| if not keepdims: | |||
| result = remove_axis(result, axis) | |||
| return result | |||
| def argmax( | |||
| inp: Tensor, | |||
| axis: Optional[Union[int, Sequence[int]]] = None, | |||
| keepdims: bool = False, | |||
| ) -> Tensor: | |||
| r"""Returns the indices of the maximum values along an axis | |||
| :param inp: The input tensor | |||
| :param axis: The dimension to reduce. If None, all the dimensions will be reduced. Default: None | |||
| :param keepdims: Whether the output tensor has *axis* retained or not. Default: False | |||
| :return: The output tensor | |||
| Examples: | |||
| .. testcode:: | |||
| import numpy as np | |||
| from megengine import tensor | |||
| import megengine.functional as F | |||
| x = tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3)) | |||
| y = F.argmax(x) | |||
| print(y.numpy()) | |||
| .. testoutput:: | |||
| [5] | |||
| """ | |||
| if isinstance(axis, collections.Iterable): | |||
| axis = list(axis) | |||
| axis.sort(reverse=True) | |||
| for ai in axis: | |||
| op = builtin.Argmax(axis=ai) | |||
| (inp,) = apply(op, inp) | |||
| if not keepdims: | |||
| inp = remove_axis(inp, ai) | |||
| return inp | |||
| if axis is None: | |||
| assert not keepdims, "can not set axis=None and keepdims=True" | |||
| inp = inp.flatten() | |||
| axis = 0 | |||
| op = builtin.Argmax(axis=axis) | |||
| (result,) = apply(op, inp) | |||
| if not keepdims: | |||
| result = remove_axis(result, axis) | |||
| return result | |||
| def normalize( | |||
| inp: Tensor, | |||
| p: int = 2, | |||
| axis: Optional[Union[int, Sequence[int]]] = None, | |||
| eps: float = 1e-12, | |||
| ) -> Tensor: | |||
| r"""Perform :math:`L_p` normalization of input tensor along certain axis. | |||
| For a tensor :attr:`inp` of shape :math:`(n_0, ..., n_{dim}, ..., n_k)`, each | |||
| :math:`n_{dim}` -element vector :math:`v` along dimension :attr:`axis` is transformed as: | |||
| .. math:: | |||
| v = \frac{v}{\max(\lVert v \rVert_p, \epsilon)}. | |||
| :param inp: the input tensor | |||
| :param p: power of value ``p`` applied to ``inp``. Default: 2 | |||
| :param axis: The dimension to reduce. If None, all the dimensions will be reduced | |||
| to calculate the norm. Default: None | |||
| :param eps: a small value to avoid division by zero. Default: 1e-12 | |||
| :return: the normalized output tensor | |||
| """ | |||
| if axis is None: | |||
| return inp / clamp(norm(inp, p, axis), lower=eps) | |||
| else: | |||
| return inp / clamp(norm(inp, p, axis, keepdims=True), lower=eps) | |||
| def argsort(inp: Tensor, descending: bool = False) -> Tensor: | |||
| r""" | |||
| Sort the target 2d matrix by row, return both the sorted tensor and indices. | |||
| :param inp: The input tensor, if 2d, each row will be sorted | |||
| :param descending: Sort in descending order, where the largest comes first. Default: ``False`` | |||
| :return: Tuple of two tensors (sorted_tensor, indices_of_int32) | |||
| Examples: | |||
| .. testcode:: | |||
| import numpy as np | |||
| from megengine import tensor | |||
| import megengine.functional as F | |||
| data = tensor(np.array([1,2], dtype=np.float32)) | |||
| indices = F.argsort(data) | |||
| print(indices.numpy()) | |||
| Outputs: | |||
| .. testoutput:: | |||
| [0 1] | |||
| """ | |||
| assert len(inp.shape) <= 2, "Input should be 1d or 2d" | |||
| if descending: | |||
| order = P.Argsort.Order.DESCENDING | |||
| else: | |||
| order = P.Argsort.Order.ASCENDING | |||
| op = builtin.Argsort(order=order) | |||
| if len(inp.shape) == 1: | |||
| inp = inp.reshape(1, -1) | |||
| _, result = apply(op, inp) | |||
| return result[0] | |||
| _, result = apply(op, inp) | |||
| return result | |||
| def sort(inp: Tensor, descending: bool = False) -> Tuple[Tensor, Tensor]: | |||
| assert len(inp.shape) <= 2, "Input should be 1d or 2d" | |||
| if descending: | |||
| order = P.Argsort.Order.DESCENDING | |||
| else: | |||
| order = P.Argsort.Order.ASCENDING | |||
| op = builtin.Argsort(order=order) | |||
| if len(inp.shape) == 1: | |||
| inp = inp.reshape(1, -1) | |||
| tns, ind = apply(op, inp) | |||
| return tns[0], ind[0] | |||
| tns, ind = apply(op, inp) | |||
| return tns, ind | |||
| def topk( | |||
| inp: Tensor, | |||
| k: int, | |||
| descending: bool = False, | |||
| kth_only: bool = False, | |||
| no_sort: bool = False, | |||
| ) -> Tuple[Tensor, Tensor]: | |||
| r""" | |||
| Selected the Top-K (by default) smallest elements of 2d matrix by row. | |||
| :param inp: The input tensor, if 2d, each row will be sorted | |||
| :param k: The number of elements needed | |||
| :param descending: If true, return the largest elements instead. Default: ``False`` | |||
| :param kth_only: If true, only the k-th element will be returned. Default: ``False`` | |||
| :param no_sort: If true, the returned elements can be unordered. Default: ``False`` | |||
| :return: Tuple of two tensors (topk_tensor, indices_of_int32) | |||
| Examples: | |||
| .. testcode:: | |||
| import numpy as np | |||
| from megengine import tensor | |||
| import megengine.functional as F | |||
| data = tensor(np.array([2, 4, 6, 8, 7, 5, 3, 1], dtype=np.float32)) | |||
| top, indices = F.topk(data, 5) | |||
| print(top.numpy(), indices.numpy()) | |||
| Outputs: | |||
| .. testoutput:: | |||
| [1. 2. 3. 4. 5.] [7 0 6 1 5] | |||
| """ | |||
| if descending: | |||
| inp = -inp | |||
| Mode = P.TopK.Mode | |||
| if kth_only: | |||
| mode = Mode.KTH_ONLY | |||
| elif no_sort: | |||
| mode = Mode.VALUE_IDX_NOSORT | |||
| else: | |||
| mode = Mode.VALUE_IDX_SORTED | |||
| op = builtin.TopK(mode=mode) | |||
| if len(inp.shape) == 1: | |||
| inp = inp.reshape(1, -1) | |||
| res = apply(op, inp, Tensor(k, dtype="int32")) | |||
| if kth_only: | |||
| tns = res[0] | |||
| else: | |||
| tns, ind = res[0][0], res[1][0] | |||
| else: | |||
| res = apply(op, inp, Tensor(k, dtype="int32")) | |||
| if kth_only: | |||
| tns = res | |||
| else: | |||
| tns, ind = res[0], res[1] | |||
| if descending: | |||
| tns = -tns | |||
| return tns, ind | |||
| @@ -0,0 +1,83 @@ | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # pylint: disable=too-many-lines | |||
| from typing import Tuple, Union | |||
| from ..core.ops import builtin | |||
| from ..core.tensor.core import apply | |||
| from ..tensor import Tensor | |||
| from .debug_param import get_conv_execution_strategy | |||
| from .types import _pair, _pair_nonzero | |||
| def conv_bias_activation( | |||
| inp: Tensor, | |||
| weight: Tensor, | |||
| bias: Tensor, | |||
| dtype=None, | |||
| stride: Union[int, Tuple[int, int]] = 1, | |||
| padding: Union[int, Tuple[int, int]] = 0, | |||
| dilation: Union[int, Tuple[int, int]] = 1, | |||
| groups: int = 1, | |||
| format="NCHW", | |||
| nonlinear_mode="IDENTITY", | |||
| conv_mode="CROSS_CORRELATION", | |||
| compute_mode="DEFAULT", | |||
| ) -> Tensor: | |||
| """ convolution bias with activation operation, only for inference. | |||
| :param inp: The feature map of the convolution operation | |||
| :param weight: The convolution kernel | |||
| :param bias: The bias added to the result of convolution | |||
| :param stride: Stride of the 2D convolution operation. Default: 1 | |||
| :param padding: Size of the paddings added to the input on both sides of its | |||
| spatial dimensions. Only zero-padding is supported. Default: 0 | |||
| :param dilation: Dilation of the 2D convolution operation. Default: 1 | |||
| :param groups: number of groups to divide input and output channels into, | |||
| so as to perform a "grouped convolution". When ``groups`` is not 1, | |||
| ``in_channels`` and ``out_channels`` must be divisible by ``groups``, | |||
| and the shape of weight should be ``(groups, out_channel // groups, | |||
| in_channels // groups, height, width)``. | |||
| :type conv_mode: string or :class:`P.Convolution.Mode` | |||
| :param conv_mode: Supports 'CROSS_CORRELATION' or 'CONVOLUTION'. Default: | |||
| 'CROSS_CORRELATION'. | |||
| :param dtype: Support for np.dtype, Default: | |||
| np.int8. | |||
| :param scale: scale if use quantization, Default: | |||
| 0.0. | |||
| :param zero_point: scale if use quantization quint8, Default: | |||
| 0.0. | |||
| :type compute_mode: string or | |||
| :class:`P.Convolution.ComputeMode` | |||
| :param compute_mode: When set to 'DEFAULT', no special requirements will be | |||
| placed on the precision of intermediate results. When set to 'FLOAT32', | |||
| Float32 would be used for accumulator and intermediate result, but only | |||
| effective when input and output are of Float16 dtype. | |||
| """ | |||
| ph, pw = _pair(padding) | |||
| sh, sw = _pair_nonzero(stride) | |||
| dh, dw = _pair_nonzero(dilation) | |||
| sparse_type = "DENSE" if groups == 1 else "GROUP" | |||
| op = builtin.ConvBiasForward( | |||
| stride_h=sh, | |||
| stride_w=sw, | |||
| pad_h=ph, | |||
| pad_w=pw, | |||
| dilate_h=dh, | |||
| dilate_w=dw, | |||
| dtype=dtype, | |||
| format=format, | |||
| strategy=get_conv_execution_strategy(), | |||
| nonlineMode=nonlinear_mode, | |||
| mode=conv_mode, | |||
| compute_mode=compute_mode, | |||
| sparse=sparse_type, | |||
| ) | |||
| (outputs,) = apply(op, inp, weight, bias) | |||
| return outputs | |||
| @@ -0,0 +1,934 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import functools | |||
| import math | |||
| from itertools import accumulate | |||
| from typing import Iterable, List, Optional, Sequence, Tuple, Union | |||
| import numpy as np | |||
| from ..core._imperative_rt import CompNode | |||
| from ..core.ops import builtin | |||
| from ..core.ops._internal import param_defs as P | |||
| from ..core.ops.special import Const | |||
| from ..core.tensor.core import TensorBase, TensorWrapperBase, apply | |||
| from ..core.tensor.utils import ( | |||
| astensor1d, | |||
| convert_inputs, | |||
| convert_single_value, | |||
| dtype_promotion, | |||
| get_device, | |||
| ) | |||
| from ..device import get_default_device | |||
| from ..tensor import Tensor | |||
| from .elemwise import ceil | |||
| __all__ = [ | |||
| "add_axis", # expand_dims | |||
| "arange", | |||
| "broadcast", | |||
| "concat", | |||
| "cond_take", | |||
| "dimshuffle", # transpose, permute | |||
| "expand_dims", | |||
| "full", | |||
| "full_like", | |||
| "gather", | |||
| "eye", | |||
| "linspace", | |||
| "ones", | |||
| "ones_like", | |||
| "remove_axis", # squeeze | |||
| "split", | |||
| "squeeze", | |||
| "stack", | |||
| "reshape", | |||
| "scatter", | |||
| "where", | |||
| "zeros", | |||
| "zeros_like", | |||
| "param_pack_split", | |||
| "param_pack_concat", | |||
| ] | |||
| def eye(n: int, *, dtype=None, device: Optional[CompNode] = None) -> Tensor: | |||
| """ | |||
| Returns a 2D tensor with ones on the diagonal and zeros elsewhere. | |||
| :param n: The number of rows | |||
| :param m: The number of columns. Default: None | |||
| :param dtype: The data type. Default: None | |||
| :param device: Compute node of the matrix. Default: None | |||
| :param comp_graph: Compute graph of the matrix. Default: None | |||
| :return: The eye matrix | |||
| Examples: | |||
| .. testcode:: | |||
| import numpy as np | |||
| import megengine.functional as F | |||
| data_shape = (4, 6) | |||
| n, m = data_shape | |||
| out = F.eye(n, m, dtype=np.float32) | |||
| print(out.numpy()) | |||
| Outputs: | |||
| .. testoutput:: | |||
| [[1. 0. 0. 0. 0. 0.] | |||
| [0. 1. 0. 0. 0. 0.] | |||
| [0. 0. 1. 0. 0. 0.] | |||
| [0. 0. 0. 1. 0. 0.]] | |||
| """ | |||
| op = builtin.Eye(k=0, dtype=dtype, comp_node=device) | |||
| (result,) = apply(op, Tensor(n, dtype="int32", device=device)) | |||
| return result | |||
| def full(shape, value, dtype="float32", device=None): | |||
| if device is None: | |||
| device = get_default_device() | |||
| (x,) = Const(value, dtype=dtype, device=device)( | |||
| Tensor(value, dtype=dtype, device=device) | |||
| ) | |||
| return broadcast(x, shape) | |||
| def ones(shape, dtype="float32", device=None): | |||
| return full(shape, 1.0, dtype=dtype, device=device) | |||
| def zeros(shape, dtype="float32", device=None): | |||
| return full(shape, 0.0, dtype=dtype, device=device) | |||
| def zeros_like(inp: Tensor) -> Tensor: | |||
| r""" | |||
| Returns a zero tensor with the same shape as input tensor | |||
| :param inp: input tensor | |||
| Examples: | |||
| .. testcode:: | |||
| import numpy as np | |||
| from megengine import tensor | |||
| import megengine.functional as F | |||
| inp = tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3)) | |||
| out = F.zeros_like(inp) | |||
| print(out.numpy()) | |||
| .. testoutput:: | |||
| [[0 0 0] | |||
| [0 0 0]] | |||
| """ | |||
| return zeros(inp.shape, dtype=inp.dtype, device=inp.device) | |||
| def ones_like(inp: Tensor) -> Tensor: | |||
| r""" | |||
| Returns a identity tensor with the same shape as input tensor | |||
| """ | |||
| return ones(inp.shape, dtype=inp.dtype, device=inp.device) | |||
| def full_like(inp: Tensor, value: Union[int, float]) -> Tensor: | |||
| r""" | |||
| Returns a tensor filled with value val with the same shape as input tensor | |||
| """ | |||
| return full(inp.shape, value, dtype=inp.dtype, device=inp.device) | |||
| def broadcast(inp: Tensor, shape: Union[int, Iterable[int]]) -> Tensor: | |||
| """ | |||
| Broadcast a tensor to ``shape`` | |||
| :param inp: The input tensor | |||
| :param shape: The target shape | |||
| :return: The output tensor | |||
| Examples: | |||
| .. testcode:: | |||
| import numpy as np | |||
| from megengine import tensor | |||
| import megengine.functional as F | |||
| data = tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3)) | |||
| out = F.broadcast(data, (4, 2, 3)) | |||
| print(out.numpy()) | |||
| Outputs: | |||
| .. testoutput:: | |||
| [[[0. 1. 2.] | |||
| [3. 4. 5.]] | |||
| [[0. 1. 2.] | |||
| [3. 4. 5.]] | |||
| [[0. 1. 2.] | |||
| [3. 4. 5.]] | |||
| [[0. 1. 2.] | |||
| [3. 4. 5.]]] | |||
| """ | |||
| shape = astensor1d(shape, inp, dtype="int32", device=inp.device) | |||
| (result,) = apply(builtin.Broadcast(), inp, shape) | |||
| return result | |||
| def concat( | |||
| inps: Iterable[Tensor], axis: int = 0, device: Optional[CompNode] = None, | |||
| ) -> Tensor: | |||
| r""" | |||
| Concat some tensors | |||
| :param inps: Input tensors to concat | |||
| :param axis: the dimension over which the tensors are concatenated. Default: 0 | |||
| :param device: The comp node output on. Default: None | |||
| :param comp_graph: The graph in which output is. Default: None | |||
| :return: The output tensor | |||
| Examples: | |||
| .. testcode:: | |||
| import numpy as np | |||
| from megengine import tensor | |||
| import megengine.functional as F | |||
| data1 = tensor(np.arange(0, 6, dtype=np.float32).reshape((2, 3))) | |||
| data2 = tensor(np.arange(6, 12, dtype=np.float32).reshape((2, 3))) | |||
| out = F.concat([data1, data2]) | |||
| print(out.numpy()) | |||
| Outputs: | |||
| .. testoutput:: | |||
| [[ 0. 1. 2.] | |||
| [ 3. 4. 5.] | |||
| [ 6. 7. 8.] | |||
| [ 9. 10. 11.]] | |||
| """ | |||
| dtype = dtype_promotion(inps) | |||
| device = get_device(inps) | |||
| def convert(x): | |||
| return convert_single_value(x, inps, dtype=dtype) | |||
| inps = tuple(map(convert, inps)) | |||
| (result,) = apply(builtin.Concat(axis=axis, comp_node=device.to_c()), *inps) | |||
| return result | |||
| def stack(inps, axis=0): | |||
| """Concats a sequence of tensors along a new axis. | |||
| The input tensors must have the same shape. | |||
| :param inps: The input tensors. | |||
| :param axis: Which axis will be concatenated. | |||
| :return: The output concatenated tensor. | |||
| Examples: | |||
| .. testcode:: | |||
| import numpy as np | |||
| from megengine import tensor | |||
| import megengine.functional as F | |||
| x1 = tensor(np.arange(0, 6, dtype=np.float32).reshape((2, 3))) | |||
| x2 = tensor(np.arange(6, 12, dtype=np.float32).reshape((2, 3))) | |||
| out = F.stack([x1, x2], axis=0) | |||
| print(out.numpy()) | |||
| Outputs: | |||
| .. testoutput:: | |||
| [[[ 0. 1. 2.] | |||
| [ 3. 4. 5.]] | |||
| [[ 6. 7. 8.] | |||
| [ 9. 10. 11.]]] | |||
| """ | |||
| shapes = {arr.shape for arr in inps} | |||
| if len(shapes) != 1: | |||
| raise ValueError("All input tensors must have the same shape") | |||
| inps = [add_axis(inp, axis=axis) for inp in inps] | |||
| return concat(inps, axis=axis) | |||
| def split(inp, nsplits_or_sections, axis=0): | |||
| """Splits the input tensor into several smaller tensors. | |||
| When nsplits_or_sections is int, the last tensor may be smaller than others. | |||
| :param inp: The input tensor. | |||
| :param nsplits_or_sections: Number of sub tensors or section information list. | |||
| :param axis: Which axis will be splited. | |||
| :return: The output tensor list. | |||
| Examples: | |||
| .. testcode:: | |||
| import numpy as np | |||
| from megengine import tensor | |||
| import megengine.functional as F | |||
| x = tensor(np.random.random((2,3,4,5)), dtype=np.float32) | |||
| out = F.split(x, 2, axis=3) | |||
| print(out[0].shape, out[1].shape) | |||
| Outputs: | |||
| .. testoutput:: | |||
| (2, 3, 4, 3) (2, 3, 4, 2) | |||
| """ | |||
| sub_tensors = [] | |||
| sections = [] | |||
| def swapaxis(inp, src, dst): | |||
| if src == dst: | |||
| return inp | |||
| shape = [i for i in range(len(inp.shape))] | |||
| shape[src] = dst | |||
| shape[dst] = src | |||
| return inp.transpose(shape) | |||
| inp = swapaxis(inp, 0, axis) | |||
| if isinstance(nsplits_or_sections, int): | |||
| incr_step = math.ceil(inp.shape[0] / nsplits_or_sections) | |||
| while incr_step < inp.shape[0]: | |||
| sections.append(incr_step) | |||
| incr_step += nsplits_or_sections | |||
| else: | |||
| sections = nsplits_or_sections | |||
| st = 0 | |||
| for se in sections: | |||
| sub_tensors.append(swapaxis(inp[st:se], axis, 0)) | |||
| st = se | |||
| if st < inp.shape[0]: | |||
| sub_tensors.append(swapaxis(inp[st:], axis, 0)) | |||
| return sub_tensors | |||
| def _get_idx(index, axis): | |||
| index_dims = len(index.shape) | |||
| idx = [] | |||
| for i in range(index_dims): | |||
| if i != axis: | |||
| shape = [1] * index_dims | |||
| shape[i] = index.shape[i] | |||
| arange = linspace( | |||
| 0, index.shape[i] - 1, index.shape[i], device=index.device, | |||
| ) | |||
| arange = ( | |||
| arange.reshape(*shape) | |||
| .broadcast(index.shape) | |||
| .reshape(-1) | |||
| .astype(np.int32) | |||
| ) | |||
| idx.append(arange) | |||
| else: | |||
| idx.append(index.reshape(-1)) | |||
| return tuple(idx) | |||
| def gather(inp: Tensor, axis: int, index: Tensor) -> Tensor: | |||
| r""" | |||
| Gather data from :attr:`inp` on :attr:`axis` using :attr:`index`. | |||
| For a 3-D tensor, the output is specified by:: | |||
| out[i][j][k] = inp[index[i][j][k]][j][k] # if axis == 0 | |||
| out[i][j][k] = inp[i][index[i][j][k]][k] # if axis == 1 | |||
| out[i][j][k] = inp[i][j][index[i][j][k]] # if axis == 2 | |||
| if :attr:`inp` is an n-dimensional tensor with size | |||
| :math:`(x_0,x_1,...,x_{i-1},x_i,x_{i+1},...,x_{n-1})` and axis=i, | |||
| then :attr:`index` must be an n-dimensional tensor with size | |||
| :math:`(x_0,x_1,...,x_{i-1},y,x_{i+1},...,x_{n-1})` where :math:`y\ge 1` and | |||
| output will have the same size as :attr:`index`. | |||
| :param inp: the source tensor | |||
| :param axis: the axis along which to index | |||
| :param index: the indices of elements to gather | |||
| Examples: | |||
| .. testcode:: | |||
| import megengine.functional as F | |||
| from megengine import tensor | |||
| inp = tensor([ | |||
| [1,2], [3,4], [5,6], | |||
| ]) | |||
| index = tensor([[0,2], [1,0]]) | |||
| oup = F.gather(inp, 0, index) | |||
| print(oup.numpy()) | |||
| Outputs: | |||
| .. testoutput:: | |||
| [[1 6] | |||
| [3 2]] | |||
| """ | |||
| input_shape = inp.shape | |||
| index_shape = index.shape | |||
| input_dims = len(input_shape) | |||
| index_dims = len(index_shape) | |||
| if input_dims != index_dims: | |||
| raise ValueError( | |||
| "The index tensor must have same dimensions as input tensor, " | |||
| "But the input dims:{}, the index dims:{}".format(input_dims, index_dims) | |||
| ) | |||
| if axis < 0 or axis >= input_dims: | |||
| raise ValueError( | |||
| "Index axis {} is output of bounds, should in range [0 {})".format( | |||
| axis, input_dims | |||
| ) | |||
| ) | |||
| for i in range(input_dims): | |||
| if i != axis and input_shape[i] != index_shape[i]: | |||
| raise ValueError( | |||
| "The input {} and index {} must have the same size apart from axis {}".format( | |||
| input_shape, index_shape, axis | |||
| ) | |||
| ) | |||
| idx = _get_idx(index, axis) | |||
| return inp[idx].reshape(index.shape) # pylint: disable=no-member | |||
| def scatter(inp: Tensor, axis: int, index: Tensor, source: Tensor) -> Tensor: | |||
| r""" | |||
| Writes all values from the tensor :attr:`source` into :attr:`inp` at the indices specified in the :attr:`index` tensor. | |||
| For each value in :attr:`source`, its output index is specified by its index | |||
| in :attr:`source` for ``axis != dimension`` and by the corresponding value in | |||
| :attr:`index` for ``axis = dimension``. | |||
| For a 3-D tensor, :attr:`inp` is updated as:: | |||
| inp[index[i][j][k]][j][k] = source[i][j][k] # if axis == 0 | |||
| inp[i][index[i][j][k]][k] = source[i][j][k] # if axis == 1 | |||
| inp[i][j][index[i][j][k]] = source[i][j][k] # if axis == 2 | |||
| :attr:`inp`, :attr:`index` and :attr:`source` should have same number of dimensions. | |||
| It is also required that ``source.shape(d) <= inp.shape(d)`` and ``index.shape(d) == source.shape(d)`` | |||
| for all dimensions ``d``. | |||
| Moreover, the values of :attr:`index` must be between ``0`` and ``inp.shape(axis) - 1`` inclusive. | |||
| .. note:: | |||
| Please notice that, due to performance issues, the result is uncertain on the GPU device | |||
| if scatter difference positions from source to the same destination position | |||
| regard to index tensor. | |||
| Show the case using the following examples, the oup[0][2] is maybe | |||
| from source[0][2] which value is 0.2256 or source[1][2] which value is 0.5339 | |||
| if set the index[1][2] from 1 to 0. | |||
| :param inp: the inp tensor which to be scattered | |||
| :param axis: the axis along which to index | |||
| :param index: the indices of elements to scatter | |||
| :param source: the source element(s) to scatter | |||
| Examples: | |||
| .. testcode:: | |||
| import numpy as np | |||
| import megengine.functional as F | |||
| from megengine import tensor | |||
| inp = tensor(np.zeros(shape=(3,5),dtype=np.float32)) | |||
| source = tensor([[0.9935,0.9465,0.2256,0.8926,0.4396],[0.7723,0.0718,0.5939,0.357,0.4576]]) | |||
| index = tensor([[0,2,0,2,1],[2,0,1,1,2]]) | |||
| oup = F.scatter(inp, 0, index,source) | |||
| print(oup.numpy()) | |||
| Outputs: | |||
| .. testoutput:: | |||
| [[0.9935 0.0718 0.2256 0. 0. ] | |||
| [0. 0. 0.5939 0.357 0.4396] | |||
| [0.7723 0.9465 0. 0.8926 0.4576]] | |||
| """ | |||
| input_shape = inp.shape | |||
| index_shape = index.shape | |||
| source_shape = source.shape | |||
| input_dims = len(input_shape) | |||
| index_dims = len(index_shape) | |||
| source_dims = len(source_shape) | |||
| if input_dims != index_dims or input_dims != source_dims: | |||
| raise ValueError("The input, source and index tensor must have same dimensions") | |||
| if axis < 0 or axis >= input_dims: | |||
| raise ValueError( | |||
| "Index axis {} is output of bounds, should in range [0 {})".format( | |||
| axis, input_dims | |||
| ) | |||
| ) | |||
| for i in range(source_dims): | |||
| if source_shape[i] > input_shape[i]: | |||
| raise ValueError( | |||
| "The each shape size for source {} must be less than or equal to input {} ".format( | |||
| source_shape, input_shape | |||
| ) | |||
| ) | |||
| for i in range(index_dims): | |||
| if index_shape[i] != source_shape[i]: | |||
| raise ValueError( | |||
| "The each shape size for index {} must be equal to source {} ".format( | |||
| index_shape, source_shape | |||
| ) | |||
| ) | |||
| for i in range(index_dims): | |||
| if i != axis and index_shape[i] > input_shape[i]: | |||
| raise ValueError( | |||
| "The index {} must be less than or equal to input {} size apart from axis {}".format( | |||
| index_shape, input_shape, axis | |||
| ) | |||
| ) | |||
| idx = _get_idx(index, axis) | |||
| inp[idx] = source.flatten() | |||
| return inp | |||
| def where(mask: Tensor, x: Tensor, y: Tensor) -> Tensor: | |||
| r""" | |||
| Select elements either from Tensor x or Tensor y, according to mask. | |||
| .. math:: | |||
| \textrm{out}_i = x_i \textrm{ if } \textrm{mask}_i \textrm{ is True else } y_i | |||
| :param mask: a mask used for choosing x or y | |||
| :param x: the first choice | |||
| :param y: the second choice | |||
| Examples: | |||
| .. testcode:: | |||
| from megengine import tensor | |||
| import megengine.functional as F | |||
| mask = tensor(np.array([[1, 0], [0, 1]], dtype=np.int32)) | |||
| x = tensor(np.array([[1, np.inf], [np.nan, 4]], | |||
| dtype=np.float32)) | |||
| y = tensor(np.array([[5, 6], [7, 8]], dtype=np.float32)) | |||
| out = F.where(mask, x, y) | |||
| print(out.numpy()) | |||
| Outputs: | |||
| .. testoutput:: | |||
| [[1. 6.] | |||
| [7. 4.]] | |||
| """ | |||
| raise NotImplementedError | |||
| # v0, index0 = mgb.opr.cond_take( | |||
| # x, mask, mode=P.CondTake.Mode.EQ, val=1 | |||
| # ) | |||
| # v1, index1 = mgb.opr.cond_take( | |||
| # y, mask, mode=P.CondTake.Mode.EQ, val=0 | |||
| # ) | |||
| # out = x.flatten() | |||
| # index = mgb.opr.concat(index0, index1, axis=0) | |||
| # v = mgb.opr.concat(v0, v1, axis=0) | |||
| # out = mgb.opr.set_advanced_indexing(out, v)[index] | |||
| # out = out.reshape(x.shape) | |||
| # return out | |||
| def cond_take(mask: Tensor, x: Tensor) -> Tensor: | |||
| r""" | |||
| Take elements from data if specific condition is satisfied on mask. This operator has two outputs: the first is the elements taken, and the second is the indices corresponding to those elements; they are both 1-dimensional. High-dimension input would first be flattened. | |||
| :param mask: condition param; must be the same shape with data | |||
| :param x: input tensor from which to take elements | |||
| Examples: | |||
| .. testcode:: | |||
| import numpy as np | |||
| from megengine import tensor | |||
| import megengine.functional as F | |||
| mask = tensor(np.array([[True, False], [False, True]], dtype=np.bool_)) | |||
| x = tensor(np.array([[1, np.inf], [np.nan, 4]], | |||
| dtype=np.float32)) | |||
| v, index = F.cond_take(mask, x) | |||
| print(v.numpy(), index.numpy()) | |||
| Outputs: | |||
| .. testoutput:: | |||
| Tensor([1. 4.]) Tensor([0 3], dtype=int32) | |||
| """ | |||
| if not isinstance(x, (TensorWrapperBase, TensorBase)): | |||
| raise TypeError("input must be a tensor") | |||
| if not isinstance(mask, (TensorWrapperBase, TensorBase)): | |||
| raise TypeError("mask must be a tensor") | |||
| if mask.dtype != np.bool_: | |||
| raise ValueError("mask must be bool") | |||
| if x.device != mask.device: | |||
| raise ValueError("ambiguous device: {} vs {}".format(x.device, mask.device)) | |||
| op = builtin.CondTake() | |||
| v, index = apply(op, x, mask) | |||
| return v, index | |||
| def dimshuffle(inp: Tensor, pattern: Iterable[int]) -> Tensor: | |||
| r""" | |||
| Swap shapes and strides according to given pattern | |||
| :param inp: Input tensor | |||
| :param pattern: a list of integers including 0, 1, ... , ``ndim``-1, and any number of ``'x'`` char in dimensions where this tensor should be broadcasted. For examples: | |||
| * (``'x'``) -> make a 0d (scalar) into a 1d vector | |||
| * (0, 1) -> identity for 2d vectors | |||
| * (1, 0) -> inverts the first and second dimensions | |||
| * (``'x'``, 0) -> make a row out of a 1d vector (N to 1xN) | |||
| * (0, ``'x'``) -> make a column out of a 1d vector (N to Nx1) | |||
| * (2, 0, 1) -> AxBxC to CxAxB | |||
| * (0, ``'x'``, 1) -> AxB to Ax1xB | |||
| * (1, ``'x'``, 0) -> AxB to Bx1xA | |||
| * (1,) -> This remove dimensions 0. It must be a broadcastable dimension (1xA to A) | |||
| :return: The output tensor | |||
| Examples: | |||
| .. testcode:: | |||
| import numpy as np | |||
| from megengine import tensor | |||
| import megengine.functional as F | |||
| x = tensor(np.array([[1, 1], [0, 0]], dtype=np.int32)) | |||
| out = F.dimshuffle(x, (1, 0)) | |||
| print(out.numpy()) | |||
| Outputs: | |||
| .. testoutput:: | |||
| [[1 0] | |||
| [1 0]] | |||
| """ | |||
| op = builtin.Dimshuffle(pattern) | |||
| (inp,) = convert_inputs(inp) | |||
| (result,) = apply(op, inp) | |||
| return result | |||
| def reshape(inp: Tensor, target_shape: Iterable[int]) -> Tensor: | |||
| r""" | |||
| Reshape a tensor to given target shape; total number of logical elements must | |||
| remain unchanged | |||
| :param inp: Input tensor | |||
| :param target_shape: target shape, the components would be concatenated to form the | |||
| target shape, and it can contain an element of -1 representing unspec_axis. | |||
| Examples: | |||
| .. testcode:: | |||
| import numpy as np | |||
| from megengine import tensor | |||
| import megengine.functional as F | |||
| x = tensor(np.arange(12, dtype=np.int32)) | |||
| out = F.reshape(x, (3, 2, 2)) | |||
| print(out.numpy()) | |||
| Outputs: | |||
| .. testoutput:: | |||
| [[[ 0 1] | |||
| [ 2 3]] | |||
| [[ 4 5] | |||
| [ 6 7]] | |||
| [[ 8 9] | |||
| [10 11]]] | |||
| """ | |||
| if isinstance(target_shape, (TensorBase, TensorWrapperBase)): | |||
| target_shape = target_shape.numpy() | |||
| target_shape = tuple(map(int, target_shape)) | |||
| unspec_axis = None | |||
| for i, s in enumerate(target_shape): | |||
| if s < 0: | |||
| if s != -1: | |||
| raise ValueError("expect shape[{}] >= -1, got {}".format(i, s)) | |||
| if unspec_axis is not None: | |||
| raise ValueError("multiple -1 in shape: {} & {}".format(unspec_axis, i)) | |||
| unspec_axis = i | |||
| # TODO: device should be None (cpu) | |||
| (target_shape,) = Const(target_shape, dtype="int32", device=inp.device)(inp) | |||
| if unspec_axis is None: | |||
| op = builtin.Reshape() | |||
| else: | |||
| op = builtin.Reshape(unspec_axis=unspec_axis) | |||
| (x,) = apply(op, inp, target_shape) | |||
| return x | |||
| transpose = dimshuffle | |||
| AxisAddRemove = builtin.AxisAddRemove | |||
| AxisDesc = AxisAddRemove.AxisDesc | |||
| def add_axis(inp: Tensor, axis: Union[int, Sequence[int]]) -> Tensor: | |||
| r""" | |||
| Add dimension before given axis. | |||
| :param inp: Input tensor | |||
| :param axis: Place of new axes | |||
| :return: The output tensor | |||
| Examples: | |||
| .. testcode:: | |||
| import numpy as np | |||
| from megengine import tensor | |||
| import megengine.functional as F | |||
| x = tensor([1, 2]) | |||
| out = F.add_axis(x, 0) | |||
| print(out.shape) | |||
| Outputs: | |||
| .. testoutput:: | |||
| (1, 2) | |||
| """ | |||
| Param = AxisAddRemove.Param | |||
| def get_axes(): | |||
| try: | |||
| return [int(axis)] | |||
| except (TypeError, ValueError): | |||
| pass | |||
| return list(map(int, axis)) | |||
| axis = get_axes() | |||
| ndim = inp.ndim + len(axis) | |||
| axis = sorted(i + ndim if i < 0 else i for i in axis) | |||
| param = Param(*map(AxisDesc.make_add, axis)) | |||
| op = AxisAddRemove(param=param) | |||
| (result,) = apply(op, inp) | |||
| return result | |||
| expand_dims = add_axis | |||
| def remove_axis(inp: Tensor, axis: Union[int, Sequence[int]]) -> Tensor: | |||
| r""" | |||
| Remove dimension of shape 1. | |||
| :param inp: Input tensor | |||
| :param axis: Place of axis to be removed | |||
| :return: The output tensor | |||
| Examples: | |||
| .. testcode:: | |||
| import numpy as np | |||
| from megengine import tensor | |||
| import megengine.functional as F | |||
| x = tensor(np.array([1, 2], dtype=np.int32).reshape(1, 1, 2, 1)) | |||
| out = F.remove_axis(x, 3) | |||
| print(out.shape) | |||
| Outputs: | |||
| .. testoutput:: | |||
| (1, 1, 2) | |||
| """ | |||
| Param = AxisAddRemove.Param | |||
| def get_axes(): | |||
| if axis is None: | |||
| return [i for i, s in enumerate(inp.shape) if s == 1] | |||
| try: | |||
| return [int(axis)] | |||
| except (TypeError, ValueError): | |||
| pass | |||
| return list(map(int, axis)) | |||
| axis = get_axes() | |||
| axis = sorted(i + inp.ndim if i < 0 else i for i in axis) | |||
| axis = [a - i for i, a in enumerate(axis)] | |||
| param = Param(*map(AxisDesc.make_remove, axis)) | |||
| op = AxisAddRemove(param=param) | |||
| (result,) = apply(op, inp) | |||
| return result | |||
| squeeze = remove_axis | |||
| def linspace( | |||
| start: Union[int, float, Tensor], | |||
| stop: Union[int, float, Tensor], | |||
| num: Union[int, Tensor], | |||
| dtype="float32", | |||
| device: Optional[CompNode] = None, | |||
| ) -> Tensor: | |||
| r""" | |||
| Return equally spaced numbers over a specified interval | |||
| :param start: Starting value of the squence, shoule be scalar | |||
| :param stop: The last value of the squence, shoule be scalar | |||
| :param num: number of values to generate | |||
| :param dtype: result data type | |||
| :return: The generated tensor | |||
| Examples: | |||
| .. testcode:: | |||
| import numpy as np | |||
| import megengine.functional as F | |||
| a = F.linspace(3,10,5) | |||
| print(a.numpy()) | |||
| .. testoutput:: | |||
| [ 3. 4.75 6.5 8.25 10. ] | |||
| """ | |||
| start = Tensor(start, device=device) | |||
| stop = Tensor(stop, device=device) | |||
| num = Tensor(num, device=device) | |||
| device = device if device is None else device.to_c() | |||
| op = builtin.Linspace(comp_node=device) | |||
| (result,) = apply(op, start, stop, num) | |||
| if np.dtype(dtype) == np.int32: | |||
| return result.astype(dtype) | |||
| return result | |||
| def arange( | |||
| start: Union[int, float, Tensor], | |||
| end: Union[int, float, Tensor], | |||
| step: Union[int, float, Tensor] = 1, | |||
| dtype="float32", | |||
| device: Optional[CompNode] = None, | |||
| ) -> Tensor: | |||
| r""" | |||
| Returns a Tensor with values from `start` to `end` with adjacent interval `step` | |||
| :param start: starting value of the squence, shoule be scalar | |||
| :param end: ending value of the squence, shoule be scalar | |||
| :param step: the gap between each pair of adjacent values. Default 1 | |||
| :param dtype: result data type | |||
| :return: The generated tensor | |||
| Examples: | |||
| .. testcode:: | |||
| import numpy as np | |||
| import megengine.functional as F | |||
| a = F.arange(1, 5, 1) | |||
| print(a.numpy()) | |||
| .. testoutput:: | |||
| [1. 2. 3. 4.] | |||
| """ | |||
| if isinstance(start, Tensor): | |||
| start = start.astype("float32") | |||
| if isinstance(end, Tensor): | |||
| end = end.astype("float32") | |||
| if isinstance(step, Tensor): | |||
| step = step.astype("float32") | |||
| num = ceil(Tensor((end - start) / step, device=device)) | |||
| stop = start + step * (num - 1) | |||
| result = linspace(start, stop, num, device=device) | |||
| if np.dtype(dtype) == np.int32: | |||
| return result.astype(dtype) | |||
| return result | |||
| def param_pack_split(inp: Tensor, offsets: List, shapes: List) -> Tensor: | |||
| op = builtin.ParamPackSplit() | |||
| op.offsets = offsets | |||
| op.shapes = shapes | |||
| return apply(op, inp) | |||
| def param_pack_concat(inps: List, offsets: Tensor, offsets_val: List) -> Tensor: | |||
| op = builtin.ParamPackConcat() | |||
| op.offsets = offsets_val | |||
| return apply(op, *inps, offsets)[0] | |||
| @@ -0,0 +1,37 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import collections | |||
| import functools | |||
| def get_ndtuple(value, *, n, allow_zero=True): | |||
| r"""Converts possibly 1D tuple to nd tuple | |||
| :type allow_zero: bool | |||
| :param allow_zero: whether to allow zero tuple value""" | |||
| if not isinstance(value, collections.Iterable): | |||
| value = int(value) | |||
| value = tuple([value for i in range(n)]) | |||
| else: | |||
| assert len(value) == n, "tuple len is not equal to n: {}".format(value) | |||
| spatial_axis = map(int, value) | |||
| value = tuple(spatial_axis) | |||
| if allow_zero: | |||
| minv = 0 | |||
| else: | |||
| minv = 1 | |||
| assert min(value) >= minv, "invalid value: {}".format(value) | |||
| return value | |||
| _single = functools.partial(get_ndtuple, n=1, allow_zero=True) | |||
| _pair = functools.partial(get_ndtuple, n=2, allow_zero=True) | |||
| _pair_nonzero = functools.partial(get_ndtuple, n=2, allow_zero=False) | |||
| _triple = functools.partial(get_ndtuple, n=3, allow_zero=True) | |||
| _quadruple = functools.partial(get_ndtuple, n=4, allow_zero=True) | |||
| @@ -0,0 +1,80 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import collections | |||
| from typing import Iterable, Union | |||
| import numpy as np | |||
| from ..core.ops.builtin import Copy | |||
| from ..core.tensor import Tensor | |||
| from ..core.tensor.core import apply | |||
| from .math import topk as _topk | |||
| from .tensor import dimshuffle as _dimshuffle | |||
| def accuracy( | |||
| logits: Tensor, target: Tensor, topk: Union[int, Iterable[int]] = 1 | |||
| ) -> Union[Tensor, Iterable[Tensor]]: | |||
| r""" | |||
| Calculate the classification accuracy given predicted logits and ground-truth labels. | |||
| :param logits: Model predictions of shape [batch_size, num_classes], | |||
| representing the probability (likelyhood) of each class. | |||
| :param target: Ground-truth labels, 1d tensor of int32 | |||
| :param topk: Specifies the topk values, could be an int or tuple of ints. Default: 1 | |||
| :return: Tensor(s) of classification accuracy between 0.0 and 1.0 | |||
| Examples: | |||
| .. testcode:: | |||
| import numpy as np | |||
| from megengine import tensor | |||
| import megengine.functional as F | |||
| logits = tensor(np.arange(80, dtype=np.int32).reshape(8,10)) | |||
| target = tensor(np.arange(8, dtype=np.int32)) | |||
| top1, top5 = F.accuracy(logits, target, (1, 5)) | |||
| print(top1.numpy(), top5.numpy()) | |||
| Outputs: | |||
| .. testoutput:: | |||
| [0.] [0.375] | |||
| """ | |||
| if isinstance(topk, int): | |||
| topk = (topk,) | |||
| _, pred = _topk(logits, k=max(topk), descending=True) | |||
| accs = [] | |||
| for k in topk: | |||
| correct = pred[:, :k].detach() == _dimshuffle(target, (0, "x")).broadcast( | |||
| target.shape[0], k | |||
| ) | |||
| accs.append(correct.astype(np.float32).sum() / target.shape[0]) | |||
| if len(topk) == 1: # type: ignore[arg-type] | |||
| accs = accs[0] | |||
| return accs | |||
| def zero_grad(inp: Tensor) -> Tensor: | |||
| r""" | |||
| Returns a tensor which is treated as constant during backward gradient calcuation, | |||
| i.e. its gradient is zero. | |||
| :param inp: Input tensor. | |||
| See implementation of :func:`~.softmax` for example. | |||
| """ | |||
| print("zero_grad is obsoleted, please use detach instead") | |||
| raise NotImplementedError | |||
| def copy(inp, cn): | |||
| return apply(Copy(comp_node=cn), inp)[0] | |||
| @@ -0,0 +1,16 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from .hub import ( | |||
| help, | |||
| import_module, | |||
| list, | |||
| load, | |||
| load_serialized_obj_from_url, | |||
| pretrained, | |||
| ) | |||
| @@ -0,0 +1,17 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| DEFAULT_BRANCH_NAME = "master" | |||
| HUBCONF = "hubconf.py" | |||
| HUBDEPENDENCY = "dependencies" | |||
| DEFAULT_GIT_HOST = "github.com" | |||
| ENV_MGE_HOME = "MGE_HOME" | |||
| ENV_XDG_CACHE_HOME = "XDG_CACHE_HOME" | |||
| DEFAULT_CACHE_DIR = "~/.cache" | |||
| DEFAULT_PROTOCOL = "HTTPS" | |||
| HTTP_READ_TIMEOUT = 120 | |||
| @@ -0,0 +1,30 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| class FetcherError(Exception): | |||
| """Base class for fetch related error.""" | |||
| class InvalidRepo(FetcherError): | |||
| """The repo provided was somehow invalid.""" | |||
| class InvalidGitHost(FetcherError): | |||
| """The git host provided was somehow invalid.""" | |||
| class GitPullError(FetcherError): | |||
| """A git pull error occurred""" | |||
| class GitCheckoutError(FetcherError): | |||
| """A git checkout error occurred""" | |||
| class InvalidProtocol(FetcherError): | |||
| """The protocol provided was somehow invalid""" | |||
| @@ -0,0 +1,300 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import hashlib | |||
| import os | |||
| import re | |||
| import shutil | |||
| import subprocess | |||
| from tempfile import NamedTemporaryFile | |||
| from typing import Tuple | |||
| from zipfile import ZipFile | |||
| import requests | |||
| from tqdm import tqdm | |||
| from megengine.utils.http_download import ( | |||
| CHUNK_SIZE, | |||
| HTTP_CONNECTION_TIMEOUT, | |||
| HTTPDownloadError, | |||
| ) | |||
| from ..distributed import is_distributed, synchronized | |||
| from ..logger import get_logger | |||
| from .const import DEFAULT_BRANCH_NAME, HTTP_READ_TIMEOUT | |||
| from .exceptions import GitCheckoutError, GitPullError, InvalidGitHost, InvalidRepo | |||
| from .tools import cd | |||
| logger = get_logger(__name__) | |||
| HTTP_TIMEOUT = (HTTP_CONNECTION_TIMEOUT, HTTP_READ_TIMEOUT) | |||
| pattern = re.compile( | |||
| r"^(?:[a-z0-9]" # First character of the domain | |||
| r"(?:[a-z0-9-_]{0,61}[a-z0-9])?\.)" # Sub domain + hostname | |||
| r"+[a-z0-9][a-z0-9-_]{0,61}" # First 61 characters of the gTLD | |||
| r"[a-z]$" # Last character of the gTLD | |||
| ) | |||
| class RepoFetcherBase: | |||
| @classmethod | |||
| def fetch( | |||
| cls, | |||
| git_host: str, | |||
| repo_info: str, | |||
| use_cache: bool = False, | |||
| commit: str = None, | |||
| silent: bool = True, | |||
| ) -> str: | |||
| raise NotImplementedError() | |||
| @classmethod | |||
| def _parse_repo_info(cls, repo_info: str) -> Tuple[str, str, str]: | |||
| try: | |||
| branch_info = DEFAULT_BRANCH_NAME | |||
| if ":" in repo_info: | |||
| prefix_info, branch_info = repo_info.split(":") | |||
| else: | |||
| prefix_info = repo_info | |||
| repo_owner, repo_name = prefix_info.split("/") | |||
| return repo_owner, repo_name, branch_info | |||
| except ValueError: | |||
| raise InvalidRepo("repo_info: '{}' is invalid.".format(repo_info)) | |||
| @classmethod | |||
| def _check_git_host(cls, git_host): | |||
| return cls._is_valid_domain(git_host) or cls._is_valid_host(git_host) | |||
| @classmethod | |||
| def _is_valid_domain(cls, s): | |||
| try: | |||
| return pattern.match(s.encode("idna").decode("ascii")) | |||
| except UnicodeError: | |||
| return False | |||
| @classmethod | |||
| def _is_valid_host(cls, s): | |||
| nums = s.split(".") | |||
| if len(nums) != 4 or any(not _.isdigit() for _ in nums): | |||
| return False | |||
| return all(0 <= int(_) < 256 for _ in nums) | |||
| @classmethod | |||
| def _gen_repo_dir(cls, repo_dir: str) -> str: | |||
| return hashlib.sha1(repo_dir.encode()).hexdigest()[:16] | |||
| class GitSSHFetcher(RepoFetcherBase): | |||
| @classmethod | |||
| @synchronized | |||
| def fetch( | |||
| cls, | |||
| git_host: str, | |||
| repo_info: str, | |||
| use_cache: bool = False, | |||
| commit: str = None, | |||
| silent: bool = True, | |||
| ) -> str: | |||
| """ | |||
| Fetches git repo by SSH protocol | |||
| :param git_host: | |||
| host address of git repo. | |||
| example: github.com | |||
| :param repo_info: | |||
| a string with format ``"repo_owner/repo_name[:tag_name/:branch_name]"`` with an optional | |||
| tag/branch. The default branch is ``master`` if not specified. | |||
| example: ``"brain_sdk/MegBrain[:hub]"`` | |||
| :param use_cache: | |||
| whether to use locally fetched code or completely re-fetch | |||
| :param commit: | |||
| commit id on github or gitlab | |||
| :param silent: | |||
| whether to accept the stdout and stderr of the subprocess with PIPE, instead of | |||
| displaying on the screen | |||
| :return: | |||
| directory where the repo code is stored | |||
| """ | |||
| if not cls._check_git_host(git_host): | |||
| raise InvalidGitHost("git_host: '{}' is malformed.".format(git_host)) | |||
| repo_owner, repo_name, branch_info = cls._parse_repo_info(repo_info) | |||
| normalized_branch_info = branch_info.replace("/", "_") | |||
| repo_dir_raw = "{}_{}_{}".format( | |||
| repo_owner, repo_name, normalized_branch_info | |||
| ) + ("_{}".format(commit) if commit else "") | |||
| repo_dir = cls._gen_repo_dir(repo_dir_raw) | |||
| git_url = "git@{}:{}/{}.git".format(git_host, repo_owner, repo_name) | |||
| if use_cache and os.path.exists(repo_dir): # use cache | |||
| logger.debug("Cache Found in %s", repo_dir) | |||
| return repo_dir | |||
| if is_distributed(): | |||
| logger.warning( | |||
| "When using `hub.load` or `hub.list` to fetch git repositories\n" | |||
| " in DISTRIBUTED mode for the first time, processes are synchronized to\n" | |||
| " ensure that target repository is ready to use for each process.\n" | |||
| " Users are expected to see this warning no more than ONCE, otherwise\n" | |||
| " (very little chance) you may need to remove corrupt cache\n" | |||
| " `%s` and fetch again.", | |||
| repo_dir, | |||
| ) | |||
| shutil.rmtree(repo_dir, ignore_errors=True) # ignore and clear cache | |||
| logger.debug( | |||
| "Git Clone from Repo:%s Branch: %s to %s", | |||
| git_url, | |||
| normalized_branch_info, | |||
| repo_dir, | |||
| ) | |||
| kwargs = ( | |||
| {"stderr": subprocess.PIPE, "stdout": subprocess.PIPE} if silent else {} | |||
| ) | |||
| if commit is None: | |||
| # shallow clone repo by branch/tag | |||
| p = subprocess.Popen( | |||
| [ | |||
| "git", | |||
| "clone", | |||
| "-b", | |||
| normalized_branch_info, | |||
| git_url, | |||
| repo_dir, | |||
| "--depth=1", | |||
| ], | |||
| **kwargs, | |||
| ) | |||
| cls._check_clone_pipe(p) | |||
| else: | |||
| # clone repo and checkout to commit_id | |||
| p = subprocess.Popen(["git", "clone", git_url, repo_dir], **kwargs) | |||
| cls._check_clone_pipe(p) | |||
| with cd(repo_dir): | |||
| logger.debug("git checkout to %s", commit) | |||
| p = subprocess.Popen(["git", "checkout", commit], **kwargs) | |||
| _, err = p.communicate() | |||
| if p.returncode: | |||
| shutil.rmtree(repo_dir, ignore_errors=True) | |||
| raise GitCheckoutError( | |||
| "Git checkout error, please check the commit id.\n" | |||
| + err.decode() | |||
| ) | |||
| with cd(repo_dir): | |||
| shutil.rmtree(".git") | |||
| return repo_dir | |||
| @classmethod | |||
| def _check_clone_pipe(cls, p): | |||
| _, err = p.communicate() | |||
| if p.returncode: | |||
| raise GitPullError( | |||
| "Repo pull error, please check repo info.\n" + err.decode() | |||
| ) | |||
| class GitHTTPSFetcher(RepoFetcherBase): | |||
| @classmethod | |||
| @synchronized | |||
| def fetch( | |||
| cls, | |||
| git_host: str, | |||
| repo_info: str, | |||
| use_cache: bool = False, | |||
| commit: str = None, | |||
| silent: bool = True, | |||
| ) -> str: | |||
| """ | |||
| Fetches git repo by HTTPS protocol | |||
| :param git_host: | |||
| host address of git repo | |||
| example: github.com | |||
| :param repo_info: | |||
| a string with format ``"repo_owner/repo_name[:tag_name/:branch_name]"`` with an optional | |||
| tag/branch. The default branch is ``master`` if not specified. | |||
| example: ``"brain_sdk/MegBrain[:hub]"`` | |||
| :param use_cache: | |||
| whether to use locally cached code or completely re-fetch | |||
| :param commit: | |||
| commit id on github or gitlab | |||
| :param silent: | |||
| whether to accept the stdout and stderr of the subprocess with PIPE, instead of | |||
| displaying on the screen | |||
| :return: | |||
| directory where the repo code is stored | |||
| """ | |||
| if not cls._check_git_host(git_host): | |||
| raise InvalidGitHost("git_host: '{}' is malformed.".format(git_host)) | |||
| repo_owner, repo_name, branch_info = cls._parse_repo_info(repo_info) | |||
| normalized_branch_info = branch_info.replace("/", "_") | |||
| repo_dir_raw = "{}_{}_{}".format( | |||
| repo_owner, repo_name, normalized_branch_info | |||
| ) + ("_{}".format(commit) if commit else "") | |||
| repo_dir = cls._gen_repo_dir(repo_dir_raw) | |||
| archive_url = cls._git_archive_link( | |||
| git_host, repo_owner, repo_name, branch_info, commit | |||
| ) | |||
| if use_cache and os.path.exists(repo_dir): # use cache | |||
| logger.debug("Cache Found in %s", repo_dir) | |||
| return repo_dir | |||
| if is_distributed(): | |||
| logger.warning( | |||
| "When using `hub.load` or `hub.list` to fetch git repositories " | |||
| "in DISTRIBUTED mode for the first time, processes are synchronized to " | |||
| "ensure that target repository is ready to use for each process.\n" | |||
| "Users are expected to see this warning no more than ONCE, otherwise" | |||
| "(very little chance) you may need to remove corrupt hub cache %s and fetch again." | |||
| ) | |||
| shutil.rmtree(repo_dir, ignore_errors=True) # ignore and clear cache | |||
| logger.debug("Downloading from %s to %s", archive_url, repo_dir) | |||
| cls._download_zip_and_extract(archive_url, repo_dir) | |||
| return repo_dir | |||
| @classmethod | |||
| def _download_zip_and_extract(cls, url, target_dir): | |||
| resp = requests.get(url, timeout=HTTP_TIMEOUT, stream=True) | |||
| if resp.status_code != 200: | |||
| raise HTTPDownloadError( | |||
| "An error occured when downloading from {}".format(url) | |||
| ) | |||
| total_size = int(resp.headers.get("Content-Length", 0)) | |||
| _bar = tqdm(total=total_size, unit="iB", unit_scale=True) | |||
| with NamedTemporaryFile("w+b") as f: | |||
| for chunk in resp.iter_content(CHUNK_SIZE): | |||
| if not chunk: | |||
| break | |||
| _bar.update(len(chunk)) | |||
| f.write(chunk) | |||
| _bar.close() | |||
| f.seek(0) | |||
| with ZipFile(f) as temp_zip_f: | |||
| zip_dir_name = temp_zip_f.namelist()[0].split("/")[0] | |||
| temp_zip_f.extractall(".") | |||
| shutil.move(zip_dir_name, target_dir) | |||
| @classmethod | |||
| def _git_archive_link(cls, git_host, repo_owner, repo_name, branch_info, commit): | |||
| archive_link = "https://{}/{}/{}/archive/{}.zip".format( | |||
| git_host, repo_owner, repo_name, commit or branch_info | |||
| ) | |||
| return archive_link | |||
| @@ -0,0 +1,333 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import functools | |||
| import hashlib | |||
| import os | |||
| import sys | |||
| import types | |||
| from typing import Any, List | |||
| from urllib.parse import urlparse | |||
| from megengine.utils.http_download import download_from_url | |||
| from ..distributed import is_distributed | |||
| from ..logger import get_logger | |||
| from ..serialization import load as _mge_load_serialized | |||
| from .const import ( | |||
| DEFAULT_CACHE_DIR, | |||
| DEFAULT_GIT_HOST, | |||
| DEFAULT_PROTOCOL, | |||
| ENV_MGE_HOME, | |||
| ENV_XDG_CACHE_HOME, | |||
| HTTP_READ_TIMEOUT, | |||
| HUBCONF, | |||
| HUBDEPENDENCY, | |||
| ) | |||
| from .exceptions import InvalidProtocol | |||
| from .fetcher import GitHTTPSFetcher, GitSSHFetcher | |||
| from .tools import cd, check_module_exists, load_module | |||
| logger = get_logger(__name__) | |||
| PROTOCOLS = { | |||
| "HTTPS": GitHTTPSFetcher, | |||
| "SSH": GitSSHFetcher, | |||
| } | |||
| def _get_megengine_home() -> str: | |||
| """MGE_HOME setting complies with the XDG Base Directory Specification | |||
| """ | |||
| megengine_home = os.path.expanduser( | |||
| os.getenv( | |||
| ENV_MGE_HOME, | |||
| os.path.join(os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), "megengine"), | |||
| ) | |||
| ) | |||
| return megengine_home | |||
| def _get_repo( | |||
| git_host: str, | |||
| repo_info: str, | |||
| use_cache: bool = False, | |||
| commit: str = None, | |||
| protocol: str = DEFAULT_PROTOCOL, | |||
| ) -> str: | |||
| if protocol not in PROTOCOLS: | |||
| raise InvalidProtocol( | |||
| "Invalid protocol, the value should be one of {}.".format( | |||
| ", ".join(PROTOCOLS.keys()) | |||
| ) | |||
| ) | |||
| cache_dir = os.path.expanduser(os.path.join(_get_megengine_home(), "hub")) | |||
| with cd(cache_dir): | |||
| fetcher = PROTOCOLS[protocol] | |||
| repo_dir = fetcher.fetch(git_host, repo_info, use_cache, commit) | |||
| return os.path.join(cache_dir, repo_dir) | |||
| def _check_dependencies(module: types.ModuleType) -> None: | |||
| if not hasattr(module, HUBDEPENDENCY): | |||
| return | |||
| dependencies = getattr(module, HUBDEPENDENCY) | |||
| if not dependencies: | |||
| return | |||
| missing_deps = [m for m in dependencies if not check_module_exists(m)] | |||
| if len(missing_deps): | |||
| raise RuntimeError("Missing dependencies: {}".format(", ".join(missing_deps))) | |||
| def _init_hub( | |||
| repo_info: str, | |||
| git_host: str, | |||
| use_cache: bool = True, | |||
| commit: str = None, | |||
| protocol: str = DEFAULT_PROTOCOL, | |||
| ): | |||
| """Imports hubmodule like python import | |||
| :param repo_info: | |||
| a string with format ``"repo_owner/repo_name[:tag_name/:branch_name]"`` with an optional | |||
| tag/branch. The default branch is ``master`` if not specified. | |||
| Example: ``"brain_sdk/MegBrain[:hub]"`` | |||
| :param git_host: | |||
| host address of git repo | |||
| Example: github.com | |||
| :param use_cache: | |||
| whether to use locally cached code or completely re-fetch | |||
| :param commit: | |||
| commit id on github or gitlab | |||
| :param protocol: | |||
| which protocol to use to get the repo, and HTTPS protocol only supports public repo on github. | |||
| The value should be one of HTTPS, SSH. | |||
| :return: | |||
| hubconf.py as a python module | |||
| """ | |||
| cache_dir = os.path.expanduser(os.path.join(_get_megengine_home(), "hub")) | |||
| os.makedirs(cache_dir, exist_ok=True) | |||
| absolute_repo_dir = _get_repo( | |||
| git_host, repo_info, use_cache=use_cache, commit=commit, protocol=protocol | |||
| ) | |||
| sys.path.insert(0, absolute_repo_dir) | |||
| hubmodule = load_module(HUBCONF, os.path.join(absolute_repo_dir, HUBCONF)) | |||
| sys.path.remove(absolute_repo_dir) | |||
| return hubmodule | |||
| @functools.wraps(_init_hub) | |||
| def import_module(*args, **kwargs): | |||
| return _init_hub(*args, **kwargs) | |||
| def list( | |||
| repo_info: str, | |||
| git_host: str = DEFAULT_GIT_HOST, | |||
| use_cache: bool = True, | |||
| commit: str = None, | |||
| protocol: str = DEFAULT_PROTOCOL, | |||
| ) -> List[str]: | |||
| """Lists all entrypoints available in repo hubconf | |||
| :param repo_info: | |||
| a string with format ``"repo_owner/repo_name[:tag_name/:branch_name]"`` with an optional | |||
| tag/branch. The default branch is ``master`` if not specified. | |||
| Example: ``"brain_sdk/MegBrain[:hub]"`` | |||
| :param git_host: | |||
| host address of git repo | |||
| Example: github.com | |||
| :param use_cache: | |||
| whether to use locally cached code or completely re-fetch | |||
| :param commit: | |||
| commit id on github or gitlab | |||
| :param protocol: | |||
| which protocol to use to get the repo, and HTTPS protocol only supports public repo on github. | |||
| The value should be one of HTTPS, SSH. | |||
| :return: | |||
| all entrypoint names of the model | |||
| """ | |||
| hubmodule = _init_hub(repo_info, git_host, use_cache, commit, protocol) | |||
| return [ | |||
| _ | |||
| for _ in dir(hubmodule) | |||
| if not _.startswith("__") and callable(getattr(hubmodule, _)) | |||
| ] | |||
| def load( | |||
| repo_info: str, | |||
| entry: str, | |||
| *args, | |||
| git_host: str = DEFAULT_GIT_HOST, | |||
| use_cache: bool = True, | |||
| commit: str = None, | |||
| protocol: str = DEFAULT_PROTOCOL, | |||
| **kwargs | |||
| ) -> Any: | |||
| """Loads model from github or gitlab repo, with pretrained weights. | |||
| :param repo_info: | |||
| a string with format ``"repo_owner/repo_name[:tag_name/:branch_name]"`` with an optional | |||
| tag/branch. The default branch is ``master`` if not specified. | |||
| Example: ``"brain_sdk/MegBrain[:hub]"`` | |||
| :param entry: | |||
| an entrypoint defined in hubconf | |||
| :param git_host: | |||
| host address of git repo | |||
| Example: github.com | |||
| :param use_cache: | |||
| whether to use locally cached code or completely re-fetch | |||
| :param commit: | |||
| commit id on github or gitlab | |||
| :param protocol: | |||
| which protocol to use to get the repo, and HTTPS protocol only supports public repo on github. | |||
| The value should be one of HTTPS, SSH. | |||
| :return: | |||
| a single model with corresponding pretrained weights. | |||
| """ | |||
| hubmodule = _init_hub(repo_info, git_host, use_cache, commit, protocol) | |||
| if not hasattr(hubmodule, entry) or not callable(getattr(hubmodule, entry)): | |||
| raise RuntimeError("Cannot find callable {} in hubconf.py".format(entry)) | |||
| _check_dependencies(hubmodule) | |||
| module = getattr(hubmodule, entry)(*args, **kwargs) | |||
| return module | |||
| def help( | |||
| repo_info: str, | |||
| entry: str, | |||
| git_host: str = DEFAULT_GIT_HOST, | |||
| use_cache: bool = True, | |||
| commit: str = None, | |||
| protocol: str = DEFAULT_PROTOCOL, | |||
| ) -> str: | |||
| """This function returns docstring of entrypoint ``entry`` by following steps: | |||
| 1. Pull the repo code specified by git and repo_info | |||
| 2. Load the entry defined in repo's hubconf.py | |||
| 3. Return docstring of function entry | |||
| :param repo_info: | |||
| a string with format ``"repo_owner/repo_name[:tag_name/:branch_name]"`` with an optional | |||
| tag/branch. The default branch is ``master`` if not specified. | |||
| Example: ``"brain_sdk/MegBrain[:hub]"`` | |||
| :param entry: | |||
| an entrypoint defined in hubconf.py | |||
| :param git_host: | |||
| host address of git repo | |||
| Example: github.com | |||
| :param use_cache: | |||
| whether to use locally cached code or completely re-fetch | |||
| :param commit: | |||
| commit id on github or gitlab | |||
| :param protocol: | |||
| which protocol to use to get the repo, and HTTPS protocol only supports public repo on github. | |||
| The value should be one of HTTPS, SSH. | |||
| :return: | |||
| docstring of entrypoint ``entry`` | |||
| """ | |||
| hubmodule = _init_hub(repo_info, git_host, use_cache, commit, protocol) | |||
| if not hasattr(hubmodule, entry) or not callable(getattr(hubmodule, entry)): | |||
| raise RuntimeError("Cannot find callable {} in hubconf.py".format(entry)) | |||
| doc = getattr(hubmodule, entry).__doc__ | |||
| return doc | |||
| def load_serialized_obj_from_url(url: str, model_dir=None) -> Any: | |||
| """Loads MegEngine serialized object from the given URL. | |||
| If the object is already present in ``model_dir``, it's deserialized and | |||
| returned. If no ``model_dir`` is specified, it will be ``MGE_HOME/serialized``. | |||
| :param url: url to serialized object | |||
| :param model_dir: dir to cache target serialized file | |||
| :return: loaded object | |||
| """ | |||
| if model_dir is None: | |||
| model_dir = os.path.join(_get_megengine_home(), "serialized") | |||
| os.makedirs(model_dir, exist_ok=True) | |||
| parts = urlparse(url) | |||
| filename = os.path.basename(parts.path) | |||
| # use hash as prefix to avoid filename conflict from different urls | |||
| sha256 = hashlib.sha256() | |||
| sha256.update(url.encode()) | |||
| digest = sha256.hexdigest()[:6] | |||
| filename = digest + "_" + filename | |||
| cached_file = os.path.join(model_dir, filename) | |||
| logger.info( | |||
| "load_serialized_obj_from_url: download to or using cached %s", cached_file | |||
| ) | |||
| if not os.path.exists(cached_file): | |||
| if is_distributed(): | |||
| logger.warning( | |||
| "Downloading serialized object in DISTRIBUTED mode\n" | |||
| " File may be downloaded multiple times. We recommend\n" | |||
| " users to download in single process first." | |||
| ) | |||
| download_from_url(url, cached_file, HTTP_READ_TIMEOUT) | |||
| state_dict = _mge_load_serialized(cached_file) | |||
| return state_dict | |||
| class pretrained: | |||
| r""" | |||
| Decorator which helps to download pretrained weights from the given url. | |||
| For example, we can decorate a resnet18 function as follows | |||
| .. code-block:: | |||
| @hub.pretrained("https://url/to/pretrained_resnet18.pkl") | |||
| def resnet18(**kwargs): | |||
| return ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) | |||
| When decorated function is called with ``pretrained=True``, MegEngine will automatically | |||
| download and fill the returned model with pretrained weights. | |||
| """ | |||
| def __init__(self, url): | |||
| self.url = url | |||
| def __call__(self, func): | |||
| @functools.wraps(func) | |||
| def pretrained_model_func( | |||
| pretrained=False, **kwargs | |||
| ): # pylint: disable=redefined-outer-name | |||
| model = func(**kwargs) | |||
| if pretrained: | |||
| weights = load_serialized_obj_from_url(self.url) | |||
| model.load_state_dict(weights) | |||
| return model | |||
| return pretrained_model_func | |||
| __all__ = [ | |||
| "list", | |||
| "load", | |||
| "help", | |||
| "load_serialized_obj_from_url", | |||
| "pretrained", | |||
| "import_module", | |||
| ] | |||
| @@ -0,0 +1,48 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import importlib.util | |||
| import os | |||
| import types | |||
| from contextlib import contextmanager | |||
| from typing import Iterator | |||
| def load_module(name: str, path: str) -> types.ModuleType: | |||
| """ | |||
| Loads module specified by name and path | |||
| :param name: module name | |||
| :param path: module path | |||
| """ | |||
| spec = importlib.util.spec_from_file_location(name, path) | |||
| module = importlib.util.module_from_spec(spec) | |||
| spec.loader.exec_module(module) | |||
| return module | |||
| def check_module_exists(module: str) -> bool: | |||
| """Checks whether python module exists or not | |||
| :param module: name of module | |||
| """ | |||
| return importlib.util.find_spec(module) is not None | |||
| @contextmanager | |||
| def cd(target: str) -> Iterator[None]: | |||
| """Changes current directory to target | |||
| :param target: target directory | |||
| """ | |||
| prev = os.getcwd() | |||
| os.chdir(os.path.expanduser(target)) | |||
| try: | |||
| yield | |||
| finally: | |||
| os.chdir(prev) | |||
| @@ -0,0 +1,237 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import contextlib | |||
| import logging | |||
| import os | |||
| import sys | |||
| _all_loggers = [] | |||
| _default_level_name = os.getenv("MEGENGINE_LOGGING_LEVEL", "ERROR") | |||
| _default_level = logging.getLevelName(_default_level_name.upper()) | |||
| def set_log_file(fout, mode="a"): | |||
| r"""Sets log output file. | |||
| :type fout: str or file-like | |||
| :param fout: file-like object that supports write and flush, or string for | |||
| the filename | |||
| :type mode: str | |||
| :param mode: specify the mode to open log file if *fout* is a string | |||
| """ | |||
| if isinstance(fout, str): | |||
| fout = open(fout, mode) | |||
| MegEngineLogFormatter.log_fout = fout | |||
| class MegEngineLogFormatter(logging.Formatter): | |||
| log_fout = None | |||
| date_full = "[%(asctime)s %(lineno)d@%(filename)s:%(name)s] " | |||
| date = "%(asctime)s " | |||
| msg = "%(message)s" | |||
| max_lines = 256 | |||
| def _color_exc(self, msg): | |||
| r"""Sets the color of message as the execution type. | |||
| """ | |||
| return "\x1b[34m{}\x1b[0m".format(msg) | |||
| def _color_dbg(self, msg): | |||
| r"""Sets the color of message as the debugging type. | |||
| """ | |||
| return "\x1b[36m{}\x1b[0m".format(msg) | |||
| def _color_warn(self, msg): | |||
| r"""Sets the color of message as the warning type. | |||
| """ | |||
| return "\x1b[1;31m{}\x1b[0m".format(msg) | |||
| def _color_err(self, msg): | |||
| r"""Sets the color of message as the error type. | |||
| """ | |||
| return "\x1b[1;4;31m{}\x1b[0m".format(msg) | |||
| def _color_omitted(self, msg): | |||
| r"""Sets the color of message as the omitted type. | |||
| """ | |||
| return "\x1b[35m{}\x1b[0m".format(msg) | |||
| def _color_normal(self, msg): | |||
| r"""Sets the color of message as the normal type. | |||
| """ | |||
| return msg | |||
| def _color_date(self, msg): | |||
| r"""Sets the color of message the same as date. | |||
| """ | |||
| return "\x1b[32m{}\x1b[0m".format(msg) | |||
| def format(self, record): | |||
| if record.levelno == logging.DEBUG: | |||
| mcl, mtxt = self._color_dbg, "DBG" | |||
| elif record.levelno == logging.WARNING: | |||
| mcl, mtxt = self._color_warn, "WRN" | |||
| elif record.levelno == logging.ERROR: | |||
| mcl, mtxt = self._color_err, "ERR" | |||
| else: | |||
| mcl, mtxt = self._color_normal, "" | |||
| if mtxt: | |||
| mtxt += " " | |||
| if self.log_fout: | |||
| self.__set_fmt(self.date_full + mtxt + self.msg) | |||
| formatted = super(MegEngineLogFormatter, self).format(record) | |||
| nr_line = formatted.count("\n") + 1 | |||
| if nr_line >= self.max_lines: | |||
| head, body = formatted.split("\n", 1) | |||
| formatted = "\n".join( | |||
| [ | |||
| head, | |||
| "BEGIN_LONG_LOG_{}_LINES{{".format(nr_line - 1), | |||
| body, | |||
| "}}END_LONG_LOG_{}_LINES".format(nr_line - 1), | |||
| ] | |||
| ) | |||
| self.log_fout.write(formatted) | |||
| self.log_fout.write("\n") | |||
| self.log_fout.flush() | |||
| self.__set_fmt(self._color_date(self.date) + mcl(mtxt + self.msg)) | |||
| formatted = super(MegEngineLogFormatter, self).format(record) | |||
| if record.exc_text or record.exc_info: | |||
| # handle exception format | |||
| b = formatted.find("Traceback ") | |||
| if b != -1: | |||
| s = formatted[b:] | |||
| s = self._color_exc(" " + s.replace("\n", "\n ")) | |||
| formatted = formatted[:b] + s | |||
| nr_line = formatted.count("\n") + 1 | |||
| if nr_line >= self.max_lines: | |||
| lines = formatted.split("\n") | |||
| remain = self.max_lines // 2 | |||
| removed = len(lines) - remain * 2 | |||
| if removed > 0: | |||
| mid_msg = self._color_omitted( | |||
| "[{} log lines omitted (would be written to output file " | |||
| "if set_log_file() has been called;\n" | |||
| " the threshold can be set at " | |||
| "MegEngineLogFormatter.max_lines)]".format(removed) | |||
| ) | |||
| formatted = "\n".join(lines[:remain] + [mid_msg] + lines[-remain:]) | |||
| return formatted | |||
| if sys.version_info.major < 3: | |||
| def __set_fmt(self, fmt): | |||
| self._fmt = fmt | |||
| else: | |||
| def __set_fmt(self, fmt): | |||
| self._style._fmt = fmt | |||
| def get_logger(name=None, formatter=MegEngineLogFormatter): | |||
| r"""Gets megengine logger with given name. | |||
| """ | |||
| logger = logging.getLogger(name) | |||
| if getattr(logger, "_init_done__", None): | |||
| return logger | |||
| logger._init_done__ = True | |||
| logger.propagate = False | |||
| logger.setLevel(_default_level) | |||
| handler = logging.StreamHandler() | |||
| handler.setFormatter(formatter(datefmt="%d %H:%M:%S")) | |||
| handler.setLevel(0) | |||
| del logger.handlers[:] | |||
| logger.addHandler(handler) | |||
| _all_loggers.append(logger) | |||
| return logger | |||
| def set_log_level(level, update_existing=True): | |||
| """Sets default logging level. | |||
| :type level: int e.g. logging.INFO | |||
| :param level: loggin level given by python :mod:`logging` module | |||
| :param update_existing: whether to update existing loggers | |||
| """ | |||
| global _default_level # pylint: disable=global-statement | |||
| _default_level = level | |||
| if update_existing: | |||
| for i in _all_loggers: | |||
| i.setLevel(level) | |||
| _logger = get_logger(__name__) | |||
| try: | |||
| if sys.version_info.major < 3: | |||
| raise ImportError() | |||
| from .core._imperative_rt.utils import Logger as _imperative_rt_logger | |||
| class MegBrainLogFormatter(MegEngineLogFormatter): | |||
| date = "%(asctime)s[mgb] " | |||
| def _color_date(self, msg): | |||
| return "\x1b[33m{}\x1b[0m".format(msg) | |||
| _megbrain_logger = get_logger("megbrain", MegBrainLogFormatter) | |||
| _imperative_rt_logger.set_log_handler(_megbrain_logger) | |||
| if _default_level == logging.getLevelName("ERROR"): | |||
| _imperative_rt_logger.set_log_level(_imperative_rt_logger.LogLevel.Error) | |||
| elif _default_level == logging.getLevelName("INFO"): | |||
| _imperative_rt_logger.set_log_level(_imperative_rt_logger.LogLevel.Info) | |||
| else: | |||
| _imperative_rt_logger.set_log_level(_imperative_rt_logger.LogLevel.Debug) | |||
| def set_mgb_log_level(level): | |||
| r"""Sets megbrain log level | |||
| :type level: int e.g. logging.INFO | |||
| :param level: new log level | |||
| :return: original log level | |||
| """ | |||
| logger = _megbrain_logger | |||
| rst = logger.getEffectiveLevel() | |||
| logger.setLevel(level) | |||
| return rst | |||
| except ImportError as exc: | |||
| def set_mgb_log_level(level): | |||
| raise NotImplementedError("imperative_rt has not been imported") | |||
| @contextlib.contextmanager | |||
| def replace_mgb_log_level(level): | |||
| r"""Replaces megbrain log level in a block and restore after exiting. | |||
| :type level: int e.g. logging.INFO | |||
| :param level: new log level | |||
| """ | |||
| old = set_mgb_log_level(level) | |||
| try: | |||
| yield | |||
| finally: | |||
| set_mgb_log_level(old) | |||
| def enable_debug_log(): | |||
| r"""Sets logging level to debug for all components. | |||
| """ | |||
| set_log_level(logging.DEBUG) | |||
| set_mgb_log_level(logging.DEBUG) | |||
| @@ -0,0 +1,24 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from .activation import LeakyReLU, PReLU, ReLU, Sigmoid, Softmax | |||
| from .batchnorm import BatchNorm1d, BatchNorm2d, SyncBatchNorm | |||
| from .concat import Concat | |||
| from .conv import Conv2d, ConvRelu2d, ConvTranspose2d, LocalConv2d | |||
| from .conv_bn import ConvBn2d, ConvBnRelu2d | |||
| from .dropout import Dropout | |||
| from .elemwise import Elemwise | |||
| from .embedding import Embedding | |||
| from .identity import Identity | |||
| from .linear import Linear | |||
| from .module import Module | |||
| from .parampack import ParamPack | |||
| from .pooling import AvgPool2d, MaxPool2d | |||
| from .quant_dequant import DequantStub, QuantStub | |||
| from .sequential import Sequential | |||
| @@ -0,0 +1,231 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import numpy as np | |||
| from ..functional import leaky_relu, prelu, relu, sigmoid, softmax | |||
| from ..tensor_nn import Parameter | |||
| from .module import Module | |||
| class Softmax(Module): | |||
| r""" | |||
| Applies a softmax function. Softmax is defined as: | |||
| .. math:: | |||
| \text{Softmax}(x_{i}) = \frac{exp(x_i)}{\sum_j exp(x_j)} | |||
| It is applied to an n-dimensional input Tensor and rescaling them so that the elements of the | |||
| n-dimensional output Tensor lie in the range of `[0, 1]` and sum to 1. | |||
| :param axis: An axis along which softmax will be applied. By default, | |||
| softmax will apply along the highest ranked axis. | |||
| Examples: | |||
| .. testcode:: | |||
| import numpy as np | |||
| import megengine as mge | |||
| import megengine.module as M | |||
| data = mge.tensor(np.array([-2,-1,0,1,2]).astype(np.float32)) | |||
| softmax = M.Softmax() | |||
| output = softmax(data) | |||
| with np.printoptions(precision=6): | |||
| print(output.numpy()) | |||
| Outputs: | |||
| .. testoutput:: | |||
| [0.011656 0.031685 0.086129 0.234122 0.636409] | |||
| """ | |||
| def __init__(self, axis=None): | |||
| super().__init__() | |||
| self.axis = axis | |||
| def forward(self, inputs): | |||
| return softmax(inputs, self.axis) | |||
| class Sigmoid(Module): | |||
| r""" | |||
| Applies the element-wise function: | |||
| .. math:: | |||
| \text{Sigmoid}(x) = \frac{1}{1 + \exp(-x)} | |||
| Examples: | |||
| .. testcode:: | |||
| import numpy as np | |||
| import megengine as mge | |||
| import megengine.module as M | |||
| data = mge.tensor(np.array([-2,-1,0,1,2,]).astype(np.float32)) | |||
| sigmoid = M.Sigmoid() | |||
| output = sigmoid(data) | |||
| with np.printoptions(precision=6): | |||
| print(output.numpy()) | |||
| Outputs: | |||
| .. testoutput:: | |||
| [0.119203 0.268941 0.5 0.731059 0.880797] | |||
| """ | |||
| def forward(self, inputs): | |||
| return sigmoid(inputs) | |||
| class ReLU(Module): | |||
| r""" | |||
| Applies the element-wise function: | |||
| .. math:: | |||
| \text{ReLU}(x) = \max(x, 0) | |||
| Examples: | |||
| .. testcode:: | |||
| import numpy as np | |||
| import megengine as mge | |||
| import megengine.module as M | |||
| data = mge.tensor(np.array([-2,-1,0,1,2,]).astype(np.float32)) | |||
| relu = M.ReLU() | |||
| output = relu(data) | |||
| with np.printoptions(precision=6): | |||
| print(output.numpy()) | |||
| Outputs: | |||
| .. testoutput:: | |||
| [0. 0. 0. 1. 2.] | |||
| """ | |||
| def forward(self, x): | |||
| return relu(x) | |||
| class PReLU(Module): | |||
| r""" | |||
| Applies the element-wise function: | |||
| .. math:: | |||
| \text{PReLU}(x) = \max(0,x) + a * \min(0,x) | |||
| or | |||
| .. math:: | |||
| \text{PReLU}(x) = | |||
| \begin{cases} | |||
| x, & \text{ if } x \geq 0 \\ | |||
| ax, & \text{ otherwise } | |||
| \end{cases} | |||
| Here :math:`a` is a learnable parameter. When called without arguments, `PReLU()` uses | |||
| a single paramter :math:`a` across all input channel. If called with `PReLU(num_of_channels)`, | |||
| a seperate :math:`a` is used for each input channle. | |||
| :param num_parameters: number of :math:`a` to learn, there is only two | |||
| values are legitimate: 1, or the number of channels at input. Default: 1 | |||
| :param init: the initial value of :math:`a`. Default: 0.25 | |||
| Examples: | |||
| .. testcode:: | |||
| import numpy as np | |||
| import megengine as mge | |||
| import megengine.module as M | |||
| data = mge.tensor(np.array([-1.2, -3.7, 2.7]).astype(np.float32)) | |||
| prelu = M.PReLU() | |||
| output = prelu(data) | |||
| print(output.numpy()) | |||
| Outputs: | |||
| .. testoutput:: | |||
| [-0.3 -0.925 2.7 ] | |||
| """ | |||
| def __init__(self, num_parameters: int = 1, init: float = 0.25): | |||
| super().__init__() | |||
| self.num_parameters = num_parameters | |||
| if num_parameters > 1: | |||
| # Assume format is NCHW | |||
| self.weight = Parameter( | |||
| data=np.full((1, num_parameters, 1, 1), init, dtype=np.float32) | |||
| ) | |||
| else: | |||
| self.weight = Parameter(data=[init]) | |||
| def forward(self, inputs): | |||
| assert self.weight.shape == (1,) or self.weight.shape == ( | |||
| 1, | |||
| int(inputs.shape[1]), | |||
| 1, | |||
| 1, | |||
| ), "invalid weight's shape" | |||
| return prelu(inputs, self.weight) | |||
| class LeakyReLU(Module): | |||
| r""" | |||
| Applies the element-wise function: | |||
| .. math:: | |||
| \text{LeakyReLU}(x) = \max(0,x) + negative\_slope \times \min(0,x) | |||
| or | |||
| .. math:: | |||
| \text{LeakyReLU}(x) = | |||
| \begin{cases} | |||
| x, & \text{ if } x \geq 0 \\ | |||
| negative\_slope \times x, & \text{ otherwise } | |||
| \end{cases} | |||
| Examples: | |||
| .. testcode:: | |||
| import numpy as np | |||
| import megengine as mge | |||
| import megengine.module as M | |||
| data = mge.tensor(np.array([-8, -12, 6, 10]).astype(np.float32)) | |||
| leakyrelu = M.LeakyReLU(0.01) | |||
| output = leakyrelu(data) | |||
| print(output.numpy()) | |||
| Outputs: | |||
| .. testoutput:: | |||
| [-0.08 -0.12 6. 10. ] | |||
| """ | |||
| def __init__(self, negative_slope: float = 0.01): | |||
| super().__init__() | |||
| self.negative_slope = negative_slope | |||
| def forward(self, inputs): | |||
| return leaky_relu(inputs, self.negative_slope) | |||
| @@ -0,0 +1,281 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from typing import Optional | |||
| import numpy as np | |||
| from ..distributed.group import WORLD, Group | |||
| from ..functional import batch_norm2d, sync_batch_norm | |||
| from ..tensor_nn import Buffer, Parameter | |||
| from . import init | |||
| from .module import Module | |||
| class _BatchNorm(Module): | |||
| def __init__( | |||
| self, | |||
| num_features, | |||
| eps=1e-5, | |||
| momentum=0.9, | |||
| affine=True, | |||
| track_running_stats=True, | |||
| freeze=False, | |||
| ): | |||
| super(_BatchNorm, self).__init__() | |||
| self.num_features = num_features | |||
| self.eps = eps | |||
| self.momentum = momentum | |||
| self.affine = affine | |||
| self.track_running_stats = track_running_stats | |||
| self._track_running_stats_saved = track_running_stats | |||
| self.freeze = freeze | |||
| if self.affine: | |||
| self.weight = Parameter(np.ones(num_features, dtype=np.float32)) | |||
| self.bias = Parameter(np.zeros(num_features, dtype=np.float32)) | |||
| else: | |||
| self.weight = None | |||
| self.bias = None | |||
| tshape = (1, self.num_features, 1, 1) | |||
| if self.track_running_stats: | |||
| self.running_mean = Buffer(np.zeros(tshape, dtype=np.float32)) | |||
| self.running_var = Buffer(np.ones(tshape, dtype=np.float32)) | |||
| else: | |||
| self.running_mean = None | |||
| self.running_var = None | |||
| def reset_running_stats(self) -> None: | |||
| if self.track_running_stats: | |||
| init.zeros_(self.running_mean) | |||
| init.ones_(self.running_var) | |||
| def reset_parameters(self) -> None: | |||
| self.reset_running_stats() | |||
| if self.affine: | |||
| init.ones_(self.weight) | |||
| init.zeros_(self.bias) | |||
| def _check_input_ndim(self, inp): | |||
| raise NotImplementedError | |||
| def forward(self, inp): | |||
| self._check_input_ndim(inp) | |||
| if self._track_running_stats_saved == False: | |||
| assert ( | |||
| self.track_running_stats == False | |||
| ), "track_running_stats can not be initilized to False and changed to True later" | |||
| _ndims = len(inp.shape) | |||
| if _ndims != 4: | |||
| origin_shape = inp.shapeof() | |||
| if _ndims == 2: | |||
| n, c = inp.shapeof(0), inp.shapeof(1) | |||
| new_shape = (n, c, 1, 1) | |||
| elif _ndims == 3: | |||
| n, c, h = inp.shapeof(0), inp.shapeof(1), inp.shapeof(2) | |||
| new_shape = (n, c, h, 1) | |||
| inp = inp.reshape(new_shape) | |||
| if self.freeze and self.training and self._track_running_stats_saved: | |||
| scale = self.weight.reshape(1, -1, 1, 1) * ( | |||
| self.running_var + self.eps | |||
| ) ** (-0.5) | |||
| bias = self.bias.reshape(1, -1, 1, 1) - self.running_mean * scale | |||
| return inp * scale.detach() + bias.detach() | |||
| if self.training and self.track_running_stats: | |||
| exponential_average_factor = self.momentum | |||
| else: | |||
| exponential_average_factor = 0.0 # useless | |||
| output = batch_norm2d( | |||
| inp, | |||
| self.running_mean if self.track_running_stats else None, | |||
| self.running_var if self.track_running_stats else None, | |||
| self.weight, | |||
| self.bias, | |||
| training=self.training | |||
| or ((self.running_mean is None) and (self.running_var is None)), | |||
| momentum=exponential_average_factor, | |||
| eps=self.eps, | |||
| ) | |||
| if _ndims != 4: | |||
| output = output.reshape(origin_shape) | |||
| return output | |||
| class SyncBatchNorm(_BatchNorm): | |||
| r""" | |||
| Applies Synchronization Batch Normalization. | |||
| """ | |||
| def __init__( | |||
| self, | |||
| num_features, | |||
| eps=1e-5, | |||
| momentum=0.9, | |||
| affine=True, | |||
| track_running_stats=True, | |||
| freeze=False, | |||
| group: Optional[Group] = None, | |||
| ) -> None: | |||
| super().__init__( | |||
| num_features, eps, momentum, affine, track_running_stats, freeze | |||
| ) | |||
| self.group = group | |||
| def _check_input_ndim(self, inp): | |||
| if len(inp.shape) not in {2, 3, 4}: | |||
| raise ValueError( | |||
| "expected 2D, 3D or 4D input (got {}D input)".format(len(inp.shape)) | |||
| ) | |||
| def forward(self, inp): | |||
| self._check_input_ndim(inp) | |||
| _ndims = len(inp.shape) | |||
| if _ndims != 4: | |||
| origin_shape = inp.shapeof() | |||
| if _ndims == 2: | |||
| n, c = inp.shapeof(0), inp.shapeof(1) | |||
| new_shape = (n, c, 1, 1) | |||
| elif _ndims == 3: | |||
| n, c, h = inp.shapeof(0), inp.shapeof(1), inp.shapeof(2) | |||
| new_shape = (n, c, h, 1) | |||
| inp = inp.reshape(new_shape) | |||
| if self.training and self.track_running_stats: | |||
| exponential_average_factor = self.momentum | |||
| else: | |||
| exponential_average_factor = 0.0 # useless | |||
| output = sync_batch_norm( | |||
| inp, | |||
| self.running_mean, | |||
| self.running_var, | |||
| self.weight, | |||
| self.bias, | |||
| self.training or not self.track_running_stats, | |||
| exponential_average_factor, | |||
| self.eps, | |||
| group=self.group, | |||
| ) | |||
| if _ndims != 4: | |||
| output = output.reshape(origin_shape) | |||
| return output | |||
| class BatchNorm1d(_BatchNorm): | |||
| r""" | |||
| Applies Batch Normalization over a 2D/3D tensor. | |||
| Refer to :class:`~.BatchNorm2d` for more information. | |||
| """ | |||
| def _check_input_ndim(self, inp): | |||
| if len(inp.shape) not in {2, 3}: | |||
| raise ValueError( | |||
| "expected 2D or 3D input (got {}D input)".format(len(inp.shape)) | |||
| ) | |||
| class BatchNorm2d(_BatchNorm): | |||
| r""" | |||
| Applies Batch Normalization over a 4D tensor. | |||
| .. math:: | |||
| y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta | |||
| The mean and standard-deviation are calculated per-dimension over | |||
| the mini-batches and :math:`\gamma` and :math:`\beta` are learnable | |||
| parameter vectors. | |||
| By default, during training this layer keeps running estimates of its | |||
| computed mean and variance, which are then used for normalization during | |||
| evaluation. The running estimates are kept with a default :attr:`momentum` | |||
| of 0.9. | |||
| If :attr:`track_running_stats` is set to ``False``, this layer will not | |||
| keep running estimates, and batch statistics are instead used during | |||
| evaluation time. | |||
| .. note:: | |||
| This :attr:`momentum` argument is different from one used in optimizer | |||
| classes and the conventional notion of momentum. Mathematically, the | |||
| update rule for running statistics here is | |||
| :math:`\hat{x}_\text{new} = \text{momentum} \times \hat{x} + (1 - \text{momentum}) \times x_t`, | |||
| where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the | |||
| new observed value. | |||
| Because the Batch Normalization is done over the `C` dimension, computing | |||
| statistics on `(N, H, W)` slices, it's common terminology to call this | |||
| Spatial Batch Normalization. | |||
| :type num_features: int | |||
| :param num_features: usually the :math:`C` from an input of size | |||
| :math:`(N, C, H, W)` or the highest ranked dimension of an input with | |||
| less than 4D. | |||
| :type eps: float | |||
| :param eps: a value added to the denominator for numerical stability. | |||
| Default: 1e-5. | |||
| :type momentum: float | |||
| :param momentum: the value used for the `running_mean` and `running_var` | |||
| computation. | |||
| Default: 0.9 | |||
| :type affine: bool | |||
| :param affine: a boolean value that when set to ``True``, this module has | |||
| learnable affine parameters. Default: ``True`` | |||
| :type track_running_stats: bool | |||
| :param track_running_stats: when set to ``True``, this module tracks the | |||
| running mean and variance. When set to ``False``, this module does not | |||
| track such statistics and always uses batch statistics in both training | |||
| and eval modes. Default: ``True``. | |||
| :type freeze: bool | |||
| :param freeze: when set to ``True``, this module does not update the | |||
| running mean and variance, and uses the running mean and variance instead of | |||
| the batch mean and batch variance to normalize the input. The parameter takes effect | |||
| only when the module is initilized with ``track_running_stats`` as ``True`` and | |||
| the module is in training mode. | |||
| Default: ``False``. | |||
| Examples: | |||
| .. testcode:: | |||
| import megengine as mge | |||
| import megengine.module as M | |||
| # With Learnable Parameters | |||
| m = M.BatchNorm2d(4) | |||
| inp = mge.tensor(np.random.rand(1, 4, 3, 3).astype("float32")) | |||
| oup = m(inp) | |||
| print(m.weight, m.bias) | |||
| # Without Learnable Parameters | |||
| m = M.BatchNorm2d(4, affine=False) | |||
| oup = m(inp) | |||
| print(m.weight, m.bias) | |||
| .. testoutput:: | |||
| Tensor([1. 1. 1. 1.]) Tensor([0. 0. 0. 0.]) | |||
| None None | |||
| """ | |||
| def _check_input_ndim(self, inp): | |||
| if len(inp.shape) != 4: | |||
| raise ValueError("expected 4D input (got {}D input)".format(len(inp.shape))) | |||
| @@ -0,0 +1,22 @@ | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from typing import Iterable | |||
| from ..functional import concat | |||
| from ..tensor import Tensor | |||
| from .module import Module | |||
| class Concat(Module): | |||
| r""" | |||
| A :class:`~.Module` to do functional concat. Could be replaced with :class:`~.QATModule` | |||
| version :class:`~.qat.concat.Concat` using :func:`~.quantize.quantize_qat`. | |||
| """ | |||
| def forward(self, inps: Iterable[Tensor], axis: int = 0): | |||
| return concat(inps, axis) | |||
| @@ -0,0 +1,391 @@ | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from abc import abstractmethod | |||
| from typing import Tuple, Union | |||
| import numpy as np | |||
| from ..core.ops._internal import param_defs as P | |||
| from ..functional import conv2d, conv_transpose2d, local_conv2d, relu | |||
| from ..functional.types import _pair, _pair_nonzero | |||
| from ..tensor_nn import Parameter | |||
| from . import init | |||
| from .module import Module | |||
| class _ConvNd(Module): | |||
| """base class for convolution modules, including transposed conv""" | |||
| def __init__( | |||
| self, | |||
| in_channels: int, | |||
| out_channels: int, | |||
| kernel_size: Union[int, Tuple[int, int]], | |||
| stride: Union[int, Tuple[int, int]], | |||
| padding: Union[int, Tuple[int, int]], | |||
| dilation: Union[int, Tuple[int, int]], | |||
| groups: int, | |||
| bias: bool = True, | |||
| ): | |||
| super().__init__() | |||
| if in_channels % groups != 0: | |||
| raise ValueError("in_channels must be divisible by groups") | |||
| if out_channels % groups != 0: | |||
| raise ValueError("out_channels must be divisible by groups") | |||
| self.in_channels = in_channels | |||
| self.out_channels = out_channels | |||
| self.kernel_size = kernel_size | |||
| self.stride = stride | |||
| self.padding = padding | |||
| self.dilation = dilation | |||
| self.groups = groups | |||
| self.weight = Parameter(np.zeros(self._infer_weight_shape(), dtype=np.float32)) | |||
| self.bias = None | |||
| if bias: | |||
| self.bias = Parameter(np.zeros(self._infer_bias_shape(), dtype=np.float32)) | |||
| self.reset_parameters() | |||
| @abstractmethod | |||
| def _get_fanin(self): | |||
| pass | |||
| def reset_parameters(self) -> None: | |||
| fanin = self._get_fanin() | |||
| std = np.sqrt(1 / fanin) | |||
| init.normal_(self.weight, 0.0, std) | |||
| if self.bias is not None: | |||
| init.zeros_(self.bias) | |||
| @abstractmethod | |||
| def _infer_weight_shape(self): | |||
| pass | |||
| @abstractmethod | |||
| def _infer_bias_shape(self): | |||
| pass | |||
| class Conv2d(_ConvNd): | |||
| r"""Applies a 2D convolution over an input tensor. | |||
| For instance, given an input of the size :math:`(N, C_{\text{in}}, H, W)`, | |||
| this layer generates an output of the size | |||
| :math:`(N, C_{\text{out}}, H_{\text{out}}, W_{\text{out}})` through the | |||
| process described as below: | |||
| .. math:: | |||
| \text{out}(N_i, C_{\text{out}_j}) = \text{bias}(C_{\text{out}_j}) + | |||
| \sum_{k = 0}^{C_{\text{in}} - 1} \text{weight}(C_{\text{out}_j}, k) \star \text{input}(N_i, k) | |||
| where :math:`\star` is the valid 2D cross-correlation operator, | |||
| :math:`N` is a batch size, :math:`C` denotes a number of channels, | |||
| :math:`H` is a height of input planes in pixels, and :math:`W` is | |||
| width in pixels. | |||
| When ``groups == in_channels`` and ``out_channels == K * in_channels``, | |||
| where `K` is a positive integer, this operation is also known as depthwise | |||
| convolution. | |||
| In other words, for an input of size :math:`(N, C_{in}, H_{in}, W_{in})`, | |||
| a depthwise convolution with a depthwise multiplier `K`, can be constructed | |||
| by arguments :math:`(in\_channels=C_{in}, out\_channels=C_{in} \times K, ..., groups=C_{in})`. | |||
| :param in_channels: number of input channels. | |||
| :param out_channels: number of output channels. | |||
| :param kernel_size: size of weight on spatial dimensions. If ``kernel_size`` is | |||
| an :class:`int`, the actual kernel size would be | |||
| ``(kernel_size, kernel_size)``. Default: 1 | |||
| :param stride: stride of the 2D convolution operation. Default: 1 | |||
| :param padding: size of the paddings added to the input on both sides of its | |||
| spatial dimensions. Only zero-padding is supported. Default: 0 | |||
| :param dilation: dilation of the 2D convolution operation. Default: 1 | |||
| :param groups: number of groups to divide input and output channels into, | |||
| so as to perform a "grouped convolution". When ``groups`` is not 1, | |||
| ``in_channels`` and ``out_channels`` must be divisible by ``groups``, | |||
| and there would be an extra dimension at the beginning of the weight's | |||
| shape. Specifically, the shape of weight would be ``(groups, | |||
| out_channel // groups, in_channels // groups, *kernel_size)``. | |||
| :param bias: whether to add a bias onto the result of convolution. Default: | |||
| True | |||
| :param conv_mode: Supports `CROSS_CORRELATION` or `CONVOLUTION`. Default: | |||
| `CROSS_CORRELATION`. | |||
| :param compute_mode: When set to `DEFAULT`, no special requirements will be | |||
| placed on the precision of intermediate results. When set to `FLOAT32`, | |||
| float32 would be used for accumulator and intermediate result, but only | |||
| effective when input and output are of float16 dtype. | |||
| """ | |||
| _conv_mode_type = P.Convolution.Mode | |||
| _compute_mode_type = P.Convolution.ComputeMode | |||
| def __init__( | |||
| self, | |||
| in_channels: int, | |||
| out_channels: int, | |||
| kernel_size: Union[int, Tuple[int, int]], | |||
| stride: Union[int, Tuple[int, int]] = 1, | |||
| padding: Union[int, Tuple[int, int]] = 0, | |||
| dilation: Union[int, Tuple[int, int]] = 1, | |||
| groups: int = 1, | |||
| bias: bool = True, | |||
| conv_mode: str = "CROSS_CORRELATION", | |||
| compute_mode: str = "DEFAULT", | |||
| ): | |||
| kernel_size = _pair_nonzero(kernel_size) | |||
| stride = _pair_nonzero(stride) | |||
| padding = _pair(padding) | |||
| dilation = _pair_nonzero(dilation) | |||
| self.conv_mode = self._conv_mode_type.convert(conv_mode) | |||
| self.compute_mode = self._compute_mode_type.convert(compute_mode) | |||
| super().__init__( | |||
| in_channels, | |||
| out_channels, | |||
| kernel_size, | |||
| stride, | |||
| padding, | |||
| dilation, | |||
| groups, | |||
| bias, | |||
| ) | |||
| def _get_fanin(self): | |||
| kh, kw = self.kernel_size | |||
| ic = self.in_channels | |||
| return kh * kw * ic | |||
| def _infer_weight_shape(self): | |||
| group = self.groups | |||
| ichl = self.in_channels | |||
| ochl = self.out_channels | |||
| kh, kw = self.kernel_size | |||
| if group == 1: | |||
| # Assume format is NCHW | |||
| return (ochl, ichl, kh, kw) | |||
| assert ( | |||
| ichl % group == 0 and ochl % group == 0 | |||
| ), "invalid config: input_channels={} output_channels={} group={}".format( | |||
| ichl, ochl, group | |||
| ) | |||
| # Assume format is NCHW | |||
| return (group, ochl // group, ichl // group, kh, kw) | |||
| def _infer_bias_shape(self): | |||
| # Assume format is NCHW | |||
| return (1, self.out_channels, 1, 1) | |||
| def calc_conv(self, inp, weight, bias): | |||
| return conv2d( | |||
| inp, | |||
| weight, | |||
| bias, | |||
| self.stride, | |||
| self.padding, | |||
| self.dilation, | |||
| self.groups, | |||
| self.conv_mode, | |||
| self.compute_mode, | |||
| ) | |||
| def forward(self, inp): | |||
| return self.calc_conv(inp, self.weight, self.bias) | |||
| class ConvTranspose2d(_ConvNd): | |||
| r"""Applies a 2D transposed convolution over an input tensor. | |||
| This module is also known as a deconvolution or a fractionally-strided convolution. | |||
| :class:`ConvTranspose2d` can ben seen as the gradient of :class:`Conv2d` operation | |||
| with respect to its input. | |||
| Convolution usually reduces the size of input, while transposed convolution works | |||
| the opposite way, transforming a smaller input to a larger output while preserving the | |||
| connectivity pattern. | |||
| :param in_channels: number of input channels. | |||
| :param out_channels: number of output channels. | |||
| :param kernel_size: size of weight on spatial dimensions. If ``kernel_size`` is | |||
| an :class:`int`, the actual kernel size would be | |||
| ``(kernel_size, kernel_size)``. Default: 1 | |||
| :param stride: stride of the 2D convolution operation. Default: 1 | |||
| :param padding: size of the paddings added to the input on both sides of its | |||
| spatial dimensions. Only zero-padding is supported. Default: 0 | |||
| :param dilation: dilation of the 2D convolution operation. Default: 1 | |||
| :param groups: number of groups to divide input and output channels into, | |||
| so as to perform a "grouped convolution". When ``groups`` is not 1, | |||
| ``in_channels`` and ``out_channels`` must be divisible by ``groups``, | |||
| and there would be an extra dimension at the beginning of the weight's | |||
| shape. Specifically, the shape of weight would be ``(groups, | |||
| out_channels // groups, in_channels // groups, *kernel_size)``. Default: 1 | |||
| :param bias: wether to add a bias onto the result of convolution. Default: | |||
| True | |||
| :param conv_mode: Supports `CROSS_CORRELATION` or `CONVOLUTION`. Default: | |||
| `CROSS_CORRELATION`. | |||
| :param compute_mode: When set to `DEFAULT`, no special requirements will be | |||
| placed on the precision of intermediate results. When set to `FLOAT32`, | |||
| float32 would be used for accumulator and intermediate result, but only | |||
| effective when input and output are of float16 dtype. | |||
| """ | |||
| _conv_mode_type = P.Convolution.Mode | |||
| _compute_mode_type = P.Convolution.ComputeMode | |||
| def __init__( | |||
| self, | |||
| in_channels: int, | |||
| out_channels: int, | |||
| kernel_size: Union[int, Tuple[int, int]], | |||
| stride: Union[int, Tuple[int, int]] = 1, | |||
| padding: Union[int, Tuple[int, int]] = 0, | |||
| dilation: Union[int, Tuple[int, int]] = 1, | |||
| groups: int = 1, | |||
| bias: bool = True, | |||
| conv_mode: str = "CROSS_CORRELATION", | |||
| compute_mode: str = "DEFAULT", | |||
| ): | |||
| kernel_size = _pair_nonzero(kernel_size) | |||
| stride = _pair_nonzero(stride) | |||
| padding = _pair(padding) | |||
| dilation = _pair_nonzero(dilation) | |||
| self.conv_mode = self._conv_mode_type.convert(conv_mode) | |||
| self.compute_mode = self._compute_mode_type.convert(compute_mode) | |||
| super().__init__( | |||
| in_channels, | |||
| out_channels, | |||
| kernel_size, | |||
| stride, | |||
| padding, | |||
| dilation, | |||
| groups, | |||
| bias, | |||
| ) | |||
| def _get_fanin(self): | |||
| kh, kw = self.kernel_size | |||
| oc = self.out_channels | |||
| return kh * kw * oc | |||
| def _infer_weight_shape(self): | |||
| group = self.groups | |||
| ichl = self.in_channels | |||
| ochl = self.out_channels | |||
| kh, kw = self.kernel_size | |||
| if group == 1: | |||
| # Assume format is NCHW | |||
| return (ichl, ochl, kh, kw) | |||
| assert ( | |||
| ichl % group == 0 and ochl % group == 0 | |||
| ), "invalid config: input_channels={} output_channels={} group={}".format( | |||
| ichl, ochl, group | |||
| ) | |||
| # Assume format is NCHW | |||
| return (group, ichl // group, ochl // group, kh, kw) | |||
| def _infer_bias_shape(self): | |||
| # Assume format is NCHW | |||
| return (1, self.out_channels, 1, 1) | |||
| def forward(self, inp): | |||
| return conv_transpose2d( | |||
| inp, | |||
| self.weight, | |||
| self.bias, | |||
| self.stride, | |||
| self.padding, | |||
| self.dilation, | |||
| self.groups, | |||
| self.conv_mode, | |||
| self.compute_mode, | |||
| ) | |||
| class LocalConv2d(Conv2d): | |||
| r"""Applies a spatial convolution with untied kernels over an input 4D tensor. | |||
| It is also known as the locally connected layer. | |||
| :param in_channels: number of input channels. | |||
| :param out_channels: number of output channels. | |||
| :param input_height: the height of the input images. | |||
| :param input_width: the width of the input images. | |||
| :param kernel_size: size of weight on spatial dimensions. If ``kernel_size`` is | |||
| an :class:`int`, the actual kernel size would be | |||
| ``(kernel_size, kernel_size)``. Default: 1 | |||
| :param stride: stride of the 2D convolution operation. Default: 1 | |||
| :param padding: size of the paddings added to the input on both sides of its | |||
| spatial dimensions. Only zero-padding is supported. Default: 0 | |||
| :param groups: number of groups to divide input and output channels into, | |||
| so as to perform a "grouped convolution". When ``groups`` is not 1, | |||
| ``in_channels`` and ``out_channels`` must be divisible by ``groups``. | |||
| The shape of weight is ``(groups, output_height, output_width, | |||
| in_channels // groups, *kernel_size, out_channels // groups)``. | |||
| """ | |||
| _conv_mode_type = P.Convolution.Mode | |||
| def __init__( | |||
| self, | |||
| in_channels: int, | |||
| out_channels: int, | |||
| input_height: int, | |||
| input_width: int, | |||
| kernel_size: Union[int, Tuple[int, int]], | |||
| stride: Union[int, Tuple[int, int]] = 1, | |||
| padding: Union[int, Tuple[int, int]] = 0, | |||
| dilation: Union[int, Tuple[int, int]] = 1, | |||
| groups: int = 1, | |||
| conv_mode: str = "CROSS_CORRELATION", | |||
| ): | |||
| self.input_height = input_height | |||
| self.input_width = input_width | |||
| super().__init__( | |||
| in_channels, | |||
| out_channels, | |||
| kernel_size, | |||
| stride, | |||
| padding, | |||
| dilation, | |||
| groups, | |||
| bias=False, | |||
| ) | |||
| def _infer_weight_shape(self): | |||
| group = self.groups | |||
| output_height = ( | |||
| self.input_height + self.padding[0] * 2 - self.kernel_size[0] | |||
| ) // self.stride[0] + 1 | |||
| output_width = ( | |||
| self.input_width + self.padding[1] * 2 - self.kernel_size[1] | |||
| ) // self.stride[1] + 1 | |||
| # Assume format is NCHW | |||
| return ( | |||
| group, | |||
| output_height, | |||
| output_width, | |||
| self.in_channels // group, | |||
| self.kernel_size[0], | |||
| self.kernel_size[1], | |||
| self.out_channels // group, | |||
| ) | |||
| def forward(self, inp): | |||
| return local_conv2d( | |||
| inp, self.weight, self.stride, self.padding, self.dilation, self.conv_mode | |||
| ) | |||
| class ConvRelu2d(Conv2d): | |||
| r""" | |||
| A fused :class:`~.Module` including Conv2d and relu. Could be replaced | |||
| with :class:`~.QATModule` version :class:`~.qat.conv.ConvRelu2d` using | |||
| :func:`~.quantize.quantize_qat`. | |||
| """ | |||
| def forward(self, inp): | |||
| return relu(self.calc_conv(inp, self.weight, self.bias)) | |||
| @@ -0,0 +1,69 @@ | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from typing import Tuple, Union | |||
| from ..functional import relu | |||
| from .batchnorm import BatchNorm2d | |||
| from .conv import Conv2d | |||
| from .module import Module | |||
| class _ConvBnActivation2d(Module): | |||
| def __init__( | |||
| self, | |||
| in_channels: int, | |||
| out_channels: int, | |||
| kernel_size: Union[int, Tuple[int, int]], | |||
| stride: Union[int, Tuple[int, int]] = 1, | |||
| padding: Union[int, Tuple[int, int]] = 0, | |||
| dilation: Union[int, Tuple[int, int]] = 1, | |||
| groups: int = 1, | |||
| bias: bool = True, | |||
| conv_mode: str = "CROSS_CORRELATION", | |||
| compute_mode: str = "DEFAULT", | |||
| eps=1e-5, | |||
| momentum=0.9, | |||
| affine=True, | |||
| track_running_stats=True, | |||
| ): | |||
| super().__init__() | |||
| self.conv = Conv2d( | |||
| in_channels, | |||
| out_channels, | |||
| kernel_size, | |||
| stride, | |||
| padding, | |||
| dilation, | |||
| groups, | |||
| bias, | |||
| conv_mode, | |||
| compute_mode, | |||
| ) | |||
| self.bn = BatchNorm2d(out_channels, eps, momentum, affine, track_running_stats) | |||
| class ConvBn2d(_ConvBnActivation2d): | |||
| r""" | |||
| A fused :class:`~.Module` including Conv2d, BatchNorm2d. Could be replaced | |||
| with :class:`~.QATModule` version :class:`~.qat.conv_bn.ConvBn2d` using | |||
| :func:`~.quantize.quantize_qat`. | |||
| """ | |||
| def forward(self, inp): | |||
| return self.bn(self.conv(inp)) | |||
| class ConvBnRelu2d(_ConvBnActivation2d): | |||
| r""" | |||
| A fused :class:`~.Module` including Conv2d, BatchNorm2d and relu. Could be replaced | |||
| with :class:`~.QATModule` version :class:`~.qat.conv_bn.ConvBnRelu2d` using | |||
| :func:`~.quantize.quantize_qat`. | |||
| """ | |||
| def forward(self, inp): | |||
| return relu(self.bn(self.conv(inp))) | |||
| @@ -0,0 +1,29 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from ..functional import dropout | |||
| from .module import Module | |||
| class Dropout(Module): | |||
| r"""Randomly set input elements to zeros with the probability :math:`drop\_prob` during training. Commonly used in large networks to prevent overfitting. | |||
| Note that we perform dropout only during training, we also rescale(multiply) the output tensor | |||
| by :math:`\frac{1}{1 - drop\_prob}`. During inference :class:`~.Dropout` is equal to :class:`~.Identity`. | |||
| :param drop_prob: The probability to drop (set to zero) each single element | |||
| """ | |||
| def __init__(self, drop_prob=0.0): | |||
| super().__init__() | |||
| self.drop_prob = drop_prob | |||
| def forward(self, inputs): | |||
| if self.training: | |||
| return dropout(inputs, self.drop_prob, rescale=True) | |||
| else: | |||
| return inputs | |||
| @@ -0,0 +1,79 @@ | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from ..core.ops._internal import param_defs as P | |||
| from ..functional.elemwise import _elwise | |||
| from ..tensor import Tensor | |||
| from .module import Module | |||
| class Elemwise(Module): | |||
| r""" | |||
| A :class:`~.Module` to do elemwise operator. Could be replaced with :class:`~.QATModule` | |||
| version :class:`~.qat.elemwise.Elemwise` using :func:`~.quantize.quantize_qat`. | |||
| :param method: the elemwise method, support the following string. | |||
| It will do the normal elemwise operator for float. | |||
| * "ADD": a + b | |||
| * "FUSE_ADD_RELU": max(x+y, 0) | |||
| * "MUL": x * y | |||
| * "MIN": min(x, y) | |||
| * "MAX": max(x, y) | |||
| * "SUB": x - y | |||
| * "TRUE_DIV": x / y | |||
| * "FUSE_ADD_SIGMOID": sigmoid(x + y) | |||
| * "FUSE_ADD_TANH": tanh(x + y) | |||
| * "RELU": x > 0 ? x : 0 | |||
| * "ABS": x > 0 ? x : -x | |||
| * "SIGMOID": sigmoid(x) | |||
| * "EXP": exp(x) | |||
| * "TANH": tanh(x) | |||
| * "FUSE_MUL_ADD3": x * y + z | |||
| * "FAST_TANH": fast_tanh(x) | |||
| * "NEGATE": -x | |||
| * "ACOS": acos(x) | |||
| * "ASIN": asin(x) | |||
| * "CEIL": ceil(x) | |||
| * "COS": cos(x) | |||
| * "EXPM1": expm1(x) | |||
| * "FLOOR": floor(x) | |||
| * "LOG": log(x) | |||
| * "LOG1P": log1p(x) | |||
| * "SIN": sin(x) | |||
| * "ROUND": round(x) | |||
| * "ERF": erf(x) | |||
| * "ERFINV": erfinv(x) | |||
| * "ERFC": erfc(x) | |||
| * "ERFCINV": erfcinv(x) | |||
| * "ABS_GRAD": abs_grad | |||
| * "FLOOR_DIV": floor_div | |||
| * "MOD": mod | |||
| * "SIGMOID_GRAD": sigmoid_grad | |||
| * "SWITCH_GT0": switch_gt0 | |||
| * "TANH_GRAD": tanh_grad | |||
| * "LT": lt | |||
| * "LEQ": leq | |||
| * "EQ": eq | |||
| * "POW": pow | |||
| * "LOG_SUM_EXP": log_sum_exp | |||
| * "FAST_TANH_GRAD": fast_tanh_grad | |||
| * "ATAN2": atan2 | |||
| * "COND_LEQ_MOV": cond_leq_mov | |||
| * "H_SWISH": h_swish | |||
| * "FUSE_ADD_H_SWISH": h_swish(x+y) | |||
| * "H_SWISH_GRAD": h_swish_grad | |||
| """ | |||
| _elemwise_mode_type = P.Elemwise.Mode | |||
| def __init__(self, method): | |||
| super().__init__() | |||
| self.method = self._elemwise_mode_type.convert(method) | |||
| def forward(self, *inps): | |||
| return _elwise(*inps, mode=self.method) | |||
| @@ -0,0 +1,171 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from typing import Optional | |||
| import numpy as np | |||
| from ..functional import embedding as embedding_func | |||
| from ..tensor_nn import Parameter | |||
| from . import init | |||
| from .module import Module | |||
| class Embedding(Module): | |||
| r""" | |||
| A simple lookup table that stores embeddings of a fixed dictionary and size. | |||
| This module is often used to store word embeddings and retrieve them using indices. | |||
| The input to the module is a list of indices, and the output is the corresponding word embeddings. | |||
| The indices should less than num_embeddings. | |||
| :param num_embeddings: size of embedding dictionary. | |||
| :param embedding_dim: size of each embedding vector. | |||
| :param padding_idx: should be set to None, not support now. | |||
| :param max_norm: should be set to None, not support now. | |||
| :param norm_type: should be set to None, not support now. | |||
| :param initial_weight: the learnable weights of the module of shape (num_embeddings, embedding_dim). | |||
| Examples: | |||
| .. testcode:: | |||
| import numpy as np | |||
| import megengine as mge | |||
| import megengine.module as M | |||
| weight = mge.tensor(np.array([(1.2,2.3,3.4,4.5,5.6),(0.1,1.1,2.1,3.1,4.1)], dtype=np.float32)) | |||
| data = mge.tensor(np.array([(0,1,1),(1,0,1),(0,0,1)], dtype=np.int32)) | |||
| embedding = M.Embedding(2, 5, initial_weight=weight) | |||
| output = embedding(data) | |||
| with np.printoptions(precision=6): | |||
| print(output.numpy()) | |||
| Outputs: | |||
| .. testoutput:: | |||
| [[[1.2 2.3 3.4 4.5 5.6] | |||
| [0.1 1.1 2.1 3.1 4.1] | |||
| [0.1 1.1 2.1 3.1 4.1]] | |||
| [[0.1 1.1 2.1 3.1 4.1] | |||
| [1.2 2.3 3.4 4.5 5.6] | |||
| [0.1 1.1 2.1 3.1 4.1]] | |||
| [[1.2 2.3 3.4 4.5 5.6] | |||
| [1.2 2.3 3.4 4.5 5.6] | |||
| [0.1 1.1 2.1 3.1 4.1]]] | |||
| """ | |||
| def __init__( | |||
| self, | |||
| num_embeddings: int, | |||
| embedding_dim: int, | |||
| padding_idx: Optional[int] = None, | |||
| max_norm: Optional[float] = None, | |||
| norm_type: Optional[float] = None, | |||
| initial_weight: Parameter = None, | |||
| ): | |||
| super().__init__() | |||
| if padding_idx is not None: | |||
| raise ValueError("Not support padding index now.") | |||
| if max_norm is not None or norm_type is not None: | |||
| raise ValueError("Not support weight normalize now.") | |||
| self.padding_idx = padding_idx | |||
| self.max_norm = max_norm | |||
| self.norm_type = norm_type | |||
| self.num_embeddings = num_embeddings | |||
| self.embedding_dim = embedding_dim | |||
| if initial_weight is None: | |||
| self.weight = Parameter( | |||
| np.random.uniform( | |||
| size=(self.num_embeddings, self.embedding_dim) | |||
| ).astype(np.float32) | |||
| ) | |||
| self.reset_parameters() | |||
| else: | |||
| if initial_weight.shape != (num_embeddings, embedding_dim): | |||
| raise ValueError( | |||
| "The weight shape should match num_embeddings and embedding_dim" | |||
| ) | |||
| self.weight = Parameter(initial_weight.numpy()) | |||
| def reset_parameters(self) -> None: | |||
| init.normal_(self.weight) | |||
| def forward(self, inputs): | |||
| return embedding_func(inputs, self.weight) | |||
| @classmethod | |||
| def from_pretrained( | |||
| cls, | |||
| embeddings: Parameter, | |||
| freeze: Optional[bool] = True, | |||
| padding_idx: Optional[int] = None, | |||
| max_norm: Optional[float] = None, | |||
| norm_type: Optional[float] = None, | |||
| ): | |||
| r""" | |||
| Creates Embedding instance from given 2-dimensional FloatTensor. | |||
| :param embeddings: Tensor contained weight for the embedding. | |||
| :param freeze: If ``True``, the weight does not get updated during the learning process. Default: ``True``. | |||
| :param padding_idx: should be set to None, not support Now. | |||
| :param max_norm: should be set to None, not support Now. | |||
| :param norm_type: should be set to None, not support Now. | |||
| Examples: | |||
| .. testcode:: | |||
| import numpy as np | |||
| import megengine as mge | |||
| import megengine.module as M | |||
| weight = mge.tensor(np.array([(1.2,2.3,3.4,4.5,5.6),(0.1,1.1,2.1,3.1,4.1)], dtype=np.float32)) | |||
| data = mge.tensor(np.array([(0,1,1),(1,0,1),(0,0,1)], dtype=np.int32)) | |||
| embedding = M.Embedding.from_pretrained(weight, freeze=False) | |||
| output = embedding(data) | |||
| print(output.numpy()) | |||
| Outputs: | |||
| .. testoutput:: | |||
| [[[1.2 2.3 3.4 4.5 5.6] | |||
| [0.1 1.1 2.1 3.1 4.1] | |||
| [0.1 1.1 2.1 3.1 4.1]] | |||
| [[0.1 1.1 2.1 3.1 4.1] | |||
| [1.2 2.3 3.4 4.5 5.6] | |||
| [0.1 1.1 2.1 3.1 4.1]] | |||
| [[1.2 2.3 3.4 4.5 5.6] | |||
| [1.2 2.3 3.4 4.5 5.6] | |||
| [0.1 1.1 2.1 3.1 4.1]]] | |||
| """ | |||
| embeddings_shape = embeddings.shape | |||
| embeddings_dim = len(embeddings_shape) | |||
| if embeddings_dim != 2: | |||
| raise ValueError("Embeddings parameter is expected to be 2-dimensional") | |||
| rows = embeddings_shape[0] | |||
| cols = embeddings_shape[1] | |||
| embedding = cls( | |||
| num_embeddings=rows, | |||
| embedding_dim=cols, | |||
| initial_weight=embeddings, | |||
| padding_idx=padding_idx, | |||
| max_norm=max_norm, | |||
| norm_type=norm_type, | |||
| ) | |||
| embedding.weight.requires_grad = not freeze | |||
| return embedding | |||
| @@ -0,0 +1,56 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import numpy as np | |||
| from ..functional import cambricon_subgraph, extern_opr_subgraph | |||
| from .module import Module | |||
| class CambriconSubgraph(Module): | |||
| r"""Load a serialized Cambricon subgraph. | |||
| See :func:`~.cambricon_subgraph` for more details. | |||
| """ | |||
| def __init__( | |||
| self, data, symbol, tensor_dim_mutable, | |||
| ): | |||
| super(CambriconSubgraph, self).__init__() | |||
| self._data = data | |||
| self.symbol = symbol | |||
| self.tensor_dim_mutable = tensor_dim_mutable | |||
| @property | |||
| def data(self): | |||
| return self._data.tobytes() | |||
| @data.setter | |||
| def data(self, val): | |||
| self._data = np.frombuffer(val, dtype=np.uint8) | |||
| def forward(self, inputs): | |||
| outputs = cambricon_subgraph( | |||
| inputs, self._data, self.symbol, self.tensor_dim_mutable, | |||
| ) | |||
| return outputs | |||
| class ExternOprSubgraph(Module): | |||
| r"""Load a serialized extern opr subgraph. | |||
| """ | |||
| def __init__(self, data, name, output_shapes): | |||
| super(ExternOprSubgraph, self).__init__() | |||
| self.data = data | |||
| self.name = name | |||
| self.output_shapes = output_shapes | |||
| def forward(self, inputs): | |||
| outputs = extern_opr_subgraph(inputs, self.output_shapes, self.name, self.data,) | |||
| return outputs | |||
| @@ -0,0 +1,17 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from ..functional import identity | |||
| from .module import Module | |||
| class Identity(Module): | |||
| r"""A placeholder identity operator that will ignore any argument.""" | |||
| def forward(self, x): | |||
| return identity(x) | |||
| @@ -0,0 +1,261 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import math | |||
| from functools import reduce | |||
| from typing import Optional, Tuple, Union | |||
| import numpy as np | |||
| from ..tensor import Tensor | |||
| def fill_(tensor: Tensor, val: Union[float, int]) -> None: | |||
| """Fill the given ``tensor`` with value ``val``. | |||
| :param tensor: An n-dimentional tensor to be initialized | |||
| :param val: The value to be filled throughout the tensor | |||
| """ | |||
| tensor.set_value(np.full(tensor.shape, val, tensor.dtype)) | |||
| def zeros_(tensor: Tensor) -> None: | |||
| """Fill the given ``tensor`` with scalar value `0`. | |||
| :param tensor: An n-dimentional tensor to be initialized | |||
| """ | |||
| fill_(tensor, 0) | |||
| def ones_(tensor: Tensor) -> None: | |||
| """Fill the given ``tensor`` with the scalar value `1`. | |||
| :param tensor: An n-dimentional tensor to be initialized | |||
| """ | |||
| fill_(tensor, 1) | |||
| def uniform_(tensor: Tensor, a: float = 0.0, b: float = 1.0) -> None: | |||
| r"""Fill the given ``tensor`` with random value sampled from uniform distribution | |||
| :math:`\mathcal{U}(\text{a}, \text{b})`. | |||
| :param tensor: An n-dimentional tensor to be initialized | |||
| :param a: Lower bound of the sampling interval | |||
| :param b: Upper bound of the sampling interval | |||
| """ | |||
| tensor.set_value(np.random.uniform(a, b, tensor.shape).astype(tensor.dtype)) | |||
| def normal_(tensor: Tensor, mean: float = 0.0, std: float = 1.0) -> None: | |||
| r"""Fill the given ``tensor`` with random value sampled from normal distribution | |||
| :math:`\mathcal{N}(\text{mean}, \text{std}^2)`. | |||
| :param tensor: An n-dimentional tensor to be initialized | |||
| :param mean: The mean of the normal distribution | |||
| :param std: The standard deviation of the normal distribution | |||
| """ | |||
| tensor.set_value(np.random.normal(mean, std, tensor.shape).astype(np.float32)) | |||
| def calculate_gain( | |||
| nonlinearity: str, param: Optional[Union[int, float]] = None | |||
| ) -> float: | |||
| r"""Return a recommended gain value (see the table below) for the given nonlinearity | |||
| function. | |||
| ================= ==================================================== | |||
| nonlinearity gain | |||
| ================= ==================================================== | |||
| Linear / Identity :math:`1` | |||
| Conv{1,2,3}D :math:`1` | |||
| Sigmoid :math:`1` | |||
| Tanh :math:`\frac{5}{3}` | |||
| ReLU :math:`\sqrt{2}` | |||
| Leaky Relu :math:`\sqrt{\frac{2}{1 + \text{negative_{slope}}^2}}` | |||
| ================= ==================================================== | |||
| :param nonlinearity: Name of the non-linear function | |||
| :param param: Optional parameter for leaky_relu. Only effective when | |||
| ``nonlinearity`` is "leaky_relu". | |||
| """ | |||
| linear_fns = [ | |||
| "linear", | |||
| "conv1d", | |||
| "conv2d", | |||
| "conv3d", | |||
| "conv_transpose1d", | |||
| "conv_transpose2d", | |||
| "conv_transpose3d", | |||
| ] | |||
| if nonlinearity in linear_fns or nonlinearity == "sigmoid": | |||
| return 1 | |||
| if nonlinearity == "tanh": | |||
| return 5.0 / 3 | |||
| if nonlinearity == "relu": | |||
| return math.sqrt(2.0) | |||
| if nonlinearity == "leaky_relu": | |||
| if param is None: | |||
| negative_slope = 0.01 | |||
| elif ( | |||
| not isinstance(param, bool) | |||
| and isinstance(param, int) | |||
| or isinstance(param, float) | |||
| ): | |||
| # True/False are instances of int, hence check above | |||
| negative_slope = param | |||
| else: | |||
| raise ValueError("negative_slope {} not a valid number".format(param)) | |||
| return math.sqrt(2.0 / (1 + negative_slope ** 2)) | |||
| raise ValueError("Unsupported nonlinearity {}".format(nonlinearity)) | |||
| def calculate_fan_in_and_fan_out(tensor: Tensor) -> Tuple[float, float]: | |||
| """ | |||
| Calculate fan_in / fan_out value for given weight tensor. This function assumes | |||
| input tensor is stored in NCHW format. | |||
| :param tensor: Weight tensor in NCHW format | |||
| """ | |||
| shape = tensor.shape | |||
| ndim = len(shape) | |||
| if ndim < 2: | |||
| raise ValueError( | |||
| "fan_in and fan_out can not be computed for tensor with fewer than 2 " | |||
| "dimensions" | |||
| ) | |||
| if ndim == 2: # Linear | |||
| fan_in = shape[1] | |||
| fan_out = shape[0] | |||
| else: | |||
| num_input_fmaps = shape[1] | |||
| num_output_fmaps = shape[0] | |||
| receptive_field_size = 1 | |||
| if ndim > 2: | |||
| receptive_field_size = reduce(lambda x, y: x * y, shape[2:], 1) | |||
| fan_in = num_input_fmaps * receptive_field_size | |||
| fan_out = num_output_fmaps * receptive_field_size | |||
| return fan_in, fan_out | |||
| def calculate_correct_fan(tensor: Tensor, mode: str) -> float: | |||
| """ | |||
| Calculate fan_in or fan_out value for given weight tensor, depending on given | |||
| ``mode``. | |||
| See :func:`calculate_fan_in_and_fan_out` for details. | |||
| :param tensor: Weight tensor in NCHW format | |||
| :param mode: ``'fan_in'`` or ``'fan_out'`` | |||
| """ | |||
| mode = mode.lower() | |||
| valid_modes = ["fan_in", "fan_out"] | |||
| if mode not in valid_modes: | |||
| raise ValueError( | |||
| "Mode {} not supported, please use one of {}".format(mode, valid_modes) | |||
| ) | |||
| fan_in, fan_out = calculate_fan_in_and_fan_out(tensor) | |||
| return fan_in if mode == "fan_in" else fan_out | |||
| def xavier_uniform_(tensor: Tensor, gain: float = 1.0) -> None: | |||
| r"""Fill ``tensor`` with random values sampled from :math:`\mathcal{U}(-a, a)` | |||
| where | |||
| .. math:: | |||
| a = \text{gain} \times \sqrt{\frac{6}{\text{fan_in} + \text{fan_out}}} | |||
| Also known as Glorot initialization. Detailed information can be retrieved from | |||
| `Understanding the difficulty of training deep feedforward neural networks` - | |||
| Glorot, X. & Bengio, Y. (2010). | |||
| :param tensor: An n-dimentional tensor to be initialized | |||
| :param gain: Scaling factor for :math:`a`. | |||
| """ | |||
| fan_in, fan_out = calculate_fan_in_and_fan_out(tensor) | |||
| std = gain * math.sqrt(2.0 / float(fan_in + fan_out)) | |||
| a = math.sqrt(3.0) * std | |||
| uniform_(tensor, -a, a) | |||
| def xavier_normal_(tensor: Tensor, gain: float = 1.0) -> None: | |||
| r"""Fill ``tensor`` with random values sampled from | |||
| :math:`\mathcal{N}(0, \text{std}^2)` where | |||
| .. math:: | |||
| \text{std} = \text{gain} \times \sqrt{\frac{2}{\text{fan_in} + \text{fan_out}}} | |||
| Also known as Glorot initialization. Detailed information can be retrieved from | |||
| `Understanding the difficulty of training deep feedforward neural networks` - | |||
| Glorot, X. & Bengio, Y. (2010). | |||
| :param tensor: An n-dimentional tensor to be initialized | |||
| :param gain: Scaling factor for :math:`std`. | |||
| """ | |||
| fan_in, fan_out = calculate_fan_in_and_fan_out(tensor) | |||
| std = gain * math.sqrt(2.0 / float(fan_in + fan_out)) | |||
| normal_(tensor, 0.0, std) | |||
| def msra_uniform_( | |||
| tensor: Tensor, a: float = 0, mode: str = "fan_in", nonlinearity: str = "leaky_relu" | |||
| ) -> None: | |||
| r"""Fill ``tensor`` wilth random values sampled from | |||
| :math:`\mathcal{U}(-\text{bound}, \text{bound})` where | |||
| .. math:: | |||
| \text{bound} = \sqrt{\frac{6}{(1 + a^2) \times \text{fan_in}}} | |||
| Detailed information can be retrieved from | |||
| `Delving deep into rectifiers: Surpassing human-level performance on ImageNet | |||
| classification` | |||
| :param tensor: An n-dimentional tensor to be initialized | |||
| :param a: Optional parameter for calculating gain for leaky_relu. See | |||
| :func:`calculate_gain` for details. | |||
| :param mode: ``'fan_in'`` or ``'fan_out'``, used to calculate :math:`gain`, the | |||
| scaling factor for :math:`bound`. See :func:`calculate_fan_in_and_fan_out` for | |||
| details. | |||
| :param nonlinearity: Name of the non-linear function used to calculate :math:`gain`. | |||
| See :func:`calculate_gain` for details. | |||
| """ | |||
| fan = calculate_correct_fan(tensor, mode) | |||
| gain = calculate_gain(nonlinearity, a) | |||
| std = gain / math.sqrt(fan) | |||
| bound = math.sqrt(3.0) * std | |||
| uniform_(tensor, -bound, bound) | |||
| def msra_normal_( | |||
| tensor: Tensor, a: float = 0, mode: str = "fan_in", nonlinearity: str = "leaky_relu" | |||
| ) -> None: | |||
| r"""Fill ``tensor`` wilth random values sampled from | |||
| :math:`\mathcal{N}(0, \text{std}^2)` where | |||
| .. math:: | |||
| \text{std} = \sqrt{\frac{2}{(1 + a^2) \times \text{fan_in}}} | |||
| Detailed information can be retrieved from | |||
| `Delving deep into rectifiers: Surpassing human-level performance on ImageNet | |||
| classification` | |||
| :param tensor: An n-dimentional tensor to be initialized | |||
| :param a: Optional parameter for calculating gain for leaky_relu. See | |||
| :func:`calculate_gain` for details. | |||
| :param mode: ``'fan_in'`` or ``'fan_out'``, used to calculate :math:`gain`, the | |||
| scaling factor for :math:`gain`. See :func:`calculate_fan_in_and_fan_out` for | |||
| details. | |||
| :param nonlinearity: Name of the non-linear function used to calculate :math:`gain`. | |||
| See :func:`calculate_gain` for details. | |||
| """ | |||
| fan = calculate_correct_fan(tensor, mode) | |||
| gain = calculate_gain(nonlinearity, a) | |||
| std = gain / math.sqrt(fan) | |||
| normal_(tensor, 0, std) | |||
| @@ -0,0 +1,61 @@ | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import numpy as np | |||
| from ..functional import linear | |||
| from ..tensor_nn import Parameter | |||
| from . import init | |||
| from .module import Module | |||
| class Linear(Module): | |||
| r"""Applies a linear transformation to the input. For instance, if input | |||
| is x, then output y is: | |||
| .. math:: | |||
| y = xW^T + b | |||
| where :math:`y_i= \sum_j W_{ij} x_j + b_i` | |||
| :param in_features: size of each input sample. | |||
| :param out_features: size of each output sample. | |||
| :param bias: If set to ``False``, the layer will not learn an additive bias. | |||
| Default: ``True`` | |||
| """ | |||
| def __init__( | |||
| self, in_features: int, out_features: int, bias: bool = True, **kwargs | |||
| ): | |||
| super().__init__(**kwargs) | |||
| self.out_features = out_features | |||
| self.in_features = in_features | |||
| w_shape = (out_features, in_features) | |||
| self.weight = Parameter(np.zeros(w_shape, dtype=np.float32)) | |||
| self.bias = None | |||
| if bias: | |||
| b_shape = (out_features,) | |||
| self.bias = Parameter(np.zeros(b_shape, dtype=np.float32)) | |||
| self.reset_parameters() | |||
| def _get_fanin(self): | |||
| return self.in_features | |||
| def reset_parameters(self) -> None: | |||
| fanin = self._get_fanin() | |||
| std = np.sqrt(1 / fanin) | |||
| init.normal_(self.weight, 0.0, std) | |||
| if self.bias is not None: | |||
| init.zeros_(self.bias) | |||
| def _calc_linear(self, x, weight, bias): | |||
| return linear(x, weight, bias) | |||
| def forward(self, x): | |||
| return self._calc_linear(x, self.weight, self.bias) | |||
| @@ -0,0 +1,508 @@ | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from abc import ABCMeta, abstractmethod | |||
| from collections import OrderedDict | |||
| from typing import Any, Callable, Iterable, Optional, Set, Tuple, Union | |||
| import numpy as np | |||
| from ..core.tensor.dtype import is_quantize | |||
| from ..logger import get_logger | |||
| from ..tensor import Tensor | |||
| from ..tensor_nn import Buffer, Parameter | |||
| from ..utils.hook import HookHandler | |||
| logger = get_logger(__name__) | |||
| def _expand_structure(key, obj): | |||
| if isinstance(obj, (Tensor, Module)): | |||
| return [(key, obj)] | |||
| elif isinstance(obj, (list, tuple, dict)): | |||
| ret = [] | |||
| if isinstance(obj, dict): | |||
| targets = ((k, obj[k]) for k in sorted(obj)) | |||
| else: | |||
| targets = ((str(k), v) for k, v in enumerate(obj)) | |||
| for k, o in targets: | |||
| sub_ret = _expand_structure(k, o) | |||
| if sub_ret and not isinstance(k, str): | |||
| raise AssertionError( | |||
| "keys for Tensor and Module must be str, error key: {}".format(k) | |||
| ) | |||
| for kt, vt in sub_ret: | |||
| ret.extend([(key + "." + kt, vt)]) | |||
| return ret | |||
| else: | |||
| return [] | |||
| def _is_parameter(obj): | |||
| return isinstance(obj, Parameter) | |||
| def _is_buffer(obj): | |||
| return isinstance(obj, Buffer) | |||
| def _is_module(obj): | |||
| return isinstance(obj, Module) | |||
| class Module(metaclass=ABCMeta): | |||
| """Base Module class. | |||
| """ | |||
| def __init__(self): | |||
| # runtime attributes | |||
| self.training = True | |||
| self.quantize_disabled = False | |||
| # hooks | |||
| self._forward_pre_hooks = OrderedDict() | |||
| self._forward_hooks = OrderedDict() | |||
| @abstractmethod | |||
| def forward(self, inputs): | |||
| pass | |||
| def register_forward_pre_hook(self, hook: Callable) -> HookHandler: | |||
| """Register a hook to handle forward inputs. `hook` should be a function | |||
| Note that `inputs` keyword inputs | |||
| :param hook: a function that receive `module` and `inputs`, then return | |||
| a modified `inputs` or `None`. | |||
| :return: a handler with :meth:`~.HookHandler.remove` interface to delete the hook. | |||
| """ | |||
| return HookHandler(self._forward_pre_hooks, hook) | |||
| def register_forward_hook(self, hook: Callable) -> HookHandler: | |||
| """Register a hook to handle forward results. `hook` should be a function that | |||
| receive `module`, `inputs` and `outputs`, then return a modified `outputs` or `None`. | |||
| This method return a handler with :meth:`~.HookHandler.remove` interface to delete the hook. | |||
| """ | |||
| return HookHandler(self._forward_hooks, hook) | |||
| def __call__(self, *inputs, **kwargs): | |||
| for hook in self._forward_pre_hooks.values(): | |||
| modified_inputs = hook(self, inputs) | |||
| if modified_inputs is not None: | |||
| if not isinstance(modified_inputs, tuple): | |||
| modified_inputs = (modified_inputs,) | |||
| inputs = modified_inputs | |||
| outputs = self.forward(*inputs, **kwargs) | |||
| for hook in self._forward_hooks.values(): | |||
| modified_outputs = hook(self, inputs, outputs) | |||
| if modified_outputs is not None: | |||
| outputs = modified_outputs | |||
| return outputs | |||
| def _flatten( | |||
| self, | |||
| *, | |||
| recursive: bool = True, | |||
| with_key: bool = False, | |||
| with_parent: bool = False, | |||
| prefix: Optional[str] = None, | |||
| predicate: Callable[[Any], bool] = lambda _: True, | |||
| seen: Optional[Set[int]] = None | |||
| ) -> Union[Iterable[Any], Iterable[Tuple[str, Any]]]: | |||
| """Scans the module object and returns an iterable for the :class:`~.Tensor` | |||
| and :class:`~.Module` attributes that agree with the ``predicate``. For multiple | |||
| calls of this function with same arguments, the order of objects within the | |||
| returned iterable is guaranteed to be identical, as long as all the involved | |||
| module objects' ``__dict__`` does not change thoughout those calls. | |||
| :param recursive: Whether to recursively scan all the submodules. | |||
| :param with_key: Whether to yield keys along with yielded objects. | |||
| :param with_parent: Whether to yield ``self`` along with yielded objects. | |||
| :param prefix: The prefix appended to the yielded keys. | |||
| :param predicate: The predicate function applied to scanned objects. | |||
| :param seen: A dict that records whether a module has been traversed yet. | |||
| """ | |||
| if seen is None: | |||
| seen = set([id(self)]) | |||
| module_dict = vars(self) | |||
| _prefix = "" if prefix is None else prefix + "." | |||
| for key in sorted(module_dict): | |||
| for expanded_key, leaf in _expand_structure(key, module_dict[key]): | |||
| leaf_id = id(leaf) | |||
| if leaf_id in seen: | |||
| continue | |||
| seen.add(leaf_id) | |||
| if predicate(leaf): | |||
| if with_key and with_parent: | |||
| yield _prefix + expanded_key, leaf, self | |||
| elif with_key: | |||
| yield _prefix + expanded_key, leaf | |||
| elif with_parent: | |||
| yield leaf, self | |||
| else: | |||
| yield leaf | |||
| if recursive and isinstance(leaf, Module): | |||
| yield from leaf._flatten( | |||
| recursive=recursive, | |||
| with_key=with_key, | |||
| with_parent=with_parent, | |||
| prefix=_prefix + expanded_key if with_key else None, | |||
| predicate=predicate, | |||
| seen=seen, | |||
| ) | |||
| def parameters( | |||
| self, requires_grad: Optional[bool] = None, recursive: bool = True, **kwargs | |||
| ) -> Iterable[Parameter]: | |||
| r"""Returns an iterable for the :class:`~.Parameter` of the module. | |||
| :param requires_grad: Limitation over the :attr:`~.Parameter.requires_grad` | |||
| attribute of returned :class:`.Parameter`. ``None`` for no limitation. | |||
| :param recursive: If ``True``, returns all :class:`~.Parameter` within this | |||
| module, else only returns :class:`~.Parameter` that are direct attributes | |||
| of this module. | |||
| """ | |||
| def predicate(obj) -> bool: | |||
| return _is_parameter(obj) and ( | |||
| requires_grad is None or obj.requires_grad == requires_grad | |||
| ) | |||
| yield from self._flatten( | |||
| with_key=False, predicate=predicate, recursive=recursive, **kwargs | |||
| ) | |||
| def named_parameters( | |||
| self, | |||
| requires_grad: Optional[bool] = None, | |||
| prefix: Optional[str] = None, | |||
| recursive: bool = True, | |||
| **kwargs | |||
| ) -> Iterable[Tuple[str, Parameter]]: | |||
| """Returns an iterable for key :class:`~.Parameter` pairs of the module, where | |||
| ``key`` is the dotted path from this module to the :class:`~.Parameter` . | |||
| :param requires_grad: Limitation over the :attr:`~.Parameter.requires_grad` | |||
| attribute of returned :class:`~.Parameter` . ``None`` for no limitation. | |||
| :param prefix: The prefix prepended to the keys. | |||
| :param recursive: If ``True``, returns all :class:`~.Parameter` within this | |||
| module, else only returns :class:`~.Parameter` that are direct attributes | |||
| of this module. | |||
| """ | |||
| def predicate(obj) -> bool: | |||
| return _is_parameter(obj) and ( | |||
| requires_grad is None or obj.requires_grad == requires_grad | |||
| ) | |||
| yield from self._flatten( | |||
| with_key=True, | |||
| prefix=prefix, | |||
| predicate=predicate, | |||
| recursive=recursive, | |||
| **kwargs, | |||
| ) | |||
| def buffers(self, recursive: bool = True, **kwargs) -> Iterable[Buffer]: | |||
| """Returns an iterable for the :class:`~.Buffer` of the module. | |||
| :param recursive: If ``True``, returns all :class:`~.Buffer` within this | |||
| module, else only returns :class:`~.Buffer` that are direct attributes | |||
| of this module. | |||
| """ | |||
| yield from self._flatten( | |||
| with_key=False, predicate=_is_buffer, recursive=recursive, **kwargs | |||
| ) | |||
| def named_buffers( | |||
| self, prefix: Optional[str] = None, recursive: bool = True, **kwargs | |||
| ) -> Iterable[Tuple[str, Buffer]]: | |||
| """Returns an iterable for key :class:`~.Buffer` pairs of the module, where | |||
| ``key`` is the dotted path from this module to the :class:`~.Buffer` . | |||
| :param prefix: The prefix prepended to the keys. | |||
| :param recursive: If ``True``, returns all :class:`~.Buffer` within this | |||
| module, else only returns :class:`~.Buffer` that are direct attributes | |||
| of this module. | |||
| """ | |||
| yield from self._flatten( | |||
| with_key=True, | |||
| prefix=prefix, | |||
| predicate=_is_buffer, | |||
| recursive=recursive, | |||
| **kwargs, | |||
| ) | |||
| def children(self, **kwargs) -> "Iterable[Module]": | |||
| """Returns an iterable for all the submodules that are direct attributes of this | |||
| module. | |||
| """ | |||
| yield from self._flatten( | |||
| with_key=False, predicate=_is_module, recursive=False, **kwargs | |||
| ) | |||
| def named_children(self, **kwargs) -> "Iterable[Tuple[str, Module]]": | |||
| """Returns an iterable of key-submodule pairs for all the submodules that are | |||
| direct attributes of this module, where 'key' is the attribute name of | |||
| submodules. | |||
| """ | |||
| yield from self._flatten( | |||
| with_key=True, predicate=_is_module, recursive=False, **kwargs | |||
| ) | |||
| def modules(self, **kwargs) -> "Iterable[Module]": | |||
| """Returns an iterable for all the modules within this module, including itself. | |||
| """ | |||
| if "with_parent" in kwargs and kwargs["with_parent"]: | |||
| yield self, None | |||
| else: | |||
| yield self | |||
| yield from self._flatten(with_key=False, predicate=_is_module, **kwargs) | |||
| def named_modules( | |||
| self, prefix: Optional[str] = None, **kwargs | |||
| ) -> "Iterable[Tuple[str, Module]]": | |||
| """Returns an iterable of key-module pairs for all the modules within this | |||
| module, including itself, where 'key' is the dotted path from this module to the | |||
| submodules. | |||
| :param prefix: The prefix prepended to the path. | |||
| """ | |||
| if "with_parent" in kwargs and kwargs["with_parent"]: | |||
| yield ("" if prefix is None else prefix), self, None | |||
| else: | |||
| yield ("" if prefix is None else prefix), self | |||
| yield from self._flatten( | |||
| with_key=True, prefix=prefix, predicate=_is_module, **kwargs | |||
| ) | |||
| def apply(self, fn: "Callable[[Module], Any]") -> None: | |||
| """Apply function ``fn`` to all the modules within this module, including | |||
| itself. | |||
| :param fn: The function to be applied on modules. | |||
| """ | |||
| for it in self.modules(): | |||
| fn(it) | |||
| def zero_grad(self) -> None: | |||
| """Set all parameters' grads to zero | |||
| """ | |||
| for param in self.parameters(): | |||
| if param.grad is not None: | |||
| param.grad.reset_zero() | |||
| def train(self, mode: bool = True, recursive: bool = True) -> None: | |||
| """Set training mode of all the modules within this module (including itself) to | |||
| ``mode``. This effectively sets the ``training`` attributes of those modules | |||
| to ``mode``, but only has effect on certain modules (e.g. | |||
| :class:`~.BatchNorm2d`, :class:`~.Dropout`, :class:`~.Observer`) | |||
| :param mode: the training mode to be set on modules. | |||
| :param recursive: whether to recursively call submodules' ``train()``. | |||
| """ | |||
| if not recursive: | |||
| self.training = mode | |||
| return | |||
| def fn(module: Module) -> None: | |||
| module.train(mode, recursive=False) | |||
| self.apply(fn) | |||
| def eval(self) -> None: | |||
| """Set training mode of all the modules within this module (including itself) to | |||
| ``False``. See :meth:`~.Module.train` for details. | |||
| """ | |||
| self.train(False) | |||
| def disable_quantize(self, value=True): | |||
| r""" | |||
| Set ``module``'s ``quantize_disabled`` attribute and return ``module``. | |||
| Could be used as a decorator. | |||
| """ | |||
| def fn(module: Module) -> None: | |||
| module.quantize_disabled = value | |||
| self.apply(fn) | |||
| def replace_param( | |||
| self, params: dict, start_pos: int, seen: Optional[Set[int]] = None | |||
| ): | |||
| """Replace module's parameters with `params`, used by :class:`~.ParamPack` to | |||
| speedup multimachine training. | |||
| """ | |||
| offset = 0 | |||
| if seen is None: | |||
| seen = set([id(self)]) | |||
| module_dict = vars(self) | |||
| for key in sorted(module_dict): | |||
| hash_id = id(module_dict[key]) | |||
| if hash_id in seen: | |||
| continue | |||
| seen.add(hash_id) | |||
| if isinstance(module_dict[key], Parameter): | |||
| if start_pos + offset in params: | |||
| assert module_dict[key].shape == params[start_pos + offset].shape | |||
| module_dict[key] = params[start_pos + offset] | |||
| offset += 1 | |||
| if isinstance(module_dict[key], Module): | |||
| offset += module_dict[key].replace_param( | |||
| params, start_pos + offset, seen | |||
| ) | |||
| return offset | |||
| def state_dict(self, rst=None, prefix="", keep_var=False): | |||
| r"""Returns a dictionary containing whole states of the module. | |||
| """ | |||
| def is_state(obj): | |||
| return _is_parameter(obj) or _is_buffer(obj) | |||
| if rst is None: | |||
| rst = OrderedDict() | |||
| for k, v in self._flatten(recursive=False, with_key=True, predicate=is_state): | |||
| assert prefix + k not in rst, "duplicated state: {}".format(k) | |||
| if keep_var: | |||
| rst[prefix + k] = v | |||
| else: | |||
| rst[prefix + k] = v.numpy() | |||
| for k, submodule in self._flatten( | |||
| recursive=False, | |||
| with_key=True, | |||
| predicate=lambda obj: isinstance(obj, Module), | |||
| ): | |||
| submodule.state_dict(rst, prefix + k + ".", keep_var) | |||
| return rst | |||
| def load_state_dict( | |||
| self, | |||
| state_dict: Union[dict, Callable[[str, Tensor], Optional[np.ndarray]]], | |||
| strict=True, | |||
| ): | |||
| r"""Load a given dictionary created by :func:`state_dict` into this module. | |||
| If ``strict`` is ``True``, the keys of :func:`state_dict` must exactly match the keys | |||
| returned by :func:`state_dict`. | |||
| Users can also pass a closure: `Function[key: str, var: Tensor] -> Optional[np.ndarray]` | |||
| as a `state_dict`, in order to handle complex situations. For example, load everything | |||
| except for the final linear classifier: | |||
| .. code-block:: | |||
| state_dict = {...} # Dict[str, np.ndarray] | |||
| model.load_state_dict({ | |||
| k: None if k.startswith('fc') else v | |||
| for k, v in state_dict.items() | |||
| }, strict=False) | |||
| Here returning `None` means skipping parameter `k`. | |||
| To prevent shape mismatch (e.g. load PyTorch weights), we can reshape before loading: | |||
| .. code-block:: | |||
| state_dict = {...} | |||
| def reshape_accordingly(k, v): | |||
| return state_dict[k].reshape(v.shape) | |||
| model.load_state_dict(reshape_accordingly) | |||
| We can also perform inplace re-initialization or pruning: | |||
| .. code-block:: | |||
| def reinit_and_pruning(k, v): | |||
| if 'bias' in k: | |||
| M.init.zero_(v) | |||
| if 'conv' in k: | |||
| return v.numpy() * (np.abs(v.numpy()) > 1e-3).astype("float32) | |||
| model.load_state_dict(reinit_and_pruning, strict=False) | |||
| """ | |||
| unused = [] | |||
| if isinstance(state_dict, dict): | |||
| unused = state_dict.keys() | |||
| def closure(k, _): # var unused | |||
| return state_dict[k] if k in state_dict else None | |||
| elif callable(state_dict): | |||
| closure = state_dict | |||
| else: | |||
| raise ValueError( | |||
| "`state_dict` must load a dict or callable, got {}".format( | |||
| type(state_dict) | |||
| ) | |||
| ) | |||
| loaded, skipped = self._load_state_dict_with_closure(closure) | |||
| unused = set(unused) - loaded | |||
| if len(unused) != 0: | |||
| if strict: | |||
| raise KeyError( | |||
| "Unused params violate `strict=True`, unused={}".format(unused) | |||
| ) | |||
| else: | |||
| logger.warning( | |||
| "Unused params in `strict=False` mode, unused={}".format(unused) | |||
| ) | |||
| if len(skipped) != 0: | |||
| if strict: | |||
| raise KeyError( | |||
| "Missing params violate `strict=True`, missing={}".format(skipped) | |||
| ) | |||
| else: | |||
| logger.warning( | |||
| "Missing params in `strict=False` mode, missing={}".format(skipped) | |||
| ) | |||
| def _load_state_dict_with_closure(self, closure): | |||
| """Advance state_dict load through callable `closure` whose signature is | |||
| `closure(key: str, var: Tensor) -> Union[np.ndarry, None]` | |||
| """ | |||
| assert callable(closure), "closure must be a function" | |||
| loaded = [] | |||
| skipped = [] | |||
| local_state_dict = self.state_dict(keep_var=True) | |||
| for k, var in local_state_dict.items(): | |||
| to_be_load = closure(k, var) | |||
| if to_be_load is None: | |||
| skipped.append(k) | |||
| continue | |||
| assert isinstance( | |||
| to_be_load, np.ndarray | |||
| ), "closure should return a `np.ndarray`, now `{}` get {}".format( | |||
| k, to_be_load | |||
| ) | |||
| assert ( | |||
| var.shape == to_be_load.shape | |||
| ), "param `{}` shape mismatch, should be {}, get {}".format( | |||
| k, var.shape, to_be_load.shape | |||
| ) | |||
| # For quantized dtype, the initialized dtype | |||
| # scale/zero_points maybe invalid, use pretrained dtype instead. | |||
| if is_quantize(to_be_load.dtype) and is_quantize(var.dtype): | |||
| var = var.astype(to_be_load.dtype) | |||
| var.set_value(to_be_load) | |||
| loaded.append(k) | |||
| return set(loaded), set(skipped) | |||
| @@ -0,0 +1,156 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import collections | |||
| from typing import Callable, Iterable, Optional, Tuple | |||
| import numpy as np | |||
| from ..tensor_nn import Parameter, Tensor | |||
| from .module import Module | |||
| class ParamPack(Module): | |||
| r"""Pack module's parameters by gathering their memory to continuous address. | |||
| Using (device, dtype, requires_grad) as key, for example ('gpu0', float32, True), | |||
| parameters with same key will be packed togather. | |||
| It helps a lot for multimachine training by speeding up allreduce gradients. | |||
| :param model: the module you want to pack parameters. | |||
| :param nr_ignore_first: how many parameters will be unpacked at first. | |||
| :param max_size_per_group: upper bound of packed parameters' size in MB. | |||
| :param max_nr_params_per_group: upper bound of the number of parameters of each group. | |||
| """ | |||
| def __init__( | |||
| self, | |||
| model: Module, | |||
| nr_ignore_first: int = 8, | |||
| max_size_per_group: int = 10, | |||
| max_nr_params_per_group: int = 100, | |||
| group_func: Callable = lambda name, param: 0, | |||
| ): | |||
| super().__init__() | |||
| self._model = model | |||
| self._nr_ignore_first = nr_ignore_first | |||
| self._max_size_per_group = max_size_per_group | |||
| self._max_nr_params_per_group = max_nr_params_per_group | |||
| self._group_func = group_func | |||
| self._grouped_params = [] | |||
| self._packed_params = [] | |||
| params = model.named_parameters() | |||
| self._pack_params(params) | |||
| def parameters(self, requires_grad: Optional[bool] = None) -> Iterable[Parameter]: | |||
| for param in self._packed_params: | |||
| if requires_grad is None or param.requires_grad == requires_grad: | |||
| yield param | |||
| def named_parameters( | |||
| self, requires_grad: Optional[bool] = None | |||
| ) -> Iterable[Tuple[str, Parameter]]: | |||
| for idx, param in enumerate(self._packed_params): | |||
| if requires_grad is None or param.requires_grad == requires_grad: | |||
| yield "packed_param_" + str(idx), param | |||
| def _pack_params(self, params: Iterable[Tuple[str, Parameter]]): | |||
| groups = collections.defaultdict(list) | |||
| ignored = 0 | |||
| param_id = 0 | |||
| for name, param in params: | |||
| if self._nr_ignore_first > ignored: | |||
| ignored += 1 | |||
| self._grouped_params.append([{"shape": param.shape, "id": param_id}]) | |||
| param.pack_group_key = self._group_func(name, param) | |||
| self._packed_params.append(param) | |||
| else: | |||
| key = ( | |||
| param.dtype, | |||
| param.device, | |||
| param.requires_grad, | |||
| self._group_func(name, param), | |||
| ) | |||
| groups[key].append({"tensor": param, "id": param_id}) | |||
| param_id += 1 | |||
| for (dtype, device, requires_grad, group_key) in groups.keys(): | |||
| dtype_sz = np.dtype(dtype).itemsize | |||
| align = device.mem_align | |||
| if align < dtype_sz: | |||
| align = 1 | |||
| else: | |||
| assert align % dtype_sz == 0 | |||
| align //= dtype_sz | |||
| group = groups[(dtype, device, requires_grad, group_key)] | |||
| while group: | |||
| aligned_pos = [] | |||
| offset = 0 | |||
| params = [] | |||
| idx = 0 | |||
| while idx < len(group): | |||
| param = group[idx] | |||
| assert param["tensor"].device == device | |||
| padding = (align - (offset & (align - 1))) & (align - 1) | |||
| offset += padding | |||
| aligned_pos.append(offset) | |||
| params.append(param) | |||
| offset += int(np.prod(param["tensor"].shape)) | |||
| idx += 1 | |||
| if ( | |||
| offset * dtype_sz >= self._max_size_per_group * 1024 * 1024 | |||
| or idx >= self._max_nr_params_per_group | |||
| ): | |||
| break | |||
| group = group[idx:] | |||
| if idx == 1: | |||
| # ignore param packs with only one item | |||
| params[0]["tensor"].pack_group_key = group_key | |||
| self._packed_params.append(params[0]["tensor"]) | |||
| self._grouped_params.append( | |||
| [{"shape": params[0]["tensor"].shape, "id": params[0]["id"]}] | |||
| ) | |||
| continue | |||
| packed_value = np.zeros((offset,), dtype=dtype) | |||
| for param, pos in zip(params, aligned_pos): | |||
| val = param["tensor"].numpy() | |||
| packed_value[pos : pos + val.size] = val.flatten() | |||
| new_param = Parameter( | |||
| value=packed_value, | |||
| device=device, | |||
| dtype=dtype, | |||
| requires_grad=requires_grad, | |||
| ) | |||
| new_param.pack_group_key = group_key | |||
| self._packed_params.append(new_param) | |||
| self._grouped_params.append( | |||
| [{"shape": i["tensor"].shape, "id": i["id"]} for i in params] | |||
| ) | |||
| def forward(self, *args, **kwargs): | |||
| replace_param = dict() | |||
| for i in range(len(self._packed_params)): | |||
| packed_param = self._packed_params[i] | |||
| grouped_params = self._grouped_params[i] | |||
| if len(grouped_params) == 1: | |||
| continue | |||
| split = param_pack_split( | |||
| packed_param._symvar, [i["shape"] for i in grouped_params] | |||
| ) | |||
| split = [ | |||
| Parameter(Tensor(i, requires_grad=packed_param.requires_grad)) | |||
| for i in split | |||
| ] | |||
| for j in range(len(split)): | |||
| replace_param[grouped_params[j]["id"]] = split[j] | |||
| self._model.replace_param(replace_param, 0) | |||
| return self._model.forward(*args, **kwargs) | |||
| @@ -0,0 +1,80 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from abc import abstractmethod | |||
| from typing import Tuple, Union | |||
| from ..functional import avg_pool2d, max_pool2d | |||
| from .module import Module | |||
| class _PoolNd(Module): | |||
| def __init__( | |||
| self, | |||
| kernel_size: Union[int, Tuple[int, int]], | |||
| stride: Union[int, Tuple[int, int]] = None, | |||
| padding: Union[int, Tuple[int, int]] = 0, | |||
| ): | |||
| super(_PoolNd, self).__init__() | |||
| self.kernel_size = kernel_size | |||
| self.stride = stride or kernel_size | |||
| self.padding = padding | |||
| @abstractmethod | |||
| def forward(self, inp): | |||
| pass | |||
| class MaxPool2d(_PoolNd): | |||
| r"""Applies a 2D max pooling over an input. | |||
| For instance, given an input of the size :math:`(N, C, H, W)` and | |||
| :attr:`kernel_size` :math:`(kH, kW)`, this layer generates the output of | |||
| the size :math:`(N, C, H_{out}, W_{out})` through a process described as: | |||
| .. math:: | |||
| \begin{aligned} | |||
| out(N_i, C_j, h, w) ={} & \max_{m=0, \ldots, kH-1} \max_{n=0, \ldots, kW-1} | |||
| \text{input}(N_i, C_j, \text{stride[0]} \times h + m, | |||
| \text{stride[1]} \times w + n) | |||
| \end{aligned} | |||
| If :attr:`padding` is non-zero, then the input is implicitly zero-padded on | |||
| both sides for :attr:`padding` number of points. | |||
| :param kernel_size: the size of the window to take a max over. | |||
| :param stride: the stride of the window. Default value is ``kernel_size``. | |||
| :param padding: implicit zero padding to be added on both sides. | |||
| """ | |||
| def forward(self, inp): | |||
| return max_pool2d(inp, self.kernel_size, self.stride, self.padding) | |||
| class AvgPool2d(_PoolNd): | |||
| r"""Applies a 2D average pooling over an input. | |||
| For instance, given an input of the size :math:`(N, C, H, W)` and | |||
| :attr:`kernel_size` :math:`(kH, kW)`, this layer generates the output of | |||
| the size :math:`(N, C, H_{out}, W_{out})` through a process described as: | |||
| .. math:: | |||
| out(N_i, C_j, h, w) = \frac{1}{kH * kW} \sum_{m=0}^{kH-1} \sum_{n=0}^{kW-1} | |||
| input(N_i, C_j, stride[0] \times h + m, stride[1] \times w + n) | |||
| If :attr:`padding` is non-zero, then the input is implicitly zero-padded on | |||
| both sides for :attr:`padding` number of points. | |||
| :param kernel_size: the size of the window. | |||
| :param stride: the stride of the window. Default value is ``kernel_size``. | |||
| :param padding: implicit zero padding to be added on both sides. | |||
| """ | |||
| def forward(self, inp): | |||
| return avg_pool2d(inp, self.kernel_size, self.stride, self.padding) | |||
| @@ -0,0 +1,14 @@ | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from .concat import Concat | |||
| from .conv import Conv2d, ConvRelu2d | |||
| from .conv_bn import ConvBn2d, ConvBnRelu2d | |||
| from .elemwise import Elemwise | |||
| from .linear import Linear | |||
| from .module import QATModule | |||
| from .quant_dequant import DequantStub, QuantStub | |||
| @@ -0,0 +1,30 @@ | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from typing import Iterable | |||
| from ...tensor import Tensor | |||
| from .. import concat as Float | |||
| from .module import QATModule | |||
| class Concat(Float.Concat, QATModule): | |||
| r""" | |||
| A :class:`~.QATModule` to do functional concat with QAT support. | |||
| Could be applied with :class:`~.Observer` and :class:`~.FakeQuantize`. | |||
| """ | |||
| def forward(self, inps: Iterable[Tensor], axis: int = 0): | |||
| return self.apply_quant_activation(super().forward(inps, axis)) | |||
| @classmethod | |||
| def from_float_module(cls, float_module): | |||
| r""" | |||
| Return a :class:`~.QATModule` instance converted from | |||
| a float :class:`~.Module` instance. | |||
| """ | |||
| return cls() | |||
| @@ -0,0 +1,59 @@ | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from ... import functional as F | |||
| from ...quantization.utils import fake_quant_bias | |||
| from .. import conv as Float | |||
| from .module import QATModule | |||
| class Conv2d(Float.Conv2d, QATModule): | |||
| r""" | |||
| A :class:`~.QATModule` Conv2d with QAT support. | |||
| Could be applied with :class:`~.Observer` and :class:`~.FakeQuantize`. | |||
| """ | |||
| def calc_conv_qat(self, inp): | |||
| w_qat = self.apply_quant_weight(self.weight) | |||
| b_qat = fake_quant_bias(self.bias, inp, w_qat) | |||
| conv = self.calc_conv(inp, w_qat, b_qat) | |||
| return conv | |||
| @classmethod | |||
| def from_float_module(cls, float_module: Float.Conv2d): | |||
| r""" | |||
| Return a :class:`~.QATModule` instance converted from | |||
| a float :class:`~.Module` instance. | |||
| """ | |||
| qat_module = cls( | |||
| float_module.in_channels, | |||
| float_module.out_channels, | |||
| float_module.kernel_size, | |||
| float_module.stride, | |||
| float_module.padding, | |||
| float_module.dilation, | |||
| float_module.groups, | |||
| float_module.bias is not None, | |||
| float_module.conv_mode.name, | |||
| float_module.compute_mode.name, | |||
| ) | |||
| qat_module.weight = float_module.weight | |||
| qat_module.bias = float_module.bias | |||
| return qat_module | |||
| def forward(self, inp): | |||
| return self.apply_quant_activation(self.calc_conv_qat(inp)) | |||
| class ConvRelu2d(Conv2d): | |||
| r""" | |||
| A :class:`~.QATModule` include Conv2d and Relu with QAT support. | |||
| Could be applied with :class:`~.Observer` and :class:`~.FakeQuantize`. | |||
| """ | |||
| def forward(self, inp): | |||
| return self.apply_quant_activation(F.relu(self.calc_conv_qat(inp))) | |||
| @@ -0,0 +1,193 @@ | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from ...functional import add_update, ones, relu, sqrt, sum, zeros | |||
| from ...quantization.utils import fake_quant_bias | |||
| from .. import conv_bn as Float | |||
| from .module import QATModule | |||
| class _ConvBnActivation2d(Float._ConvBnActivation2d, QATModule): | |||
| def get_batch_mean_var(self, inp): | |||
| def _sum_channel(inp, axis=0, keepdims=True): | |||
| if isinstance(axis, int): | |||
| out = sum(inp, axis=axis, keepdims=keepdims) | |||
| elif isinstance(axis, tuple): | |||
| for idx, elem in enumerate(axis): | |||
| out = sum(inp if idx == 0 else out, axis=elem, keepdims=keepdims) | |||
| return out | |||
| sum1 = _sum_channel(inp, (0, 2, 3)) | |||
| sum2 = _sum_channel(inp ** 2, (0, 2, 3)) | |||
| reduce_size = inp.size / inp.shape[1] | |||
| batch_mean = sum1 / reduce_size | |||
| batch_var = (sum2 - sum1 ** 2 / reduce_size) / reduce_size | |||
| return batch_mean, batch_var | |||
| def fold_weight_bias(self, bn_mean, bn_var): | |||
| # get fold bn conv param | |||
| # bn_istd = 1 / bn_std | |||
| # w_fold = gamma / bn_std * W | |||
| # b_fold = gamma * (b - bn_mean) / bn_std + beta | |||
| gamma = self.bn.weight | |||
| if gamma is None: | |||
| gamma = ones((self.bn.num_features), dtype="float32") | |||
| gamma = gamma.reshape(1, -1, 1, 1) | |||
| beta = self.bn.bias | |||
| if beta is None: | |||
| beta = zeros((self.bn.num_features), dtype="float32") | |||
| beta = beta.reshape(1, -1, 1, 1) | |||
| if bn_mean is None: | |||
| bn_mean = zeros((1, self.bn.num_features, 1, 1), dtype="float32") | |||
| if bn_var is None: | |||
| bn_var = ones((1, self.bn.num_features, 1, 1), dtype="float32") | |||
| conv_bias = self.conv.bias | |||
| if conv_bias is None: | |||
| conv_bias = zeros(self.conv._infer_bias_shape(), dtype="float32") | |||
| bn_istd = 1.0 / sqrt(bn_var + self.bn.eps) | |||
| # bn_istd = 1 / bn_std | |||
| # w_fold = gamma / bn_std * W | |||
| scale_factor = gamma * bn_istd | |||
| if self.conv.groups == 1: | |||
| w_fold = self.conv.weight * scale_factor.reshape(-1, 1, 1, 1) | |||
| else: | |||
| w_fold = self.conv.weight * scale_factor.reshape( | |||
| self.conv.groups, -1, 1, 1, 1 | |||
| ) | |||
| w_fold = self.apply_quant_weight(w_fold) | |||
| # b_fold = gamma * (b - bn_mean) / bn_std + beta | |||
| b_fold = beta + gamma * (conv_bias - bn_mean) * bn_istd | |||
| return w_fold, b_fold | |||
| def update_running_mean_and_running_var( | |||
| self, bn_mean, bn_var, num_elements_per_channel | |||
| ): | |||
| # update running mean and running var. no grad, use unbiased bn var | |||
| bn_mean = bn_mean.detach() | |||
| bn_var = ( | |||
| bn_var.detach() * num_elements_per_channel / (num_elements_per_channel - 1) | |||
| ) | |||
| exponential_average_factor = 1 - self.bn.momentum | |||
| add_update( | |||
| self.bn.running_mean, | |||
| delta=bn_mean, | |||
| alpha=1 - exponential_average_factor, | |||
| beta=exponential_average_factor, | |||
| ) | |||
| add_update( | |||
| self.bn.running_var, | |||
| delta=bn_var, | |||
| alpha=1 - exponential_average_factor, | |||
| beta=exponential_average_factor, | |||
| ) | |||
| def calc_conv_bn_qat(self, inp, approx=True): | |||
| if self.training and not approx: | |||
| conv = self.conv(inp) | |||
| bn_mean, bn_var = self.get_batch_mean_var(conv) | |||
| num_elements_per_channel = conv.size / conv.shape[1] | |||
| self.update_running_mean_and_running_var( | |||
| bn_mean, bn_var, num_elements_per_channel | |||
| ) | |||
| else: | |||
| bn_mean, bn_var = self.bn.running_mean, self.bn.running_var | |||
| # get gamma and beta in BatchNorm | |||
| gamma = self.bn.weight | |||
| if gamma is None: | |||
| gamma = ones((self.bn.num_features), dtype="float32") | |||
| gamma = gamma.reshape(1, -1, 1, 1) | |||
| beta = self.bn.bias | |||
| if beta is None: | |||
| beta = zeros((self.bn.num_features), dtype="float32") | |||
| beta = beta.reshape(1, -1, 1, 1) | |||
| # conv_bias | |||
| conv_bias = self.conv.bias | |||
| if conv_bias is None: | |||
| conv_bias = zeros(self.conv._infer_bias_shape(), dtype="float32") | |||
| bn_istd = 1.0 / sqrt(bn_var + self.bn.eps) | |||
| # bn_istd = 1 / bn_std | |||
| # w_fold = gamma / bn_std * W | |||
| scale_factor = gamma * bn_istd | |||
| if self.conv.groups == 1: | |||
| w_fold = self.conv.weight * scale_factor.reshape(-1, 1, 1, 1) | |||
| else: | |||
| w_fold = self.conv.weight * scale_factor.reshape( | |||
| self.conv.groups, -1, 1, 1, 1 | |||
| ) | |||
| b_fold = None | |||
| if not (self.training and approx): | |||
| # b_fold = gamma * (conv_bias - bn_mean) / bn_std + beta | |||
| b_fold = beta + gamma * (conv_bias - bn_mean) * bn_istd | |||
| w_qat = self.apply_quant_weight(w_fold) | |||
| b_qat = fake_quant_bias(b_fold, inp, w_qat) | |||
| conv = self.conv.calc_conv(inp, w_qat, b_qat) | |||
| if not (self.training and approx): | |||
| return conv | |||
| # rescale conv to get original conv output | |||
| orig_conv = conv / scale_factor.reshape(1, -1, 1, 1) | |||
| if self.conv.bias is not None: | |||
| orig_conv = orig_conv + self.conv.bias | |||
| # calculate batch norm | |||
| bn_mean, bn_var = self.get_batch_mean_var(orig_conv) | |||
| bn_istd = 1.0 / sqrt(bn_var + self.bn.eps) | |||
| conv = gamma * bn_istd * (orig_conv - bn_mean) + beta | |||
| num_elements_per_channel = conv.size / conv.shape[1] | |||
| self.update_running_mean_and_running_var( | |||
| bn_mean, bn_var, num_elements_per_channel | |||
| ) | |||
| return conv | |||
| @classmethod | |||
| def from_float_module(cls, float_module: Float._ConvBnActivation2d): | |||
| r""" | |||
| Return a :class:`~.QATModule` instance converted from | |||
| a float :class:`~.Module` instance. | |||
| """ | |||
| qat_module = cls( | |||
| float_module.conv.in_channels, | |||
| float_module.conv.out_channels, | |||
| float_module.conv.kernel_size, | |||
| float_module.conv.stride, | |||
| float_module.conv.padding, | |||
| float_module.conv.dilation, | |||
| float_module.conv.groups, | |||
| float_module.conv.bias is not None, | |||
| float_module.conv.conv_mode.name, | |||
| float_module.conv.compute_mode.name, | |||
| ) | |||
| qat_module.conv.weight = float_module.conv.weight | |||
| qat_module.conv.bias = float_module.conv.bias | |||
| qat_module.bn = float_module.bn | |||
| return qat_module | |||
| class ConvBn2d(_ConvBnActivation2d): | |||
| r""" | |||
| A fused :class:`~.QATModule` including Conv2d, BatchNorm2d with QAT support. | |||
| Could be applied with :class:`~.Observer` and :class:`~.FakeQuantize`. | |||
| """ | |||
| def forward(self, inp): | |||
| return self.apply_quant_activation(self.calc_conv_bn_qat(inp)) | |||
| class ConvBnRelu2d(_ConvBnActivation2d): | |||
| r""" | |||
| A fused :class:`~.QATModule` including Conv2d, BatchNorm2d and relu with QAT support. | |||
| Could be applied with :class:`~.Observer` and :class:`~.FakeQuantize`. | |||
| """ | |||
| def forward(self, inp): | |||
| return self.apply_quant_activation(relu(self.calc_conv_bn_qat(inp))) | |||