GitOrigin-RevId: 11411b6964
tags/v1.0.0-rc1
| @@ -247,10 +247,6 @@ if(MGE_BUILD_IMPERATIVE_RT) | |||||
| set(CMAKE_CXX_STANDARD 17) | set(CMAKE_CXX_STANDARD 17) | ||||
| endif() | endif() | ||||
| if(MGE_BUILD_IMPERATIVE_RT) | |||||
| set(MGE_BUILD_SDK OFF) | |||||
| endif() | |||||
| if(NOT MGE_WITH_CUDA) | if(NOT MGE_WITH_CUDA) | ||||
| message("-- Disable distributed support, as CUDA is not enabled.") | message("-- Disable distributed support, as CUDA is not enabled.") | ||||
| set(MGE_WITH_DISTRIBUTED OFF) | set(MGE_WITH_DISTRIBUTED OFF) | ||||
| @@ -697,9 +693,7 @@ if(MGE_WITH_PYTHON_MODULE) | |||||
| endif() | endif() | ||||
| if(MGE_WITH_TEST AND MGE_ENABLE_RTTI) | if(MGE_WITH_TEST AND MGE_ENABLE_RTTI) | ||||
| if(NOT MGE_BUILD_IMPERATIVE_RT) | |||||
| add_subdirectory(test) | |||||
| endif() | |||||
| add_subdirectory(test) | |||||
| endif() | endif() | ||||
| if(TARGET mgb) | if(TARGET mgb) | ||||
| @@ -66,9 +66,7 @@ if(MGE_WITH_CUDA) | |||||
| endif() | endif() | ||||
| if(MGE_WITH_TEST) | if(MGE_WITH_TEST) | ||||
| if(NOT MGE_BUILD_IMPERATIVE_RT) | |||||
| add_subdirectory(test) | |||||
| endif() | |||||
| add_subdirectory(test) | |||||
| endif() | endif() | ||||
| add_subdirectory(src) | add_subdirectory(src) | ||||
| @@ -0,0 +1,5 @@ | |||||
| Makefile | |||||
| /test/imperative_test | |||||
| *.so | |||||
| /python/megengine/core/ops/_internal/generated_ops.py | |||||
| /python/megengine/core/ops/_internal/param_defs.py | |||||
| @@ -0,0 +1,110 @@ | |||||
| find_package(NumPy REQUIRED) | |||||
| set(PACKAGE_NAME megengine) | |||||
| set(PACKAGE_NAME ${PACKAGE_NAME} PARENT_SCOPE) | |||||
| set(MODULE_NAME _imperative_rt) | |||||
| set(MODULE_NAME ${MODULE_NAME} PARENT_SCOPE) | |||||
| file(GLOB_RECURSE SRCS src/impl/*.cpp src/include/*.h python/src/*.cpp python/src/*.h) | |||||
| set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DMGB_WITH_IMPERATIVE=1") | |||||
| file(GLOB_RECURSE OPR_DECL_SRCS "${PROJECT_SOURCE_DIR}/src/**/*.oprdecl") | |||||
| file(GLOB_RECURSE PYTHON_SRCS python/${PACKAGE_NAME}/*.py) | |||||
| list(REMOVE_ITEM PYTHON_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/python/megengine/core/ops/_internal/generated_ops.py ${CMAKE_CURRENT_SOURCE_DIR}/python/megengine/core/ops/_internal/param_defs.py) | |||||
| file(GLOB_RECURSE ALL_HEADERS src/cpp/megbrain_pubapi.h | |||||
| ${PROJECT_SOURCE_DIR}/src/core/include/* | |||||
| ${PROJECT_SOURCE_DIR}/src/opr/include/* | |||||
| ${PROJECT_SOURCE_DIR}/src/serialization/include/* | |||||
| ${PROJECT_SOURCE_DIR}/src/plugin/include/* | |||||
| ${PROJECT_SOURCE_DIR}/dnn/include/*) | |||||
| set(MEGENGINE_DIR ${CMAKE_CURRENT_BINARY_DIR}/python/) | |||||
| set(GEN_OPS_DIR ${MEGENGINE_DIR}/${PACKAGE_NAME}/core/ops/_internal) | |||||
| file(MAKE_DIRECTORY ${GEN_OPS_DIR}) | |||||
| set(GEN_OPS_FILE ${GEN_OPS_DIR}/generated_ops.py) | |||||
| set(GEN_OP_PARAMS_FILE ${MEGENGINE_DIR}/${PACKAGE_NAME}/core/ops/_internal/param_defs.py) | |||||
| set(GEN_OP_PARAMS_TEMPLATE ${CMAKE_CURRENT_SOURCE_DIR}/python/tools/ops.tpl.py) | |||||
| ##################### generate python opr_param_defs.py ############## | |||||
| file(COPY ${PROJECT_SOURCE_DIR}/dnn/scripts/opr_param_defs.py DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) | |||||
| file(READ ${PROJECT_SOURCE_DIR}/tools/param_defs/mgb_opr_param_defs.py CONTENTS) | |||||
| file(APPEND ${CMAKE_CURRENT_BINARY_DIR}/opr_param_defs.py ${CONTENTS}) | |||||
| add_custom_command( | |||||
| OUTPUT ${GEN_OPS_FILE} | |||||
| COMMAND ${CMAKE_COMMAND} -E touch ${MEGENGINE_DIR}/${PACKAGE_NAME}/core/${MODULE_NAME}.so ${GEN_OPS_FILE} ${GEN_OP_PARAMS_FILE} | |||||
| COMMAND ${CMAKE_COMMAND} -E copy_directory ${CMAKE_CURRENT_SOURCE_DIR}/python/${PACKAGE_NAME} ${MEGENGINE_DIR}/${PACKAGE_NAME} | |||||
| COMMAND ${CMAKE_COMMAND} -E remove -f ${MEGENGINE_DIR}/${PACKAGE_NAME}/core/${MODULE_NAME}.so ${GEN_OPS_FILE} ${GEN_OP_PARAMS_FILE} | |||||
| COMMAND ${PYTHON_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/python/tools/gen_ops.py ${OPR_DECL_SRCS} -o ${GEN_OPS_FILE} | |||||
| COMMAND ${CMAKE_COMMAND} -E copy_directory ${CMAKE_CURRENT_SOURCE_DIR}/python/test ${MEGENGINE_DIR}/${PACKAGE_NAME}/test | |||||
| COMMAND ${PYTHON_EXECUTABLE} ${PROJECT_SOURCE_DIR}/dnn/scripts/gen_param_defs.py -t py --imperative ${CMAKE_CURRENT_BINARY_DIR}/opr_param_defs.py ${GEN_OP_PARAMS_FILE} | |||||
| DEPENDS ${OPR_DECL_SRCS} ${PYTHON_SRCS} ${ALL_HEADERS} ${GEN_OP_PARAMS_TEMPLATE} | |||||
| VERBATIM | |||||
| ) | |||||
| add_custom_target(gen_opr_py DEPENDS ${GEN_OPS_FILE}) | |||||
| ##################### generate opdef c header and python binding ############## | |||||
| set(OP_DEF_HEADER_OUT_DIR ${CMAKE_CURRENT_BINARY_DIR}/src/include) | |||||
| file(MAKE_DIRECTORY ${OP_DEF_HEADER_OUT_DIR}/megbrain/imperative/opdef) | |||||
| set(OP_DEF_HEADER ${OP_DEF_HEADER_OUT_DIR}/megbrain/imperative/opdef/all.h) | |||||
| set(OP_DEF_PYTHON_BINDING_OUT_DIR ${MEGENGINE_DIR}/${PACKAGE_NAME}/src) | |||||
| file(MAKE_DIRECTORY ${OP_DEF_PYTHON_BINDING_OUT_DIR}) | |||||
| set(OP_DEF_PYTHON_BINDING ${OP_DEF_PYTHON_BINDING_OUT_DIR}/opdef.inl) | |||||
| set(OP_PARAM_DEF ${CMAKE_CURRENT_BINARY_DIR}/opr_param_defs.py) | |||||
| set(GEN_OP_DEF_SCRIPT ${CMAKE_CURRENT_SOURCE_DIR}/python/tools/gen_op_defs.py) | |||||
| add_custom_command( | |||||
| OUTPUT ${OP_DEF_HEADER} ${OP_DEF_PYTHON_BINDING} | |||||
| COMMAND ${PYTHON_EXECUTABLE} ${GEN_OP_DEF_SCRIPT} ${OP_PARAM_DEF} ${OP_DEF_HEADER} | |||||
| COMMAND ${PYTHON_EXECUTABLE} ${GEN_OP_DEF_SCRIPT} -t py ${OP_PARAM_DEF} ${OP_DEF_PYTHON_BINDING} | |||||
| DEPENDS ${GEN_OP_DEF_SCRIPT} ${OP_PARAM_DEF} | |||||
| VERBATIM | |||||
| ) | |||||
| add_custom_target(gen_op_def_internal DEPENDS ${OP_DEF_HEADER} ${OP_DEF_PYTHON_BINDING}) | |||||
| add_library(gen_op_def INTERFACE) | |||||
| target_include_directories(gen_op_def INTERFACE ${OP_DEF_HEADER_OUT_DIR} ${OP_DEF_PYTHON_BINDING_OUT_DIR}) | |||||
| add_dependencies(gen_op_def gen_op_def_internal) | |||||
| ##################### end of opdef generation ######################### | |||||
| set(VERSION_SCRIPT ${CMAKE_CURRENT_SOURCE_DIR}/src/version.ld) | |||||
| add_custom_target(_version_ld SOURCES ${VERSION_SCRIPT}) | |||||
| add_subdirectory(${PROJECT_SOURCE_DIR}/third_party/pybind11 ${PROJECT_BINARY_DIR}/third_party/pybind11) | |||||
| pybind11_add_module(${MODULE_NAME} NO_EXTRAS ${SRCS}) | |||||
| target_link_libraries(${MODULE_NAME} PRIVATE gen_op_def megbrain megdnn -Wl,--version-script=${VERSION_SCRIPT}) | |||||
| if (MGE_WITH_DISTRIBUTED) | |||||
| message("Imperative configured to link megray") | |||||
| target_link_libraries(${MODULE_NAME} PRIVATE megray) | |||||
| endif() | |||||
| target_include_directories(${MODULE_NAME} PUBLIC src/include PRIVATE ${PYTHON_INCLUDE_DIRS} ${NUMPY_INCLUDE_DIR}) | |||||
| target_compile_definitions(${MODULE_NAME} PRIVATE MODULE_NAME=${MODULE_NAME}) | |||||
| target_compile_options(${MODULE_NAME} PRIVATE -Wno-unused-parameter) | |||||
| if(CXX_SUPPORT_WCLASS_MEMACCESS) | |||||
| target_compile_options(${MODULE_NAME} PRIVATE "-Wno-class-memaccess") | |||||
| endif() | |||||
| set_target_properties(${MODULE_NAME} PROPERTIES | |||||
| SUFFIX ${CMAKE_SHARED_LIBRARY_SUFFIX} | |||||
| LIBRARY_OUTPUT_DIRECTORY ${MEGENGINE_DIR}/${PACKAGE_NAME}/core | |||||
| ) | |||||
| add_dependencies(${MODULE_NAME} gen_opr_py _version_ld) | |||||
| if(MGE_WITH_TEST AND MGE_ENABLE_RTTI) | |||||
| add_subdirectory(test) | |||||
| endif() | |||||
| add_custom_command( | |||||
| TARGET ${MODULE_NAME} POST_BUILD | |||||
| COMMAND ${CMAKE_COMMAND} -E copy ${PROJECT_SOURCE_DIR}/LICENSE ${PROJECT_SOURCE_DIR}/ACKNOWLEDGMENTS ${PROJECT_BINARY_DIR} | |||||
| COMMAND ${CMAKE_COMMAND} -E copy_directory ${CMAKE_CURRENT_SOURCE_DIR}/python/megengine ${CMAKE_CURRENT_BINARY_DIR}/python/megengine | |||||
| COMMAND ${CMAKE_COMMAND} -E copy_directory ${CMAKE_CURRENT_SOURCE_DIR}/python/test ${CMAKE_CURRENT_BINARY_DIR}/python/test | |||||
| COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_SOURCE_DIR}/python/setup.py ${CMAKE_CURRENT_BINARY_DIR}/python/setup.py | |||||
| COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_SOURCE_DIR}/python/requires.txt ${CMAKE_CURRENT_BINARY_DIR}/python/requires.txt | |||||
| COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_SOURCE_DIR}/python/requires-style.txt ${CMAKE_CURRENT_BINARY_DIR}/python/requires-style.txt | |||||
| COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_SOURCE_DIR}/python/requires-test.txt ${CMAKE_CURRENT_BINARY_DIR}/python/requires-test.txt | |||||
| ) | |||||
| @@ -0,0 +1,25 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| import os | |||||
| import sys | |||||
| from .core._imperative_rt.utils import _set_fork_exec_path_for_timed_func | |||||
| from .device import * | |||||
| from .logger import enable_debug_log, get_logger, set_log_file, set_log_level | |||||
| from .serialization import load, save | |||||
| from .tensor import Tensor, tensor | |||||
| from .tensor_nn import Buffer, Parameter | |||||
| from .version import __version__ | |||||
| _set_fork_exec_path_for_timed_func( | |||||
| sys.executable, | |||||
| os.path.join(os.path.dirname(__file__), "utils", "_timed_func_fork_exec_entry.py"), | |||||
| ) | |||||
| del _set_fork_exec_path_for_timed_func | |||||
| @@ -0,0 +1,12 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| import os | |||||
| import sys | |||||
| from .tensor import Tensor | |||||
| @@ -0,0 +1,46 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| import numpy as np | |||||
| from ._imperative_rt import CompNode | |||||
| class Device: | |||||
| def __init__(self, device=None): | |||||
| if device is None: | |||||
| self._cn = CompNode() | |||||
| elif isinstance(device, Device): | |||||
| self._cn = device._cn | |||||
| elif isinstance(device, CompNode): | |||||
| self._cn = device | |||||
| else: | |||||
| self._cn = CompNode(device) | |||||
| def to_c(self): | |||||
| return self._cn | |||||
| def __repr__(self): | |||||
| return "{}({})".format(type(self).__qualname__, self) | |||||
| def __str__(self): | |||||
| return str(self._cn) | |||||
| def __hash__(self): | |||||
| return hash(str(self._cn)) | |||||
| def __eq__(self, rhs): | |||||
| if not isinstance(rhs, Device): | |||||
| rhs = Device(rhs) | |||||
| return str(self._cn) == str(rhs._cn) | |||||
| def device(obj): | |||||
| if isinstance(obj, Device): | |||||
| return obj | |||||
| return Device(obj) | |||||
| @@ -0,0 +1,8 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| @@ -0,0 +1,134 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| import functools | |||||
| import itertools | |||||
| import numpy as np | |||||
| from .._imperative_rt import TensorAttr, imperative | |||||
| from ..ops.builtin import Elemwise, GetVarShape, OpDef, OprAttr, Reduce, Reshape | |||||
| from ..tensor.core import apply | |||||
| from ..tensor.function import Function | |||||
| @functools.singledispatch | |||||
| def builtin_op_get_backward_fn(op: OpDef, inputs, outputs, input_requires_grad): | |||||
| assert 0 | |||||
| _elemwise_add_param = Elemwise(mode="add").to_c().param | |||||
| @builtin_op_get_backward_fn.register(OpDef) | |||||
| def _(op: OpDef, inputs, outputs, input_requires_grad): | |||||
| if ( | |||||
| isinstance(op, OprAttr) | |||||
| and op.type == "Elemwise" | |||||
| and op.param == _elemwise_add_param | |||||
| ): | |||||
| grad_fn = elemwise_grad_fn | |||||
| elif isinstance(op, OprAttr) and op.type == Reshape.name: | |||||
| grad_fn = reshape_grad_fn | |||||
| else: | |||||
| grad_fn = default_grad_fn | |||||
| return grad_fn(op, inputs, outputs, input_requires_grad) | |||||
| @builtin_op_get_backward_fn.register(Function) | |||||
| def _(op: Function, inputs, outputs, input_requires_grad): | |||||
| return op.get_backward_fn(), [True,] * len(outputs) | |||||
| def default_grad_fn(op, inputs, outputs, input_requires_grad): | |||||
| def get_tensor_attr(x): | |||||
| attr = TensorAttr() | |||||
| attr.dtype = x.dtype | |||||
| attr.comp_node = x.device.to_c() | |||||
| return attr | |||||
| output_has_grads = [True,] * len(outputs) | |||||
| result = imperative.make_backward_graph( | |||||
| op, list(map(get_tensor_attr, inputs)), input_requires_grad, output_has_grads | |||||
| ) | |||||
| if result is None: | |||||
| nr_inputs = len(inputs) | |||||
| nr_outputs = len(outputs) | |||||
| def backward(*args): | |||||
| return nr_inputs * [ | |||||
| None, | |||||
| ] | |||||
| return backward, nr_outputs * [False,] | |||||
| backward_graph, save_for_backward_mask, input_has_grad = result | |||||
| intput_output_mask = save_for_backward_mask[: len(inputs + outputs) :] | |||||
| output_grad_mask = save_for_backward_mask[len(inputs + outputs) :] | |||||
| save_for_backward = tuple( | |||||
| val for val, mask in zip(inputs + outputs, intput_output_mask) if mask | |||||
| ) | |||||
| del inputs | |||||
| del outputs | |||||
| def backward(*args): | |||||
| output_grads = tuple(val for val, mask in zip(args, output_grad_mask) if mask) | |||||
| assert None not in output_grads | |||||
| ret = iter(apply(backward_graph, *(save_for_backward + output_grads))) | |||||
| return tuple(next(ret) if mask else None for mask in input_has_grad) | |||||
| return backward, output_grad_mask | |||||
| # override for elemwise | |||||
| def elemwise_grad_fn(op, inputs, outputs, input_requires_grad): | |||||
| assert len(inputs) == len(input_requires_grad) == 2 | |||||
| def get_shape(x): | |||||
| (s,) = apply(GetVarShape(), x) | |||||
| return s | |||||
| input_shapes = [ | |||||
| get_shape(x) if i else None for i, x in zip(input_requires_grad, inputs) | |||||
| ] | |||||
| def reduce_to(x, s): | |||||
| (y,) = apply(Reduce(), x, s) | |||||
| return y | |||||
| def backward(dy): | |||||
| return tuple( | |||||
| reduce_to(dy, s) if i else None | |||||
| for i, s in zip(input_requires_grad, input_shapes) | |||||
| ) | |||||
| return backward, [True] | |||||
| def reshape_grad_fn(op, inputs, outputs, input_requires_grad): | |||||
| assert len(inputs) == len(input_requires_grad) == 2 | |||||
| def get_shape(x): | |||||
| (s,) = apply(GetVarShape(), x) | |||||
| return s | |||||
| input_shapes = [ | |||||
| get_shape(x) if i else None for i, x in zip(input_requires_grad, inputs) | |||||
| ] | |||||
| def reshape_to(dy, s): | |||||
| (dx,) = apply(Reshape(), dy, s) | |||||
| return dx | |||||
| def backward(dy): | |||||
| return tuple( | |||||
| reshape_to(dy, s) if i else None | |||||
| for i, s in zip(input_requires_grad, input_shapes) | |||||
| ) | |||||
| return backward, [True] | |||||
| @@ -0,0 +1,390 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| import functools | |||||
| import heapq | |||||
| import itertools | |||||
| import typing | |||||
| import weakref | |||||
| import numpy as np | |||||
| from ..ops.builtin import Elemwise, OpDef | |||||
| from ..ops.special import Const | |||||
| from ..tensor.core import TensorBase, TensorWrapperBase, apply | |||||
| from ..tensor.function import Function | |||||
| from ..tensor.tensor import Tensor, get_context | |||||
| from . import builtin_op_utils | |||||
| """ Some notes: | |||||
| 1. Initialize the optimizer: | |||||
| for each trainable parameter: | |||||
| call wrt(param, callback) | |||||
| Each parameter tensor will be assciated with a Tracer object saved in Tensor._extra_data | |||||
| 2. Tracer has one member: node, which is a VariableNode | |||||
| 3. VariableNode has a OpNode member: opnode | |||||
| 4. OpNode has four members: | |||||
| a. id | |||||
| b. inputs, which is made of VariableNode | |||||
| c. outputs, which are weakref's to VariableNode | |||||
| d. backward: call back function | |||||
| e. has_grad_fn: call has_grad_fn(opnode, reached) to check grad exist | |||||
| f. backward_allow_noinput: whether backward allow noinput | |||||
| """ | |||||
| _grad_count = 0 | |||||
| _grad_manager_dict = weakref.WeakValueDictionary() | |||||
| def get_grad_managers(): | |||||
| return [_grad_manager_dict[key] for key in _grad_manager_dict] | |||||
| def add(a, b): | |||||
| (c,) = apply(Elemwise(mode="add"), a, b) | |||||
| return c | |||||
| def get_tensor(x): | |||||
| # use recursion to avoid infinite loop | |||||
| if isinstance(x, Tensor): | |||||
| return x | |||||
| try: | |||||
| x = x.__wrapped__ | |||||
| except AttributeError: | |||||
| raise TypeError(type(x)) | |||||
| return get_tensor(x) | |||||
| class Grad: | |||||
| def __init__(self, name=None): | |||||
| if name is None: | |||||
| global _grad_count | |||||
| self._name = "grad_" + str(_grad_count) | |||||
| _grad_count += 1 | |||||
| else: | |||||
| self._name = name | |||||
| assert self._name not in _grad_manager_dict, "grad manager name duplicated" | |||||
| _grad_manager_dict[self._name] = self | |||||
| # list of all x in partial(y) / partial(x) | |||||
| self.xs = [] | |||||
| # constains weak reference of all OpNode during forward | |||||
| # OpNode contains inputs, outputs and its backward | |||||
| # ops forms the computational graph | |||||
| self.ops = [] | |||||
| self._enabled = True | |||||
| @property | |||||
| def name(self): | |||||
| return self._name | |||||
| def wrt(self, *args: Tensor, callback=None): | |||||
| """ Indicates the loss is a function of the input tensors (usually the net trainable parameters), | |||||
| i.e., d (loss) / d (Tensor) != 0 | |||||
| callback is used to perform additional operations after gradient is obtained in backward. | |||||
| e.g., copy the grad to a particular place | |||||
| A VariableNode will be created and saved in the tensor/s _extra_data slot. | |||||
| """ | |||||
| for x in map(get_tensor, args): | |||||
| v = self._new_variable(x, callback=callback) | |||||
| assert self not in x._extra_data | |||||
| x._extra_data[self] = Tracer(v) | |||||
| self.xs.append(v) | |||||
| return self | |||||
| def _new_variable(self, owner, opnode=None, callback=None): | |||||
| return VariableNode(self, owner, opnode=opnode, callback=callback) | |||||
| def _new_opnode(self, inputs, outputs): | |||||
| inputs = tuple(inputs) | |||||
| for i in inputs: | |||||
| assert i is None or isinstance(i, VariableNode) | |||||
| o = OpNode() | |||||
| o.inputs = inputs | |||||
| o.outputs = [] | |||||
| tracers = [] | |||||
| for i in outputs: | |||||
| assert isinstance(i, Tensor) | |||||
| v = self._new_variable(i, o) | |||||
| o.outputs.append(weakref.ref(v)) | |||||
| tracers.append(Tracer(v)) | |||||
| self.ops.append(weakref.ref(o)) | |||||
| return o, tracers | |||||
| def copy(self): | |||||
| raise NotImplementedError | |||||
| def __enter__(self): | |||||
| return self | |||||
| def __exit__(self, *_): | |||||
| """clear all resources""" | |||||
| self._enabled = False | |||||
| for o in self.ops: | |||||
| o = o() | |||||
| if o: | |||||
| o.clear() | |||||
| def __call__(self, ys, dys): | |||||
| """ Defines Grad(). | |||||
| :param ys: outputs of forward operators, e.g., the loss tensor | |||||
| :type ys: list of Tensor or TensorWrapperBase | |||||
| :param dys: delta of outputs, physically equivalent to sensitivity of outputs to the loss, | |||||
| e.g., one for the loss itself | |||||
| :type dys: list of Tensor or TensorWrapperBase | |||||
| """ | |||||
| assert self._enabled | |||||
| self._enabled = False | |||||
| def check_wrapper(): | |||||
| if isinstance(dys, TensorWrapperBase): | |||||
| return type(dys) | |||||
| if isinstance(dys, TensorBase): | |||||
| return | |||||
| assert isinstance(dys, (tuple, list)) | |||||
| for i in dys: | |||||
| if isinstance(i, TensorWrapperBase): | |||||
| return type(i) | |||||
| Wrapper = check_wrapper() | |||||
| def aslist(x): | |||||
| if isinstance(x, (Tensor, TensorWrapperBase)): | |||||
| x = [x] | |||||
| else: | |||||
| x = list(x) | |||||
| x = [i.__wrapped__ if isinstance(i, TensorWrapperBase) else i for i in x] | |||||
| for i in x: | |||||
| assert isinstance(i, Tensor) | |||||
| return x | |||||
| ys = aslist(ys) | |||||
| dys = aslist(dys) | |||||
| assert len(ys) == len(dys) | |||||
| # ys is changed to a list of VariableNode which contains more information | |||||
| # such as OpNode, callback, etc. | |||||
| ys = [i._extra_data[self].node for i in ys] | |||||
| # NOTE: callback is called only if grad is not None | |||||
| # the OpNode sequence in backward | |||||
| op_seq = [] | |||||
| # VariableNode -> (i, j), where i is time stamp in backward, j means jth input | |||||
| last_written_to = {} | |||||
| def schedule(): | |||||
| reached = set(ys) | |||||
| # i is the time stamp in backward | |||||
| i = 0 | |||||
| for o in self.ops[::-1]: | |||||
| o = o() | |||||
| if o is None: | |||||
| continue | |||||
| if not o.has_grad_fn(o, reached): | |||||
| continue | |||||
| op_seq.append(o) | |||||
| for j, v in enumerate(o.inputs): | |||||
| reached.add(v) | |||||
| last_written_to[v] = i, j | |||||
| i += 1 | |||||
| schedule() | |||||
| # VariableNode -> Tensor | |||||
| cache = {} | |||||
| def initialize(): | |||||
| for y, dy in zip(ys, dys): | |||||
| cache[y] = dy | |||||
| if y not in last_written_to and y.callback: | |||||
| y.callback(y.owner(), dy) | |||||
| initialize() | |||||
| # NOTE: None is used to mark a node has been consumed | |||||
| for seqno, opnode in enumerate(op_seq): | |||||
| input_nodes = opnode.inputs | |||||
| output_nodes = [i() for i in opnode.outputs] | |||||
| backward = opnode.backward | |||||
| backward_allow_noinput = opnode.backward_allow_noinput | |||||
| opnode.clear() | |||||
| output_grads = [] | |||||
| for i in output_nodes: | |||||
| if i is not None: | |||||
| if i in cache: | |||||
| assert cache[i] is not None | |||||
| output_grads.append(cache[i]) | |||||
| else: | |||||
| output_grads.append(None) | |||||
| # read by backward, mark consumed | |||||
| cache[i] = None | |||||
| else: | |||||
| output_grads.append(None) | |||||
| if ( | |||||
| any([grad is not None for grad in output_grads]) | |||||
| or backward_allow_noinput | |||||
| ): | |||||
| input_grads = backward(*output_grads) | |||||
| else: | |||||
| input_grads = [None] * len(input_nodes) | |||||
| assert len(input_nodes) == len(input_grads) | |||||
| for i, (v, g) in enumerate(zip(input_nodes, input_grads)): | |||||
| if v is None: | |||||
| continue | |||||
| if v in cache: | |||||
| assert cache[v] | |||||
| if g is not None: | |||||
| cache[v] = add(cache[v], g) | |||||
| elif g is not None: | |||||
| cache[v] = g | |||||
| if last_written_to[v] == (seqno, i): | |||||
| if v.callback: | |||||
| v.callback( | |||||
| v.owner(), Wrapper(cache[v]) if Wrapper else cache[v] | |||||
| ) | |||||
| if v.opnode is None: | |||||
| # won't read by backward, mark consumed | |||||
| cache[v] = None | |||||
| for v in cache.values(): | |||||
| assert v is None | |||||
| class clearable: | |||||
| __cleared = False | |||||
| def __bool__(self): | |||||
| return not self.__cleared | |||||
| def clear(self): | |||||
| self.__dict__.clear() | |||||
| self.__cleared = True | |||||
| class OpNode(clearable): | |||||
| """ OpNode saves all the information to form the computational graph. | |||||
| """ | |||||
| def __init__(self): | |||||
| self.id = None | |||||
| self.inputs = None # Could be VariableNode | |||||
| self.outputs = None # Could be VariableNode | |||||
| self.backward = None | |||||
| self.has_grad_fn = None | |||||
| self.backward_allow_noinput = False | |||||
| class VariableNode(clearable): | |||||
| """ VariableNode saves OpNode and callback. | |||||
| FIXME!!! Explain manager and owner | |||||
| """ | |||||
| def __init__(self, manager, owner, opnode=None, callback=None): | |||||
| # manager is Grad type | |||||
| self.manager = weakref.ref(manager) | |||||
| # owner is Tensor type | |||||
| self.owner = weakref.ref(owner) | |||||
| self.opnode = opnode | |||||
| self.callback = callback | |||||
| class Tracer(clearable, TensorBase): | |||||
| def __init__(self, node=None): | |||||
| """ type(node) is VariableNode | |||||
| """ | |||||
| self.node = node | |||||
| @functools.singledispatch | |||||
| def check_backward_allow_noinput(op: OpDef): | |||||
| return False | |||||
| @functools.singledispatch | |||||
| def get_op_has_grad_fn(op: OpDef): | |||||
| assert 0 | |||||
| @get_op_has_grad_fn.register(OpDef) | |||||
| def _(op: OpDef): | |||||
| return default_has_grad_fn | |||||
| @get_op_has_grad_fn.register(Function) | |||||
| def _(op: Function): | |||||
| return default_has_grad_fn | |||||
| def default_has_grad_fn(opnode, reached): | |||||
| for v in opnode.outputs: | |||||
| if v() in reached: | |||||
| return True | |||||
| return False | |||||
| @apply.add | |||||
| def tracer_apply(op: (OpDef, Function), *args: typing.Optional[Tracer]): | |||||
| args = tuple(i if isinstance(i, Tracer) else None for i in args) | |||||
| input_requires_grad = list(map(bool, args)) | |||||
| if not any(input_requires_grad): | |||||
| return | |||||
| ctx = get_context() | |||||
| manager = None | |||||
| assert len(ctx.inputs) == len(args) | |||||
| for i, j in zip(ctx.inputs, args): | |||||
| if j: | |||||
| j = j.node | |||||
| assert i is j.owner() | |||||
| if manager is None: | |||||
| manager = j.manager() | |||||
| assert manager | |||||
| else: | |||||
| assert manager is j.manager() | |||||
| if not manager._enabled: | |||||
| return | |||||
| opnode, outputs = manager._new_opnode([i and i.node for i in args], ctx.outputs) | |||||
| # register backward method | |||||
| # tuple of backward functions corresponding to dy / dx_i | |||||
| # None means y is not a function of x_i | |||||
| opnode.backward, output_need_grad = builtin_op_utils.builtin_op_get_backward_fn( | |||||
| op, ctx.inputs, ctx.outputs, input_requires_grad | |||||
| ) | |||||
| assert len(outputs) == len(output_need_grad) | |||||
| outputs = [x if y else None for (x, y) in zip(outputs, output_need_grad)] | |||||
| opnode.backward_allow_noinput = check_backward_allow_noinput(op) | |||||
| opnode.has_grad_fn = get_op_has_grad_fn(op) | |||||
| return tuple(outputs) | |||||
| @apply.add | |||||
| def _(op: Const, *_: typing.Optional[Tracer]): | |||||
| return None | |||||
| @@ -0,0 +1,8 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| @@ -0,0 +1,8 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| @@ -0,0 +1,10 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| from .generated_ops import * | |||||
| from .misc_ops import * | |||||
| @@ -0,0 +1,929 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| import sys | |||||
| from functools import reduce | |||||
| from operator import or_ as _or_ | |||||
| from types import DynamicClassAttribute, MappingProxyType | |||||
| # try _collections first to reduce startup cost | |||||
| try: | |||||
| from _collections import OrderedDict | |||||
| except ImportError: | |||||
| from collections import OrderedDict | |||||
| __all__ = [ | |||||
| "EnumMeta", | |||||
| "Enum", | |||||
| "IntEnum", | |||||
| "Flag", | |||||
| "IntFlag", | |||||
| "auto", | |||||
| "unique", | |||||
| ] | |||||
| def _is_descriptor(obj): | |||||
| """Returns True if obj is a descriptor, False otherwise.""" | |||||
| return ( | |||||
| hasattr(obj, "__get__") or hasattr(obj, "__set__") or hasattr(obj, "__delete__") | |||||
| ) | |||||
| def _is_dunder(name): | |||||
| """Returns True if a __dunder__ name, False otherwise.""" | |||||
| return ( | |||||
| name[:2] == name[-2:] == "__" | |||||
| and name[2:3] != "_" | |||||
| and name[-3:-2] != "_" | |||||
| and len(name) > 4 | |||||
| ) | |||||
| def _is_sunder(name): | |||||
| """Returns True if a _sunder_ name, False otherwise.""" | |||||
| return ( | |||||
| name[0] == name[-1] == "_" | |||||
| and name[1:2] != "_" | |||||
| and name[-2:-1] != "_" | |||||
| and len(name) > 2 | |||||
| ) | |||||
| def _make_class_unpicklable(cls): | |||||
| """Make the given class un-picklable.""" | |||||
| def _break_on_call_reduce(self, proto): | |||||
| raise TypeError("%r cannot be pickled" % self) | |||||
| cls.__reduce_ex__ = _break_on_call_reduce | |||||
| cls.__module__ = "<unknown>" | |||||
| _auto_null = object() | |||||
| class auto: | |||||
| """ | |||||
| Instances are replaced with an appropriate value in Enum class suites. | |||||
| """ | |||||
| value = _auto_null | |||||
| class _EnumDict(dict): | |||||
| """Track enum member order and ensure member names are not reused. | |||||
| EnumMeta will use the names found in self._member_names as the | |||||
| enumeration member names. | |||||
| """ | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| self._member_names = [] | |||||
| self._last_values = [] | |||||
| def __setitem__(self, key, value): | |||||
| """Changes anything not dundered or not a descriptor. | |||||
| If an enum member name is used twice, an error is raised; duplicate | |||||
| values are not checked for. | |||||
| Single underscore (sunder) names are reserved. | |||||
| """ | |||||
| if _is_sunder(key): | |||||
| if key not in ( | |||||
| "_order_", | |||||
| "_create_pseudo_member_", | |||||
| "_generate_next_value_", | |||||
| "_missing_", | |||||
| ): | |||||
| raise ValueError("_names_ are reserved for future Enum use") | |||||
| if key == "_generate_next_value_": | |||||
| setattr(self, "_generate_next_value", value) | |||||
| elif _is_dunder(key): | |||||
| if key == "__order__": | |||||
| key = "_order_" | |||||
| elif key in self._member_names: | |||||
| # descriptor overwriting an enum? | |||||
| raise TypeError("Attempted to reuse key: %r" % key) | |||||
| elif not _is_descriptor(value): | |||||
| if key in self: | |||||
| # enum overwriting a descriptor? | |||||
| raise TypeError("%r already defined as: %r" % (key, self[key])) | |||||
| if isinstance(value, auto): | |||||
| if value.value == _auto_null: | |||||
| value.value = self._generate_next_value( | |||||
| key, 1, len(self._member_names), self._last_values[:] | |||||
| ) | |||||
| value = value.value | |||||
| self._member_names.append(key) | |||||
| self._last_values.append(value) | |||||
| super().__setitem__(key, value) | |||||
| # Dummy value for Enum as EnumMeta explicitly checks for it, but of course | |||||
| # until EnumMeta finishes running the first time the Enum class doesn't exist. | |||||
| # This is also why there are checks in EnumMeta like `if Enum is not None` | |||||
| Enum = None | |||||
| class EnumMeta(type): | |||||
| """Metaclass for Enum""" | |||||
| @classmethod | |||||
| def __prepare__(metacls, cls, bases): | |||||
| # create the namespace dict | |||||
| enum_dict = _EnumDict() | |||||
| # inherit previous flags and _generate_next_value_ function | |||||
| member_type, first_enum = metacls._get_mixins_(bases) | |||||
| if first_enum is not None: | |||||
| enum_dict["_generate_next_value_"] = getattr( | |||||
| first_enum, "_generate_next_value_", None | |||||
| ) | |||||
| return enum_dict | |||||
| def __new__(metacls, cls, bases, classdict): | |||||
| # an Enum class is final once enumeration items have been defined; it | |||||
| # cannot be mixed with other types (int, float, etc.) if it has an | |||||
| # inherited __new__ unless a new __new__ is defined (or the resulting | |||||
| # class will fail). | |||||
| member_type, first_enum = metacls._get_mixins_(bases) | |||||
| __new__, save_new, use_args = metacls._find_new_( | |||||
| classdict, member_type, first_enum | |||||
| ) | |||||
| # save enum items into separate mapping so they don't get baked into | |||||
| # the new class | |||||
| enum_members = {k: classdict[k] for k in classdict._member_names} | |||||
| for name in classdict._member_names: | |||||
| del classdict[name] | |||||
| # adjust the sunders | |||||
| _order_ = classdict.pop("_order_", None) | |||||
| # check for illegal enum names (any others?) | |||||
| invalid_names = set(enum_members) & { | |||||
| "mro", | |||||
| } | |||||
| if invalid_names: | |||||
| raise ValueError( | |||||
| "Invalid enum member name: {0}".format(",".join(invalid_names)) | |||||
| ) | |||||
| # create a default docstring if one has not been provided | |||||
| if "__doc__" not in classdict: | |||||
| classdict["__doc__"] = "An enumeration." | |||||
| # create our new Enum type | |||||
| enum_class = super().__new__(metacls, cls, bases, classdict) | |||||
| enum_class._member_names_ = [] # names in definition order | |||||
| enum_class._member_map_ = OrderedDict() # name->value map | |||||
| enum_class._member_type_ = member_type | |||||
| # save attributes from super classes so we know if we can take | |||||
| # the shortcut of storing members in the class dict | |||||
| base_attributes = {a for b in enum_class.mro() for a in b.__dict__} | |||||
| # Reverse value->name map for hashable values. | |||||
| enum_class._value2member_map_ = {} | |||||
| # If a custom type is mixed into the Enum, and it does not know how | |||||
| # to pickle itself, pickle.dumps will succeed but pickle.loads will | |||||
| # fail. Rather than have the error show up later and possibly far | |||||
| # from the source, sabotage the pickle protocol for this class so | |||||
| # that pickle.dumps also fails. | |||||
| # | |||||
| # However, if the new class implements its own __reduce_ex__, do not | |||||
| # sabotage -- it's on them to make sure it works correctly. We use | |||||
| # __reduce_ex__ instead of any of the others as it is preferred by | |||||
| # pickle over __reduce__, and it handles all pickle protocols. | |||||
| if "__reduce_ex__" not in classdict: | |||||
| if member_type is not object: | |||||
| methods = ( | |||||
| "__getnewargs_ex__", | |||||
| "__getnewargs__", | |||||
| "__reduce_ex__", | |||||
| "__reduce__", | |||||
| ) | |||||
| if not any(m in member_type.__dict__ for m in methods): | |||||
| _make_class_unpicklable(enum_class) | |||||
| # instantiate them, checking for duplicates as we go | |||||
| # we instantiate first instead of checking for duplicates first in case | |||||
| # a custom __new__ is doing something funky with the values -- such as | |||||
| # auto-numbering ;) | |||||
| for member_name in classdict._member_names: | |||||
| value = enum_members[member_name] | |||||
| if not isinstance(value, tuple): | |||||
| args = (value,) | |||||
| else: | |||||
| args = value | |||||
| if member_type is tuple: # special case for tuple enums | |||||
| args = (args,) # wrap it one more time | |||||
| if not use_args: | |||||
| enum_member = __new__(enum_class) | |||||
| if not hasattr(enum_member, "_value_"): | |||||
| enum_member._value_ = value | |||||
| else: | |||||
| enum_member = __new__(enum_class, *args) | |||||
| if not hasattr(enum_member, "_value_"): | |||||
| if member_type is object: | |||||
| enum_member._value_ = value | |||||
| else: | |||||
| enum_member._value_ = member_type(*args) | |||||
| value = enum_member._value_ | |||||
| enum_member._name_ = member_name | |||||
| enum_member.__objclass__ = enum_class | |||||
| enum_member.__init__(*args) | |||||
| # If another member with the same value was already defined, the | |||||
| # new member becomes an alias to the existing one. | |||||
| for name, canonical_member in enum_class._member_map_.items(): | |||||
| if canonical_member._value_ == enum_member._value_: | |||||
| enum_member = canonical_member | |||||
| break | |||||
| else: | |||||
| # Aliases don't appear in member names (only in __members__). | |||||
| enum_class._member_names_.append(member_name) | |||||
| # performance boost for any member that would not shadow | |||||
| # a DynamicClassAttribute | |||||
| if member_name not in base_attributes: | |||||
| setattr(enum_class, member_name, enum_member) | |||||
| # now add to _member_map_ | |||||
| enum_class._member_map_[member_name] = enum_member | |||||
| try: | |||||
| # This may fail if value is not hashable. We can't add the value | |||||
| # to the map, and by-value lookups for this value will be | |||||
| # linear. | |||||
| enum_class._value2member_map_[value] = enum_member | |||||
| except TypeError: | |||||
| pass | |||||
| # double check that repr and friends are not the mixin's or various | |||||
| # things break (such as pickle) | |||||
| for name in ("__repr__", "__str__", "__format__", "__reduce_ex__"): | |||||
| class_method = getattr(enum_class, name) | |||||
| obj_method = getattr(member_type, name, None) | |||||
| enum_method = getattr(first_enum, name, None) | |||||
| if obj_method is not None and obj_method is class_method: | |||||
| setattr(enum_class, name, enum_method) | |||||
| # replace any other __new__ with our own (as long as Enum is not None, | |||||
| # anyway) -- again, this is to support pickle | |||||
| if Enum is not None: | |||||
| # if the user defined their own __new__, save it before it gets | |||||
| # clobbered in case they subclass later | |||||
| if save_new: | |||||
| enum_class.__new_member__ = __new__ | |||||
| enum_class.__new__ = Enum.__new__ | |||||
| # py3 support for definition order (helps keep py2/py3 code in sync) | |||||
| if _order_ is not None: | |||||
| if isinstance(_order_, str): | |||||
| _order_ = _order_.replace(",", " ").split() | |||||
| if _order_ != enum_class._member_names_: | |||||
| raise TypeError("member order does not match _order_") | |||||
| return enum_class | |||||
| def __bool__(self): | |||||
| """ | |||||
| classes/types should always be True. | |||||
| """ | |||||
| return True | |||||
| def __call__( | |||||
| cls, value, names=None, *, module=None, qualname=None, type=None, start=1 | |||||
| ): | |||||
| """Either returns an existing member, or creates a new enum class. | |||||
| This method is used both when an enum class is given a value to match | |||||
| to an enumeration member (i.e. Color(3)) and for the functional API | |||||
| (i.e. Color = Enum('Color', names='RED GREEN BLUE')). | |||||
| When used for the functional API: | |||||
| `value` will be the name of the new class. | |||||
| `names` should be either a string of white-space/comma delimited names | |||||
| (values will start at `start`), or an iterator/mapping of name, value pairs. | |||||
| `module` should be set to the module this class is being created in; | |||||
| if it is not set, an attempt to find that module will be made, but if | |||||
| it fails the class will not be picklable. | |||||
| `qualname` should be set to the actual location this class can be found | |||||
| at in its module; by default it is set to the global scope. If this is | |||||
| not correct, unpickling will fail in some circumstances. | |||||
| `type`, if set, will be mixed in as the first base class. | |||||
| """ | |||||
| if names is None: # simple value lookup | |||||
| return cls.__new__(cls, value) | |||||
| # otherwise, functional API: we're creating a new Enum type | |||||
| return cls._create_( | |||||
| value, names, module=module, qualname=qualname, type=type, start=start | |||||
| ) | |||||
| def __contains__(cls, member): | |||||
| return isinstance(member, cls) and member._name_ in cls._member_map_ | |||||
| def __delattr__(cls, attr): | |||||
| # nicer error message when someone tries to delete an attribute | |||||
| # (see issue19025). | |||||
| if attr in cls._member_map_: | |||||
| raise AttributeError("%s: cannot delete Enum member." % cls.__name__) | |||||
| super().__delattr__(attr) | |||||
| def __dir__(self): | |||||
| return [ | |||||
| "__class__", | |||||
| "__doc__", | |||||
| "__members__", | |||||
| "__module__", | |||||
| ] + self._member_names_ | |||||
| def __getattr__(cls, name): | |||||
| """Return the enum member matching `name` | |||||
| We use __getattr__ instead of descriptors or inserting into the enum | |||||
| class' __dict__ in order to support `name` and `value` being both | |||||
| properties for enum members (which live in the class' __dict__) and | |||||
| enum members themselves. | |||||
| """ | |||||
| if _is_dunder(name): | |||||
| raise AttributeError(name) | |||||
| try: | |||||
| return cls._member_map_[name] | |||||
| except KeyError: | |||||
| raise AttributeError(name) from None | |||||
| def __getitem__(cls, name): | |||||
| return cls._member_map_[name] | |||||
| def __iter__(cls): | |||||
| return (cls._member_map_[name] for name in cls._member_names_) | |||||
| def __len__(cls): | |||||
| return len(cls._member_names_) | |||||
| @property | |||||
| def __members__(cls): | |||||
| """Returns a mapping of member name->value. | |||||
| This mapping lists all enum members, including aliases. Note that this | |||||
| is a read-only view of the internal mapping. | |||||
| """ | |||||
| return MappingProxyType(cls._member_map_) | |||||
| def __repr__(cls): | |||||
| return "<enum %r>" % cls.__name__ | |||||
| def __reversed__(cls): | |||||
| return (cls._member_map_[name] for name in reversed(cls._member_names_)) | |||||
| def __setattr__(cls, name, value): | |||||
| """Block attempts to reassign Enum members. | |||||
| A simple assignment to the class namespace only changes one of the | |||||
| several possible ways to get an Enum member from the Enum class, | |||||
| resulting in an inconsistent Enumeration. | |||||
| """ | |||||
| member_map = cls.__dict__.get("_member_map_", {}) | |||||
| if name in member_map: | |||||
| raise AttributeError("Cannot reassign members.") | |||||
| super().__setattr__(name, value) | |||||
| def _create_( | |||||
| cls, class_name, names=None, *, module=None, qualname=None, type=None, start=1 | |||||
| ): | |||||
| """Convenience method to create a new Enum class. | |||||
| `names` can be: | |||||
| * A string containing member names, separated either with spaces or | |||||
| commas. Values are incremented by 1 from `start`. | |||||
| * An iterable of member names. Values are incremented by 1 from `start`. | |||||
| * An iterable of (member name, value) pairs. | |||||
| * A mapping of member name -> value pairs. | |||||
| """ | |||||
| metacls = cls.__class__ | |||||
| bases = (cls,) if type is None else (type, cls) | |||||
| _, first_enum = cls._get_mixins_(bases) | |||||
| classdict = metacls.__prepare__(class_name, bases) | |||||
| # special processing needed for names? | |||||
| if isinstance(names, str): | |||||
| names = names.replace(",", " ").split() | |||||
| if isinstance(names, (tuple, list)) and names and isinstance(names[0], str): | |||||
| original_names, names = names, [] | |||||
| last_values = [] | |||||
| for count, name in enumerate(original_names): | |||||
| value = first_enum._generate_next_value_( | |||||
| name, start, count, last_values[:] | |||||
| ) | |||||
| last_values.append(value) | |||||
| names.append((name, value)) | |||||
| # Here, names is either an iterable of (name, value) or a mapping. | |||||
| for item in names: | |||||
| if isinstance(item, str): | |||||
| member_name, member_value = item, names[item] | |||||
| else: | |||||
| member_name, member_value = item | |||||
| classdict[member_name] = member_value | |||||
| enum_class = metacls.__new__(metacls, class_name, bases, classdict) | |||||
| # TODO: replace the frame hack if a blessed way to know the calling | |||||
| # module is ever developed | |||||
| if module is None: | |||||
| try: | |||||
| module = sys._getframe(2).f_globals["__name__"] | |||||
| except (AttributeError, ValueError) as exc: | |||||
| pass | |||||
| if module is None: | |||||
| _make_class_unpicklable(enum_class) | |||||
| else: | |||||
| enum_class.__module__ = module | |||||
| if qualname is not None: | |||||
| enum_class.__qualname__ = qualname | |||||
| return enum_class | |||||
| @staticmethod | |||||
| def _get_mixins_(bases): | |||||
| """Returns the type for creating enum members, and the first inherited | |||||
| enum class. | |||||
| bases: the tuple of bases that was given to __new__ | |||||
| """ | |||||
| if not bases: | |||||
| return object, Enum | |||||
| # double check that we are not subclassing a class with existing | |||||
| # enumeration members; while we're at it, see if any other data | |||||
| # type has been mixed in so we can use the correct __new__ | |||||
| member_type = first_enum = None | |||||
| for base in bases: | |||||
| if base is not Enum and issubclass(base, Enum) and base._member_names_: | |||||
| raise TypeError("Cannot extend enumerations") | |||||
| # base is now the last base in bases | |||||
| if not issubclass(base, Enum): | |||||
| raise TypeError( | |||||
| "new enumerations must be created as " | |||||
| "`ClassName([mixin_type,] enum_type)`" | |||||
| ) | |||||
| # get correct mix-in type (either mix-in type of Enum subclass, or | |||||
| # first base if last base is Enum) | |||||
| if not issubclass(bases[0], Enum): | |||||
| member_type = bases[0] # first data type | |||||
| first_enum = bases[-1] # enum type | |||||
| else: | |||||
| for base in bases[0].__mro__: | |||||
| # most common: (IntEnum, int, Enum, object) | |||||
| # possible: (<Enum 'AutoIntEnum'>, <Enum 'IntEnum'>, | |||||
| # <class 'int'>, <Enum 'Enum'>, | |||||
| # <class 'object'>) | |||||
| if issubclass(base, Enum): | |||||
| if first_enum is None: | |||||
| first_enum = base | |||||
| else: | |||||
| if member_type is None: | |||||
| member_type = base | |||||
| return member_type, first_enum | |||||
| @staticmethod | |||||
| def _find_new_(classdict, member_type, first_enum): | |||||
| """Returns the __new__ to be used for creating the enum members. | |||||
| classdict: the class dictionary given to __new__ | |||||
| member_type: the data type whose __new__ will be used by default | |||||
| first_enum: enumeration to check for an overriding __new__ | |||||
| """ | |||||
| # now find the correct __new__, checking to see of one was defined | |||||
| # by the user; also check earlier enum classes in case a __new__ was | |||||
| # saved as __new_member__ | |||||
| __new__ = classdict.get("__new__", None) | |||||
| # should __new__ be saved as __new_member__ later? | |||||
| save_new = __new__ is not None | |||||
| if __new__ is None: | |||||
| # check all possibles for __new_member__ before falling back to | |||||
| # __new__ | |||||
| for method in ("__new_member__", "__new__"): | |||||
| for possible in (member_type, first_enum): | |||||
| target = getattr(possible, method, None) | |||||
| if target not in { | |||||
| None, | |||||
| None.__new__, | |||||
| object.__new__, | |||||
| Enum.__new__, | |||||
| }: | |||||
| __new__ = target | |||||
| break | |||||
| if __new__ is not None: | |||||
| break | |||||
| else: | |||||
| __new__ = object.__new__ | |||||
| # if a non-object.__new__ is used then whatever value/tuple was | |||||
| # assigned to the enum member name will be passed to __new__ and to the | |||||
| # new enum member's __init__ | |||||
| if __new__ is object.__new__: | |||||
| use_args = False | |||||
| else: | |||||
| use_args = True | |||||
| return __new__, save_new, use_args | |||||
| class Enum(metaclass=EnumMeta): | |||||
| """Generic enumeration. | |||||
| Derive from this class to define new enumerations. | |||||
| """ | |||||
| def __new__(cls, value): | |||||
| # all enum instances are actually created during class construction | |||||
| # without calling this method; this method is called by the metaclass' | |||||
| # __call__ (i.e. Color(3) ), and by pickle | |||||
| if type(value) is cls: | |||||
| # For lookups like Color(Color.RED) | |||||
| return value | |||||
| # by-value search for a matching enum member | |||||
| # see if it's in the reverse mapping (for hashable values) | |||||
| try: | |||||
| if value in cls._value2member_map_: | |||||
| return cls._value2member_map_[value] | |||||
| except TypeError: | |||||
| # not there, now do long search -- O(n) behavior | |||||
| for member in cls._member_map_.values(): | |||||
| if member._value_ == value: | |||||
| return member | |||||
| # still not found -- try _missing_ hook | |||||
| return cls._missing_(value) | |||||
| def _generate_next_value_(name, start, count, last_values): | |||||
| for last_value in reversed(last_values): | |||||
| try: | |||||
| return last_value + 1 | |||||
| except TypeError: | |||||
| pass | |||||
| else: | |||||
| return start | |||||
| @classmethod | |||||
| def _missing_(cls, value): | |||||
| raise ValueError("%r is not a valid %s" % (value, cls.__name__)) | |||||
| def __repr__(self): | |||||
| return "<%s.%s: %r>" % (self.__class__.__name__, self._name_, self._value_) | |||||
| def __str__(self): | |||||
| return "%s.%s" % (self.__class__.__name__, self._name_) | |||||
| def __dir__(self): | |||||
| added_behavior = [ | |||||
| m | |||||
| for cls in self.__class__.mro() | |||||
| for m in cls.__dict__ | |||||
| if m[0] != "_" and m not in self._member_map_ | |||||
| ] | |||||
| return ["__class__", "__doc__", "__module__"] + added_behavior | |||||
| def __format__(self, format_spec): | |||||
| # mixed-in Enums should use the mixed-in type's __format__, otherwise | |||||
| # we can get strange results with the Enum name showing up instead of | |||||
| # the value | |||||
| # pure Enum branch | |||||
| if self._member_type_ is object: | |||||
| cls = str | |||||
| val = str(self) | |||||
| # mix-in branch | |||||
| else: | |||||
| cls = self._member_type_ | |||||
| val = self._value_ | |||||
| return cls.__format__(val, format_spec) | |||||
| def __hash__(self): | |||||
| return hash(self._name_) | |||||
| def __reduce_ex__(self, proto): | |||||
| return self.__class__, (self._value_,) | |||||
| # DynamicClassAttribute is used to provide access to the `name` and | |||||
| # `value` properties of enum members while keeping some measure of | |||||
| # protection from modification, while still allowing for an enumeration | |||||
| # to have members named `name` and `value`. This works because enumeration | |||||
| # members are not set directly on the enum class -- __getattr__ is | |||||
| # used to look them up. | |||||
| @DynamicClassAttribute | |||||
| def name(self): | |||||
| """The name of the Enum member.""" | |||||
| return self._name_ | |||||
| @DynamicClassAttribute | |||||
| def value(self): | |||||
| """The value of the Enum member.""" | |||||
| return self._value_ | |||||
| @classmethod | |||||
| def _convert(cls, name, module, filter, source=None): | |||||
| """ | |||||
| Create a new Enum subclass that replaces a collection of global constants | |||||
| """ | |||||
| # convert all constants from source (or module) that pass filter() to | |||||
| # a new Enum called name, and export the enum and its members back to | |||||
| # module; | |||||
| # also, replace the __reduce_ex__ method so unpickling works in | |||||
| # previous Python versions | |||||
| module_globals = vars(sys.modules[module]) | |||||
| if source: | |||||
| source = vars(source) | |||||
| else: | |||||
| source = module_globals | |||||
| # We use an OrderedDict of sorted source keys so that the | |||||
| # _value2member_map is populated in the same order every time | |||||
| # for a consistent reverse mapping of number to name when there | |||||
| # are multiple names for the same number rather than varying | |||||
| # between runs due to hash randomization of the module dictionary. | |||||
| members = [(name, source[name]) for name in source.keys() if filter(name)] | |||||
| try: | |||||
| # sort by value | |||||
| members.sort(key=lambda t: (t[1], t[0])) | |||||
| except TypeError: | |||||
| # unless some values aren't comparable, in which case sort by name | |||||
| members.sort(key=lambda t: t[0]) | |||||
| cls = cls(name, members, module=module) | |||||
| cls.__reduce_ex__ = _reduce_ex_by_name | |||||
| module_globals.update(cls.__members__) | |||||
| module_globals[name] = cls | |||||
| return cls | |||||
| class IntEnum(int, Enum): | |||||
| """Enum where members are also (and must be) ints""" | |||||
| def _reduce_ex_by_name(self, proto): | |||||
| return self.name | |||||
| class Flag(Enum): | |||||
| """Support for flags""" | |||||
| def _generate_next_value_(name, start, count, last_values): | |||||
| """ | |||||
| Generate the next value when not given. | |||||
| name: the name of the member | |||||
| start: the initital start value or None | |||||
| count: the number of existing members | |||||
| last_value: the last value assigned or None | |||||
| """ | |||||
| if not count: | |||||
| return start if start is not None else 1 | |||||
| for last_value in reversed(last_values): | |||||
| try: | |||||
| high_bit = _high_bit(last_value) | |||||
| break | |||||
| except Exception: | |||||
| raise TypeError("Invalid Flag value: %r" % last_value) from None | |||||
| return 2 ** (high_bit + 1) | |||||
| @classmethod | |||||
| def _missing_(cls, value): | |||||
| original_value = value | |||||
| if value < 0: | |||||
| value = ~value | |||||
| possible_member = cls._create_pseudo_member_(value) | |||||
| if original_value < 0: | |||||
| possible_member = ~possible_member | |||||
| return possible_member | |||||
| @classmethod | |||||
| def _create_pseudo_member_(cls, value): | |||||
| """ | |||||
| Create a composite member iff value contains only members. | |||||
| """ | |||||
| pseudo_member = cls._value2member_map_.get(value, None) | |||||
| if pseudo_member is None: | |||||
| # verify all bits are accounted for | |||||
| _, extra_flags = _decompose(cls, value) | |||||
| if extra_flags: | |||||
| raise ValueError("%r is not a valid %s" % (value, cls.__name__)) | |||||
| # construct a singleton enum pseudo-member | |||||
| pseudo_member = object.__new__(cls) | |||||
| pseudo_member._name_ = None | |||||
| pseudo_member._value_ = value | |||||
| # use setdefault in case another thread already created a composite | |||||
| # with this value | |||||
| pseudo_member = cls._value2member_map_.setdefault(value, pseudo_member) | |||||
| return pseudo_member | |||||
| def __contains__(self, other): | |||||
| if not isinstance(other, self.__class__): | |||||
| return NotImplemented | |||||
| return other._value_ & self._value_ == other._value_ | |||||
| def __repr__(self): | |||||
| cls = self.__class__ | |||||
| if self._name_ is not None: | |||||
| return "<%s.%s: %r>" % (cls.__name__, self._name_, self._value_) | |||||
| members, uncovered = _decompose(cls, self._value_) | |||||
| return "<%s.%s: %r>" % ( | |||||
| cls.__name__, | |||||
| "|".join([str(m._name_ or m._value_) for m in members]), | |||||
| self._value_, | |||||
| ) | |||||
| def __str__(self): | |||||
| cls = self.__class__ | |||||
| if self._name_ is not None: | |||||
| return "%s.%s" % (cls.__name__, self._name_) | |||||
| members, uncovered = _decompose(cls, self._value_) | |||||
| if len(members) == 1 and members[0]._name_ is None: | |||||
| return "%s.%r" % (cls.__name__, members[0]._value_) | |||||
| else: | |||||
| return "%s.%s" % ( | |||||
| cls.__name__, | |||||
| "|".join([str(m._name_ or m._value_) for m in members]), | |||||
| ) | |||||
| def __bool__(self): | |||||
| return bool(self._value_) | |||||
| def __or__(self, other): | |||||
| if not isinstance(other, self.__class__): | |||||
| return NotImplemented | |||||
| return self.__class__(self._value_ | other._value_) | |||||
| def __and__(self, other): | |||||
| if not isinstance(other, self.__class__): | |||||
| return NotImplemented | |||||
| return self.__class__(self._value_ & other._value_) | |||||
| def __xor__(self, other): | |||||
| if not isinstance(other, self.__class__): | |||||
| return NotImplemented | |||||
| return self.__class__(self._value_ ^ other._value_) | |||||
| def __invert__(self): | |||||
| members, uncovered = _decompose(self.__class__, self._value_) | |||||
| inverted_members = [ | |||||
| m | |||||
| for m in self.__class__ | |||||
| if m not in members and not m._value_ & self._value_ | |||||
| ] | |||||
| inverted = reduce(_or_, inverted_members, self.__class__(0)) | |||||
| return self.__class__(inverted) | |||||
| class IntFlag(int, Flag): | |||||
| """Support for integer-based Flags""" | |||||
| @classmethod | |||||
| def _missing_(cls, value): | |||||
| if not isinstance(value, int): | |||||
| raise ValueError("%r is not a valid %s" % (value, cls.__name__)) | |||||
| new_member = cls._create_pseudo_member_(value) | |||||
| return new_member | |||||
| @classmethod | |||||
| def _create_pseudo_member_(cls, value): | |||||
| pseudo_member = cls._value2member_map_.get(value, None) | |||||
| if pseudo_member is None: | |||||
| need_to_create = [value] | |||||
| # get unaccounted for bits | |||||
| _, extra_flags = _decompose(cls, value) | |||||
| # timer = 10 | |||||
| while extra_flags: | |||||
| # timer -= 1 | |||||
| bit = _high_bit(extra_flags) | |||||
| flag_value = 2 ** bit | |||||
| if ( | |||||
| flag_value not in cls._value2member_map_ | |||||
| and flag_value not in need_to_create | |||||
| ): | |||||
| need_to_create.append(flag_value) | |||||
| if extra_flags == -flag_value: | |||||
| extra_flags = 0 | |||||
| else: | |||||
| extra_flags ^= flag_value | |||||
| for value in reversed(need_to_create): | |||||
| # construct singleton pseudo-members | |||||
| pseudo_member = int.__new__(cls, value) | |||||
| pseudo_member._name_ = None | |||||
| pseudo_member._value_ = value | |||||
| # use setdefault in case another thread already created a composite | |||||
| # with this value | |||||
| pseudo_member = cls._value2member_map_.setdefault(value, pseudo_member) | |||||
| return pseudo_member | |||||
| def __or__(self, other): | |||||
| if not isinstance(other, (self.__class__, int)): | |||||
| return NotImplemented | |||||
| result = self.__class__(self._value_ | self.__class__(other)._value_) | |||||
| return result | |||||
| def __and__(self, other): | |||||
| if not isinstance(other, (self.__class__, int)): | |||||
| return NotImplemented | |||||
| return self.__class__(self._value_ & self.__class__(other)._value_) | |||||
| def __xor__(self, other): | |||||
| if not isinstance(other, (self.__class__, int)): | |||||
| return NotImplemented | |||||
| return self.__class__(self._value_ ^ self.__class__(other)._value_) | |||||
| __ror__ = __or__ | |||||
| __rand__ = __and__ | |||||
| __rxor__ = __xor__ | |||||
| def __invert__(self): | |||||
| result = self.__class__(~self._value_) | |||||
| return result | |||||
| def _high_bit(value): | |||||
| """returns index of highest bit, or -1 if value is zero or negative""" | |||||
| return value.bit_length() - 1 | |||||
| def unique(enumeration): | |||||
| """Class decorator for enumerations ensuring unique member values.""" | |||||
| duplicates = [] | |||||
| for name, member in enumeration.__members__.items(): | |||||
| if name != member.name: | |||||
| duplicates.append((name, member.name)) | |||||
| if duplicates: | |||||
| alias_details = ", ".join( | |||||
| ["%s -> %s" % (alias, name) for (alias, name) in duplicates] | |||||
| ) | |||||
| raise ValueError( | |||||
| "duplicate values found in %r: %s" % (enumeration, alias_details) | |||||
| ) | |||||
| return enumeration | |||||
| def _decompose(flag, value): | |||||
| """Extract all members from the value.""" | |||||
| # _decompose is only called if the value is not named | |||||
| not_covered = value | |||||
| negative = value < 0 | |||||
| # issue29167: wrap accesses to _value2member_map_ in a list to avoid race | |||||
| # conditions between iterating over it and having more psuedo- | |||||
| # members added to it | |||||
| if negative: | |||||
| # only check for named flags | |||||
| flags_to_check = [ | |||||
| (m, v) | |||||
| for v, m in list(flag._value2member_map_.items()) | |||||
| if m.name is not None | |||||
| ] | |||||
| else: | |||||
| # check for named flags and powers-of-two flags | |||||
| flags_to_check = [ | |||||
| (m, v) | |||||
| for v, m in list(flag._value2member_map_.items()) | |||||
| if m.name is not None or _power_of_two(v) | |||||
| ] | |||||
| members = [] | |||||
| for member, member_value in flags_to_check: | |||||
| if member_value and member_value & value == member_value: | |||||
| members.append(member) | |||||
| not_covered &= ~member_value | |||||
| if not members and value in flag._value2member_map_: | |||||
| members.append(flag._value2member_map_[value]) | |||||
| members.sort(key=lambda m: m._value_, reverse=True) | |||||
| if len(members) > 1 and members[0].value == value: | |||||
| # we have the breakdown, don't need the value member itself | |||||
| members.pop(0) | |||||
| return members, not_covered | |||||
| def _power_of_two(value): | |||||
| if value < 1: | |||||
| return False | |||||
| return value == 2 ** _high_bit(value) | |||||
| @@ -0,0 +1,94 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| import warnings | |||||
| from ..._imperative_rt.ops import OprAttr | |||||
| from . import param_defs | |||||
| def make_param(param, ptype, kwargs): | |||||
| if param is not None: | |||||
| if isinstance(param, ptype): | |||||
| return param | |||||
| param = [param] | |||||
| assert len(param) == len( | |||||
| ptype.__slots__ | |||||
| ), "{} needs {} params, but {} are provided".format( | |||||
| ptype, len(ptype.__slots__), len(param) | |||||
| ) | |||||
| return ptype(*param) | |||||
| ckw = {} | |||||
| for i in ptype.__slots__: | |||||
| val = kwargs.pop(i, ckw) | |||||
| if val is not ckw: | |||||
| ckw[i] = val | |||||
| return ptype(**ckw) | |||||
| class PodOpVisitor: | |||||
| __name2subclass = {} | |||||
| __c = None | |||||
| name = None | |||||
| param_names = [] | |||||
| config = None | |||||
| def __init__(self, config, **params): | |||||
| self.config = config | |||||
| assert set(params) == set(self.param_names) | |||||
| self.__dict__.update(params) | |||||
| def __init_subclass__(cls, **kwargs): | |||||
| super().__init_subclass__(**kwargs) # python 3.5 does not have this | |||||
| name = cls.name | |||||
| if name in cls.__name2subclass: | |||||
| if not issubclass(cls, cls.__name2subclass[name]): | |||||
| warnings.warn("Multiple subclasses for bultin op: %s" % name) | |||||
| cls.__name2subclass[name] = cls | |||||
| def to_c(self): | |||||
| if self.__c: | |||||
| return self.__c | |||||
| op = OprAttr() | |||||
| op.type = self.name | |||||
| if self.config is not None: | |||||
| op.config = self.config | |||||
| # first 4 bytes is TAG, has to remove them currently | |||||
| op.param = b"".join(self.__dict__[k].serialize()[4:] for k in self.param_names) | |||||
| self.__c = op | |||||
| return op | |||||
| def __eq__(self, rhs): | |||||
| return self.to_c() == rhs.to_c() | |||||
| def __repr__(self): | |||||
| name = self.__class__.__name__ | |||||
| if self.__c: | |||||
| return "{}(<binary data>)".format(name) | |||||
| kwargs = {} | |||||
| for i in self.param_names: | |||||
| p = self.__dict__[i] | |||||
| if isinstance(p, param_defs._ParamDefBase): | |||||
| for k in p.__slots__: | |||||
| v = getattr(p, k) | |||||
| if isinstance(v, param_defs._EnumBase): | |||||
| v = v.name | |||||
| kwargs[k] = repr(v) | |||||
| else: | |||||
| kwargs[i] = repr(p) | |||||
| if self.config: | |||||
| if len(self.config.comp_node_arr) == 1: | |||||
| kwargs["device"] = "'%s'" % self.config.comp_node | |||||
| return "{}({})".format( | |||||
| name, ", ".join("{}={}".format(k, v) for k, v in kwargs.items()) | |||||
| ) | |||||
| @@ -0,0 +1,194 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| import collections | |||||
| import ctypes | |||||
| from ..._imperative_rt import OperatorNodeConfig as Config | |||||
| from . import param_defs | |||||
| from .helper import PodOpVisitor, make_param | |||||
| __all__ = ["ConvolutionBackwardData", "Dimshuffle", "Reshape", "AxisAddRemove"] | |||||
| class TensorShape: | |||||
| MAX_NDIM = 7 | |||||
| class ConvolutionBackwardData(PodOpVisitor): | |||||
| param_names = ( | |||||
| "param", | |||||
| "execution_polity", | |||||
| ) | |||||
| name = "ConvolutionBackwardDataV1" | |||||
| def __init__( | |||||
| self, | |||||
| *, | |||||
| param=None, | |||||
| execution_polity=None, | |||||
| name=None, | |||||
| comp_node=None, | |||||
| config=None, | |||||
| dtype=None, | |||||
| **kwargs | |||||
| ): | |||||
| config = config or Config() | |||||
| if name: | |||||
| config.name = name | |||||
| if comp_node: | |||||
| config.comp_node = comp_node | |||||
| if dtype: | |||||
| config.dtype = dtype | |||||
| self.config = config | |||||
| self.param = make_param(param, param_defs.Convolution, kwargs) | |||||
| self.execution_polity = make_param( | |||||
| execution_polity, param_defs.ExecutionPolicy, kwargs | |||||
| ) | |||||
| assert not kwargs, "extra kwargs: {}".format(kwargs) | |||||
| class Dimshuffle(PodOpVisitor): | |||||
| name = "Dimshuffle" | |||||
| param_names = ("pattern",) | |||||
| class Pattern(ctypes.Structure): | |||||
| Pattern_Array = ctypes.c_int32 * TensorShape.MAX_NDIM | |||||
| _fields_ = [ | |||||
| ("length", ctypes.c_uint32), | |||||
| ("pattern", Pattern_Array), | |||||
| ("ndim", ctypes.c_uint32), | |||||
| ] | |||||
| def serialize(self): | |||||
| return bytes(ctypes.c_uint32(0)) + bytes(self) | |||||
| def __init__(self, pattern, ndim=0): | |||||
| assert isinstance(pattern, collections.Iterable) | |||||
| assert len(pattern) <= TensorShape.MAX_NDIM | |||||
| pattern_array = Dimshuffle.Pattern.Pattern_Array() | |||||
| for idx, v in enumerate(pattern): | |||||
| pattern_array[idx] = ctypes.c_int32(-1 if v == "x" else int(v)) | |||||
| self.pattern = Dimshuffle.Pattern(len(pattern), pattern_array, ndim) | |||||
| class Reshape(PodOpVisitor): | |||||
| name = "ReshapeV1" | |||||
| param_names = ("unspec_axis",) | |||||
| def __init__(self, unspec_axis=None): | |||||
| if unspec_axis is None: | |||||
| self.unspec_axis = param_defs.OptionalAxisV1() | |||||
| else: | |||||
| self.unspec_axis = param_defs.OptionalAxisV1(unspec_axis) | |||||
| class AxisNum(ctypes.Structure): | |||||
| _fields_ = [ | |||||
| ("m_num", ctypes.c_int), | |||||
| ] | |||||
| class AxisDesc(ctypes.Structure): | |||||
| class Method(ctypes.c_int): | |||||
| ADD_1 = 0 | |||||
| REMOVE = 1 | |||||
| _fields_ = [ | |||||
| ("method", Method), | |||||
| ("axis", AxisNum), | |||||
| ] | |||||
| @classmethod | |||||
| def make_add(cls, axis): | |||||
| return cls(cls.Method.ADD_1, AxisNum(axis)) | |||||
| @classmethod | |||||
| def make_remove(cls, axis): | |||||
| return cls(cls.Method.REMOVE, AxisNum(axis)) | |||||
| class AxisAddRemove(PodOpVisitor): | |||||
| name = "AxisAddRemove" | |||||
| param_names = ("param",) | |||||
| AxisDesc = AxisDesc | |||||
| class Param(ctypes.Structure): | |||||
| MAX_DESC_SIZE = TensorShape.MAX_NDIM * 2 | |||||
| _fields_ = [("nr_desc", ctypes.c_uint32), ("desc", AxisDesc * MAX_DESC_SIZE)] | |||||
| def __init__(self, *args): | |||||
| super().__init__() | |||||
| self.nr_desc = len(args) | |||||
| for i, a in enumerate(args): | |||||
| self.desc[i] = a | |||||
| def serialize(self): | |||||
| return bytes(ctypes.c_uint32(0)) + bytes(self) | |||||
| def __init__(self, param): | |||||
| assert isinstance(param, self.Param) | |||||
| self.param = param | |||||
| del AxisDesc | |||||
| class IndexingOpBase(PodOpVisitor): | |||||
| param_names = ("index_desc",) | |||||
| class IndexDescMaskDump(ctypes.Structure): | |||||
| class Item(ctypes.Structure): | |||||
| _fields_ = [ | |||||
| ("axis", ctypes.c_int8), | |||||
| ("begin", ctypes.c_bool), | |||||
| ("end", ctypes.c_bool), | |||||
| ("step", ctypes.c_bool), | |||||
| ("idx", ctypes.c_bool), | |||||
| ] | |||||
| Item_Array = Item * TensorShape.MAX_NDIM | |||||
| _fields_ = [("nr_item", ctypes.c_uint8), ("items", Item_Array)] | |||||
| def serialize(self): | |||||
| return bytes(ctypes.c_uint32(0)) + bytes(self) | |||||
| def __init__(self, items): | |||||
| nr_item = len(items) | |||||
| assert nr_item <= TensorShape.MAX_NDIM | |||||
| item_array = IndexingOpBase.IndexDescMaskDump.Item_Array() | |||||
| for idx, item in enumerate(items): | |||||
| assert isinstance(item, (tuple, list)) and len(item) == 5 | |||||
| item_array[idx] = IndexingOpBase.IndexDescMaskDump.Item(*item) | |||||
| self.index_desc = IndexingOpBase.IndexDescMaskDump(nr_item, item_array) | |||||
| def _gen_indexing_defs(*names): | |||||
| for name in names: | |||||
| globals()[name] = type(name, (IndexingOpBase,), dict(name=name)) | |||||
| __all__.append(name) | |||||
| _gen_indexing_defs( | |||||
| "Subtensor", | |||||
| "SetSubtensor", | |||||
| "IncrSubtensor", | |||||
| "IndexingMultiAxisVec", | |||||
| "IndexingSetMultiAxisVec", | |||||
| "IndexingIncrMultiAxisVec", | |||||
| "MeshIndexing", | |||||
| "IncrMeshIndexing", | |||||
| "SetMeshIndexing", | |||||
| "BatchedMeshIndexing", | |||||
| "BatchedIncrMeshIndexing", | |||||
| "BatchedSetMeshIndexing", | |||||
| ) | |||||
| @@ -0,0 +1,37 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| import warnings | |||||
| from typing import Union | |||||
| from ..._imperative_rt import OpDef, ops | |||||
| from ...tensor.core import OpBase, TensorBase, TensorWrapperBase, apply | |||||
| from .._internal import all_ops | |||||
| from .._internal.helper import PodOpVisitor | |||||
| # register OpDef as a "virtual subclass" of OpBase, so any of registered | |||||
| # apply(OpBase, ...) rules could work well on OpDef | |||||
| OpBase.register(OpDef) | |||||
| # forward to apply(OpDef, ...) | |||||
| @apply.add | |||||
| def _(op: PodOpVisitor, *args: Union[TensorBase, TensorWrapperBase]): | |||||
| return apply(op.to_c(), *args) | |||||
| __all__ = ["OpDef", "PodOpVisitor"] | |||||
| for k, v in all_ops.__dict__.items(): | |||||
| if isinstance(v, type) and issubclass(v, PodOpVisitor): | |||||
| globals()[k] = v | |||||
| __all__.append(k) | |||||
| for k, v in ops.__dict__.items(): | |||||
| if isinstance(v, type) and issubclass(v, OpDef): | |||||
| globals()[k] = v | |||||
| __all__.append(k) | |||||
| @@ -0,0 +1,16 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| from ..tensor.core import OpBase, TensorBase, apply | |||||
| class Const(OpBase): | |||||
| def __init__(self, value=None, *, dtype=None, device=None): | |||||
| self.value = value | |||||
| self.dtype = dtype | |||||
| self.device = device | |||||
| @@ -0,0 +1,9 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| from .tensor_wrapper import TensorWrapper as Tensor | |||||
| @@ -0,0 +1,115 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| import collections | |||||
| import functools | |||||
| import inspect | |||||
| import sys | |||||
| import typing | |||||
| from abc import ABC | |||||
| import multipledispatch | |||||
| class OpBase(ABC): | |||||
| def __call__(self, *args): | |||||
| return apply(self, *args) | |||||
| class TensorBase: | |||||
| pass | |||||
| class TensorWrapperBase: | |||||
| pass | |||||
| class Dispatcher(multipledispatch.Dispatcher): | |||||
| def add(self, f, g=None): | |||||
| if g is None: | |||||
| super().add(get_signature(f), f) | |||||
| else: | |||||
| super().add(f, g) | |||||
| return f | |||||
| def __get__(self, instance, owner=None): | |||||
| if instance is not None: | |||||
| return self | |||||
| return functools.partial(self, instance) | |||||
| if sys.version_info < (3, 6): | |||||
| def parse_union(ann): | |||||
| if type(ann) is not typing.UnionMeta: | |||||
| return | |||||
| return ann.__union_params__ | |||||
| elif sys.version_info < (3, 7): | |||||
| def parse_union(ann): | |||||
| if type(ann) is not typing._Union: | |||||
| return | |||||
| return ann.__args__ | |||||
| elif sys.version_info < (3, 8): | |||||
| def parse_union(ann): | |||||
| if type(ann) is not typing._GenericAlias: | |||||
| if type(ann) is not typing.Union: | |||||
| return | |||||
| else: | |||||
| if ann.__origin__ is not typing.Union: | |||||
| return | |||||
| return ann.__args__ | |||||
| else: | |||||
| def parse_union(ann): | |||||
| if typing.get_origin(ann) is not typing.Union: | |||||
| return | |||||
| return typing.get_args(ann) | |||||
| def get_signature(function, op_type=None): | |||||
| sig = inspect.signature(function) | |||||
| types = [] | |||||
| for p in sig.parameters.values(): | |||||
| ann = p.annotation | |||||
| ann = parse_union(ann) or ann | |||||
| if p.kind in ( | |||||
| inspect.Parameter.POSITIONAL_ONLY, | |||||
| inspect.Parameter.POSITIONAL_OR_KEYWORD, | |||||
| ): | |||||
| types.append(ann) | |||||
| if p.kind == inspect.Parameter.VAR_POSITIONAL: | |||||
| types.append([ann]) | |||||
| return tuple(types) | |||||
| apply = Dispatcher("apply") | |||||
| OpBase.apply = apply | |||||
| @apply.add | |||||
| def _(op: OpBase, *args: TensorBase): | |||||
| raise NotImplementedError | |||||
| @apply.add | |||||
| def _(op: OpBase, *args: TensorWrapperBase): | |||||
| assert args | |||||
| Wrapper = type(args[0]) | |||||
| outputs = apply(op, *(i.__wrapped__ for i in args)) | |||||
| assert isinstance(outputs, tuple) | |||||
| return tuple(map(Wrapper, outputs)) | |||||
| @@ -0,0 +1,289 @@ | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| import collections | |||||
| from typing import Union | |||||
| import numpy as np | |||||
| # normal dtype related | |||||
| from .._imperative_rt import bfloat16, intb1, intb2, intb4 | |||||
| def is_lowbit(dtype): | |||||
| return (dtype is intb1) or (dtype is intb2) or (dtype is intb4) | |||||
| def is_bfloat16(dtype): | |||||
| return dtype is bfloat16 | |||||
| # quantization dtype related | |||||
| _QuantDtypeMetadata = collections.namedtuple( | |||||
| "QuantDtypeMetadata", ["name", "np_dtype_str", "is_unsigned", "qmin", "qmax",] | |||||
| ) | |||||
| _metadata_dict = { | |||||
| "quint8": _QuantDtypeMetadata("Quantized8Asymm", "uint8", True, 0, 255), | |||||
| "qint8": _QuantDtypeMetadata("QuantizedS8", "int8", False, -128, 127), | |||||
| "quint4": _QuantDtypeMetadata("Quantized4Asymm", "uint8", True, 0, 15), | |||||
| "qint4": _QuantDtypeMetadata("QuantizedS4", "int8", False, -8, 7), | |||||
| "qint32": _QuantDtypeMetadata( | |||||
| "QuantizedS32", "int32", False, -(2 ** 31), 2 ** 31 - 1, | |||||
| ), | |||||
| # NOTE: int2 is not supported for model dump yet | |||||
| "quint2": _QuantDtypeMetadata(None, "uint8", True, 0, 3), | |||||
| "qint2": _QuantDtypeMetadata(None, "int8", False, -2, 1), | |||||
| } | |||||
| def is_quantize(dtype): | |||||
| return ( | |||||
| hasattr(dtype, "metadata") | |||||
| and dtype.metadata is not None | |||||
| and "mgb_dtype" in dtype.metadata | |||||
| ) | |||||
| def get_scale(dtype): | |||||
| assert is_quantize(dtype) | |||||
| return dtype.metadata["mgb_dtype"]["scale"] | |||||
| def get_zero_point(dtype): | |||||
| assert is_quantize(dtype) | |||||
| metadata = dtype.metadata["mgb_dtype"] | |||||
| assert metadata["name"] in ("Quantized8Asymm", "Quantized4Asymm") | |||||
| return metadata["zero_point"] | |||||
| def _check_zero_point(zp: int, dtype_str: str): | |||||
| qmin = _metadata_dict[dtype_str].qmin | |||||
| qmax = _metadata_dict[dtype_str].qmax | |||||
| if zp < qmin or zp > qmax: | |||||
| raise ValueError( | |||||
| "zero_point should be within [{}, {}] for {}".format(qmin, qmax, dtype_str) | |||||
| ) | |||||
| def get_quantized_dtype(dtype_str: str, scale: float, zp: Union[int, None]): | |||||
| r""" | |||||
| Get quantized dtype with metadata attribute according to _metadata_dict. | |||||
| Note that unsigned dtype must have ``zero_point`` and signed dtype must | |||||
| not have ``zero_point``, to be consitent with tensor generated by calling | |||||
| compiled function from `CompGraph.compile(inputs, outspec)`. | |||||
| :param dtype: a string indicating which dtype to return | |||||
| :param scale: a number for scale to store in dtype's metadata | |||||
| :param zp: a number for zero_point to store in dtype's metadata | |||||
| """ | |||||
| metadata = _metadata_dict[dtype_str] | |||||
| np_dtype_str = metadata.np_dtype_str | |||||
| is_unsigned = metadata.is_unsigned | |||||
| if is_unsigned: | |||||
| if zp is None or int(zp) != zp: | |||||
| raise ValueError("zero_point should be an integer") | |||||
| zp = int(zp) | |||||
| _check_zero_point(zp, dtype_str) | |||||
| return np.dtype( | |||||
| np_dtype_str, | |||||
| metadata={ | |||||
| "mgb_dtype": { | |||||
| "name": metadata.name, | |||||
| "scale": float(scale), | |||||
| "zero_point": zp, | |||||
| } | |||||
| }, | |||||
| ) | |||||
| else: | |||||
| return np.dtype( | |||||
| np_dtype_str, | |||||
| metadata={"mgb_dtype": {"name": metadata.name, "scale": float(scale)}}, | |||||
| ) | |||||
| def quint8(scale, zero_point): | |||||
| """ | |||||
| Consturct a quantized unsigned int8 data type with ``scale`` (float) and | |||||
| ``zero_point`` (uint8). The real value represented by a quint8 data type is | |||||
| float_val = scale * (uint8_val - zero_point) | |||||
| """ | |||||
| return get_quantized_dtype("quint8", scale, zero_point) | |||||
| def qint8(scale): | |||||
| """ | |||||
| Construct a quantized int8 data type with ``scale`` (float). The real value | |||||
| represented by a qint8 data type is float_val = scale * int8_val | |||||
| """ | |||||
| return get_quantized_dtype("qint8", scale, None) | |||||
| def qint32(scale): | |||||
| """ | |||||
| Construct a quantized int32 data type with ``scale`` (float). The real value | |||||
| represented by a qint32 data type is float_val = scale * int32_val | |||||
| """ | |||||
| return get_quantized_dtype("qint32", scale, None) | |||||
| def quint4(scale, zero_point): | |||||
| """ | |||||
| Consturct a quantized unsigned int4 data type with ``scale`` (float) and | |||||
| ``zero_point`` (uint8). The real value represented by a quint4 data type is | |||||
| float_val = scale * (uint4_val - zero_point) | |||||
| """ | |||||
| return get_quantized_dtype("quint4", scale, zero_point) | |||||
| def qint4(scale): | |||||
| """ | |||||
| Construct a quantized int4 data type with ``scale`` (float). The real value | |||||
| represented by a qint4 data type is float_val = scale * int4_val | |||||
| """ | |||||
| return get_quantized_dtype("qint4", scale, None) | |||||
| def _convert_to_quantized_dtype(arr: np.ndarray, dtype: np.dtype, dtype_str: str): | |||||
| metadata = _metadata_dict[dtype_str] | |||||
| arr_metadata = dtype.metadata["mgb_dtype"] | |||||
| if not isinstance(arr, np.ndarray): | |||||
| raise ValueError("arr parameter should be instance of np.ndarray") | |||||
| if not is_quantize(dtype) or arr_metadata["name"] != metadata.name: | |||||
| raise ValueError("dtype parameter should be a {} dtype".format(dtype_str)) | |||||
| is_unsigned = metadata.is_unsigned | |||||
| if is_unsigned: | |||||
| scale, zp = ( | |||||
| arr_metadata["scale"], | |||||
| arr_metadata["zero_point"], | |||||
| ) | |||||
| return ( | |||||
| (np.round(arr / scale) + zp) | |||||
| .clip(metadata.qmin, metadata.qmax) | |||||
| .astype(dtype) | |||||
| ) | |||||
| else: | |||||
| # don't trick to combine with is_unsigned, seeing ``get_quantized_dtype`` | |||||
| scale = arr_metadata["scale"] | |||||
| return np.round(arr / scale).clip(metadata.qmin, metadata.qmax).astype(dtype) | |||||
| def _convert_from_quantized_dtype(arr: np.ndarray, dtype_str: str): | |||||
| metadata = _metadata_dict[dtype_str] | |||||
| arr_metadata = arr.dtype.metadata["mgb_dtype"] | |||||
| if not isinstance(arr, np.ndarray): | |||||
| raise ValueError("arr parameter should be instance of np.ndarray") | |||||
| if not is_quantize(arr.dtype) or arr_metadata["name"] != metadata.name: | |||||
| raise ValueError("arr's dtype should be a {} dtype".format(dtype_str)) | |||||
| is_unsigned = metadata.is_unsigned | |||||
| if is_unsigned: | |||||
| scale, zp = ( | |||||
| arr_metadata["scale"], | |||||
| arr_metadata["zero_point"], | |||||
| ) | |||||
| return (arr.astype(np.float32) - zp) * scale | |||||
| else: | |||||
| # don't trick to combine with is_unsigned, seeing ``get_quantized_dtype`` | |||||
| scale = arr_metadata["scale"] | |||||
| return (arr.astype(np.float32)) * scale | |||||
| def convert_to_quint8(arr: np.ndarray, q: np.dtype): | |||||
| """ | |||||
| Quantize a float NumPy ndarray into a quint8 one with specified params. | |||||
| :param arr: Input ndarray. | |||||
| :param q: Target data type, should be a quint8. | |||||
| """ | |||||
| return _convert_to_quantized_dtype(arr, q, "quint8") | |||||
| def convert_from_quint8(arr: np.ndarray): | |||||
| """ | |||||
| Dequantize a quint8 NumPy ndarray into a float one. | |||||
| :param arr: Input ndarray. | |||||
| """ | |||||
| return _convert_from_quantized_dtype(arr, "quint8") | |||||
| def convert_to_qint8(arr: np.ndarray, q: np.dtype): | |||||
| """ | |||||
| Quantize a float NumPy ndarray into a qint8 one with specified params. | |||||
| :param arr: Input ndarray. | |||||
| :param q: Target data type, should be a qint8. | |||||
| """ | |||||
| return _convert_to_quantized_dtype(arr, q, "qint8") | |||||
| def convert_from_qint8(arr: np.ndarray): | |||||
| """ | |||||
| Dequantize a qint8 NumPy ndarray into a float one. | |||||
| :param arr: Input ndarray. | |||||
| """ | |||||
| return _convert_from_quantized_dtype(arr, "qint8") | |||||
| def convert_to_qint32(arr: np.ndarray, q: np.dtype): | |||||
| """ | |||||
| Quantize a float NumPy ndarray into a qint32 one with specified params. | |||||
| :param arr: Input ndarray. | |||||
| :param q: Target data type, should be a qint8. | |||||
| """ | |||||
| return _convert_to_quantized_dtype(arr, q, "qint32") | |||||
| def convert_from_qint32(arr): | |||||
| """ | |||||
| Dequantize a qint32 NumPy ndarray into a float one. | |||||
| :param arr: Input ndarray. | |||||
| """ | |||||
| return _convert_from_quantized_dtype(arr, "qint32") | |||||
| def convert_to_quint4(arr: np.ndarray, q: np.dtype): | |||||
| """ | |||||
| Quantize a float NumPy ndarray into a quint4 one with specified params. | |||||
| :param arr: Input ndarray. | |||||
| :param q: Target data type, should be a quint4. | |||||
| """ | |||||
| return _convert_to_quantized_dtype(arr, q, "quint4") | |||||
| def convert_from_quint4(arr: np.ndarray): | |||||
| """ | |||||
| Dequantize a quint4 NumPy ndarray into a float one. | |||||
| :param arr: Input ndarray. | |||||
| """ | |||||
| return _convert_from_quantized_dtype(arr, "quint4") | |||||
| def convert_to_qint4(arr: np.ndarray, q: np.dtype): | |||||
| """ | |||||
| Quantize a float NumPy ndarray into a qint4 one with specified params. | |||||
| :param arr: Input ndarray. | |||||
| :param q: Target data type, should be a qint4. | |||||
| """ | |||||
| return _convert_to_quantized_dtype(arr, q, "qint4") | |||||
| def convert_from_qint4(arr: np.ndarray): | |||||
| """ | |||||
| Dequantize a qint4 NumPy ndarray into a float one. | |||||
| :param arr: Input ndarray. | |||||
| """ | |||||
| return _convert_from_quantized_dtype(arr, "qint4") | |||||
| @@ -0,0 +1,158 @@ | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| from ..ops.builtin import OpDef | |||||
| from .core import TensorBase, TensorWrapperBase, apply | |||||
| from .raw_tensor import RawTensor | |||||
| from .tensor import Tensor, push_context | |||||
| from .tensor_wrapper import TensorWrapper | |||||
| class Function: | |||||
| """ | |||||
| Defines a block of operations with customizable differentiation. | |||||
| The computation should be defined in ``forward`` method, with gradient | |||||
| computation defined in ``backward`` method. | |||||
| Each instance of ``Function`` should be used only once during forwardding. | |||||
| Examples: | |||||
| .. testcode:: | |||||
| class Sigmoid(Function): | |||||
| def forward(self, x): | |||||
| y = 1 / (1 + F.exp(-x)) | |||||
| self.y = y | |||||
| return y | |||||
| def backward(self. output_grads): | |||||
| y = self.y | |||||
| return output_grads * y * (1-y) | |||||
| """ | |||||
| def __init__(self, *args, **kwargs): | |||||
| pass | |||||
| def __call__(self, *args): | |||||
| ret = apply(self, *args) | |||||
| if type(ret) == tuple and len(ret) == 1: | |||||
| return ret[0] | |||||
| return ret | |||||
| def forward(self, *args, **kwargs): | |||||
| """ | |||||
| Applies operations to ``inputs`` and returns results. It must be overriden by all subclasses. | |||||
| :param input: Input tensors. | |||||
| :return: A tuple of Tensor or a single Tensor. | |||||
| .. note:: | |||||
| This method should return a tuple of Tensor or a single Tensor representing the output | |||||
| of the function. | |||||
| """ | |||||
| raise NotImplementedError | |||||
| def backward(self, *output_grads): | |||||
| """ | |||||
| Compute the gradient of the forward function. It must be overriden by all subclasses. | |||||
| :param output_grads: gradients of outputs that are returned by :meth:`~.function.Function.forward` | |||||
| .. note:: | |||||
| In case when some tensors of outputs are not related to loss function, the corresponding | |||||
| values in ``output_grads`` would be ``None``. | |||||
| .. note:: | |||||
| This method should return a tuple which containing the gradients of all inputs, in the same order | |||||
| as the ``inputs`` argument of :meth:`~.function.Function.forward` . A ``Tensor`` could be returned | |||||
| instead if there is only one input. If users want to stop the propagation of some gradients, | |||||
| the corresponding returned values should be set ``None`` . | |||||
| """ | |||||
| raise NotImplementedError | |||||
| def get_backward_fn(self): | |||||
| if self.backward is None: | |||||
| return None | |||||
| def _backward(*output_grads): | |||||
| if type(output_grads) is tuple: | |||||
| _output_grads = map(TensorWrapper, output_grads) | |||||
| else: | |||||
| _output_grads = (TensorWrapper(output_grads),) | |||||
| ret = self.backward(*_output_grads) | |||||
| if type(ret) is not tuple: | |||||
| ret = (ret,) | |||||
| ret = tuple([i.__wrapped__ for i in ret]) | |||||
| return ret | |||||
| return _backward | |||||
| Function.apply = Function.__call__ | |||||
| @apply.add | |||||
| def _(op: Function, *args: TensorWrapperBase): | |||||
| assert args | |||||
| Wrapper = type(args[0]) | |||||
| # compute the value for self define function | |||||
| extra_data_dic = {} | |||||
| for arg in args: | |||||
| extra_data_dic[arg.__wrapped__] = arg.__wrapped__._extra_data | |||||
| arg.__wrapped__._extra_data = {} | |||||
| rets = op.forward(*args) | |||||
| for arg in args: | |||||
| arg.__wrapped__._extra_data = extra_data_dic[arg.__wrapped__] | |||||
| # update the gradient information for self define function | |||||
| inputs = tuple(map(lambda i: i.__wrapped__, args)) | |||||
| outputs = ( | |||||
| tuple(map(lambda i: i.__wrapped__, rets)) | |||||
| if type(rets) is tuple | |||||
| else (rets.__wrapped__,) | |||||
| ) | |||||
| for output in outputs: | |||||
| output._extra_data = {} | |||||
| with push_context() as ctx: | |||||
| ctx.inputs = inputs | |||||
| ctx.outputs = outputs | |||||
| for k in set().union(*(i._extra_data for i in inputs if isinstance(i, Tensor))): | |||||
| ctx.key = k | |||||
| data = tuple( | |||||
| i._extra_data.get(k) if isinstance(i, Tensor) else i for i in inputs | |||||
| ) | |||||
| # data are instances of Tracer | |||||
| # dispatched to apply.add@grad.py | |||||
| rets = apply(op, *data) | |||||
| if rets is not None: | |||||
| assert len(outputs) == len(rets) | |||||
| for t, i in zip(outputs, rets): | |||||
| t._extra_data[k] = i | |||||
| return tuple(map(Wrapper, outputs)) | |||||
| @apply.add | |||||
| def _(op: Function, *args: Tensor): | |||||
| raise NotImplementedError | |||||
| @apply.add | |||||
| def _(op: Function, *args: RawTensor): | |||||
| raise NotImplementedError | |||||
| @@ -0,0 +1,251 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| import numpy as np | |||||
| from ..ops import builtin | |||||
| from ..ops.special import Const | |||||
| from .core import TensorBase, TensorWrapperBase, apply | |||||
| def remove_ellipsis(tensor, tuple_val): | |||||
| ndim_sum = tensor.ndim | |||||
| cur_sum = 0 | |||||
| pos = -1 | |||||
| for i_idx, i in enumerate(tuple_val): | |||||
| if i is Ellipsis: | |||||
| for j in tuple_val[:i_idx:-1]: | |||||
| if j is Ellipsis: | |||||
| raise IndexError("only one ellipsis is allowed") | |||||
| pos = i_idx | |||||
| else: | |||||
| cur_sum += i.ndim if hasattr(i, "ndim") else 1 | |||||
| if pos == -1: | |||||
| return tuple_val | |||||
| else: | |||||
| return ( | |||||
| tuple_val[:pos] | |||||
| + (slice(None, None, None),) * (ndim_sum - cur_sum) | |||||
| + tuple_val[pos + 1 :] | |||||
| ) | |||||
| def check_bool_index(tensor, tuple_val): | |||||
| cur_shape = tensor.shape | |||||
| new_tuple_val = [] | |||||
| offset = 0 | |||||
| tdim = 0 | |||||
| for idx, i in enumerate(tuple_val): | |||||
| if hasattr(i, "dtype") and i.dtype == np.bool_: | |||||
| if i.ndim > 1: | |||||
| tot = i.ndim | |||||
| for j in range(i.ndim): | |||||
| if cur_shape[tdim + j - offset] != i.shape[j]: | |||||
| raise IndexError( | |||||
| "boolean index did not match tensor along dimension {}; dimension is {} but corresponding boolean dimension is {}".format( | |||||
| tdim + j, cur_shape[tdim + j - offset], i.shape[j] | |||||
| ) | |||||
| ) | |||||
| i = i.reshape(-1) | |||||
| cur_shape = ( | |||||
| cur_shape[:idx] + (i.shape[0],) + cur_shape[tdim + tot - offset :] | |||||
| ) | |||||
| offset += 1 | |||||
| tensor = tensor.reshape(cur_shape) | |||||
| tdim += tot | |||||
| new_tuple_val.append(i) | |||||
| else: | |||||
| new_tuple_val.append(i) | |||||
| tdim += 1 | |||||
| return tensor, new_tuple_val | |||||
| def unpack_getitem(inp, tuple_val, *, allow_newaxis=True): | |||||
| if not isinstance(tuple_val, tuple): | |||||
| tuple_val = (tuple_val,) | |||||
| ndim_indexed = 0 | |||||
| for i in tuple_val: | |||||
| if not i is Ellipsis: | |||||
| ndim_indexed += 1 if not hasattr(i, "ndim") else i.ndim | |||||
| if ndim_indexed > inp.ndim: | |||||
| raise IndexError( | |||||
| "too many indices for tensor: tensor is {}-dimensional, but {} were indexed".format( | |||||
| inp.ndim, ndim_indexed | |||||
| ) | |||||
| ) | |||||
| tuple_val = remove_ellipsis(inp, tuple_val) | |||||
| use_subtensor = True | |||||
| inp, tuple_val = check_bool_index(inp, tuple_val) | |||||
| def is_scalar(d): | |||||
| if isinstance(i, int): | |||||
| return True | |||||
| if type(d).__module__ == np.__name__: | |||||
| return np.isscalar(d) | |||||
| # if isinstance(d, (TensorBase, TensorWrapperBase)): | |||||
| # return d.shape == (1,) | |||||
| return False | |||||
| new_axes = [] | |||||
| tensors = [] | |||||
| items = [] | |||||
| cur_axis = -1 | |||||
| for i_idx, i in enumerate(tuple_val): | |||||
| cur_axis += 1 | |||||
| if i is np.newaxis: | |||||
| if cur_axis >= 0: | |||||
| new_axes.append(cur_axis) | |||||
| continue | |||||
| if i is Ellipsis: | |||||
| cur_axis = -1 | |||||
| for j in tuple_val[:i_idx:-1]: | |||||
| if j is Ellipsis: | |||||
| raise IndexError("only one ellipsis is allowed") | |||||
| if j is np.newaxis: | |||||
| new_axes.append(cur_axis) | |||||
| cur_axis -= 1 | |||||
| continue | |||||
| if ( | |||||
| not is_scalar(i) | |||||
| and not i is np.newaxis | |||||
| and not i is Ellipsis | |||||
| and not isinstance(i, slice) | |||||
| ): | |||||
| use_subtensor = False | |||||
| item = [ | |||||
| cur_axis, | |||||
| ] | |||||
| def is_bool_list(x): | |||||
| if not isinstance(x, list): | |||||
| return False | |||||
| for i in x: | |||||
| if not isinstance(i, bool): | |||||
| return False | |||||
| return True | |||||
| def get_index(i): | |||||
| if not isinstance(i, (TensorBase, TensorWrapperBase)): | |||||
| if is_bool_list(i) or isinstance(i, np.ndarray) and i.dtype == np.bool_: | |||||
| (i,) = Const(i, dtype=np.bool_, device=inp.device)(inp) | |||||
| else: | |||||
| (i,) = Const(i, dtype=np.int32, device=inp.device)(inp) | |||||
| return i | |||||
| assert isinstance(i, (TensorBase, TensorWrapperBase)) | |||||
| if i.dtype != np.bool_: | |||||
| return i | |||||
| _, ind = apply(builtin.CondTake(), i, i) | |||||
| return ind | |||||
| def push(v, item, tensors): | |||||
| if v is None: | |||||
| item.append(False) | |||||
| else: | |||||
| item.append(True) | |||||
| v = get_index(v) | |||||
| assert np.issubdtype(v.dtype, np.integer) or np.issubdtype( | |||||
| v.dtype, np.bool | |||||
| ), "var type in the subscript must be int or bool" | |||||
| tensors.append(v) | |||||
| if isinstance(i, slice): | |||||
| if i.start is None and i.stop is None and i.step is None: | |||||
| continue | |||||
| push(i.start, item, tensors) | |||||
| push(i.stop, item, tensors) | |||||
| push(i.step, item, tensors) | |||||
| item.append(False) # idx | |||||
| else: | |||||
| item += [False,] * 3 # begin, end, stop | |||||
| push(i, item, tensors) | |||||
| assert len(item) == 5 | |||||
| items.append(item) | |||||
| if new_axes: | |||||
| raise IndexError("newaxis is not allowed here") | |||||
| return inp, tensors, items, use_subtensor | |||||
| def try_condtake(tensor, index): | |||||
| if not hasattr(index, "dtype") or not hasattr(index, "shape"): | |||||
| return [] | |||||
| if index.dtype != np.bool_ or index.shape != tensor.shape: | |||||
| return [] | |||||
| if isinstance(index, np.ndarray): | |||||
| (i,) = Const(i, dtype=np.bool_, device=inp.device)(inp) | |||||
| assert isinstance(index, (TensorBase, TensorWrapperBase)) | |||||
| if not isinstance(tensor, (TensorWrapperBase, TensorBase)): | |||||
| raise TypeError("input must be a tensor") | |||||
| if tensor.device != index.device: | |||||
| raise ValueError( | |||||
| "ambiguous device: {} vs {}".format(tensor.device, index.device) | |||||
| ) | |||||
| return apply(builtin.CondTake(), tensor, index) | |||||
| def getitem(tensor, index): | |||||
| try_result = try_condtake(tensor, index) | |||||
| if len(try_result) == 2: | |||||
| return try_result[0] | |||||
| tensor, tensors, items, use_subtensor = unpack_getitem(tensor, index) | |||||
| for v in tensors: | |||||
| if v.shape[0] == 0: | |||||
| (empty_tensor,) = Const([], dtype=tensor.dtype, device=tensor.device)( | |||||
| tensor | |||||
| ) | |||||
| return empty_tensor | |||||
| if use_subtensor: | |||||
| op = builtin.Subtensor(items=items) | |||||
| else: | |||||
| op = builtin.IndexingMultiAxisVec(items=items) | |||||
| (result,) = apply(op, tensor, *tensors) | |||||
| return result | |||||
| def setitem(tensor, index, value): | |||||
| org_shape = tensor.shape | |||||
| try_result = try_condtake(tensor, index) | |||||
| if len(try_result) == 2: | |||||
| index = try_result[1] | |||||
| if index.shape[0] == 0: | |||||
| return tensor | |||||
| tensor = tensor.reshape(-1) | |||||
| if not isinstance(value, (TensorBase, TensorWrapperBase)): | |||||
| op = Const(value, dtype=tensor.dtype, device=tensor.device) | |||||
| (value,) = op(tensor) | |||||
| tensor, tensors, items, use_subtensor = unpack_getitem(tensor, index) | |||||
| for v in tensors: | |||||
| if v.shape[0] == 0: | |||||
| return tensor | |||||
| if use_subtensor: | |||||
| op = builtin.Subtensor(items=items) | |||||
| else: | |||||
| op = builtin.IndexingMultiAxisVec(items=items) | |||||
| (tmp_result,) = apply(op, tensor, *tensors) | |||||
| if value.shape != tmp_result.shape: | |||||
| for i in range(min(len(value.shape), len(tmp_result.shape))): | |||||
| if ( | |||||
| value.shape[-i - 1] != 1 | |||||
| and value.shape[-i - 1] != tmp_result.shape[-i - 1] | |||||
| ): | |||||
| raise ValueError( | |||||
| "cannot copy tensor with shape {} to subtensor with shape {}".format( | |||||
| value.shape, tmp_result.shape | |||||
| ) | |||||
| ) | |||||
| value = value.broadcast(tmp_result.shape) | |||||
| if use_subtensor: | |||||
| op = builtin.SetSubtensor(items=items) | |||||
| else: | |||||
| op = builtin.IndexingSetMultiAxisVec(items=items) | |||||
| (result,) = apply(op, tensor, value, *tensors) | |||||
| result = result.reshape(org_shape) | |||||
| return result | |||||
| @@ -0,0 +1,196 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| import collections | |||||
| import threading | |||||
| import weakref | |||||
| from concurrent.futures import Future, ThreadPoolExecutor | |||||
| from .. import _imperative_rt | |||||
| from .._wrap import device as as_device | |||||
| from ..ops.builtin import OpDef | |||||
| from .core import OpBase, TensorBase, apply | |||||
| class CompiledFunction: | |||||
| def __init__(self, graph, function): | |||||
| self._graph = graph | |||||
| self._function = function | |||||
| self._future = None | |||||
| def execute(self, *args): | |||||
| assert self._future is None | |||||
| self._future = self._graph._executor.submit(self._function.execute, *args) | |||||
| def wait(self): | |||||
| assert self._future is not None | |||||
| self._future.exception() | |||||
| self._function.wait() | |||||
| try: | |||||
| return self._future.result() | |||||
| finally: | |||||
| self._future = None | |||||
| def __call__(self, *args): | |||||
| self.execute(*args) | |||||
| return self.wait() | |||||
| class Graph(_imperative_rt.ComputingGraph): | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| self._var_cache = weakref.WeakKeyDictionary() | |||||
| self._op_cache = weakref.WeakKeyDictionary() | |||||
| self._executor = ThreadPoolExecutor(1) | |||||
| def _wrap(self, obj): | |||||
| if type(obj) is _imperative_rt.VarNode: | |||||
| wrapper, cache = VarNode, self._var_cache | |||||
| elif type(obj) is _imperative_rt.OperatorNode: | |||||
| wrapper, cache = OpNode, self._op_cache | |||||
| if obj not in cache: | |||||
| cache[obj] = wrapper(obj) | |||||
| return cache[obj] | |||||
| def compile(self, *args): | |||||
| return CompiledFunction(self, super().compile(_unwrap(args))) | |||||
| class VarNode(TensorBase): | |||||
| def __init__(self, node: _imperative_rt.VarNode): | |||||
| self._node = node | |||||
| @property | |||||
| def graph(self) -> Graph: | |||||
| return self._node.graph | |||||
| @property | |||||
| def op(self): | |||||
| return self.graph._wrap(self._node.owner) | |||||
| @property | |||||
| def dtype(self): | |||||
| return self._node.dtype | |||||
| @property | |||||
| def device(self): | |||||
| return as_device(self._node.comp_node) | |||||
| class OpNode: | |||||
| def __init__(self, node: _imperative_rt.OperatorNode): | |||||
| self._node = node | |||||
| @property | |||||
| def graph(self) -> Graph: | |||||
| return self._node.graph | |||||
| @property | |||||
| def inputs(self): | |||||
| return tuple(map(self.graph._wrap, self._node.inputs)) | |||||
| @property | |||||
| def outputs(self): | |||||
| return tuple(map(self.graph._wrap, self._node.outputs)) | |||||
| def _wrap(x): | |||||
| if isinstance(x, collections.Sequence): | |||||
| return type(x)(map(_wrap, x)) | |||||
| return x.graph._wrap(x) | |||||
| def _unwrap(x): | |||||
| if isinstance(x, collections.Sequence): | |||||
| return type(x)(map(_unwrap, x)) | |||||
| return x._node | |||||
| @apply.add | |||||
| def _(op: OpDef, *args: VarNode): | |||||
| outputs = _imperative_rt.invoke_op(op, _unwrap(args)) | |||||
| return _wrap(outputs) | |||||
| def input_callback(callback, *args, device=None, dtype=None, graph=None): | |||||
| outputs = _imperative_rt.input_callback( | |||||
| callback, as_device(device).to_c(), dtype, _unwrap(args), graph=graph | |||||
| ) | |||||
| value, dummy = _wrap(outputs) | |||||
| return value, dummy | |||||
| class InputNode(OpNode): | |||||
| def __init__(self, *args: VarNode, device=None, dtype=None, graph=None): | |||||
| r = _imperative_rt.DeviceTensorNDRendezvous() | |||||
| if device is not None: | |||||
| device = as_device(device).to_c() | |||||
| outputs = _imperative_rt.input_callback( | |||||
| r, device, dtype, _unwrap(args), graph=graph | |||||
| ) | |||||
| super().__init__(outputs[0].owner) | |||||
| self._rendezvous = r | |||||
| def set_value(self, value): | |||||
| assert isinstance(value, _imperative_rt.DeviceTensorND) | |||||
| self._rendezvous.set(value) | |||||
| def reset(self): | |||||
| self._rendezvous.reset() | |||||
| @property | |||||
| def device(self): | |||||
| return self.outputs[0].device | |||||
| @property | |||||
| def dtype(self): | |||||
| return self.outputs[0].dtype | |||||
| def output_callback(callback, var, *args): | |||||
| args = (var,) + args | |||||
| dummy = _imperative_rt.output_callback(callback, _unwrap(args)) | |||||
| return _wrap(dummy) | |||||
| class OutputNode(OpNode): | |||||
| def __init__(self, var, *args): | |||||
| args = (var,) + args | |||||
| r = _imperative_rt.DeviceTensorNDRendezvous() | |||||
| dummy = _imperative_rt.output_callback(r, _unwrap(args)) | |||||
| super().__init__(dummy.owner) | |||||
| self._rendezvous = r | |||||
| def get_value(self): | |||||
| return self._rendezvous.get() | |||||
| def reset(self): | |||||
| self._rendezvous.reset() | |||||
| class TensorAttr: | |||||
| def __init__(self, shape, dtype, device): | |||||
| self.shape = shape | |||||
| self.dtype = dtype | |||||
| self.device = device | |||||
| class AttrOutputNode(OpNode): | |||||
| def __init__(self, var, *args): | |||||
| args = (var,) + args | |||||
| r = _imperative_rt.TensorAttrRendezvous() | |||||
| dummy = _imperative_rt.attr_output_callback(r, _unwrap(args)) | |||||
| super().__init__(dummy.owner) | |||||
| self._rendezvous = r | |||||
| def get_value(self): | |||||
| attr = self._rendezvous.get() | |||||
| return TensorAttr(attr.shape, attr.dtype, as_device(attr.comp_node)) | |||||
| def reset(self): | |||||
| self._rendezvous.reset() | |||||
| @@ -0,0 +1,108 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| import functools | |||||
| import numpy as np | |||||
| from ..._imperative_rt import CompNode, DeviceTensorND | |||||
| from ..._imperative_rt.imperative import ( | |||||
| _get_dev_tensor, | |||||
| apply_op, | |||||
| delete, | |||||
| get_device, | |||||
| get_dtype, | |||||
| get_shape, | |||||
| get_value, | |||||
| put, | |||||
| ) | |||||
| from ..._wrap import device as as_device | |||||
| from ...ops.builtin import Copy, OpDef, TypeCvt | |||||
| from ...ops.special import Const | |||||
| from ..core import OpBase, TensorBase, apply | |||||
| class RawTensor(TensorBase): | |||||
| _init_cb = None | |||||
| _del_cb = None | |||||
| def __init__(self, handle): | |||||
| self._handle = handle | |||||
| if self._init_cb: | |||||
| self._init_cb() | |||||
| @property | |||||
| def dtype(self): | |||||
| return get_dtype(self._handle) | |||||
| @property | |||||
| def device(self): | |||||
| return as_device(get_device(self._handle)) | |||||
| @property | |||||
| def shape(self): | |||||
| return get_shape(self._handle) | |||||
| def numpy(self): | |||||
| return get_value(self._handle) | |||||
| def _dev_tensor(self): | |||||
| return _get_dev_tensor(self._handle) | |||||
| def __repr__(self): | |||||
| return "{}({}, device='{}')".format( | |||||
| type(self).__qualname__, repr(self.numpy()), self.device | |||||
| ) | |||||
| def __del__(self): | |||||
| if self._del_cb: | |||||
| self._del_cb() | |||||
| delete(self._handle) | |||||
| @apply.add | |||||
| def _(op: OpDef, *args: RawTensor): | |||||
| outputs = apply_op(op, tuple(i._handle for i in args)) | |||||
| return tuple(map(RawTensor, outputs)) | |||||
| @apply.add | |||||
| def _(op: Const, *args: RawTensor): | |||||
| dtype = op.dtype | |||||
| device = as_device(op.device).to_c() | |||||
| return (as_raw_tensor(op.value, dtype=dtype, device=device),) | |||||
| @functools.singledispatch | |||||
| def as_raw_tensor(obj, dtype=None, device=None): | |||||
| obj = np.asarray(obj, dtype=dtype) | |||||
| if obj.dtype == np.float64: | |||||
| obj = obj.astype(np.float32) | |||||
| if obj.dtype == np.int64: | |||||
| obj = obj.astype(np.int32) | |||||
| return as_raw_tensor(obj, device=device) | |||||
| @as_raw_tensor.register(np.ndarray) | |||||
| def _(array: np.ndarray, dtype=None, device=None): | |||||
| device = None if device is None else as_device(device).to_c() | |||||
| return RawTensor(put(array, dtype=dtype, device=device)) | |||||
| @as_raw_tensor.register(RawTensor) | |||||
| def _(tensor: RawTensor, dtype=None, device=None): | |||||
| if dtype is not None: | |||||
| dtype = np.dtype(dtype) | |||||
| if dtype != tensor.dtype: | |||||
| (tensor,) = apply(TypeCvt(dtype=dtype), tensor) | |||||
| if device is not None: | |||||
| device = as_device(device) | |||||
| if device != tensor.device: | |||||
| (tensor,) = apply(Copy(comp_node=device.to_c()), tensor) | |||||
| return tensor | |||||
| @@ -0,0 +1,251 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| import functools | |||||
| import io | |||||
| import weakref | |||||
| class partial(functools.partial): | |||||
| def __get__(self, instance, owner=None): | |||||
| if instance is None: | |||||
| return self | |||||
| return functools.partial(self, instance) | |||||
| def hook(f): | |||||
| def decorator(impl): | |||||
| return functools.update_wrapper(partial(f, impl), impl) | |||||
| return decorator | |||||
| def on_input(impl, value): | |||||
| tensor = impl(value) | |||||
| trace = get_trace() | |||||
| if trace: | |||||
| var = trace.get_var(tensor) | |||||
| event = InputEvent(var) | |||||
| trace.append(event) | |||||
| return tensor | |||||
| def on_read_dtype(impl, self): | |||||
| trace = get_trace() | |||||
| if trace: | |||||
| var = trace.get_var(self) | |||||
| event = ReadDtypeEvent(var) | |||||
| trace.append(event) | |||||
| return impl(self) | |||||
| def on_read_device(impl, self): | |||||
| trace = get_trace() | |||||
| if trace: | |||||
| var = trace.get_var(self) | |||||
| event = ReadDeviceEvent(var) | |||||
| trace.append(event) | |||||
| return impl(self) | |||||
| def on_read_shape(impl, self): | |||||
| trace = get_trace() | |||||
| if trace: | |||||
| var = trace.get_var(self) | |||||
| event = ReadShapeEvent(var) | |||||
| trace.append(event) | |||||
| return impl(self) | |||||
| def on_read_value(impl, self): | |||||
| trace = get_trace() | |||||
| if trace: | |||||
| var = trace.get_var(self) | |||||
| event = ReadValueEvent(var) | |||||
| trace.append(event) | |||||
| return impl(self) | |||||
| def on_builtin_op(impl, op, *args): | |||||
| outputs = impl(op, *args) | |||||
| trace = get_trace() | |||||
| if trace: | |||||
| input_vars = tuple(map(trace.get_var, args)) | |||||
| output_vars = outputs and tuple(map(trace.get_var, outputs)) | |||||
| event = OpEvent(op, input_vars, output_vars) | |||||
| trace.append(event) | |||||
| return outputs | |||||
| def on_del(impl, self): | |||||
| trace = get_trace() | |||||
| if trace: | |||||
| var = trace.get_var(self) | |||||
| event = DelEvent(var) | |||||
| trace.append(event) | |||||
| return impl(self) | |||||
| class Trace(list): | |||||
| def __init__(self): | |||||
| self._var_id = 1 | |||||
| self._t2v = weakref.WeakKeyDictionary() | |||||
| self._v2t = weakref.WeakValueDictionary() | |||||
| def get_var(self, x): | |||||
| v = self._t2v.get(x) | |||||
| if v: | |||||
| return v | |||||
| v = self._var_id | |||||
| self._var_id += 1 | |||||
| self._t2v[x] = v | |||||
| self._v2t[v] = x | |||||
| return v | |||||
| def __bool__(self): | |||||
| return True | |||||
| def __enter__(self): | |||||
| global _current_trace | |||||
| if hasattr(self, "_prev_trace"): | |||||
| raise RuntimeError | |||||
| self._prev_trace = _current_trace | |||||
| _current_trace = self | |||||
| return self | |||||
| def __exit__(self, *_): | |||||
| global _current_trace | |||||
| if _current_trace is not self: | |||||
| raise RuntimeError | |||||
| _current_trace = self._prev_trace | |||||
| del self._prev_trace | |||||
| class Event: | |||||
| pass | |||||
| class InputEvent(Event): | |||||
| def __init__(self, var): | |||||
| self.var = var | |||||
| class ReadEvent(Event): | |||||
| def __init__(self, var): | |||||
| self.var = var | |||||
| class ReadDtypeEvent(ReadEvent): | |||||
| pass | |||||
| class ReadDeviceEvent(ReadEvent): | |||||
| pass | |||||
| class ReadShapeEvent(ReadEvent): | |||||
| pass | |||||
| class ReadValueEvent(ReadEvent): | |||||
| pass | |||||
| class OpEvent(Event): | |||||
| def __init__(self, op, inputs, outputs): | |||||
| self.op = op | |||||
| self.inputs = inputs | |||||
| self.outputs = outputs | |||||
| class DelEvent(Event): | |||||
| def __init__(self, var): | |||||
| self.var = var | |||||
| _current_trace = None | |||||
| def get_trace() -> Trace: | |||||
| global _current_trace | |||||
| return _current_trace | |||||
| def format_trace(trace): | |||||
| buf = io.StringIO() | |||||
| active_vars = set() | |||||
| def write(fmt, *args, **kwargs): | |||||
| print(fmt.format(*args, **kwargs), file=buf) | |||||
| def init_vars(*args): | |||||
| for i in args: | |||||
| if i in active_vars: | |||||
| continue | |||||
| active_vars.add(i) | |||||
| write("_{} = input()", i) | |||||
| for event in trace: | |||||
| if isinstance(event, InputEvent): | |||||
| init_vars(event.var) | |||||
| elif isinstance(event, ReadDtypeEvent): | |||||
| init_vars(event.var) | |||||
| write("output(_{}.dtype)", event.var) | |||||
| elif isinstance(event, ReadDeviceEvent): | |||||
| init_vars(event.var) | |||||
| write("output(_{}.device)", event.var) | |||||
| elif isinstance(event, ReadShapeEvent): | |||||
| init_vars(event.var) | |||||
| write("output(_{}.shape)", event.var) | |||||
| elif isinstance(event, ReadValueEvent): | |||||
| init_vars(event.var) | |||||
| write("output(_{}.dtype)", event.var) | |||||
| elif isinstance(event, ReadValueEvent): | |||||
| init_vars(event.var) | |||||
| write("output(_{}.value)", event.var) | |||||
| elif isinstance(event, OpEvent): | |||||
| init_vars(*event.inputs) | |||||
| active_vars.update(event.outputs) | |||||
| ovars = ", ".join(map("_{}".format, event.outputs)) | |||||
| ivars = ", ".join(map("_{}".format, event.inputs)) | |||||
| if ovars: | |||||
| write("{} = {}({})", ovars, repr(event.op), ivars) | |||||
| else: | |||||
| write("{}({})", repr(event.op), ivars) | |||||
| elif isinstance(event, DelEvent): | |||||
| init_vars(event.var) | |||||
| write("del _{}", event.var) | |||||
| else: | |||||
| raise TypeError(type(event)) | |||||
| return buf.getvalue() | |||||
| def compile_trace(trace): | |||||
| trace = list(trace) | |||||
| def static_function(f): | |||||
| trace = None | |||||
| @functools.wraps(f) | |||||
| def wrapper(*args, **kwargs): | |||||
| nonlocal trace | |||||
| if trace is None: | |||||
| with Trace() as trace: | |||||
| return f(*args, **kwargs) | |||||
| return f(*args, **kwargs) | |||||
| return wrapper | |||||
| @@ -0,0 +1,263 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| import functools | |||||
| import weakref | |||||
| # Concepts | |||||
| # | |||||
| # * Internal tensor | |||||
| # Tensor produced by the static sequence | |||||
| # | |||||
| # * External tensor | |||||
| # Tensor not produced, but used as input, by the static sequence | |||||
| # | |||||
| # * Irrelevant tensor | |||||
| # Tensor not present in input/output of any op | |||||
| # | |||||
| # * Escape | |||||
| # An internal tensor is said to escape if it is still alive | |||||
| # at the end of the sequence | |||||
| # JIT-ed execution | |||||
| # | |||||
| # 1. read attr (dtype, device, shape) | |||||
| # a. internal tensor | |||||
| # read out as soon as tensor is produced | |||||
| # b. external or irrelevant tensor | |||||
| # fallback | |||||
| # | |||||
| # 2. apply op | |||||
| # bind external tensors in input | |||||
| # | |||||
| # 3. del | |||||
| class Action: | |||||
| pass | |||||
| class ReadAttrAction(Action): | |||||
| def __init__(self, var, name, getter): | |||||
| self.var = var | |||||
| self.name = name | |||||
| self.getter = getter | |||||
| class ReadValueAction(Action): | |||||
| def __init__(self, var, getter): | |||||
| self.var = var | |||||
| self.getter = getter | |||||
| class GetTensorAction(Action): | |||||
| def __init__(self, var, getter): | |||||
| self.var = var | |||||
| self.getter = getter | |||||
| class OpAction(Action): | |||||
| def __init__(self, op, inputs, outputs, input_receivers): | |||||
| self.op = op | |||||
| self.inputs = inputs | |||||
| self.outputs = outputs | |||||
| self.input_receivers = input_receivers | |||||
| class TensorAttr: | |||||
| def __init__(self): | |||||
| self.shape = None | |||||
| self.dtype = None | |||||
| self.device = None | |||||
| class Bailout(Exception): | |||||
| pass | |||||
| class Fallback(Exception): | |||||
| pass | |||||
| def handle_bailout_fallback_finalize(f): | |||||
| @functools.wraps(f) | |||||
| def wrapper(self, impl, *args, **kwargs): | |||||
| try: | |||||
| return f(*args, **kwargs) | |||||
| except Bailout: | |||||
| self.bailout() | |||||
| except Fallback: | |||||
| pass | |||||
| finally: | |||||
| if self.pc == len(self): | |||||
| self.finalize() | |||||
| return impl(*args, **kwargs) | |||||
| return wrapper | |||||
| class ExecTrajectory(list): | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| self.reset() | |||||
| def __bool__(self): | |||||
| return True | |||||
| def __enter__(self): | |||||
| global _current_trajectory | |||||
| if hasattr(self, "_prev_trajectory"): | |||||
| raise RuntimeError | |||||
| self._prev_trajectory = _current_trajectory | |||||
| _current_trajectory = self | |||||
| self._exited = False | |||||
| return self | |||||
| def __exit__(self, *exc_info): | |||||
| # cleanup should be done at completion, | |||||
| # which is before exiting context manager | |||||
| assert self._exited == (exc_info == (None, None, None)) | |||||
| if not self._exited: | |||||
| assert self.pc < len(self) | |||||
| self.bailout() | |||||
| def _exit(self): | |||||
| # clean up self and global varaible | |||||
| assert not self._exited | |||||
| self.reset() | |||||
| global _current_trajectory | |||||
| if _current_trajectory is not self: | |||||
| raise RuntimeError | |||||
| _current_trajectory = self._prev_trajectory | |||||
| del self._prev_trajectory | |||||
| def reset(self): | |||||
| self._exited = True | |||||
| self.pc = 0 | |||||
| self.attr_cache = weakref.WeakKeyDictionary() | |||||
| ### Internal and External Tensor ### | |||||
| # internal tensors are those produced by us | |||||
| # external tensors are those received from outside | |||||
| # during JIT-ed execution, internal tensors are just placeholders. | |||||
| # var_to_tensor is the binding table for all tensors | |||||
| self.var_to_tensor = {} # var -> weakref[tensor] | |||||
| # tensor_to_var is the reverse binding table for internal tensors | |||||
| # note that external tensors could map to >1 vars. | |||||
| self.tensor_to_var = weakref.WeakKeyDictionary() | |||||
| # internal tensor will be materialized if its .data is accessed from outside | |||||
| # after being meterialized, an intern tensor is much like an external tensor | |||||
| def finalize(self): | |||||
| assert self.pc == len(self) | |||||
| self._exit() | |||||
| def bailout(self): | |||||
| self._exit() | |||||
| raise NotImplementedError | |||||
| def next_action(self): | |||||
| assert not self._exited | |||||
| assert self.pc < len(self) | |||||
| return self[self.pc] | |||||
| @handle_bailout_fallback_finalize | |||||
| def read_attr(self, tensor, name): | |||||
| attrs = self.attr_cache.setdefault(tensor, TensorAttr()) | |||||
| value = getattr(attrs, name, None) | |||||
| if value is None: | |||||
| action = self.next_action() | |||||
| if not isinstance(action, ReadAttrAction): | |||||
| raise Bailout | |||||
| if name != action.name: | |||||
| raise Bailout | |||||
| value = action.getter() | |||||
| setattr(attrs, name, value) | |||||
| return value | |||||
| @handle_bailout_fallback_finalize | |||||
| def read_value(self, impl, tensor): | |||||
| # possibilities: | |||||
| # 1. internal tensor | |||||
| # 2. external tensor | |||||
| # 3. irrelevant tensor (not an input / output of any op) | |||||
| if tensor not in self.tensor_to_var: | |||||
| raise Fallback | |||||
| assert tensor._data is None | |||||
| action = self.next_action() | |||||
| if not isinstance(action, ReadValueAction): | |||||
| raise Bailout | |||||
| return action.getter() | |||||
| @handle_bailout_fallback_finalize | |||||
| def apply_op(self, impl, op, *args): | |||||
| from . import RawTensor | |||||
| action = self.next_action() | |||||
| if not isinstance(action, OpAction): | |||||
| raise Bailout | |||||
| if len(args) != len(action.inputs): | |||||
| raise Bailout | |||||
| assert len(actions.inputs) == len(action.input_receivers) | |||||
| for v, t, r in zip(action.inputs, args, action.input_receivers): | |||||
| if v in self.var_to_tensor: | |||||
| assert r is None | |||||
| if t is not self.var_to_tensor[v](): | |||||
| raise Bailout | |||||
| else: | |||||
| # NOTE: not checking for aliasing (>=2 vars map to 1 tensor) | |||||
| # the static execution backend must handle this | |||||
| self.var_to_tensor[v] = weakref.ref(t) | |||||
| r(t) | |||||
| outputs = [] | |||||
| for v in action.outputs: | |||||
| assert v not in self.var_to_tensor | |||||
| t = RawTensor() | |||||
| t._data_getter = functools.partial(self.get_data, v) | |||||
| outputs.append(t) | |||||
| self.var_to_tensor[v] = weakref.ref(t) | |||||
| return tuple(outputs) | |||||
| def get_data(self, var): | |||||
| tensor = self.var_to_tensor[var]() | |||||
| assert tensor is not None | |||||
| assert tensor._data is None | |||||
| assert tensor in self.tensor_to_var | |||||
| action = self.next_action() | |||||
| if not isinstance(action, GetTensorAction): | |||||
| self.bailout() | |||||
| elif action.var != var: | |||||
| self.bailout() | |||||
| else: | |||||
| tensor._data = action.getter() | |||||
| del tensor._data_getter | |||||
| del self.tensor_to_var[tensor] | |||||
| assert "_data_getter" not in tensor.__dict__ | |||||
| return tensor._data_getter() | |||||
| _current_trajectory = None | |||||
| def get_trajectory(): | |||||
| return _current_trajectory | |||||
| def compile_trace(trace): | |||||
| from .jit import ReadDTypeEvent, ReadDeviceEvent, ReadShapeEvent, OpEvent, DelEvent | |||||
| traj = ExecutionTrajectory() | |||||
| active_vars = set() | |||||
| for event in trace: | |||||
| if isinstance(event, ReadDTypeEvent): | |||||
| traj.append(ReadAttrAction()) | |||||
| @@ -0,0 +1,106 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| import contextlib | |||||
| import copy | |||||
| from .core import Dispatcher, OpBase, TensorBase, apply | |||||
| class Tensor(TensorBase): | |||||
| def __init__(self, data: TensorBase): | |||||
| self._data = data | |||||
| # _extra_data is set up in Grad.wrt | |||||
| self._extra_data = {} | |||||
| self._user_data = {} | |||||
| def __getattr__(self, name): | |||||
| if name in self._user_data: | |||||
| return self._user_data[name] | |||||
| raise AttributeError(name) | |||||
| def reset(self, other): | |||||
| assert isinstance(other, __class__) | |||||
| self.__dict__.clear() | |||||
| self._data = other.data | |||||
| self._extra_data = other._extra_data.copy() | |||||
| self._user_data = other._user_data.copy() | |||||
| def copy(self): | |||||
| other = object.__new__(type(self)) | |||||
| other.reset(self) | |||||
| return other | |||||
| # tensor interface | |||||
| @property | |||||
| def shape(self): | |||||
| return self._data.shape | |||||
| @property | |||||
| def dtype(self): | |||||
| return self._data.dtype | |||||
| @property | |||||
| def device(self): | |||||
| return self._data.device | |||||
| def numpy(self): | |||||
| return self._data.numpy() | |||||
| class ApplyContext: | |||||
| def __init__(self): | |||||
| self.inputs = None | |||||
| self.outputs = None | |||||
| self.key = None | |||||
| _context = None | |||||
| @contextlib.contextmanager | |||||
| def push_context(): | |||||
| global _context | |||||
| backup = _context | |||||
| try: | |||||
| _context = ApplyContext() | |||||
| yield _context | |||||
| finally: | |||||
| _context = backup | |||||
| def get_context(): | |||||
| return _context | |||||
| @apply.add | |||||
| def tensor_apply(op: OpBase, *args: Tensor): | |||||
| data = tuple(i._data if isinstance(i, Tensor) else i for i in args) | |||||
| # type(Tensor._data) is RawTensor | |||||
| # dispached to apply.add@RawTensor.py if passed Tensor args | |||||
| outputs = apply(op, *data) | |||||
| ret = tuple(map(Tensor, outputs)) | |||||
| with push_context() as ctx: | |||||
| ctx.inputs = args | |||||
| ctx.outputs = ret | |||||
| for k in set().union(*(i._extra_data for i in args if isinstance(i, Tensor))): | |||||
| ctx.key = k | |||||
| data = tuple( | |||||
| i._extra_data.get(k) if isinstance(i, Tensor) else i for i in args | |||||
| ) | |||||
| # data are instances of Tracer | |||||
| # dispatched to apply.add@grad.py | |||||
| outputs = apply(op, *data) | |||||
| if outputs is not None: | |||||
| assert len(outputs) == len(ret) | |||||
| for t, i in zip(ret, outputs): | |||||
| t._extra_data[k] = i | |||||
| return ret | |||||
| @@ -0,0 +1,367 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| import abc | |||||
| import collections | |||||
| import numpy as np | |||||
| from ..ops import builtin | |||||
| from ..ops.special import Const | |||||
| from . import utils | |||||
| from .core import OpBase, TensorBase, TensorWrapperBase, apply | |||||
| from .indexing import getitem as _getitem | |||||
| from .indexing import setitem as _setitem | |||||
| from .raw_tensor import RawTensor, as_raw_tensor | |||||
| from .tensor import Tensor | |||||
| def _elwise(*args, mode): | |||||
| op = builtin.Elemwise(mode=mode) | |||||
| args = utils.convert_inputs(*args) | |||||
| (result,) = apply(op, *args) | |||||
| return result | |||||
| def _matmul(inp1, inp2): | |||||
| op = builtin.MatrixMul( | |||||
| transposeA=False, transposeB=False, compute_mode="DEFAULT", format="DEFAULT" | |||||
| ) | |||||
| inp1, inp2 = utils.convert_inputs(inp1, inp2) | |||||
| (result,) = apply(op, inp1, inp2) | |||||
| return result | |||||
| def _transpose(data, axes): | |||||
| op = builtin.Dimshuffle(axes) | |||||
| (data,) = utils.convert_inputs(data) | |||||
| (result,) = apply(op, data) | |||||
| return result | |||||
| def _broadcast(inp, shape): | |||||
| shape = utils.astensor1d(shape, inp, dtype="int32", device=inp.device) | |||||
| (result,) = apply(builtin.Broadcast(), inp, shape) | |||||
| return result | |||||
| def _reshape(x, shape): | |||||
| if isinstance(shape, (TensorBase, TensorWrapperBase)): | |||||
| shape = shape.numpy() | |||||
| shape = tuple(map(int, shape)) | |||||
| unspec_axis = None | |||||
| for i, s in enumerate(shape): | |||||
| if s < 0: | |||||
| if s != -1: | |||||
| raise ValueError("expect shape[{}] >= -1, got {}".format(i, s)) | |||||
| if unspec_axis is not None: | |||||
| raise ValueError("multiple -1 in shape: {} & {}".format(unspec_axis, i)) | |||||
| unspec_axis = i | |||||
| # TODO: device should be None (cpu) | |||||
| (shape,) = Const(shape, dtype=np.int32, device=x.device)(x) | |||||
| if unspec_axis is None: | |||||
| op = builtin.Reshape() | |||||
| else: | |||||
| op = builtin.Reshape(unspec_axis=unspec_axis) | |||||
| (x,) = apply(op, x, shape) | |||||
| return x | |||||
| def _unary_elwise(mode): | |||||
| def f(self): | |||||
| return _elwise(self, mode=mode) | |||||
| return f | |||||
| def _binary_elwise(mode, rev=False): | |||||
| if not rev: | |||||
| def f(self, value): | |||||
| return _elwise(self, value, mode=mode) | |||||
| else: | |||||
| def f(self, value): | |||||
| return _elwise(value, self, mode=mode) | |||||
| return f | |||||
| def _logical_unary_elwise(mode, rev=False): | |||||
| def f(self): | |||||
| if self.dtype != np.bool_: | |||||
| raise TypeError("{} requires a bool tensor".format(mode)) | |||||
| return _elwise(self, mode=mode) | |||||
| return f | |||||
| def _logical_binary_elwise(mode, rev=False): | |||||
| if not rev: | |||||
| def f(self, value): | |||||
| if self.dtype != np.bool_ or value.dtype != np.bool_: | |||||
| raise TypeError("{} requires 2 bool tensors".format(mode)) | |||||
| return _elwise(self, value, mode=mode) | |||||
| else: | |||||
| def f(self, value): | |||||
| if self.dtype != np.bool_ or value.dtype != np.bool_: | |||||
| raise TypeError("{} requires 2 bool tensors".format(mode)) | |||||
| return _elwise(value, self, mode=mode) | |||||
| return f | |||||
| def _reduce(mode): | |||||
| def f(self, axis=None): | |||||
| inp = self | |||||
| if axis is None: | |||||
| inp = self.flatten() | |||||
| axis = 0 | |||||
| op = builtin.Reduce(mode=mode, axis=axis) | |||||
| (result,) = utils.convert_inputs(inp) | |||||
| (result,) = apply(op, result) | |||||
| return result | |||||
| return f | |||||
| def _inplace(f): | |||||
| def g(self, value): | |||||
| result = f(self, value) | |||||
| if result is NotImplemented: | |||||
| raise NotImplementedError | |||||
| self._reset(result) | |||||
| return self | |||||
| return g | |||||
| def _todo(*_): | |||||
| raise NotImplementedError | |||||
| class ArrayMethodMixin(abc.ABC): | |||||
| __array_priority__ = 233333 | |||||
| @abc.abstractmethod | |||||
| def _reset(self, other): | |||||
| pass | |||||
| @abc.abstractproperty | |||||
| def dtype(self) -> np.dtype: | |||||
| pass | |||||
| @abc.abstractproperty | |||||
| def shape(self) -> tuple: | |||||
| pass | |||||
| @abc.abstractmethod | |||||
| def numpy(self) -> np.ndarray: | |||||
| pass | |||||
| __hash__ = None # due to __eq__ diviates from python convention | |||||
| __lt__ = lambda self, value: _elwise(self, value, mode="LT").astype("bool") | |||||
| __le__ = lambda self, value: _elwise(self, value, mode="LEQ").astype("bool") | |||||
| __gt__ = lambda self, value: _elwise(value, self, mode="LT").astype("bool") | |||||
| __ge__ = lambda self, value: _elwise(value, self, mode="LEQ").astype("bool") | |||||
| __eq__ = lambda self, value: _elwise(self, value, mode="EQ").astype("bool") | |||||
| __ne__ = lambda self, value: _elwise( | |||||
| _elwise(self, value, mode="EQ").astype("bool"), mode="NOT" | |||||
| ) | |||||
| __neg__ = _unary_elwise("NEGATE") | |||||
| __pos__ = lambda self: self | |||||
| __abs__ = _unary_elwise("ABS") | |||||
| __invert__ = _logical_unary_elwise("NOT") | |||||
| __round__ = _unary_elwise("ROUND") | |||||
| __trunc__ = _todo | |||||
| __floor__ = _unary_elwise("FLOOR") | |||||
| __ceil__ = _unary_elwise("CEIL") | |||||
| __add__ = _binary_elwise("ADD") | |||||
| __sub__ = _binary_elwise("SUB") | |||||
| __mul__ = _binary_elwise("MUL") | |||||
| __matmul__ = lambda self, other: _matmul(self, other) | |||||
| __truediv__ = _binary_elwise("TRUE_DIV") | |||||
| __floordiv__ = _binary_elwise("FLOOR_DIV") | |||||
| __mod__ = _binary_elwise("MOD") | |||||
| # __divmode__ | |||||
| __pow__ = _binary_elwise("POW") | |||||
| __lshift__ = _binary_elwise("SHL") | |||||
| __rshift__ = _binary_elwise("SHR") | |||||
| __and__ = _logical_binary_elwise("AND") | |||||
| __or__ = _logical_binary_elwise("OR") | |||||
| __xor__ = _logical_binary_elwise("XOR") | |||||
| __radd__ = _binary_elwise("ADD", rev=1) | |||||
| __rsub__ = _binary_elwise("SUB", rev=1) | |||||
| __rmul__ = _binary_elwise("MUL", rev=1) | |||||
| __rmatmul__ = lambda self, other: _matmul(other, self) | |||||
| __rtruediv__ = _binary_elwise("TRUE_DIV", rev=1) | |||||
| __rfloordiv__ = _binary_elwise("FLOOR_DIV", rev=1) | |||||
| __rmod__ = _binary_elwise("MOD", rev=1) | |||||
| # __rdivmode__ | |||||
| __rpow__ = _binary_elwise("POW", rev=1) | |||||
| __rlshift__ = _binary_elwise("SHL", rev=1) | |||||
| __rrshift__ = _binary_elwise("SHR", rev=1) | |||||
| __rand__ = _logical_binary_elwise("AND", rev=1) | |||||
| __ror__ = _logical_binary_elwise("OR", rev=1) | |||||
| __rxor__ = _logical_binary_elwise("XOR", rev=1) | |||||
| __iadd__ = _inplace(__add__) | |||||
| __isub__ = _inplace(__sub__) | |||||
| __imul__ = _inplace(__mul__) | |||||
| __imatmul__ = _inplace(__matmul__) | |||||
| __itruediv__ = _inplace(__truediv__) | |||||
| __ifloordiv__ = _inplace(__floordiv__) | |||||
| __imod__ = _inplace(__mod__) | |||||
| __ipow__ = _inplace(__pow__) | |||||
| __ilshift__ = _inplace(__lshift__) | |||||
| __irshift__ = _inplace(__rshift__) | |||||
| __iand__ = _inplace(__and__) | |||||
| __ior__ = _inplace(__or__) | |||||
| __ixor__ = _inplace(__xor__) | |||||
| __index__ = lambda self: self.item().__index__() | |||||
| __bool__ = lambda self: bool(self.item()) | |||||
| __int__ = lambda self: int(self.item()) | |||||
| __float__ = lambda self: float(self.item()) | |||||
| __complex__ = lambda self: complex(self.item()) | |||||
| def __len__(self): | |||||
| shape = self.shape | |||||
| if shape: | |||||
| return int(shape[0]) | |||||
| raise TypeError("ndim is 0") | |||||
| def __iter__(self): | |||||
| for i in range(len(self)): | |||||
| yield self[i] | |||||
| def __getitem__(self, index): | |||||
| return _getitem(self, index) | |||||
| def __setitem__(self, index, value): | |||||
| if index is not Ellipsis: | |||||
| value = _setitem(self, index, value) | |||||
| self._reset(value) | |||||
| __contains__ = _todo | |||||
| @property | |||||
| def ndim(self): | |||||
| return len(self.shape) | |||||
| @property | |||||
| def size(self): | |||||
| return np.prod(self.shape).item() | |||||
| @property | |||||
| def T(self): | |||||
| return self.transpose() | |||||
| def item(self, *args): | |||||
| if not args: | |||||
| assert self.size == 1 | |||||
| return self.numpy().item() | |||||
| return self[args].item() | |||||
| def tolist(self): | |||||
| return self.numpy().tolist() | |||||
| def astype(self, dtype): | |||||
| return utils.astype(self, dtype) | |||||
| def reshape(self, *args): | |||||
| if len(args) == 1: | |||||
| if isinstance(args[0], collections.Sequence): | |||||
| args = args[0] | |||||
| return _reshape(self, args) | |||||
| def broadcast(self, *args): | |||||
| if len(args) == 1: | |||||
| if isinstance(args[0], collections.Sequence): | |||||
| args = args[0] | |||||
| return _broadcast(self, args) | |||||
| def transpose(self, *args): | |||||
| if not args: | |||||
| args = reversed(range(self.ndim)) | |||||
| elif len(args) == 1: | |||||
| if isinstance(args[0], collections.Sequence): | |||||
| args = args[0] | |||||
| return _transpose(self, args) | |||||
| def flatten(self): | |||||
| return self.reshape(-1) | |||||
| sum = _reduce("SUM") | |||||
| prod = _reduce("PRODUCT") | |||||
| min = _reduce("MIN") | |||||
| max = _reduce("MAX") | |||||
| mean = _reduce("MEAN") | |||||
| class GenericTensorWrapper(ArrayMethodMixin, TensorWrapperBase): | |||||
| def __init__(self, data): | |||||
| self.__wrapped__ = data | |||||
| def _reset(self, other): | |||||
| if not isinstance(other, __class__): | |||||
| raise TypeError(type(other)) | |||||
| self.__wrapped__ = other.__wrapped__ | |||||
| return self | |||||
| @property | |||||
| def dtype(self): | |||||
| return self.__wrapped__.dtype | |||||
| @property | |||||
| def shape(self): | |||||
| return self.__wrapped__.shape | |||||
| @property | |||||
| def device(self): | |||||
| return self.__wrapped__.device | |||||
| def numpy(self): | |||||
| return self.__wrapped__.numpy() | |||||
| class TensorWrapper(GenericTensorWrapper): | |||||
| def __init__(self, data, dtype=None, device=None): | |||||
| if isinstance(data, TensorWrapperBase): | |||||
| data = data.__wrapped__ | |||||
| elif not isinstance(data, TensorBase): | |||||
| assert data is not None, "Cannot init a tensor with data as None" | |||||
| data = Tensor(as_raw_tensor(data, dtype=dtype, device=device)) | |||||
| super().__init__(data) | |||||
| def _reset(self, other): | |||||
| if isinstance(other, TensorWrapperBase): | |||||
| self.__wrapped__ = other.__wrapped__ | |||||
| elif isinstance(other, TensorBase): | |||||
| self.__wrapped__ = other | |||||
| else: | |||||
| self._reset(type(self)(other, dtype=self.dtype, device=self.device)) | |||||
| def __repr__(self): | |||||
| piece = "Tensor(" | |||||
| with np.printoptions(precision=4, suppress=True): | |||||
| piece += "{}".format(str(self.numpy())) | |||||
| if self.dtype != np.float32: | |||||
| piece += ", dtype={}".format(np.dtype(self.dtype).name) | |||||
| piece += ", device={}".format(self.device) + ")" | |||||
| return piece | |||||
| @@ -0,0 +1,154 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| import collections | |||||
| from typing import Iterable, Union | |||||
| import numpy as np | |||||
| from ..ops import builtin | |||||
| from ..ops.special import Const | |||||
| from ..tensor.core import OpBase, TensorBase, TensorWrapperBase, apply | |||||
| def dtype_promotion(raw_inputs): | |||||
| def add_dtype(i): | |||||
| if type(i) == int: | |||||
| return np.array(i, dtype=np.int32) | |||||
| if type(i) == float: | |||||
| return np.array(i, dtype=np.float32) | |||||
| if type(i) == bool: | |||||
| return np.array(i, dtype=np.bool_) | |||||
| return None | |||||
| scalar_inputs = [ | |||||
| add_dtype(i) for i in raw_inputs if not hasattr(i, "dtype") and add_dtype(i) | |||||
| ] | |||||
| inputs = [i for i in raw_inputs if hasattr(i, "dtype")] | |||||
| assert len(scalar_inputs + inputs) > 0 | |||||
| dtype = np.result_type(*inputs) | |||||
| dtype_all = np.result_type(*(inputs + scalar_inputs)) | |||||
| assert ( | |||||
| dtype != np.float64 and dtype != np.int64 | |||||
| ), "unsupport dtype {} by dtype_promotion, please use explict type convert".format( | |||||
| dtype | |||||
| ) | |||||
| if dtype_all == np.bool_: | |||||
| for i in raw_inputs: | |||||
| if not hasattr(i, "dtype") or i.dtype != np.bool_: | |||||
| raise TypeError( | |||||
| "bool dtype can not be operated with an element without bool dtype" | |||||
| ) | |||||
| if dtype_all == np.float64: | |||||
| dtype_all = np.float32 | |||||
| return dtype_all | |||||
| def get_device(inputs): | |||||
| device = None | |||||
| for i in inputs: | |||||
| if isinstance(i, (TensorWrapperBase, TensorBase)): | |||||
| if device is None: | |||||
| device = i.device | |||||
| elif device != i.device: | |||||
| raise ValueError("ambiguous device: {} vs {}".format(device, i.device)) | |||||
| assert device is not None | |||||
| return device | |||||
| def concatenate(inputs, axis=0, *, device=None): | |||||
| dtype = dtype_promotion(inputs) | |||||
| device = get_device(inputs) | |||||
| def convert(x): | |||||
| return convert_single_value(x, inputs, dtype=dtype) | |||||
| inputs = tuple(map(convert, inputs)) | |||||
| (result,) = apply(builtin.Concat(axis=axis, comp_node=device.to_c()), *inputs) | |||||
| return result | |||||
| def astype(x, dtype): | |||||
| dtype = np.dtype(dtype) | |||||
| if x.dtype != dtype: | |||||
| (x,) = apply(builtin.TypeCvt(param=dtype), x) | |||||
| return x | |||||
| def convert_single_value(v, inputs, *, dtype=None, device=None): | |||||
| tensors = [i for i in inputs if isinstance(i, (TensorBase, TensorWrapperBase))] | |||||
| assert len(tensors) > 0 | |||||
| if isinstance(v, (TensorWrapperBase, TensorBase)): | |||||
| v = astype(v, dtype) | |||||
| else: | |||||
| (v,) = Const(v, dtype=dtype, device=device)(*tensors) | |||||
| return v | |||||
| def convert_inputs(*args: TensorBase): | |||||
| dtype = dtype_promotion(args) | |||||
| device = get_device(args) | |||||
| def convert(value): | |||||
| if value is None: | |||||
| return value | |||||
| return convert_single_value(value, args, dtype=dtype, device=device) | |||||
| return tuple(map(convert, args)) | |||||
| def result_type(*args): | |||||
| dtypes = [] | |||||
| for i in args: | |||||
| if isinstance(i, (TensorWrapperBase, TensorBase)): | |||||
| dtypes.append(i.dtype) | |||||
| continue | |||||
| try: | |||||
| dtypes.append(np.dtype(i)) | |||||
| except TypeError: | |||||
| pass | |||||
| return np.result_type(*dtypes) | |||||
| def isscalar(x): | |||||
| try: | |||||
| return x.ndim == 0 | |||||
| except: | |||||
| pass | |||||
| return np.isscalar(x) | |||||
| def astensor1d(x, *reference, dtype=None, device=None): | |||||
| """ | |||||
| Convert something to 1D tensor. Support following types | |||||
| * sequence of scalar literal / tensor | |||||
| * numpy array | |||||
| * tensor (returned as is, regardless of dtype and device) | |||||
| """ | |||||
| try: | |||||
| ndim = x.ndim | |||||
| except AttributeError: | |||||
| pass | |||||
| else: | |||||
| if ndim != 1: | |||||
| raise ValueError("ndim != 1: %d" % ndim) | |||||
| if not isinstance(x, (TensorBase, TensorWrapperBase)): | |||||
| (x,) = Const(x, dtype=dtype, device=device)(*reference) | |||||
| return x | |||||
| if not isinstance(x, collections.Sequence): | |||||
| raise TypeError | |||||
| if any(isinstance(i, (TensorBase, TensorWrapperBase)) for i in x): | |||||
| x = concatenate(x, device=device) | |||||
| if dtype is not None: | |||||
| x = astype(x, dtype) | |||||
| return x | |||||
| (x,) = Const(x, dtype=dtype, device=device)(*reference) | |||||
| return x | |||||
| @@ -0,0 +1,17 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| from .collator import Collator | |||||
| from .dataloader import DataLoader | |||||
| from .sampler import ( | |||||
| Infinite, | |||||
| RandomSampler, | |||||
| ReplacementSampler, | |||||
| Sampler, | |||||
| SequentialSampler, | |||||
| ) | |||||
| @@ -0,0 +1,139 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| import binascii | |||||
| import os | |||||
| import queue | |||||
| import subprocess | |||||
| from multiprocessing import Queue | |||||
| import pyarrow | |||||
| import pyarrow.plasma as plasma | |||||
| MGE_PLASMA_MEMORY = int(os.environ.get("MGE_PLASMA_MEMORY", 4000000000)) # 4GB | |||||
| # Each process only need to start one plasma store, so we set it as a global variable. | |||||
| # TODO: how to share between different processes? | |||||
| MGE_PLASMA_STORE_MANAGER = None | |||||
| def _clear_plasma_store(): | |||||
| # `_PlasmaStoreManager.__del__` will not be called automaticly in subprocess, | |||||
| # so this function should be called explicitly | |||||
| global MGE_PLASMA_STORE_MANAGER | |||||
| if MGE_PLASMA_STORE_MANAGER is not None: | |||||
| del MGE_PLASMA_STORE_MANAGER | |||||
| MGE_PLASMA_STORE_MANAGER = None | |||||
| class _PlasmaStoreManager: | |||||
| __initialized = False | |||||
| def __init__(self): | |||||
| self.socket_name = "/tmp/mge_plasma_{}".format( | |||||
| binascii.hexlify(os.urandom(8)).decode() | |||||
| ) | |||||
| debug_flag = bool(os.environ.get("MGE_DATALOADER_PLASMA_DEBUG", 0)) | |||||
| # NOTE: this is a hack. Directly use `plasma_store` may make subprocess | |||||
| # difficult to handle the exception happened in `plasma-store-server`. | |||||
| # For `plasma_store` is just a wrapper of `plasma-store-server`, which use | |||||
| # `os.execv` to call the executable `plasma-store-server`. | |||||
| cmd_path = os.path.join(pyarrow.__path__[0], "plasma-store-server") | |||||
| self.plasma_store = subprocess.Popen( | |||||
| [cmd_path, "-s", self.socket_name, "-m", str(MGE_PLASMA_MEMORY),], | |||||
| stdout=None if debug_flag else subprocess.DEVNULL, | |||||
| stderr=None if debug_flag else subprocess.DEVNULL, | |||||
| ) | |||||
| self.__initialized = True | |||||
| def __del__(self): | |||||
| if self.__initialized and self.plasma_store.returncode is None: | |||||
| self.plasma_store.kill() | |||||
| class PlasmaShmQueue: | |||||
| def __init__(self, maxsize: int = 0): | |||||
| r"""Use pyarrow in-memory plasma store to implement shared memory queue. | |||||
| Compared to native `multiprocess.Queue`, `PlasmaShmQueue` avoid pickle/unpickle | |||||
| and communication overhead, leading to better performance in multi-process | |||||
| application. | |||||
| :type maxsize: int | |||||
| :param maxsize: maximum size of the queue, `None` means no limit. (default: ``None``) | |||||
| """ | |||||
| # Lazy start the plasma store manager | |||||
| global MGE_PLASMA_STORE_MANAGER | |||||
| if MGE_PLASMA_STORE_MANAGER is None: | |||||
| try: | |||||
| MGE_PLASMA_STORE_MANAGER = _PlasmaStoreManager() | |||||
| except Exception as e: | |||||
| err_info = ( | |||||
| "Please make sure pyarrow installed correctly!\n" | |||||
| "You can try reinstall pyarrow and see if you can run " | |||||
| "`plasma_store -s /tmp/mge_plasma_xxx -m 1000` normally." | |||||
| ) | |||||
| raise RuntimeError( | |||||
| "Exception happened in starting plasma_store: {}\n" | |||||
| "Tips: {}".format(str(e), err_info) | |||||
| ) | |||||
| self.socket_name = MGE_PLASMA_STORE_MANAGER.socket_name | |||||
| # TODO: how to catch the exception happened in `plasma.connect`? | |||||
| self.client = None | |||||
| # Used to store the header for the data.(ObjectIDs) | |||||
| self.queue = Queue(maxsize) # type: Queue | |||||
| def put(self, data, block=True, timeout=None): | |||||
| if self.client is None: | |||||
| self.client = plasma.connect(self.socket_name) | |||||
| try: | |||||
| object_id = self.client.put(data) | |||||
| except plasma.PlasmaStoreFull: | |||||
| raise RuntimeError("plasma store out of memory!") | |||||
| try: | |||||
| self.queue.put(object_id, block, timeout) | |||||
| except queue.Full: | |||||
| self.client.delete([object_id]) | |||||
| raise queue.Full | |||||
| def get(self, block=True, timeout=None): | |||||
| if self.client is None: | |||||
| self.client = plasma.connect(self.socket_name) | |||||
| object_id = self.queue.get(block, timeout) | |||||
| if not self.client.contains(object_id): | |||||
| raise RuntimeError( | |||||
| "ObjectID: {} not found in plasma store".format(object_id) | |||||
| ) | |||||
| data = self.client.get(object_id) | |||||
| self.client.delete([object_id]) | |||||
| return data | |||||
| def qsize(self): | |||||
| return self.queue.qsize() | |||||
| def empty(self): | |||||
| return self.queue.empty() | |||||
| def join(self): | |||||
| self.queue.join() | |||||
| def disconnect_client(self): | |||||
| if self.client is not None: | |||||
| self.client.disconnect() | |||||
| def close(self): | |||||
| self.queue.close() | |||||
| self.disconnect_client() | |||||
| _clear_plasma_store() | |||||
| def cancel_join_thread(self): | |||||
| self.queue.cancel_join_thread() | |||||
| @@ -0,0 +1,76 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # Copyright (c) 2016- Facebook, Inc (Adam Paszke) | |||||
| # Copyright (c) 2014- Facebook, Inc (Soumith Chintala) | |||||
| # Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) | |||||
| # Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) | |||||
| # Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) | |||||
| # Copyright (c) 2011-2013 NYU (Clement Farabet) | |||||
| # Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) | |||||
| # Copyright (c) 2006 Idiap Research Institute (Samy Bengio) | |||||
| # Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) | |||||
| # --------------------------------------------------------------------- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # | |||||
| # This file has been modified by Megvii ("Megvii Modifications"). | |||||
| # All Megvii Modifications are Copyright (C) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # ---------------------------------------------------------------------- | |||||
| import collections.abc | |||||
| import re | |||||
| import numpy as np | |||||
| np_str_obj_array_pattern = re.compile(r"[aO]") | |||||
| default_collate_err_msg_format = ( | |||||
| "default_collator: inputs must contain numpy arrays, numbers, " | |||||
| "Unicode strings, bytes, dicts or lists; found {}" | |||||
| ) | |||||
| class Collator: | |||||
| r""" | |||||
| Used for merge a list of samples to form a mini-batch of Tenor(s). Used when using batched loading from a dataset. | |||||
| modified from https://github.com/pytorch/pytorch/blob/master/torch/utils/data/_utils/collate.py | |||||
| """ | |||||
| def apply(self, inputs): | |||||
| """ | |||||
| input : sequence_N(tuple(CHW, C, CK)) | |||||
| output : tuple(NCHW, NC, NCK) | |||||
| """ | |||||
| elem = inputs[0] | |||||
| elem_type = type(elem) | |||||
| if ( | |||||
| elem_type.__module__ == "numpy" | |||||
| and elem_type.__name__ != "str_" | |||||
| and elem_type.__name__ != "string_" | |||||
| ): | |||||
| elem = inputs[0] | |||||
| if elem_type.__name__ == "ndarray": | |||||
| # array of string classes and object | |||||
| if np_str_obj_array_pattern.search(elem.dtype.str) is not None: | |||||
| raise TypeError(default_collate_err_msg_format.format(elem.dtype)) | |||||
| return np.ascontiguousarray(np.stack(inputs)) | |||||
| elif elem.shape == (): # scalars | |||||
| return np.array(inputs) | |||||
| elif isinstance(elem, float): | |||||
| return np.array(inputs, dtype=np.float64) | |||||
| elif isinstance(elem, int): | |||||
| return np.array(inputs) | |||||
| elif isinstance(elem, (str, bytes)): | |||||
| return inputs | |||||
| elif isinstance(elem, collections.abc.Mapping): | |||||
| return {key: self.apply([d[key] for d in inputs]) for key in elem} | |||||
| elif isinstance(elem, tuple) and hasattr(elem, "_fields"): # namedtuple | |||||
| return elem_type(*(self.apply(samples) for samples in zip(*inputs))) | |||||
| elif isinstance(elem, collections.abc.Sequence): | |||||
| transposed = zip(*inputs) | |||||
| return [self.apply(samples) for samples in transposed] | |||||
| raise TypeError(default_collate_err_msg_format.format(elem_type)) | |||||
| @@ -0,0 +1,500 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| import collections | |||||
| import math | |||||
| import multiprocessing | |||||
| import queue | |||||
| import random | |||||
| import time | |||||
| import numpy as np | |||||
| from ..logger import get_logger | |||||
| from ..random.rng import _random_seed_generator | |||||
| from .collator import Collator | |||||
| from .dataset import Dataset | |||||
| from .sampler import Sampler, SequentialSampler | |||||
| from .transform import PseudoTransform, Transform | |||||
| logger = get_logger(__name__) | |||||
| MP_QUEUE_GET_TIMEOUT = 5 | |||||
| class DataLoader: | |||||
| __initialized = False | |||||
| def __init__( | |||||
| self, | |||||
| dataset: Dataset, | |||||
| sampler: Sampler = None, | |||||
| transform: Transform = None, | |||||
| collator: Collator = None, | |||||
| num_workers: int = 0, | |||||
| timeout: int = 0, | |||||
| divide: bool = False, | |||||
| ): | |||||
| r"""Provides a convenient way to iterate on a given dataset. | |||||
| `DataLoader` combines a dataset with sampler, transform and collator, | |||||
| make it flexible to get minibatch continually from a dataset. | |||||
| :type dataset: Dataset | |||||
| :param dataset: dataset from which to load the minibatch. | |||||
| :type sampler: Sampler | |||||
| :param sampler: defines the strategy to sample data from the dataset. | |||||
| If specified, :attr:`shuffle` must be ``False``. | |||||
| :type transform: Transform | |||||
| :param transform: defined the transforming strategy for a sampled batch. | |||||
| (default: ``None``) | |||||
| :type collator: Collator | |||||
| :param collator: defined the merging strategy for a transformed batch. | |||||
| (default: ``None``) | |||||
| :type num_workers: int | |||||
| :param num_workers: the number of sub-process to load, transform and collate | |||||
| the batch. ``0`` means using single-process. (default: ``0``) | |||||
| :type timeout: int | |||||
| :param timeout: if positive, means the timeout value(second) for collecting a | |||||
| batch from workers. (default: 0) | |||||
| :type divide: bool | |||||
| :param divide: define the paralleling strategy in multi-processing mode. | |||||
| ``True`` means one batch is divided into :attr:`num_workers` pieces, and | |||||
| the workers will process these pieces parallelly. ``False`` means | |||||
| different sub-process will process different batch. (default: ``False``) | |||||
| """ | |||||
| if num_workers < 0: | |||||
| raise ValueError("num_workers should not be negative") | |||||
| if timeout < 0: | |||||
| raise ValueError("timeout should not be negative") | |||||
| if divide and num_workers <= 1: | |||||
| raise ValueError("divide should not be set to True when num_workers <= 1") | |||||
| self.dataset = dataset | |||||
| self.num_workers = num_workers | |||||
| self.timeout = timeout | |||||
| self.divide = divide | |||||
| if sampler is None: | |||||
| self.sampler = SequentialSampler(dataset, batch_size=1, drop_last=False) | |||||
| else: | |||||
| self.sampler = sampler | |||||
| if divide: | |||||
| if self.sampler.batch_size <= self.num_workers: | |||||
| raise ValueError( | |||||
| "batch size must not smaller than num_workers in divide mode." | |||||
| ) | |||||
| elif self.sampler.batch_size % self.num_workers: | |||||
| logger.warning( | |||||
| "batch size is not divisible by num_workers, may lose performance in divide mode." | |||||
| ) | |||||
| if transform is None: | |||||
| self.transform = PseudoTransform() | |||||
| else: | |||||
| self.transform = transform | |||||
| if collator is None: | |||||
| self.collator = Collator() | |||||
| else: | |||||
| self.collator = collator | |||||
| self.__initialized = True | |||||
| def __iter__(self): | |||||
| if self.num_workers == 0: | |||||
| return _SerialDataLoaderIter(self) | |||||
| else: | |||||
| return _ParallelDataLoaderIter(self) | |||||
| def __len__(self): | |||||
| return len(self.sampler) | |||||
| class _BaseDataLoaderIter: | |||||
| def __init__(self, loader): | |||||
| self.dataset = loader.dataset | |||||
| self.sampler = loader.sampler | |||||
| self.seed = _random_seed_generator().__next__() | |||||
| self.transform = loader.transform | |||||
| self.collator = loader.collator | |||||
| self.num_workers = loader.num_workers | |||||
| self.timeout = loader.timeout | |||||
| self.divide = loader.divide | |||||
| self.num_processed = 0 | |||||
| def _get_next_batch(self): | |||||
| raise NotImplementedError | |||||
| def __len__(self): | |||||
| return len(self.sampler) | |||||
| def __iter__(self): | |||||
| return self | |||||
| def __next__(self): | |||||
| if self.num_processed >= len(self): | |||||
| raise StopIteration | |||||
| minibatch = self._get_next_batch() | |||||
| self.num_processed += 1 | |||||
| return minibatch | |||||
| class _SerialDataLoaderIter(_BaseDataLoaderIter): | |||||
| def __init__(self, loader): | |||||
| super(_SerialDataLoaderIter, self).__init__(loader) | |||||
| self.indices_iter = iter(self.sampler) | |||||
| def _get_next_batch(self): | |||||
| indices = next(self.indices_iter) | |||||
| items = [self.dataset[idx] for idx in indices] | |||||
| trans_items = self.transform.apply_batch(items) | |||||
| return self.collator.apply(trans_items) | |||||
| class _ParallelDataLoaderIter(_BaseDataLoaderIter): | |||||
| __initialized = False | |||||
| def __init__(self, loader): | |||||
| super(_ParallelDataLoaderIter, self).__init__(loader) | |||||
| self.task_queues = [ | |||||
| multiprocessing.Queue(maxsize=2) for _ in range(self.num_workers) | |||||
| ] | |||||
| self.feed_batch_idx = multiprocessing.Value("i", 0) | |||||
| self.target_batch_idx = multiprocessing.Value("i", 0) | |||||
| self.shutdown_flag = multiprocessing.Value("i", 0) | |||||
| self.trans_data_queues = [ | |||||
| multiprocessing.Queue(maxsize=1) for _ in range(self.num_workers) | |||||
| ] | |||||
| # use shared-memory queue implemented by pyarrow plasma store. | |||||
| from ._queue import PlasmaShmQueue | |||||
| self.batch_queue = PlasmaShmQueue(maxsize=2) | |||||
| self.task_feeding_worker = multiprocessing.Process( | |||||
| target=_task_feeding_loop, | |||||
| args=( | |||||
| iter(self.sampler), | |||||
| self.task_queues, | |||||
| self.num_workers, | |||||
| self.divide, | |||||
| self.shutdown_flag, | |||||
| self.feed_batch_idx, | |||||
| ), | |||||
| daemon=True, | |||||
| ) | |||||
| self.task_feeding_worker.start() | |||||
| self.workers = [] | |||||
| for worker_id in range(self.num_workers): | |||||
| worker = multiprocessing.Process( | |||||
| target=_worker_loop, | |||||
| args=( | |||||
| self.dataset, | |||||
| self.task_queues[worker_id], | |||||
| self.trans_data_queues[worker_id], | |||||
| self.transform, | |||||
| self.seed + worker_id + 1, | |||||
| self.shutdown_flag, | |||||
| ), | |||||
| daemon=True, | |||||
| ) | |||||
| worker.start() | |||||
| self.workers.append(worker) | |||||
| if self.divide: | |||||
| self.data_collecting_worker = multiprocessing.Process( | |||||
| target=_data_gathering_loop, | |||||
| args=( | |||||
| self.trans_data_queues, | |||||
| self.batch_queue, | |||||
| self.collator, | |||||
| len(self), | |||||
| self.num_workers, | |||||
| self.shutdown_flag, | |||||
| self.target_batch_idx, | |||||
| ), | |||||
| daemon=True, | |||||
| ) | |||||
| else: | |||||
| self.data_collecting_worker = multiprocessing.Process( | |||||
| target=_data_selecting_loop, | |||||
| args=( | |||||
| self.trans_data_queues, | |||||
| self.batch_queue, | |||||
| self.collator, | |||||
| len(self), | |||||
| self.num_workers, | |||||
| self.shutdown_flag, | |||||
| self.target_batch_idx, | |||||
| ), | |||||
| daemon=True, | |||||
| ) | |||||
| self.data_collecting_worker.start() | |||||
| self.__initialized = True | |||||
| def _check_workers(self): | |||||
| # Check the status of each worker. | |||||
| if not self.data_collecting_worker.is_alive(): | |||||
| exitcode = self.task_feeding_worker.exitcode | |||||
| if exitcode != 0: | |||||
| raise RuntimeError("data collecting worker died. {}".format(exitcode)) | |||||
| if not self.task_feeding_worker.is_alive(): | |||||
| exitcode = self.task_feeding_worker.exitcode | |||||
| if exitcode != 0: | |||||
| raise RuntimeError("task feeding worker died. {}".format(exitcode)) | |||||
| for worker_id, worker in enumerate(self.workers): | |||||
| if not worker.is_alive(): | |||||
| exitcode = worker.exitcode | |||||
| if exitcode != 0: | |||||
| raise RuntimeError("worker:{} died. {}".format(worker_id, exitcode)) | |||||
| logger.debug("all workers are alive.") | |||||
| def _try_get_next_batch(self): | |||||
| start_time = time.time() | |||||
| while True: | |||||
| self._check_workers() | |||||
| try: | |||||
| return self.batch_queue.get(timeout=1) | |||||
| except queue.Empty: | |||||
| logger.debug("batch queue empty!") | |||||
| waited_time = time.time() - start_time | |||||
| if self.timeout > 0: | |||||
| if waited_time > self.timeout: | |||||
| raise RuntimeError("get_next_batch timeout!") | |||||
| def _get_next_batch(self): | |||||
| batch_data = self._try_get_next_batch() | |||||
| return batch_data | |||||
| def _shutdown(self): | |||||
| with self.shutdown_flag.get_lock(): | |||||
| self.shutdown_flag.value = 1 | |||||
| if self.task_feeding_worker.is_alive(): | |||||
| self.task_feeding_worker.terminate() | |||||
| self.task_feeding_worker.join() | |||||
| if self.data_collecting_worker.is_alive(): | |||||
| self.data_collecting_worker.terminate() | |||||
| self.data_collecting_worker.join() | |||||
| for worker in self.workers: | |||||
| if worker.is_alive(): | |||||
| worker.terminate() | |||||
| worker.join() | |||||
| for q in self.trans_data_queues: | |||||
| q.cancel_join_thread() | |||||
| q.close() | |||||
| for q in self.task_queues: | |||||
| q.cancel_join_thread() | |||||
| q.close() | |||||
| self.batch_queue.cancel_join_thread() | |||||
| self.batch_queue.close() | |||||
| def __del__(self): | |||||
| if self.__initialized: | |||||
| self._shutdown() | |||||
| def _task_feeding_loop( | |||||
| indices_iter, task_queues, num_workers, divide, shutdown_flag, feed_batch_idx | |||||
| ): | |||||
| # Feed the indices into the task queues | |||||
| while True: | |||||
| if shutdown_flag.value == 1: | |||||
| break | |||||
| batch_idx = feed_batch_idx.value | |||||
| try: | |||||
| indices = next(indices_iter) | |||||
| except StopIteration: | |||||
| break | |||||
| if divide: | |||||
| # make sure all task_queues is ready for put | |||||
| while any([q.full() for q in task_queues]): | |||||
| if shutdown_flag.value == 1: | |||||
| return | |||||
| # divide into small pieces, feed to different workers. | |||||
| sub_num = math.ceil(len(indices) / num_workers) | |||||
| for worker_id in range(num_workers): | |||||
| sub_indices = indices[worker_id * sub_num : (worker_id + 1) * sub_num] | |||||
| task_queues[worker_id].put((batch_idx, sub_indices)) | |||||
| else: | |||||
| # distribute tasks to different workers uniformly. | |||||
| target_id = batch_idx % num_workers | |||||
| while task_queues[target_id].full(): | |||||
| if shutdown_flag.value == 1: | |||||
| return | |||||
| task_queues[target_id].put((batch_idx, indices)) | |||||
| with feed_batch_idx.get_lock(): | |||||
| feed_batch_idx.value += 1 | |||||
| def _worker_loop(dataset, task_queue, trans_data_queue, transform, seed, shutdown_flag): | |||||
| # Get dataset items and do the transform | |||||
| random.seed(seed) | |||||
| np.random.seed(seed) | |||||
| while True: | |||||
| if shutdown_flag.value == 1: | |||||
| break | |||||
| try: | |||||
| batch_idx, indices = task_queue.get(timeout=MP_QUEUE_GET_TIMEOUT) | |||||
| except queue.Empty: | |||||
| continue | |||||
| if len(indices) > 0: | |||||
| items = [dataset[idx] for idx in indices] | |||||
| trans_items = transform.apply_batch(items) | |||||
| else: | |||||
| # in case of incomplete last batch | |||||
| trans_items = () | |||||
| while True: | |||||
| try: | |||||
| trans_data_queue.put((batch_idx, trans_items), timeout=1) | |||||
| break | |||||
| except queue.Full: | |||||
| if shutdown_flag.value == 1: | |||||
| break | |||||
| logger.debug("batch part queue is full!") | |||||
| def _data_gathering_loop( | |||||
| trans_data_queues, | |||||
| batch_queue, | |||||
| collator, | |||||
| length, | |||||
| num_workers, | |||||
| shutdown_flag, | |||||
| target_idx, | |||||
| ): | |||||
| # Gathering the small pieces of batch data into full batch data | |||||
| while True: | |||||
| if shutdown_flag.value == 1: | |||||
| break | |||||
| target_batch_idx = target_idx.value | |||||
| if target_batch_idx >= length: | |||||
| break | |||||
| full_trans_items = [] | |||||
| for worker_id in range(num_workers): | |||||
| while True: | |||||
| try: | |||||
| batch_idx, trans_items = trans_data_queues[worker_id].get( | |||||
| timeout=MP_QUEUE_GET_TIMEOUT | |||||
| ) | |||||
| break | |||||
| except queue.Empty: | |||||
| if shutdown_flag.value == 1: | |||||
| break | |||||
| logger.debug( | |||||
| "worker:{} data queue get timeout! target batch idx:{}".format( | |||||
| worker_id, target_batch_idx | |||||
| ) | |||||
| ) | |||||
| if batch_idx != target_batch_idx: | |||||
| raise RuntimeError( | |||||
| "Unexperted batch_idx in data gathering loop. worker_id:{}.".format( | |||||
| worker_id | |||||
| ) | |||||
| ) | |||||
| else: | |||||
| full_trans_items.extend(trans_items) | |||||
| # Merge different parts into a batch. | |||||
| full_batch = collator.apply(full_trans_items) | |||||
| while True: | |||||
| try: | |||||
| batch_queue.put(full_batch, timeout=1) | |||||
| break | |||||
| except queue.Full: | |||||
| if shutdown_flag.value == 1: | |||||
| break | |||||
| logger.debug("batch queue is full!") | |||||
| with target_idx.get_lock(): | |||||
| target_idx.value += 1 | |||||
| batch_queue.disconnect_client() | |||||
| def _data_selecting_loop( | |||||
| trans_data_queues, | |||||
| batch_queue, | |||||
| collator, | |||||
| length, | |||||
| num_workers, | |||||
| shutdown_flag, | |||||
| target_idx, | |||||
| ): | |||||
| # Make sure that batch is generated exactly with the same order as generated indices | |||||
| while True: | |||||
| if shutdown_flag.value == 1: | |||||
| break | |||||
| target_batch_idx = target_idx.value | |||||
| if target_batch_idx >= length: | |||||
| break | |||||
| target_worker_id = target_batch_idx % num_workers | |||||
| while True: | |||||
| try: | |||||
| batch_idx, trans_items = trans_data_queues[target_worker_id].get( | |||||
| timeout=MP_QUEUE_GET_TIMEOUT | |||||
| ) | |||||
| batch_data = collator.apply(trans_items) | |||||
| break | |||||
| except queue.Empty: | |||||
| if shutdown_flag.value == 1: | |||||
| break | |||||
| logger.debug( | |||||
| "worker:{} data queue get timeout! target batch idx:{}".format( | |||||
| target_worker_id, target_batch_idx | |||||
| ) | |||||
| ) | |||||
| if batch_idx != target_batch_idx: | |||||
| raise RuntimeError( | |||||
| "batch_idx {} mismatch the target_batch_idx {}".format( | |||||
| batch_idx, target_batch_idx | |||||
| ) | |||||
| ) | |||||
| while True: | |||||
| try: | |||||
| batch_queue.put(batch_data, timeout=1) | |||||
| break | |||||
| except queue.Full: | |||||
| if shutdown_flag.value == 1: | |||||
| break | |||||
| logger.debug("batch queue is full!") | |||||
| with target_idx.get_lock(): | |||||
| target_idx.value += 1 | |||||
| batch_queue.disconnect_client() | |||||
| @@ -0,0 +1,10 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| from .meta_dataset import ArrayDataset, Dataset, MapDataset, StreamDataset | |||||
| from .vision import * | |||||
| @@ -0,0 +1,73 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| from abc import ABC, abstractmethod | |||||
| from typing import Tuple | |||||
| class Dataset(ABC): | |||||
| r""" | |||||
| An abstract class for all Datasets | |||||
| """ | |||||
| @abstractmethod | |||||
| def __init__(self): | |||||
| pass | |||||
| class MapDataset(Dataset): | |||||
| r""" | |||||
| An abstract class for map data | |||||
| __getitem__ and __len__ method are aditionally needed | |||||
| """ | |||||
| @abstractmethod | |||||
| def __init__(self): | |||||
| pass | |||||
| @abstractmethod | |||||
| def __getitem__(self, index): | |||||
| pass | |||||
| @abstractmethod | |||||
| def __len__(self): | |||||
| pass | |||||
| class StreamDataset(Dataset): | |||||
| r""" | |||||
| An abstract class for stream data | |||||
| __iter__ method is aditionally needed | |||||
| """ | |||||
| @abstractmethod | |||||
| def __init__(self): | |||||
| pass | |||||
| @abstractmethod | |||||
| def __iter__(self): | |||||
| pass | |||||
| class ArrayDataset(MapDataset): | |||||
| def __init__(self, *arrays): | |||||
| r""" | |||||
| ArrayDataset is a dataset for numpy array data, one or more numpy arrays | |||||
| are needed to initiate the dataset. And the dimensions represented sample number | |||||
| are expected to be the same. | |||||
| """ | |||||
| super().__init__() | |||||
| if not all(len(arrays[0]) == len(array) for array in arrays): | |||||
| raise ValueError("lengths of input arrays are inconsistent") | |||||
| self.arrays = arrays | |||||
| def __getitem__(self, index: int) -> Tuple: | |||||
| return tuple(array[index] for array in self.arrays) | |||||
| def __len__(self) -> int: | |||||
| return len(self.arrays[0]) | |||||
| @@ -0,0 +1,17 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| from .cifar import CIFAR10, CIFAR100 | |||||
| from .cityscapes import Cityscapes | |||||
| from .coco import COCO | |||||
| from .folder import ImageFolder | |||||
| from .imagenet import ImageNet | |||||
| from .meta_vision import VisionDataset | |||||
| from .mnist import MNIST | |||||
| from .objects365 import Objects365 | |||||
| from .voc import PascalVOC | |||||
| @@ -0,0 +1,171 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| import os | |||||
| import pickle | |||||
| import tarfile | |||||
| from typing import Tuple | |||||
| import numpy as np | |||||
| from ....logger import get_logger | |||||
| from .meta_vision import VisionDataset | |||||
| from .utils import _default_dataset_root, load_raw_data_from_url | |||||
| logger = get_logger(__name__) | |||||
| class CIFAR10(VisionDataset): | |||||
| r""" ``Dataset`` for CIFAR10 meta data | |||||
| """ | |||||
| url_path = "http://www.cs.utoronto.ca/~kriz/" | |||||
| raw_file_name = "cifar-10-python.tar.gz" | |||||
| raw_file_md5 = "c58f30108f718f92721af3b95e74349a" | |||||
| raw_file_dir = "cifar-10-batches-py" | |||||
| train_batch = [ | |||||
| "data_batch_1", | |||||
| "data_batch_2", | |||||
| "data_batch_3", | |||||
| "data_batch_4", | |||||
| "data_batch_5", | |||||
| ] | |||||
| test_batch = ["test_batch"] | |||||
| meta_info = {"name": "batches.meta"} | |||||
| def __init__( | |||||
| self, | |||||
| root: str = None, | |||||
| train: bool = True, | |||||
| download: bool = True, | |||||
| timeout: int = 500, | |||||
| ): | |||||
| super().__init__(root, order=("image", "image_category")) | |||||
| self.timeout = timeout | |||||
| # process the root path | |||||
| if root is None: | |||||
| self.root = self._default_root | |||||
| if not os.path.exists(self.root): | |||||
| os.makedirs(self.root) | |||||
| else: | |||||
| self.root = root | |||||
| if not os.path.exists(self.root): | |||||
| if download: | |||||
| logger.debug( | |||||
| "dir %s does not exist, will be automatically created", | |||||
| self.root, | |||||
| ) | |||||
| os.makedirs(self.root) | |||||
| else: | |||||
| raise ValueError("dir %s does not exist" % self.root) | |||||
| self.target_file = os.path.join(self.root, self.raw_file_dir) | |||||
| # check existence of target pickle dir, if exists load the | |||||
| # pickle file no matter what download is set | |||||
| if os.path.exists(self.target_file): | |||||
| if train: | |||||
| self.arrays = self.bytes2array(self.train_batch) | |||||
| else: | |||||
| self.arrays = self.bytes2array(self.test_batch) | |||||
| else: | |||||
| if download: | |||||
| self.download() | |||||
| if train: | |||||
| self.arrays = self.bytes2array(self.train_batch) | |||||
| else: | |||||
| self.arrays = self.bytes2array(self.test_batch) | |||||
| else: | |||||
| raise ValueError( | |||||
| "dir does not contain target file %s, please set download=True" | |||||
| % (self.target_file) | |||||
| ) | |||||
| def __getitem__(self, index: int) -> Tuple: | |||||
| return tuple(array[index] for array in self.arrays) | |||||
| def __len__(self) -> int: | |||||
| return len(self.arrays[0]) | |||||
| @property | |||||
| def _default_root(self): | |||||
| return os.path.join(_default_dataset_root(), self.__class__.__name__) | |||||
| @property | |||||
| def meta(self): | |||||
| meta_path = os.path.join(self.root, self.raw_file_dir, self.meta_info["name"]) | |||||
| with open(meta_path, "rb") as f: | |||||
| meta = pickle.load(f, encoding="bytes") | |||||
| return meta | |||||
| def download(self): | |||||
| url = self.url_path + self.raw_file_name | |||||
| load_raw_data_from_url( | |||||
| url, self.raw_file_name, self.raw_file_md5, self.root, self.timeout | |||||
| ) | |||||
| self.process() | |||||
| def untar(self, file_path, dirs): | |||||
| assert file_path.endswith(".tar.gz") | |||||
| logger.debug("untar file %s to %s", file_path, dirs) | |||||
| t = tarfile.open(file_path) | |||||
| t.extractall(path=dirs) | |||||
| def bytes2array(self, filenames): | |||||
| data = [] | |||||
| label = [] | |||||
| for filename in filenames: | |||||
| path = os.path.join(self.root, self.raw_file_dir, filename) | |||||
| logger.debug("unpickle file %s", path) | |||||
| with open(path, "rb") as fo: | |||||
| dic = pickle.load(fo, encoding="bytes") | |||||
| batch_data = dic[b"data"].reshape(-1, 3, 32, 32).transpose((0, 2, 3, 1)) | |||||
| data.extend(list(batch_data[..., [2, 1, 0]])) | |||||
| label.extend(dic[b"labels"]) | |||||
| label = np.array(label, dtype=np.int32) | |||||
| return (data, label) | |||||
| def process(self): | |||||
| logger.info("process raw data ...") | |||||
| self.untar(os.path.join(self.root, self.raw_file_name), self.root) | |||||
| class CIFAR100(CIFAR10): | |||||
| url_path = "http://www.cs.utoronto.ca/~kriz/" | |||||
| raw_file_name = "cifar-100-python.tar.gz" | |||||
| raw_file_md5 = "eb9058c3a382ffc7106e4002c42a8d85" | |||||
| raw_file_dir = "cifar-100-python" | |||||
| train_batch = ["train"] | |||||
| test_batch = ["test"] | |||||
| meta_info = {"name": "meta"} | |||||
| @property | |||||
| def meta(self): | |||||
| meta_path = os.path.join(self.root, self.raw_file_dir, self.meta_info["name"]) | |||||
| with open(meta_path, "rb") as f: | |||||
| meta = pickle.load(f, encoding="bytes") | |||||
| return meta | |||||
| def bytes2array(self, filenames): | |||||
| data = [] | |||||
| fine_label = [] | |||||
| coarse_label = [] | |||||
| for filename in filenames: | |||||
| path = os.path.join(self.root, self.raw_file_dir, filename) | |||||
| logger.debug("unpickle file %s", path) | |||||
| with open(path, "rb") as fo: | |||||
| dic = pickle.load(fo, encoding="bytes") | |||||
| batch_data = dic[b"data"].reshape(-1, 3, 32, 32).transpose((0, 2, 3, 1)) | |||||
| data.extend(list(batch_data[..., [2, 1, 0]])) | |||||
| fine_label.extend(dic[b"fine_labels"]) | |||||
| coarse_label.extend(dic[b"coarse_labels"]) | |||||
| fine_label = np.array(fine_label, dtype=np.int32) | |||||
| coarse_label = np.array(coarse_label, dtype=np.int32) | |||||
| return data, fine_label, coarse_label | |||||
| @@ -0,0 +1,151 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # --------------------------------------------------------------------- | |||||
| # Part of the following code in this file refs to torchvision | |||||
| # BSD 3-Clause License | |||||
| # | |||||
| # Copyright (c) Soumith Chintala 2016, | |||||
| # All rights reserved. | |||||
| # --------------------------------------------------------------------- | |||||
| import json | |||||
| import os | |||||
| import cv2 | |||||
| import numpy as np | |||||
| from .meta_vision import VisionDataset | |||||
| class Cityscapes(VisionDataset): | |||||
| r"""`Cityscapes <http://www.cityscapes-dataset.com/>`_ Dataset. | |||||
| """ | |||||
| supported_order = ( | |||||
| "image", | |||||
| "mask", | |||||
| "info", | |||||
| ) | |||||
| def __init__(self, root, image_set, mode, *, order=None): | |||||
| super().__init__(root, order=order, supported_order=self.supported_order) | |||||
| city_root = self.root | |||||
| if not os.path.isdir(city_root): | |||||
| raise RuntimeError("Dataset not found or corrupted.") | |||||
| self.mode = mode | |||||
| self.images_dir = os.path.join(city_root, "leftImg8bit", image_set) | |||||
| self.masks_dir = os.path.join(city_root, self.mode, image_set) | |||||
| self.images, self.masks = [], [] | |||||
| # self.target_type = ["instance", "semantic", "polygon", "color"] | |||||
| # for semantic segmentation | |||||
| if mode == "gtFine": | |||||
| valid_modes = ("train", "test", "val") | |||||
| else: | |||||
| valid_modes = ("train", "train_extra", "val") | |||||
| for city in os.listdir(self.images_dir): | |||||
| img_dir = os.path.join(self.images_dir, city) | |||||
| mask_dir = os.path.join(self.masks_dir, city) | |||||
| for file_name in os.listdir(img_dir): | |||||
| mask_name = "{}_{}".format( | |||||
| file_name.split("_leftImg8bit")[0], | |||||
| self._get_target_suffix(self.mode, "semantic"), | |||||
| ) | |||||
| self.images.append(os.path.join(img_dir, file_name)) | |||||
| self.masks.append(os.path.join(mask_dir, mask_name)) | |||||
| def __getitem__(self, index): | |||||
| target = [] | |||||
| for k in self.order: | |||||
| if k == "image": | |||||
| image = cv2.imread(self.images[index], cv2.IMREAD_COLOR) | |||||
| target.append(image) | |||||
| elif k == "mask": | |||||
| mask = cv2.imread(self.masks[index], cv2.IMREAD_GRAYSCALE) | |||||
| mask = self._trans_mask(mask) | |||||
| mask = mask[:, :, np.newaxis] | |||||
| target.append(mask) | |||||
| elif k == "info": | |||||
| if image is None: | |||||
| image = cv2.imread(self.images[index], cv2.IMREAD_COLOR) | |||||
| info = [image.shape[0], image.shape[1], self.images[index]] | |||||
| target.append(info) | |||||
| else: | |||||
| raise NotImplementedError | |||||
| return tuple(target) | |||||
| def __len__(self): | |||||
| return len(self.images) | |||||
| def _trans_mask(self, mask): | |||||
| trans_labels = [ | |||||
| 7, | |||||
| 8, | |||||
| 11, | |||||
| 12, | |||||
| 13, | |||||
| 17, | |||||
| 19, | |||||
| 20, | |||||
| 21, | |||||
| 22, | |||||
| 23, | |||||
| 24, | |||||
| 25, | |||||
| 26, | |||||
| 27, | |||||
| 28, | |||||
| 31, | |||||
| 32, | |||||
| 33, | |||||
| ] | |||||
| label = np.ones(mask.shape) * 255 | |||||
| for i, tl in enumerate(trans_labels): | |||||
| label[mask == tl] = i | |||||
| return label.astype(np.uint8) | |||||
| def _get_target_suffix(self, mode, target_type): | |||||
| if target_type == "instance": | |||||
| return "{}_instanceIds.png".format(mode) | |||||
| elif target_type == "semantic": | |||||
| return "{}_labelIds.png".format(mode) | |||||
| elif target_type == "color": | |||||
| return "{}_color.png".format(mode) | |||||
| else: | |||||
| return "{}_polygons.json".format(mode) | |||||
| def _load_json(self, path): | |||||
| with open(path, "r") as file: | |||||
| data = json.load(file) | |||||
| return data | |||||
| class_names = ( | |||||
| "road", | |||||
| "sidewalk", | |||||
| "building", | |||||
| "wall", | |||||
| "fence", | |||||
| "pole", | |||||
| "traffic light", | |||||
| "traffic sign", | |||||
| "vegetation", | |||||
| "terrain", | |||||
| "sky", | |||||
| "person", | |||||
| "rider", | |||||
| "car", | |||||
| "truck", | |||||
| "bus", | |||||
| "train", | |||||
| "motorcycle", | |||||
| "bicycle", | |||||
| ) | |||||
| @@ -0,0 +1,366 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # --------------------------------------------------------------------- | |||||
| # Part of the following code in this file refs to maskrcnn-benchmark | |||||
| # MIT License | |||||
| # | |||||
| # Copyright (c) 2018 Facebook | |||||
| # --------------------------------------------------------------------- | |||||
| import json | |||||
| import os | |||||
| from collections import defaultdict | |||||
| import cv2 | |||||
| import numpy as np | |||||
| from .meta_vision import VisionDataset | |||||
| min_keypoints_per_image = 10 | |||||
| def _count_visible_keypoints(anno): | |||||
| return sum(sum(1 for v in ann["keypoints"][2::3] if v > 0) for ann in anno) | |||||
| def has_valid_annotation(anno, order): | |||||
| # if it"s empty, there is no annotation | |||||
| if len(anno) == 0: | |||||
| return False | |||||
| if "boxes" in order or "boxes_category" in order: | |||||
| if "bbox" not in anno[0]: | |||||
| return False | |||||
| if "keypoints" in order: | |||||
| if "keypoints" not in anno[0]: | |||||
| return False | |||||
| # for keypoint detection tasks, only consider valid images those | |||||
| # containing at least min_keypoints_per_image | |||||
| if _count_visible_keypoints(anno) < min_keypoints_per_image: | |||||
| return False | |||||
| return True | |||||
| class COCO(VisionDataset): | |||||
| r"""`MS COCO <http://cocodataset.org/#home>`_ Dataset. | |||||
| """ | |||||
| supported_order = ( | |||||
| "image", | |||||
| "boxes", | |||||
| "boxes_category", | |||||
| "keypoints", | |||||
| # TODO: need to check | |||||
| # "polygons", | |||||
| "info", | |||||
| ) | |||||
| def __init__( | |||||
| self, root, ann_file, remove_images_without_annotations=False, *, order=None | |||||
| ): | |||||
| super().__init__(root, order=order, supported_order=self.supported_order) | |||||
| with open(ann_file, "r") as f: | |||||
| dataset = json.load(f) | |||||
| self.imgs = dict() | |||||
| for img in dataset["images"]: | |||||
| # for saving memory | |||||
| if "license" in img: | |||||
| del img["license"] | |||||
| if "coco_url" in img: | |||||
| del img["coco_url"] | |||||
| if "date_captured" in img: | |||||
| del img["date_captured"] | |||||
| if "flickr_url" in img: | |||||
| del img["flickr_url"] | |||||
| self.imgs[img["id"]] = img | |||||
| self.img_to_anns = defaultdict(list) | |||||
| for ann in dataset["annotations"]: | |||||
| # for saving memory | |||||
| if ( | |||||
| "boxes" not in self.order | |||||
| and "boxes_category" not in self.order | |||||
| and "bbox" in ann | |||||
| ): | |||||
| del ann["bbox"] | |||||
| if "polygons" not in self.order and "segmentation" in ann: | |||||
| del ann["segmentation"] | |||||
| self.img_to_anns[ann["image_id"]].append(ann) | |||||
| self.cats = dict() | |||||
| for cat in dataset["categories"]: | |||||
| self.cats[cat["id"]] = cat | |||||
| self.ids = list(sorted(self.imgs.keys())) | |||||
| # filter images without detection annotations | |||||
| if remove_images_without_annotations: | |||||
| ids = [] | |||||
| for img_id in self.ids: | |||||
| anno = self.img_to_anns[img_id] | |||||
| # filter crowd annotations | |||||
| anno = [obj for obj in anno if obj["iscrowd"] == 0] | |||||
| anno = [ | |||||
| obj for obj in anno if obj["bbox"][2] > 0 and obj["bbox"][3] > 0 | |||||
| ] | |||||
| if has_valid_annotation(anno, order): | |||||
| ids.append(img_id) | |||||
| self.img_to_anns[img_id] = anno | |||||
| else: | |||||
| del self.imgs[img_id] | |||||
| del self.img_to_anns[img_id] | |||||
| self.ids = ids | |||||
| self.json_category_id_to_contiguous_id = { | |||||
| v: i + 1 for i, v in enumerate(self.cats.keys()) | |||||
| } | |||||
| self.contiguous_category_id_to_json_id = { | |||||
| v: k for k, v in self.json_category_id_to_contiguous_id.items() | |||||
| } | |||||
| def __getitem__(self, index): | |||||
| img_id = self.ids[index] | |||||
| anno = self.img_to_anns[img_id] | |||||
| target = [] | |||||
| for k in self.order: | |||||
| if k == "image": | |||||
| file_name = self.imgs[img_id]["file_name"] | |||||
| path = os.path.join(self.root, file_name) | |||||
| image = cv2.imread(path, cv2.IMREAD_COLOR) | |||||
| target.append(image) | |||||
| elif k == "boxes": | |||||
| boxes = [obj["bbox"] for obj in anno] | |||||
| boxes = np.array(boxes, dtype=np.float32).reshape(-1, 4) | |||||
| # transfer boxes from xywh to xyxy | |||||
| boxes[:, 2:] += boxes[:, :2] | |||||
| target.append(boxes) | |||||
| elif k == "boxes_category": | |||||
| boxes_category = [obj["category_id"] for obj in anno] | |||||
| boxes_category = [ | |||||
| self.json_category_id_to_contiguous_id[c] for c in boxes_category | |||||
| ] | |||||
| boxes_category = np.array(boxes_category, dtype=np.int32) | |||||
| target.append(boxes_category) | |||||
| elif k == "keypoints": | |||||
| keypoints = [obj["keypoints"] for obj in anno] | |||||
| keypoints = np.array(keypoints, dtype=np.float32).reshape( | |||||
| -1, len(self.keypoint_names), 3 | |||||
| ) | |||||
| target.append(keypoints) | |||||
| elif k == "polygons": | |||||
| polygons = [obj["segmentation"] for obj in anno] | |||||
| polygons = [ | |||||
| [np.array(p, dtype=np.float32).reshape(-1, 2) for p in ps] | |||||
| for ps in polygons | |||||
| ] | |||||
| target.append(polygons) | |||||
| elif k == "info": | |||||
| info = self.imgs[img_id] | |||||
| info = [info["height"], info["width"], info["file_name"]] | |||||
| target.append(info) | |||||
| else: | |||||
| raise NotImplementedError | |||||
| return tuple(target) | |||||
| def __len__(self): | |||||
| return len(self.ids) | |||||
| def get_img_info(self, index): | |||||
| img_id = self.ids[index] | |||||
| img_info = self.imgs[img_id] | |||||
| return img_info | |||||
| class_names = ( | |||||
| "person", | |||||
| "bicycle", | |||||
| "car", | |||||
| "motorcycle", | |||||
| "airplane", | |||||
| "bus", | |||||
| "train", | |||||
| "truck", | |||||
| "boat", | |||||
| "traffic light", | |||||
| "fire hydrant", | |||||
| "stop sign", | |||||
| "parking meter", | |||||
| "bench", | |||||
| "bird", | |||||
| "cat", | |||||
| "dog", | |||||
| "horse", | |||||
| "sheep", | |||||
| "cow", | |||||
| "elephant", | |||||
| "bear", | |||||
| "zebra", | |||||
| "giraffe", | |||||
| "backpack", | |||||
| "umbrella", | |||||
| "handbag", | |||||
| "tie", | |||||
| "suitcase", | |||||
| "frisbee", | |||||
| "skis", | |||||
| "snowboard", | |||||
| "sports ball", | |||||
| "kite", | |||||
| "baseball bat", | |||||
| "baseball glove", | |||||
| "skateboard", | |||||
| "surfboard", | |||||
| "tennis racket", | |||||
| "bottle", | |||||
| "wine glass", | |||||
| "cup", | |||||
| "fork", | |||||
| "knife", | |||||
| "spoon", | |||||
| "bowl", | |||||
| "banana", | |||||
| "apple", | |||||
| "sandwich", | |||||
| "orange", | |||||
| "broccoli", | |||||
| "carrot", | |||||
| "hot dog", | |||||
| "pizza", | |||||
| "donut", | |||||
| "cake", | |||||
| "chair", | |||||
| "couch", | |||||
| "potted plant", | |||||
| "bed", | |||||
| "dining table", | |||||
| "toilet", | |||||
| "tv", | |||||
| "laptop", | |||||
| "mouse", | |||||
| "remote", | |||||
| "keyboard", | |||||
| "cell phone", | |||||
| "microwave", | |||||
| "oven", | |||||
| "toaster", | |||||
| "sink", | |||||
| "refrigerator", | |||||
| "book", | |||||
| "clock", | |||||
| "vase", | |||||
| "scissors", | |||||
| "teddy bear", | |||||
| "hair drier", | |||||
| "toothbrush", | |||||
| ) | |||||
| classes_originID = { | |||||
| "person": 1, | |||||
| "bicycle": 2, | |||||
| "car": 3, | |||||
| "motorcycle": 4, | |||||
| "airplane": 5, | |||||
| "bus": 6, | |||||
| "train": 7, | |||||
| "truck": 8, | |||||
| "boat": 9, | |||||
| "traffic light": 10, | |||||
| "fire hydrant": 11, | |||||
| "stop sign": 13, | |||||
| "parking meter": 14, | |||||
| "bench": 15, | |||||
| "bird": 16, | |||||
| "cat": 17, | |||||
| "dog": 18, | |||||
| "horse": 19, | |||||
| "sheep": 20, | |||||
| "cow": 21, | |||||
| "elephant": 22, | |||||
| "bear": 23, | |||||
| "zebra": 24, | |||||
| "giraffe": 25, | |||||
| "backpack": 27, | |||||
| "umbrella": 28, | |||||
| "handbag": 31, | |||||
| "tie": 32, | |||||
| "suitcase": 33, | |||||
| "frisbee": 34, | |||||
| "skis": 35, | |||||
| "snowboard": 36, | |||||
| "sports ball": 37, | |||||
| "kite": 38, | |||||
| "baseball bat": 39, | |||||
| "baseball glove": 40, | |||||
| "skateboard": 41, | |||||
| "surfboard": 42, | |||||
| "tennis racket": 43, | |||||
| "bottle": 44, | |||||
| "wine glass": 46, | |||||
| "cup": 47, | |||||
| "fork": 48, | |||||
| "knife": 49, | |||||
| "spoon": 50, | |||||
| "bowl": 51, | |||||
| "banana": 52, | |||||
| "apple": 53, | |||||
| "sandwich": 54, | |||||
| "orange": 55, | |||||
| "broccoli": 56, | |||||
| "carrot": 57, | |||||
| "hot dog": 58, | |||||
| "pizza": 59, | |||||
| "donut": 60, | |||||
| "cake": 61, | |||||
| "chair": 62, | |||||
| "couch": 63, | |||||
| "potted plant": 64, | |||||
| "bed": 65, | |||||
| "dining table": 67, | |||||
| "toilet": 70, | |||||
| "tv": 72, | |||||
| "laptop": 73, | |||||
| "mouse": 74, | |||||
| "remote": 75, | |||||
| "keyboard": 76, | |||||
| "cell phone": 77, | |||||
| "microwave": 78, | |||||
| "oven": 79, | |||||
| "toaster": 80, | |||||
| "sink": 81, | |||||
| "refrigerator": 82, | |||||
| "book": 84, | |||||
| "clock": 85, | |||||
| "vase": 86, | |||||
| "scissors": 87, | |||||
| "teddy bear": 88, | |||||
| "hair drier": 89, | |||||
| "toothbrush": 90, | |||||
| } | |||||
| keypoint_names = ( | |||||
| "nose", | |||||
| "left_eye", | |||||
| "right_eye", | |||||
| "left_ear", | |||||
| "right_ear", | |||||
| "left_shoulder", | |||||
| "right_shoulder", | |||||
| "left_elbow", | |||||
| "right_elbow", | |||||
| "left_wrist", | |||||
| "right_wrist", | |||||
| "left_hip", | |||||
| "right_hip", | |||||
| "left_knee", | |||||
| "right_knee", | |||||
| "left_ankle", | |||||
| "right_ankle", | |||||
| ) | |||||
| @@ -0,0 +1,90 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # BSD 3-Clause License | |||||
| # Copyright (c) Soumith Chintala 2016, | |||||
| # All rights reserved. | |||||
| # --------------------------------------------------------------------- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # | |||||
| # This file has been modified by Megvii ("Megvii Modifications"). | |||||
| # All Megvii Modifications are Copyright (C) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # --------------------------------------------------------------------- | |||||
| import os | |||||
| from typing import Dict, List, Tuple | |||||
| import cv2 | |||||
| import numpy as np | |||||
| from .meta_vision import VisionDataset | |||||
| from .utils import is_img | |||||
| class ImageFolder(VisionDataset): | |||||
| def __init__(self, root: str, check_valid_func=None, class_name: bool = False): | |||||
| r""" | |||||
| ImageFolder is a class for loading image data and labels from a organized folder. | |||||
| the folder is expected to be organized as followed | |||||
| root/cls/xxx.img_ext | |||||
| labels are indices of sorted classes in the root directory | |||||
| :param root: root directory of an image folder | |||||
| :param loader: a function used to load image from path, | |||||
| if ``None``, default function that loads | |||||
| images with PILwill be called | |||||
| :param check_valid_func: a function used to check if files in folder are | |||||
| expected image files, if ``None``, default function | |||||
| that checks file extensions will be called | |||||
| :param class_name: if ``True``, return class name instead of class index | |||||
| """ | |||||
| super().__init__(root, order=("image", "image_category")) | |||||
| self.root = root | |||||
| if check_valid_func is not None: | |||||
| self.check_valid = check_valid_func | |||||
| else: | |||||
| self.check_valid = is_img | |||||
| self.class_name = class_name | |||||
| self.class_dict = self.collect_class() | |||||
| self.samples = self.collect_samples() | |||||
| def collect_samples(self) -> List: | |||||
| samples = [] | |||||
| directory = os.path.expanduser(self.root) | |||||
| for key in sorted(self.class_dict.keys()): | |||||
| d = os.path.join(directory, key) | |||||
| if not os.path.isdir(d): | |||||
| continue | |||||
| for r, _, filename in sorted(os.walk(d, followlinks=True)): | |||||
| for name in sorted(filename): | |||||
| path = os.path.join(r, name) | |||||
| if self.check_valid(path): | |||||
| if self.class_name: | |||||
| samples.append((path, key)) | |||||
| else: | |||||
| samples.append((path, self.class_dict[key])) | |||||
| return samples | |||||
| def collect_class(self) -> Dict: | |||||
| classes = [d.name for d in os.scandir(self.root) if d.is_dir()] | |||||
| classes.sort() | |||||
| return {classes[i]: np.int32(i) for i in range(len(classes))} | |||||
| def __getitem__(self, index: int) -> Tuple: | |||||
| path, label = self.samples[index] | |||||
| img = cv2.imread(path, cv2.IMREAD_COLOR) | |||||
| return img, label | |||||
| def __len__(self): | |||||
| return len(self.samples) | |||||
| @@ -0,0 +1,248 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # BSD 3-Clause License | |||||
| # | |||||
| # Copyright (c) Soumith Chintala 2016, | |||||
| # All rights reserved. | |||||
| # --------------------------------------------------------------------- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # | |||||
| # This file has been modified by Megvii ("Megvii Modifications"). | |||||
| # All Megvii Modifications are Copyright (C) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # --------------------------------------------------------------------- | |||||
| import os | |||||
| import shutil | |||||
| from tqdm import tqdm | |||||
| from ....distributed.group import is_distributed | |||||
| from ....logger import get_logger | |||||
| from ....serialization import load, save | |||||
| from .folder import ImageFolder | |||||
| from .utils import _default_dataset_root, calculate_md5, untar, untargz | |||||
| logger = get_logger(__name__) | |||||
| class ImageNet(ImageFolder): | |||||
| r""" | |||||
| Load ImageNet from raw files or folder, expected folder looks like | |||||
| .. code-block:: bash | |||||
| ${root}/ | |||||
| | [REQUIRED TAR FILES] | |||||
| |- ILSVRC2012_img_train.tar | |||||
| |- ILSVRC2012_img_val.tar | |||||
| |- ILSVRC2012_devkit_t12.tar.gz | |||||
| | [OPTIONAL IMAGE FOLDERS] | |||||
| |- train/cls/xxx.${img_ext} | |||||
| |- val/cls/xxx.${img_ext} | |||||
| |- ILSVRC2012_devkit_t12/data/meta.mat | |||||
| |- ILSVRC2012_devkit_t12/data/ILSVRC2012_validation_ground_truth.txt | |||||
| If the image folders don't exist, raw tar files are required to get extracted and processed. | |||||
| """ | |||||
| raw_file_meta = { | |||||
| "train": ("ILSVRC2012_img_train.tar", "1d675b47d978889d74fa0da5fadfb00e"), | |||||
| "val": ("ILSVRC2012_img_val.tar", "29b22e2961454d5413ddabcf34fc5622"), | |||||
| "devkit": ("ILSVRC2012_devkit_t12.tar.gz", "fa75699e90414af021442c21a62c3abf"), | |||||
| } # ImageNet raw files | |||||
| default_train_dir = "train" | |||||
| default_val_dir = "val" | |||||
| default_devkit_dir = "ILSVRC2012_devkit_t12" | |||||
| def __init__(self, root: str = None, train: bool = True, **kwargs): | |||||
| r""" | |||||
| initialization: | |||||
| * if ``root`` contains ``self.target_folder`` depent on ``train``: | |||||
| * initialize ImageFolder with target_folder | |||||
| * else: | |||||
| * if all raw files are in ``root``: | |||||
| * parse ``self.target_folder`` from raw files | |||||
| * initialize ImageFolder with ``self.target_folder`` | |||||
| * else: | |||||
| * raise error | |||||
| :param root: root directory of imagenet data, if root is ``None``, used default_dataset_root | |||||
| :param train: if ``True``, load the train split, otherwise load the validation split | |||||
| """ | |||||
| # process the root path | |||||
| if root is None: | |||||
| self.root = self._default_root | |||||
| else: | |||||
| self.root = root | |||||
| if not os.path.exists(self.root): | |||||
| raise FileNotFoundError("dir %s does not exist" % self.root) | |||||
| self.devkit_dir = os.path.join(self.root, self.default_devkit_dir) | |||||
| if not os.path.exists(self.devkit_dir): | |||||
| logger.warning("devkit directory %s does not exists", self.devkit_dir) | |||||
| self._prepare_devkit() | |||||
| self.train = train | |||||
| if train: | |||||
| self.target_folder = os.path.join(self.root, self.default_train_dir) | |||||
| else: | |||||
| self.target_folder = os.path.join(self.root, self.default_val_dir) | |||||
| if not os.path.exists(self.target_folder): | |||||
| logger.warning( | |||||
| "expected image folder %s does not exist, try to load from raw file", | |||||
| self.target_folder, | |||||
| ) | |||||
| if not self.check_raw_file(): | |||||
| raise FileNotFoundError( | |||||
| "expected image folder %s does not exist, and raw files do not exist in %s" | |||||
| % (self.target_folder, self.root) | |||||
| ) | |||||
| elif is_distributed(): | |||||
| raise RuntimeError( | |||||
| "extracting raw file shouldn't be done in distributed mode, use single process instead" | |||||
| ) | |||||
| elif train: | |||||
| self._prepare_train() | |||||
| else: | |||||
| self._prepare_val() | |||||
| super().__init__(self.target_folder, **kwargs) | |||||
| @property | |||||
| def _default_root(self): | |||||
| return os.path.join(_default_dataset_root(), self.__class__.__name__) | |||||
| @property | |||||
| def valid_ground_truth(self): | |||||
| groud_truth_path = os.path.join( | |||||
| self.devkit_dir, "data", "ILSVRC2012_validation_ground_truth.txt" | |||||
| ) | |||||
| if os.path.exists(groud_truth_path): | |||||
| with open(groud_truth_path, "r") as f: | |||||
| val_labels = f.readlines() | |||||
| return [int(val_label) for val_label in val_labels] | |||||
| else: | |||||
| raise FileNotFoundError( | |||||
| "valid ground truth file %s does not exist" % groud_truth_path | |||||
| ) | |||||
| @property | |||||
| def meta(self): | |||||
| try: | |||||
| return load(os.path.join(self.devkit_dir, "meta.pkl")) | |||||
| except FileNotFoundError: | |||||
| import scipy.io | |||||
| meta_path = os.path.join(self.devkit_dir, "data", "meta.mat") | |||||
| if not os.path.exists(meta_path): | |||||
| raise FileNotFoundError("meta file %s does not exist" % meta_path) | |||||
| meta = scipy.io.loadmat(meta_path, squeeze_me=True)["synsets"] | |||||
| nums_children = list(zip(*meta))[4] | |||||
| meta = [ | |||||
| meta[idx] | |||||
| for idx, num_children in enumerate(nums_children) | |||||
| if num_children == 0 | |||||
| ] | |||||
| idcs, wnids, classes = list(zip(*meta))[:3] | |||||
| classes = [tuple(clss.split(", ")) for clss in classes] | |||||
| idx_to_wnid = dict(zip(idcs, wnids)) | |||||
| wnid_to_classes = dict(zip(wnids, classes)) | |||||
| logger.info( | |||||
| "saving cached meta file to %s", | |||||
| os.path.join(self.devkit_dir, "meta.pkl"), | |||||
| ) | |||||
| save( | |||||
| (idx_to_wnid, wnid_to_classes), | |||||
| os.path.join(self.devkit_dir, "meta.pkl"), | |||||
| ) | |||||
| return idx_to_wnid, wnid_to_classes | |||||
| def check_raw_file(self) -> bool: | |||||
| return all( | |||||
| [ | |||||
| os.path.exists(os.path.join(self.root, value[0])) | |||||
| for _, value in self.raw_file_meta.items() | |||||
| ] | |||||
| ) | |||||
| def _organize_val_data(self): | |||||
| id2wnid = self.meta[0] | |||||
| val_idcs = self.valid_ground_truth | |||||
| val_wnids = [id2wnid[idx] for idx in val_idcs] | |||||
| val_images = sorted( | |||||
| [ | |||||
| os.path.join(self.target_folder, image) | |||||
| for image in os.listdir(self.target_folder) | |||||
| ] | |||||
| ) | |||||
| logger.debug("mkdir for val set wnids") | |||||
| for wnid in set(val_wnids): | |||||
| os.makedirs(os.path.join(self.root, self.default_val_dir, wnid)) | |||||
| logger.debug("mv val images into wnids dir") | |||||
| for wnid, img_file in tqdm(zip(val_wnids, val_images)): | |||||
| shutil.move( | |||||
| img_file, | |||||
| os.path.join( | |||||
| self.root, self.default_val_dir, wnid, os.path.basename(img_file) | |||||
| ), | |||||
| ) | |||||
| def _prepare_val(self): | |||||
| assert not self.train | |||||
| raw_filename, checksum = self.raw_file_meta["val"] | |||||
| raw_file = os.path.join(self.root, raw_filename) | |||||
| logger.info("checksum valid tar file %s ...", raw_file) | |||||
| assert ( | |||||
| calculate_md5(raw_file) == checksum | |||||
| ), "checksum mismatch, {} may be damaged".format(raw_file) | |||||
| logger.info("extract valid tar file... this may take 10-20 minutes") | |||||
| untar(os.path.join(self.root, raw_file), self.target_folder) | |||||
| self._organize_val_data() | |||||
| def _prepare_train(self): | |||||
| assert self.train | |||||
| raw_filename, checksum = self.raw_file_meta["train"] | |||||
| raw_file = os.path.join(self.root, raw_filename) | |||||
| logger.info("checksum train tar file %s ...", raw_file) | |||||
| assert ( | |||||
| calculate_md5(raw_file) == checksum | |||||
| ), "checksum mismatch, {} may be damaged".format(raw_file) | |||||
| logger.info("extract train tar file.. this may take several hours") | |||||
| untar( | |||||
| os.path.join(self.root, raw_file), self.target_folder, | |||||
| ) | |||||
| paths = [ | |||||
| os.path.join(self.target_folder, child_dir) | |||||
| for child_dir in os.listdir(self.target_folder) | |||||
| ] | |||||
| for path in tqdm(paths): | |||||
| untar(path, os.path.splitext(path)[0], remove=True) | |||||
| def _prepare_devkit(self): | |||||
| raw_filename, checksum = self.raw_file_meta["devkit"] | |||||
| raw_file = os.path.join(self.root, raw_filename) | |||||
| logger.info("checksum devkit tar file %s ...", raw_file) | |||||
| assert ( | |||||
| calculate_md5(raw_file) == checksum | |||||
| ), "checksum mismatch, {} may be damaged".format(raw_file) | |||||
| logger.info("extract devkit file..") | |||||
| untargz(os.path.join(self.root, self.raw_file_meta["devkit"][0])) | |||||
| @@ -0,0 +1,41 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| import collections.abc | |||||
| import os | |||||
| from ..meta_dataset import MapDataset | |||||
| class VisionDataset(MapDataset): | |||||
| _repr_indent = 4 | |||||
| def __init__(self, root, *, order=None, supported_order=None): | |||||
| if isinstance(root, (str, bytes)): | |||||
| root = os.path.expanduser(root) | |||||
| self.root = root | |||||
| if order is None: | |||||
| order = ("image",) | |||||
| if not isinstance(order, collections.abc.Sequence): | |||||
| raise ValueError( | |||||
| "order should be a sequence, but got order={}".format(order) | |||||
| ) | |||||
| if supported_order is not None: | |||||
| assert isinstance(supported_order, collections.abc.Sequence) | |||||
| for k in order: | |||||
| if k not in supported_order: | |||||
| raise NotImplementedError("{} is unsupported data type".format(k)) | |||||
| self.order = order | |||||
| def __getitem__(self, index): | |||||
| raise NotImplementedError | |||||
| def __len__(self): | |||||
| raise NotImplementedError | |||||
| @@ -0,0 +1,197 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| import gzip | |||||
| import os | |||||
| import struct | |||||
| from typing import Tuple | |||||
| import numpy as np | |||||
| from tqdm import tqdm | |||||
| from ....logger import get_logger | |||||
| from .meta_vision import VisionDataset | |||||
| from .utils import _default_dataset_root, load_raw_data_from_url | |||||
| logger = get_logger(__name__) | |||||
| class MNIST(VisionDataset): | |||||
| r""" ``Dataset`` for MNIST meta data | |||||
| """ | |||||
| url_path = "http://yann.lecun.com/exdb/mnist/" | |||||
| """ | |||||
| url prefix for downloading raw file | |||||
| """ | |||||
| raw_file_name = [ | |||||
| "train-images-idx3-ubyte.gz", | |||||
| "train-labels-idx1-ubyte.gz", | |||||
| "t10k-images-idx3-ubyte.gz", | |||||
| "t10k-labels-idx1-ubyte.gz", | |||||
| ] | |||||
| """ | |||||
| raw file names of both training set and test set (10k) | |||||
| """ | |||||
| raw_file_md5 = [ | |||||
| "f68b3c2dcbeaaa9fbdd348bbdeb94873", | |||||
| "d53e105ee54ea40749a09fcbcd1e9432", | |||||
| "9fb629c4189551a2d022fa330f9573f3", | |||||
| "ec29112dd5afa0611ce80d1b7f02629c", | |||||
| ] | |||||
| """ | |||||
| md5 for checking raw files | |||||
| """ | |||||
| def __init__( | |||||
| self, | |||||
| root: str = None, | |||||
| train: bool = True, | |||||
| download: bool = True, | |||||
| timeout: int = 500, | |||||
| ): | |||||
| r""" | |||||
| :param root: path for mnist dataset downloading or loading, if ``None``, | |||||
| set ``root`` to the ``_default_root`` | |||||
| :param train: if ``True``, loading trainingset, else loading test set | |||||
| :param download: if raw files do not exists and download sets to ``True``, | |||||
| download raw files and process, otherwise raise ValueError, default is True | |||||
| """ | |||||
| super().__init__(root, order=("image", "image_category")) | |||||
| self.timeout = timeout | |||||
| # process the root path | |||||
| if root is None: | |||||
| self.root = self._default_root | |||||
| if not os.path.exists(self.root): | |||||
| os.makedirs(self.root) | |||||
| else: | |||||
| self.root = root | |||||
| if not os.path.exists(self.root): | |||||
| if download: | |||||
| logger.debug( | |||||
| "dir %s does not exist, will be automatically created", | |||||
| self.root, | |||||
| ) | |||||
| os.makedirs(self.root) | |||||
| else: | |||||
| raise ValueError("dir %s does not exist" % self.root) | |||||
| if self._check_raw_files(): | |||||
| self.process(train) | |||||
| elif download: | |||||
| self.download() | |||||
| self.process(train) | |||||
| else: | |||||
| raise ValueError( | |||||
| "root does not contain valid raw files, please set download=True" | |||||
| ) | |||||
| def __getitem__(self, index: int) -> Tuple: | |||||
| return tuple(array[index] for array in self.arrays) | |||||
| def __len__(self) -> int: | |||||
| return len(self.arrays[0]) | |||||
| @property | |||||
| def _default_root(self): | |||||
| return os.path.join(_default_dataset_root(), self.__class__.__name__) | |||||
| @property | |||||
| def meta(self): | |||||
| return self._meta_data | |||||
| def _check_raw_files(self): | |||||
| return all( | |||||
| [ | |||||
| os.path.exists(os.path.join(self.root, path)) | |||||
| for path in self.raw_file_name | |||||
| ] | |||||
| ) | |||||
| def download(self): | |||||
| for file_name, md5 in zip(self.raw_file_name, self.raw_file_md5): | |||||
| url = self.url_path + file_name | |||||
| load_raw_data_from_url(url, file_name, md5, self.root, self.timeout) | |||||
| def process(self, train): | |||||
| # load raw files and transform them into meta data and datasets Tuple(np.array) | |||||
| logger.info("process the raw files of %s set...", "train" if train else "test") | |||||
| if train: | |||||
| meta_data_images, images = parse_idx3( | |||||
| os.path.join(self.root, self.raw_file_name[0]) | |||||
| ) | |||||
| meta_data_labels, labels = parse_idx1( | |||||
| os.path.join(self.root, self.raw_file_name[1]) | |||||
| ) | |||||
| else: | |||||
| meta_data_images, images = parse_idx3( | |||||
| os.path.join(self.root, self.raw_file_name[2]) | |||||
| ) | |||||
| meta_data_labels, labels = parse_idx1( | |||||
| os.path.join(self.root, self.raw_file_name[3]) | |||||
| ) | |||||
| self._meta_data = { | |||||
| "images": meta_data_images, | |||||
| "labels": meta_data_labels, | |||||
| } | |||||
| self.arrays = (images, labels.astype(np.int32)) | |||||
| def parse_idx3(idx3_file): | |||||
| # parse idx3 file to meta data and data in numpy array (images) | |||||
| logger.debug("parse idx3 file %s ...", idx3_file) | |||||
| assert idx3_file.endswith(".gz") | |||||
| with gzip.open(idx3_file, "rb") as f: | |||||
| bin_data = f.read() | |||||
| # parse meta data | |||||
| offset = 0 | |||||
| fmt_header = ">iiii" | |||||
| magic, imgs, height, width = struct.unpack_from(fmt_header, bin_data, offset) | |||||
| meta_data = {"magic": magic, "imgs": imgs, "height": height, "width": width} | |||||
| # parse images | |||||
| image_size = height * width | |||||
| offset += struct.calcsize(fmt_header) | |||||
| fmt_image = ">" + str(image_size) + "B" | |||||
| images = [] | |||||
| bar = tqdm(total=meta_data["imgs"], ncols=80) | |||||
| for image in struct.iter_unpack(fmt_image, bin_data[offset:]): | |||||
| images.append(np.array(image, dtype=np.uint8).reshape((height, width, 1))) | |||||
| bar.update() | |||||
| bar.close() | |||||
| return meta_data, images | |||||
| def parse_idx1(idx1_file): | |||||
| # parse idx1 file to meta data and data in numpy array (labels) | |||||
| logger.debug("parse idx1 file %s ...", idx1_file) | |||||
| assert idx1_file.endswith(".gz") | |||||
| with gzip.open(idx1_file, "rb") as f: | |||||
| bin_data = f.read() | |||||
| # parse meta data | |||||
| offset = 0 | |||||
| fmt_header = ">ii" | |||||
| magic, imgs = struct.unpack_from(fmt_header, bin_data, offset) | |||||
| meta_data = {"magic": magic, "imgs": imgs} | |||||
| # parse labels | |||||
| offset += struct.calcsize(fmt_header) | |||||
| fmt_image = ">B" | |||||
| labels = np.empty(imgs, dtype=int) | |||||
| bar = tqdm(total=meta_data["imgs"], ncols=80) | |||||
| for i, label in enumerate(struct.iter_unpack(fmt_image, bin_data[offset:])): | |||||
| labels[i] = label[0] | |||||
| bar.update() | |||||
| bar.close() | |||||
| return meta_data, labels | |||||
| @@ -0,0 +1,498 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # --------------------------------------------------------------------- | |||||
| # Part of the following code in this file refs to maskrcnn-benchmark | |||||
| # MIT License | |||||
| # | |||||
| # Copyright (c) 2018 Facebook | |||||
| # --------------------------------------------------------------------- | |||||
| import json | |||||
| import os | |||||
| from collections import defaultdict | |||||
| import cv2 | |||||
| import numpy as np | |||||
| from .meta_vision import VisionDataset | |||||
| class Objects365(VisionDataset): | |||||
| r"""`Objects365 <https://www.objects365.org/overview.html>`_ Dataset. | |||||
| """ | |||||
| supported_order = ( | |||||
| "image", | |||||
| "boxes", | |||||
| "boxes_category", | |||||
| "info", | |||||
| ) | |||||
| def __init__( | |||||
| self, root, ann_file, remove_images_without_annotations=False, *, order=None | |||||
| ): | |||||
| super().__init__(root, order=order, supported_order=self.supported_order) | |||||
| with open(ann_file, "r") as f: | |||||
| dataset = json.load(f) | |||||
| self.imgs = dict() | |||||
| for img in dataset["images"]: | |||||
| self.imgs[img["id"]] = img | |||||
| self.img_to_anns = defaultdict(list) | |||||
| for ann in dataset["annotations"]: | |||||
| # for saving memory | |||||
| if ( | |||||
| "boxes" not in self.order | |||||
| and "boxes_category" not in self.order | |||||
| and "bbox" in ann | |||||
| ): | |||||
| del ann["bbox"] | |||||
| self.img_to_anns[ann["image_id"]].append(ann) | |||||
| self.cats = dict() | |||||
| for cat in dataset["categories"]: | |||||
| self.cats[cat["id"]] = cat | |||||
| self.ids = list(sorted(self.imgs.keys())) | |||||
| # filter images without detection annotations | |||||
| if remove_images_without_annotations: | |||||
| ids = [] | |||||
| for img_id in self.ids: | |||||
| anno = self.img_to_anns[img_id] | |||||
| # filter crowd annotations | |||||
| anno = [obj for obj in anno if obj["iscrowd"] == 0] | |||||
| anno = [ | |||||
| obj for obj in anno if obj["bbox"][2] > 0 and obj["bbox"][3] > 0 | |||||
| ] | |||||
| if len(anno) > 0: | |||||
| ids.append(img_id) | |||||
| self.img_to_anns[img_id] = anno | |||||
| else: | |||||
| del self.imgs[img_id] | |||||
| del self.img_to_anns[img_id] | |||||
| self.ids = ids | |||||
| self.json_category_id_to_contiguous_id = { | |||||
| v: i + 1 for i, v in enumerate(self.cats.keys()) | |||||
| } | |||||
| self.contiguous_category_id_to_json_id = { | |||||
| v: k for k, v in self.json_category_id_to_contiguous_id.items() | |||||
| } | |||||
| def __getitem__(self, index): | |||||
| img_id = self.ids[index] | |||||
| anno = self.img_to_anns[img_id] | |||||
| target = [] | |||||
| for k in self.order: | |||||
| if k == "image": | |||||
| file_name = self.imgs[img_id]["file_name"] | |||||
| path = os.path.join(self.root, file_name) | |||||
| image = cv2.imread(path, cv2.IMREAD_COLOR) | |||||
| target.append(image) | |||||
| elif k == "boxes": | |||||
| boxes = [obj["bbox"] for obj in anno] | |||||
| boxes = np.array(boxes, dtype=np.float32).reshape(-1, 4) | |||||
| # transfer boxes from xywh to xyxy | |||||
| boxes[:, 2:] += boxes[:, :2] | |||||
| target.append(boxes) | |||||
| elif k == "boxes_category": | |||||
| boxes_category = [obj["category_id"] for obj in anno] | |||||
| boxes_category = [ | |||||
| self.json_category_id_to_contiguous_id[c] for c in boxes_category | |||||
| ] | |||||
| boxes_category = np.array(boxes_category, dtype=np.int32) | |||||
| target.append(boxes_category) | |||||
| elif k == "info": | |||||
| info = self.imgs[img_id] | |||||
| info = [info["height"], info["width"], info["file_name"]] | |||||
| target.append(info) | |||||
| else: | |||||
| raise NotImplementedError | |||||
| return tuple(target) | |||||
| def __len__(self): | |||||
| return len(self.ids) | |||||
| def get_img_info(self, index): | |||||
| img_id = self.ids[index] | |||||
| img_info = self.imgs[img_id] | |||||
| return img_info | |||||
| class_names = ( | |||||
| "person", | |||||
| "sneakers", | |||||
| "chair", | |||||
| "hat", | |||||
| "lamp", | |||||
| "bottle", | |||||
| "cabinet/shelf", | |||||
| "cup", | |||||
| "car", | |||||
| "glasses", | |||||
| "picture/frame", | |||||
| "desk", | |||||
| "handbag", | |||||
| "street lights", | |||||
| "book", | |||||
| "plate", | |||||
| "helmet", | |||||
| "leather shoes", | |||||
| "pillow", | |||||
| "glove", | |||||
| "potted plant", | |||||
| "bracelet", | |||||
| "flower", | |||||
| "tv", | |||||
| "storage box", | |||||
| "vase", | |||||
| "bench", | |||||
| "wine glass", | |||||
| "boots", | |||||
| "bowl", | |||||
| "dining table", | |||||
| "umbrella", | |||||
| "boat", | |||||
| "flag", | |||||
| "speaker", | |||||
| "trash bin/can", | |||||
| "stool", | |||||
| "backpack", | |||||
| "couch", | |||||
| "belt", | |||||
| "carpet", | |||||
| "basket", | |||||
| "towel/napkin", | |||||
| "slippers", | |||||
| "barrel/bucket", | |||||
| "coffee table", | |||||
| "suv", | |||||
| "toy", | |||||
| "tie", | |||||
| "bed", | |||||
| "traffic light", | |||||
| "pen/pencil", | |||||
| "microphone", | |||||
| "sandals", | |||||
| "canned", | |||||
| "necklace", | |||||
| "mirror", | |||||
| "faucet", | |||||
| "bicycle", | |||||
| "bread", | |||||
| "high heels", | |||||
| "ring", | |||||
| "van", | |||||
| "watch", | |||||
| "sink", | |||||
| "horse", | |||||
| "fish", | |||||
| "apple", | |||||
| "camera", | |||||
| "candle", | |||||
| "teddy bear", | |||||
| "cake", | |||||
| "motorcycle", | |||||
| "wild bird", | |||||
| "laptop", | |||||
| "knife", | |||||
| "traffic sign", | |||||
| "cell phone", | |||||
| "paddle", | |||||
| "truck", | |||||
| "cow", | |||||
| "power outlet", | |||||
| "clock", | |||||
| "drum", | |||||
| "fork", | |||||
| "bus", | |||||
| "hanger", | |||||
| "nightstand", | |||||
| "pot/pan", | |||||
| "sheep", | |||||
| "guitar", | |||||
| "traffic cone", | |||||
| "tea pot", | |||||
| "keyboard", | |||||
| "tripod", | |||||
| "hockey", | |||||
| "fan", | |||||
| "dog", | |||||
| "spoon", | |||||
| "blackboard/whiteboard", | |||||
| "balloon", | |||||
| "air conditioner", | |||||
| "cymbal", | |||||
| "mouse", | |||||
| "telephone", | |||||
| "pickup truck", | |||||
| "orange", | |||||
| "banana", | |||||
| "airplane", | |||||
| "luggage", | |||||
| "skis", | |||||
| "soccer", | |||||
| "trolley", | |||||
| "oven", | |||||
| "remote", | |||||
| "baseball glove", | |||||
| "paper towel", | |||||
| "refrigerator", | |||||
| "train", | |||||
| "tomato", | |||||
| "machinery vehicle", | |||||
| "tent", | |||||
| "shampoo/shower gel", | |||||
| "head phone", | |||||
| "lantern", | |||||
| "donut", | |||||
| "cleaning products", | |||||
| "sailboat", | |||||
| "tangerine", | |||||
| "pizza", | |||||
| "kite", | |||||
| "computer box", | |||||
| "elephant", | |||||
| "toiletries", | |||||
| "gas stove", | |||||
| "broccoli", | |||||
| "toilet", | |||||
| "stroller", | |||||
| "shovel", | |||||
| "baseball bat", | |||||
| "microwave", | |||||
| "skateboard", | |||||
| "surfboard", | |||||
| "surveillance camera", | |||||
| "gun", | |||||
| "life saver", | |||||
| "cat", | |||||
| "lemon", | |||||
| "liquid soap", | |||||
| "zebra", | |||||
| "duck", | |||||
| "sports car", | |||||
| "giraffe", | |||||
| "pumpkin", | |||||
| "piano", | |||||
| "stop sign", | |||||
| "radiator", | |||||
| "converter", | |||||
| "tissue ", | |||||
| "carrot", | |||||
| "washing machine", | |||||
| "vent", | |||||
| "cookies", | |||||
| "cutting/chopping board", | |||||
| "tennis racket", | |||||
| "candy", | |||||
| "skating and skiing shoes", | |||||
| "scissors", | |||||
| "folder", | |||||
| "baseball", | |||||
| "strawberry", | |||||
| "bow tie", | |||||
| "pigeon", | |||||
| "pepper", | |||||
| "coffee machine", | |||||
| "bathtub", | |||||
| "snowboard", | |||||
| "suitcase", | |||||
| "grapes", | |||||
| "ladder", | |||||
| "pear", | |||||
| "american football", | |||||
| "basketball", | |||||
| "potato", | |||||
| "paint brush", | |||||
| "printer", | |||||
| "billiards", | |||||
| "fire hydrant", | |||||
| "goose", | |||||
| "projector", | |||||
| "sausage", | |||||
| "fire extinguisher", | |||||
| "extension cord", | |||||
| "facial mask", | |||||
| "tennis ball", | |||||
| "chopsticks", | |||||
| "electronic stove and gas stove", | |||||
| "pie", | |||||
| "frisbee", | |||||
| "kettle", | |||||
| "hamburger", | |||||
| "golf club", | |||||
| "cucumber", | |||||
| "clutch", | |||||
| "blender", | |||||
| "tong", | |||||
| "slide", | |||||
| "hot dog", | |||||
| "toothbrush", | |||||
| "facial cleanser", | |||||
| "mango", | |||||
| "deer", | |||||
| "egg", | |||||
| "violin", | |||||
| "marker", | |||||
| "ship", | |||||
| "chicken", | |||||
| "onion", | |||||
| "ice cream", | |||||
| "tape", | |||||
| "wheelchair", | |||||
| "plum", | |||||
| "bar soap", | |||||
| "scale", | |||||
| "watermelon", | |||||
| "cabbage", | |||||
| "router/modem", | |||||
| "golf ball", | |||||
| "pine apple", | |||||
| "crane", | |||||
| "fire truck", | |||||
| "peach", | |||||
| "cello", | |||||
| "notepaper", | |||||
| "tricycle", | |||||
| "toaster", | |||||
| "helicopter", | |||||
| "green beans", | |||||
| "brush", | |||||
| "carriage", | |||||
| "cigar", | |||||
| "earphone", | |||||
| "penguin", | |||||
| "hurdle", | |||||
| "swing", | |||||
| "radio", | |||||
| "CD", | |||||
| "parking meter", | |||||
| "swan", | |||||
| "garlic", | |||||
| "french fries", | |||||
| "horn", | |||||
| "avocado", | |||||
| "saxophone", | |||||
| "trumpet", | |||||
| "sandwich", | |||||
| "cue", | |||||
| "kiwi fruit", | |||||
| "bear", | |||||
| "fishing rod", | |||||
| "cherry", | |||||
| "tablet", | |||||
| "green vegetables", | |||||
| "nuts", | |||||
| "corn", | |||||
| "key", | |||||
| "screwdriver", | |||||
| "globe", | |||||
| "broom", | |||||
| "pliers", | |||||
| "volleyball", | |||||
| "hammer", | |||||
| "eggplant", | |||||
| "trophy", | |||||
| "dates", | |||||
| "board eraser", | |||||
| "rice", | |||||
| "tape measure/ruler", | |||||
| "dumbbell", | |||||
| "hamimelon", | |||||
| "stapler", | |||||
| "camel", | |||||
| "lettuce", | |||||
| "goldfish", | |||||
| "meat balls", | |||||
| "medal", | |||||
| "toothpaste", | |||||
| "antelope", | |||||
| "shrimp", | |||||
| "rickshaw", | |||||
| "trombone", | |||||
| "pomegranate", | |||||
| "coconut", | |||||
| "jellyfish", | |||||
| "mushroom", | |||||
| "calculator", | |||||
| "treadmill", | |||||
| "butterfly", | |||||
| "egg tart", | |||||
| "cheese", | |||||
| "pig", | |||||
| "pomelo", | |||||
| "race car", | |||||
| "rice cooker", | |||||
| "tuba", | |||||
| "crosswalk sign", | |||||
| "papaya", | |||||
| "hair drier", | |||||
| "green onion", | |||||
| "chips", | |||||
| "dolphin", | |||||
| "sushi", | |||||
| "urinal", | |||||
| "donkey", | |||||
| "electric drill", | |||||
| "spring rolls", | |||||
| "tortoise/turtle", | |||||
| "parrot", | |||||
| "flute", | |||||
| "measuring cup", | |||||
| "shark", | |||||
| "steak", | |||||
| "poker card", | |||||
| "binoculars", | |||||
| "llama", | |||||
| "radish", | |||||
| "noodles", | |||||
| "yak", | |||||
| "mop", | |||||
| "crab", | |||||
| "microscope", | |||||
| "barbell", | |||||
| "bread/bun", | |||||
| "baozi", | |||||
| "lion", | |||||
| "red cabbage", | |||||
| "polar bear", | |||||
| "lighter", | |||||
| "seal", | |||||
| "mangosteen", | |||||
| "comb", | |||||
| "eraser", | |||||
| "pitaya", | |||||
| "scallop", | |||||
| "pencil case", | |||||
| "saw", | |||||
| "table tennis paddle", | |||||
| "okra", | |||||
| "starfish", | |||||
| "eagle", | |||||
| "monkey", | |||||
| "durian", | |||||
| "game board", | |||||
| "rabbit", | |||||
| "french horn", | |||||
| "ambulance", | |||||
| "asparagus", | |||||
| "hoverboard", | |||||
| "pasta", | |||||
| "target", | |||||
| "hotair balloon", | |||||
| "chainsaw", | |||||
| "lobster", | |||||
| "iron", | |||||
| "flashlight", | |||||
| ) | |||||
| @@ -0,0 +1,89 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| import hashlib | |||||
| import os | |||||
| import tarfile | |||||
| from ....distributed.group import is_distributed | |||||
| from ....logger import get_logger | |||||
| from ....utils.http_download import download_from_url | |||||
| IMG_EXT = (".jpg", ".png", ".jpeg", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp") | |||||
| logger = get_logger(__name__) | |||||
| def _default_dataset_root(): | |||||
| default_dataset_root = os.path.expanduser( | |||||
| os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "megengine") | |||||
| ) | |||||
| return default_dataset_root | |||||
| def load_raw_data_from_url( | |||||
| url: str, filename: str, target_md5: str, raw_data_dir: str, timeout: int | |||||
| ): | |||||
| cached_file = os.path.join(raw_data_dir, filename) | |||||
| logger.debug( | |||||
| "load_raw_data_from_url: downloading to or using cached %s ...", cached_file | |||||
| ) | |||||
| if not os.path.exists(cached_file): | |||||
| if is_distributed(): | |||||
| logger.warning( | |||||
| "Downloading raw data in DISTRIBUTED mode\n" | |||||
| " File may be downloaded multiple times. We recommend\n" | |||||
| " users to download in single process first." | |||||
| ) | |||||
| md5 = download_from_url(url, cached_file, http_read_timeout=timeout) | |||||
| else: | |||||
| md5 = calculate_md5(cached_file) | |||||
| if target_md5 == md5: | |||||
| logger.debug("%s exists with correct md5: %s", filename, target_md5) | |||||
| else: | |||||
| os.remove(cached_file) | |||||
| raise RuntimeError("{} exists but fail to match md5".format(filename)) | |||||
| def calculate_md5(filename): | |||||
| m = hashlib.md5() | |||||
| with open(filename, "rb") as f: | |||||
| while True: | |||||
| data = f.read(4096) | |||||
| if not data: | |||||
| break | |||||
| m.update(data) | |||||
| return m.hexdigest() | |||||
| def is_img(filename): | |||||
| return filename.lower().endswith(IMG_EXT) | |||||
| def untar(path, to=None, remove=False): | |||||
| if to is None: | |||||
| to = os.path.dirname(path) | |||||
| with tarfile.open(path, "r") as tar: | |||||
| tar.extractall(path=to) | |||||
| if remove: | |||||
| os.remove(path) | |||||
| def untargz(path, to=None, remove=False): | |||||
| if path.endswith(".tar.gz"): | |||||
| if to is None: | |||||
| to = os.path.dirname(path) | |||||
| with tarfile.open(path, "r:gz") as tar: | |||||
| tar.extractall(path=to) | |||||
| else: | |||||
| raise ValueError("path %s does not end with .tar" % path) | |||||
| if remove: | |||||
| os.remove(path) | |||||
| @@ -0,0 +1,195 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # --------------------------------------------------------------------- | |||||
| # Part of the following code in this file refs to torchvision | |||||
| # BSD 3-Clause License | |||||
| # | |||||
| # Copyright (c) Soumith Chintala 2016, | |||||
| # All rights reserved. | |||||
| # --------------------------------------------------------------------- | |||||
| import collections.abc | |||||
| import os | |||||
| import xml.etree.ElementTree as ET | |||||
| import cv2 | |||||
| import numpy as np | |||||
| from .meta_vision import VisionDataset | |||||
| class PascalVOC(VisionDataset): | |||||
| r"""`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Dataset. | |||||
| """ | |||||
| supported_order = ( | |||||
| "image", | |||||
| "boxes", | |||||
| "boxes_category", | |||||
| "mask", | |||||
| "info", | |||||
| ) | |||||
| def __init__(self, root, image_set, *, order=None): | |||||
| if ("boxes" in order or "boxes_category" in order) and "mask" in order: | |||||
| raise ValueError( | |||||
| "PascalVOC only supports boxes & boxes_category or mask, not both." | |||||
| ) | |||||
| super().__init__(root, order=order, supported_order=self.supported_order) | |||||
| if not os.path.isdir(self.root): | |||||
| raise RuntimeError("Dataset not found or corrupted.") | |||||
| self.image_set = image_set | |||||
| image_dir = os.path.join(self.root, "JPEGImages") | |||||
| if "boxes" in order or "boxes_category" in order: | |||||
| annotation_dir = os.path.join(self.root, "Annotations") | |||||
| splitdet_dir = os.path.join(self.root, "ImageSets/Main") | |||||
| split_f = os.path.join(splitdet_dir, image_set.rstrip("\n") + ".txt") | |||||
| with open(os.path.join(split_f), "r") as f: | |||||
| self.file_names = [x.strip() for x in f.readlines()] | |||||
| self.images = [os.path.join(image_dir, x + ".jpg") for x in self.file_names] | |||||
| self.annotations = [ | |||||
| os.path.join(annotation_dir, x + ".xml") for x in self.file_names | |||||
| ] | |||||
| assert len(self.images) == len(self.annotations) | |||||
| elif "mask" in order: | |||||
| if "aug" in image_set: | |||||
| mask_dir = os.path.join(self.root, "SegmentationClass_aug") | |||||
| else: | |||||
| mask_dir = os.path.join(self.root, "SegmentationClass") | |||||
| splitmask_dir = os.path.join(self.root, "ImageSets/Segmentation") | |||||
| split_f = os.path.join(splitmask_dir, image_set.rstrip("\n") + ".txt") | |||||
| with open(os.path.join(split_f), "r") as f: | |||||
| self.file_names = [x.strip() for x in f.readlines()] | |||||
| self.images = [os.path.join(image_dir, x + ".jpg") for x in self.file_names] | |||||
| self.masks = [os.path.join(mask_dir, x + ".png") for x in self.file_names] | |||||
| assert len(self.images) == len(self.masks) | |||||
| else: | |||||
| raise NotImplementedError | |||||
| def __getitem__(self, index): | |||||
| target = [] | |||||
| for k in self.order: | |||||
| if k == "image": | |||||
| image = cv2.imread(self.images[index], cv2.IMREAD_COLOR) | |||||
| target.append(image) | |||||
| elif k == "boxes": | |||||
| anno = self.parse_voc_xml(ET.parse(self.annotations[index]).getroot()) | |||||
| boxes = [obj["bndbox"] for obj in anno["annotation"]["object"]] | |||||
| # boxes type xyxy | |||||
| boxes = [ | |||||
| (bb["xmin"], bb["ymin"], bb["xmax"], bb["ymax"]) for bb in boxes | |||||
| ] | |||||
| boxes = np.array(boxes, dtype=np.float32).reshape(-1, 4) | |||||
| target.append(boxes) | |||||
| elif k == "boxes_category": | |||||
| anno = self.parse_voc_xml(ET.parse(self.annotations[index]).getroot()) | |||||
| boxes_category = [obj["name"] for obj in anno["annotation"]["object"]] | |||||
| boxes_category = [ | |||||
| self.class_names.index(bc) + 1 for bc in boxes_category | |||||
| ] | |||||
| boxes_category = np.array(boxes_category, dtype=np.int32) | |||||
| target.append(boxes_category) | |||||
| elif k == "mask": | |||||
| if "aug" in self.image_set: | |||||
| mask = cv2.imread(self.masks[index], cv2.IMREAD_GRAYSCALE) | |||||
| else: | |||||
| mask = cv2.imread(self.masks[index], cv2.IMREAD_COLOR) | |||||
| mask = self._trans_mask(mask) | |||||
| mask = mask[:, :, np.newaxis] | |||||
| target.append(mask) | |||||
| elif k == "info": | |||||
| if image is None: | |||||
| image = cv2.imread(self.images[index], cv2.IMREAD_COLOR) | |||||
| info = [image.shape[0], image.shape[1], self.file_names[index]] | |||||
| target.append(info) | |||||
| else: | |||||
| raise NotImplementedError | |||||
| return tuple(target) | |||||
| def __len__(self): | |||||
| return len(self.images) | |||||
| def _trans_mask(self, mask): | |||||
| label = np.ones(mask.shape[:2]) * 255 | |||||
| for i in range(len(self.class_colors)): | |||||
| b, g, r = self.class_colors[i] | |||||
| label[ | |||||
| (mask[:, :, 0] == b) & (mask[:, :, 1] == g) & (mask[:, :, 2] == r) | |||||
| ] = i | |||||
| return label.astype(np.uint8) | |||||
| def parse_voc_xml(self, node): | |||||
| voc_dict = {} | |||||
| children = list(node) | |||||
| if children: | |||||
| def_dic = collections.defaultdict(list) | |||||
| for dc in map(self.parse_voc_xml, children): | |||||
| for ind, v in dc.items(): | |||||
| def_dic[ind].append(v) | |||||
| if node.tag == "annotation": | |||||
| def_dic["object"] = [def_dic["object"]] | |||||
| voc_dict = { | |||||
| node.tag: { | |||||
| ind: v[0] if len(v) == 1 else v for ind, v in def_dic.items() | |||||
| } | |||||
| } | |||||
| if node.text: | |||||
| text = node.text.strip() | |||||
| if not children: | |||||
| voc_dict[node.tag] = text | |||||
| return voc_dict | |||||
| class_names = ( | |||||
| "aeroplane", | |||||
| "bicycle", | |||||
| "bird", | |||||
| "boat", | |||||
| "bottle", | |||||
| "bus", | |||||
| "car", | |||||
| "cat", | |||||
| "chair", | |||||
| "cow", | |||||
| "diningtable", | |||||
| "dog", | |||||
| "horse", | |||||
| "motorbike", | |||||
| "person", | |||||
| "pottedplant", | |||||
| "sheep", | |||||
| "sofa", | |||||
| "train", | |||||
| "tvmonitor", | |||||
| ) | |||||
| class_colors = [ | |||||
| [0, 0, 128], | |||||
| [0, 128, 0], | |||||
| [0, 128, 128], | |||||
| [128, 0, 0], | |||||
| [128, 0, 128], | |||||
| [128, 128, 0], | |||||
| [128, 128, 128], | |||||
| [0, 0, 64], | |||||
| [0, 0, 192], | |||||
| [0, 128, 64], | |||||
| [0, 128, 192], | |||||
| [128, 0, 64], | |||||
| [128, 0, 192], | |||||
| [128, 128, 64], | |||||
| [128, 128, 192], | |||||
| [0, 64, 0], | |||||
| [0, 64, 128], | |||||
| [0, 192, 0], | |||||
| [0, 192, 128], | |||||
| [128, 64, 0], | |||||
| ] | |||||
| @@ -0,0 +1,274 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| import collections.abc | |||||
| import math | |||||
| from abc import ABC | |||||
| from typing import Any, Generator, Iterator, List, Union | |||||
| import numpy as np | |||||
| import megengine.distributed as dist | |||||
| class Sampler(ABC): | |||||
| def __init__( | |||||
| self, | |||||
| dataset, | |||||
| batch_size=1, | |||||
| drop_last=False, | |||||
| num_samples=None, | |||||
| world_size=None, | |||||
| rank=None, | |||||
| seed=None, | |||||
| ): | |||||
| r""" | |||||
| An abstract class for all sampler | |||||
| :type dataset: `dataset` | |||||
| :param dataset: dataset to sample from | |||||
| :type batch_size: positive integer | |||||
| :param batch_size: batch size for batch method | |||||
| :type drop_last: bool | |||||
| :param drop_last: set ``True`` to drop the last incomplete batch, | |||||
| if the dataset size is not divisible by the batch size. If ``False`` and | |||||
| the size of dataset is not divisible by the batch_size, then the last batch will | |||||
| be smaller. (default: ``False``) | |||||
| :type num_samples: positive integer | |||||
| :param num_samples: number of samples assigned to one rank | |||||
| :type world_size: positive integer | |||||
| :param world_size: number of ranks | |||||
| :type rank: non-negative integer within 0 and world_size | |||||
| :param rank: rank id, non-negative interger within 0 and ``world_size`` | |||||
| :type seed: non-negative integer | |||||
| :param seed: seed for random operators | |||||
| """ | |||||
| if ( | |||||
| not isinstance(batch_size, int) | |||||
| or isinstance(batch_size, bool) | |||||
| or batch_size <= 0 | |||||
| ): | |||||
| raise ValueError( | |||||
| "batch_size should be a positive integer value, " | |||||
| "but got batch_size={}".format(batch_size) | |||||
| ) | |||||
| if not isinstance(drop_last, bool): | |||||
| raise ValueError( | |||||
| "drop_last should be a boolean value, but got " | |||||
| "drop_last={}".format(drop_last) | |||||
| ) | |||||
| if num_samples is not None and ( | |||||
| not isinstance(num_samples, int) | |||||
| or isinstance(num_samples, bool) | |||||
| or num_samples <= 0 | |||||
| ): | |||||
| raise ValueError( | |||||
| "num_samples should be a positive integer " | |||||
| "value, but got num_samples={}".format(num_samples) | |||||
| ) | |||||
| self.batch_size = batch_size | |||||
| self.dataset = dataset | |||||
| self.drop_last = drop_last | |||||
| if world_size is None: | |||||
| world_size = dist.get_world_size() if dist.is_distributed() else 1 | |||||
| self.world_size = world_size | |||||
| if rank is None: | |||||
| rank = dist.get_rank() if dist.is_distributed() else 0 | |||||
| self.rank = rank | |||||
| if num_samples is None: | |||||
| num_samples = len(self.dataset) | |||||
| self.num_samples = int(math.ceil(num_samples / self.world_size)) | |||||
| # Make sure seeds are the same at each rank | |||||
| if seed is None and self.world_size > 1: | |||||
| seed = 0 | |||||
| self.rng = np.random.RandomState(seed) | |||||
| def __iter__(self) -> Union[Generator, Iterator]: | |||||
| return self.batch() | |||||
| def __len__(self) -> int: | |||||
| if self.drop_last: | |||||
| return self.num_samples // self.batch_size | |||||
| else: | |||||
| return int(math.ceil(self.num_samples / self.batch_size)) | |||||
| def sample(self): | |||||
| """ | |||||
| return a list contains all sample indices | |||||
| """ | |||||
| raise NotImplementedError | |||||
| def scatter(self, indices) -> List: | |||||
| r""" | |||||
| scatter method is used for splitting indices into subset, each subset | |||||
| will be assigned to a rank. Indices are evenly splitted by default. | |||||
| If customized indices assignment method is needed, please rewrite this method | |||||
| """ | |||||
| total_size = self.num_samples * self.world_size | |||||
| # add extra indices to make it evenly divisible | |||||
| indices += indices[: (total_size - len(indices))] | |||||
| assert len(indices) == total_size | |||||
| # subsample | |||||
| indices = indices[self.rank : total_size : self.world_size] | |||||
| assert len(indices) == self.num_samples | |||||
| return indices | |||||
| def batch(self) -> Iterator[List[Any]]: | |||||
| r""" | |||||
| batch method provides a batch indices generator | |||||
| """ | |||||
| indices = list(self.sample()) | |||||
| # user might pass the world_size parameter without dist, | |||||
| # so dist.is_distributed() should not be used | |||||
| if self.world_size > 1: | |||||
| indices = self.scatter(indices) | |||||
| step, length = self.batch_size, len(indices) | |||||
| batch_index = [indices[i : i + step] for i in range(0, length, step)] | |||||
| if self.drop_last and len(batch_index[-1]) < self.batch_size: | |||||
| batch_index.pop() | |||||
| return iter(batch_index) | |||||
| class SequentialSampler(Sampler): | |||||
| def __init__( | |||||
| self, | |||||
| dataset, | |||||
| batch_size=1, | |||||
| drop_last=False, | |||||
| indices=None, | |||||
| world_size=None, | |||||
| rank=None, | |||||
| ): | |||||
| r""" | |||||
| Sample elements sequentially | |||||
| """ | |||||
| super().__init__(dataset, batch_size, drop_last, None, world_size, rank) | |||||
| if indices is not None and not isinstance(indices, collections.abc.Sequence): | |||||
| raise ValueError( | |||||
| "indices should be None or a sequence, " | |||||
| "but got indices={}".format(indices) | |||||
| ) | |||||
| self.indices = indices | |||||
| def sample(self) -> Iterator[Any]: | |||||
| r""" | |||||
| return a generator | |||||
| """ | |||||
| if self.indices is None: | |||||
| return iter(range(len(self.dataset))) | |||||
| else: | |||||
| return self.indices | |||||
| class RandomSampler(Sampler): | |||||
| def __init__( | |||||
| self, | |||||
| dataset, | |||||
| batch_size=1, | |||||
| drop_last=False, | |||||
| indices=None, | |||||
| world_size=None, | |||||
| rank=None, | |||||
| seed=None, | |||||
| ): | |||||
| r""" | |||||
| Sample elements randomly without replacement | |||||
| """ | |||||
| super().__init__(dataset, batch_size, drop_last, None, world_size, rank, seed) | |||||
| if indices is not None and not isinstance(indices, collections.abc.Sequence): | |||||
| raise ValueError( | |||||
| "indices should be None or a sequence, " | |||||
| "but got indices={}".format(indices) | |||||
| ) | |||||
| self.indices = indices | |||||
| def sample(self) -> List: | |||||
| if self.indices is None: | |||||
| return self.rng.permutation(len(self.dataset)).tolist() | |||||
| else: | |||||
| return self.rng.permutation(self.indices).tolist() | |||||
| class ReplacementSampler(Sampler): | |||||
| def __init__( | |||||
| self, | |||||
| dataset, | |||||
| batch_size=1, | |||||
| drop_last=False, | |||||
| num_samples=None, | |||||
| weights=None, | |||||
| world_size=None, | |||||
| rank=None, | |||||
| seed=None, | |||||
| ): | |||||
| r""" | |||||
| Sample elements randomly with replacement | |||||
| :type weights: List | |||||
| :param weights: weights for sampling indices, it could be unnormalized weights | |||||
| """ | |||||
| super().__init__( | |||||
| dataset, batch_size, drop_last, num_samples, world_size, rank, seed | |||||
| ) | |||||
| if weights is not None: | |||||
| if not isinstance(weights, collections.abc.Sequence): | |||||
| raise ValueError( | |||||
| "weights should be None or a sequence, " | |||||
| "but got weights={}".format(weights) | |||||
| ) | |||||
| if len(weights) != len(dataset): | |||||
| raise ValueError( | |||||
| "len(dataset)={} should be equal to" | |||||
| "len(weights)={}".format(len(dataset), len(weights)) | |||||
| ) | |||||
| self.weights = weights | |||||
| if self.weights is not None: | |||||
| self.weights = np.array(weights) / sum(weights) | |||||
| def sample(self) -> List: | |||||
| n = len(self.dataset) | |||||
| if self.weights is None: | |||||
| return self.rng.randint(n, size=self.num_samples).tolist() | |||||
| else: | |||||
| return self.rng.multinomial(n, self.weights, self.num_samples).tolist() | |||||
| class Infinite(Sampler): | |||||
| r"""Infinite Sampler warper for basic sampler""" | |||||
| def sample(self): | |||||
| raise NotImplementedError("sample method not supported in Infinite") | |||||
| def __init__(self, sampler): | |||||
| self.sampler = sampler | |||||
| self.sampler_iter = iter(self.sampler) | |||||
| def __iter__(self): | |||||
| return self | |||||
| def __next__(self): | |||||
| try: | |||||
| index = next(self.sampler_iter) | |||||
| except StopIteration: | |||||
| self.sampler_iter = iter(self.sampler) | |||||
| index = next(self.sampler_iter) | |||||
| return index | |||||
| def __len__(self): | |||||
| return np.iinfo(np.int64).max | |||||
| @@ -0,0 +1,10 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| from .meta_transform import PseudoTransform, Transform | |||||
| from .vision import * | |||||
| @@ -0,0 +1,31 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| from abc import ABC, abstractmethod | |||||
| from typing import Sequence, Tuple | |||||
| class Transform(ABC): | |||||
| """ | |||||
| rewrite apply method in subclass | |||||
| """ | |||||
| def apply_batch(self, inputs: Sequence[Tuple]): | |||||
| return tuple(self.apply(input) for input in inputs) | |||||
| @abstractmethod | |||||
| def apply(self, input: Tuple): | |||||
| pass | |||||
| def __repr__(self): | |||||
| return self.__class__.__name__ | |||||
| class PseudoTransform(Transform): | |||||
| def apply(self, input: Tuple): | |||||
| return input | |||||
| @@ -0,0 +1,9 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| from .transform import * | |||||
| @@ -0,0 +1,111 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| import collections.abc | |||||
| import functools | |||||
| import random | |||||
| import cv2 | |||||
| import numpy as np | |||||
| def wrap_keepdims(func): | |||||
| """Wraper to keep the dimension of input images unchanged""" | |||||
| @functools.wraps(func) | |||||
| def wrapper(image, *args, **kwargs): | |||||
| if len(image.shape) != 3: | |||||
| raise ValueError( | |||||
| "image must have 3 dims, but got {} dims".format(len(image.shape)) | |||||
| ) | |||||
| ret = func(image, *args, **kwargs) | |||||
| if len(ret.shape) == 2: | |||||
| ret = ret[:, :, np.newaxis] | |||||
| return ret | |||||
| return wrapper | |||||
| @wrap_keepdims | |||||
| def to_gray(image): | |||||
| r""" | |||||
| Change BGR format image's color space to gray | |||||
| :param image: Input BGR format image, with (H, W, C) shape | |||||
| :return: Gray format image, with (H, W, C) shape | |||||
| """ | |||||
| return cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) | |||||
| @wrap_keepdims | |||||
| def to_bgr(image): | |||||
| r""" | |||||
| Change gray format image's color space to BGR | |||||
| :param image: input Gray format image, with (H, W, C) shape | |||||
| :return: BGR format image, with (H, W, C) shape | |||||
| """ | |||||
| return cv2.cvtColor(image, cv2.COLOR_GRAY2BGR) | |||||
| @wrap_keepdims | |||||
| def pad(input, size, value): | |||||
| r""" | |||||
| Pad input data with *value* and given *size* | |||||
| :param input: Input data, with (H, W, C) shape | |||||
| :param size: Padding size of input data, it could be integer or sequence. | |||||
| If it's an integer, the input data will be padded in four directions. | |||||
| If it's a sequence contains two integer, the bottom and right side | |||||
| of input data will be padded. | |||||
| If it's a sequence contains four integer, the top, bottom, left, right | |||||
| side of input data will be padded with given size. | |||||
| :param value: Padding value of data, could be a sequence of int or float. | |||||
| if it's float value, the dtype of image will be casted to float32 also. | |||||
| :return: Padded image | |||||
| """ | |||||
| if isinstance(size, int): | |||||
| size = (size, size, size, size) | |||||
| elif isinstance(size, collections.abc.Sequence) and len(size) == 2: | |||||
| size = (0, size[0], 0, size[1]) | |||||
| if np.array(value).dtype == float: | |||||
| input = input.astype(np.float32) | |||||
| return cv2.copyMakeBorder(input, *size, cv2.BORDER_CONSTANT, value=value) | |||||
| @wrap_keepdims | |||||
| def flip(image, flipCode): | |||||
| r""" | |||||
| Accordding to the flipCode (the type of flip), flip the input image | |||||
| :param image: Input image, with (H, W, C) shape | |||||
| :param flipCode: code that indicates the type of flip. | |||||
| 1 : Flip horizontally | |||||
| 0 : Flip vertically | |||||
| -1 : Flip horizontally and vertically | |||||
| :return: BGR format image, with (H, W, C) shape | |||||
| """ | |||||
| return cv2.flip(image, flipCode=flipCode) | |||||
| @wrap_keepdims | |||||
| def resize(input, size, interpolation=cv2.INTER_LINEAR): | |||||
| r""" | |||||
| resize the input data to given size | |||||
| :param input: Input data, could be image or masks, with (H, W, C) shape | |||||
| :param size: Target size of input data, with (height, width) shape. | |||||
| :param interpolation: Interpolation method. | |||||
| :return: Resized data, with (H, W, C) shape | |||||
| """ | |||||
| if len(size) != 2: | |||||
| raise ValueError("resize needs (h, w), but got {}".format(size)) | |||||
| if isinstance(interpolation, collections.abc.Sequence): | |||||
| interpolation = random.choice(interpolation) | |||||
| return cv2.resize(input, size[::-1], interpolation=interpolation) | |||||
| @@ -0,0 +1,89 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| import os | |||||
| from .core._imperative_rt.common import CompNode, DeviceType | |||||
| __all__ = [ | |||||
| "is_cuda_available", | |||||
| "get_device_count", | |||||
| "get_default_device", | |||||
| "set_default_device", | |||||
| ] | |||||
| _default_device = os.getenv("MGE_DEFAULT_DEVICE", "xpux") | |||||
| def _valid_device(inp): | |||||
| if isinstance(inp, str) and len(inp) == 4: | |||||
| if inp[0] in {"x", "c", "g"} and inp[1:3] == "pu": | |||||
| if inp[3] == "x" or inp[3].isdigit(): | |||||
| return True | |||||
| return False | |||||
| def _str2device_type(type_str: str, allow_unspec: bool = True): | |||||
| type_str = type_str.upper() | |||||
| if type_str == "CPU": | |||||
| return DeviceType.CPU | |||||
| elif type_str == "GPU" or type_str == "CUDA": | |||||
| return DeviceType.CUDA | |||||
| else: | |||||
| assert allow_unspec and str == "XPU", "bad device type" | |||||
| return DeviceType.UNSPEC | |||||
| def get_device_count(device_type: str) -> int: | |||||
| """Gets number of devices installed on this system. | |||||
| :param device_type: device type, one of 'gpu' or 'cpu' | |||||
| """ | |||||
| device_type_set = ("cpu", "gpu") | |||||
| assert device_type in device_type_set, "device must be one of {}".format( | |||||
| device_type_set | |||||
| ) | |||||
| device_type = _str2device_type(device_type) | |||||
| return CompNode._get_device_count(device_type, False) | |||||
| def is_cuda_available() -> bool: | |||||
| """Returns whether cuda device is available on this system. | |||||
| """ | |||||
| t = _str2device_type("gpu") | |||||
| return CompNode._get_device_count(t, False) > 0 | |||||
| def set_default_device(device: str = "xpux"): | |||||
| r"""Sets default computing node. | |||||
| :param device: default device type. The type can be 'cpu0', 'cpu1', etc., | |||||
| or 'gpu0', 'gpu1', etc., to specify the particular cpu or gpu to use. | |||||
| 'cpux' and 'gpux' can also be used to specify any number of cpu or gpu devices. | |||||
| 'multithread' device type is avaliable when inference, which implements | |||||
| multi-threading parallelism at the operator level. For example, | |||||
| 'multithread4' will compute with 4 threads. which implements | |||||
| The default value is 'xpux' to specify any device available. | |||||
| It can also be set by environmental variable `MGE_DEFAULT_DEVICE`. | |||||
| """ | |||||
| global _default_device # pylint: disable=global-statement | |||||
| assert _valid_device(device), "Invalid device name {}".format(device) | |||||
| _default_device = device | |||||
| def get_default_device() -> str: | |||||
| r"""Gets default computing node. | |||||
| It returns the value set by :func:`~.set_default_device`. | |||||
| """ | |||||
| return _default_device | |||||
| @@ -0,0 +1,25 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| from .group import ( | |||||
| WORLD, | |||||
| get_backend, | |||||
| get_client, | |||||
| get_mm_server_addr, | |||||
| get_py_server_addr, | |||||
| get_rank, | |||||
| get_world_size, | |||||
| group_barrier, | |||||
| init_process_group, | |||||
| is_distributed, | |||||
| new_group, | |||||
| ) | |||||
| from .helper import synchronized | |||||
| from .launcher import launcher | |||||
| from .server import Client, Server | |||||
| from .util import get_free_ports | |||||
| @@ -0,0 +1,176 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| from typing import List, Optional, Tuple | |||||
| from ..device import set_default_device | |||||
| from .server import Client, Server | |||||
| class StaticData: | |||||
| server = None | |||||
| client = None | |||||
| master_ip = None | |||||
| py_server_port = None | |||||
| mm_server_port = None | |||||
| world_size = None | |||||
| proc_rank = None | |||||
| device = None | |||||
| backend = None | |||||
| next_stream = None | |||||
| _sd = None | |||||
| class Group: | |||||
| def __init__(self, proc_ranks): | |||||
| if len(proc_ranks) == 0: # empty group | |||||
| self.proc_ranks = None | |||||
| self.stream = None | |||||
| else: | |||||
| self.reset(proc_ranks) | |||||
| def reset(self, proc_ranks): | |||||
| self.check(proc_ranks) | |||||
| self.proc_ranks = proc_ranks | |||||
| self.stream = _sd.next_stream | |||||
| _sd.next_stream += 1 | |||||
| def check(self, proc_ranks): | |||||
| assert _sd is not None, "please call init_process_group first" | |||||
| for rank in proc_ranks: | |||||
| assert isinstance(rank, int) | |||||
| assert rank >= 0 and rank < _sd.world_size | |||||
| assert _sd.proc_rank in proc_ranks | |||||
| @property | |||||
| def size(self): | |||||
| assert len(self.proc_ranks) > 0, "invalid group" | |||||
| return len(self.proc_ranks) | |||||
| @property | |||||
| def key(self): | |||||
| assert len(self.proc_ranks) > 0, "invalid group" | |||||
| return ",".join(map(str, self.proc_ranks)) | |||||
| @property | |||||
| def rank(self): | |||||
| assert len(self.proc_ranks) > 0, "invalid group" | |||||
| return self.proc_ranks.index(_sd.proc_rank) | |||||
| @property | |||||
| def comp_node(self): | |||||
| assert len(self.proc_ranks) > 0, "invalid group" | |||||
| return "gpu{}:{}".format(_sd.device, self.stream) | |||||
| WORLD = Group([]) | |||||
| def init_process_group( | |||||
| master_ip: str, | |||||
| port: int, | |||||
| world_size: int, | |||||
| rank: int, | |||||
| device: int, | |||||
| backend: Optional[str] = "nccl", | |||||
| ) -> None: | |||||
| """Initialize the distributed process group and specify the device used in the current process | |||||
| :param master_ip: IP address of the master node | |||||
| :param port: Port available for all processes to communicate | |||||
| :param world_size: Total number of processes participating in the job | |||||
| :param rank: Rank of the current process | |||||
| :param device: The GPU device id to bind this process to | |||||
| :param backend: Communicator backend, currently support 'nccl' and 'ucx' | |||||
| """ | |||||
| if not isinstance(master_ip, str): | |||||
| raise TypeError("Expect type str but got {}".format(type(master_ip))) | |||||
| if not isinstance(port, int): | |||||
| raise TypeError("Expect type int but got {}".format(type(port))) | |||||
| if not isinstance(world_size, int): | |||||
| raise TypeError("Expect type int but got {}".format(type(world_size))) | |||||
| if not isinstance(rank, int): | |||||
| raise TypeError("Expect type int but got {}".format(type(rank))) | |||||
| if not isinstance(device, int): | |||||
| raise TypeError("Expect type int but got {}".format(type(backend))) | |||||
| if not isinstance(backend, str): | |||||
| raise TypeError("Expect type str but got {}".format(type(backend))) | |||||
| global _sd | |||||
| assert _sd is None, "init_process_group should be called only once" | |||||
| _sd = StaticData() | |||||
| assert world_size > 1 | |||||
| assert rank >= 0 and rank < world_size | |||||
| assert port > 0 | |||||
| _sd.client = Client(master_ip, port) | |||||
| _sd.master_ip = master_ip | |||||
| _sd.py_server_port = port | |||||
| _sd.mm_server_port = _sd.client.get_mm_server_port() | |||||
| _sd.world_size = world_size | |||||
| _sd.proc_rank = rank | |||||
| _sd.device = device | |||||
| _sd.backend = backend | |||||
| _sd.next_stream = 1 | |||||
| WORLD.reset(list(range(world_size))) | |||||
| set_default_device("gpu{}".format(device)) | |||||
| def is_distributed() -> bool: | |||||
| """Return True if the distributed process group has been initialized""" | |||||
| return _sd is not None | |||||
| def get_rank() -> int: | |||||
| """Get the rank of the current process""" | |||||
| return _sd.proc_rank if _sd is not None else 0 | |||||
| def get_world_size() -> int: | |||||
| """Get the total number of processes participating in the job""" | |||||
| return _sd.world_size if _sd is not None else 1 | |||||
| def get_backend() -> str: | |||||
| """Get the backend str""" | |||||
| assert _sd is not None, "please call init_process_group first" | |||||
| return _sd.backend if _sd is not None else None | |||||
| def get_py_server_addr() -> Tuple[str, int]: | |||||
| """Get master_ip and port of python XML RPC server""" | |||||
| assert _sd is not None, "please call init_process_group first" | |||||
| return _sd.master_ip, _sd.py_server_port | |||||
| def get_mm_server_addr() -> Tuple[str, int]: | |||||
| """Get master_ip and port of C++ mm_server""" | |||||
| assert _sd is not None, "please call init_process_group first" | |||||
| return _sd.master_ip, _sd.mm_server_port | |||||
| def get_client() -> Client: | |||||
| """Get client of python XML RPC server""" | |||||
| assert _sd is not None, "please call init_process_group first" | |||||
| return _sd.client | |||||
| def new_group(proc_ranks: List[int]) -> Group: | |||||
| """Build a subgroup containing certain ranks""" | |||||
| return Group(proc_ranks) | |||||
| def group_barrier(group: Optional[Group] = WORLD) -> None: | |||||
| """Block until all ranks in the group reach this barrier""" | |||||
| assert isinstance(group, Group) | |||||
| _sd.client.group_barrier(group.key, group.size) | |||||
| @@ -0,0 +1,28 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| import functools | |||||
| from typing import Callable | |||||
| from .group import group_barrier, is_distributed | |||||
| def synchronized(func: Callable): | |||||
| """Decorator. Decorated function will synchronize when finished. | |||||
| Specifically, we use this to prevent data race during hub.load""" | |||||
| @functools.wraps(func) | |||||
| def wrapper(*args, **kwargs): | |||||
| if not is_distributed(): | |||||
| return func(*args, **kwargs) | |||||
| ret = func(*args, **kwargs) | |||||
| group_barrier() | |||||
| return ret | |||||
| return wrapper | |||||
| @@ -0,0 +1,68 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| import multiprocessing as mp | |||||
| from ..device import get_device_count | |||||
| from .group import init_process_group | |||||
| from .server import Server | |||||
| from .util import get_free_ports | |||||
| def _get_device_count(): | |||||
| """use subprocess to avoid cuda environment initialization in the main process""" | |||||
| def run(q): | |||||
| count = get_device_count("gpu") | |||||
| q.put(count) | |||||
| q = mp.Queue() | |||||
| p = mp.Process(target=run, args=(q,)) | |||||
| p.start() | |||||
| p.join() | |||||
| return q.get() | |||||
| def _run_wrapped(func, master_ip, port, world_size, rank, dev, args, kwargs): | |||||
| """init distributed process group and run wrapped function""" | |||||
| init_process_group( | |||||
| master_ip=master_ip, port=port, world_size=world_size, rank=rank, device=dev | |||||
| ) | |||||
| func(*args, **kwargs) | |||||
| def launcher(n_gpus): | |||||
| """decorator for launching multiple processes in single-machine multi-gpu training""" | |||||
| count = _get_device_count() | |||||
| assert isinstance(n_gpus, int) and n_gpus > 1, "invalid n_gpus" | |||||
| assert n_gpus <= count, "{} gpus required, {} gpus provided".format(n_gpus, count) | |||||
| def decorator(func): | |||||
| def wrapper(*args, **kwargs): | |||||
| master_ip = "localhost" | |||||
| port = get_free_ports(1)[0] | |||||
| server = Server(port) | |||||
| procs = [] | |||||
| for rank in range(n_gpus): | |||||
| p = mp.Process( | |||||
| target=_run_wrapped, | |||||
| args=(func, master_ip, port, n_gpus, rank, rank, args, kwargs), | |||||
| ) | |||||
| p.start() | |||||
| procs.append(p) | |||||
| for rank in range(n_gpus): | |||||
| procs[rank].join() | |||||
| code = procs[rank].exitcode | |||||
| assert code == 0, "subprocess {} exit with code {}".format(rank, code) | |||||
| return wrapper | |||||
| return decorator | |||||
| @@ -0,0 +1,170 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| import multiprocessing as mp | |||||
| import threading | |||||
| import time | |||||
| from collections import defaultdict | |||||
| from functools import partial | |||||
| from socketserver import ThreadingMixIn | |||||
| from xmlrpc.client import ServerProxy | |||||
| from xmlrpc.server import SimpleXMLRPCServer | |||||
| from ..core._imperative_rt.utils import create_mm_server | |||||
| from .util import get_free_ports | |||||
| class Future: | |||||
| def __init__(self, ack=True): | |||||
| self.ready = threading.Event() | |||||
| self.ack = threading.Event() if ack else None | |||||
| def set(self, value): | |||||
| self.value = value | |||||
| self.ready.set() | |||||
| if self.ack: | |||||
| self.ack.wait() | |||||
| def get(self): | |||||
| self.ready.wait() | |||||
| if self.ack: | |||||
| self.ack.set() | |||||
| return self.value | |||||
| class Methods: | |||||
| def __init__(self, mm_server_port): | |||||
| self.lock = threading.Lock() | |||||
| self.mm_server_port = mm_server_port | |||||
| self.dict_is_grad = defaultdict(partial(Future, True)) | |||||
| self.dict_remote_tracer = defaultdict(partial(Future, True)) | |||||
| self.dict_pack_list = defaultdict(partial(Future, False)) | |||||
| self.dict_barrier_counter = defaultdict(int) | |||||
| self.dict_barrier_event = defaultdict(threading.Event) | |||||
| def connect(self): | |||||
| return True | |||||
| def get_mm_server_port(self): | |||||
| return self.mm_server_port | |||||
| def set_is_grad(self, rank_peer, is_grad): | |||||
| with self.lock: | |||||
| future = self.dict_is_grad[rank_peer] | |||||
| future.set(is_grad) | |||||
| return True | |||||
| def check_is_grad(self, rank_peer): | |||||
| with self.lock: | |||||
| future = self.dict_is_grad[rank_peer] | |||||
| ret = future.get() | |||||
| with self.lock: | |||||
| del self.dict_is_grad[rank_peer] | |||||
| return ret | |||||
| def set_remote_tracer(self, rank_peer, tracer_set): | |||||
| with self.lock: | |||||
| future = self.dict_remote_tracer[rank_peer] | |||||
| future.set(tracer_set) | |||||
| return True | |||||
| def check_remote_tracer(self, rank_peer): | |||||
| with self.lock: | |||||
| future = self.dict_remote_tracer[rank_peer] | |||||
| ret = future.get() | |||||
| with self.lock: | |||||
| del self.dict_remote_tracer[rank_peer] | |||||
| return ret | |||||
| def set_pack_list(self, key, pack_list): | |||||
| with self.lock: | |||||
| future = self.dict_pack_list[key] | |||||
| future.set(pack_list) | |||||
| return True | |||||
| def get_pack_list(self, key): | |||||
| with self.lock: | |||||
| future = self.dict_pack_list[key] | |||||
| return future.get() | |||||
| def group_barrier(self, key, size): | |||||
| with self.lock: | |||||
| self.dict_barrier_counter[key] += 1 | |||||
| counter = self.dict_barrier_counter[key] | |||||
| event = self.dict_barrier_event[key] | |||||
| if counter == size: | |||||
| del self.dict_barrier_counter[key] | |||||
| del self.dict_barrier_event[key] | |||||
| event.set() | |||||
| else: | |||||
| event.wait() | |||||
| return True | |||||
| class ThreadXMLRPCServer(ThreadingMixIn, SimpleXMLRPCServer): | |||||
| pass | |||||
| def start_server(py_server_port, mm_server_port): | |||||
| server = ThreadXMLRPCServer(("0.0.0.0", py_server_port), logRequests=False) | |||||
| server.register_instance(Methods(mm_server_port)) | |||||
| server.serve_forever() | |||||
| class Server: | |||||
| def __init__(self, port): | |||||
| self.py_server_port = get_free_ports(1)[0] if port == 0 else port | |||||
| self.mm_server_port = create_mm_server("0.0.0.0", 0) | |||||
| self.proc = mp.Process( | |||||
| target=start_server, | |||||
| args=(self.py_server_port, self.mm_server_port), | |||||
| daemon=True, | |||||
| ) | |||||
| self.proc.start() | |||||
| class Client: | |||||
| def __init__(self, master_ip, port): | |||||
| self.master_ip = master_ip | |||||
| self.port = port | |||||
| self.connect() | |||||
| def connect(self): | |||||
| while True: | |||||
| try: | |||||
| self.proxy = ServerProxy( | |||||
| "http://{}:{}".format(self.master_ip, self.port) | |||||
| ) | |||||
| if self.proxy.connect(): | |||||
| break | |||||
| except: | |||||
| time.sleep(1) | |||||
| def get_mm_server_port(self): | |||||
| return self.proxy.get_mm_server_port() | |||||
| def set_is_grad(self, rank_peer, is_grad): | |||||
| self.proxy.set_is_grad(rank_peer, is_grad) | |||||
| def check_is_grad(self, rank_peer): | |||||
| return self.proxy.check_is_grad(rank_peer) | |||||
| def set_remote_tracer(self, rank_peer, tracer_set): | |||||
| self.proxy.set_remote_tracer(rank_peer, tracer_set) | |||||
| def check_remote_tracer(self, rank_peer): | |||||
| return self.proxy.check_remote_tracer(rank_peer) | |||||
| def set_pack_list(self, key, pack_list): | |||||
| self.proxy.set_pack_list(key, pack_list) | |||||
| def get_pack_list(self, key): | |||||
| return self.proxy.get_pack_list(key) | |||||
| def group_barrier(self, key, size): | |||||
| self.proxy.group_barrier(key, size) | |||||
| @@ -0,0 +1,25 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| import functools | |||||
| import socket | |||||
| from typing import List | |||||
| def get_free_ports(num: int) -> List[int]: | |||||
| """Get one or more free ports. | |||||
| """ | |||||
| socks, ports = [], [] | |||||
| for i in range(num): | |||||
| sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | |||||
| sock.bind(("", 0)) | |||||
| socks.append(sock) | |||||
| ports.append(sock.getsockname()[1]) | |||||
| for sock in socks: | |||||
| sock.close() | |||||
| return ports | |||||
| @@ -0,0 +1,32 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # pylint: disable=redefined-builtin | |||||
| from . import distributed | |||||
| from .elemwise import * | |||||
| from .graph import add_update | |||||
| from .loss import ( | |||||
| binary_cross_entropy, | |||||
| cross_entropy, | |||||
| cross_entropy_with_softmax, | |||||
| hinge_loss, | |||||
| l1_loss, | |||||
| nll_loss, | |||||
| smooth_l1_loss, | |||||
| square_loss, | |||||
| triplet_margin_loss, | |||||
| ) | |||||
| from .math import * | |||||
| from .nn import * | |||||
| from .quantized import conv_bias_activation | |||||
| from .tensor import * | |||||
| from .utils import accuracy, zero_grad | |||||
| # delete namespace | |||||
| # pylint: disable=undefined-variable | |||||
| # del elemwise, graph, loss, math, nn, tensor # type: ignore[name-defined] | |||||
| @@ -0,0 +1,49 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| import os | |||||
| _conv_execution_strategy = os.getenv("MEGENGINE_CONV_EXECUTION_STRATEGY", "HEURISTIC") | |||||
| def get_conv_execution_strategy() -> str: | |||||
| """Returns the execuation strategy of :class:`~.Conv2d`. | |||||
| See :func:`~.set_conv_execution_strategy` for possible return values | |||||
| """ | |||||
| return _conv_execution_strategy | |||||
| def set_conv_execution_strategy(option: str): | |||||
| """Sets the execuation strategy of :class:`~.Conv2d`. | |||||
| :param option: Decides how :class:`~.Conv2d` algorithm is chosen. | |||||
| Available values: | |||||
| * 'HEURISTIC' uses heuristic to choose the fastest algorithm. | |||||
| * 'PROFILE' runs possible algorithms on real device to find the best. | |||||
| * 'PROFILE_HEURISTIC' uses profile result and heuristic to choose the fastest algorithm. | |||||
| * 'PROFILE_REPRODUCIBLE' uses the fastest of profile result that is also reproducible. | |||||
| * 'HEURISTIC_REPRODUCIBLE' uses heuristic to choose the fastest algorithm that is also reproducible. | |||||
| The default strategy is 'HEURISTIC'. | |||||
| It can also be set through the environmental variable 'MEGENGINE_CONV_EXECUTION_STRATEGY'. | |||||
| """ | |||||
| valid_option = ( | |||||
| "HEURISTIC", | |||||
| "PROFILE", | |||||
| "PROFILE_HEURISTIC", | |||||
| "PROFILE_REPRODUCIBLE", | |||||
| "HEURISTIC_REPRODUCIBLE", | |||||
| ) | |||||
| if not option in valid_option: | |||||
| raise ValueError("Valid option can only be one of {}".format(valid_option)) | |||||
| global _conv_execution_strategy # pylint: disable=global-statement | |||||
| _conv_execution_strategy = option | |||||
| @@ -0,0 +1,299 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| from typing import Optional, Tuple | |||||
| from ..core._imperative_rt.ops import CollectiveCommDefModeEnum | |||||
| from ..core.autodiff.builtin_op_utils import builtin_op_get_backward_fn | |||||
| from ..core.autodiff.grad import ( | |||||
| Tracer, | |||||
| check_backward_allow_noinput, | |||||
| get_grad_managers, | |||||
| get_op_has_grad_fn, | |||||
| tracer_apply, | |||||
| ) | |||||
| from ..core.ops.builtin import CollectiveComm, Copy, RemoteRecv, RemoteSend | |||||
| from ..core.tensor.core import apply | |||||
| from ..core.tensor.tensor import Tensor, tensor_apply | |||||
| from ..distributed.group import ( | |||||
| WORLD, | |||||
| Group, | |||||
| get_backend, | |||||
| get_client, | |||||
| get_mm_server_addr, | |||||
| get_rank, | |||||
| ) | |||||
| from ..tensor import tensor | |||||
| __all__ = [ | |||||
| "reduce_sum", | |||||
| "broadcast", | |||||
| "all_gather", | |||||
| "reduce_scatter_sum", | |||||
| "all_reduce_sum", | |||||
| "all_reduce_max", | |||||
| "all_reduce_min", | |||||
| "gather", | |||||
| "scatter", | |||||
| "all_to_all", | |||||
| "remote_send", | |||||
| "remote_recv", | |||||
| ] | |||||
| @apply.add | |||||
| def _(op: RemoteSend, *args: Tensor): | |||||
| ret = tensor_apply(op, *args) | |||||
| # set extra information | |||||
| tracer_set = dict() | |||||
| for k in set().union(*(i._extra_data for i in args if isinstance(i, Tensor))): | |||||
| tracer_set[k.name] = True | |||||
| # check tracer_set in remote_recv | |||||
| get_client().set_remote_tracer(op.key, tracer_set) | |||||
| return ret | |||||
| @builtin_op_get_backward_fn.register(RemoteSend) | |||||
| def _(op: RemoteSend, inputs, outputs, input_requires_grad): | |||||
| def backward(*args): | |||||
| return [ | |||||
| remote_recv( | |||||
| op.rank_to, inputs[0].shape, inputs[0].dtype, str(inputs[0].device) | |||||
| ) | |||||
| ] | |||||
| return backward, [True] | |||||
| @get_op_has_grad_fn.register(RemoteSend) | |||||
| def _(op: RemoteSend): | |||||
| def has_grad(opnode, reached): | |||||
| return get_client().check_is_grad(op.key) | |||||
| return has_grad | |||||
| @check_backward_allow_noinput.register(RemoteSend) | |||||
| def _(op: RemoteSend): | |||||
| return True | |||||
| @builtin_op_get_backward_fn.register(RemoteRecv) | |||||
| def _(op: RemoteRecv, inputs, outputs, input_requires_grad): | |||||
| def backward(*output_grads): | |||||
| return [remote_send(output_grads[0], op.rank_from)] | |||||
| return backward, [True] | |||||
| @get_op_has_grad_fn.register(RemoteRecv) | |||||
| def _(op: RemoteRecv): | |||||
| def has_grad(opnode, reached): | |||||
| ret = False | |||||
| for v in opnode.outputs: | |||||
| if v() in reached: | |||||
| ret = True | |||||
| break | |||||
| get_client().set_is_grad(op.key, ret) | |||||
| return ret | |||||
| return has_grad | |||||
| def collective_comm(inp, mode, group, device): | |||||
| """Helper function for applying collective communication functions""" | |||||
| assert isinstance(group, Group) | |||||
| if group is None: | |||||
| return inp | |||||
| op = CollectiveComm() | |||||
| op.key = group.key | |||||
| op.nr_devices = group.size | |||||
| op.rank = group.rank | |||||
| op.is_root = op.rank == 0 | |||||
| op.local_grad = False | |||||
| op.addr, op.port = get_mm_server_addr() | |||||
| op.mode = mode | |||||
| op.dtype = inp.dtype | |||||
| op.backend = get_backend() | |||||
| op.comp_node = device | |||||
| return apply(op, inp)[0] | |||||
| def reduce_sum( | |||||
| inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||||
| ) -> Tensor: | |||||
| """Create reduce_sum operator for collective communication | |||||
| :param inp: input tensor | |||||
| :param group: communication group | |||||
| :param device: execute placement | |||||
| """ | |||||
| mode = CollectiveCommDefModeEnum.REDUCE_SUM | |||||
| return collective_comm(inp, mode, group, device) | |||||
| def broadcast( | |||||
| inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||||
| ) -> Tensor: | |||||
| """Create broadcast operator for collective communication | |||||
| :param inp: input tensor | |||||
| :param group: communication group | |||||
| :param device: execute placement | |||||
| """ | |||||
| mode = CollectiveCommDefModeEnum.BROADCAST | |||||
| return collective_comm(inp, mode, group, device) | |||||
| def all_gather( | |||||
| inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||||
| ) -> Tensor: | |||||
| """Create all_gather operator for collective communication | |||||
| :param inp: input tensor | |||||
| :param group: communication group | |||||
| :param device: execute placement | |||||
| """ | |||||
| mode = CollectiveCommDefModeEnum.ALL_GATHER | |||||
| return collective_comm(inp, mode, group, device) | |||||
| def reduce_scatter_sum( | |||||
| inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||||
| ) -> Tensor: | |||||
| """Create reduce_scatter_sum operator for collective communication | |||||
| :param inp: input tensor | |||||
| :param group: communication group | |||||
| :param device: execute placement | |||||
| """ | |||||
| mode = CollectiveCommDefModeEnum.REDUCE_SCATTER_SUM | |||||
| return collective_comm(inp, mode, group, device) | |||||
| def all_reduce_sum( | |||||
| inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||||
| ) -> Tensor: | |||||
| """Create all_reduce_sum operator for collective communication | |||||
| :param inp: input tensor | |||||
| :param group: communication group | |||||
| :param device: execute placement | |||||
| """ | |||||
| mode = CollectiveCommDefModeEnum.ALL_REDUCE_SUM | |||||
| return collective_comm(inp, mode, group, device) | |||||
| def all_reduce_max( | |||||
| inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||||
| ) -> Tensor: | |||||
| """Create all_reduce_max operator for collective communication | |||||
| :param inp: input tensor | |||||
| :param group: communication group | |||||
| :param device: execute placement | |||||
| """ | |||||
| mode = CollectiveCommDefModeEnum.ALL_REDUCE_MAX | |||||
| return collective_comm(inp, mode, group, device) | |||||
| def all_reduce_min( | |||||
| inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||||
| ) -> Tensor: | |||||
| """Create all_reduce_min operator for collective communication | |||||
| :param inp: input tensor | |||||
| :param group: communication group | |||||
| :param device: execute placement | |||||
| """ | |||||
| mode = CollectiveCommDefModeEnum.ALL_REDUCE_MIN | |||||
| return collective_comm(inp, mode, group, device) | |||||
| def gather( | |||||
| inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||||
| ) -> Tensor: | |||||
| """Create gather operator for collective communication | |||||
| :param inp: input tensor | |||||
| :param group: communication group | |||||
| :param device: execute placement | |||||
| """ | |||||
| mode = CollectiveCommDefModeEnum.GATHER | |||||
| return collective_comm(inp, mode, group, device) | |||||
| def scatter( | |||||
| inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||||
| ) -> Tensor: | |||||
| """Create scatter operator for collective communication | |||||
| :param inp: input tensor | |||||
| :param group: communication group | |||||
| :param device: execute placement | |||||
| """ | |||||
| mode = CollectiveCommDefModeEnum.SCATTER | |||||
| return collective_comm(inp, mode, group, device) | |||||
| def all_to_all( | |||||
| inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||||
| ) -> Tensor: | |||||
| """Create all_to_all operator for collective communication | |||||
| :param inp: input tensor | |||||
| :param group: communication group | |||||
| :param device: execute placement | |||||
| """ | |||||
| mode = CollectiveCommDefModeEnum.ALL_TO_ALL | |||||
| return collective_comm(inp, mode, group, device) | |||||
| def remote_send(inp: Tensor, dest_rank: int) -> Tensor: | |||||
| """Send a Tensor to a remote process | |||||
| :param inp: tensor to send | |||||
| :param dest_rank: destination process rank | |||||
| """ | |||||
| op = RemoteSend() | |||||
| op.key = "{}->{}".format(get_rank(), dest_rank) | |||||
| op.addr, op.port = get_mm_server_addr() | |||||
| op.rank_to = dest_rank | |||||
| return apply(op, inp)[0] | |||||
| def remote_recv( | |||||
| src_rank: int, shape: Tuple[int], dtype: type, cn: Optional[str] = "gpu0" | |||||
| ) -> Tensor: | |||||
| """Receive a Tensor from a remote process | |||||
| :param src_rank: source process rank | |||||
| :param shape: the shape of the tensor to receive | |||||
| :param dtype: the data type of the tensor to receive | |||||
| :param cn: the comp node to place the received tensor | |||||
| """ | |||||
| key = "{}->{}".format(src_rank, get_rank()) | |||||
| # dummpy input | |||||
| inp = tensor([0]) | |||||
| tracer_set = get_client().check_remote_tracer(key) | |||||
| for grad_manager in get_grad_managers(): | |||||
| if grad_manager.name in tracer_set: | |||||
| grad_manager.wrt(inp) | |||||
| op = RemoteRecv() | |||||
| op.key = key | |||||
| op.cn = cn | |||||
| op.shape = shape | |||||
| op.dtype = dtype | |||||
| op.addr, op.port = get_mm_server_addr() | |||||
| op.rank_from = src_rank | |||||
| return apply(op, inp)[0] | |||||
| @@ -0,0 +1,481 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # pylint: disable=unused-argument,invalid-name,redefined-builtin,arguments-out-of-order | |||||
| import functools | |||||
| from ..core.ops import builtin | |||||
| from ..core.tensor import utils | |||||
| from ..core.tensor.core import apply | |||||
| from ..tensor import Tensor | |||||
| __all__ = [ | |||||
| "abs", | |||||
| "add", | |||||
| "acos", | |||||
| "asin", | |||||
| "atan", | |||||
| "atan2", | |||||
| "asinh", | |||||
| "acosh", | |||||
| "atanh", | |||||
| "bitwise_and", # TODO | |||||
| "bitwise_not", # TODO | |||||
| "bitwise_or", # TODO | |||||
| "bitwise_xor", # TODO | |||||
| "ceil", | |||||
| "clamp", | |||||
| "cos", | |||||
| "cosh", | |||||
| "div", | |||||
| "eq", | |||||
| "exp", | |||||
| "expm1", | |||||
| "floor", | |||||
| "floor_div", | |||||
| "gt", | |||||
| "ge", | |||||
| "hswish", | |||||
| "hsigmoid", | |||||
| "left_shift", | |||||
| "lt", | |||||
| "le", | |||||
| "log", | |||||
| "log1p", | |||||
| "logical_and", | |||||
| "logical_not", | |||||
| "logical_or", | |||||
| "logical_xor", | |||||
| "maximum", | |||||
| "minimum", | |||||
| "mod", | |||||
| "mul", | |||||
| "neg", | |||||
| "ne", | |||||
| "pow", | |||||
| "relu", | |||||
| "relu6", | |||||
| "right_shift", | |||||
| "round", | |||||
| "sigmoid", | |||||
| "sin", | |||||
| "sinh", | |||||
| "sqrt", | |||||
| "square", | |||||
| "sub", | |||||
| "tan", | |||||
| "tanh", | |||||
| "fast_tanh", | |||||
| ] | |||||
| def _elwise(*args, mode): | |||||
| op = builtin.Elemwise(mode=mode) | |||||
| args = utils.convert_inputs(*args) | |||||
| (result,) = apply(op, *args) | |||||
| return result | |||||
| def _logical(*args, mode): | |||||
| op = builtin.CondExecPredLogical(mode=mode) | |||||
| args = utils.convert_inputs(*args) | |||||
| (result,) = apply(op, *args) | |||||
| return result | |||||
| def _elemwise_multi_type(*args, mode, **kwargs): | |||||
| op = builtin.ElemwiseMultiType(mode=mode, **kwargs) | |||||
| args = utils.convert_inputs(*args) | |||||
| (result,) = apply(op, *args) | |||||
| return result | |||||
| # math operations | |||||
| def add(x, y): | |||||
| """Element-wise addition. | |||||
| At least one operand should be tensor. | |||||
| same for sub/mul/div/floor_div/pow/mod/atan2/eq/ne/lt/le/gt/ge/maximum/minmium. | |||||
| """ | |||||
| return _elwise(x, y, mode="add") | |||||
| def sub(x, y): | |||||
| """Element-wise subtract.""" | |||||
| return _elwise(x, y, mode="sub") | |||||
| def mul(x, y): | |||||
| """Element-wise multiplication.""" | |||||
| return _elwise(x, y, mode="mul") | |||||
| def div(x, y): | |||||
| """Element-wise (x / y).""" | |||||
| return _elwise(x, y, mode="true_div") | |||||
| def floor_div(x, y): | |||||
| """Element-wise floor(x / y).""" | |||||
| return _elwise(x, y, mode="floor_divide") | |||||
| def neg(x): | |||||
| """Element-wise negation.""" | |||||
| return _elwise(x, mode="negate") | |||||
| def pow(x, y): | |||||
| """Element-wise power.""" | |||||
| return _elwise(x, y, mode="pow") | |||||
| def mod(x, y): | |||||
| """Element-wise remainder of division.""" | |||||
| return _elwise(x, y, mode="mod") | |||||
| def abs(x): | |||||
| """Element-wise absolute value.""" | |||||
| return _elwise(x, mode="abs") | |||||
| def exp(x): | |||||
| """Element-wise exponential.""" | |||||
| return _elwise(x, mode="exp") | |||||
| def expm1(x): | |||||
| """Element-wise exp(x)-1.""" | |||||
| return _elwise(x, mode="expm1") | |||||
| def log(x): | |||||
| """Element-wise logarithm (base `e`).""" | |||||
| return _elwise(x, mode="log") | |||||
| def log1p(x): | |||||
| """Element-wise log(x+1) (base `e`).""" | |||||
| return _elwise(x, mode="log1p") | |||||
| def sqrt(inp: Tensor) -> Tensor: | |||||
| """ | |||||
| Return a new tensor with the square-root of the elements of ``inp``. | |||||
| For negative value, return nan. | |||||
| :param inp: The input tensor | |||||
| :return: The computed tensor | |||||
| Examples: | |||||
| .. testcode:: | |||||
| import numpy as np | |||||
| import megengine as mge | |||||
| import megengine.functional as F | |||||
| data = mge.tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3)) | |||||
| out = F.sqrt(data) | |||||
| print(out.numpy()) | |||||
| Outputs: | |||||
| .. testoutput:: | |||||
| [[0. 1. 1.4142] | |||||
| [1.7321 2. 2.2361 ]] | |||||
| """ | |||||
| return inp ** 0.5 | |||||
| def square(inp: Tensor) -> Tensor: | |||||
| """ | |||||
| Return a new tensor with the square of the elements of ``inp`` | |||||
| :param inp: The input tensor | |||||
| :return: The computed tensor | |||||
| Examples: | |||||
| .. testcode:: | |||||
| import numpy as np | |||||
| import megengine as mge | |||||
| import megengine.functional as F | |||||
| data = mge.tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3)) | |||||
| out = F.square(data) | |||||
| print(out.numpy()) | |||||
| Outputs: | |||||
| .. testoutput:: | |||||
| [[0. 1. 4.] | |||||
| [9. 16. 25.]] | |||||
| """ | |||||
| return inp ** 2 | |||||
| def round(x): | |||||
| """Round tensor to int element-wise.""" | |||||
| return _elwise(x, mode="round") | |||||
| def ceil(x): | |||||
| """Return the ceil of the input, element-wise.""" | |||||
| return _elwise(x, mode="ceil") | |||||
| def floor(x): | |||||
| """Calculate the floor element-wise""" | |||||
| return _elwise(x, mode="floor") | |||||
| # trigonometric functions | |||||
| def cos(x): | |||||
| """Cosine, element-wise.""" | |||||
| return _elwise(x, mode="cos") | |||||
| def sin(x): | |||||
| """Sine, element-wise.""" | |||||
| return _elwise(x, mode="sin") | |||||
| def tan(x): | |||||
| return sin(x) / cos(x) | |||||
| def acos(x): | |||||
| """Inverse cosine, element-wise.""" | |||||
| return _elwise(x, mode="acos") | |||||
| def asin(x): | |||||
| """Inverse sine, element-wise.""" | |||||
| return _elwise(x, mode="asin") | |||||
| def atan(x): | |||||
| return _elwise(x, 1, mode="atan2") | |||||
| def atan2(y, x): | |||||
| return _elwise(y, x, mode="atan2") | |||||
| def cosh(x): | |||||
| r"""Compute element-wise hyperbolic cosine.""" | |||||
| return 0.5 * (exp(x) + exp(-x)) | |||||
| def sinh(x): | |||||
| r"""Compute element-wise hyperbolic sine.""" | |||||
| u = expm1(x) | |||||
| return 0.5 * u / (u + 1) * (u + 2) | |||||
| def tanh(x): | |||||
| r"""Compute element-wise hyperbolic tangent.""" | |||||
| return _elwise(x, mode="tanh") | |||||
| def asinh(x): | |||||
| r"""Compute element-wise inverse hyperbolic sine.""" | |||||
| return log(x + (x ** 2 + 1) ** 0.5) | |||||
| def acosh(x): | |||||
| r"""Compute element-wise inverse hyperbolic cosine.""" | |||||
| return log(x + (x ** 2 - 1) ** 0.5) | |||||
| def atanh(x): | |||||
| r"""Compute element-wise inverse hyperbolic tangent.""" | |||||
| return log1p(2 * x / (1 - x)) / 2 | |||||
| def fast_tanh(x): | |||||
| r"""Compute element-wise fast tanh; this is an approximation: | |||||
| .. math:: | |||||
| \text{fast_tanh}(x) = x * (27. + x * x) / (27. + 9. * x * x) | |||||
| """ | |||||
| return _elwise(x, mode="fast_tanh") | |||||
| # bit-twiddling functions | |||||
| def left_shift(x, y): | |||||
| return _elwise(x, y, mode="shl") | |||||
| def right_shift(x, y): | |||||
| return _elwise(x, y, mode="shl") | |||||
| def bitwise_and(x, y): | |||||
| raise NotImplementedError | |||||
| def bitwise_not(x): | |||||
| raise NotImplementedError | |||||
| def bitwise_or(x, y): | |||||
| raise NotImplementedError | |||||
| def bitwise_xor(x, y): | |||||
| raise NotImplementedError | |||||
| # logical functions | |||||
| def logical_and(x, y): | |||||
| return _elwise(x, y, mode="AND") | |||||
| def logical_not(x): | |||||
| return _elwise(x, mode="NOT") | |||||
| def logical_or(x, y): | |||||
| return _elwise(x, y, mode="OR") | |||||
| def logical_xor(x, y): | |||||
| return _elwise(x, y, mode="XOR") | |||||
| # comparison functions | |||||
| def eq(x, y): | |||||
| """Return (x == y) element-wise.""" | |||||
| return _elwise(x, y, mode="eq") | |||||
| def ne(x, y): | |||||
| return x != y | |||||
| def lt(x, y): | |||||
| """Return (x < y) element-wise.""" | |||||
| return _elwise(x, y, mode="lt") | |||||
| def le(x, y): | |||||
| """Return (x =< y) element-wise.""" | |||||
| return _elwise(x, y, mode="leq") | |||||
| def gt(x, y): | |||||
| """Return (x > y) element-wise.""" | |||||
| return _elwise(y, x, mode="lt") | |||||
| def ge(x, y): | |||||
| """Return (x >= y) element-wise""" | |||||
| return _elwise(y, x, mode="leq") | |||||
| def hswish(x): | |||||
| """Return x * relu6(x + 3) / 6 element-wise""" | |||||
| return _elwise(x, mode="h_swish") | |||||
| def hsigmoid(x): | |||||
| """Return relu6(x + 3) / 6 element-wise""" | |||||
| return relu6(x + 3) / 6 | |||||
| def relu(x): | |||||
| """Return `max(x, 0)` element-wise.""" | |||||
| return _elwise(x, mode="relu") | |||||
| def relu6(x): | |||||
| """Return min(max(x, 0), 6) element-wise.""" | |||||
| return minimum(maximum(x, 0), 6) | |||||
| def sigmoid(x): | |||||
| """Return 1 / ( 1 + exp( -x ) ) element-wise.""" | |||||
| return _elwise(x, mode="sigmoid") | |||||
| def maximum(x, y): | |||||
| """Element-wise maximum of array elements.""" | |||||
| return _elwise(x, y, mode="max") | |||||
| def minimum(x, y): | |||||
| """Element-wise minimum of array elements.""" | |||||
| return _elwise(x, y, mode="min") | |||||
| def clamp(inp: Tensor, lower=None, upper=None) -> Tensor: | |||||
| r""" | |||||
| Clamp all elements in :attr:`inp` into the range `[` :attr:`lower`, :attr:`upper` `]` and return | |||||
| a resulting tensor: | |||||
| .. math:: | |||||
| y_i = \begin{cases} | |||||
| \text{lower} & \text{if } x_i < \text{lower} \\ | |||||
| x_i & \text{if } \text{lower} \leq x_i \leq \text{upper} \\ | |||||
| \text{upper} & \text{if } x_i > \text{upper} | |||||
| \end{cases} | |||||
| :param inp: the input tensor. | |||||
| :param lower: lower-bound of the range to be clamped to | |||||
| :param upper: upper-bound of the range to be clamped to | |||||
| Example: | |||||
| .. testcode:: | |||||
| import numpy as np | |||||
| from megengine import tensor | |||||
| import megengine.functional as F | |||||
| a = tensor(np.arange(5).astype(np.int32)) | |||||
| print(F.clamp(a, 2, 4).numpy()) | |||||
| print(F.clamp(a, lower=3).numpy()) | |||||
| print(F.clamp(a, upper=3).numpy()) | |||||
| .. testoutput:: | |||||
| [2 2 2 3 4] | |||||
| [3 3 3 3 4] | |||||
| [0 1 2 3 3] | |||||
| """ | |||||
| assert ( | |||||
| lower is not None or upper is not None | |||||
| ), "At least one of 'lower' or 'upper' must not be None" | |||||
| if lower is not None: | |||||
| if upper is not None: | |||||
| assert lower <= upper, "clamp lower bound is bigger that upper bound" | |||||
| return minimum(maximum(inp, lower), upper) | |||||
| else: | |||||
| return maximum(inp, lower) | |||||
| else: | |||||
| return minimum(inp, upper) | |||||
| @@ -0,0 +1,44 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # pylint: disable=too-many-lines | |||||
| from typing import List | |||||
| from ..core import Tensor | |||||
| def cambricon_subgraph( | |||||
| inputs: List[Tensor], data: bytes, symbol: str, tensor_dim_mutable: bool, | |||||
| ) -> List[Tensor]: | |||||
| """Load a serialized Cambricon subgraph (i.e. cnrtModel_t) and | |||||
| execute the operations defined in the subgraph. | |||||
| :param inputs: List of input tensors of the subgraph. | |||||
| :param data: The serialized subgraph. | |||||
| :param symbol: The name of the function in the subgraph. | |||||
| The function is corresponding to a cnmlFusionOp | |||||
| which is added to the cnmlModel_t/cnrtModel_t. | |||||
| :param tensor_dim_mutable: Whether the input tensors' shapes are mutalbe | |||||
| in cnrtModel_t | |||||
| """ | |||||
| raise NotImplementedError | |||||
| def extern_opr_subgraph( | |||||
| inputs, output_shapes: List[tuple], dump_name: str, dump_data: bytes, | |||||
| ) -> List[Tensor]: | |||||
| """Load a serialized extern opr subgraph and fake execute the operator | |||||
| :param inputs: Tensor or list of input tensors. | |||||
| :param output_shapes: The output shapes. | |||||
| :param dump_name: The serialized subgraph name. | |||||
| :param dump_data: The serialized subgraph. | |||||
| :return: List of tensors | |||||
| """ | |||||
| raise NotImplementedError | |||||
| @@ -0,0 +1,41 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| import collections | |||||
| from typing import Iterable, Optional, Union | |||||
| from ..core.tensor import Tensor | |||||
| def add_update( | |||||
| dest: Tensor, | |||||
| delta: Tensor, | |||||
| *, | |||||
| alpha: Union[Tensor, float, int] = 1.0, | |||||
| beta: Union[Tensor, float, int] = 1.0, | |||||
| bias: Union[Tensor, float, int] = 0.0 | |||||
| ): | |||||
| r"""Inplace modify ``dest`` as follows: | |||||
| .. math:: | |||||
| dest = alpha * dest + beta * delta + bias | |||||
| :param dest: input data that will be inplace modified. | |||||
| :param delta: update value that will be added to ``dest``. | |||||
| :param alpha: weight ratio of ``dest``. Default: 1.0 | |||||
| :param beta: weight ratio of ``delta``. Default: 1.0 | |||||
| :param bias: bias value appended to the result. Default: 0.0 | |||||
| """ | |||||
| if beta is not None and beta != 1.0: | |||||
| delta = delta * beta | |||||
| if bias is not None and bias != 0.0: | |||||
| delta = delta + bias | |||||
| if alpha is not None and alpha != 1.0: | |||||
| dest *= alpha | |||||
| dest += delta | |||||
| return dest | |||||
| @@ -0,0 +1,388 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| import numpy as np | |||||
| from ..tensor import Tensor | |||||
| from .elemwise import abs, eq, exp, log, maximum, pow, relu | |||||
| from .nn import assert_equal, indexing_one_hot | |||||
| from .tensor import where | |||||
| from .utils import zero_grad | |||||
| def l1_loss(pred: Tensor, label: Tensor) -> Tensor: | |||||
| r""" | |||||
| Calculates the mean absolute error (MAE) between | |||||
| each element in the pred :math:`x` and label :math:`y`. | |||||
| The mean absolute error can be described as: | |||||
| .. math:: \ell(x,y) = mean\left(L \right) | |||||
| where | |||||
| .. math:: | |||||
| L = \{l_1,\dots,l_N\}, \quad | |||||
| l_n = \left| x_n - y_n \right|, | |||||
| :math:`x` and :math:`y` are tensors of arbitrary shapes with a total | |||||
| of :math:`N` elements each. :math:`N` is the batch size. | |||||
| :param pred: The predicted result from model. | |||||
| :param label: The ground truth to compare. | |||||
| Examples: | |||||
| .. testcode:: | |||||
| import numpy as np | |||||
| import megengine as mge | |||||
| import megengine.functional as F | |||||
| ipt = mge.tensor(np.array([3, 3, 3, 3]).astype(np.float32)) | |||||
| tgt = mge.tensor(np.array([2, 8, 6, 1]).astype(np.float32)) | |||||
| loss = F.l1_loss(ipt,tgt) | |||||
| print(loss.numpy()) | |||||
| Outputs: | |||||
| .. testoutput:: | |||||
| [2.75] | |||||
| """ | |||||
| diff = pred - label | |||||
| return abs(diff).mean() | |||||
| def square_loss(pred: Tensor, label: Tensor) -> Tensor: | |||||
| r""" | |||||
| Calculates the mean squared error (squared L2 norm) between | |||||
| each element in the pred :math:`x` and label :math:`y`. | |||||
| The mean squared error can be described as: | |||||
| .. math:: \ell(x, y) = mean\left( L \right) | |||||
| where | |||||
| .. math:: | |||||
| L = \{l_1,\dots,l_N\}, \quad | |||||
| l_n = \left( x_n - y_n \right)^2, | |||||
| :math:`x` and :math:`y` are tensors of arbitrary shapes with a total | |||||
| of :math:`N` elements each. :math:`N` is the batch size. | |||||
| :param pred: The predicted result from model. | |||||
| :param label: The ground truth to compare. | |||||
| Shape: | |||||
| - pred: :math:`(N, *)` where :math:`*` means any number of additional | |||||
| dimensions | |||||
| - label: :math:`(N, *)`. Same shape as ``pred`` | |||||
| """ | |||||
| diff = pred - label | |||||
| return (diff ** 2).mean() | |||||
| def cross_entropy( | |||||
| inp: Tensor, target: Tensor, axis: int = 1, ignore_index: int = -1 | |||||
| ) -> Tensor: | |||||
| r""" | |||||
| Returns the cross entropy loss in a classification problem. | |||||
| .. math:: \textrm{CrossEntropy}(x, y) = - \sum_{i} y_i\log(x_i) | |||||
| :param inp: The input tensor representing the predicted probability. | |||||
| :param label: The input tensor representing the classification label. | |||||
| :param axis: An axis along which cross_entropy will be applied. Default: 1 | |||||
| :param ignore_index: Specifies a target value that is ignored and does not contribute to the input gradient. Default: -1 | |||||
| Examples: | |||||
| .. testcode:: | |||||
| import numpy as np | |||||
| from megengine import tensor | |||||
| import megengine.functional as F | |||||
| data_shape = (1, 2) | |||||
| label_shape = (1, ) | |||||
| pred = tensor(np.array([0.5, 0.5], dtype=np.float32).reshape(data_shape)) | |||||
| label = tensor(np.ones(label_shape, dtype=np.int32)) | |||||
| loss = F.cross_entropy(pred, label) | |||||
| print(loss.numpy()) | |||||
| Outputs: | |||||
| .. testoutput:: | |||||
| [0.69] | |||||
| """ | |||||
| raise NotImplementedError | |||||
| # n0 = inp.ndim | |||||
| # n1 = target.ndim | |||||
| # assert n0 == n1 + 1, ( | |||||
| # "target ndim must be one less than input ndim; input_ndim={} " | |||||
| # "target_ndim={}".format(n0, n1) | |||||
| # ) | |||||
| # if ignore_index != -1: | |||||
| # mask = 1 - equal(target, ignore_index) | |||||
| # target = target * mask | |||||
| # loss = -log(indexing_one_hot(inp, target, axis)) * mask | |||||
| # return loss.sum() / maximum(mask.sum(), 1.0) | |||||
| # else: | |||||
| # return -log(indexing_one_hot(inp, target, axis)).mean() | |||||
| def cross_entropy_with_softmax( | |||||
| pred: Tensor, label: Tensor, axis: int = 1, label_smooth: float = 0 | |||||
| ) -> Tensor: | |||||
| r""" | |||||
| Returns loss after applying :func:`~.softmax` + :func:`~.cross_entropy`. | |||||
| It has better numerical stability compared with sequential calls to :func:`~.softmax` and :func:`~.cross_entropy`. | |||||
| When using label smoothing, the label distribution is as follows: | |||||
| .. math:: y^{LS}_{k}=y_{k}\left(1-\alpha\right)+\alpha/K | |||||
| where :math:`y^{LS}` and :math:`y` are new label distribution and origin label distribution respectively. | |||||
| k is the index of label distribution. :math:`\alpha` is label_smooth and :math:`K` is the number of classes. | |||||
| :param pred: The input tensor representing the predicted probability. | |||||
| :param label: The input tensor representing the classification label. | |||||
| :param axis: An axis along which softmax will be applied. Default: 1. | |||||
| :param label_smooth: A label smoothing of parameter that can re-distribute target distribution. Default: 0. | |||||
| """ | |||||
| n0 = pred.ndim | |||||
| n1 = label.ndim | |||||
| assert n0 == n1 + 1, ( | |||||
| "target ndim must be one less than input ndim; input_ndim={} " | |||||
| "target_ndim={}".format(n0, n1) | |||||
| ) | |||||
| num_classes = pred.shape[axis] | |||||
| # Denominator of the softmax | |||||
| offset = pred.max(axis=axis).detach() | |||||
| pred = pred - offset | |||||
| down = exp(pred).sum(axis=axis) | |||||
| up = pred[np.arange(pred.shape[0]), label] | |||||
| if label_smooth != 0: | |||||
| factor = label_smooth / num_classes | |||||
| up = up * (1 - label_smooth) + pred.sum(axis=axis) * factor | |||||
| return (log(down) - up).mean() | |||||
| def triplet_margin_loss( | |||||
| anchor: Tensor, positive: Tensor, negative: Tensor, margin: float = 1.0, p: int = 2 | |||||
| ) -> Tensor: | |||||
| r""" | |||||
| Creates a criterion that measures the triplet loss given an input tensors. | |||||
| .. math:: | |||||
| L(a, p, n) = max\left\{d\left(a_{i},p_{i}\right)-d\left(a_{i}, n_{i}\right)+margin, 0\right\},\ | |||||
| d\left(x_{i},y_{i}\right)=\left\|x_{i}-y_{i}\right\|_{p} | |||||
| :param anchor: The input tensor representing the anchor samples. | |||||
| :param positive: The input tensor representing the positive samples. | |||||
| :param negative: The input tensor representing the negative samples. | |||||
| :param margin: Default: 1.0 | |||||
| :param p: The norm degree for pairwise distance. Default: 2.0 | |||||
| """ | |||||
| s0 = anchor.shapeof() | |||||
| s1 = positive.shapeof() | |||||
| s2 = negative.shapeof() | |||||
| assert_equal(s0, s1) | |||||
| assert_equal(s1, s2) | |||||
| n0 = anchor.ndim | |||||
| n1 = positive.ndim | |||||
| n2 = negative.ndim | |||||
| assert n0 == 2 and n1 == 2 and n2 == 2, ( | |||||
| "anchor ndim, positive ndim, and negative ndim must be 2; " | |||||
| "anchor_ndim={} positive_ndim={} negative_ndim={}".format(n0, n1, n2) | |||||
| ) | |||||
| assert p > 0, "a margin with a value greater than 0; p={}".format(p) | |||||
| diff0 = abs(anchor - positive) | |||||
| diff1 = abs(anchor - negative) | |||||
| d1 = power(power(diff0, p).sum(axis=1, keepdims=True), 1 / p) | |||||
| d2 = power(power(diff1, p).sum(axis=1, keepdims=True), 1 / p) | |||||
| loss = maximum(d1 - d2 + margin, 0) | |||||
| return loss.mean() | |||||
| def binary_cross_entropy(pred: Tensor, label: Tensor) -> Tensor: | |||||
| r"""Function that measures the Binary Cross Entropy between the target and the prediction. | |||||
| :param pred: (N,*) where * means, any number of additional dimensions. | |||||
| :param label: (N,*), same shape as the input. | |||||
| """ | |||||
| assert pred.shape == label.shape | |||||
| return -1.0 * (label * log(pred) + (1.0 - label) * log(1 - pred)).mean() | |||||
| def nll_loss( | |||||
| pred: Tensor, label: Tensor, axis: int = 1, ignore_index: int = -1 | |||||
| ) -> Tensor: | |||||
| r""" | |||||
| The negative log likelihood loss. | |||||
| :param pred: The predicted result from model. | |||||
| :param label: The ground truth to compare. | |||||
| Examples: | |||||
| .. testcode:: | |||||
| import numpy as np | |||||
| from megengine import tensor | |||||
| import megengine.functional as F | |||||
| data_shape = (2, 2) | |||||
| label_shape = (2, ) | |||||
| data = tensor( | |||||
| np.array([[1, 0.5], [0.3, 1.2]], dtype=np.float32).reshape(data_shape), | |||||
| ) | |||||
| label = tensor( | |||||
| np.ones(label_shape, dtype=np.int32) | |||||
| ) | |||||
| pred = F.log(F.softmax(data)) | |||||
| loss1 = F.nll_loss(pred, label) | |||||
| loss2 = F.cross_entropy_with_softmax(data, label) | |||||
| print(loss1.numpy(), loss2.numpy()) | |||||
| Outputs: | |||||
| .. testoutput:: | |||||
| [0.6576154] [0.6576154] | |||||
| """ | |||||
| raise NotImplementedError | |||||
| # n0 = pred.ndim | |||||
| # n1 = label.ndim | |||||
| # assert n0 == n1 + 1, ( | |||||
| # "target ndim must be one less than input ndim; input_ndim={} " | |||||
| # "target_ndim={}".format(n0, n1) | |||||
| # ) | |||||
| # mask = 1.0 - equal(label, ignore_index) | |||||
| # label = label * mask | |||||
| # loss = indexing_one_hot(pred, label, axis) * mask | |||||
| # return -1.0 * loss.sum() / maximum(mask.sum(), 1.0) | |||||
| def hinge_loss(pred: Tensor, label: Tensor, norm: str = "L1") -> Tensor: | |||||
| r""" | |||||
| Caculate the hinge loss which is often used in SVMs. | |||||
| The hinge loss can be described as: | |||||
| .. math:: loss(x, y) = \frac{1}{N}\sum_i\sum_j(max(0, 1 - x_i_j*y_i_j)) | |||||
| :param pred: The input tensor representing the predicted probability, shape is (N, C). | |||||
| :param label: The input tensor representing the binary classification label, shape is (N, C). | |||||
| :param norm: Specify the norm to caculate the loss, should be "L1" or "L2". | |||||
| Examples: | |||||
| .. testcode:: | |||||
| from megengine import tensor | |||||
| import megengine.functional as F | |||||
| pred = tensor([[0.5, -0.5, 0.1], [-0.6, 0.7, 0.8]], dtype="float32") | |||||
| label = tensor([[1, -1, -1], [-1, 1, 1]], dtype="float32") | |||||
| loss = F.hinge_loss(pred, label) | |||||
| print(loss.numpy()) | |||||
| Outputs: | |||||
| .. testoutput:: | |||||
| [1.5] | |||||
| """ | |||||
| assert norm in ["L1", "L2"], "norm must be L1 or L2" | |||||
| # Converts binary labels to -1/1 labels. | |||||
| loss = relu(1.0 - pred * label) | |||||
| if norm == "L1": | |||||
| return loss.sum(axis=1).mean() | |||||
| else: | |||||
| return (loss ** 2).sum(axis=1).mean() | |||||
| def smooth_l1_loss(pred: Tensor, label: Tensor) -> Tensor: | |||||
| r""" | |||||
| Caculate the smooth l1 loss proposed in `Fast R-CNN paper by Ross Girshick`. | |||||
| The smooth l1 loss can be described as: | |||||
| .. math:: | |||||
| \text{loss}(x, y) = \frac{1}{n} \sum_{i} l_{i} | |||||
| where :math:`l_{i}` is given by: | |||||
| .. math:: | |||||
| l_{i} = | |||||
| \begin{cases} | |||||
| 0.5 (x_i - y_i)^2, & \text{if } |x_i - y_i| < 1 \\ | |||||
| |x_i - y_i| - 0.5, & \text{otherwise } | |||||
| \end{cases} | |||||
| :param pred: The predicted result from model. | |||||
| :param label: The ground truth to compare. | |||||
| Examples: | |||||
| .. testcode:: | |||||
| from megengine import tensor | |||||
| import megengine.functional as F | |||||
| pred = tensor([[0.5, -0.5, 0.1], [-0.6, 0.7, 0.8]]) | |||||
| label = tensor([[0.4, 1.5, 1.2], [0., 0.1, 2.2]]) | |||||
| loss = F.smooth_l1_loss(pred, label) | |||||
| print(loss.numpy()) | |||||
| Outputs: | |||||
| .. testoutput:: | |||||
| [0.5608334] | |||||
| """ | |||||
| raise NotImplementedError | |||||
| # diff = abs(pred - label) | |||||
| # l2_loss = 0.5 * (diff ** 2) | |||||
| # l1_loss = diff - 0.5 | |||||
| # mask = diff < 1 | |||||
| # loss = where(mask, l2_loss, l1_loss) | |||||
| # return loss.mean() | |||||
| @@ -0,0 +1,696 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| import collections | |||||
| import functools | |||||
| import math | |||||
| import numbers | |||||
| from typing import Optional, Sequence, Tuple, Union | |||||
| from ..core.ops import builtin | |||||
| from ..core.ops._internal import param_defs as P | |||||
| from ..core.tensor import utils | |||||
| from ..core.tensor.core import apply | |||||
| from ..tensor import Tensor | |||||
| from .elemwise import clamp, exp, log, log1p | |||||
| from .tensor import remove_axis, reshape | |||||
| __all__ = [ | |||||
| "all", # TODO | |||||
| "all_close", # TODO | |||||
| "any", # TODO | |||||
| "argmax", | |||||
| "argmin", | |||||
| "argsort", | |||||
| "isinf", | |||||
| "isnan", # TODO | |||||
| "max", | |||||
| "mean", | |||||
| "median", # TODO | |||||
| "min", | |||||
| "norm", | |||||
| "normalize", | |||||
| "prod", | |||||
| "sign", # TODO | |||||
| "sort", | |||||
| "std", | |||||
| "sum", | |||||
| "topk", | |||||
| "unique", # TODO | |||||
| "var", | |||||
| ] | |||||
| def all(inp): | |||||
| raise NotImplementedError | |||||
| def all_close(inp): | |||||
| raise NotImplementedError | |||||
| def any(inp): | |||||
| raise NotImplementedError | |||||
| def unique(inp): | |||||
| raise NotImplementedError | |||||
| def isnan(inp: Tensor) -> Tensor: | |||||
| r"""Returns a new tensor representing if each element is NaN or not. | |||||
| :param: inp | |||||
| :return: a new tensor representing if each element in :attr:`inp` is NaN or not. | |||||
| Examples: | |||||
| .. testcode:: | |||||
| from megengine import tensor | |||||
| import megengine.functional as F | |||||
| x = tensor([1, float("nan"), 0]) | |||||
| print(F.isnan(x)) | |||||
| .. testoutput:: | |||||
| Tensor([0 1 0], dtype=uint8) | |||||
| """ | |||||
| raise NotImplementedError | |||||
| # return (inp != inp).astype("uint8") | |||||
| def isinf(inp: Tensor) -> Tensor: | |||||
| r"""Returns a new tensor representing if each element is Inf or not. | |||||
| :param: inp | |||||
| :return: a new tensor representing if each element in :attr:`inp` is Inf or not. | |||||
| Examples: | |||||
| .. testcode:: | |||||
| from megengine import tensor | |||||
| import megengine.functional as F | |||||
| x = tensor([1, float("inf"), 0]) | |||||
| print(F.isinf(x)) | |||||
| .. testoutput:: | |||||
| Tensor([0 1 0], dtype=uint8) | |||||
| """ | |||||
| return (abs(inp).astype("float32") == float("inf")).astype("uint8") | |||||
| def sign(inp: Tensor): | |||||
| raise NotImplementedError | |||||
| def _reduce( | |||||
| data, | |||||
| *, | |||||
| mode, | |||||
| axis: Optional[Union[int, Sequence[int]]] = None, | |||||
| keepdims: bool = False | |||||
| ): | |||||
| (data,) = utils.convert_inputs(data) | |||||
| if axis is None: | |||||
| data = data.reshape(-1) | |||||
| assert not keepdims, "can not set axis=None and keepdims=True" | |||||
| op = builtin.Reduce(mode=mode, axis=0) | |||||
| (result,) = apply(op, data) | |||||
| elif isinstance(axis, collections.Iterable): | |||||
| axis = list(axis) | |||||
| axis.sort(reverse=True) | |||||
| for ai in axis: | |||||
| op = builtin.Reduce(mode=mode, axis=ai) | |||||
| (data,) = apply(op, data) | |||||
| if not keepdims: | |||||
| data = remove_axis(data, ai) | |||||
| result = data | |||||
| else: | |||||
| op = builtin.Reduce(mode=mode, axis=axis) | |||||
| (result,) = apply(op, data) | |||||
| if not keepdims: | |||||
| result = remove_axis(result, axis) | |||||
| return result | |||||
| def sum( | |||||
| inp: Tensor, | |||||
| axis: Optional[Union[int, Sequence[int]]] = None, | |||||
| keepdims: bool = False, | |||||
| ) -> Tensor: | |||||
| r"""Returns the sum of each row of the ``inp`` tensor in the given ``axis``. | |||||
| :param inp: The input tensor. | |||||
| :param axis: The dimension to reduce. If None, all the dimensions will be reduced. | |||||
| Default: None | |||||
| :param keepdims: Whether the output tensor has ``axis`` retained or not. | |||||
| Default: False | |||||
| :return: The output tensor | |||||
| Examples: | |||||
| .. testcode:: | |||||
| import numpy as np | |||||
| from megengine import tensor | |||||
| import megengine.functional as F | |||||
| data = tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3)) | |||||
| out = F.sum(data) | |||||
| print(out.numpy()) | |||||
| .. testoutput:: | |||||
| [21] | |||||
| """ | |||||
| return _reduce(inp, mode="SUM", axis=axis, keepdims=keepdims) | |||||
| def prod( | |||||
| inp: Tensor, axis: Optional[Union[int, Sequence[int]]] = None, keepdims=False | |||||
| ) -> Tensor: | |||||
| r""" | |||||
| Returns the element product of input tensor along given *axis*. | |||||
| :param inp: The input tensor | |||||
| :param axis: The dimension to reduce. If None, all the dimensions will be reduced. Default: ``None`` | |||||
| :param keepdims: Whether the output tensor has *axis* retained or not. Default: ``False`` | |||||
| :return: The output tensor | |||||
| Examples: | |||||
| .. testcode:: | |||||
| import numpy as np | |||||
| from megengine import tensor | |||||
| import megengine.functional as F | |||||
| data = tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3)) | |||||
| out = F.prod(data) | |||||
| print(out.numpy()) | |||||
| Outputs: | |||||
| .. testoutput:: | |||||
| [720] | |||||
| """ | |||||
| return _reduce(inp, mode="PRODUCT", axis=axis, keepdims=keepdims) | |||||
| def mean( | |||||
| inp: Tensor, | |||||
| axis: Optional[Union[int, Sequence[int]]] = None, | |||||
| keepdims: bool = False, | |||||
| ) -> Tensor: | |||||
| """Returns the mean value of each row of the ``inp`` tensor in | |||||
| the given ``axis``. If axis is a list of dimensions, | |||||
| reduce over all of them. | |||||
| :param inp: The input tensor | |||||
| :param axis: The dimension to reduce. If None, all the dimensions will be reduced. Default: None | |||||
| :param keepdims: Whether the output tensor has ``axis`` retained or not. Default: False | |||||
| Examples: | |||||
| .. testcode:: | |||||
| import numpy as np | |||||
| from megengine import tensor | |||||
| import megengine.functional as F | |||||
| data = tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3)) | |||||
| out = F.mean(data) | |||||
| print(out.numpy()) | |||||
| .. testoutput:: | |||||
| [3.5] | |||||
| """ | |||||
| return _reduce(inp, mode="MEAN", axis=axis, keepdims=keepdims) | |||||
| def median( | |||||
| inp: Tensor, | |||||
| axis: Optional[Union[int, Sequence[int]]] = None, | |||||
| keepdims: bool = False, | |||||
| ) -> Tensor: | |||||
| raise NotImplementedError | |||||
| def var( | |||||
| inp: Tensor, | |||||
| axis: Optional[Union[int, Sequence[int]]] = None, | |||||
| keepdims: bool = False, | |||||
| ) -> Tensor: | |||||
| """Returns the variance value of input tensor along | |||||
| given ``axis``. If axis is a list of dimensions, | |||||
| reduce over all of them. | |||||
| :param inp: The input tensor. | |||||
| :param axis: The dimension to reduce. If None, all the dimensions will be reduced. Default: ``None``. | |||||
| :param keepdims: Whether the output tensor has ``axis`` retained or not. Default: ``False``. | |||||
| :return: The output tensor. | |||||
| Examples: | |||||
| .. testcode:: | |||||
| import numpy as np | |||||
| from megengine import tensor | |||||
| import megengine.functional as F | |||||
| data = tensor(np.arange(1, 7, dtype=np.float32).reshape(2, 3)) | |||||
| out = F.var(data) | |||||
| print(out.numpy()) | |||||
| .. testoutput:: | |||||
| [2.9166667] | |||||
| """ | |||||
| if axis is None: | |||||
| m = mean(inp, axis=axis, keepdims=False) | |||||
| else: | |||||
| m = mean(inp, axis=axis, keepdims=True) | |||||
| v = inp - m | |||||
| return mean(v ** 2, axis=axis, keepdims=keepdims) | |||||
| def std( | |||||
| inp: Tensor, | |||||
| axis: Optional[Union[int, Sequence[int]]] = None, | |||||
| keepdims: bool = False, | |||||
| ) -> Tensor: | |||||
| """Returns the standard deviation of input tensor along | |||||
| given ``axis``. If axis is a list of dimensions, | |||||
| reduce over all of them. | |||||
| :param inp: The input tensor. | |||||
| :param axis: The dimension to reduce. If None, all the dimensions will be reduced. Default: ``None``. | |||||
| :param keepdims: Whether the output tensor has ``axis`` retained or not. Default: ``False``. | |||||
| :return: The output tensor. | |||||
| Examples: | |||||
| .. testcode:: | |||||
| import numpy as np | |||||
| from megengine import tensor | |||||
| import megengine.functional as F | |||||
| data = tensor(np.arange(1, 7, dtype=np.float32).reshape(2, 3)) | |||||
| out = F.std(data, axis=1) | |||||
| print(out.numpy()) | |||||
| .. testoutput:: | |||||
| [0.8164966 0.8164966] | |||||
| """ | |||||
| return var(inp, axis=axis, keepdims=keepdims) ** 0.5 | |||||
| def min( | |||||
| inp: Tensor, | |||||
| axis: Optional[Union[int, Sequence[int]]] = None, | |||||
| keepdims: bool = False, | |||||
| ) -> Tensor: | |||||
| r""" | |||||
| Returns the min value of input tensor along given *axis*. | |||||
| :param inp: The input tensor | |||||
| :param axis: The dimension to reduce. If None, all the dimensions will be reduced. Default: None | |||||
| :param keepdims: Whether the output tensor has *axis* retained or not. Default: False | |||||
| :return: The output tensor | |||||
| Examples: | |||||
| .. testcode:: | |||||
| import numpy as np | |||||
| from megengine import tensor | |||||
| import megengine.functional as F | |||||
| x = tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3)) | |||||
| y = F.min(x) | |||||
| print(y.numpy()) | |||||
| Outputs: | |||||
| .. testoutput:: | |||||
| [1] | |||||
| """ | |||||
| return _reduce(inp, mode="MIN", axis=axis, keepdims=keepdims) | |||||
| def max( | |||||
| inp: Tensor, | |||||
| axis: Optional[Union[int, Sequence[int]]] = None, | |||||
| keepdims: bool = False, | |||||
| ) -> Tensor: | |||||
| r"""Returns the max value of the input tensor along given *axis*. | |||||
| :param inp: The input tensor | |||||
| :param axis: The dimension to reduce. If None, all the dimensions will be reduced. Default: None | |||||
| :param keepdims: Whether the output tensor has *axis* retained or not. Default: False | |||||
| :return: The output tensor | |||||
| Examples: | |||||
| .. testcode:: | |||||
| import numpy as np | |||||
| from megengine import tensor | |||||
| import megengine.functional as F | |||||
| x = tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3)) | |||||
| y = F.max(x) | |||||
| print(y.numpy()) | |||||
| .. testoutput:: | |||||
| [6] | |||||
| """ | |||||
| return _reduce(inp, mode="MAX", axis=axis, keepdims=keepdims) | |||||
| def norm( | |||||
| inp: Tensor, | |||||
| p: int = 2, | |||||
| axis: Optional[Union[int, Sequence[int]]] = None, | |||||
| keepdims=False, | |||||
| ): | |||||
| """Calculate ``p``-norm of input tensor along certain axis. | |||||
| :param inp: The input tensor | |||||
| :param p: power of value ``p`` applied to ``inp``. Default: 2 | |||||
| :param axis: The dimension to reduce. If None, all the dimensions will be reduced. Default: None | |||||
| :param keepdims: Whether the output tensor has ``axis`` retained or not. Default: False | |||||
| :return: The output tensor | |||||
| Examples: | |||||
| .. testcode:: | |||||
| import numpy as np | |||||
| from megengine import tensor | |||||
| import megengine.functional as F | |||||
| x = tensor(np.arange(-3, 3, dtype=np.float32).reshape(2,3)) | |||||
| y = F.norm(x) | |||||
| print(y.numpy()) | |||||
| .. testoutput:: | |||||
| [4.358899] | |||||
| """ | |||||
| if p == 0: | |||||
| return sum(inp != 0, axis=axis, keepdims=keepdims) | |||||
| if p == math.inf: | |||||
| return max(abs(inp)) | |||||
| if p == -math.inf: | |||||
| return min(abs(inp)) | |||||
| return sum(abs(inp) ** p, axis=axis, keepdims=keepdims) ** (1.0 / p) | |||||
| def argmin( | |||||
| inp: Tensor, | |||||
| axis: Optional[Union[int, Sequence[int]]] = None, | |||||
| keepdims: bool = False, | |||||
| ) -> Tensor: | |||||
| r"""Returns the indices of the minimum values along an axis | |||||
| :param inp: The input tensor | |||||
| :param axis: The dimension to reduce. If None, all the dimensions will be reduced. Default: None | |||||
| :param keepdims: Whether the output tensor has *axis* retained or not. Default: False | |||||
| :return: The output tensor | |||||
| Examples: | |||||
| .. testcode:: | |||||
| import numpy as np | |||||
| from megengine import tensor | |||||
| import megengine.functional as F | |||||
| x = tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3)) | |||||
| y = F.argmin(x) | |||||
| print(y.numpy()) | |||||
| .. testoutput:: | |||||
| [0] | |||||
| """ | |||||
| if isinstance(axis, collections.Iterable): | |||||
| axis = list(axis) | |||||
| axis.sort(reverse=True) | |||||
| for ai in axis: | |||||
| op = builtin.Argmin(axis=ai) | |||||
| (inp,) = apply(op, inp) | |||||
| if not keepdims: | |||||
| inp = remove_axis(inp, ai) | |||||
| return inp | |||||
| if axis is None: | |||||
| assert not keepdims, "can not set axis=None and keepdims=True" | |||||
| inp = inp.flatten() | |||||
| axis = 0 | |||||
| op = builtin.Argmin(axis=axis) | |||||
| (result,) = apply(op, inp) | |||||
| if not keepdims: | |||||
| result = remove_axis(result, axis) | |||||
| return result | |||||
| def argmax( | |||||
| inp: Tensor, | |||||
| axis: Optional[Union[int, Sequence[int]]] = None, | |||||
| keepdims: bool = False, | |||||
| ) -> Tensor: | |||||
| r"""Returns the indices of the maximum values along an axis | |||||
| :param inp: The input tensor | |||||
| :param axis: The dimension to reduce. If None, all the dimensions will be reduced. Default: None | |||||
| :param keepdims: Whether the output tensor has *axis* retained or not. Default: False | |||||
| :return: The output tensor | |||||
| Examples: | |||||
| .. testcode:: | |||||
| import numpy as np | |||||
| from megengine import tensor | |||||
| import megengine.functional as F | |||||
| x = tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3)) | |||||
| y = F.argmax(x) | |||||
| print(y.numpy()) | |||||
| .. testoutput:: | |||||
| [5] | |||||
| """ | |||||
| if isinstance(axis, collections.Iterable): | |||||
| axis = list(axis) | |||||
| axis.sort(reverse=True) | |||||
| for ai in axis: | |||||
| op = builtin.Argmax(axis=ai) | |||||
| (inp,) = apply(op, inp) | |||||
| if not keepdims: | |||||
| inp = remove_axis(inp, ai) | |||||
| return inp | |||||
| if axis is None: | |||||
| assert not keepdims, "can not set axis=None and keepdims=True" | |||||
| inp = inp.flatten() | |||||
| axis = 0 | |||||
| op = builtin.Argmax(axis=axis) | |||||
| (result,) = apply(op, inp) | |||||
| if not keepdims: | |||||
| result = remove_axis(result, axis) | |||||
| return result | |||||
| def normalize( | |||||
| inp: Tensor, | |||||
| p: int = 2, | |||||
| axis: Optional[Union[int, Sequence[int]]] = None, | |||||
| eps: float = 1e-12, | |||||
| ) -> Tensor: | |||||
| r"""Perform :math:`L_p` normalization of input tensor along certain axis. | |||||
| For a tensor :attr:`inp` of shape :math:`(n_0, ..., n_{dim}, ..., n_k)`, each | |||||
| :math:`n_{dim}` -element vector :math:`v` along dimension :attr:`axis` is transformed as: | |||||
| .. math:: | |||||
| v = \frac{v}{\max(\lVert v \rVert_p, \epsilon)}. | |||||
| :param inp: the input tensor | |||||
| :param p: power of value ``p`` applied to ``inp``. Default: 2 | |||||
| :param axis: The dimension to reduce. If None, all the dimensions will be reduced | |||||
| to calculate the norm. Default: None | |||||
| :param eps: a small value to avoid division by zero. Default: 1e-12 | |||||
| :return: the normalized output tensor | |||||
| """ | |||||
| if axis is None: | |||||
| return inp / clamp(norm(inp, p, axis), lower=eps) | |||||
| else: | |||||
| return inp / clamp(norm(inp, p, axis, keepdims=True), lower=eps) | |||||
| def argsort(inp: Tensor, descending: bool = False) -> Tensor: | |||||
| r""" | |||||
| Sort the target 2d matrix by row, return both the sorted tensor and indices. | |||||
| :param inp: The input tensor, if 2d, each row will be sorted | |||||
| :param descending: Sort in descending order, where the largest comes first. Default: ``False`` | |||||
| :return: Tuple of two tensors (sorted_tensor, indices_of_int32) | |||||
| Examples: | |||||
| .. testcode:: | |||||
| import numpy as np | |||||
| from megengine import tensor | |||||
| import megengine.functional as F | |||||
| data = tensor(np.array([1,2], dtype=np.float32)) | |||||
| indices = F.argsort(data) | |||||
| print(indices.numpy()) | |||||
| Outputs: | |||||
| .. testoutput:: | |||||
| [0 1] | |||||
| """ | |||||
| assert len(inp.shape) <= 2, "Input should be 1d or 2d" | |||||
| if descending: | |||||
| order = P.Argsort.Order.DESCENDING | |||||
| else: | |||||
| order = P.Argsort.Order.ASCENDING | |||||
| op = builtin.Argsort(order=order) | |||||
| if len(inp.shape) == 1: | |||||
| inp = inp.reshape(1, -1) | |||||
| _, result = apply(op, inp) | |||||
| return result[0] | |||||
| _, result = apply(op, inp) | |||||
| return result | |||||
| def sort(inp: Tensor, descending: bool = False) -> Tuple[Tensor, Tensor]: | |||||
| assert len(inp.shape) <= 2, "Input should be 1d or 2d" | |||||
| if descending: | |||||
| order = P.Argsort.Order.DESCENDING | |||||
| else: | |||||
| order = P.Argsort.Order.ASCENDING | |||||
| op = builtin.Argsort(order=order) | |||||
| if len(inp.shape) == 1: | |||||
| inp = inp.reshape(1, -1) | |||||
| tns, ind = apply(op, inp) | |||||
| return tns[0], ind[0] | |||||
| tns, ind = apply(op, inp) | |||||
| return tns, ind | |||||
| def topk( | |||||
| inp: Tensor, | |||||
| k: int, | |||||
| descending: bool = False, | |||||
| kth_only: bool = False, | |||||
| no_sort: bool = False, | |||||
| ) -> Tuple[Tensor, Tensor]: | |||||
| r""" | |||||
| Selected the Top-K (by default) smallest elements of 2d matrix by row. | |||||
| :param inp: The input tensor, if 2d, each row will be sorted | |||||
| :param k: The number of elements needed | |||||
| :param descending: If true, return the largest elements instead. Default: ``False`` | |||||
| :param kth_only: If true, only the k-th element will be returned. Default: ``False`` | |||||
| :param no_sort: If true, the returned elements can be unordered. Default: ``False`` | |||||
| :return: Tuple of two tensors (topk_tensor, indices_of_int32) | |||||
| Examples: | |||||
| .. testcode:: | |||||
| import numpy as np | |||||
| from megengine import tensor | |||||
| import megengine.functional as F | |||||
| data = tensor(np.array([2, 4, 6, 8, 7, 5, 3, 1], dtype=np.float32)) | |||||
| top, indices = F.topk(data, 5) | |||||
| print(top.numpy(), indices.numpy()) | |||||
| Outputs: | |||||
| .. testoutput:: | |||||
| [1. 2. 3. 4. 5.] [7 0 6 1 5] | |||||
| """ | |||||
| if descending: | |||||
| inp = -inp | |||||
| Mode = P.TopK.Mode | |||||
| if kth_only: | |||||
| mode = Mode.KTH_ONLY | |||||
| elif no_sort: | |||||
| mode = Mode.VALUE_IDX_NOSORT | |||||
| else: | |||||
| mode = Mode.VALUE_IDX_SORTED | |||||
| op = builtin.TopK(mode=mode) | |||||
| if len(inp.shape) == 1: | |||||
| inp = inp.reshape(1, -1) | |||||
| res = apply(op, inp, Tensor(k, dtype="int32")) | |||||
| if kth_only: | |||||
| tns = res[0] | |||||
| else: | |||||
| tns, ind = res[0][0], res[1][0] | |||||
| else: | |||||
| res = apply(op, inp, Tensor(k, dtype="int32")) | |||||
| if kth_only: | |||||
| tns = res | |||||
| else: | |||||
| tns, ind = res[0], res[1] | |||||
| if descending: | |||||
| tns = -tns | |||||
| return tns, ind | |||||
| @@ -0,0 +1,83 @@ | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # pylint: disable=too-many-lines | |||||
| from typing import Tuple, Union | |||||
| from ..core.ops import builtin | |||||
| from ..core.tensor.core import apply | |||||
| from ..tensor import Tensor | |||||
| from .debug_param import get_conv_execution_strategy | |||||
| from .types import _pair, _pair_nonzero | |||||
| def conv_bias_activation( | |||||
| inp: Tensor, | |||||
| weight: Tensor, | |||||
| bias: Tensor, | |||||
| dtype=None, | |||||
| stride: Union[int, Tuple[int, int]] = 1, | |||||
| padding: Union[int, Tuple[int, int]] = 0, | |||||
| dilation: Union[int, Tuple[int, int]] = 1, | |||||
| groups: int = 1, | |||||
| format="NCHW", | |||||
| nonlinear_mode="IDENTITY", | |||||
| conv_mode="CROSS_CORRELATION", | |||||
| compute_mode="DEFAULT", | |||||
| ) -> Tensor: | |||||
| """ convolution bias with activation operation, only for inference. | |||||
| :param inp: The feature map of the convolution operation | |||||
| :param weight: The convolution kernel | |||||
| :param bias: The bias added to the result of convolution | |||||
| :param stride: Stride of the 2D convolution operation. Default: 1 | |||||
| :param padding: Size of the paddings added to the input on both sides of its | |||||
| spatial dimensions. Only zero-padding is supported. Default: 0 | |||||
| :param dilation: Dilation of the 2D convolution operation. Default: 1 | |||||
| :param groups: number of groups to divide input and output channels into, | |||||
| so as to perform a "grouped convolution". When ``groups`` is not 1, | |||||
| ``in_channels`` and ``out_channels`` must be divisible by ``groups``, | |||||
| and the shape of weight should be ``(groups, out_channel // groups, | |||||
| in_channels // groups, height, width)``. | |||||
| :type conv_mode: string or :class:`P.Convolution.Mode` | |||||
| :param conv_mode: Supports 'CROSS_CORRELATION' or 'CONVOLUTION'. Default: | |||||
| 'CROSS_CORRELATION'. | |||||
| :param dtype: Support for np.dtype, Default: | |||||
| np.int8. | |||||
| :param scale: scale if use quantization, Default: | |||||
| 0.0. | |||||
| :param zero_point: scale if use quantization quint8, Default: | |||||
| 0.0. | |||||
| :type compute_mode: string or | |||||
| :class:`P.Convolution.ComputeMode` | |||||
| :param compute_mode: When set to 'DEFAULT', no special requirements will be | |||||
| placed on the precision of intermediate results. When set to 'FLOAT32', | |||||
| Float32 would be used for accumulator and intermediate result, but only | |||||
| effective when input and output are of Float16 dtype. | |||||
| """ | |||||
| ph, pw = _pair(padding) | |||||
| sh, sw = _pair_nonzero(stride) | |||||
| dh, dw = _pair_nonzero(dilation) | |||||
| sparse_type = "DENSE" if groups == 1 else "GROUP" | |||||
| op = builtin.ConvBiasForward( | |||||
| stride_h=sh, | |||||
| stride_w=sw, | |||||
| pad_h=ph, | |||||
| pad_w=pw, | |||||
| dilate_h=dh, | |||||
| dilate_w=dw, | |||||
| dtype=dtype, | |||||
| format=format, | |||||
| strategy=get_conv_execution_strategy(), | |||||
| nonlineMode=nonlinear_mode, | |||||
| mode=conv_mode, | |||||
| compute_mode=compute_mode, | |||||
| sparse=sparse_type, | |||||
| ) | |||||
| (outputs,) = apply(op, inp, weight, bias) | |||||
| return outputs | |||||
| @@ -0,0 +1,934 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| import functools | |||||
| import math | |||||
| from itertools import accumulate | |||||
| from typing import Iterable, List, Optional, Sequence, Tuple, Union | |||||
| import numpy as np | |||||
| from ..core._imperative_rt import CompNode | |||||
| from ..core.ops import builtin | |||||
| from ..core.ops._internal import param_defs as P | |||||
| from ..core.ops.special import Const | |||||
| from ..core.tensor.core import TensorBase, TensorWrapperBase, apply | |||||
| from ..core.tensor.utils import ( | |||||
| astensor1d, | |||||
| convert_inputs, | |||||
| convert_single_value, | |||||
| dtype_promotion, | |||||
| get_device, | |||||
| ) | |||||
| from ..device import get_default_device | |||||
| from ..tensor import Tensor | |||||
| from .elemwise import ceil | |||||
| __all__ = [ | |||||
| "add_axis", # expand_dims | |||||
| "arange", | |||||
| "broadcast", | |||||
| "concat", | |||||
| "cond_take", | |||||
| "dimshuffle", # transpose, permute | |||||
| "expand_dims", | |||||
| "full", | |||||
| "full_like", | |||||
| "gather", | |||||
| "eye", | |||||
| "linspace", | |||||
| "ones", | |||||
| "ones_like", | |||||
| "remove_axis", # squeeze | |||||
| "split", | |||||
| "squeeze", | |||||
| "stack", | |||||
| "reshape", | |||||
| "scatter", | |||||
| "where", | |||||
| "zeros", | |||||
| "zeros_like", | |||||
| "param_pack_split", | |||||
| "param_pack_concat", | |||||
| ] | |||||
| def eye(n: int, *, dtype=None, device: Optional[CompNode] = None) -> Tensor: | |||||
| """ | |||||
| Returns a 2D tensor with ones on the diagonal and zeros elsewhere. | |||||
| :param n: The number of rows | |||||
| :param m: The number of columns. Default: None | |||||
| :param dtype: The data type. Default: None | |||||
| :param device: Compute node of the matrix. Default: None | |||||
| :param comp_graph: Compute graph of the matrix. Default: None | |||||
| :return: The eye matrix | |||||
| Examples: | |||||
| .. testcode:: | |||||
| import numpy as np | |||||
| import megengine.functional as F | |||||
| data_shape = (4, 6) | |||||
| n, m = data_shape | |||||
| out = F.eye(n, m, dtype=np.float32) | |||||
| print(out.numpy()) | |||||
| Outputs: | |||||
| .. testoutput:: | |||||
| [[1. 0. 0. 0. 0. 0.] | |||||
| [0. 1. 0. 0. 0. 0.] | |||||
| [0. 0. 1. 0. 0. 0.] | |||||
| [0. 0. 0. 1. 0. 0.]] | |||||
| """ | |||||
| op = builtin.Eye(k=0, dtype=dtype, comp_node=device) | |||||
| (result,) = apply(op, Tensor(n, dtype="int32", device=device)) | |||||
| return result | |||||
| def full(shape, value, dtype="float32", device=None): | |||||
| if device is None: | |||||
| device = get_default_device() | |||||
| (x,) = Const(value, dtype=dtype, device=device)( | |||||
| Tensor(value, dtype=dtype, device=device) | |||||
| ) | |||||
| return broadcast(x, shape) | |||||
| def ones(shape, dtype="float32", device=None): | |||||
| return full(shape, 1.0, dtype=dtype, device=device) | |||||
| def zeros(shape, dtype="float32", device=None): | |||||
| return full(shape, 0.0, dtype=dtype, device=device) | |||||
| def zeros_like(inp: Tensor) -> Tensor: | |||||
| r""" | |||||
| Returns a zero tensor with the same shape as input tensor | |||||
| :param inp: input tensor | |||||
| Examples: | |||||
| .. testcode:: | |||||
| import numpy as np | |||||
| from megengine import tensor | |||||
| import megengine.functional as F | |||||
| inp = tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3)) | |||||
| out = F.zeros_like(inp) | |||||
| print(out.numpy()) | |||||
| .. testoutput:: | |||||
| [[0 0 0] | |||||
| [0 0 0]] | |||||
| """ | |||||
| return zeros(inp.shape, dtype=inp.dtype, device=inp.device) | |||||
| def ones_like(inp: Tensor) -> Tensor: | |||||
| r""" | |||||
| Returns a identity tensor with the same shape as input tensor | |||||
| """ | |||||
| return ones(inp.shape, dtype=inp.dtype, device=inp.device) | |||||
| def full_like(inp: Tensor, value: Union[int, float]) -> Tensor: | |||||
| r""" | |||||
| Returns a tensor filled with value val with the same shape as input tensor | |||||
| """ | |||||
| return full(inp.shape, value, dtype=inp.dtype, device=inp.device) | |||||
| def broadcast(inp: Tensor, shape: Union[int, Iterable[int]]) -> Tensor: | |||||
| """ | |||||
| Broadcast a tensor to ``shape`` | |||||
| :param inp: The input tensor | |||||
| :param shape: The target shape | |||||
| :return: The output tensor | |||||
| Examples: | |||||
| .. testcode:: | |||||
| import numpy as np | |||||
| from megengine import tensor | |||||
| import megengine.functional as F | |||||
| data = tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3)) | |||||
| out = F.broadcast(data, (4, 2, 3)) | |||||
| print(out.numpy()) | |||||
| Outputs: | |||||
| .. testoutput:: | |||||
| [[[0. 1. 2.] | |||||
| [3. 4. 5.]] | |||||
| [[0. 1. 2.] | |||||
| [3. 4. 5.]] | |||||
| [[0. 1. 2.] | |||||
| [3. 4. 5.]] | |||||
| [[0. 1. 2.] | |||||
| [3. 4. 5.]]] | |||||
| """ | |||||
| shape = astensor1d(shape, inp, dtype="int32", device=inp.device) | |||||
| (result,) = apply(builtin.Broadcast(), inp, shape) | |||||
| return result | |||||
| def concat( | |||||
| inps: Iterable[Tensor], axis: int = 0, device: Optional[CompNode] = None, | |||||
| ) -> Tensor: | |||||
| r""" | |||||
| Concat some tensors | |||||
| :param inps: Input tensors to concat | |||||
| :param axis: the dimension over which the tensors are concatenated. Default: 0 | |||||
| :param device: The comp node output on. Default: None | |||||
| :param comp_graph: The graph in which output is. Default: None | |||||
| :return: The output tensor | |||||
| Examples: | |||||
| .. testcode:: | |||||
| import numpy as np | |||||
| from megengine import tensor | |||||
| import megengine.functional as F | |||||
| data1 = tensor(np.arange(0, 6, dtype=np.float32).reshape((2, 3))) | |||||
| data2 = tensor(np.arange(6, 12, dtype=np.float32).reshape((2, 3))) | |||||
| out = F.concat([data1, data2]) | |||||
| print(out.numpy()) | |||||
| Outputs: | |||||
| .. testoutput:: | |||||
| [[ 0. 1. 2.] | |||||
| [ 3. 4. 5.] | |||||
| [ 6. 7. 8.] | |||||
| [ 9. 10. 11.]] | |||||
| """ | |||||
| dtype = dtype_promotion(inps) | |||||
| device = get_device(inps) | |||||
| def convert(x): | |||||
| return convert_single_value(x, inps, dtype=dtype) | |||||
| inps = tuple(map(convert, inps)) | |||||
| (result,) = apply(builtin.Concat(axis=axis, comp_node=device.to_c()), *inps) | |||||
| return result | |||||
| def stack(inps, axis=0): | |||||
| """Concats a sequence of tensors along a new axis. | |||||
| The input tensors must have the same shape. | |||||
| :param inps: The input tensors. | |||||
| :param axis: Which axis will be concatenated. | |||||
| :return: The output concatenated tensor. | |||||
| Examples: | |||||
| .. testcode:: | |||||
| import numpy as np | |||||
| from megengine import tensor | |||||
| import megengine.functional as F | |||||
| x1 = tensor(np.arange(0, 6, dtype=np.float32).reshape((2, 3))) | |||||
| x2 = tensor(np.arange(6, 12, dtype=np.float32).reshape((2, 3))) | |||||
| out = F.stack([x1, x2], axis=0) | |||||
| print(out.numpy()) | |||||
| Outputs: | |||||
| .. testoutput:: | |||||
| [[[ 0. 1. 2.] | |||||
| [ 3. 4. 5.]] | |||||
| [[ 6. 7. 8.] | |||||
| [ 9. 10. 11.]]] | |||||
| """ | |||||
| shapes = {arr.shape for arr in inps} | |||||
| if len(shapes) != 1: | |||||
| raise ValueError("All input tensors must have the same shape") | |||||
| inps = [add_axis(inp, axis=axis) for inp in inps] | |||||
| return concat(inps, axis=axis) | |||||
| def split(inp, nsplits_or_sections, axis=0): | |||||
| """Splits the input tensor into several smaller tensors. | |||||
| When nsplits_or_sections is int, the last tensor may be smaller than others. | |||||
| :param inp: The input tensor. | |||||
| :param nsplits_or_sections: Number of sub tensors or section information list. | |||||
| :param axis: Which axis will be splited. | |||||
| :return: The output tensor list. | |||||
| Examples: | |||||
| .. testcode:: | |||||
| import numpy as np | |||||
| from megengine import tensor | |||||
| import megengine.functional as F | |||||
| x = tensor(np.random.random((2,3,4,5)), dtype=np.float32) | |||||
| out = F.split(x, 2, axis=3) | |||||
| print(out[0].shape, out[1].shape) | |||||
| Outputs: | |||||
| .. testoutput:: | |||||
| (2, 3, 4, 3) (2, 3, 4, 2) | |||||
| """ | |||||
| sub_tensors = [] | |||||
| sections = [] | |||||
| def swapaxis(inp, src, dst): | |||||
| if src == dst: | |||||
| return inp | |||||
| shape = [i for i in range(len(inp.shape))] | |||||
| shape[src] = dst | |||||
| shape[dst] = src | |||||
| return inp.transpose(shape) | |||||
| inp = swapaxis(inp, 0, axis) | |||||
| if isinstance(nsplits_or_sections, int): | |||||
| incr_step = math.ceil(inp.shape[0] / nsplits_or_sections) | |||||
| while incr_step < inp.shape[0]: | |||||
| sections.append(incr_step) | |||||
| incr_step += nsplits_or_sections | |||||
| else: | |||||
| sections = nsplits_or_sections | |||||
| st = 0 | |||||
| for se in sections: | |||||
| sub_tensors.append(swapaxis(inp[st:se], axis, 0)) | |||||
| st = se | |||||
| if st < inp.shape[0]: | |||||
| sub_tensors.append(swapaxis(inp[st:], axis, 0)) | |||||
| return sub_tensors | |||||
| def _get_idx(index, axis): | |||||
| index_dims = len(index.shape) | |||||
| idx = [] | |||||
| for i in range(index_dims): | |||||
| if i != axis: | |||||
| shape = [1] * index_dims | |||||
| shape[i] = index.shape[i] | |||||
| arange = linspace( | |||||
| 0, index.shape[i] - 1, index.shape[i], device=index.device, | |||||
| ) | |||||
| arange = ( | |||||
| arange.reshape(*shape) | |||||
| .broadcast(index.shape) | |||||
| .reshape(-1) | |||||
| .astype(np.int32) | |||||
| ) | |||||
| idx.append(arange) | |||||
| else: | |||||
| idx.append(index.reshape(-1)) | |||||
| return tuple(idx) | |||||
| def gather(inp: Tensor, axis: int, index: Tensor) -> Tensor: | |||||
| r""" | |||||
| Gather data from :attr:`inp` on :attr:`axis` using :attr:`index`. | |||||
| For a 3-D tensor, the output is specified by:: | |||||
| out[i][j][k] = inp[index[i][j][k]][j][k] # if axis == 0 | |||||
| out[i][j][k] = inp[i][index[i][j][k]][k] # if axis == 1 | |||||
| out[i][j][k] = inp[i][j][index[i][j][k]] # if axis == 2 | |||||
| if :attr:`inp` is an n-dimensional tensor with size | |||||
| :math:`(x_0,x_1,...,x_{i-1},x_i,x_{i+1},...,x_{n-1})` and axis=i, | |||||
| then :attr:`index` must be an n-dimensional tensor with size | |||||
| :math:`(x_0,x_1,...,x_{i-1},y,x_{i+1},...,x_{n-1})` where :math:`y\ge 1` and | |||||
| output will have the same size as :attr:`index`. | |||||
| :param inp: the source tensor | |||||
| :param axis: the axis along which to index | |||||
| :param index: the indices of elements to gather | |||||
| Examples: | |||||
| .. testcode:: | |||||
| import megengine.functional as F | |||||
| from megengine import tensor | |||||
| inp = tensor([ | |||||
| [1,2], [3,4], [5,6], | |||||
| ]) | |||||
| index = tensor([[0,2], [1,0]]) | |||||
| oup = F.gather(inp, 0, index) | |||||
| print(oup.numpy()) | |||||
| Outputs: | |||||
| .. testoutput:: | |||||
| [[1 6] | |||||
| [3 2]] | |||||
| """ | |||||
| input_shape = inp.shape | |||||
| index_shape = index.shape | |||||
| input_dims = len(input_shape) | |||||
| index_dims = len(index_shape) | |||||
| if input_dims != index_dims: | |||||
| raise ValueError( | |||||
| "The index tensor must have same dimensions as input tensor, " | |||||
| "But the input dims:{}, the index dims:{}".format(input_dims, index_dims) | |||||
| ) | |||||
| if axis < 0 or axis >= input_dims: | |||||
| raise ValueError( | |||||
| "Index axis {} is output of bounds, should in range [0 {})".format( | |||||
| axis, input_dims | |||||
| ) | |||||
| ) | |||||
| for i in range(input_dims): | |||||
| if i != axis and input_shape[i] != index_shape[i]: | |||||
| raise ValueError( | |||||
| "The input {} and index {} must have the same size apart from axis {}".format( | |||||
| input_shape, index_shape, axis | |||||
| ) | |||||
| ) | |||||
| idx = _get_idx(index, axis) | |||||
| return inp[idx].reshape(index.shape) # pylint: disable=no-member | |||||
| def scatter(inp: Tensor, axis: int, index: Tensor, source: Tensor) -> Tensor: | |||||
| r""" | |||||
| Writes all values from the tensor :attr:`source` into :attr:`inp` at the indices specified in the :attr:`index` tensor. | |||||
| For each value in :attr:`source`, its output index is specified by its index | |||||
| in :attr:`source` for ``axis != dimension`` and by the corresponding value in | |||||
| :attr:`index` for ``axis = dimension``. | |||||
| For a 3-D tensor, :attr:`inp` is updated as:: | |||||
| inp[index[i][j][k]][j][k] = source[i][j][k] # if axis == 0 | |||||
| inp[i][index[i][j][k]][k] = source[i][j][k] # if axis == 1 | |||||
| inp[i][j][index[i][j][k]] = source[i][j][k] # if axis == 2 | |||||
| :attr:`inp`, :attr:`index` and :attr:`source` should have same number of dimensions. | |||||
| It is also required that ``source.shape(d) <= inp.shape(d)`` and ``index.shape(d) == source.shape(d)`` | |||||
| for all dimensions ``d``. | |||||
| Moreover, the values of :attr:`index` must be between ``0`` and ``inp.shape(axis) - 1`` inclusive. | |||||
| .. note:: | |||||
| Please notice that, due to performance issues, the result is uncertain on the GPU device | |||||
| if scatter difference positions from source to the same destination position | |||||
| regard to index tensor. | |||||
| Show the case using the following examples, the oup[0][2] is maybe | |||||
| from source[0][2] which value is 0.2256 or source[1][2] which value is 0.5339 | |||||
| if set the index[1][2] from 1 to 0. | |||||
| :param inp: the inp tensor which to be scattered | |||||
| :param axis: the axis along which to index | |||||
| :param index: the indices of elements to scatter | |||||
| :param source: the source element(s) to scatter | |||||
| Examples: | |||||
| .. testcode:: | |||||
| import numpy as np | |||||
| import megengine.functional as F | |||||
| from megengine import tensor | |||||
| inp = tensor(np.zeros(shape=(3,5),dtype=np.float32)) | |||||
| source = tensor([[0.9935,0.9465,0.2256,0.8926,0.4396],[0.7723,0.0718,0.5939,0.357,0.4576]]) | |||||
| index = tensor([[0,2,0,2,1],[2,0,1,1,2]]) | |||||
| oup = F.scatter(inp, 0, index,source) | |||||
| print(oup.numpy()) | |||||
| Outputs: | |||||
| .. testoutput:: | |||||
| [[0.9935 0.0718 0.2256 0. 0. ] | |||||
| [0. 0. 0.5939 0.357 0.4396] | |||||
| [0.7723 0.9465 0. 0.8926 0.4576]] | |||||
| """ | |||||
| input_shape = inp.shape | |||||
| index_shape = index.shape | |||||
| source_shape = source.shape | |||||
| input_dims = len(input_shape) | |||||
| index_dims = len(index_shape) | |||||
| source_dims = len(source_shape) | |||||
| if input_dims != index_dims or input_dims != source_dims: | |||||
| raise ValueError("The input, source and index tensor must have same dimensions") | |||||
| if axis < 0 or axis >= input_dims: | |||||
| raise ValueError( | |||||
| "Index axis {} is output of bounds, should in range [0 {})".format( | |||||
| axis, input_dims | |||||
| ) | |||||
| ) | |||||
| for i in range(source_dims): | |||||
| if source_shape[i] > input_shape[i]: | |||||
| raise ValueError( | |||||
| "The each shape size for source {} must be less than or equal to input {} ".format( | |||||
| source_shape, input_shape | |||||
| ) | |||||
| ) | |||||
| for i in range(index_dims): | |||||
| if index_shape[i] != source_shape[i]: | |||||
| raise ValueError( | |||||
| "The each shape size for index {} must be equal to source {} ".format( | |||||
| index_shape, source_shape | |||||
| ) | |||||
| ) | |||||
| for i in range(index_dims): | |||||
| if i != axis and index_shape[i] > input_shape[i]: | |||||
| raise ValueError( | |||||
| "The index {} must be less than or equal to input {} size apart from axis {}".format( | |||||
| index_shape, input_shape, axis | |||||
| ) | |||||
| ) | |||||
| idx = _get_idx(index, axis) | |||||
| inp[idx] = source.flatten() | |||||
| return inp | |||||
| def where(mask: Tensor, x: Tensor, y: Tensor) -> Tensor: | |||||
| r""" | |||||
| Select elements either from Tensor x or Tensor y, according to mask. | |||||
| .. math:: | |||||
| \textrm{out}_i = x_i \textrm{ if } \textrm{mask}_i \textrm{ is True else } y_i | |||||
| :param mask: a mask used for choosing x or y | |||||
| :param x: the first choice | |||||
| :param y: the second choice | |||||
| Examples: | |||||
| .. testcode:: | |||||
| from megengine import tensor | |||||
| import megengine.functional as F | |||||
| mask = tensor(np.array([[1, 0], [0, 1]], dtype=np.int32)) | |||||
| x = tensor(np.array([[1, np.inf], [np.nan, 4]], | |||||
| dtype=np.float32)) | |||||
| y = tensor(np.array([[5, 6], [7, 8]], dtype=np.float32)) | |||||
| out = F.where(mask, x, y) | |||||
| print(out.numpy()) | |||||
| Outputs: | |||||
| .. testoutput:: | |||||
| [[1. 6.] | |||||
| [7. 4.]] | |||||
| """ | |||||
| raise NotImplementedError | |||||
| # v0, index0 = mgb.opr.cond_take( | |||||
| # x, mask, mode=P.CondTake.Mode.EQ, val=1 | |||||
| # ) | |||||
| # v1, index1 = mgb.opr.cond_take( | |||||
| # y, mask, mode=P.CondTake.Mode.EQ, val=0 | |||||
| # ) | |||||
| # out = x.flatten() | |||||
| # index = mgb.opr.concat(index0, index1, axis=0) | |||||
| # v = mgb.opr.concat(v0, v1, axis=0) | |||||
| # out = mgb.opr.set_advanced_indexing(out, v)[index] | |||||
| # out = out.reshape(x.shape) | |||||
| # return out | |||||
| def cond_take(mask: Tensor, x: Tensor) -> Tensor: | |||||
| r""" | |||||
| Take elements from data if specific condition is satisfied on mask. This operator has two outputs: the first is the elements taken, and the second is the indices corresponding to those elements; they are both 1-dimensional. High-dimension input would first be flattened. | |||||
| :param mask: condition param; must be the same shape with data | |||||
| :param x: input tensor from which to take elements | |||||
| Examples: | |||||
| .. testcode:: | |||||
| import numpy as np | |||||
| from megengine import tensor | |||||
| import megengine.functional as F | |||||
| mask = tensor(np.array([[True, False], [False, True]], dtype=np.bool_)) | |||||
| x = tensor(np.array([[1, np.inf], [np.nan, 4]], | |||||
| dtype=np.float32)) | |||||
| v, index = F.cond_take(mask, x) | |||||
| print(v.numpy(), index.numpy()) | |||||
| Outputs: | |||||
| .. testoutput:: | |||||
| Tensor([1. 4.]) Tensor([0 3], dtype=int32) | |||||
| """ | |||||
| if not isinstance(x, (TensorWrapperBase, TensorBase)): | |||||
| raise TypeError("input must be a tensor") | |||||
| if not isinstance(mask, (TensorWrapperBase, TensorBase)): | |||||
| raise TypeError("mask must be a tensor") | |||||
| if mask.dtype != np.bool_: | |||||
| raise ValueError("mask must be bool") | |||||
| if x.device != mask.device: | |||||
| raise ValueError("ambiguous device: {} vs {}".format(x.device, mask.device)) | |||||
| op = builtin.CondTake() | |||||
| v, index = apply(op, x, mask) | |||||
| return v, index | |||||
| def dimshuffle(inp: Tensor, pattern: Iterable[int]) -> Tensor: | |||||
| r""" | |||||
| Swap shapes and strides according to given pattern | |||||
| :param inp: Input tensor | |||||
| :param pattern: a list of integers including 0, 1, ... , ``ndim``-1, and any number of ``'x'`` char in dimensions where this tensor should be broadcasted. For examples: | |||||
| * (``'x'``) -> make a 0d (scalar) into a 1d vector | |||||
| * (0, 1) -> identity for 2d vectors | |||||
| * (1, 0) -> inverts the first and second dimensions | |||||
| * (``'x'``, 0) -> make a row out of a 1d vector (N to 1xN) | |||||
| * (0, ``'x'``) -> make a column out of a 1d vector (N to Nx1) | |||||
| * (2, 0, 1) -> AxBxC to CxAxB | |||||
| * (0, ``'x'``, 1) -> AxB to Ax1xB | |||||
| * (1, ``'x'``, 0) -> AxB to Bx1xA | |||||
| * (1,) -> This remove dimensions 0. It must be a broadcastable dimension (1xA to A) | |||||
| :return: The output tensor | |||||
| Examples: | |||||
| .. testcode:: | |||||
| import numpy as np | |||||
| from megengine import tensor | |||||
| import megengine.functional as F | |||||
| x = tensor(np.array([[1, 1], [0, 0]], dtype=np.int32)) | |||||
| out = F.dimshuffle(x, (1, 0)) | |||||
| print(out.numpy()) | |||||
| Outputs: | |||||
| .. testoutput:: | |||||
| [[1 0] | |||||
| [1 0]] | |||||
| """ | |||||
| op = builtin.Dimshuffle(pattern) | |||||
| (inp,) = convert_inputs(inp) | |||||
| (result,) = apply(op, inp) | |||||
| return result | |||||
| def reshape(inp: Tensor, target_shape: Iterable[int]) -> Tensor: | |||||
| r""" | |||||
| Reshape a tensor to given target shape; total number of logical elements must | |||||
| remain unchanged | |||||
| :param inp: Input tensor | |||||
| :param target_shape: target shape, the components would be concatenated to form the | |||||
| target shape, and it can contain an element of -1 representing unspec_axis. | |||||
| Examples: | |||||
| .. testcode:: | |||||
| import numpy as np | |||||
| from megengine import tensor | |||||
| import megengine.functional as F | |||||
| x = tensor(np.arange(12, dtype=np.int32)) | |||||
| out = F.reshape(x, (3, 2, 2)) | |||||
| print(out.numpy()) | |||||
| Outputs: | |||||
| .. testoutput:: | |||||
| [[[ 0 1] | |||||
| [ 2 3]] | |||||
| [[ 4 5] | |||||
| [ 6 7]] | |||||
| [[ 8 9] | |||||
| [10 11]]] | |||||
| """ | |||||
| if isinstance(target_shape, (TensorBase, TensorWrapperBase)): | |||||
| target_shape = target_shape.numpy() | |||||
| target_shape = tuple(map(int, target_shape)) | |||||
| unspec_axis = None | |||||
| for i, s in enumerate(target_shape): | |||||
| if s < 0: | |||||
| if s != -1: | |||||
| raise ValueError("expect shape[{}] >= -1, got {}".format(i, s)) | |||||
| if unspec_axis is not None: | |||||
| raise ValueError("multiple -1 in shape: {} & {}".format(unspec_axis, i)) | |||||
| unspec_axis = i | |||||
| # TODO: device should be None (cpu) | |||||
| (target_shape,) = Const(target_shape, dtype="int32", device=inp.device)(inp) | |||||
| if unspec_axis is None: | |||||
| op = builtin.Reshape() | |||||
| else: | |||||
| op = builtin.Reshape(unspec_axis=unspec_axis) | |||||
| (x,) = apply(op, inp, target_shape) | |||||
| return x | |||||
| transpose = dimshuffle | |||||
| AxisAddRemove = builtin.AxisAddRemove | |||||
| AxisDesc = AxisAddRemove.AxisDesc | |||||
| def add_axis(inp: Tensor, axis: Union[int, Sequence[int]]) -> Tensor: | |||||
| r""" | |||||
| Add dimension before given axis. | |||||
| :param inp: Input tensor | |||||
| :param axis: Place of new axes | |||||
| :return: The output tensor | |||||
| Examples: | |||||
| .. testcode:: | |||||
| import numpy as np | |||||
| from megengine import tensor | |||||
| import megengine.functional as F | |||||
| x = tensor([1, 2]) | |||||
| out = F.add_axis(x, 0) | |||||
| print(out.shape) | |||||
| Outputs: | |||||
| .. testoutput:: | |||||
| (1, 2) | |||||
| """ | |||||
| Param = AxisAddRemove.Param | |||||
| def get_axes(): | |||||
| try: | |||||
| return [int(axis)] | |||||
| except (TypeError, ValueError): | |||||
| pass | |||||
| return list(map(int, axis)) | |||||
| axis = get_axes() | |||||
| ndim = inp.ndim + len(axis) | |||||
| axis = sorted(i + ndim if i < 0 else i for i in axis) | |||||
| param = Param(*map(AxisDesc.make_add, axis)) | |||||
| op = AxisAddRemove(param=param) | |||||
| (result,) = apply(op, inp) | |||||
| return result | |||||
| expand_dims = add_axis | |||||
| def remove_axis(inp: Tensor, axis: Union[int, Sequence[int]]) -> Tensor: | |||||
| r""" | |||||
| Remove dimension of shape 1. | |||||
| :param inp: Input tensor | |||||
| :param axis: Place of axis to be removed | |||||
| :return: The output tensor | |||||
| Examples: | |||||
| .. testcode:: | |||||
| import numpy as np | |||||
| from megengine import tensor | |||||
| import megengine.functional as F | |||||
| x = tensor(np.array([1, 2], dtype=np.int32).reshape(1, 1, 2, 1)) | |||||
| out = F.remove_axis(x, 3) | |||||
| print(out.shape) | |||||
| Outputs: | |||||
| .. testoutput:: | |||||
| (1, 1, 2) | |||||
| """ | |||||
| Param = AxisAddRemove.Param | |||||
| def get_axes(): | |||||
| if axis is None: | |||||
| return [i for i, s in enumerate(inp.shape) if s == 1] | |||||
| try: | |||||
| return [int(axis)] | |||||
| except (TypeError, ValueError): | |||||
| pass | |||||
| return list(map(int, axis)) | |||||
| axis = get_axes() | |||||
| axis = sorted(i + inp.ndim if i < 0 else i for i in axis) | |||||
| axis = [a - i for i, a in enumerate(axis)] | |||||
| param = Param(*map(AxisDesc.make_remove, axis)) | |||||
| op = AxisAddRemove(param=param) | |||||
| (result,) = apply(op, inp) | |||||
| return result | |||||
| squeeze = remove_axis | |||||
| def linspace( | |||||
| start: Union[int, float, Tensor], | |||||
| stop: Union[int, float, Tensor], | |||||
| num: Union[int, Tensor], | |||||
| dtype="float32", | |||||
| device: Optional[CompNode] = None, | |||||
| ) -> Tensor: | |||||
| r""" | |||||
| Return equally spaced numbers over a specified interval | |||||
| :param start: Starting value of the squence, shoule be scalar | |||||
| :param stop: The last value of the squence, shoule be scalar | |||||
| :param num: number of values to generate | |||||
| :param dtype: result data type | |||||
| :return: The generated tensor | |||||
| Examples: | |||||
| .. testcode:: | |||||
| import numpy as np | |||||
| import megengine.functional as F | |||||
| a = F.linspace(3,10,5) | |||||
| print(a.numpy()) | |||||
| .. testoutput:: | |||||
| [ 3. 4.75 6.5 8.25 10. ] | |||||
| """ | |||||
| start = Tensor(start, device=device) | |||||
| stop = Tensor(stop, device=device) | |||||
| num = Tensor(num, device=device) | |||||
| device = device if device is None else device.to_c() | |||||
| op = builtin.Linspace(comp_node=device) | |||||
| (result,) = apply(op, start, stop, num) | |||||
| if np.dtype(dtype) == np.int32: | |||||
| return result.astype(dtype) | |||||
| return result | |||||
| def arange( | |||||
| start: Union[int, float, Tensor], | |||||
| end: Union[int, float, Tensor], | |||||
| step: Union[int, float, Tensor] = 1, | |||||
| dtype="float32", | |||||
| device: Optional[CompNode] = None, | |||||
| ) -> Tensor: | |||||
| r""" | |||||
| Returns a Tensor with values from `start` to `end` with adjacent interval `step` | |||||
| :param start: starting value of the squence, shoule be scalar | |||||
| :param end: ending value of the squence, shoule be scalar | |||||
| :param step: the gap between each pair of adjacent values. Default 1 | |||||
| :param dtype: result data type | |||||
| :return: The generated tensor | |||||
| Examples: | |||||
| .. testcode:: | |||||
| import numpy as np | |||||
| import megengine.functional as F | |||||
| a = F.arange(1, 5, 1) | |||||
| print(a.numpy()) | |||||
| .. testoutput:: | |||||
| [1. 2. 3. 4.] | |||||
| """ | |||||
| if isinstance(start, Tensor): | |||||
| start = start.astype("float32") | |||||
| if isinstance(end, Tensor): | |||||
| end = end.astype("float32") | |||||
| if isinstance(step, Tensor): | |||||
| step = step.astype("float32") | |||||
| num = ceil(Tensor((end - start) / step, device=device)) | |||||
| stop = start + step * (num - 1) | |||||
| result = linspace(start, stop, num, device=device) | |||||
| if np.dtype(dtype) == np.int32: | |||||
| return result.astype(dtype) | |||||
| return result | |||||
| def param_pack_split(inp: Tensor, offsets: List, shapes: List) -> Tensor: | |||||
| op = builtin.ParamPackSplit() | |||||
| op.offsets = offsets | |||||
| op.shapes = shapes | |||||
| return apply(op, inp) | |||||
| def param_pack_concat(inps: List, offsets: Tensor, offsets_val: List) -> Tensor: | |||||
| op = builtin.ParamPackConcat() | |||||
| op.offsets = offsets_val | |||||
| return apply(op, *inps, offsets)[0] | |||||
| @@ -0,0 +1,37 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| import collections | |||||
| import functools | |||||
| def get_ndtuple(value, *, n, allow_zero=True): | |||||
| r"""Converts possibly 1D tuple to nd tuple | |||||
| :type allow_zero: bool | |||||
| :param allow_zero: whether to allow zero tuple value""" | |||||
| if not isinstance(value, collections.Iterable): | |||||
| value = int(value) | |||||
| value = tuple([value for i in range(n)]) | |||||
| else: | |||||
| assert len(value) == n, "tuple len is not equal to n: {}".format(value) | |||||
| spatial_axis = map(int, value) | |||||
| value = tuple(spatial_axis) | |||||
| if allow_zero: | |||||
| minv = 0 | |||||
| else: | |||||
| minv = 1 | |||||
| assert min(value) >= minv, "invalid value: {}".format(value) | |||||
| return value | |||||
| _single = functools.partial(get_ndtuple, n=1, allow_zero=True) | |||||
| _pair = functools.partial(get_ndtuple, n=2, allow_zero=True) | |||||
| _pair_nonzero = functools.partial(get_ndtuple, n=2, allow_zero=False) | |||||
| _triple = functools.partial(get_ndtuple, n=3, allow_zero=True) | |||||
| _quadruple = functools.partial(get_ndtuple, n=4, allow_zero=True) | |||||
| @@ -0,0 +1,80 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| import collections | |||||
| from typing import Iterable, Union | |||||
| import numpy as np | |||||
| from ..core.ops.builtin import Copy | |||||
| from ..core.tensor import Tensor | |||||
| from ..core.tensor.core import apply | |||||
| from .math import topk as _topk | |||||
| from .tensor import dimshuffle as _dimshuffle | |||||
| def accuracy( | |||||
| logits: Tensor, target: Tensor, topk: Union[int, Iterable[int]] = 1 | |||||
| ) -> Union[Tensor, Iterable[Tensor]]: | |||||
| r""" | |||||
| Calculate the classification accuracy given predicted logits and ground-truth labels. | |||||
| :param logits: Model predictions of shape [batch_size, num_classes], | |||||
| representing the probability (likelyhood) of each class. | |||||
| :param target: Ground-truth labels, 1d tensor of int32 | |||||
| :param topk: Specifies the topk values, could be an int or tuple of ints. Default: 1 | |||||
| :return: Tensor(s) of classification accuracy between 0.0 and 1.0 | |||||
| Examples: | |||||
| .. testcode:: | |||||
| import numpy as np | |||||
| from megengine import tensor | |||||
| import megengine.functional as F | |||||
| logits = tensor(np.arange(80, dtype=np.int32).reshape(8,10)) | |||||
| target = tensor(np.arange(8, dtype=np.int32)) | |||||
| top1, top5 = F.accuracy(logits, target, (1, 5)) | |||||
| print(top1.numpy(), top5.numpy()) | |||||
| Outputs: | |||||
| .. testoutput:: | |||||
| [0.] [0.375] | |||||
| """ | |||||
| if isinstance(topk, int): | |||||
| topk = (topk,) | |||||
| _, pred = _topk(logits, k=max(topk), descending=True) | |||||
| accs = [] | |||||
| for k in topk: | |||||
| correct = pred[:, :k].detach() == _dimshuffle(target, (0, "x")).broadcast( | |||||
| target.shape[0], k | |||||
| ) | |||||
| accs.append(correct.astype(np.float32).sum() / target.shape[0]) | |||||
| if len(topk) == 1: # type: ignore[arg-type] | |||||
| accs = accs[0] | |||||
| return accs | |||||
| def zero_grad(inp: Tensor) -> Tensor: | |||||
| r""" | |||||
| Returns a tensor which is treated as constant during backward gradient calcuation, | |||||
| i.e. its gradient is zero. | |||||
| :param inp: Input tensor. | |||||
| See implementation of :func:`~.softmax` for example. | |||||
| """ | |||||
| print("zero_grad is obsoleted, please use detach instead") | |||||
| raise NotImplementedError | |||||
| def copy(inp, cn): | |||||
| return apply(Copy(comp_node=cn), inp)[0] | |||||
| @@ -0,0 +1,16 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| from .hub import ( | |||||
| help, | |||||
| import_module, | |||||
| list, | |||||
| load, | |||||
| load_serialized_obj_from_url, | |||||
| pretrained, | |||||
| ) | |||||
| @@ -0,0 +1,17 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| DEFAULT_BRANCH_NAME = "master" | |||||
| HUBCONF = "hubconf.py" | |||||
| HUBDEPENDENCY = "dependencies" | |||||
| DEFAULT_GIT_HOST = "github.com" | |||||
| ENV_MGE_HOME = "MGE_HOME" | |||||
| ENV_XDG_CACHE_HOME = "XDG_CACHE_HOME" | |||||
| DEFAULT_CACHE_DIR = "~/.cache" | |||||
| DEFAULT_PROTOCOL = "HTTPS" | |||||
| HTTP_READ_TIMEOUT = 120 | |||||
| @@ -0,0 +1,30 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| class FetcherError(Exception): | |||||
| """Base class for fetch related error.""" | |||||
| class InvalidRepo(FetcherError): | |||||
| """The repo provided was somehow invalid.""" | |||||
| class InvalidGitHost(FetcherError): | |||||
| """The git host provided was somehow invalid.""" | |||||
| class GitPullError(FetcherError): | |||||
| """A git pull error occurred""" | |||||
| class GitCheckoutError(FetcherError): | |||||
| """A git checkout error occurred""" | |||||
| class InvalidProtocol(FetcherError): | |||||
| """The protocol provided was somehow invalid""" | |||||
| @@ -0,0 +1,300 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| import hashlib | |||||
| import os | |||||
| import re | |||||
| import shutil | |||||
| import subprocess | |||||
| from tempfile import NamedTemporaryFile | |||||
| from typing import Tuple | |||||
| from zipfile import ZipFile | |||||
| import requests | |||||
| from tqdm import tqdm | |||||
| from megengine.utils.http_download import ( | |||||
| CHUNK_SIZE, | |||||
| HTTP_CONNECTION_TIMEOUT, | |||||
| HTTPDownloadError, | |||||
| ) | |||||
| from ..distributed import is_distributed, synchronized | |||||
| from ..logger import get_logger | |||||
| from .const import DEFAULT_BRANCH_NAME, HTTP_READ_TIMEOUT | |||||
| from .exceptions import GitCheckoutError, GitPullError, InvalidGitHost, InvalidRepo | |||||
| from .tools import cd | |||||
| logger = get_logger(__name__) | |||||
| HTTP_TIMEOUT = (HTTP_CONNECTION_TIMEOUT, HTTP_READ_TIMEOUT) | |||||
| pattern = re.compile( | |||||
| r"^(?:[a-z0-9]" # First character of the domain | |||||
| r"(?:[a-z0-9-_]{0,61}[a-z0-9])?\.)" # Sub domain + hostname | |||||
| r"+[a-z0-9][a-z0-9-_]{0,61}" # First 61 characters of the gTLD | |||||
| r"[a-z]$" # Last character of the gTLD | |||||
| ) | |||||
| class RepoFetcherBase: | |||||
| @classmethod | |||||
| def fetch( | |||||
| cls, | |||||
| git_host: str, | |||||
| repo_info: str, | |||||
| use_cache: bool = False, | |||||
| commit: str = None, | |||||
| silent: bool = True, | |||||
| ) -> str: | |||||
| raise NotImplementedError() | |||||
| @classmethod | |||||
| def _parse_repo_info(cls, repo_info: str) -> Tuple[str, str, str]: | |||||
| try: | |||||
| branch_info = DEFAULT_BRANCH_NAME | |||||
| if ":" in repo_info: | |||||
| prefix_info, branch_info = repo_info.split(":") | |||||
| else: | |||||
| prefix_info = repo_info | |||||
| repo_owner, repo_name = prefix_info.split("/") | |||||
| return repo_owner, repo_name, branch_info | |||||
| except ValueError: | |||||
| raise InvalidRepo("repo_info: '{}' is invalid.".format(repo_info)) | |||||
| @classmethod | |||||
| def _check_git_host(cls, git_host): | |||||
| return cls._is_valid_domain(git_host) or cls._is_valid_host(git_host) | |||||
| @classmethod | |||||
| def _is_valid_domain(cls, s): | |||||
| try: | |||||
| return pattern.match(s.encode("idna").decode("ascii")) | |||||
| except UnicodeError: | |||||
| return False | |||||
| @classmethod | |||||
| def _is_valid_host(cls, s): | |||||
| nums = s.split(".") | |||||
| if len(nums) != 4 or any(not _.isdigit() for _ in nums): | |||||
| return False | |||||
| return all(0 <= int(_) < 256 for _ in nums) | |||||
| @classmethod | |||||
| def _gen_repo_dir(cls, repo_dir: str) -> str: | |||||
| return hashlib.sha1(repo_dir.encode()).hexdigest()[:16] | |||||
| class GitSSHFetcher(RepoFetcherBase): | |||||
| @classmethod | |||||
| @synchronized | |||||
| def fetch( | |||||
| cls, | |||||
| git_host: str, | |||||
| repo_info: str, | |||||
| use_cache: bool = False, | |||||
| commit: str = None, | |||||
| silent: bool = True, | |||||
| ) -> str: | |||||
| """ | |||||
| Fetches git repo by SSH protocol | |||||
| :param git_host: | |||||
| host address of git repo. | |||||
| example: github.com | |||||
| :param repo_info: | |||||
| a string with format ``"repo_owner/repo_name[:tag_name/:branch_name]"`` with an optional | |||||
| tag/branch. The default branch is ``master`` if not specified. | |||||
| example: ``"brain_sdk/MegBrain[:hub]"`` | |||||
| :param use_cache: | |||||
| whether to use locally fetched code or completely re-fetch | |||||
| :param commit: | |||||
| commit id on github or gitlab | |||||
| :param silent: | |||||
| whether to accept the stdout and stderr of the subprocess with PIPE, instead of | |||||
| displaying on the screen | |||||
| :return: | |||||
| directory where the repo code is stored | |||||
| """ | |||||
| if not cls._check_git_host(git_host): | |||||
| raise InvalidGitHost("git_host: '{}' is malformed.".format(git_host)) | |||||
| repo_owner, repo_name, branch_info = cls._parse_repo_info(repo_info) | |||||
| normalized_branch_info = branch_info.replace("/", "_") | |||||
| repo_dir_raw = "{}_{}_{}".format( | |||||
| repo_owner, repo_name, normalized_branch_info | |||||
| ) + ("_{}".format(commit) if commit else "") | |||||
| repo_dir = cls._gen_repo_dir(repo_dir_raw) | |||||
| git_url = "git@{}:{}/{}.git".format(git_host, repo_owner, repo_name) | |||||
| if use_cache and os.path.exists(repo_dir): # use cache | |||||
| logger.debug("Cache Found in %s", repo_dir) | |||||
| return repo_dir | |||||
| if is_distributed(): | |||||
| logger.warning( | |||||
| "When using `hub.load` or `hub.list` to fetch git repositories\n" | |||||
| " in DISTRIBUTED mode for the first time, processes are synchronized to\n" | |||||
| " ensure that target repository is ready to use for each process.\n" | |||||
| " Users are expected to see this warning no more than ONCE, otherwise\n" | |||||
| " (very little chance) you may need to remove corrupt cache\n" | |||||
| " `%s` and fetch again.", | |||||
| repo_dir, | |||||
| ) | |||||
| shutil.rmtree(repo_dir, ignore_errors=True) # ignore and clear cache | |||||
| logger.debug( | |||||
| "Git Clone from Repo:%s Branch: %s to %s", | |||||
| git_url, | |||||
| normalized_branch_info, | |||||
| repo_dir, | |||||
| ) | |||||
| kwargs = ( | |||||
| {"stderr": subprocess.PIPE, "stdout": subprocess.PIPE} if silent else {} | |||||
| ) | |||||
| if commit is None: | |||||
| # shallow clone repo by branch/tag | |||||
| p = subprocess.Popen( | |||||
| [ | |||||
| "git", | |||||
| "clone", | |||||
| "-b", | |||||
| normalized_branch_info, | |||||
| git_url, | |||||
| repo_dir, | |||||
| "--depth=1", | |||||
| ], | |||||
| **kwargs, | |||||
| ) | |||||
| cls._check_clone_pipe(p) | |||||
| else: | |||||
| # clone repo and checkout to commit_id | |||||
| p = subprocess.Popen(["git", "clone", git_url, repo_dir], **kwargs) | |||||
| cls._check_clone_pipe(p) | |||||
| with cd(repo_dir): | |||||
| logger.debug("git checkout to %s", commit) | |||||
| p = subprocess.Popen(["git", "checkout", commit], **kwargs) | |||||
| _, err = p.communicate() | |||||
| if p.returncode: | |||||
| shutil.rmtree(repo_dir, ignore_errors=True) | |||||
| raise GitCheckoutError( | |||||
| "Git checkout error, please check the commit id.\n" | |||||
| + err.decode() | |||||
| ) | |||||
| with cd(repo_dir): | |||||
| shutil.rmtree(".git") | |||||
| return repo_dir | |||||
| @classmethod | |||||
| def _check_clone_pipe(cls, p): | |||||
| _, err = p.communicate() | |||||
| if p.returncode: | |||||
| raise GitPullError( | |||||
| "Repo pull error, please check repo info.\n" + err.decode() | |||||
| ) | |||||
| class GitHTTPSFetcher(RepoFetcherBase): | |||||
| @classmethod | |||||
| @synchronized | |||||
| def fetch( | |||||
| cls, | |||||
| git_host: str, | |||||
| repo_info: str, | |||||
| use_cache: bool = False, | |||||
| commit: str = None, | |||||
| silent: bool = True, | |||||
| ) -> str: | |||||
| """ | |||||
| Fetches git repo by HTTPS protocol | |||||
| :param git_host: | |||||
| host address of git repo | |||||
| example: github.com | |||||
| :param repo_info: | |||||
| a string with format ``"repo_owner/repo_name[:tag_name/:branch_name]"`` with an optional | |||||
| tag/branch. The default branch is ``master`` if not specified. | |||||
| example: ``"brain_sdk/MegBrain[:hub]"`` | |||||
| :param use_cache: | |||||
| whether to use locally cached code or completely re-fetch | |||||
| :param commit: | |||||
| commit id on github or gitlab | |||||
| :param silent: | |||||
| whether to accept the stdout and stderr of the subprocess with PIPE, instead of | |||||
| displaying on the screen | |||||
| :return: | |||||
| directory where the repo code is stored | |||||
| """ | |||||
| if not cls._check_git_host(git_host): | |||||
| raise InvalidGitHost("git_host: '{}' is malformed.".format(git_host)) | |||||
| repo_owner, repo_name, branch_info = cls._parse_repo_info(repo_info) | |||||
| normalized_branch_info = branch_info.replace("/", "_") | |||||
| repo_dir_raw = "{}_{}_{}".format( | |||||
| repo_owner, repo_name, normalized_branch_info | |||||
| ) + ("_{}".format(commit) if commit else "") | |||||
| repo_dir = cls._gen_repo_dir(repo_dir_raw) | |||||
| archive_url = cls._git_archive_link( | |||||
| git_host, repo_owner, repo_name, branch_info, commit | |||||
| ) | |||||
| if use_cache and os.path.exists(repo_dir): # use cache | |||||
| logger.debug("Cache Found in %s", repo_dir) | |||||
| return repo_dir | |||||
| if is_distributed(): | |||||
| logger.warning( | |||||
| "When using `hub.load` or `hub.list` to fetch git repositories " | |||||
| "in DISTRIBUTED mode for the first time, processes are synchronized to " | |||||
| "ensure that target repository is ready to use for each process.\n" | |||||
| "Users are expected to see this warning no more than ONCE, otherwise" | |||||
| "(very little chance) you may need to remove corrupt hub cache %s and fetch again." | |||||
| ) | |||||
| shutil.rmtree(repo_dir, ignore_errors=True) # ignore and clear cache | |||||
| logger.debug("Downloading from %s to %s", archive_url, repo_dir) | |||||
| cls._download_zip_and_extract(archive_url, repo_dir) | |||||
| return repo_dir | |||||
| @classmethod | |||||
| def _download_zip_and_extract(cls, url, target_dir): | |||||
| resp = requests.get(url, timeout=HTTP_TIMEOUT, stream=True) | |||||
| if resp.status_code != 200: | |||||
| raise HTTPDownloadError( | |||||
| "An error occured when downloading from {}".format(url) | |||||
| ) | |||||
| total_size = int(resp.headers.get("Content-Length", 0)) | |||||
| _bar = tqdm(total=total_size, unit="iB", unit_scale=True) | |||||
| with NamedTemporaryFile("w+b") as f: | |||||
| for chunk in resp.iter_content(CHUNK_SIZE): | |||||
| if not chunk: | |||||
| break | |||||
| _bar.update(len(chunk)) | |||||
| f.write(chunk) | |||||
| _bar.close() | |||||
| f.seek(0) | |||||
| with ZipFile(f) as temp_zip_f: | |||||
| zip_dir_name = temp_zip_f.namelist()[0].split("/")[0] | |||||
| temp_zip_f.extractall(".") | |||||
| shutil.move(zip_dir_name, target_dir) | |||||
| @classmethod | |||||
| def _git_archive_link(cls, git_host, repo_owner, repo_name, branch_info, commit): | |||||
| archive_link = "https://{}/{}/{}/archive/{}.zip".format( | |||||
| git_host, repo_owner, repo_name, commit or branch_info | |||||
| ) | |||||
| return archive_link | |||||
| @@ -0,0 +1,333 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| import functools | |||||
| import hashlib | |||||
| import os | |||||
| import sys | |||||
| import types | |||||
| from typing import Any, List | |||||
| from urllib.parse import urlparse | |||||
| from megengine.utils.http_download import download_from_url | |||||
| from ..distributed import is_distributed | |||||
| from ..logger import get_logger | |||||
| from ..serialization import load as _mge_load_serialized | |||||
| from .const import ( | |||||
| DEFAULT_CACHE_DIR, | |||||
| DEFAULT_GIT_HOST, | |||||
| DEFAULT_PROTOCOL, | |||||
| ENV_MGE_HOME, | |||||
| ENV_XDG_CACHE_HOME, | |||||
| HTTP_READ_TIMEOUT, | |||||
| HUBCONF, | |||||
| HUBDEPENDENCY, | |||||
| ) | |||||
| from .exceptions import InvalidProtocol | |||||
| from .fetcher import GitHTTPSFetcher, GitSSHFetcher | |||||
| from .tools import cd, check_module_exists, load_module | |||||
| logger = get_logger(__name__) | |||||
| PROTOCOLS = { | |||||
| "HTTPS": GitHTTPSFetcher, | |||||
| "SSH": GitSSHFetcher, | |||||
| } | |||||
| def _get_megengine_home() -> str: | |||||
| """MGE_HOME setting complies with the XDG Base Directory Specification | |||||
| """ | |||||
| megengine_home = os.path.expanduser( | |||||
| os.getenv( | |||||
| ENV_MGE_HOME, | |||||
| os.path.join(os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), "megengine"), | |||||
| ) | |||||
| ) | |||||
| return megengine_home | |||||
| def _get_repo( | |||||
| git_host: str, | |||||
| repo_info: str, | |||||
| use_cache: bool = False, | |||||
| commit: str = None, | |||||
| protocol: str = DEFAULT_PROTOCOL, | |||||
| ) -> str: | |||||
| if protocol not in PROTOCOLS: | |||||
| raise InvalidProtocol( | |||||
| "Invalid protocol, the value should be one of {}.".format( | |||||
| ", ".join(PROTOCOLS.keys()) | |||||
| ) | |||||
| ) | |||||
| cache_dir = os.path.expanduser(os.path.join(_get_megengine_home(), "hub")) | |||||
| with cd(cache_dir): | |||||
| fetcher = PROTOCOLS[protocol] | |||||
| repo_dir = fetcher.fetch(git_host, repo_info, use_cache, commit) | |||||
| return os.path.join(cache_dir, repo_dir) | |||||
| def _check_dependencies(module: types.ModuleType) -> None: | |||||
| if not hasattr(module, HUBDEPENDENCY): | |||||
| return | |||||
| dependencies = getattr(module, HUBDEPENDENCY) | |||||
| if not dependencies: | |||||
| return | |||||
| missing_deps = [m for m in dependencies if not check_module_exists(m)] | |||||
| if len(missing_deps): | |||||
| raise RuntimeError("Missing dependencies: {}".format(", ".join(missing_deps))) | |||||
| def _init_hub( | |||||
| repo_info: str, | |||||
| git_host: str, | |||||
| use_cache: bool = True, | |||||
| commit: str = None, | |||||
| protocol: str = DEFAULT_PROTOCOL, | |||||
| ): | |||||
| """Imports hubmodule like python import | |||||
| :param repo_info: | |||||
| a string with format ``"repo_owner/repo_name[:tag_name/:branch_name]"`` with an optional | |||||
| tag/branch. The default branch is ``master`` if not specified. | |||||
| Example: ``"brain_sdk/MegBrain[:hub]"`` | |||||
| :param git_host: | |||||
| host address of git repo | |||||
| Example: github.com | |||||
| :param use_cache: | |||||
| whether to use locally cached code or completely re-fetch | |||||
| :param commit: | |||||
| commit id on github or gitlab | |||||
| :param protocol: | |||||
| which protocol to use to get the repo, and HTTPS protocol only supports public repo on github. | |||||
| The value should be one of HTTPS, SSH. | |||||
| :return: | |||||
| hubconf.py as a python module | |||||
| """ | |||||
| cache_dir = os.path.expanduser(os.path.join(_get_megengine_home(), "hub")) | |||||
| os.makedirs(cache_dir, exist_ok=True) | |||||
| absolute_repo_dir = _get_repo( | |||||
| git_host, repo_info, use_cache=use_cache, commit=commit, protocol=protocol | |||||
| ) | |||||
| sys.path.insert(0, absolute_repo_dir) | |||||
| hubmodule = load_module(HUBCONF, os.path.join(absolute_repo_dir, HUBCONF)) | |||||
| sys.path.remove(absolute_repo_dir) | |||||
| return hubmodule | |||||
| @functools.wraps(_init_hub) | |||||
| def import_module(*args, **kwargs): | |||||
| return _init_hub(*args, **kwargs) | |||||
| def list( | |||||
| repo_info: str, | |||||
| git_host: str = DEFAULT_GIT_HOST, | |||||
| use_cache: bool = True, | |||||
| commit: str = None, | |||||
| protocol: str = DEFAULT_PROTOCOL, | |||||
| ) -> List[str]: | |||||
| """Lists all entrypoints available in repo hubconf | |||||
| :param repo_info: | |||||
| a string with format ``"repo_owner/repo_name[:tag_name/:branch_name]"`` with an optional | |||||
| tag/branch. The default branch is ``master`` if not specified. | |||||
| Example: ``"brain_sdk/MegBrain[:hub]"`` | |||||
| :param git_host: | |||||
| host address of git repo | |||||
| Example: github.com | |||||
| :param use_cache: | |||||
| whether to use locally cached code or completely re-fetch | |||||
| :param commit: | |||||
| commit id on github or gitlab | |||||
| :param protocol: | |||||
| which protocol to use to get the repo, and HTTPS protocol only supports public repo on github. | |||||
| The value should be one of HTTPS, SSH. | |||||
| :return: | |||||
| all entrypoint names of the model | |||||
| """ | |||||
| hubmodule = _init_hub(repo_info, git_host, use_cache, commit, protocol) | |||||
| return [ | |||||
| _ | |||||
| for _ in dir(hubmodule) | |||||
| if not _.startswith("__") and callable(getattr(hubmodule, _)) | |||||
| ] | |||||
| def load( | |||||
| repo_info: str, | |||||
| entry: str, | |||||
| *args, | |||||
| git_host: str = DEFAULT_GIT_HOST, | |||||
| use_cache: bool = True, | |||||
| commit: str = None, | |||||
| protocol: str = DEFAULT_PROTOCOL, | |||||
| **kwargs | |||||
| ) -> Any: | |||||
| """Loads model from github or gitlab repo, with pretrained weights. | |||||
| :param repo_info: | |||||
| a string with format ``"repo_owner/repo_name[:tag_name/:branch_name]"`` with an optional | |||||
| tag/branch. The default branch is ``master`` if not specified. | |||||
| Example: ``"brain_sdk/MegBrain[:hub]"`` | |||||
| :param entry: | |||||
| an entrypoint defined in hubconf | |||||
| :param git_host: | |||||
| host address of git repo | |||||
| Example: github.com | |||||
| :param use_cache: | |||||
| whether to use locally cached code or completely re-fetch | |||||
| :param commit: | |||||
| commit id on github or gitlab | |||||
| :param protocol: | |||||
| which protocol to use to get the repo, and HTTPS protocol only supports public repo on github. | |||||
| The value should be one of HTTPS, SSH. | |||||
| :return: | |||||
| a single model with corresponding pretrained weights. | |||||
| """ | |||||
| hubmodule = _init_hub(repo_info, git_host, use_cache, commit, protocol) | |||||
| if not hasattr(hubmodule, entry) or not callable(getattr(hubmodule, entry)): | |||||
| raise RuntimeError("Cannot find callable {} in hubconf.py".format(entry)) | |||||
| _check_dependencies(hubmodule) | |||||
| module = getattr(hubmodule, entry)(*args, **kwargs) | |||||
| return module | |||||
| def help( | |||||
| repo_info: str, | |||||
| entry: str, | |||||
| git_host: str = DEFAULT_GIT_HOST, | |||||
| use_cache: bool = True, | |||||
| commit: str = None, | |||||
| protocol: str = DEFAULT_PROTOCOL, | |||||
| ) -> str: | |||||
| """This function returns docstring of entrypoint ``entry`` by following steps: | |||||
| 1. Pull the repo code specified by git and repo_info | |||||
| 2. Load the entry defined in repo's hubconf.py | |||||
| 3. Return docstring of function entry | |||||
| :param repo_info: | |||||
| a string with format ``"repo_owner/repo_name[:tag_name/:branch_name]"`` with an optional | |||||
| tag/branch. The default branch is ``master`` if not specified. | |||||
| Example: ``"brain_sdk/MegBrain[:hub]"`` | |||||
| :param entry: | |||||
| an entrypoint defined in hubconf.py | |||||
| :param git_host: | |||||
| host address of git repo | |||||
| Example: github.com | |||||
| :param use_cache: | |||||
| whether to use locally cached code or completely re-fetch | |||||
| :param commit: | |||||
| commit id on github or gitlab | |||||
| :param protocol: | |||||
| which protocol to use to get the repo, and HTTPS protocol only supports public repo on github. | |||||
| The value should be one of HTTPS, SSH. | |||||
| :return: | |||||
| docstring of entrypoint ``entry`` | |||||
| """ | |||||
| hubmodule = _init_hub(repo_info, git_host, use_cache, commit, protocol) | |||||
| if not hasattr(hubmodule, entry) or not callable(getattr(hubmodule, entry)): | |||||
| raise RuntimeError("Cannot find callable {} in hubconf.py".format(entry)) | |||||
| doc = getattr(hubmodule, entry).__doc__ | |||||
| return doc | |||||
| def load_serialized_obj_from_url(url: str, model_dir=None) -> Any: | |||||
| """Loads MegEngine serialized object from the given URL. | |||||
| If the object is already present in ``model_dir``, it's deserialized and | |||||
| returned. If no ``model_dir`` is specified, it will be ``MGE_HOME/serialized``. | |||||
| :param url: url to serialized object | |||||
| :param model_dir: dir to cache target serialized file | |||||
| :return: loaded object | |||||
| """ | |||||
| if model_dir is None: | |||||
| model_dir = os.path.join(_get_megengine_home(), "serialized") | |||||
| os.makedirs(model_dir, exist_ok=True) | |||||
| parts = urlparse(url) | |||||
| filename = os.path.basename(parts.path) | |||||
| # use hash as prefix to avoid filename conflict from different urls | |||||
| sha256 = hashlib.sha256() | |||||
| sha256.update(url.encode()) | |||||
| digest = sha256.hexdigest()[:6] | |||||
| filename = digest + "_" + filename | |||||
| cached_file = os.path.join(model_dir, filename) | |||||
| logger.info( | |||||
| "load_serialized_obj_from_url: download to or using cached %s", cached_file | |||||
| ) | |||||
| if not os.path.exists(cached_file): | |||||
| if is_distributed(): | |||||
| logger.warning( | |||||
| "Downloading serialized object in DISTRIBUTED mode\n" | |||||
| " File may be downloaded multiple times. We recommend\n" | |||||
| " users to download in single process first." | |||||
| ) | |||||
| download_from_url(url, cached_file, HTTP_READ_TIMEOUT) | |||||
| state_dict = _mge_load_serialized(cached_file) | |||||
| return state_dict | |||||
| class pretrained: | |||||
| r""" | |||||
| Decorator which helps to download pretrained weights from the given url. | |||||
| For example, we can decorate a resnet18 function as follows | |||||
| .. code-block:: | |||||
| @hub.pretrained("https://url/to/pretrained_resnet18.pkl") | |||||
| def resnet18(**kwargs): | |||||
| return ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) | |||||
| When decorated function is called with ``pretrained=True``, MegEngine will automatically | |||||
| download and fill the returned model with pretrained weights. | |||||
| """ | |||||
| def __init__(self, url): | |||||
| self.url = url | |||||
| def __call__(self, func): | |||||
| @functools.wraps(func) | |||||
| def pretrained_model_func( | |||||
| pretrained=False, **kwargs | |||||
| ): # pylint: disable=redefined-outer-name | |||||
| model = func(**kwargs) | |||||
| if pretrained: | |||||
| weights = load_serialized_obj_from_url(self.url) | |||||
| model.load_state_dict(weights) | |||||
| return model | |||||
| return pretrained_model_func | |||||
| __all__ = [ | |||||
| "list", | |||||
| "load", | |||||
| "help", | |||||
| "load_serialized_obj_from_url", | |||||
| "pretrained", | |||||
| "import_module", | |||||
| ] | |||||
| @@ -0,0 +1,48 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| import importlib.util | |||||
| import os | |||||
| import types | |||||
| from contextlib import contextmanager | |||||
| from typing import Iterator | |||||
| def load_module(name: str, path: str) -> types.ModuleType: | |||||
| """ | |||||
| Loads module specified by name and path | |||||
| :param name: module name | |||||
| :param path: module path | |||||
| """ | |||||
| spec = importlib.util.spec_from_file_location(name, path) | |||||
| module = importlib.util.module_from_spec(spec) | |||||
| spec.loader.exec_module(module) | |||||
| return module | |||||
| def check_module_exists(module: str) -> bool: | |||||
| """Checks whether python module exists or not | |||||
| :param module: name of module | |||||
| """ | |||||
| return importlib.util.find_spec(module) is not None | |||||
| @contextmanager | |||||
| def cd(target: str) -> Iterator[None]: | |||||
| """Changes current directory to target | |||||
| :param target: target directory | |||||
| """ | |||||
| prev = os.getcwd() | |||||
| os.chdir(os.path.expanduser(target)) | |||||
| try: | |||||
| yield | |||||
| finally: | |||||
| os.chdir(prev) | |||||
| @@ -0,0 +1,237 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| import contextlib | |||||
| import logging | |||||
| import os | |||||
| import sys | |||||
| _all_loggers = [] | |||||
| _default_level_name = os.getenv("MEGENGINE_LOGGING_LEVEL", "ERROR") | |||||
| _default_level = logging.getLevelName(_default_level_name.upper()) | |||||
| def set_log_file(fout, mode="a"): | |||||
| r"""Sets log output file. | |||||
| :type fout: str or file-like | |||||
| :param fout: file-like object that supports write and flush, or string for | |||||
| the filename | |||||
| :type mode: str | |||||
| :param mode: specify the mode to open log file if *fout* is a string | |||||
| """ | |||||
| if isinstance(fout, str): | |||||
| fout = open(fout, mode) | |||||
| MegEngineLogFormatter.log_fout = fout | |||||
| class MegEngineLogFormatter(logging.Formatter): | |||||
| log_fout = None | |||||
| date_full = "[%(asctime)s %(lineno)d@%(filename)s:%(name)s] " | |||||
| date = "%(asctime)s " | |||||
| msg = "%(message)s" | |||||
| max_lines = 256 | |||||
| def _color_exc(self, msg): | |||||
| r"""Sets the color of message as the execution type. | |||||
| """ | |||||
| return "\x1b[34m{}\x1b[0m".format(msg) | |||||
| def _color_dbg(self, msg): | |||||
| r"""Sets the color of message as the debugging type. | |||||
| """ | |||||
| return "\x1b[36m{}\x1b[0m".format(msg) | |||||
| def _color_warn(self, msg): | |||||
| r"""Sets the color of message as the warning type. | |||||
| """ | |||||
| return "\x1b[1;31m{}\x1b[0m".format(msg) | |||||
| def _color_err(self, msg): | |||||
| r"""Sets the color of message as the error type. | |||||
| """ | |||||
| return "\x1b[1;4;31m{}\x1b[0m".format(msg) | |||||
| def _color_omitted(self, msg): | |||||
| r"""Sets the color of message as the omitted type. | |||||
| """ | |||||
| return "\x1b[35m{}\x1b[0m".format(msg) | |||||
| def _color_normal(self, msg): | |||||
| r"""Sets the color of message as the normal type. | |||||
| """ | |||||
| return msg | |||||
| def _color_date(self, msg): | |||||
| r"""Sets the color of message the same as date. | |||||
| """ | |||||
| return "\x1b[32m{}\x1b[0m".format(msg) | |||||
| def format(self, record): | |||||
| if record.levelno == logging.DEBUG: | |||||
| mcl, mtxt = self._color_dbg, "DBG" | |||||
| elif record.levelno == logging.WARNING: | |||||
| mcl, mtxt = self._color_warn, "WRN" | |||||
| elif record.levelno == logging.ERROR: | |||||
| mcl, mtxt = self._color_err, "ERR" | |||||
| else: | |||||
| mcl, mtxt = self._color_normal, "" | |||||
| if mtxt: | |||||
| mtxt += " " | |||||
| if self.log_fout: | |||||
| self.__set_fmt(self.date_full + mtxt + self.msg) | |||||
| formatted = super(MegEngineLogFormatter, self).format(record) | |||||
| nr_line = formatted.count("\n") + 1 | |||||
| if nr_line >= self.max_lines: | |||||
| head, body = formatted.split("\n", 1) | |||||
| formatted = "\n".join( | |||||
| [ | |||||
| head, | |||||
| "BEGIN_LONG_LOG_{}_LINES{{".format(nr_line - 1), | |||||
| body, | |||||
| "}}END_LONG_LOG_{}_LINES".format(nr_line - 1), | |||||
| ] | |||||
| ) | |||||
| self.log_fout.write(formatted) | |||||
| self.log_fout.write("\n") | |||||
| self.log_fout.flush() | |||||
| self.__set_fmt(self._color_date(self.date) + mcl(mtxt + self.msg)) | |||||
| formatted = super(MegEngineLogFormatter, self).format(record) | |||||
| if record.exc_text or record.exc_info: | |||||
| # handle exception format | |||||
| b = formatted.find("Traceback ") | |||||
| if b != -1: | |||||
| s = formatted[b:] | |||||
| s = self._color_exc(" " + s.replace("\n", "\n ")) | |||||
| formatted = formatted[:b] + s | |||||
| nr_line = formatted.count("\n") + 1 | |||||
| if nr_line >= self.max_lines: | |||||
| lines = formatted.split("\n") | |||||
| remain = self.max_lines // 2 | |||||
| removed = len(lines) - remain * 2 | |||||
| if removed > 0: | |||||
| mid_msg = self._color_omitted( | |||||
| "[{} log lines omitted (would be written to output file " | |||||
| "if set_log_file() has been called;\n" | |||||
| " the threshold can be set at " | |||||
| "MegEngineLogFormatter.max_lines)]".format(removed) | |||||
| ) | |||||
| formatted = "\n".join(lines[:remain] + [mid_msg] + lines[-remain:]) | |||||
| return formatted | |||||
| if sys.version_info.major < 3: | |||||
| def __set_fmt(self, fmt): | |||||
| self._fmt = fmt | |||||
| else: | |||||
| def __set_fmt(self, fmt): | |||||
| self._style._fmt = fmt | |||||
| def get_logger(name=None, formatter=MegEngineLogFormatter): | |||||
| r"""Gets megengine logger with given name. | |||||
| """ | |||||
| logger = logging.getLogger(name) | |||||
| if getattr(logger, "_init_done__", None): | |||||
| return logger | |||||
| logger._init_done__ = True | |||||
| logger.propagate = False | |||||
| logger.setLevel(_default_level) | |||||
| handler = logging.StreamHandler() | |||||
| handler.setFormatter(formatter(datefmt="%d %H:%M:%S")) | |||||
| handler.setLevel(0) | |||||
| del logger.handlers[:] | |||||
| logger.addHandler(handler) | |||||
| _all_loggers.append(logger) | |||||
| return logger | |||||
| def set_log_level(level, update_existing=True): | |||||
| """Sets default logging level. | |||||
| :type level: int e.g. logging.INFO | |||||
| :param level: loggin level given by python :mod:`logging` module | |||||
| :param update_existing: whether to update existing loggers | |||||
| """ | |||||
| global _default_level # pylint: disable=global-statement | |||||
| _default_level = level | |||||
| if update_existing: | |||||
| for i in _all_loggers: | |||||
| i.setLevel(level) | |||||
| _logger = get_logger(__name__) | |||||
| try: | |||||
| if sys.version_info.major < 3: | |||||
| raise ImportError() | |||||
| from .core._imperative_rt.utils import Logger as _imperative_rt_logger | |||||
| class MegBrainLogFormatter(MegEngineLogFormatter): | |||||
| date = "%(asctime)s[mgb] " | |||||
| def _color_date(self, msg): | |||||
| return "\x1b[33m{}\x1b[0m".format(msg) | |||||
| _megbrain_logger = get_logger("megbrain", MegBrainLogFormatter) | |||||
| _imperative_rt_logger.set_log_handler(_megbrain_logger) | |||||
| if _default_level == logging.getLevelName("ERROR"): | |||||
| _imperative_rt_logger.set_log_level(_imperative_rt_logger.LogLevel.Error) | |||||
| elif _default_level == logging.getLevelName("INFO"): | |||||
| _imperative_rt_logger.set_log_level(_imperative_rt_logger.LogLevel.Info) | |||||
| else: | |||||
| _imperative_rt_logger.set_log_level(_imperative_rt_logger.LogLevel.Debug) | |||||
| def set_mgb_log_level(level): | |||||
| r"""Sets megbrain log level | |||||
| :type level: int e.g. logging.INFO | |||||
| :param level: new log level | |||||
| :return: original log level | |||||
| """ | |||||
| logger = _megbrain_logger | |||||
| rst = logger.getEffectiveLevel() | |||||
| logger.setLevel(level) | |||||
| return rst | |||||
| except ImportError as exc: | |||||
| def set_mgb_log_level(level): | |||||
| raise NotImplementedError("imperative_rt has not been imported") | |||||
| @contextlib.contextmanager | |||||
| def replace_mgb_log_level(level): | |||||
| r"""Replaces megbrain log level in a block and restore after exiting. | |||||
| :type level: int e.g. logging.INFO | |||||
| :param level: new log level | |||||
| """ | |||||
| old = set_mgb_log_level(level) | |||||
| try: | |||||
| yield | |||||
| finally: | |||||
| set_mgb_log_level(old) | |||||
| def enable_debug_log(): | |||||
| r"""Sets logging level to debug for all components. | |||||
| """ | |||||
| set_log_level(logging.DEBUG) | |||||
| set_mgb_log_level(logging.DEBUG) | |||||
| @@ -0,0 +1,24 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| from .activation import LeakyReLU, PReLU, ReLU, Sigmoid, Softmax | |||||
| from .batchnorm import BatchNorm1d, BatchNorm2d, SyncBatchNorm | |||||
| from .concat import Concat | |||||
| from .conv import Conv2d, ConvRelu2d, ConvTranspose2d, LocalConv2d | |||||
| from .conv_bn import ConvBn2d, ConvBnRelu2d | |||||
| from .dropout import Dropout | |||||
| from .elemwise import Elemwise | |||||
| from .embedding import Embedding | |||||
| from .identity import Identity | |||||
| from .linear import Linear | |||||
| from .module import Module | |||||
| from .parampack import ParamPack | |||||
| from .pooling import AvgPool2d, MaxPool2d | |||||
| from .quant_dequant import DequantStub, QuantStub | |||||
| from .sequential import Sequential | |||||
| @@ -0,0 +1,231 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| import numpy as np | |||||
| from ..functional import leaky_relu, prelu, relu, sigmoid, softmax | |||||
| from ..tensor_nn import Parameter | |||||
| from .module import Module | |||||
| class Softmax(Module): | |||||
| r""" | |||||
| Applies a softmax function. Softmax is defined as: | |||||
| .. math:: | |||||
| \text{Softmax}(x_{i}) = \frac{exp(x_i)}{\sum_j exp(x_j)} | |||||
| It is applied to an n-dimensional input Tensor and rescaling them so that the elements of the | |||||
| n-dimensional output Tensor lie in the range of `[0, 1]` and sum to 1. | |||||
| :param axis: An axis along which softmax will be applied. By default, | |||||
| softmax will apply along the highest ranked axis. | |||||
| Examples: | |||||
| .. testcode:: | |||||
| import numpy as np | |||||
| import megengine as mge | |||||
| import megengine.module as M | |||||
| data = mge.tensor(np.array([-2,-1,0,1,2]).astype(np.float32)) | |||||
| softmax = M.Softmax() | |||||
| output = softmax(data) | |||||
| with np.printoptions(precision=6): | |||||
| print(output.numpy()) | |||||
| Outputs: | |||||
| .. testoutput:: | |||||
| [0.011656 0.031685 0.086129 0.234122 0.636409] | |||||
| """ | |||||
| def __init__(self, axis=None): | |||||
| super().__init__() | |||||
| self.axis = axis | |||||
| def forward(self, inputs): | |||||
| return softmax(inputs, self.axis) | |||||
| class Sigmoid(Module): | |||||
| r""" | |||||
| Applies the element-wise function: | |||||
| .. math:: | |||||
| \text{Sigmoid}(x) = \frac{1}{1 + \exp(-x)} | |||||
| Examples: | |||||
| .. testcode:: | |||||
| import numpy as np | |||||
| import megengine as mge | |||||
| import megengine.module as M | |||||
| data = mge.tensor(np.array([-2,-1,0,1,2,]).astype(np.float32)) | |||||
| sigmoid = M.Sigmoid() | |||||
| output = sigmoid(data) | |||||
| with np.printoptions(precision=6): | |||||
| print(output.numpy()) | |||||
| Outputs: | |||||
| .. testoutput:: | |||||
| [0.119203 0.268941 0.5 0.731059 0.880797] | |||||
| """ | |||||
| def forward(self, inputs): | |||||
| return sigmoid(inputs) | |||||
| class ReLU(Module): | |||||
| r""" | |||||
| Applies the element-wise function: | |||||
| .. math:: | |||||
| \text{ReLU}(x) = \max(x, 0) | |||||
| Examples: | |||||
| .. testcode:: | |||||
| import numpy as np | |||||
| import megengine as mge | |||||
| import megengine.module as M | |||||
| data = mge.tensor(np.array([-2,-1,0,1,2,]).astype(np.float32)) | |||||
| relu = M.ReLU() | |||||
| output = relu(data) | |||||
| with np.printoptions(precision=6): | |||||
| print(output.numpy()) | |||||
| Outputs: | |||||
| .. testoutput:: | |||||
| [0. 0. 0. 1. 2.] | |||||
| """ | |||||
| def forward(self, x): | |||||
| return relu(x) | |||||
| class PReLU(Module): | |||||
| r""" | |||||
| Applies the element-wise function: | |||||
| .. math:: | |||||
| \text{PReLU}(x) = \max(0,x) + a * \min(0,x) | |||||
| or | |||||
| .. math:: | |||||
| \text{PReLU}(x) = | |||||
| \begin{cases} | |||||
| x, & \text{ if } x \geq 0 \\ | |||||
| ax, & \text{ otherwise } | |||||
| \end{cases} | |||||
| Here :math:`a` is a learnable parameter. When called without arguments, `PReLU()` uses | |||||
| a single paramter :math:`a` across all input channel. If called with `PReLU(num_of_channels)`, | |||||
| a seperate :math:`a` is used for each input channle. | |||||
| :param num_parameters: number of :math:`a` to learn, there is only two | |||||
| values are legitimate: 1, or the number of channels at input. Default: 1 | |||||
| :param init: the initial value of :math:`a`. Default: 0.25 | |||||
| Examples: | |||||
| .. testcode:: | |||||
| import numpy as np | |||||
| import megengine as mge | |||||
| import megengine.module as M | |||||
| data = mge.tensor(np.array([-1.2, -3.7, 2.7]).astype(np.float32)) | |||||
| prelu = M.PReLU() | |||||
| output = prelu(data) | |||||
| print(output.numpy()) | |||||
| Outputs: | |||||
| .. testoutput:: | |||||
| [-0.3 -0.925 2.7 ] | |||||
| """ | |||||
| def __init__(self, num_parameters: int = 1, init: float = 0.25): | |||||
| super().__init__() | |||||
| self.num_parameters = num_parameters | |||||
| if num_parameters > 1: | |||||
| # Assume format is NCHW | |||||
| self.weight = Parameter( | |||||
| data=np.full((1, num_parameters, 1, 1), init, dtype=np.float32) | |||||
| ) | |||||
| else: | |||||
| self.weight = Parameter(data=[init]) | |||||
| def forward(self, inputs): | |||||
| assert self.weight.shape == (1,) or self.weight.shape == ( | |||||
| 1, | |||||
| int(inputs.shape[1]), | |||||
| 1, | |||||
| 1, | |||||
| ), "invalid weight's shape" | |||||
| return prelu(inputs, self.weight) | |||||
| class LeakyReLU(Module): | |||||
| r""" | |||||
| Applies the element-wise function: | |||||
| .. math:: | |||||
| \text{LeakyReLU}(x) = \max(0,x) + negative\_slope \times \min(0,x) | |||||
| or | |||||
| .. math:: | |||||
| \text{LeakyReLU}(x) = | |||||
| \begin{cases} | |||||
| x, & \text{ if } x \geq 0 \\ | |||||
| negative\_slope \times x, & \text{ otherwise } | |||||
| \end{cases} | |||||
| Examples: | |||||
| .. testcode:: | |||||
| import numpy as np | |||||
| import megengine as mge | |||||
| import megengine.module as M | |||||
| data = mge.tensor(np.array([-8, -12, 6, 10]).astype(np.float32)) | |||||
| leakyrelu = M.LeakyReLU(0.01) | |||||
| output = leakyrelu(data) | |||||
| print(output.numpy()) | |||||
| Outputs: | |||||
| .. testoutput:: | |||||
| [-0.08 -0.12 6. 10. ] | |||||
| """ | |||||
| def __init__(self, negative_slope: float = 0.01): | |||||
| super().__init__() | |||||
| self.negative_slope = negative_slope | |||||
| def forward(self, inputs): | |||||
| return leaky_relu(inputs, self.negative_slope) | |||||
| @@ -0,0 +1,281 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| from typing import Optional | |||||
| import numpy as np | |||||
| from ..distributed.group import WORLD, Group | |||||
| from ..functional import batch_norm2d, sync_batch_norm | |||||
| from ..tensor_nn import Buffer, Parameter | |||||
| from . import init | |||||
| from .module import Module | |||||
| class _BatchNorm(Module): | |||||
| def __init__( | |||||
| self, | |||||
| num_features, | |||||
| eps=1e-5, | |||||
| momentum=0.9, | |||||
| affine=True, | |||||
| track_running_stats=True, | |||||
| freeze=False, | |||||
| ): | |||||
| super(_BatchNorm, self).__init__() | |||||
| self.num_features = num_features | |||||
| self.eps = eps | |||||
| self.momentum = momentum | |||||
| self.affine = affine | |||||
| self.track_running_stats = track_running_stats | |||||
| self._track_running_stats_saved = track_running_stats | |||||
| self.freeze = freeze | |||||
| if self.affine: | |||||
| self.weight = Parameter(np.ones(num_features, dtype=np.float32)) | |||||
| self.bias = Parameter(np.zeros(num_features, dtype=np.float32)) | |||||
| else: | |||||
| self.weight = None | |||||
| self.bias = None | |||||
| tshape = (1, self.num_features, 1, 1) | |||||
| if self.track_running_stats: | |||||
| self.running_mean = Buffer(np.zeros(tshape, dtype=np.float32)) | |||||
| self.running_var = Buffer(np.ones(tshape, dtype=np.float32)) | |||||
| else: | |||||
| self.running_mean = None | |||||
| self.running_var = None | |||||
| def reset_running_stats(self) -> None: | |||||
| if self.track_running_stats: | |||||
| init.zeros_(self.running_mean) | |||||
| init.ones_(self.running_var) | |||||
| def reset_parameters(self) -> None: | |||||
| self.reset_running_stats() | |||||
| if self.affine: | |||||
| init.ones_(self.weight) | |||||
| init.zeros_(self.bias) | |||||
| def _check_input_ndim(self, inp): | |||||
| raise NotImplementedError | |||||
| def forward(self, inp): | |||||
| self._check_input_ndim(inp) | |||||
| if self._track_running_stats_saved == False: | |||||
| assert ( | |||||
| self.track_running_stats == False | |||||
| ), "track_running_stats can not be initilized to False and changed to True later" | |||||
| _ndims = len(inp.shape) | |||||
| if _ndims != 4: | |||||
| origin_shape = inp.shapeof() | |||||
| if _ndims == 2: | |||||
| n, c = inp.shapeof(0), inp.shapeof(1) | |||||
| new_shape = (n, c, 1, 1) | |||||
| elif _ndims == 3: | |||||
| n, c, h = inp.shapeof(0), inp.shapeof(1), inp.shapeof(2) | |||||
| new_shape = (n, c, h, 1) | |||||
| inp = inp.reshape(new_shape) | |||||
| if self.freeze and self.training and self._track_running_stats_saved: | |||||
| scale = self.weight.reshape(1, -1, 1, 1) * ( | |||||
| self.running_var + self.eps | |||||
| ) ** (-0.5) | |||||
| bias = self.bias.reshape(1, -1, 1, 1) - self.running_mean * scale | |||||
| return inp * scale.detach() + bias.detach() | |||||
| if self.training and self.track_running_stats: | |||||
| exponential_average_factor = self.momentum | |||||
| else: | |||||
| exponential_average_factor = 0.0 # useless | |||||
| output = batch_norm2d( | |||||
| inp, | |||||
| self.running_mean if self.track_running_stats else None, | |||||
| self.running_var if self.track_running_stats else None, | |||||
| self.weight, | |||||
| self.bias, | |||||
| training=self.training | |||||
| or ((self.running_mean is None) and (self.running_var is None)), | |||||
| momentum=exponential_average_factor, | |||||
| eps=self.eps, | |||||
| ) | |||||
| if _ndims != 4: | |||||
| output = output.reshape(origin_shape) | |||||
| return output | |||||
| class SyncBatchNorm(_BatchNorm): | |||||
| r""" | |||||
| Applies Synchronization Batch Normalization. | |||||
| """ | |||||
| def __init__( | |||||
| self, | |||||
| num_features, | |||||
| eps=1e-5, | |||||
| momentum=0.9, | |||||
| affine=True, | |||||
| track_running_stats=True, | |||||
| freeze=False, | |||||
| group: Optional[Group] = None, | |||||
| ) -> None: | |||||
| super().__init__( | |||||
| num_features, eps, momentum, affine, track_running_stats, freeze | |||||
| ) | |||||
| self.group = group | |||||
| def _check_input_ndim(self, inp): | |||||
| if len(inp.shape) not in {2, 3, 4}: | |||||
| raise ValueError( | |||||
| "expected 2D, 3D or 4D input (got {}D input)".format(len(inp.shape)) | |||||
| ) | |||||
| def forward(self, inp): | |||||
| self._check_input_ndim(inp) | |||||
| _ndims = len(inp.shape) | |||||
| if _ndims != 4: | |||||
| origin_shape = inp.shapeof() | |||||
| if _ndims == 2: | |||||
| n, c = inp.shapeof(0), inp.shapeof(1) | |||||
| new_shape = (n, c, 1, 1) | |||||
| elif _ndims == 3: | |||||
| n, c, h = inp.shapeof(0), inp.shapeof(1), inp.shapeof(2) | |||||
| new_shape = (n, c, h, 1) | |||||
| inp = inp.reshape(new_shape) | |||||
| if self.training and self.track_running_stats: | |||||
| exponential_average_factor = self.momentum | |||||
| else: | |||||
| exponential_average_factor = 0.0 # useless | |||||
| output = sync_batch_norm( | |||||
| inp, | |||||
| self.running_mean, | |||||
| self.running_var, | |||||
| self.weight, | |||||
| self.bias, | |||||
| self.training or not self.track_running_stats, | |||||
| exponential_average_factor, | |||||
| self.eps, | |||||
| group=self.group, | |||||
| ) | |||||
| if _ndims != 4: | |||||
| output = output.reshape(origin_shape) | |||||
| return output | |||||
| class BatchNorm1d(_BatchNorm): | |||||
| r""" | |||||
| Applies Batch Normalization over a 2D/3D tensor. | |||||
| Refer to :class:`~.BatchNorm2d` for more information. | |||||
| """ | |||||
| def _check_input_ndim(self, inp): | |||||
| if len(inp.shape) not in {2, 3}: | |||||
| raise ValueError( | |||||
| "expected 2D or 3D input (got {}D input)".format(len(inp.shape)) | |||||
| ) | |||||
| class BatchNorm2d(_BatchNorm): | |||||
| r""" | |||||
| Applies Batch Normalization over a 4D tensor. | |||||
| .. math:: | |||||
| y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta | |||||
| The mean and standard-deviation are calculated per-dimension over | |||||
| the mini-batches and :math:`\gamma` and :math:`\beta` are learnable | |||||
| parameter vectors. | |||||
| By default, during training this layer keeps running estimates of its | |||||
| computed mean and variance, which are then used for normalization during | |||||
| evaluation. The running estimates are kept with a default :attr:`momentum` | |||||
| of 0.9. | |||||
| If :attr:`track_running_stats` is set to ``False``, this layer will not | |||||
| keep running estimates, and batch statistics are instead used during | |||||
| evaluation time. | |||||
| .. note:: | |||||
| This :attr:`momentum` argument is different from one used in optimizer | |||||
| classes and the conventional notion of momentum. Mathematically, the | |||||
| update rule for running statistics here is | |||||
| :math:`\hat{x}_\text{new} = \text{momentum} \times \hat{x} + (1 - \text{momentum}) \times x_t`, | |||||
| where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the | |||||
| new observed value. | |||||
| Because the Batch Normalization is done over the `C` dimension, computing | |||||
| statistics on `(N, H, W)` slices, it's common terminology to call this | |||||
| Spatial Batch Normalization. | |||||
| :type num_features: int | |||||
| :param num_features: usually the :math:`C` from an input of size | |||||
| :math:`(N, C, H, W)` or the highest ranked dimension of an input with | |||||
| less than 4D. | |||||
| :type eps: float | |||||
| :param eps: a value added to the denominator for numerical stability. | |||||
| Default: 1e-5. | |||||
| :type momentum: float | |||||
| :param momentum: the value used for the `running_mean` and `running_var` | |||||
| computation. | |||||
| Default: 0.9 | |||||
| :type affine: bool | |||||
| :param affine: a boolean value that when set to ``True``, this module has | |||||
| learnable affine parameters. Default: ``True`` | |||||
| :type track_running_stats: bool | |||||
| :param track_running_stats: when set to ``True``, this module tracks the | |||||
| running mean and variance. When set to ``False``, this module does not | |||||
| track such statistics and always uses batch statistics in both training | |||||
| and eval modes. Default: ``True``. | |||||
| :type freeze: bool | |||||
| :param freeze: when set to ``True``, this module does not update the | |||||
| running mean and variance, and uses the running mean and variance instead of | |||||
| the batch mean and batch variance to normalize the input. The parameter takes effect | |||||
| only when the module is initilized with ``track_running_stats`` as ``True`` and | |||||
| the module is in training mode. | |||||
| Default: ``False``. | |||||
| Examples: | |||||
| .. testcode:: | |||||
| import megengine as mge | |||||
| import megengine.module as M | |||||
| # With Learnable Parameters | |||||
| m = M.BatchNorm2d(4) | |||||
| inp = mge.tensor(np.random.rand(1, 4, 3, 3).astype("float32")) | |||||
| oup = m(inp) | |||||
| print(m.weight, m.bias) | |||||
| # Without Learnable Parameters | |||||
| m = M.BatchNorm2d(4, affine=False) | |||||
| oup = m(inp) | |||||
| print(m.weight, m.bias) | |||||
| .. testoutput:: | |||||
| Tensor([1. 1. 1. 1.]) Tensor([0. 0. 0. 0.]) | |||||
| None None | |||||
| """ | |||||
| def _check_input_ndim(self, inp): | |||||
| if len(inp.shape) != 4: | |||||
| raise ValueError("expected 4D input (got {}D input)".format(len(inp.shape))) | |||||
| @@ -0,0 +1,22 @@ | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| from typing import Iterable | |||||
| from ..functional import concat | |||||
| from ..tensor import Tensor | |||||
| from .module import Module | |||||
| class Concat(Module): | |||||
| r""" | |||||
| A :class:`~.Module` to do functional concat. Could be replaced with :class:`~.QATModule` | |||||
| version :class:`~.qat.concat.Concat` using :func:`~.quantize.quantize_qat`. | |||||
| """ | |||||
| def forward(self, inps: Iterable[Tensor], axis: int = 0): | |||||
| return concat(inps, axis) | |||||
| @@ -0,0 +1,391 @@ | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| from abc import abstractmethod | |||||
| from typing import Tuple, Union | |||||
| import numpy as np | |||||
| from ..core.ops._internal import param_defs as P | |||||
| from ..functional import conv2d, conv_transpose2d, local_conv2d, relu | |||||
| from ..functional.types import _pair, _pair_nonzero | |||||
| from ..tensor_nn import Parameter | |||||
| from . import init | |||||
| from .module import Module | |||||
| class _ConvNd(Module): | |||||
| """base class for convolution modules, including transposed conv""" | |||||
| def __init__( | |||||
| self, | |||||
| in_channels: int, | |||||
| out_channels: int, | |||||
| kernel_size: Union[int, Tuple[int, int]], | |||||
| stride: Union[int, Tuple[int, int]], | |||||
| padding: Union[int, Tuple[int, int]], | |||||
| dilation: Union[int, Tuple[int, int]], | |||||
| groups: int, | |||||
| bias: bool = True, | |||||
| ): | |||||
| super().__init__() | |||||
| if in_channels % groups != 0: | |||||
| raise ValueError("in_channels must be divisible by groups") | |||||
| if out_channels % groups != 0: | |||||
| raise ValueError("out_channels must be divisible by groups") | |||||
| self.in_channels = in_channels | |||||
| self.out_channels = out_channels | |||||
| self.kernel_size = kernel_size | |||||
| self.stride = stride | |||||
| self.padding = padding | |||||
| self.dilation = dilation | |||||
| self.groups = groups | |||||
| self.weight = Parameter(np.zeros(self._infer_weight_shape(), dtype=np.float32)) | |||||
| self.bias = None | |||||
| if bias: | |||||
| self.bias = Parameter(np.zeros(self._infer_bias_shape(), dtype=np.float32)) | |||||
| self.reset_parameters() | |||||
| @abstractmethod | |||||
| def _get_fanin(self): | |||||
| pass | |||||
| def reset_parameters(self) -> None: | |||||
| fanin = self._get_fanin() | |||||
| std = np.sqrt(1 / fanin) | |||||
| init.normal_(self.weight, 0.0, std) | |||||
| if self.bias is not None: | |||||
| init.zeros_(self.bias) | |||||
| @abstractmethod | |||||
| def _infer_weight_shape(self): | |||||
| pass | |||||
| @abstractmethod | |||||
| def _infer_bias_shape(self): | |||||
| pass | |||||
| class Conv2d(_ConvNd): | |||||
| r"""Applies a 2D convolution over an input tensor. | |||||
| For instance, given an input of the size :math:`(N, C_{\text{in}}, H, W)`, | |||||
| this layer generates an output of the size | |||||
| :math:`(N, C_{\text{out}}, H_{\text{out}}, W_{\text{out}})` through the | |||||
| process described as below: | |||||
| .. math:: | |||||
| \text{out}(N_i, C_{\text{out}_j}) = \text{bias}(C_{\text{out}_j}) + | |||||
| \sum_{k = 0}^{C_{\text{in}} - 1} \text{weight}(C_{\text{out}_j}, k) \star \text{input}(N_i, k) | |||||
| where :math:`\star` is the valid 2D cross-correlation operator, | |||||
| :math:`N` is a batch size, :math:`C` denotes a number of channels, | |||||
| :math:`H` is a height of input planes in pixels, and :math:`W` is | |||||
| width in pixels. | |||||
| When ``groups == in_channels`` and ``out_channels == K * in_channels``, | |||||
| where `K` is a positive integer, this operation is also known as depthwise | |||||
| convolution. | |||||
| In other words, for an input of size :math:`(N, C_{in}, H_{in}, W_{in})`, | |||||
| a depthwise convolution with a depthwise multiplier `K`, can be constructed | |||||
| by arguments :math:`(in\_channels=C_{in}, out\_channels=C_{in} \times K, ..., groups=C_{in})`. | |||||
| :param in_channels: number of input channels. | |||||
| :param out_channels: number of output channels. | |||||
| :param kernel_size: size of weight on spatial dimensions. If ``kernel_size`` is | |||||
| an :class:`int`, the actual kernel size would be | |||||
| ``(kernel_size, kernel_size)``. Default: 1 | |||||
| :param stride: stride of the 2D convolution operation. Default: 1 | |||||
| :param padding: size of the paddings added to the input on both sides of its | |||||
| spatial dimensions. Only zero-padding is supported. Default: 0 | |||||
| :param dilation: dilation of the 2D convolution operation. Default: 1 | |||||
| :param groups: number of groups to divide input and output channels into, | |||||
| so as to perform a "grouped convolution". When ``groups`` is not 1, | |||||
| ``in_channels`` and ``out_channels`` must be divisible by ``groups``, | |||||
| and there would be an extra dimension at the beginning of the weight's | |||||
| shape. Specifically, the shape of weight would be ``(groups, | |||||
| out_channel // groups, in_channels // groups, *kernel_size)``. | |||||
| :param bias: whether to add a bias onto the result of convolution. Default: | |||||
| True | |||||
| :param conv_mode: Supports `CROSS_CORRELATION` or `CONVOLUTION`. Default: | |||||
| `CROSS_CORRELATION`. | |||||
| :param compute_mode: When set to `DEFAULT`, no special requirements will be | |||||
| placed on the precision of intermediate results. When set to `FLOAT32`, | |||||
| float32 would be used for accumulator and intermediate result, but only | |||||
| effective when input and output are of float16 dtype. | |||||
| """ | |||||
| _conv_mode_type = P.Convolution.Mode | |||||
| _compute_mode_type = P.Convolution.ComputeMode | |||||
| def __init__( | |||||
| self, | |||||
| in_channels: int, | |||||
| out_channels: int, | |||||
| kernel_size: Union[int, Tuple[int, int]], | |||||
| stride: Union[int, Tuple[int, int]] = 1, | |||||
| padding: Union[int, Tuple[int, int]] = 0, | |||||
| dilation: Union[int, Tuple[int, int]] = 1, | |||||
| groups: int = 1, | |||||
| bias: bool = True, | |||||
| conv_mode: str = "CROSS_CORRELATION", | |||||
| compute_mode: str = "DEFAULT", | |||||
| ): | |||||
| kernel_size = _pair_nonzero(kernel_size) | |||||
| stride = _pair_nonzero(stride) | |||||
| padding = _pair(padding) | |||||
| dilation = _pair_nonzero(dilation) | |||||
| self.conv_mode = self._conv_mode_type.convert(conv_mode) | |||||
| self.compute_mode = self._compute_mode_type.convert(compute_mode) | |||||
| super().__init__( | |||||
| in_channels, | |||||
| out_channels, | |||||
| kernel_size, | |||||
| stride, | |||||
| padding, | |||||
| dilation, | |||||
| groups, | |||||
| bias, | |||||
| ) | |||||
| def _get_fanin(self): | |||||
| kh, kw = self.kernel_size | |||||
| ic = self.in_channels | |||||
| return kh * kw * ic | |||||
| def _infer_weight_shape(self): | |||||
| group = self.groups | |||||
| ichl = self.in_channels | |||||
| ochl = self.out_channels | |||||
| kh, kw = self.kernel_size | |||||
| if group == 1: | |||||
| # Assume format is NCHW | |||||
| return (ochl, ichl, kh, kw) | |||||
| assert ( | |||||
| ichl % group == 0 and ochl % group == 0 | |||||
| ), "invalid config: input_channels={} output_channels={} group={}".format( | |||||
| ichl, ochl, group | |||||
| ) | |||||
| # Assume format is NCHW | |||||
| return (group, ochl // group, ichl // group, kh, kw) | |||||
| def _infer_bias_shape(self): | |||||
| # Assume format is NCHW | |||||
| return (1, self.out_channels, 1, 1) | |||||
| def calc_conv(self, inp, weight, bias): | |||||
| return conv2d( | |||||
| inp, | |||||
| weight, | |||||
| bias, | |||||
| self.stride, | |||||
| self.padding, | |||||
| self.dilation, | |||||
| self.groups, | |||||
| self.conv_mode, | |||||
| self.compute_mode, | |||||
| ) | |||||
| def forward(self, inp): | |||||
| return self.calc_conv(inp, self.weight, self.bias) | |||||
| class ConvTranspose2d(_ConvNd): | |||||
| r"""Applies a 2D transposed convolution over an input tensor. | |||||
| This module is also known as a deconvolution or a fractionally-strided convolution. | |||||
| :class:`ConvTranspose2d` can ben seen as the gradient of :class:`Conv2d` operation | |||||
| with respect to its input. | |||||
| Convolution usually reduces the size of input, while transposed convolution works | |||||
| the opposite way, transforming a smaller input to a larger output while preserving the | |||||
| connectivity pattern. | |||||
| :param in_channels: number of input channels. | |||||
| :param out_channels: number of output channels. | |||||
| :param kernel_size: size of weight on spatial dimensions. If ``kernel_size`` is | |||||
| an :class:`int`, the actual kernel size would be | |||||
| ``(kernel_size, kernel_size)``. Default: 1 | |||||
| :param stride: stride of the 2D convolution operation. Default: 1 | |||||
| :param padding: size of the paddings added to the input on both sides of its | |||||
| spatial dimensions. Only zero-padding is supported. Default: 0 | |||||
| :param dilation: dilation of the 2D convolution operation. Default: 1 | |||||
| :param groups: number of groups to divide input and output channels into, | |||||
| so as to perform a "grouped convolution". When ``groups`` is not 1, | |||||
| ``in_channels`` and ``out_channels`` must be divisible by ``groups``, | |||||
| and there would be an extra dimension at the beginning of the weight's | |||||
| shape. Specifically, the shape of weight would be ``(groups, | |||||
| out_channels // groups, in_channels // groups, *kernel_size)``. Default: 1 | |||||
| :param bias: wether to add a bias onto the result of convolution. Default: | |||||
| True | |||||
| :param conv_mode: Supports `CROSS_CORRELATION` or `CONVOLUTION`. Default: | |||||
| `CROSS_CORRELATION`. | |||||
| :param compute_mode: When set to `DEFAULT`, no special requirements will be | |||||
| placed on the precision of intermediate results. When set to `FLOAT32`, | |||||
| float32 would be used for accumulator and intermediate result, but only | |||||
| effective when input and output are of float16 dtype. | |||||
| """ | |||||
| _conv_mode_type = P.Convolution.Mode | |||||
| _compute_mode_type = P.Convolution.ComputeMode | |||||
| def __init__( | |||||
| self, | |||||
| in_channels: int, | |||||
| out_channels: int, | |||||
| kernel_size: Union[int, Tuple[int, int]], | |||||
| stride: Union[int, Tuple[int, int]] = 1, | |||||
| padding: Union[int, Tuple[int, int]] = 0, | |||||
| dilation: Union[int, Tuple[int, int]] = 1, | |||||
| groups: int = 1, | |||||
| bias: bool = True, | |||||
| conv_mode: str = "CROSS_CORRELATION", | |||||
| compute_mode: str = "DEFAULT", | |||||
| ): | |||||
| kernel_size = _pair_nonzero(kernel_size) | |||||
| stride = _pair_nonzero(stride) | |||||
| padding = _pair(padding) | |||||
| dilation = _pair_nonzero(dilation) | |||||
| self.conv_mode = self._conv_mode_type.convert(conv_mode) | |||||
| self.compute_mode = self._compute_mode_type.convert(compute_mode) | |||||
| super().__init__( | |||||
| in_channels, | |||||
| out_channels, | |||||
| kernel_size, | |||||
| stride, | |||||
| padding, | |||||
| dilation, | |||||
| groups, | |||||
| bias, | |||||
| ) | |||||
| def _get_fanin(self): | |||||
| kh, kw = self.kernel_size | |||||
| oc = self.out_channels | |||||
| return kh * kw * oc | |||||
| def _infer_weight_shape(self): | |||||
| group = self.groups | |||||
| ichl = self.in_channels | |||||
| ochl = self.out_channels | |||||
| kh, kw = self.kernel_size | |||||
| if group == 1: | |||||
| # Assume format is NCHW | |||||
| return (ichl, ochl, kh, kw) | |||||
| assert ( | |||||
| ichl % group == 0 and ochl % group == 0 | |||||
| ), "invalid config: input_channels={} output_channels={} group={}".format( | |||||
| ichl, ochl, group | |||||
| ) | |||||
| # Assume format is NCHW | |||||
| return (group, ichl // group, ochl // group, kh, kw) | |||||
| def _infer_bias_shape(self): | |||||
| # Assume format is NCHW | |||||
| return (1, self.out_channels, 1, 1) | |||||
| def forward(self, inp): | |||||
| return conv_transpose2d( | |||||
| inp, | |||||
| self.weight, | |||||
| self.bias, | |||||
| self.stride, | |||||
| self.padding, | |||||
| self.dilation, | |||||
| self.groups, | |||||
| self.conv_mode, | |||||
| self.compute_mode, | |||||
| ) | |||||
| class LocalConv2d(Conv2d): | |||||
| r"""Applies a spatial convolution with untied kernels over an input 4D tensor. | |||||
| It is also known as the locally connected layer. | |||||
| :param in_channels: number of input channels. | |||||
| :param out_channels: number of output channels. | |||||
| :param input_height: the height of the input images. | |||||
| :param input_width: the width of the input images. | |||||
| :param kernel_size: size of weight on spatial dimensions. If ``kernel_size`` is | |||||
| an :class:`int`, the actual kernel size would be | |||||
| ``(kernel_size, kernel_size)``. Default: 1 | |||||
| :param stride: stride of the 2D convolution operation. Default: 1 | |||||
| :param padding: size of the paddings added to the input on both sides of its | |||||
| spatial dimensions. Only zero-padding is supported. Default: 0 | |||||
| :param groups: number of groups to divide input and output channels into, | |||||
| so as to perform a "grouped convolution". When ``groups`` is not 1, | |||||
| ``in_channels`` and ``out_channels`` must be divisible by ``groups``. | |||||
| The shape of weight is ``(groups, output_height, output_width, | |||||
| in_channels // groups, *kernel_size, out_channels // groups)``. | |||||
| """ | |||||
| _conv_mode_type = P.Convolution.Mode | |||||
| def __init__( | |||||
| self, | |||||
| in_channels: int, | |||||
| out_channels: int, | |||||
| input_height: int, | |||||
| input_width: int, | |||||
| kernel_size: Union[int, Tuple[int, int]], | |||||
| stride: Union[int, Tuple[int, int]] = 1, | |||||
| padding: Union[int, Tuple[int, int]] = 0, | |||||
| dilation: Union[int, Tuple[int, int]] = 1, | |||||
| groups: int = 1, | |||||
| conv_mode: str = "CROSS_CORRELATION", | |||||
| ): | |||||
| self.input_height = input_height | |||||
| self.input_width = input_width | |||||
| super().__init__( | |||||
| in_channels, | |||||
| out_channels, | |||||
| kernel_size, | |||||
| stride, | |||||
| padding, | |||||
| dilation, | |||||
| groups, | |||||
| bias=False, | |||||
| ) | |||||
| def _infer_weight_shape(self): | |||||
| group = self.groups | |||||
| output_height = ( | |||||
| self.input_height + self.padding[0] * 2 - self.kernel_size[0] | |||||
| ) // self.stride[0] + 1 | |||||
| output_width = ( | |||||
| self.input_width + self.padding[1] * 2 - self.kernel_size[1] | |||||
| ) // self.stride[1] + 1 | |||||
| # Assume format is NCHW | |||||
| return ( | |||||
| group, | |||||
| output_height, | |||||
| output_width, | |||||
| self.in_channels // group, | |||||
| self.kernel_size[0], | |||||
| self.kernel_size[1], | |||||
| self.out_channels // group, | |||||
| ) | |||||
| def forward(self, inp): | |||||
| return local_conv2d( | |||||
| inp, self.weight, self.stride, self.padding, self.dilation, self.conv_mode | |||||
| ) | |||||
| class ConvRelu2d(Conv2d): | |||||
| r""" | |||||
| A fused :class:`~.Module` including Conv2d and relu. Could be replaced | |||||
| with :class:`~.QATModule` version :class:`~.qat.conv.ConvRelu2d` using | |||||
| :func:`~.quantize.quantize_qat`. | |||||
| """ | |||||
| def forward(self, inp): | |||||
| return relu(self.calc_conv(inp, self.weight, self.bias)) | |||||
| @@ -0,0 +1,69 @@ | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| from typing import Tuple, Union | |||||
| from ..functional import relu | |||||
| from .batchnorm import BatchNorm2d | |||||
| from .conv import Conv2d | |||||
| from .module import Module | |||||
| class _ConvBnActivation2d(Module): | |||||
| def __init__( | |||||
| self, | |||||
| in_channels: int, | |||||
| out_channels: int, | |||||
| kernel_size: Union[int, Tuple[int, int]], | |||||
| stride: Union[int, Tuple[int, int]] = 1, | |||||
| padding: Union[int, Tuple[int, int]] = 0, | |||||
| dilation: Union[int, Tuple[int, int]] = 1, | |||||
| groups: int = 1, | |||||
| bias: bool = True, | |||||
| conv_mode: str = "CROSS_CORRELATION", | |||||
| compute_mode: str = "DEFAULT", | |||||
| eps=1e-5, | |||||
| momentum=0.9, | |||||
| affine=True, | |||||
| track_running_stats=True, | |||||
| ): | |||||
| super().__init__() | |||||
| self.conv = Conv2d( | |||||
| in_channels, | |||||
| out_channels, | |||||
| kernel_size, | |||||
| stride, | |||||
| padding, | |||||
| dilation, | |||||
| groups, | |||||
| bias, | |||||
| conv_mode, | |||||
| compute_mode, | |||||
| ) | |||||
| self.bn = BatchNorm2d(out_channels, eps, momentum, affine, track_running_stats) | |||||
| class ConvBn2d(_ConvBnActivation2d): | |||||
| r""" | |||||
| A fused :class:`~.Module` including Conv2d, BatchNorm2d. Could be replaced | |||||
| with :class:`~.QATModule` version :class:`~.qat.conv_bn.ConvBn2d` using | |||||
| :func:`~.quantize.quantize_qat`. | |||||
| """ | |||||
| def forward(self, inp): | |||||
| return self.bn(self.conv(inp)) | |||||
| class ConvBnRelu2d(_ConvBnActivation2d): | |||||
| r""" | |||||
| A fused :class:`~.Module` including Conv2d, BatchNorm2d and relu. Could be replaced | |||||
| with :class:`~.QATModule` version :class:`~.qat.conv_bn.ConvBnRelu2d` using | |||||
| :func:`~.quantize.quantize_qat`. | |||||
| """ | |||||
| def forward(self, inp): | |||||
| return relu(self.bn(self.conv(inp))) | |||||
| @@ -0,0 +1,29 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| from ..functional import dropout | |||||
| from .module import Module | |||||
| class Dropout(Module): | |||||
| r"""Randomly set input elements to zeros with the probability :math:`drop\_prob` during training. Commonly used in large networks to prevent overfitting. | |||||
| Note that we perform dropout only during training, we also rescale(multiply) the output tensor | |||||
| by :math:`\frac{1}{1 - drop\_prob}`. During inference :class:`~.Dropout` is equal to :class:`~.Identity`. | |||||
| :param drop_prob: The probability to drop (set to zero) each single element | |||||
| """ | |||||
| def __init__(self, drop_prob=0.0): | |||||
| super().__init__() | |||||
| self.drop_prob = drop_prob | |||||
| def forward(self, inputs): | |||||
| if self.training: | |||||
| return dropout(inputs, self.drop_prob, rescale=True) | |||||
| else: | |||||
| return inputs | |||||
| @@ -0,0 +1,79 @@ | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| from ..core.ops._internal import param_defs as P | |||||
| from ..functional.elemwise import _elwise | |||||
| from ..tensor import Tensor | |||||
| from .module import Module | |||||
| class Elemwise(Module): | |||||
| r""" | |||||
| A :class:`~.Module` to do elemwise operator. Could be replaced with :class:`~.QATModule` | |||||
| version :class:`~.qat.elemwise.Elemwise` using :func:`~.quantize.quantize_qat`. | |||||
| :param method: the elemwise method, support the following string. | |||||
| It will do the normal elemwise operator for float. | |||||
| * "ADD": a + b | |||||
| * "FUSE_ADD_RELU": max(x+y, 0) | |||||
| * "MUL": x * y | |||||
| * "MIN": min(x, y) | |||||
| * "MAX": max(x, y) | |||||
| * "SUB": x - y | |||||
| * "TRUE_DIV": x / y | |||||
| * "FUSE_ADD_SIGMOID": sigmoid(x + y) | |||||
| * "FUSE_ADD_TANH": tanh(x + y) | |||||
| * "RELU": x > 0 ? x : 0 | |||||
| * "ABS": x > 0 ? x : -x | |||||
| * "SIGMOID": sigmoid(x) | |||||
| * "EXP": exp(x) | |||||
| * "TANH": tanh(x) | |||||
| * "FUSE_MUL_ADD3": x * y + z | |||||
| * "FAST_TANH": fast_tanh(x) | |||||
| * "NEGATE": -x | |||||
| * "ACOS": acos(x) | |||||
| * "ASIN": asin(x) | |||||
| * "CEIL": ceil(x) | |||||
| * "COS": cos(x) | |||||
| * "EXPM1": expm1(x) | |||||
| * "FLOOR": floor(x) | |||||
| * "LOG": log(x) | |||||
| * "LOG1P": log1p(x) | |||||
| * "SIN": sin(x) | |||||
| * "ROUND": round(x) | |||||
| * "ERF": erf(x) | |||||
| * "ERFINV": erfinv(x) | |||||
| * "ERFC": erfc(x) | |||||
| * "ERFCINV": erfcinv(x) | |||||
| * "ABS_GRAD": abs_grad | |||||
| * "FLOOR_DIV": floor_div | |||||
| * "MOD": mod | |||||
| * "SIGMOID_GRAD": sigmoid_grad | |||||
| * "SWITCH_GT0": switch_gt0 | |||||
| * "TANH_GRAD": tanh_grad | |||||
| * "LT": lt | |||||
| * "LEQ": leq | |||||
| * "EQ": eq | |||||
| * "POW": pow | |||||
| * "LOG_SUM_EXP": log_sum_exp | |||||
| * "FAST_TANH_GRAD": fast_tanh_grad | |||||
| * "ATAN2": atan2 | |||||
| * "COND_LEQ_MOV": cond_leq_mov | |||||
| * "H_SWISH": h_swish | |||||
| * "FUSE_ADD_H_SWISH": h_swish(x+y) | |||||
| * "H_SWISH_GRAD": h_swish_grad | |||||
| """ | |||||
| _elemwise_mode_type = P.Elemwise.Mode | |||||
| def __init__(self, method): | |||||
| super().__init__() | |||||
| self.method = self._elemwise_mode_type.convert(method) | |||||
| def forward(self, *inps): | |||||
| return _elwise(*inps, mode=self.method) | |||||
| @@ -0,0 +1,171 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| from typing import Optional | |||||
| import numpy as np | |||||
| from ..functional import embedding as embedding_func | |||||
| from ..tensor_nn import Parameter | |||||
| from . import init | |||||
| from .module import Module | |||||
| class Embedding(Module): | |||||
| r""" | |||||
| A simple lookup table that stores embeddings of a fixed dictionary and size. | |||||
| This module is often used to store word embeddings and retrieve them using indices. | |||||
| The input to the module is a list of indices, and the output is the corresponding word embeddings. | |||||
| The indices should less than num_embeddings. | |||||
| :param num_embeddings: size of embedding dictionary. | |||||
| :param embedding_dim: size of each embedding vector. | |||||
| :param padding_idx: should be set to None, not support now. | |||||
| :param max_norm: should be set to None, not support now. | |||||
| :param norm_type: should be set to None, not support now. | |||||
| :param initial_weight: the learnable weights of the module of shape (num_embeddings, embedding_dim). | |||||
| Examples: | |||||
| .. testcode:: | |||||
| import numpy as np | |||||
| import megengine as mge | |||||
| import megengine.module as M | |||||
| weight = mge.tensor(np.array([(1.2,2.3,3.4,4.5,5.6),(0.1,1.1,2.1,3.1,4.1)], dtype=np.float32)) | |||||
| data = mge.tensor(np.array([(0,1,1),(1,0,1),(0,0,1)], dtype=np.int32)) | |||||
| embedding = M.Embedding(2, 5, initial_weight=weight) | |||||
| output = embedding(data) | |||||
| with np.printoptions(precision=6): | |||||
| print(output.numpy()) | |||||
| Outputs: | |||||
| .. testoutput:: | |||||
| [[[1.2 2.3 3.4 4.5 5.6] | |||||
| [0.1 1.1 2.1 3.1 4.1] | |||||
| [0.1 1.1 2.1 3.1 4.1]] | |||||
| [[0.1 1.1 2.1 3.1 4.1] | |||||
| [1.2 2.3 3.4 4.5 5.6] | |||||
| [0.1 1.1 2.1 3.1 4.1]] | |||||
| [[1.2 2.3 3.4 4.5 5.6] | |||||
| [1.2 2.3 3.4 4.5 5.6] | |||||
| [0.1 1.1 2.1 3.1 4.1]]] | |||||
| """ | |||||
| def __init__( | |||||
| self, | |||||
| num_embeddings: int, | |||||
| embedding_dim: int, | |||||
| padding_idx: Optional[int] = None, | |||||
| max_norm: Optional[float] = None, | |||||
| norm_type: Optional[float] = None, | |||||
| initial_weight: Parameter = None, | |||||
| ): | |||||
| super().__init__() | |||||
| if padding_idx is not None: | |||||
| raise ValueError("Not support padding index now.") | |||||
| if max_norm is not None or norm_type is not None: | |||||
| raise ValueError("Not support weight normalize now.") | |||||
| self.padding_idx = padding_idx | |||||
| self.max_norm = max_norm | |||||
| self.norm_type = norm_type | |||||
| self.num_embeddings = num_embeddings | |||||
| self.embedding_dim = embedding_dim | |||||
| if initial_weight is None: | |||||
| self.weight = Parameter( | |||||
| np.random.uniform( | |||||
| size=(self.num_embeddings, self.embedding_dim) | |||||
| ).astype(np.float32) | |||||
| ) | |||||
| self.reset_parameters() | |||||
| else: | |||||
| if initial_weight.shape != (num_embeddings, embedding_dim): | |||||
| raise ValueError( | |||||
| "The weight shape should match num_embeddings and embedding_dim" | |||||
| ) | |||||
| self.weight = Parameter(initial_weight.numpy()) | |||||
| def reset_parameters(self) -> None: | |||||
| init.normal_(self.weight) | |||||
| def forward(self, inputs): | |||||
| return embedding_func(inputs, self.weight) | |||||
| @classmethod | |||||
| def from_pretrained( | |||||
| cls, | |||||
| embeddings: Parameter, | |||||
| freeze: Optional[bool] = True, | |||||
| padding_idx: Optional[int] = None, | |||||
| max_norm: Optional[float] = None, | |||||
| norm_type: Optional[float] = None, | |||||
| ): | |||||
| r""" | |||||
| Creates Embedding instance from given 2-dimensional FloatTensor. | |||||
| :param embeddings: Tensor contained weight for the embedding. | |||||
| :param freeze: If ``True``, the weight does not get updated during the learning process. Default: ``True``. | |||||
| :param padding_idx: should be set to None, not support Now. | |||||
| :param max_norm: should be set to None, not support Now. | |||||
| :param norm_type: should be set to None, not support Now. | |||||
| Examples: | |||||
| .. testcode:: | |||||
| import numpy as np | |||||
| import megengine as mge | |||||
| import megengine.module as M | |||||
| weight = mge.tensor(np.array([(1.2,2.3,3.4,4.5,5.6),(0.1,1.1,2.1,3.1,4.1)], dtype=np.float32)) | |||||
| data = mge.tensor(np.array([(0,1,1),(1,0,1),(0,0,1)], dtype=np.int32)) | |||||
| embedding = M.Embedding.from_pretrained(weight, freeze=False) | |||||
| output = embedding(data) | |||||
| print(output.numpy()) | |||||
| Outputs: | |||||
| .. testoutput:: | |||||
| [[[1.2 2.3 3.4 4.5 5.6] | |||||
| [0.1 1.1 2.1 3.1 4.1] | |||||
| [0.1 1.1 2.1 3.1 4.1]] | |||||
| [[0.1 1.1 2.1 3.1 4.1] | |||||
| [1.2 2.3 3.4 4.5 5.6] | |||||
| [0.1 1.1 2.1 3.1 4.1]] | |||||
| [[1.2 2.3 3.4 4.5 5.6] | |||||
| [1.2 2.3 3.4 4.5 5.6] | |||||
| [0.1 1.1 2.1 3.1 4.1]]] | |||||
| """ | |||||
| embeddings_shape = embeddings.shape | |||||
| embeddings_dim = len(embeddings_shape) | |||||
| if embeddings_dim != 2: | |||||
| raise ValueError("Embeddings parameter is expected to be 2-dimensional") | |||||
| rows = embeddings_shape[0] | |||||
| cols = embeddings_shape[1] | |||||
| embedding = cls( | |||||
| num_embeddings=rows, | |||||
| embedding_dim=cols, | |||||
| initial_weight=embeddings, | |||||
| padding_idx=padding_idx, | |||||
| max_norm=max_norm, | |||||
| norm_type=norm_type, | |||||
| ) | |||||
| embedding.weight.requires_grad = not freeze | |||||
| return embedding | |||||
| @@ -0,0 +1,56 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| import numpy as np | |||||
| from ..functional import cambricon_subgraph, extern_opr_subgraph | |||||
| from .module import Module | |||||
| class CambriconSubgraph(Module): | |||||
| r"""Load a serialized Cambricon subgraph. | |||||
| See :func:`~.cambricon_subgraph` for more details. | |||||
| """ | |||||
| def __init__( | |||||
| self, data, symbol, tensor_dim_mutable, | |||||
| ): | |||||
| super(CambriconSubgraph, self).__init__() | |||||
| self._data = data | |||||
| self.symbol = symbol | |||||
| self.tensor_dim_mutable = tensor_dim_mutable | |||||
| @property | |||||
| def data(self): | |||||
| return self._data.tobytes() | |||||
| @data.setter | |||||
| def data(self, val): | |||||
| self._data = np.frombuffer(val, dtype=np.uint8) | |||||
| def forward(self, inputs): | |||||
| outputs = cambricon_subgraph( | |||||
| inputs, self._data, self.symbol, self.tensor_dim_mutable, | |||||
| ) | |||||
| return outputs | |||||
| class ExternOprSubgraph(Module): | |||||
| r"""Load a serialized extern opr subgraph. | |||||
| """ | |||||
| def __init__(self, data, name, output_shapes): | |||||
| super(ExternOprSubgraph, self).__init__() | |||||
| self.data = data | |||||
| self.name = name | |||||
| self.output_shapes = output_shapes | |||||
| def forward(self, inputs): | |||||
| outputs = extern_opr_subgraph(inputs, self.output_shapes, self.name, self.data,) | |||||
| return outputs | |||||
| @@ -0,0 +1,17 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| from ..functional import identity | |||||
| from .module import Module | |||||
| class Identity(Module): | |||||
| r"""A placeholder identity operator that will ignore any argument.""" | |||||
| def forward(self, x): | |||||
| return identity(x) | |||||
| @@ -0,0 +1,261 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| import math | |||||
| from functools import reduce | |||||
| from typing import Optional, Tuple, Union | |||||
| import numpy as np | |||||
| from ..tensor import Tensor | |||||
| def fill_(tensor: Tensor, val: Union[float, int]) -> None: | |||||
| """Fill the given ``tensor`` with value ``val``. | |||||
| :param tensor: An n-dimentional tensor to be initialized | |||||
| :param val: The value to be filled throughout the tensor | |||||
| """ | |||||
| tensor.set_value(np.full(tensor.shape, val, tensor.dtype)) | |||||
| def zeros_(tensor: Tensor) -> None: | |||||
| """Fill the given ``tensor`` with scalar value `0`. | |||||
| :param tensor: An n-dimentional tensor to be initialized | |||||
| """ | |||||
| fill_(tensor, 0) | |||||
| def ones_(tensor: Tensor) -> None: | |||||
| """Fill the given ``tensor`` with the scalar value `1`. | |||||
| :param tensor: An n-dimentional tensor to be initialized | |||||
| """ | |||||
| fill_(tensor, 1) | |||||
| def uniform_(tensor: Tensor, a: float = 0.0, b: float = 1.0) -> None: | |||||
| r"""Fill the given ``tensor`` with random value sampled from uniform distribution | |||||
| :math:`\mathcal{U}(\text{a}, \text{b})`. | |||||
| :param tensor: An n-dimentional tensor to be initialized | |||||
| :param a: Lower bound of the sampling interval | |||||
| :param b: Upper bound of the sampling interval | |||||
| """ | |||||
| tensor.set_value(np.random.uniform(a, b, tensor.shape).astype(tensor.dtype)) | |||||
| def normal_(tensor: Tensor, mean: float = 0.0, std: float = 1.0) -> None: | |||||
| r"""Fill the given ``tensor`` with random value sampled from normal distribution | |||||
| :math:`\mathcal{N}(\text{mean}, \text{std}^2)`. | |||||
| :param tensor: An n-dimentional tensor to be initialized | |||||
| :param mean: The mean of the normal distribution | |||||
| :param std: The standard deviation of the normal distribution | |||||
| """ | |||||
| tensor.set_value(np.random.normal(mean, std, tensor.shape).astype(np.float32)) | |||||
| def calculate_gain( | |||||
| nonlinearity: str, param: Optional[Union[int, float]] = None | |||||
| ) -> float: | |||||
| r"""Return a recommended gain value (see the table below) for the given nonlinearity | |||||
| function. | |||||
| ================= ==================================================== | |||||
| nonlinearity gain | |||||
| ================= ==================================================== | |||||
| Linear / Identity :math:`1` | |||||
| Conv{1,2,3}D :math:`1` | |||||
| Sigmoid :math:`1` | |||||
| Tanh :math:`\frac{5}{3}` | |||||
| ReLU :math:`\sqrt{2}` | |||||
| Leaky Relu :math:`\sqrt{\frac{2}{1 + \text{negative_{slope}}^2}}` | |||||
| ================= ==================================================== | |||||
| :param nonlinearity: Name of the non-linear function | |||||
| :param param: Optional parameter for leaky_relu. Only effective when | |||||
| ``nonlinearity`` is "leaky_relu". | |||||
| """ | |||||
| linear_fns = [ | |||||
| "linear", | |||||
| "conv1d", | |||||
| "conv2d", | |||||
| "conv3d", | |||||
| "conv_transpose1d", | |||||
| "conv_transpose2d", | |||||
| "conv_transpose3d", | |||||
| ] | |||||
| if nonlinearity in linear_fns or nonlinearity == "sigmoid": | |||||
| return 1 | |||||
| if nonlinearity == "tanh": | |||||
| return 5.0 / 3 | |||||
| if nonlinearity == "relu": | |||||
| return math.sqrt(2.0) | |||||
| if nonlinearity == "leaky_relu": | |||||
| if param is None: | |||||
| negative_slope = 0.01 | |||||
| elif ( | |||||
| not isinstance(param, bool) | |||||
| and isinstance(param, int) | |||||
| or isinstance(param, float) | |||||
| ): | |||||
| # True/False are instances of int, hence check above | |||||
| negative_slope = param | |||||
| else: | |||||
| raise ValueError("negative_slope {} not a valid number".format(param)) | |||||
| return math.sqrt(2.0 / (1 + negative_slope ** 2)) | |||||
| raise ValueError("Unsupported nonlinearity {}".format(nonlinearity)) | |||||
| def calculate_fan_in_and_fan_out(tensor: Tensor) -> Tuple[float, float]: | |||||
| """ | |||||
| Calculate fan_in / fan_out value for given weight tensor. This function assumes | |||||
| input tensor is stored in NCHW format. | |||||
| :param tensor: Weight tensor in NCHW format | |||||
| """ | |||||
| shape = tensor.shape | |||||
| ndim = len(shape) | |||||
| if ndim < 2: | |||||
| raise ValueError( | |||||
| "fan_in and fan_out can not be computed for tensor with fewer than 2 " | |||||
| "dimensions" | |||||
| ) | |||||
| if ndim == 2: # Linear | |||||
| fan_in = shape[1] | |||||
| fan_out = shape[0] | |||||
| else: | |||||
| num_input_fmaps = shape[1] | |||||
| num_output_fmaps = shape[0] | |||||
| receptive_field_size = 1 | |||||
| if ndim > 2: | |||||
| receptive_field_size = reduce(lambda x, y: x * y, shape[2:], 1) | |||||
| fan_in = num_input_fmaps * receptive_field_size | |||||
| fan_out = num_output_fmaps * receptive_field_size | |||||
| return fan_in, fan_out | |||||
| def calculate_correct_fan(tensor: Tensor, mode: str) -> float: | |||||
| """ | |||||
| Calculate fan_in or fan_out value for given weight tensor, depending on given | |||||
| ``mode``. | |||||
| See :func:`calculate_fan_in_and_fan_out` for details. | |||||
| :param tensor: Weight tensor in NCHW format | |||||
| :param mode: ``'fan_in'`` or ``'fan_out'`` | |||||
| """ | |||||
| mode = mode.lower() | |||||
| valid_modes = ["fan_in", "fan_out"] | |||||
| if mode not in valid_modes: | |||||
| raise ValueError( | |||||
| "Mode {} not supported, please use one of {}".format(mode, valid_modes) | |||||
| ) | |||||
| fan_in, fan_out = calculate_fan_in_and_fan_out(tensor) | |||||
| return fan_in if mode == "fan_in" else fan_out | |||||
| def xavier_uniform_(tensor: Tensor, gain: float = 1.0) -> None: | |||||
| r"""Fill ``tensor`` with random values sampled from :math:`\mathcal{U}(-a, a)` | |||||
| where | |||||
| .. math:: | |||||
| a = \text{gain} \times \sqrt{\frac{6}{\text{fan_in} + \text{fan_out}}} | |||||
| Also known as Glorot initialization. Detailed information can be retrieved from | |||||
| `Understanding the difficulty of training deep feedforward neural networks` - | |||||
| Glorot, X. & Bengio, Y. (2010). | |||||
| :param tensor: An n-dimentional tensor to be initialized | |||||
| :param gain: Scaling factor for :math:`a`. | |||||
| """ | |||||
| fan_in, fan_out = calculate_fan_in_and_fan_out(tensor) | |||||
| std = gain * math.sqrt(2.0 / float(fan_in + fan_out)) | |||||
| a = math.sqrt(3.0) * std | |||||
| uniform_(tensor, -a, a) | |||||
| def xavier_normal_(tensor: Tensor, gain: float = 1.0) -> None: | |||||
| r"""Fill ``tensor`` with random values sampled from | |||||
| :math:`\mathcal{N}(0, \text{std}^2)` where | |||||
| .. math:: | |||||
| \text{std} = \text{gain} \times \sqrt{\frac{2}{\text{fan_in} + \text{fan_out}}} | |||||
| Also known as Glorot initialization. Detailed information can be retrieved from | |||||
| `Understanding the difficulty of training deep feedforward neural networks` - | |||||
| Glorot, X. & Bengio, Y. (2010). | |||||
| :param tensor: An n-dimentional tensor to be initialized | |||||
| :param gain: Scaling factor for :math:`std`. | |||||
| """ | |||||
| fan_in, fan_out = calculate_fan_in_and_fan_out(tensor) | |||||
| std = gain * math.sqrt(2.0 / float(fan_in + fan_out)) | |||||
| normal_(tensor, 0.0, std) | |||||
| def msra_uniform_( | |||||
| tensor: Tensor, a: float = 0, mode: str = "fan_in", nonlinearity: str = "leaky_relu" | |||||
| ) -> None: | |||||
| r"""Fill ``tensor`` wilth random values sampled from | |||||
| :math:`\mathcal{U}(-\text{bound}, \text{bound})` where | |||||
| .. math:: | |||||
| \text{bound} = \sqrt{\frac{6}{(1 + a^2) \times \text{fan_in}}} | |||||
| Detailed information can be retrieved from | |||||
| `Delving deep into rectifiers: Surpassing human-level performance on ImageNet | |||||
| classification` | |||||
| :param tensor: An n-dimentional tensor to be initialized | |||||
| :param a: Optional parameter for calculating gain for leaky_relu. See | |||||
| :func:`calculate_gain` for details. | |||||
| :param mode: ``'fan_in'`` or ``'fan_out'``, used to calculate :math:`gain`, the | |||||
| scaling factor for :math:`bound`. See :func:`calculate_fan_in_and_fan_out` for | |||||
| details. | |||||
| :param nonlinearity: Name of the non-linear function used to calculate :math:`gain`. | |||||
| See :func:`calculate_gain` for details. | |||||
| """ | |||||
| fan = calculate_correct_fan(tensor, mode) | |||||
| gain = calculate_gain(nonlinearity, a) | |||||
| std = gain / math.sqrt(fan) | |||||
| bound = math.sqrt(3.0) * std | |||||
| uniform_(tensor, -bound, bound) | |||||
| def msra_normal_( | |||||
| tensor: Tensor, a: float = 0, mode: str = "fan_in", nonlinearity: str = "leaky_relu" | |||||
| ) -> None: | |||||
| r"""Fill ``tensor`` wilth random values sampled from | |||||
| :math:`\mathcal{N}(0, \text{std}^2)` where | |||||
| .. math:: | |||||
| \text{std} = \sqrt{\frac{2}{(1 + a^2) \times \text{fan_in}}} | |||||
| Detailed information can be retrieved from | |||||
| `Delving deep into rectifiers: Surpassing human-level performance on ImageNet | |||||
| classification` | |||||
| :param tensor: An n-dimentional tensor to be initialized | |||||
| :param a: Optional parameter for calculating gain for leaky_relu. See | |||||
| :func:`calculate_gain` for details. | |||||
| :param mode: ``'fan_in'`` or ``'fan_out'``, used to calculate :math:`gain`, the | |||||
| scaling factor for :math:`gain`. See :func:`calculate_fan_in_and_fan_out` for | |||||
| details. | |||||
| :param nonlinearity: Name of the non-linear function used to calculate :math:`gain`. | |||||
| See :func:`calculate_gain` for details. | |||||
| """ | |||||
| fan = calculate_correct_fan(tensor, mode) | |||||
| gain = calculate_gain(nonlinearity, a) | |||||
| std = gain / math.sqrt(fan) | |||||
| normal_(tensor, 0, std) | |||||
| @@ -0,0 +1,61 @@ | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| import numpy as np | |||||
| from ..functional import linear | |||||
| from ..tensor_nn import Parameter | |||||
| from . import init | |||||
| from .module import Module | |||||
| class Linear(Module): | |||||
| r"""Applies a linear transformation to the input. For instance, if input | |||||
| is x, then output y is: | |||||
| .. math:: | |||||
| y = xW^T + b | |||||
| where :math:`y_i= \sum_j W_{ij} x_j + b_i` | |||||
| :param in_features: size of each input sample. | |||||
| :param out_features: size of each output sample. | |||||
| :param bias: If set to ``False``, the layer will not learn an additive bias. | |||||
| Default: ``True`` | |||||
| """ | |||||
| def __init__( | |||||
| self, in_features: int, out_features: int, bias: bool = True, **kwargs | |||||
| ): | |||||
| super().__init__(**kwargs) | |||||
| self.out_features = out_features | |||||
| self.in_features = in_features | |||||
| w_shape = (out_features, in_features) | |||||
| self.weight = Parameter(np.zeros(w_shape, dtype=np.float32)) | |||||
| self.bias = None | |||||
| if bias: | |||||
| b_shape = (out_features,) | |||||
| self.bias = Parameter(np.zeros(b_shape, dtype=np.float32)) | |||||
| self.reset_parameters() | |||||
| def _get_fanin(self): | |||||
| return self.in_features | |||||
| def reset_parameters(self) -> None: | |||||
| fanin = self._get_fanin() | |||||
| std = np.sqrt(1 / fanin) | |||||
| init.normal_(self.weight, 0.0, std) | |||||
| if self.bias is not None: | |||||
| init.zeros_(self.bias) | |||||
| def _calc_linear(self, x, weight, bias): | |||||
| return linear(x, weight, bias) | |||||
| def forward(self, x): | |||||
| return self._calc_linear(x, self.weight, self.bias) | |||||
| @@ -0,0 +1,508 @@ | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| from abc import ABCMeta, abstractmethod | |||||
| from collections import OrderedDict | |||||
| from typing import Any, Callable, Iterable, Optional, Set, Tuple, Union | |||||
| import numpy as np | |||||
| from ..core.tensor.dtype import is_quantize | |||||
| from ..logger import get_logger | |||||
| from ..tensor import Tensor | |||||
| from ..tensor_nn import Buffer, Parameter | |||||
| from ..utils.hook import HookHandler | |||||
| logger = get_logger(__name__) | |||||
| def _expand_structure(key, obj): | |||||
| if isinstance(obj, (Tensor, Module)): | |||||
| return [(key, obj)] | |||||
| elif isinstance(obj, (list, tuple, dict)): | |||||
| ret = [] | |||||
| if isinstance(obj, dict): | |||||
| targets = ((k, obj[k]) for k in sorted(obj)) | |||||
| else: | |||||
| targets = ((str(k), v) for k, v in enumerate(obj)) | |||||
| for k, o in targets: | |||||
| sub_ret = _expand_structure(k, o) | |||||
| if sub_ret and not isinstance(k, str): | |||||
| raise AssertionError( | |||||
| "keys for Tensor and Module must be str, error key: {}".format(k) | |||||
| ) | |||||
| for kt, vt in sub_ret: | |||||
| ret.extend([(key + "." + kt, vt)]) | |||||
| return ret | |||||
| else: | |||||
| return [] | |||||
| def _is_parameter(obj): | |||||
| return isinstance(obj, Parameter) | |||||
| def _is_buffer(obj): | |||||
| return isinstance(obj, Buffer) | |||||
| def _is_module(obj): | |||||
| return isinstance(obj, Module) | |||||
| class Module(metaclass=ABCMeta): | |||||
| """Base Module class. | |||||
| """ | |||||
| def __init__(self): | |||||
| # runtime attributes | |||||
| self.training = True | |||||
| self.quantize_disabled = False | |||||
| # hooks | |||||
| self._forward_pre_hooks = OrderedDict() | |||||
| self._forward_hooks = OrderedDict() | |||||
| @abstractmethod | |||||
| def forward(self, inputs): | |||||
| pass | |||||
| def register_forward_pre_hook(self, hook: Callable) -> HookHandler: | |||||
| """Register a hook to handle forward inputs. `hook` should be a function | |||||
| Note that `inputs` keyword inputs | |||||
| :param hook: a function that receive `module` and `inputs`, then return | |||||
| a modified `inputs` or `None`. | |||||
| :return: a handler with :meth:`~.HookHandler.remove` interface to delete the hook. | |||||
| """ | |||||
| return HookHandler(self._forward_pre_hooks, hook) | |||||
| def register_forward_hook(self, hook: Callable) -> HookHandler: | |||||
| """Register a hook to handle forward results. `hook` should be a function that | |||||
| receive `module`, `inputs` and `outputs`, then return a modified `outputs` or `None`. | |||||
| This method return a handler with :meth:`~.HookHandler.remove` interface to delete the hook. | |||||
| """ | |||||
| return HookHandler(self._forward_hooks, hook) | |||||
| def __call__(self, *inputs, **kwargs): | |||||
| for hook in self._forward_pre_hooks.values(): | |||||
| modified_inputs = hook(self, inputs) | |||||
| if modified_inputs is not None: | |||||
| if not isinstance(modified_inputs, tuple): | |||||
| modified_inputs = (modified_inputs,) | |||||
| inputs = modified_inputs | |||||
| outputs = self.forward(*inputs, **kwargs) | |||||
| for hook in self._forward_hooks.values(): | |||||
| modified_outputs = hook(self, inputs, outputs) | |||||
| if modified_outputs is not None: | |||||
| outputs = modified_outputs | |||||
| return outputs | |||||
| def _flatten( | |||||
| self, | |||||
| *, | |||||
| recursive: bool = True, | |||||
| with_key: bool = False, | |||||
| with_parent: bool = False, | |||||
| prefix: Optional[str] = None, | |||||
| predicate: Callable[[Any], bool] = lambda _: True, | |||||
| seen: Optional[Set[int]] = None | |||||
| ) -> Union[Iterable[Any], Iterable[Tuple[str, Any]]]: | |||||
| """Scans the module object and returns an iterable for the :class:`~.Tensor` | |||||
| and :class:`~.Module` attributes that agree with the ``predicate``. For multiple | |||||
| calls of this function with same arguments, the order of objects within the | |||||
| returned iterable is guaranteed to be identical, as long as all the involved | |||||
| module objects' ``__dict__`` does not change thoughout those calls. | |||||
| :param recursive: Whether to recursively scan all the submodules. | |||||
| :param with_key: Whether to yield keys along with yielded objects. | |||||
| :param with_parent: Whether to yield ``self`` along with yielded objects. | |||||
| :param prefix: The prefix appended to the yielded keys. | |||||
| :param predicate: The predicate function applied to scanned objects. | |||||
| :param seen: A dict that records whether a module has been traversed yet. | |||||
| """ | |||||
| if seen is None: | |||||
| seen = set([id(self)]) | |||||
| module_dict = vars(self) | |||||
| _prefix = "" if prefix is None else prefix + "." | |||||
| for key in sorted(module_dict): | |||||
| for expanded_key, leaf in _expand_structure(key, module_dict[key]): | |||||
| leaf_id = id(leaf) | |||||
| if leaf_id in seen: | |||||
| continue | |||||
| seen.add(leaf_id) | |||||
| if predicate(leaf): | |||||
| if with_key and with_parent: | |||||
| yield _prefix + expanded_key, leaf, self | |||||
| elif with_key: | |||||
| yield _prefix + expanded_key, leaf | |||||
| elif with_parent: | |||||
| yield leaf, self | |||||
| else: | |||||
| yield leaf | |||||
| if recursive and isinstance(leaf, Module): | |||||
| yield from leaf._flatten( | |||||
| recursive=recursive, | |||||
| with_key=with_key, | |||||
| with_parent=with_parent, | |||||
| prefix=_prefix + expanded_key if with_key else None, | |||||
| predicate=predicate, | |||||
| seen=seen, | |||||
| ) | |||||
| def parameters( | |||||
| self, requires_grad: Optional[bool] = None, recursive: bool = True, **kwargs | |||||
| ) -> Iterable[Parameter]: | |||||
| r"""Returns an iterable for the :class:`~.Parameter` of the module. | |||||
| :param requires_grad: Limitation over the :attr:`~.Parameter.requires_grad` | |||||
| attribute of returned :class:`.Parameter`. ``None`` for no limitation. | |||||
| :param recursive: If ``True``, returns all :class:`~.Parameter` within this | |||||
| module, else only returns :class:`~.Parameter` that are direct attributes | |||||
| of this module. | |||||
| """ | |||||
| def predicate(obj) -> bool: | |||||
| return _is_parameter(obj) and ( | |||||
| requires_grad is None or obj.requires_grad == requires_grad | |||||
| ) | |||||
| yield from self._flatten( | |||||
| with_key=False, predicate=predicate, recursive=recursive, **kwargs | |||||
| ) | |||||
| def named_parameters( | |||||
| self, | |||||
| requires_grad: Optional[bool] = None, | |||||
| prefix: Optional[str] = None, | |||||
| recursive: bool = True, | |||||
| **kwargs | |||||
| ) -> Iterable[Tuple[str, Parameter]]: | |||||
| """Returns an iterable for key :class:`~.Parameter` pairs of the module, where | |||||
| ``key`` is the dotted path from this module to the :class:`~.Parameter` . | |||||
| :param requires_grad: Limitation over the :attr:`~.Parameter.requires_grad` | |||||
| attribute of returned :class:`~.Parameter` . ``None`` for no limitation. | |||||
| :param prefix: The prefix prepended to the keys. | |||||
| :param recursive: If ``True``, returns all :class:`~.Parameter` within this | |||||
| module, else only returns :class:`~.Parameter` that are direct attributes | |||||
| of this module. | |||||
| """ | |||||
| def predicate(obj) -> bool: | |||||
| return _is_parameter(obj) and ( | |||||
| requires_grad is None or obj.requires_grad == requires_grad | |||||
| ) | |||||
| yield from self._flatten( | |||||
| with_key=True, | |||||
| prefix=prefix, | |||||
| predicate=predicate, | |||||
| recursive=recursive, | |||||
| **kwargs, | |||||
| ) | |||||
| def buffers(self, recursive: bool = True, **kwargs) -> Iterable[Buffer]: | |||||
| """Returns an iterable for the :class:`~.Buffer` of the module. | |||||
| :param recursive: If ``True``, returns all :class:`~.Buffer` within this | |||||
| module, else only returns :class:`~.Buffer` that are direct attributes | |||||
| of this module. | |||||
| """ | |||||
| yield from self._flatten( | |||||
| with_key=False, predicate=_is_buffer, recursive=recursive, **kwargs | |||||
| ) | |||||
| def named_buffers( | |||||
| self, prefix: Optional[str] = None, recursive: bool = True, **kwargs | |||||
| ) -> Iterable[Tuple[str, Buffer]]: | |||||
| """Returns an iterable for key :class:`~.Buffer` pairs of the module, where | |||||
| ``key`` is the dotted path from this module to the :class:`~.Buffer` . | |||||
| :param prefix: The prefix prepended to the keys. | |||||
| :param recursive: If ``True``, returns all :class:`~.Buffer` within this | |||||
| module, else only returns :class:`~.Buffer` that are direct attributes | |||||
| of this module. | |||||
| """ | |||||
| yield from self._flatten( | |||||
| with_key=True, | |||||
| prefix=prefix, | |||||
| predicate=_is_buffer, | |||||
| recursive=recursive, | |||||
| **kwargs, | |||||
| ) | |||||
| def children(self, **kwargs) -> "Iterable[Module]": | |||||
| """Returns an iterable for all the submodules that are direct attributes of this | |||||
| module. | |||||
| """ | |||||
| yield from self._flatten( | |||||
| with_key=False, predicate=_is_module, recursive=False, **kwargs | |||||
| ) | |||||
| def named_children(self, **kwargs) -> "Iterable[Tuple[str, Module]]": | |||||
| """Returns an iterable of key-submodule pairs for all the submodules that are | |||||
| direct attributes of this module, where 'key' is the attribute name of | |||||
| submodules. | |||||
| """ | |||||
| yield from self._flatten( | |||||
| with_key=True, predicate=_is_module, recursive=False, **kwargs | |||||
| ) | |||||
| def modules(self, **kwargs) -> "Iterable[Module]": | |||||
| """Returns an iterable for all the modules within this module, including itself. | |||||
| """ | |||||
| if "with_parent" in kwargs and kwargs["with_parent"]: | |||||
| yield self, None | |||||
| else: | |||||
| yield self | |||||
| yield from self._flatten(with_key=False, predicate=_is_module, **kwargs) | |||||
| def named_modules( | |||||
| self, prefix: Optional[str] = None, **kwargs | |||||
| ) -> "Iterable[Tuple[str, Module]]": | |||||
| """Returns an iterable of key-module pairs for all the modules within this | |||||
| module, including itself, where 'key' is the dotted path from this module to the | |||||
| submodules. | |||||
| :param prefix: The prefix prepended to the path. | |||||
| """ | |||||
| if "with_parent" in kwargs and kwargs["with_parent"]: | |||||
| yield ("" if prefix is None else prefix), self, None | |||||
| else: | |||||
| yield ("" if prefix is None else prefix), self | |||||
| yield from self._flatten( | |||||
| with_key=True, prefix=prefix, predicate=_is_module, **kwargs | |||||
| ) | |||||
| def apply(self, fn: "Callable[[Module], Any]") -> None: | |||||
| """Apply function ``fn`` to all the modules within this module, including | |||||
| itself. | |||||
| :param fn: The function to be applied on modules. | |||||
| """ | |||||
| for it in self.modules(): | |||||
| fn(it) | |||||
| def zero_grad(self) -> None: | |||||
| """Set all parameters' grads to zero | |||||
| """ | |||||
| for param in self.parameters(): | |||||
| if param.grad is not None: | |||||
| param.grad.reset_zero() | |||||
| def train(self, mode: bool = True, recursive: bool = True) -> None: | |||||
| """Set training mode of all the modules within this module (including itself) to | |||||
| ``mode``. This effectively sets the ``training`` attributes of those modules | |||||
| to ``mode``, but only has effect on certain modules (e.g. | |||||
| :class:`~.BatchNorm2d`, :class:`~.Dropout`, :class:`~.Observer`) | |||||
| :param mode: the training mode to be set on modules. | |||||
| :param recursive: whether to recursively call submodules' ``train()``. | |||||
| """ | |||||
| if not recursive: | |||||
| self.training = mode | |||||
| return | |||||
| def fn(module: Module) -> None: | |||||
| module.train(mode, recursive=False) | |||||
| self.apply(fn) | |||||
| def eval(self) -> None: | |||||
| """Set training mode of all the modules within this module (including itself) to | |||||
| ``False``. See :meth:`~.Module.train` for details. | |||||
| """ | |||||
| self.train(False) | |||||
| def disable_quantize(self, value=True): | |||||
| r""" | |||||
| Set ``module``'s ``quantize_disabled`` attribute and return ``module``. | |||||
| Could be used as a decorator. | |||||
| """ | |||||
| def fn(module: Module) -> None: | |||||
| module.quantize_disabled = value | |||||
| self.apply(fn) | |||||
| def replace_param( | |||||
| self, params: dict, start_pos: int, seen: Optional[Set[int]] = None | |||||
| ): | |||||
| """Replace module's parameters with `params`, used by :class:`~.ParamPack` to | |||||
| speedup multimachine training. | |||||
| """ | |||||
| offset = 0 | |||||
| if seen is None: | |||||
| seen = set([id(self)]) | |||||
| module_dict = vars(self) | |||||
| for key in sorted(module_dict): | |||||
| hash_id = id(module_dict[key]) | |||||
| if hash_id in seen: | |||||
| continue | |||||
| seen.add(hash_id) | |||||
| if isinstance(module_dict[key], Parameter): | |||||
| if start_pos + offset in params: | |||||
| assert module_dict[key].shape == params[start_pos + offset].shape | |||||
| module_dict[key] = params[start_pos + offset] | |||||
| offset += 1 | |||||
| if isinstance(module_dict[key], Module): | |||||
| offset += module_dict[key].replace_param( | |||||
| params, start_pos + offset, seen | |||||
| ) | |||||
| return offset | |||||
| def state_dict(self, rst=None, prefix="", keep_var=False): | |||||
| r"""Returns a dictionary containing whole states of the module. | |||||
| """ | |||||
| def is_state(obj): | |||||
| return _is_parameter(obj) or _is_buffer(obj) | |||||
| if rst is None: | |||||
| rst = OrderedDict() | |||||
| for k, v in self._flatten(recursive=False, with_key=True, predicate=is_state): | |||||
| assert prefix + k not in rst, "duplicated state: {}".format(k) | |||||
| if keep_var: | |||||
| rst[prefix + k] = v | |||||
| else: | |||||
| rst[prefix + k] = v.numpy() | |||||
| for k, submodule in self._flatten( | |||||
| recursive=False, | |||||
| with_key=True, | |||||
| predicate=lambda obj: isinstance(obj, Module), | |||||
| ): | |||||
| submodule.state_dict(rst, prefix + k + ".", keep_var) | |||||
| return rst | |||||
| def load_state_dict( | |||||
| self, | |||||
| state_dict: Union[dict, Callable[[str, Tensor], Optional[np.ndarray]]], | |||||
| strict=True, | |||||
| ): | |||||
| r"""Load a given dictionary created by :func:`state_dict` into this module. | |||||
| If ``strict`` is ``True``, the keys of :func:`state_dict` must exactly match the keys | |||||
| returned by :func:`state_dict`. | |||||
| Users can also pass a closure: `Function[key: str, var: Tensor] -> Optional[np.ndarray]` | |||||
| as a `state_dict`, in order to handle complex situations. For example, load everything | |||||
| except for the final linear classifier: | |||||
| .. code-block:: | |||||
| state_dict = {...} # Dict[str, np.ndarray] | |||||
| model.load_state_dict({ | |||||
| k: None if k.startswith('fc') else v | |||||
| for k, v in state_dict.items() | |||||
| }, strict=False) | |||||
| Here returning `None` means skipping parameter `k`. | |||||
| To prevent shape mismatch (e.g. load PyTorch weights), we can reshape before loading: | |||||
| .. code-block:: | |||||
| state_dict = {...} | |||||
| def reshape_accordingly(k, v): | |||||
| return state_dict[k].reshape(v.shape) | |||||
| model.load_state_dict(reshape_accordingly) | |||||
| We can also perform inplace re-initialization or pruning: | |||||
| .. code-block:: | |||||
| def reinit_and_pruning(k, v): | |||||
| if 'bias' in k: | |||||
| M.init.zero_(v) | |||||
| if 'conv' in k: | |||||
| return v.numpy() * (np.abs(v.numpy()) > 1e-3).astype("float32) | |||||
| model.load_state_dict(reinit_and_pruning, strict=False) | |||||
| """ | |||||
| unused = [] | |||||
| if isinstance(state_dict, dict): | |||||
| unused = state_dict.keys() | |||||
| def closure(k, _): # var unused | |||||
| return state_dict[k] if k in state_dict else None | |||||
| elif callable(state_dict): | |||||
| closure = state_dict | |||||
| else: | |||||
| raise ValueError( | |||||
| "`state_dict` must load a dict or callable, got {}".format( | |||||
| type(state_dict) | |||||
| ) | |||||
| ) | |||||
| loaded, skipped = self._load_state_dict_with_closure(closure) | |||||
| unused = set(unused) - loaded | |||||
| if len(unused) != 0: | |||||
| if strict: | |||||
| raise KeyError( | |||||
| "Unused params violate `strict=True`, unused={}".format(unused) | |||||
| ) | |||||
| else: | |||||
| logger.warning( | |||||
| "Unused params in `strict=False` mode, unused={}".format(unused) | |||||
| ) | |||||
| if len(skipped) != 0: | |||||
| if strict: | |||||
| raise KeyError( | |||||
| "Missing params violate `strict=True`, missing={}".format(skipped) | |||||
| ) | |||||
| else: | |||||
| logger.warning( | |||||
| "Missing params in `strict=False` mode, missing={}".format(skipped) | |||||
| ) | |||||
| def _load_state_dict_with_closure(self, closure): | |||||
| """Advance state_dict load through callable `closure` whose signature is | |||||
| `closure(key: str, var: Tensor) -> Union[np.ndarry, None]` | |||||
| """ | |||||
| assert callable(closure), "closure must be a function" | |||||
| loaded = [] | |||||
| skipped = [] | |||||
| local_state_dict = self.state_dict(keep_var=True) | |||||
| for k, var in local_state_dict.items(): | |||||
| to_be_load = closure(k, var) | |||||
| if to_be_load is None: | |||||
| skipped.append(k) | |||||
| continue | |||||
| assert isinstance( | |||||
| to_be_load, np.ndarray | |||||
| ), "closure should return a `np.ndarray`, now `{}` get {}".format( | |||||
| k, to_be_load | |||||
| ) | |||||
| assert ( | |||||
| var.shape == to_be_load.shape | |||||
| ), "param `{}` shape mismatch, should be {}, get {}".format( | |||||
| k, var.shape, to_be_load.shape | |||||
| ) | |||||
| # For quantized dtype, the initialized dtype | |||||
| # scale/zero_points maybe invalid, use pretrained dtype instead. | |||||
| if is_quantize(to_be_load.dtype) and is_quantize(var.dtype): | |||||
| var = var.astype(to_be_load.dtype) | |||||
| var.set_value(to_be_load) | |||||
| loaded.append(k) | |||||
| return set(loaded), set(skipped) | |||||
| @@ -0,0 +1,156 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| import collections | |||||
| from typing import Callable, Iterable, Optional, Tuple | |||||
| import numpy as np | |||||
| from ..tensor_nn import Parameter, Tensor | |||||
| from .module import Module | |||||
| class ParamPack(Module): | |||||
| r"""Pack module's parameters by gathering their memory to continuous address. | |||||
| Using (device, dtype, requires_grad) as key, for example ('gpu0', float32, True), | |||||
| parameters with same key will be packed togather. | |||||
| It helps a lot for multimachine training by speeding up allreduce gradients. | |||||
| :param model: the module you want to pack parameters. | |||||
| :param nr_ignore_first: how many parameters will be unpacked at first. | |||||
| :param max_size_per_group: upper bound of packed parameters' size in MB. | |||||
| :param max_nr_params_per_group: upper bound of the number of parameters of each group. | |||||
| """ | |||||
| def __init__( | |||||
| self, | |||||
| model: Module, | |||||
| nr_ignore_first: int = 8, | |||||
| max_size_per_group: int = 10, | |||||
| max_nr_params_per_group: int = 100, | |||||
| group_func: Callable = lambda name, param: 0, | |||||
| ): | |||||
| super().__init__() | |||||
| self._model = model | |||||
| self._nr_ignore_first = nr_ignore_first | |||||
| self._max_size_per_group = max_size_per_group | |||||
| self._max_nr_params_per_group = max_nr_params_per_group | |||||
| self._group_func = group_func | |||||
| self._grouped_params = [] | |||||
| self._packed_params = [] | |||||
| params = model.named_parameters() | |||||
| self._pack_params(params) | |||||
| def parameters(self, requires_grad: Optional[bool] = None) -> Iterable[Parameter]: | |||||
| for param in self._packed_params: | |||||
| if requires_grad is None or param.requires_grad == requires_grad: | |||||
| yield param | |||||
| def named_parameters( | |||||
| self, requires_grad: Optional[bool] = None | |||||
| ) -> Iterable[Tuple[str, Parameter]]: | |||||
| for idx, param in enumerate(self._packed_params): | |||||
| if requires_grad is None or param.requires_grad == requires_grad: | |||||
| yield "packed_param_" + str(idx), param | |||||
| def _pack_params(self, params: Iterable[Tuple[str, Parameter]]): | |||||
| groups = collections.defaultdict(list) | |||||
| ignored = 0 | |||||
| param_id = 0 | |||||
| for name, param in params: | |||||
| if self._nr_ignore_first > ignored: | |||||
| ignored += 1 | |||||
| self._grouped_params.append([{"shape": param.shape, "id": param_id}]) | |||||
| param.pack_group_key = self._group_func(name, param) | |||||
| self._packed_params.append(param) | |||||
| else: | |||||
| key = ( | |||||
| param.dtype, | |||||
| param.device, | |||||
| param.requires_grad, | |||||
| self._group_func(name, param), | |||||
| ) | |||||
| groups[key].append({"tensor": param, "id": param_id}) | |||||
| param_id += 1 | |||||
| for (dtype, device, requires_grad, group_key) in groups.keys(): | |||||
| dtype_sz = np.dtype(dtype).itemsize | |||||
| align = device.mem_align | |||||
| if align < dtype_sz: | |||||
| align = 1 | |||||
| else: | |||||
| assert align % dtype_sz == 0 | |||||
| align //= dtype_sz | |||||
| group = groups[(dtype, device, requires_grad, group_key)] | |||||
| while group: | |||||
| aligned_pos = [] | |||||
| offset = 0 | |||||
| params = [] | |||||
| idx = 0 | |||||
| while idx < len(group): | |||||
| param = group[idx] | |||||
| assert param["tensor"].device == device | |||||
| padding = (align - (offset & (align - 1))) & (align - 1) | |||||
| offset += padding | |||||
| aligned_pos.append(offset) | |||||
| params.append(param) | |||||
| offset += int(np.prod(param["tensor"].shape)) | |||||
| idx += 1 | |||||
| if ( | |||||
| offset * dtype_sz >= self._max_size_per_group * 1024 * 1024 | |||||
| or idx >= self._max_nr_params_per_group | |||||
| ): | |||||
| break | |||||
| group = group[idx:] | |||||
| if idx == 1: | |||||
| # ignore param packs with only one item | |||||
| params[0]["tensor"].pack_group_key = group_key | |||||
| self._packed_params.append(params[0]["tensor"]) | |||||
| self._grouped_params.append( | |||||
| [{"shape": params[0]["tensor"].shape, "id": params[0]["id"]}] | |||||
| ) | |||||
| continue | |||||
| packed_value = np.zeros((offset,), dtype=dtype) | |||||
| for param, pos in zip(params, aligned_pos): | |||||
| val = param["tensor"].numpy() | |||||
| packed_value[pos : pos + val.size] = val.flatten() | |||||
| new_param = Parameter( | |||||
| value=packed_value, | |||||
| device=device, | |||||
| dtype=dtype, | |||||
| requires_grad=requires_grad, | |||||
| ) | |||||
| new_param.pack_group_key = group_key | |||||
| self._packed_params.append(new_param) | |||||
| self._grouped_params.append( | |||||
| [{"shape": i["tensor"].shape, "id": i["id"]} for i in params] | |||||
| ) | |||||
| def forward(self, *args, **kwargs): | |||||
| replace_param = dict() | |||||
| for i in range(len(self._packed_params)): | |||||
| packed_param = self._packed_params[i] | |||||
| grouped_params = self._grouped_params[i] | |||||
| if len(grouped_params) == 1: | |||||
| continue | |||||
| split = param_pack_split( | |||||
| packed_param._symvar, [i["shape"] for i in grouped_params] | |||||
| ) | |||||
| split = [ | |||||
| Parameter(Tensor(i, requires_grad=packed_param.requires_grad)) | |||||
| for i in split | |||||
| ] | |||||
| for j in range(len(split)): | |||||
| replace_param[grouped_params[j]["id"]] = split[j] | |||||
| self._model.replace_param(replace_param, 0) | |||||
| return self._model.forward(*args, **kwargs) | |||||
| @@ -0,0 +1,80 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| from abc import abstractmethod | |||||
| from typing import Tuple, Union | |||||
| from ..functional import avg_pool2d, max_pool2d | |||||
| from .module import Module | |||||
| class _PoolNd(Module): | |||||
| def __init__( | |||||
| self, | |||||
| kernel_size: Union[int, Tuple[int, int]], | |||||
| stride: Union[int, Tuple[int, int]] = None, | |||||
| padding: Union[int, Tuple[int, int]] = 0, | |||||
| ): | |||||
| super(_PoolNd, self).__init__() | |||||
| self.kernel_size = kernel_size | |||||
| self.stride = stride or kernel_size | |||||
| self.padding = padding | |||||
| @abstractmethod | |||||
| def forward(self, inp): | |||||
| pass | |||||
| class MaxPool2d(_PoolNd): | |||||
| r"""Applies a 2D max pooling over an input. | |||||
| For instance, given an input of the size :math:`(N, C, H, W)` and | |||||
| :attr:`kernel_size` :math:`(kH, kW)`, this layer generates the output of | |||||
| the size :math:`(N, C, H_{out}, W_{out})` through a process described as: | |||||
| .. math:: | |||||
| \begin{aligned} | |||||
| out(N_i, C_j, h, w) ={} & \max_{m=0, \ldots, kH-1} \max_{n=0, \ldots, kW-1} | |||||
| \text{input}(N_i, C_j, \text{stride[0]} \times h + m, | |||||
| \text{stride[1]} \times w + n) | |||||
| \end{aligned} | |||||
| If :attr:`padding` is non-zero, then the input is implicitly zero-padded on | |||||
| both sides for :attr:`padding` number of points. | |||||
| :param kernel_size: the size of the window to take a max over. | |||||
| :param stride: the stride of the window. Default value is ``kernel_size``. | |||||
| :param padding: implicit zero padding to be added on both sides. | |||||
| """ | |||||
| def forward(self, inp): | |||||
| return max_pool2d(inp, self.kernel_size, self.stride, self.padding) | |||||
| class AvgPool2d(_PoolNd): | |||||
| r"""Applies a 2D average pooling over an input. | |||||
| For instance, given an input of the size :math:`(N, C, H, W)` and | |||||
| :attr:`kernel_size` :math:`(kH, kW)`, this layer generates the output of | |||||
| the size :math:`(N, C, H_{out}, W_{out})` through a process described as: | |||||
| .. math:: | |||||
| out(N_i, C_j, h, w) = \frac{1}{kH * kW} \sum_{m=0}^{kH-1} \sum_{n=0}^{kW-1} | |||||
| input(N_i, C_j, stride[0] \times h + m, stride[1] \times w + n) | |||||
| If :attr:`padding` is non-zero, then the input is implicitly zero-padded on | |||||
| both sides for :attr:`padding` number of points. | |||||
| :param kernel_size: the size of the window. | |||||
| :param stride: the stride of the window. Default value is ``kernel_size``. | |||||
| :param padding: implicit zero padding to be added on both sides. | |||||
| """ | |||||
| def forward(self, inp): | |||||
| return avg_pool2d(inp, self.kernel_size, self.stride, self.padding) | |||||
| @@ -0,0 +1,14 @@ | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| from .concat import Concat | |||||
| from .conv import Conv2d, ConvRelu2d | |||||
| from .conv_bn import ConvBn2d, ConvBnRelu2d | |||||
| from .elemwise import Elemwise | |||||
| from .linear import Linear | |||||
| from .module import QATModule | |||||
| from .quant_dequant import DequantStub, QuantStub | |||||
| @@ -0,0 +1,30 @@ | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| from typing import Iterable | |||||
| from ...tensor import Tensor | |||||
| from .. import concat as Float | |||||
| from .module import QATModule | |||||
| class Concat(Float.Concat, QATModule): | |||||
| r""" | |||||
| A :class:`~.QATModule` to do functional concat with QAT support. | |||||
| Could be applied with :class:`~.Observer` and :class:`~.FakeQuantize`. | |||||
| """ | |||||
| def forward(self, inps: Iterable[Tensor], axis: int = 0): | |||||
| return self.apply_quant_activation(super().forward(inps, axis)) | |||||
| @classmethod | |||||
| def from_float_module(cls, float_module): | |||||
| r""" | |||||
| Return a :class:`~.QATModule` instance converted from | |||||
| a float :class:`~.Module` instance. | |||||
| """ | |||||
| return cls() | |||||
| @@ -0,0 +1,59 @@ | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| from ... import functional as F | |||||
| from ...quantization.utils import fake_quant_bias | |||||
| from .. import conv as Float | |||||
| from .module import QATModule | |||||
| class Conv2d(Float.Conv2d, QATModule): | |||||
| r""" | |||||
| A :class:`~.QATModule` Conv2d with QAT support. | |||||
| Could be applied with :class:`~.Observer` and :class:`~.FakeQuantize`. | |||||
| """ | |||||
| def calc_conv_qat(self, inp): | |||||
| w_qat = self.apply_quant_weight(self.weight) | |||||
| b_qat = fake_quant_bias(self.bias, inp, w_qat) | |||||
| conv = self.calc_conv(inp, w_qat, b_qat) | |||||
| return conv | |||||
| @classmethod | |||||
| def from_float_module(cls, float_module: Float.Conv2d): | |||||
| r""" | |||||
| Return a :class:`~.QATModule` instance converted from | |||||
| a float :class:`~.Module` instance. | |||||
| """ | |||||
| qat_module = cls( | |||||
| float_module.in_channels, | |||||
| float_module.out_channels, | |||||
| float_module.kernel_size, | |||||
| float_module.stride, | |||||
| float_module.padding, | |||||
| float_module.dilation, | |||||
| float_module.groups, | |||||
| float_module.bias is not None, | |||||
| float_module.conv_mode.name, | |||||
| float_module.compute_mode.name, | |||||
| ) | |||||
| qat_module.weight = float_module.weight | |||||
| qat_module.bias = float_module.bias | |||||
| return qat_module | |||||
| def forward(self, inp): | |||||
| return self.apply_quant_activation(self.calc_conv_qat(inp)) | |||||
| class ConvRelu2d(Conv2d): | |||||
| r""" | |||||
| A :class:`~.QATModule` include Conv2d and Relu with QAT support. | |||||
| Could be applied with :class:`~.Observer` and :class:`~.FakeQuantize`. | |||||
| """ | |||||
| def forward(self, inp): | |||||
| return self.apply_quant_activation(F.relu(self.calc_conv_qat(inp))) | |||||
| @@ -0,0 +1,193 @@ | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| from ...functional import add_update, ones, relu, sqrt, sum, zeros | |||||
| from ...quantization.utils import fake_quant_bias | |||||
| from .. import conv_bn as Float | |||||
| from .module import QATModule | |||||
| class _ConvBnActivation2d(Float._ConvBnActivation2d, QATModule): | |||||
| def get_batch_mean_var(self, inp): | |||||
| def _sum_channel(inp, axis=0, keepdims=True): | |||||
| if isinstance(axis, int): | |||||
| out = sum(inp, axis=axis, keepdims=keepdims) | |||||
| elif isinstance(axis, tuple): | |||||
| for idx, elem in enumerate(axis): | |||||
| out = sum(inp if idx == 0 else out, axis=elem, keepdims=keepdims) | |||||
| return out | |||||
| sum1 = _sum_channel(inp, (0, 2, 3)) | |||||
| sum2 = _sum_channel(inp ** 2, (0, 2, 3)) | |||||
| reduce_size = inp.size / inp.shape[1] | |||||
| batch_mean = sum1 / reduce_size | |||||
| batch_var = (sum2 - sum1 ** 2 / reduce_size) / reduce_size | |||||
| return batch_mean, batch_var | |||||
| def fold_weight_bias(self, bn_mean, bn_var): | |||||
| # get fold bn conv param | |||||
| # bn_istd = 1 / bn_std | |||||
| # w_fold = gamma / bn_std * W | |||||
| # b_fold = gamma * (b - bn_mean) / bn_std + beta | |||||
| gamma = self.bn.weight | |||||
| if gamma is None: | |||||
| gamma = ones((self.bn.num_features), dtype="float32") | |||||
| gamma = gamma.reshape(1, -1, 1, 1) | |||||
| beta = self.bn.bias | |||||
| if beta is None: | |||||
| beta = zeros((self.bn.num_features), dtype="float32") | |||||
| beta = beta.reshape(1, -1, 1, 1) | |||||
| if bn_mean is None: | |||||
| bn_mean = zeros((1, self.bn.num_features, 1, 1), dtype="float32") | |||||
| if bn_var is None: | |||||
| bn_var = ones((1, self.bn.num_features, 1, 1), dtype="float32") | |||||
| conv_bias = self.conv.bias | |||||
| if conv_bias is None: | |||||
| conv_bias = zeros(self.conv._infer_bias_shape(), dtype="float32") | |||||
| bn_istd = 1.0 / sqrt(bn_var + self.bn.eps) | |||||
| # bn_istd = 1 / bn_std | |||||
| # w_fold = gamma / bn_std * W | |||||
| scale_factor = gamma * bn_istd | |||||
| if self.conv.groups == 1: | |||||
| w_fold = self.conv.weight * scale_factor.reshape(-1, 1, 1, 1) | |||||
| else: | |||||
| w_fold = self.conv.weight * scale_factor.reshape( | |||||
| self.conv.groups, -1, 1, 1, 1 | |||||
| ) | |||||
| w_fold = self.apply_quant_weight(w_fold) | |||||
| # b_fold = gamma * (b - bn_mean) / bn_std + beta | |||||
| b_fold = beta + gamma * (conv_bias - bn_mean) * bn_istd | |||||
| return w_fold, b_fold | |||||
| def update_running_mean_and_running_var( | |||||
| self, bn_mean, bn_var, num_elements_per_channel | |||||
| ): | |||||
| # update running mean and running var. no grad, use unbiased bn var | |||||
| bn_mean = bn_mean.detach() | |||||
| bn_var = ( | |||||
| bn_var.detach() * num_elements_per_channel / (num_elements_per_channel - 1) | |||||
| ) | |||||
| exponential_average_factor = 1 - self.bn.momentum | |||||
| add_update( | |||||
| self.bn.running_mean, | |||||
| delta=bn_mean, | |||||
| alpha=1 - exponential_average_factor, | |||||
| beta=exponential_average_factor, | |||||
| ) | |||||
| add_update( | |||||
| self.bn.running_var, | |||||
| delta=bn_var, | |||||
| alpha=1 - exponential_average_factor, | |||||
| beta=exponential_average_factor, | |||||
| ) | |||||
| def calc_conv_bn_qat(self, inp, approx=True): | |||||
| if self.training and not approx: | |||||
| conv = self.conv(inp) | |||||
| bn_mean, bn_var = self.get_batch_mean_var(conv) | |||||
| num_elements_per_channel = conv.size / conv.shape[1] | |||||
| self.update_running_mean_and_running_var( | |||||
| bn_mean, bn_var, num_elements_per_channel | |||||
| ) | |||||
| else: | |||||
| bn_mean, bn_var = self.bn.running_mean, self.bn.running_var | |||||
| # get gamma and beta in BatchNorm | |||||
| gamma = self.bn.weight | |||||
| if gamma is None: | |||||
| gamma = ones((self.bn.num_features), dtype="float32") | |||||
| gamma = gamma.reshape(1, -1, 1, 1) | |||||
| beta = self.bn.bias | |||||
| if beta is None: | |||||
| beta = zeros((self.bn.num_features), dtype="float32") | |||||
| beta = beta.reshape(1, -1, 1, 1) | |||||
| # conv_bias | |||||
| conv_bias = self.conv.bias | |||||
| if conv_bias is None: | |||||
| conv_bias = zeros(self.conv._infer_bias_shape(), dtype="float32") | |||||
| bn_istd = 1.0 / sqrt(bn_var + self.bn.eps) | |||||
| # bn_istd = 1 / bn_std | |||||
| # w_fold = gamma / bn_std * W | |||||
| scale_factor = gamma * bn_istd | |||||
| if self.conv.groups == 1: | |||||
| w_fold = self.conv.weight * scale_factor.reshape(-1, 1, 1, 1) | |||||
| else: | |||||
| w_fold = self.conv.weight * scale_factor.reshape( | |||||
| self.conv.groups, -1, 1, 1, 1 | |||||
| ) | |||||
| b_fold = None | |||||
| if not (self.training and approx): | |||||
| # b_fold = gamma * (conv_bias - bn_mean) / bn_std + beta | |||||
| b_fold = beta + gamma * (conv_bias - bn_mean) * bn_istd | |||||
| w_qat = self.apply_quant_weight(w_fold) | |||||
| b_qat = fake_quant_bias(b_fold, inp, w_qat) | |||||
| conv = self.conv.calc_conv(inp, w_qat, b_qat) | |||||
| if not (self.training and approx): | |||||
| return conv | |||||
| # rescale conv to get original conv output | |||||
| orig_conv = conv / scale_factor.reshape(1, -1, 1, 1) | |||||
| if self.conv.bias is not None: | |||||
| orig_conv = orig_conv + self.conv.bias | |||||
| # calculate batch norm | |||||
| bn_mean, bn_var = self.get_batch_mean_var(orig_conv) | |||||
| bn_istd = 1.0 / sqrt(bn_var + self.bn.eps) | |||||
| conv = gamma * bn_istd * (orig_conv - bn_mean) + beta | |||||
| num_elements_per_channel = conv.size / conv.shape[1] | |||||
| self.update_running_mean_and_running_var( | |||||
| bn_mean, bn_var, num_elements_per_channel | |||||
| ) | |||||
| return conv | |||||
| @classmethod | |||||
| def from_float_module(cls, float_module: Float._ConvBnActivation2d): | |||||
| r""" | |||||
| Return a :class:`~.QATModule` instance converted from | |||||
| a float :class:`~.Module` instance. | |||||
| """ | |||||
| qat_module = cls( | |||||
| float_module.conv.in_channels, | |||||
| float_module.conv.out_channels, | |||||
| float_module.conv.kernel_size, | |||||
| float_module.conv.stride, | |||||
| float_module.conv.padding, | |||||
| float_module.conv.dilation, | |||||
| float_module.conv.groups, | |||||
| float_module.conv.bias is not None, | |||||
| float_module.conv.conv_mode.name, | |||||
| float_module.conv.compute_mode.name, | |||||
| ) | |||||
| qat_module.conv.weight = float_module.conv.weight | |||||
| qat_module.conv.bias = float_module.conv.bias | |||||
| qat_module.bn = float_module.bn | |||||
| return qat_module | |||||
| class ConvBn2d(_ConvBnActivation2d): | |||||
| r""" | |||||
| A fused :class:`~.QATModule` including Conv2d, BatchNorm2d with QAT support. | |||||
| Could be applied with :class:`~.Observer` and :class:`~.FakeQuantize`. | |||||
| """ | |||||
| def forward(self, inp): | |||||
| return self.apply_quant_activation(self.calc_conv_bn_qat(inp)) | |||||
| class ConvBnRelu2d(_ConvBnActivation2d): | |||||
| r""" | |||||
| A fused :class:`~.QATModule` including Conv2d, BatchNorm2d and relu with QAT support. | |||||
| Could be applied with :class:`~.Observer` and :class:`~.FakeQuantize`. | |||||
| """ | |||||
| def forward(self, inp): | |||||
| return self.apply_quant_activation(relu(self.calc_conv_bn_qat(inp))) | |||||