GitOrigin-RevId: f3b6e492d7
tags/v1.0.0-rc1
| @@ -47,10 +47,9 @@ option(MGE_DEBUG_UTIL "Enable debug utility" ON) | |||
| option(MGE_ENABLE_EXCEPTIONS "Build with exceptions" ON) | |||
| option(MGE_WITH_TEST "Enable test for MegEngine." OFF) | |||
| option(MGE_WITH_DISTRIBUTED "Build with distributed support" ON) | |||
| option(MGE_BUILD_IMPERATIVE_RT "Build _imperative_rt.so instead of _mgb.so " OFF) | |||
| option(MGE_BUILD_IMPERATIVE_RT "Build _imperative_rt Python Module " ON) | |||
| option(MGE_BUILD_SDK "Build load_and_run" ON) | |||
| option(MGE_INFERENCE_ONLY "Build inference only library." OFF) | |||
| option(MGE_WITH_PYTHON_MODULE "Build MegEngine Python Module." ON) | |||
| option(MGE_WITH_MKLDNN "Enable Intel MKL_DNN support," ON) | |||
| option(MGE_WITH_ROCM "Enable ROCM support" OFF) | |||
| @@ -256,8 +255,8 @@ endif() | |||
| if(MGE_INFERENCE_ONLY) | |||
| message("-- Disable distributed support for inference only build.") | |||
| set(MGE_WITH_DISTRIBUTED OFF) | |||
| message("-- Disable python module for inference only build.") | |||
| set(MGE_WITH_PYTHON_MODULE OFF) | |||
| message("-- Disable imperative_rt python module for inference only build.") | |||
| set(MGE_BUILD_IMPERATIVE_RT OFF) | |||
| endif() | |||
| if(MGE_WITH_DISTRIBUTED) | |||
| @@ -694,43 +693,18 @@ if(MGE_BUILD_SDK) | |||
| add_subdirectory(sdk/load-and-run) | |||
| endif() | |||
| if(MGE_WITH_PYTHON_MODULE) | |||
| if(MGE_BUILD_IMPERATIVE_RT) | |||
| add_subdirectory(imperative) | |||
| message("-- Enable imperative python wrapper runtime") | |||
| else() | |||
| add_subdirectory(python_module) | |||
| message("-- Enable legacy python wrapper runtime") | |||
| endif() | |||
| if(MGE_BUILD_IMPERATIVE_RT) | |||
| add_subdirectory(imperative) | |||
| message("-- Enable imperative python wrapper runtime") | |||
| endif() | |||
| if(MGE_WITH_TEST AND MGE_ENABLE_RTTI) | |||
| add_subdirectory(test) | |||
| endif() | |||
| if(TARGET mgb) | |||
| add_custom_target( | |||
| develop | |||
| COMMAND ${CMAKE_COMMAND} -E create_symlink | |||
| ${CMAKE_CURRENT_BINARY_DIR}/python_module/megengine/_internal/$<TARGET_FILE_NAME:mgb> | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/python_module/megengine/_internal/$<TARGET_FILE_NAME:mgb> | |||
| COMMAND ${CMAKE_COMMAND} -E create_symlink | |||
| ${CMAKE_CURRENT_BINARY_DIR}/python_module/megengine/_internal/mgb.py | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/python_module/megengine/_internal/mgb.py | |||
| COMMAND ${CMAKE_COMMAND} -E create_symlink | |||
| ${CMAKE_CURRENT_BINARY_DIR}/python_module/megengine/_internal/opr.py | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/python_module/megengine/_internal/opr.py | |||
| COMMAND ${CMAKE_COMMAND} -E create_symlink | |||
| ${CMAKE_CURRENT_BINARY_DIR}/python_module/megengine/_internal/opr_param_defs.py | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/python_module/megengine/_internal/opr_param_defs.py | |||
| COMMAND ${CMAKE_COMMAND} -E create_symlink | |||
| ${CMAKE_CURRENT_BINARY_DIR}/python_module/megengine/_internal/include | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/python_module/megengine/_internal/include | |||
| DEPENDS mgb | |||
| VERBATIM | |||
| ) | |||
| elseif(TARGET _imperative_rt) | |||
| if(TARGET _imperative_rt) | |||
| add_custom_target( | |||
| develop | |||
| COMMAND ${CMAKE_COMMAND} -E create_symlink | |||
| @@ -183,25 +183,16 @@ def typename(type): | |||
| # parse typing.Union | |||
| if sys.version_info < (3, 6): | |||
| def parse_union(ann): | |||
| def parse_union(ann): | |||
| if hasattr(typing, "UnionMeta"): | |||
| if type(ann) is not typing.UnionMeta: | |||
| return | |||
| return ann.__union_params__ | |||
| elif sys.version_info < (3, 7): | |||
| def parse_union(ann): | |||
| elif hasattr(typing, "_Union"): | |||
| if type(ann) is not typing._Union: | |||
| return | |||
| return ann.__args__ | |||
| elif sys.version_info < (3, 8): | |||
| def parse_union(ann): | |||
| elif hasattr(typing, "_GenericAlias"): | |||
| if type(ann) is not typing._GenericAlias: | |||
| if type(ann) is not typing.Union: | |||
| return | |||
| @@ -209,11 +200,9 @@ elif sys.version_info < (3, 8): | |||
| if ann.__origin__ is not typing.Union: | |||
| return | |||
| return ann.__args__ | |||
| else: | |||
| def parse_union(ann): | |||
| elif hasattr(typing, "Union"): | |||
| if typing.get_origin(ann) is not typing.Union: | |||
| return | |||
| return typing.get_args(ann) | |||
| else: | |||
| raise NotImplementedError("unsupported Python version") | |||
| @@ -6,6 +6,7 @@ | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import os | |||
| import re | |||
| import pathlib | |||
| @@ -55,11 +56,13 @@ package_data = [ | |||
| str(f.relative_to('megengine')) | |||
| for f in pathlib.Path('megengine', 'core', 'include').glob('**/*') | |||
| ] | |||
| package_data += [ | |||
| str(f.relative_to('megengine')) | |||
| for f in pathlib.Path('megengine', 'core', 'lib').glob('**/*') | |||
| ] | |||
| with open('requires.txt') as f: | |||
| requires = f.read().splitlines() | |||
| with open('requires-style.txt') as f: | |||
| @@ -67,6 +70,7 @@ with open('requires-style.txt') as f: | |||
| with open('requires-test.txt') as f: | |||
| requires_test = f.read().splitlines() | |||
| prebuild_modules=[PrecompiledExtesion('megengine.core._imperative_rt')] | |||
| setup_kwargs = dict( | |||
| name=package_name, | |||
| version=__version__, | |||
| @@ -78,7 +82,7 @@ setup_kwargs = dict( | |||
| package_data={ | |||
| 'megengine': package_data, | |||
| }, | |||
| ext_modules=[PrecompiledExtesion('megengine.core._imperative_rt')], | |||
| ext_modules=prebuild_modules, | |||
| install_requires=requires, | |||
| extras_require={ | |||
| 'dev': requires_style + requires_test, | |||
| @@ -87,6 +91,7 @@ setup_kwargs = dict( | |||
| cmdclass={'build_ext': build_ext}, | |||
| ) | |||
| setup_kwargs.update(dict( | |||
| classifiers=[ | |||
| 'Development Status :: 3 - Alpha', | |||
| @@ -0,0 +1,21 @@ | |||
| #!/bin/bash -e | |||
| test_dirs="test" | |||
| TEST_PLAT=$1 | |||
| if [[ "$TEST_PLAT" == cpu ]]; then | |||
| echo "only test cpu pytest" | |||
| elif [[ "$TEST_PLAT" == cuda ]]; then | |||
| echo "test both cpu and gpu pytest" | |||
| else | |||
| log "Argument must cpu or cuda" | |||
| exit 1 | |||
| fi | |||
| pushd $(dirname "${BASH_SOURCE[0]}")/.. >/dev/null | |||
| PYTHONPATH="." PY_IGNORE_IMPORTMISMATCH=1 python3 -m pytest $test_dirs -m 'not isolated_distributed' | |||
| if [[ "$TEST_PLAT" == cuda ]]; then | |||
| echo "test GPU pytest now" | |||
| PYTHONPATH="." PY_IGNORE_IMPORTMISMATCH=1 python3 -m pytest $test_dirs -m 'isolated_distributed' | |||
| fi | |||
| popd >/dev/null | |||
| @@ -1,8 +0,0 @@ | |||
| /megbrain/_mgb.so | |||
| /megbrain/_mgb.*.so | |||
| /MegBrain.egg-info/ | |||
| /dist | |||
| /dist_cuda | |||
| /dist_nocuda | |||
| /wheel_dist | |||
| .cache | |||
| @@ -1,113 +0,0 @@ | |||
| cmake_policy(SET CMP0086 NEW) | |||
| find_package(PythonLibs ${PYTHON_VERSION_STRING} EXACT REQUIRED) | |||
| find_package(Git) | |||
| if(GIT_FOUND) | |||
| message("git found: ${GIT_EXECUTABLE}") | |||
| endif() | |||
| find_package(NumPy REQUIRED) | |||
| find_package(SWIG REQUIRED) | |||
| set(SWIG_SRC src/swig/mgb.i) | |||
| if(MSVC OR WIN32) | |||
| set(CMAKE_SWIG_FLAGS -Wall -threads -py3 -DSWIGWORDSIZE64) | |||
| message("WARN: swig have some define issue at windows(64) env") | |||
| message("Please refs scripts/whl/BUILD_PYTHON_WHL_README.md to init windows build env") | |||
| else() | |||
| set(CMAKE_SWIG_FLAGS -Wall -threads -py3 -modern -DSWIGWORDSIZE64) | |||
| endif() | |||
| set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-parameter") | |||
| file(GLOB_RECURSE OPR_DECL_SRCS "${PROJECT_SOURCE_DIR}/src/**/*.oprdecl") | |||
| file(GLOB_RECURSE PYTHON_SRCS setup.py | |||
| src/python/*.py | |||
| test/*.py | |||
| megengine/*.py) | |||
| list(REMOVE_ITEM PYTHON_SRCS | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/megengine/_internal/mgb.py | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/megengine/_internal/opr.py | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/megengine/_internal/opr_param_defs.py | |||
| ) | |||
| list(APPEND PYTHON_SRCS ${MGB_SRCS}) | |||
| file(GLOB_RECURSE ALL_HEADERS src/cpp/megbrain_pubapi.h | |||
| ${PROJECT_SOURCE_DIR}/src/core/include/* | |||
| ${PROJECT_SOURCE_DIR}/src/opr/include/* | |||
| ${PROJECT_SOURCE_DIR}/src/serialization/include/* | |||
| ${PROJECT_SOURCE_DIR}/src/plugin/include/* | |||
| ${PROJECT_SOURCE_DIR}/dnn/include/*) | |||
| file(COPY ${PROJECT_SOURCE_DIR}/dnn/scripts/opr_param_defs.py DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) | |||
| file(READ ${PROJECT_SOURCE_DIR}/tools/param_defs/mgb_opr_param_defs.py CONTENTS) | |||
| file(APPEND ${CMAKE_CURRENT_BINARY_DIR}/opr_param_defs.py ${CONTENTS}) | |||
| add_custom_command( | |||
| OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/megengine/_internal/opr.py ${CMAKE_CURRENT_BINARY_DIR}/megengine/_internal/opr_param_defs.py | |||
| COMMAND ${CMAKE_COMMAND} -E copy_directory ${CMAKE_CURRENT_SOURCE_DIR}/src/python ${CMAKE_CURRENT_BINARY_DIR}/src/python | |||
| COMMAND ${CMAKE_COMMAND} -E make_directory ${CMAKE_CURRENT_BINARY_DIR}/megengine/_internal | |||
| COMMAND ${PYTHON_EXECUTABLE} ${CMAKE_CURRENT_BINARY_DIR}/src/python/genopr.py ${OPR_DECL_SRCS} | |||
| COMMAND ${PYTHON_EXECUTABLE} ${PROJECT_SOURCE_DIR}/dnn/scripts/gen_param_defs.py -t py ${CMAKE_CURRENT_BINARY_DIR}/opr_param_defs.py ${CMAKE_CURRENT_BINARY_DIR}/megengine/_internal/opr_param_defs.py | |||
| DEPENDS ${OPR_DECL_SRCS} | |||
| VERBATIM | |||
| ) | |||
| add_custom_target(mgb_opr_py DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/megengine/_internal/opr.py) | |||
| set(SRCS src/cpp/craniotome.cpp src/cpp/function_replace.cpp src/cpp/intbx.cpp src/cpp/bfloat16.cpp src/cpp/megbrain_config.cpp src/cpp/megbrain_pubapi.cpp src/cpp/megbrain_serialize.cpp src/cpp/megbrain_wrap.cpp src/cpp/opr_defs.cpp src/cpp/opr_helper.cpp src/cpp/plugin.cpp src/cpp/python_helper.cpp) | |||
| include(UseSWIG) | |||
| set_property(SOURCE ${SWIG_SRC} PROPERTY CPLUSPLUS ON) | |||
| # cmake < 3.12 do not honor INCLUDE_DIRECTORIES property, just add include directory into SWIG_FLAGS | |||
| # Add -I${PROJECT_BINARY_DIR}/genfiles in order to include megbrain_build_config.h so that we don't need to pass cmake flags by -D. | |||
| set_property(SOURCE ${SWIG_SRC} PROPERTY SWIG_FLAGS -I${PROJECT_SOURCE_DIR}/src/serialization/include -I${PROJECT_BINARY_DIR}/genfiles) | |||
| set(SWIG_OUTFILE_DIR ${CMAKE_CURRENT_BINARY_DIR}) | |||
| set(CMAKE_SWIG_OUTDIR ${CMAKE_CURRENT_BINARY_DIR}/megengine/_internal) | |||
| swig_add_library(mgb LANGUAGE python SOURCES ${SWIG_SRC} ${SRCS}) | |||
| set(VERSION_SCRIPT ${CMAKE_CURRENT_SOURCE_DIR}/src/version.ld) | |||
| add_custom_target(version_ld SOURCES ${VERSION_SCRIPT}) | |||
| set_target_properties(mgb PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/megengine/_internal) | |||
| if (APPLE) | |||
| target_link_libraries(mgb megbrain megdnn) | |||
| set_target_properties(mgb PROPERTIES LINK_FLAGS "-undefined dynamic_lookup") | |||
| elseif (MSVC OR WIN32) | |||
| target_link_libraries(mgb megbrain megdnn) | |||
| else() | |||
| target_link_libraries(mgb megbrain megdnn -Wl,--version-script=${VERSION_SCRIPT}) | |||
| endif() | |||
| target_include_directories(mgb PRIVATE ${PYTHON_INCLUDE_DIRS} src/cpp ${CMAKE_CURRENT_BINARY_DIR} ${NUMPY_INCLUDE_DIR}) | |||
| # only windows need link PYTHON_LIBRARIES | |||
| if(MSVC OR WIN32) | |||
| target_link_libraries(mgb ${PYTHON_LIBRARIES}) | |||
| endif() | |||
| if (MGE_WITH_DISTRIBUTED) | |||
| target_link_libraries(mgb megray) | |||
| endif() | |||
| add_dependencies(mgb mgb_opr_py version_ld) | |||
| add_custom_command( | |||
| TARGET mgb POST_BUILD | |||
| COMMAND ${CMAKE_COMMAND} -E copy ${PROJECT_SOURCE_DIR}/LICENSE ${PROJECT_SOURCE_DIR}/ACKNOWLEDGMENTS ${PROJECT_BINARY_DIR} | |||
| COMMAND ${CMAKE_COMMAND} -E copy_directory ${CMAKE_CURRENT_SOURCE_DIR}/megengine ${CMAKE_CURRENT_BINARY_DIR}/megengine | |||
| COMMAND ${CMAKE_COMMAND} -E copy_directory ${CMAKE_CURRENT_SOURCE_DIR}/test ${CMAKE_CURRENT_BINARY_DIR}/test | |||
| COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_SOURCE_DIR}/setup.py ${CMAKE_CURRENT_BINARY_DIR}/setup.py | |||
| COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_SOURCE_DIR}/requires.txt ${CMAKE_CURRENT_BINARY_DIR}/requires.txt | |||
| COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_SOURCE_DIR}/requires-style.txt ${CMAKE_CURRENT_BINARY_DIR}/requires-style.txt | |||
| COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_SOURCE_DIR}/requires-test.txt ${CMAKE_CURRENT_BINARY_DIR}/requires-test.txt | |||
| COMMAND ${CMAKE_COMMAND} -E remove_directory ${CMAKE_CURRENT_BINARY_DIR}/megengine/_internal/include | |||
| COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_SOURCE_DIR}/src/cpp/megbrain_pubapi.h ${CMAKE_CURRENT_BINARY_DIR}/megengine/_internal/include/megbrain_pubapi.h | |||
| COMMAND ${CMAKE_COMMAND} -E copy_directory ${PROJECT_SOURCE_DIR}/src/core/include ${CMAKE_CURRENT_BINARY_DIR}/megengine/_internal/include | |||
| COMMAND ${CMAKE_COMMAND} -E copy_directory ${PROJECT_SOURCE_DIR}/src/opr/include ${CMAKE_CURRENT_BINARY_DIR}/megengine/_internal/include | |||
| COMMAND ${CMAKE_COMMAND} -E copy_directory ${PROJECT_SOURCE_DIR}/src/serialization/include ${CMAKE_CURRENT_BINARY_DIR}/megengine/_internal/include | |||
| COMMAND ${CMAKE_COMMAND} -E copy_directory ${PROJECT_SOURCE_DIR}/src/plugin/include ${CMAKE_CURRENT_BINARY_DIR}/megengine/_internal/include | |||
| COMMAND ${CMAKE_COMMAND} -E copy_directory ${PROJECT_SOURCE_DIR}/dnn/include ${CMAKE_CURRENT_BINARY_DIR}/megengine/_internal/include | |||
| COMMAND ${CMAKE_COMMAND} -E copy ${PROJECT_BINARY_DIR}/genfiles/megbrain_build_config.h ${CMAKE_CURRENT_BINARY_DIR}/megengine/_internal/include/megbrain_build_config.h | |||
| ) | |||
| @@ -1,11 +0,0 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from .core import * | |||
| from .logger import enable_debug_log, get_logger, set_log_file, set_log_level | |||
| from .version import __version__ | |||
| @@ -1,729 +0,0 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| """the megbrain python package | |||
| Note that all the submodules are automatically imported, so you usually only | |||
| need to ``import megengine._internal as mgb``. | |||
| """ | |||
| import collections | |||
| import json | |||
| import os | |||
| import sys | |||
| import platform | |||
| import ctypes | |||
| if sys.platform == "win32": | |||
| lib_path = os.path.join(os.path.dirname(__file__), "lib") | |||
| Lib_path = os.path.join(os.path.dirname(__file__), "Lib") | |||
| dll_paths = list(filter(os.path.exists, [lib_path, Lib_path])) | |||
| assert len(dll_paths) > 0 | |||
| kernel32 = ctypes.WinDLL("kernel32.dll", use_last_error=True) | |||
| has_load_library_attr = hasattr(kernel32, "AddDllDirectory") | |||
| old_error_mode = kernel32.SetErrorMode(0x0001) | |||
| kernel32.LoadLibraryW.restype = ctypes.c_void_p | |||
| if has_load_library_attr: | |||
| kernel32.AddDllDirectory.restype = ctypes.c_void_p | |||
| kernel32.LoadLibraryExW.restype = ctypes.c_void_p | |||
| for dll_path in dll_paths: | |||
| if sys.version_info >= (3, 8): | |||
| os.add_dll_directory(dll_path) | |||
| elif has_load_library_attr: | |||
| res = kernel32.AddDllDirectory(dll_path) | |||
| if res is None: | |||
| err = ctypes.WinError(ctypes.get_last_error()) | |||
| err.strerror += ' Error adding "{}" to the DLL search PATH.'.format( | |||
| dll_path | |||
| ) | |||
| raise err | |||
| else: | |||
| print("WARN: python or OS env have some issue, may load DLL failed!!!") | |||
| import glob | |||
| dlls = glob.glob(os.path.join(lib_path, "*.dll")) | |||
| path_patched = False | |||
| for dll in dlls: | |||
| is_loaded = False | |||
| if has_load_library_attr: | |||
| res = kernel32.LoadLibraryExW(dll, None, 0x00001100) | |||
| last_error = ctypes.get_last_error() | |||
| if res is None and last_error != 126: | |||
| err = ctypes.WinError(last_error) | |||
| err.strerror += ' Error loading "{}" or one of its dependencies.'.format( | |||
| dll | |||
| ) | |||
| raise err | |||
| elif res is not None: | |||
| is_loaded = True | |||
| if not is_loaded: | |||
| if not path_patched: | |||
| os.environ["PATH"] = ";".join(dll_paths + [os.environ["PATH"]]) | |||
| path_patched = True | |||
| res = kernel32.LoadLibraryW(dll) | |||
| if res is None: | |||
| err = ctypes.WinError(ctypes.get_last_error()) | |||
| err.strerror += ' Error loading "{}" or one of its dependencies.'.format( | |||
| dll | |||
| ) | |||
| raise err | |||
| kernel32.SetErrorMode(old_error_mode) | |||
| import numpy as np | |||
| from . import comp_graph_tools as cgtools | |||
| from . import config, craniotome, dtype | |||
| from . import global_init as _global_init | |||
| from . import helper as _helper | |||
| from . import mgb as _detail | |||
| from . import opr, opr_extra, opr_param_defs, plugin | |||
| from .exc import MegBrainError | |||
| from .logconf import get_logger | |||
| from .mgb import ( | |||
| CompGraph, | |||
| CompNode, | |||
| SharedND, | |||
| SharedScalar, | |||
| SymbolVar, | |||
| TensorValueDumperContext, | |||
| TensorValueLoaderContext, | |||
| ) | |||
| from .mgb import as_comp_node as comp_node | |||
| from .mgb_helper import SharedNDLazyInitializer, callback_lazycopy, copy_output | |||
| from .plugin import CompGraphProfiler | |||
| from .plugin import GlobalInfkernFinder as _GlobalInfkernFinder | |||
| from .plugin import NumRangeChecker | |||
| from .version import __version__, version_info | |||
| if sys.version_info.major < 3: | |||
| raise ImportError("megbrain requires python 3") | |||
| class ProxySharedNDAndSymbolVar(_detail.SymbolVar): | |||
| """this is a :class:`.SymbolVar` with a corresponding :class:`.SharedND`. | |||
| It can participate in graph computating and also provides :meth:`set_value` | |||
| and :meth:`get_value`. It should be constructed by :func:`make_shared`. | |||
| """ | |||
| __shared_nd = None | |||
| __kwargs = None | |||
| def __init__(self, snd, comp_graph, name, **kwargs): | |||
| self.__shared_nd = snd | |||
| self.__kwargs = kwargs | |||
| self.this = snd.symvar(comp_graph=comp_graph, name=name, **kwargs).this | |||
| def set_value(self, v, **kwargs): | |||
| ret = self.__shared_nd.set_value(v, **kwargs) | |||
| self._reeval_if_eager_eval() | |||
| return ret | |||
| def get_value(self): | |||
| return self.__shared_nd.get_value() | |||
| def reset_zero(self): | |||
| self.__shared_nd.reset_zero() | |||
| def make_shared( | |||
| comp_node, | |||
| *, | |||
| dtype=None, | |||
| shape=None, | |||
| value=None, | |||
| comp_graph=None, | |||
| name=None, | |||
| volatile=None | |||
| ): | |||
| """make a shared tensor which is stored on device and could be modified | |||
| later, either as a :class:`.SymbolVar` or a :class:`.SharedND` object | |||
| :param comp_node: computing node | |||
| :type comp_node: :class:`.CompNode` | |||
| :param dtype: data type; if it is None, then dtype of value would be used | |||
| if value is not None, and float32 would be used as default dtype if | |||
| value is None | |||
| :type dtype: :class:`numpy.dtype` compatible | |||
| :param value: initializing value | |||
| :type value: None or :class:`numpy.ndarray` | |||
| :param comp_graph: the computing graph to which this shared value should | |||
| belong; if provided, the retuned object could be used as a | |||
| :class:`.SymbolVar` | |||
| :type comp_graph: None or :class:`.CompGraph` | |||
| :param name: node name to be used in computing graph; only meaningful if | |||
| *comp_graph* is provided | |||
| :param volatile: if *comp_graph* is given then *volatile* indicates whether | |||
| shape or mem ptr of this SharedND can be changed | |||
| :rtype: :class:`.SharedND` if *comp_graph* is not given; or | |||
| :class:`ProxySharedNDAndSymbolVar` otherwise | |||
| """ | |||
| if dtype is None: | |||
| if value is not None: | |||
| value = np.ascontiguousarray(value) | |||
| dtype = to_mgb_supported_dtype(value.dtype) | |||
| else: | |||
| dtype = np.float32 | |||
| comp_node = _detail.as_comp_node(comp_node) | |||
| rst = _detail.SharedND(comp_node, dtype) | |||
| if value is not None: | |||
| assert shape is None, "could not provide both value and shape" | |||
| rst.set_value(value) | |||
| elif shape is not None: | |||
| rst._set_init_shape(shape) | |||
| if comp_graph is None: | |||
| assert name is None and volatile is None | |||
| return rst | |||
| assert isinstance(comp_graph, CompGraph), "expect CompGraph but got {}".format( | |||
| comp_graph | |||
| ) | |||
| if volatile is None: | |||
| volatile = False | |||
| else: | |||
| assert isinstance(volatile, bool) | |||
| return ProxySharedNDAndSymbolVar(rst, comp_graph, name, volatile=volatile) | |||
| def make_immutable(comp_node, comp_graph, value, *, dtype=None, name=None): | |||
| """make a graph node containing an immutable tensor from host tensor value | |||
| :param dtype: required data type; if not None, the data would be converted | |||
| to that type; otherwise | |||
| """ | |||
| comp_node = _detail.as_comp_node(comp_node) | |||
| assert isinstance( | |||
| comp_graph, _detail.CompGraph | |||
| ), "expect CompGraph but got {!r}".format(comp_graph) | |||
| config = _detail.make_opr_config(name, comp_node) | |||
| return _helper.cvt_opr_result( | |||
| _detail._make_immutable(comp_graph, value, dtype, config) | |||
| ) | |||
| def make_arg( | |||
| comp_node, | |||
| comp_graph, | |||
| *, | |||
| dtype=np.float32, | |||
| shape=None, | |||
| name=None, | |||
| value=None, | |||
| enable_static_infer=True | |||
| ): | |||
| """make an argument to be passed to compiled function during runtime; | |||
| :type shape: None or tuple of int | |||
| :param shape: expected tensor shape to be used for shape inferring; actual | |||
| tesor shape could be different | |||
| :type name: str | |||
| :param name: name of the generated var node | |||
| :type value: None or ndarray-compatible | |||
| :param value: initial value used for static inference; if not given, static | |||
| infer would be deferred to first graph execution | |||
| :param enable_static_infer: whether to enable static inference for this var | |||
| """ | |||
| comp_node = _detail.as_comp_node(comp_node) | |||
| host_val = mgb._HostSharedND(comp_node, dtype) | |||
| if value is not None: | |||
| value = np.ascontiguousarray(value, dtype=dtype) | |||
| if shape is None: | |||
| shape = value.shape | |||
| else: | |||
| assert shape == value.shape | |||
| if shape is not None: | |||
| host_val._resize(shape) | |||
| if value is not None: | |||
| host_val.set_value(value) | |||
| return _helper.cvt_opr_result( | |||
| ProxySharedNDAndSymbolVar( | |||
| host_val, comp_graph, name, enable_static_infer=enable_static_infer | |||
| ) | |||
| ) | |||
| def comp_graph(*, extra_opts=None, check_env_var=True): | |||
| """allocate a new computing graph | |||
| :param extra_opts: extra options to be set; would be updated (modified | |||
| inplace) from ``MGB_COMP_GRAPH_OPT`` environment var. See | |||
| :func:`.set_comp_graph_option` for list of supported options. | |||
| :type extra_opts: dict | |||
| :param check_env_var: whether to check environment vars | |||
| :type check_env_var: bool | |||
| :return: the comp graph object | |||
| :rtype: :class:`.CompGraph` | |||
| """ | |||
| cg = _detail.CompGraph() | |||
| if extra_opts is None: | |||
| extra_opts = {} | |||
| if check_env_var: | |||
| setting = os.getenv("MGB_COMP_GRAPH_OPT") | |||
| if setting: | |||
| for item in setting.split(";"): | |||
| k, v = item.split("=", 1) | |||
| extra_opts.setdefault(k, v) | |||
| get_logger().warning( | |||
| "set comp graph option from env: {}".format(extra_opts) | |||
| ) | |||
| user_data = os.getenv("MGB_COMP_GRAPH_USER_DATA") | |||
| if user_data: | |||
| storage = cg.user_data | |||
| for ud in user_data.split(";"): | |||
| k, v = ud.split("=", 1) | |||
| storage[k] = eval(v) | |||
| _GlobalInfkernFinder.add_graph(cg) | |||
| for k, v in extra_opts.items(): | |||
| cg.set_option(k, v) | |||
| return cg | |||
| def grad( | |||
| target, wrt, warn_mid_wrt=True, use_virtual_grad=None, return_zero_for_nodep=True | |||
| ): | |||
| r"""compute symbolic grad | |||
| :param target: grad target var | |||
| :type target: :class:`.SymbolVar` | |||
| :param wrt: with respect to which to compute the grad | |||
| :type wrt: :class:`.SymbolVar` or Iterable[SymbolVar] | |||
| :param warn_mid_wrt: whether to give warning if *wrt* is not endpoint | |||
| :type warn_mid_wrt: bool | |||
| :param use_virtual_grad: whether to use virtual grad opr, so fwd graph can | |||
| be optimized before applying grad; if ``None`` is given, then virtual | |||
| grad would be used if ``graph_opt_level >= 2`` | |||
| :type use_virtual_grad: :class:`bool` or ``None`` | |||
| :param return_zero_for_nodep: if *target* does not depend on *wrt*, set to True to return | |||
| a zero-valued `.SymbolVar` rather than ``None``; can't be set to False when using | |||
| virtual grad opr. | |||
| :type return_zero_for_nodep: bool | |||
| :rtype: :class:`.SymbolVar` or None | |||
| :return: :math:`\frac{\partial\text{target}}{\partial\text{wrt}}` | |||
| """ | |||
| if use_virtual_grad is None: | |||
| use_virtual_grad = -1 | |||
| else: | |||
| use_virtual_grad = 1 if use_virtual_grad else 0 | |||
| if isinstance(wrt, SymbolVar): | |||
| wrts = [ | |||
| wrt, | |||
| ] | |||
| else: | |||
| wrts = wrt | |||
| assert isinstance(wrts, collections.Iterable) | |||
| # return a invalid SymbolVar (with nullptr VarNode*) when return_zero_for_nodep is False | |||
| # and target doesn't depend on wrt | |||
| grads = _detail._grad( | |||
| target, wrts, bool(warn_mid_wrt), use_virtual_grad, return_zero_for_nodep | |||
| ) | |||
| grads = list(grads) | |||
| for i in range(len(grads)): | |||
| if not grads[i].valid: | |||
| assert ( | |||
| not return_zero_for_nodep | |||
| ), "invalid grad SymbolVar: target={}, wrt={}".format(target, wrts[i]) | |||
| grads[i] = None | |||
| if len(grads) == 1: | |||
| grads = grads[0] | |||
| return grads | |||
| def current_grad_target(comp_graph): | |||
| """get current target var to compute grad, used for implementing custom | |||
| gradient""" | |||
| return _detail._current_grad_target(comp_graph) | |||
| def add_device_map(map_location): | |||
| """add map location while loading models""" | |||
| _detail.CompNode.cn_thread_local.__setattr__("map_location", map_location) | |||
| def del_device_map(): | |||
| """delete map location""" | |||
| _detail.CompNode.cn_thread_local.__delattr__("map_location") | |||
| def inter_graph_trans_var(dest_graph, src): | |||
| """get the corresponding var of *src* in *dest_graph*; assuming | |||
| *dest_graph* is a copy of owner graph of *src*; usually used in callback of | |||
| set_grad to get grad of vars in loop | |||
| :param dest_graph: target computing graph | |||
| :type dest_graph: :class:`.CompGraph` | |||
| :param src: source var node | |||
| :type src: :class:`.SymbolVar` | |||
| :return: corresponding var in *dest_graph* | |||
| :rtype: :class:`.SymbolVar` | |||
| """ | |||
| return _detail._inter_graph_trans_var(dest_graph, src) | |||
| def get_graph_optimizer_replaced_var(src): | |||
| """get optimized var corresponding to given var; usually used in callback | |||
| of set_grad to get grad w.r.t. some var | |||
| :param src: source var node | |||
| :type src: :class:`.SymbolVar` | |||
| :rtype: :class:`.SymbolVar` | |||
| """ | |||
| return _detail._get_graph_optimizer_replaced_var(src) | |||
| CompGraphSerializationResult = collections.namedtuple( | |||
| "CompGraphSerializationResult", | |||
| [ | |||
| "nr_opr", | |||
| "tot_bytes", | |||
| "tensor_value_bytes", | |||
| "content_hash", | |||
| "inputs", | |||
| "outputs", | |||
| "params", | |||
| ], | |||
| ) | |||
| def serialize_comp_graph_to_file( | |||
| fpath, | |||
| output_vars, | |||
| *, | |||
| keep_var_name=1, | |||
| keep_param_name=False, | |||
| keep_opr_priority=False, | |||
| tensor_value_dumper=None, | |||
| output_strip_info=False, | |||
| append=False, | |||
| format=None, | |||
| **kwargs | |||
| ): | |||
| """serialize this computing graph and write result to a file. Note: | |||
| ``kwargs`` exists for backward compatibility; there is no additional | |||
| arguments. | |||
| :parma fpath: path for the output file | |||
| :type fpath: ``str`` | |||
| :param output_vars: output variables that need to be retrieved when | |||
| deserializing | |||
| .. note:: | |||
| The underlying C++ API only accepts a var list. If a dict is given, | |||
| the vars would be renamed to given names. | |||
| :type output_vars: dict(name => :class:`.SymbolVar`), or a list of vars | |||
| :param keep_var_name: level for keeping variable names: | |||
| * 0: none of the names are kept | |||
| * 1: keep names of output vars | |||
| * 2: keep names of all (output and internal) vars | |||
| :param keep_param_name: whether to keep param names, so param values can be | |||
| easily manipulated after loading model | |||
| :param keep_opr_priority: whether to keep priority setting for operators | |||
| :param tensor_value_dumper: a callable to dump tensor values; it should | |||
| only write the tensor value without layout information. It would be | |||
| given a :class:`.TensorValueDumperContext` object as its sole argument. | |||
| :param output_strip_info: if set to True, then a json file containing | |||
| information for code strip would be written to ``fpath+'.json'`` | |||
| :param append: whether to open output file in append mode | |||
| :return: an instance of namedtuple :class:`CompGraphSerializationResult`, | |||
| whose fields are: | |||
| * ``nr_opr`` number of operators dumped | |||
| * ``tot_bytes`` total bytes for the whole graph | |||
| * ``tensor_value_bytes`` bytes consumed for dumping tensor values | |||
| * ``inputs`` names of input tensors | |||
| * ``params`` list of names of dumped params | |||
| * ``outputs`` names of output vars | |||
| :param format: serialization format of the resulting model, should be either | |||
| "mdl" or "fbs"; none means default. | |||
| :type format: ``str`` | |||
| """ | |||
| assert isinstance(fpath, str), "bad file path: {!r}".format(fpath) | |||
| ov = _detail._VectorSymbolVar() | |||
| SUPPORTED_FORMATS = { | |||
| # default | |||
| None: _detail.GraphDumpFormat_FLATBUFFERS, | |||
| "fbs": _detail.GraphDumpFormat_FLATBUFFERS, | |||
| } | |||
| resolved_fmt = SUPPORTED_FORMATS.get(format, None) | |||
| if resolved_fmt is None: | |||
| raise ValueError( | |||
| "unknown format {} requested, supported ones are {}".format( | |||
| format, list(filter(None, SUPPORTED_FORMATS.keys())) | |||
| ) | |||
| ) | |||
| if isinstance(output_vars, dict): | |||
| used_vars = set() | |||
| for name, var in output_vars.items(): | |||
| assert isinstance(var, _detail.SymbolVar), "bad output var: {!r}".format( | |||
| var | |||
| ) | |||
| assert var.id not in used_vars, ( | |||
| "var name is associated with a var object, so we can not have " | |||
| "two names given to the same var: {}".format(var) | |||
| ) | |||
| used_vars.add(var.id) | |||
| var.rename(name) | |||
| ov.push_back(var) | |||
| else: | |||
| for i in output_vars: | |||
| assert isinstance(i, _detail.SymbolVar), "bad output var: {!r}".format(i) | |||
| ov.push_back(i) | |||
| if tensor_value_dumper is not None: | |||
| assert isinstance(tensor_value_dumper, collections.Callable) | |||
| class Callback(_detail._TensorValueDumperCallback): | |||
| def call(self, ctx, *, _f=tensor_value_dumper): | |||
| _f(ctx) | |||
| tensor_value_dumper = Callback() | |||
| # for backward compatibility | |||
| mangle_opr_name = kwargs.pop("mangle_opr_name", ov) | |||
| if mangle_opr_name is not ov: | |||
| get_logger().warning("mangle_opr_name is deprecated; use keep_var_name instead") | |||
| keep_var_name = 1 if mangle_opr_name else 2 | |||
| mangle_param_name = kwargs.pop("mangle_param_name", ov) | |||
| assert ( | |||
| not kwargs | |||
| ), "extra kwargs provided to serialize_comp_graph_to_file: {}".format(kwargs) | |||
| if mangle_param_name is not ov: | |||
| get_logger().warning( | |||
| "mangle_param_name is deprecated; use keep_param_name instead" | |||
| ) | |||
| keep_param_name = not mangle_param_name | |||
| inputs = _detail._VectorString() | |||
| outputs = _detail._VectorString() | |||
| params = _detail._VectorString() | |||
| stat = _detail._VectorSizeT() | |||
| _detail._serialize_comp_graph_to_file( | |||
| fpath, | |||
| append, | |||
| resolved_fmt, | |||
| ov, | |||
| keep_var_name, | |||
| keep_param_name, | |||
| keep_opr_priority, | |||
| tensor_value_dumper, | |||
| stat, | |||
| inputs, | |||
| outputs, | |||
| params, | |||
| ) | |||
| dump_ret = CompGraphSerializationResult( | |||
| *stat, list(inputs), list(outputs), list(params) | |||
| ) | |||
| if output_strip_info: | |||
| with open(fpath + ".json", "w") as fout: | |||
| strip_info = _detail._get_info_for_strip(ov) | |||
| strip_info_dict = json.loads(strip_info) | |||
| strip_info_dict["hash"] = dump_ret.content_hash | |||
| json.dump(strip_info_dict, fout) | |||
| return dump_ret | |||
| CompGraphLoadResult = collections.namedtuple( | |||
| "CompGraphLoadResult", ["graph", "output_vars_dict", "output_vars_list"] | |||
| ) | |||
| def load_comp_graph_from_file( | |||
| fpath, *, comp_node_mapper=None, tensor_value_loader=None | |||
| ): | |||
| """Load a serialized computing graph from file. | |||
| :parma fpath: Path for the output file | |||
| :type fpath: ``str`` | |||
| :param comp_node_mapper: A callable to modify comp node locator, takes old | |||
| locator as argument and returns new locator. | |||
| :type comp_node_mapper: Callable[[str], str] | |||
| :param tensor_value_loader: A callable to load tensor values. It should | |||
| read the tensor value with the given shape and dtype and return it as | |||
| NumPy ndarray. It would be given a :class:`.TensorValueLoaderContext` | |||
| object as its sole argument. | |||
| :type tensor_value_loader: Callable[[TensorValueLoaderContext], numpy.ndarray] | |||
| :return: An instance of namedtuple :class:`CompGraphLoadResult`, | |||
| whose fields are: | |||
| * ``graph`` loaded CompGraph | |||
| * ``output_vars_dict`` A Python dict, mapping name to output SymbolVar | |||
| * ``output_vars_list`` A Python list, containing output vars in the | |||
| order passed to serialize_comp_graph_to_file | |||
| """ | |||
| assert isinstance(fpath, str), "bad file path: {!r}".format(fpath) | |||
| if comp_node_mapper is not None: | |||
| assert isinstance(comp_node_mapper, collections.Callable) | |||
| class Callback(_detail._CompNodeMapperCallback): | |||
| def call(self, desc, *, _f=comp_node_mapper): | |||
| return _f(desc) | |||
| comp_node_mapper = Callback() | |||
| if tensor_value_loader is not None: | |||
| assert isinstance(tensor_value_loader, collections.Callable) | |||
| class Callback(_detail._TensorValueLoaderCallback): | |||
| def call(self, ctx, *, _f=tensor_value_loader): | |||
| return _f(ctx) | |||
| tensor_value_loader = Callback() | |||
| output_vars_map = _detail._VectorPairStringSymbolVar() | |||
| output_vars_list = _detail._VectorSymbolVar() | |||
| cg = _detail._load_comp_graph_from_file( | |||
| fpath, comp_node_mapper, tensor_value_loader, output_vars_map, output_vars_list | |||
| ) | |||
| return CompGraphLoadResult(cg, dict(list(output_vars_map)), list(output_vars_list)) | |||
| def optimize_for_inference( | |||
| output_vars, | |||
| *, | |||
| f16_io_f32_comp=False, | |||
| f16_io_comp=False, | |||
| use_nhwcd4=False, | |||
| fuse_conv_bias_nonlinearity=False, | |||
| use_nchw32=False, | |||
| fuse_conv_bias_with_z=False, | |||
| use_nchw4=False, | |||
| use_nchw88=False, | |||
| use_nchw44=False, | |||
| use_nchw44_dot=False, | |||
| use_chwn4=False | |||
| ): | |||
| """optimize computing graph for inference | |||
| This applies a predefined set of optimization passes. Refer to the mnist | |||
| sdk example and C++ code for fine-grained control. | |||
| :param output_vars: output symvars | |||
| :type output_vars: list of :class:`.SymbolVar` | |||
| :param f16_io_f32_comp: whether to use float16 for I/O between oprs and use | |||
| float32 as internal computation precision. Note the output var would be | |||
| changed to float16 | |||
| :param f16_io_comp: whether to use float16 for both I/O and computation | |||
| precision | |||
| :param use_nhwcd4: whether to use NHWCD4 data format. This is faster on some | |||
| OpenCL devices | |||
| :param fuse_conv_bias_nonlinearity: whether to fuse conv+bias+nonlinearty | |||
| into one opr. This is supported only in NHWCD4 format. | |||
| :param use_nchw4: whether to use NCHW4 tensor format. | |||
| :param use_nchw88: whether to use NCHW88 tensor format. This maybe faster some | |||
| times. | |||
| :param use_nchw44: whether to use NCHW44 tensor format. This maybe faster some | |||
| times. | |||
| :param use_nchw44_dot: whether to use NCHW44_DOT tensor format. This format is | |||
| optimized for inference in armv8.2 | |||
| :param use_nchw32: whether to use NCHW32 tensor format. Mainly used for | |||
| nvidia tensorcore. | |||
| :param use_chwn4: whether to use CHWN4 tensor format. Mainly used for | |||
| nvidia tensorcore. | |||
| :return: list of transformed vars corresponding to given output vars | |||
| """ | |||
| assert isinstance(output_vars, (list, tuple)) | |||
| opt = _detail._OptimizeForInferenceOptions() | |||
| settings = locals() | |||
| for i in [ | |||
| "f16_io_f32_comp", | |||
| "f16_io_comp", | |||
| "fuse_conv_bias_nonlinearity", | |||
| "fuse_conv_bias_with_z", | |||
| ]: | |||
| if settings[i]: | |||
| getattr(opt, "enable_{}".format(i))() | |||
| layout_tranform = None | |||
| for k, v in { | |||
| "use_nchw4": "nchw4", | |||
| "use_nhwcd4": "nhwcd4", | |||
| "use_nchw32": "nchw32", | |||
| "use_nchw88": "nchw88", | |||
| "use_nchw44": "nchw44", | |||
| "use_nchw44_dot": "nchw44_dot", | |||
| "use_chwn4": "chwn4", | |||
| }.items(): | |||
| if settings[k]: | |||
| assert ( | |||
| not layout_tranform | |||
| ), "Only one layout transform supported, both {} and {}".format( | |||
| layout_tranform, k | |||
| ) | |||
| getattr(opt, "enable_{}".format(v))() | |||
| layout_tranform = k | |||
| vec = _detail._VectorSymbolVar() | |||
| for i in output_vars: | |||
| assert isinstance(i, _detail.SymbolVar), "bad var: {}".format(i) | |||
| vec.push_back(i) | |||
| return list(_detail._optimize_for_inference(vec, opt)) | |||
| def get_opr_fp_graph_exec(comp_graph, output_vars): | |||
| """get opr footprint and graph exec info | |||
| This function will recompile the compute graph, the AsyncExecutable compiled | |||
| before will be invalid. | |||
| :param comp_graph: ComputingGraph | |||
| :param output_vars: list of :class:'.SymbolVar' | |||
| """ | |||
| assert isinstance(output_vars, (list, tuple)) | |||
| vec = _detail._VectorSymbolVar() | |||
| for i in output_vars: | |||
| assert isinstance(i, _detail.SymbolVar), "bad var: {}".format(i) | |||
| vec.push_back(i) | |||
| return json.loads(_detail._get_opr_fp_graph_exec(comp_graph, output_vars)) | |||
| def to_mgb_supported_dtype(dtype_): | |||
| """get the dtype supported by megbrain nearest to given dtype""" | |||
| if ( | |||
| dtype.is_lowbit(dtype_) | |||
| or dtype.is_quantize(dtype_) | |||
| or dtype.is_bfloat16(dtype_) | |||
| ): | |||
| return dtype_ | |||
| return _detail._to_mgb_supported_dtype(dtype_) | |||
| def return_free_memory(): | |||
| """return free memory chunks on all devices. | |||
| This function will try it best to free all consecutive free chunks back to | |||
| operating system, small pieces may not be returned. | |||
| Please notice that this function will not move any memory in-use. | |||
| """ | |||
| _detail.CompNode._try_coalesce_all_free_memory() | |||
| @@ -1,37 +0,0 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import argparse | |||
| import os | |||
| import sys | |||
| import megengine._internal.mgb as _mgb | |||
| try: | |||
| from setproctitle import setproctitle | |||
| except ImportError: | |||
| setproctitle = None | |||
| def main(): | |||
| parser = argparse.ArgumentParser( | |||
| description="entry point for fork-exec callback in TimedFuncInvoker;" | |||
| " this file should not be used directly by normal user." | |||
| ) | |||
| parser.add_argument("user_data") | |||
| args = parser.parse_args() | |||
| if setproctitle: | |||
| setproctitle("megbrain:timed_func_exec:ppid={}".format(os.getppid())) | |||
| _mgb._timed_func_exec_cb(args.user_data) | |||
| raise SystemError("_timed_func_exec_cb returned") | |||
| if __name__ == "__main__": | |||
| main() | |||
| @@ -1,274 +0,0 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| """tools for graph manipulation""" | |||
| import collections | |||
| from . import mgb as _mgb | |||
| def get_dep_vars(var, var_type=None): | |||
| """return :class:`.SymbolVar` of type ``var_type`` that input ``var`` | |||
| depands on. If ``var_type`` is None, return all types. | |||
| :type var: an instance or iterable of :class:`.SymbolVar` | |||
| :type var_type: ``str`` or an iterable of ``str`` | |||
| "rtype: list of :class:`.SymbolVar` | |||
| """ | |||
| outputs = [] | |||
| memo = set() | |||
| if isinstance(var, _mgb.SymbolVar): | |||
| var = [var] | |||
| if isinstance(var_type, str): | |||
| var_type = [var_type] | |||
| q = list(var) | |||
| while q: | |||
| v = q.pop() | |||
| if v in memo: | |||
| continue | |||
| memo.add(v) | |||
| q.extend(get_inputs(v)) | |||
| if var_type is not None: | |||
| if get_type(v) in var_type: | |||
| outputs.append(v) | |||
| else: | |||
| outputs.append(v) | |||
| return outputs | |||
| def get_inputs(var): | |||
| """get the inputs of owner opr of a variable | |||
| :type var: :class:`.SymbolVar` | |||
| :rtype: list of :class:`.SymbolVar` | |||
| """ | |||
| assert isinstance(var, _mgb.SymbolVar) | |||
| return _mgb._get_owner_opr_inputs(var) | |||
| def get_type(var): | |||
| """get the type of owner opr of a variable | |||
| :type var: :class:`.SymbolVar` | |||
| :rtype: ``str`` | |||
| """ | |||
| assert isinstance(var, _mgb.SymbolVar) | |||
| return _mgb._get_owner_opr_type(var) | |||
| def get_opr_type(opr): | |||
| """get the type of a opr | |||
| :type var: :class:`.Operator` | |||
| :rtype: ``str`` | |||
| """ | |||
| assert isinstance(opr, _mgb.Operator) | |||
| return _mgb._get_opr_type(opr) | |||
| def graph_traversal(outputs): | |||
| """helper function to traverse the computing graph and reeturn enough useful information | |||
| :param outputs: model outputs | |||
| :type outputs: :class:`.Symbolvar` | |||
| :return: tuple (map_oprs, map_vars, var2oprs, opr2receivers, indegree2opr, opr2indegree) | |||
| WHERE | |||
| map_oprs is dict from opr_id to actual opr | |||
| map_vars is dict from var_id to actual var | |||
| var2oprs is dict from var to dest oprs along with index | |||
| opr2receivers is dict from current opr to next opr | |||
| indegree2opr is dict from in_degree to opr in computing graph | |||
| opr2indegree is dict from opr in computing graph to in_degree | |||
| (indegree2opr, opr2indegree) are only used in topological sort in get_oprs_seq function | |||
| """ | |||
| # meta information for comp graph | |||
| map_oprs = collections.defaultdict(set) | |||
| map_vars = collections.defaultdict(set) | |||
| var2oprs = collections.defaultdict(list) | |||
| opr2receivers = collections.defaultdict(list) | |||
| queue = list(map(lambda x: x.owner_opr, outputs)) | |||
| visited = set(map(lambda x: x.id, queue)) | |||
| # iterate through whole comp_graph, fill in meta information | |||
| indegree2opr = collections.defaultdict(set) | |||
| opr2indegree = {} | |||
| idx = 0 | |||
| while idx < len(queue): | |||
| cur_opr = queue[idx] | |||
| map_oprs[cur_opr.id] = cur_opr | |||
| idx += 1 | |||
| indegree = 0 | |||
| for var_idx, var in enumerate(cur_opr.inputs): | |||
| map_vars[var.id] = var | |||
| var2oprs[var.id].append((cur_opr.id, var_idx)) | |||
| pre_opr = var.owner_opr | |||
| if pre_opr.id not in visited: | |||
| visited.add(pre_opr.id) | |||
| queue.append(pre_opr) | |||
| indegree += 1 | |||
| opr2receivers[pre_opr.id].append(cur_opr.id) | |||
| indegree2opr[indegree].add(cur_opr.id) | |||
| opr2indegree[cur_opr.id] = indegree | |||
| return map_oprs, map_vars, var2oprs, opr2receivers, indegree2opr, opr2indegree | |||
| def get_oprs_seq(outputs, prune_reshape=False): | |||
| """get oprs in some topological order for a dumped model | |||
| :param outputs: model outputs | |||
| :param prune_reshape: whether to prune the operators useless during inference | |||
| :return: opr list with some correct execution order | |||
| """ | |||
| def topological_sort(map_oprs, opr2receivers, indegree2opr, opr2indegree): | |||
| # generate an execution order with topological sort algorithm | |||
| oprs_seq = [] | |||
| nr_remain = len(map_oprs) | |||
| while indegree2opr[0]: | |||
| opr_id = indegree2opr[0].pop() | |||
| opr = map_oprs[opr_id] | |||
| nr_remain -= 1 | |||
| # skip const value generation operator | |||
| if get_opr_type(opr) != "ImmutableTensor": | |||
| oprs_seq.append(opr) | |||
| for post_id in opr2receivers[opr_id]: | |||
| indegree = opr2indegree[post_id] | |||
| indegree2opr[indegree].remove(post_id) | |||
| indegree -= 1 | |||
| indegree2opr[indegree].add(post_id) | |||
| opr2indegree[post_id] = indegree | |||
| assert nr_remain == 0, "there are {} remaining nodes; cyclic graph?".format( | |||
| nr_remain | |||
| ) | |||
| return oprs_seq | |||
| # reshape op definition: reshape(input_tensor, dest_shape) -> output_tensor | |||
| # when inferencing, shape of output_tensor is already known, so one can prune some operators related to dest_shape in the loaded graph | |||
| def prune_reshape_oprs(outputs, oprs_seq, var2oprs): | |||
| def iterative_pruning(cur_opr, post_opr, marked_opr_ids): | |||
| useless = True | |||
| for oup in cur_opr.outputs: | |||
| if "workspace" not in oup.name: | |||
| var_idx = post_opr.inputs.index(oup) | |||
| var2oprs[oup.id].remove((post_opr.id, var_idx)) | |||
| useless = useless and (len(var2oprs[oup.id]) == 0) | |||
| if useless: | |||
| marked_opr_ids.append(cur_opr.id) | |||
| for inp in cur_opr.inputs: | |||
| iterative_pruning(inp.owner_opr, cur_opr, marked_opr_ids) | |||
| reshape_vars = get_dep_vars(outputs, "Reshape") | |||
| reshape_oprs = [var.owner_opr for var in reshape_vars] | |||
| marked_opr_ids = [] | |||
| for reshape_opr in reshape_oprs: | |||
| iterative_pruning( | |||
| reshape_opr.inputs[1].owner_opr, reshape_opr, marked_opr_ids | |||
| ) | |||
| # filter out all marked oprs | |||
| return list(filter(lambda x: x.id not in marked_opr_ids, oprs_seq)) | |||
| map_oprs, _, var2oprs, opr2receivers, indegree2opr, opr2indegree = graph_traversal( | |||
| outputs | |||
| ) | |||
| oprs_seq = topological_sort(map_oprs, opr2receivers, indegree2opr, opr2indegree) | |||
| if prune_reshape is True: | |||
| oprs_seq = prune_reshape_oprs(outputs, oprs_seq, var2oprs.copy()) | |||
| return oprs_seq | |||
| def replace_vars(dst, varmap): | |||
| """replace vars in the graph | |||
| :param dst: target vars representing the graph | |||
| :type dst: list of :class:`.SymbolVar` | |||
| :param varmap: the map that specifies how to replace the vars | |||
| :type varmap: dict that maps from src var to dst var | |||
| :return: new vars that correspond to ``dst`` with all the dependencies | |||
| replaced | |||
| :rtype: list of :class:`.SymbolVar` | |||
| """ | |||
| dst_vec = _mgb._VectorSymbolVar() | |||
| repl_src_vec = _mgb._VectorSymbolVar() | |||
| repl_dst_vec = _mgb._VectorSymbolVar() | |||
| for i in dst: | |||
| assert isinstance(i, _mgb.SymbolVar) | |||
| dst_vec.push_back(i) | |||
| for i, j in getattr(varmap, "items", lambda: varmap)(): | |||
| assert isinstance(i, _mgb.SymbolVar) | |||
| assert isinstance(j, _mgb.SymbolVar) | |||
| repl_src_vec.push_back(i) | |||
| repl_dst_vec.push_back(j) | |||
| return _mgb._replace_vars(repl_src_vec, repl_dst_vec, dst_vec) | |||
| def replace_oprs(dst, oprmap): | |||
| """Replace operators in the graph. Roughly equivalent to | |||
| :param dst: target vars representing the graph | |||
| :type dst: list of :class:`.SymbolVar` | |||
| :param oprmap: the map that specifies how to replace the operators | |||
| :type oprmap: dict that maps from src operator to dst operator | |||
| :return: new vars that correspond to ``dst`` with all the dependencies | |||
| replaced | |||
| :rtype: list of :class:`.SymbolVar` | |||
| """ | |||
| dst_vec = _mgb._VectorSymbolVar() | |||
| repl_src_vec = _mgb._VectorOperator() | |||
| repl_dst_vec = _mgb._VectorOperator() | |||
| for i in dst: | |||
| assert isinstance(i, _mgb.SymbolVar) | |||
| dst_vec.push_back(i) | |||
| for i, j in getattr(oprmap, "items", lambda: oprmap)(): | |||
| assert isinstance(i, _mgb.Operator) | |||
| assert isinstance(j, _mgb.Operator) | |||
| repl_src_vec.push_back(i) | |||
| repl_dst_vec.push_back(j) | |||
| return _mgb._replace_oprs(repl_src_vec, repl_dst_vec, dst_vec) | |||
| def set_priority_to_id(dest_vars): | |||
| """For all oprs in the subgraph constructed by dest_vars | |||
| set its priority to id if its original priority is zero | |||
| :param dest_vars: target vars representing the graph | |||
| """ | |||
| dest_vec = _mgb._VectorSymbolVar() | |||
| for i in dest_vars: | |||
| assert isinstance(i, _mgb.SymbolVar) | |||
| dest_vec.push_back(i) | |||
| _mgb._set_priority_to_id(dest_vec) | |||
| @@ -1,439 +0,0 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import collections | |||
| import os | |||
| from . import mgb as _mgb | |||
| _default_device_type = "CUDA" | |||
| def set_device_map(logical_dev, physical_dev, device_type=None): | |||
| """map from *logical_dev* to *physical_dev* for furture comp node | |||
| loading | |||
| example:: | |||
| set_device_map(0, 2, 'CPU') # cpu0 -> cpu2 | |||
| set_device_map('gpu3', 'gpu0') # gpu0 -> gpu0 | |||
| :param device_type: specify the device type if devices are given by | |||
| integers; if devices are given by integers and ``device_type`` is not | |||
| given, the default value ``'CUDA'`` would be used. Possible values are | |||
| ``'CUDA'`` and ``'CPU'``. | |||
| """ | |||
| if device_type is None: | |||
| device_type = _default_device_type | |||
| if device_type == "CUDA": | |||
| xpu = "gpu" | |||
| else: | |||
| assert device_type == "CPU" | |||
| xpu = "cpu" | |||
| def rmxpu(v): | |||
| if isinstance(v, str): | |||
| assert v.startswith(xpu) or v.startswith("xpu"), ( | |||
| "bad comp node in set_device_map: " | |||
| "device_type={} comp_node={}".format(device_type, v) | |||
| ) | |||
| return v[3:] | |||
| return v | |||
| logical_dev, physical_dev = map(rmxpu, [logical_dev, physical_dev]) | |||
| _mgb.CompNode._set_device_map(device_type, int(logical_dev), int(physical_dev)) | |||
| def set_default_device(physical_dev, device_type=None): | |||
| """set physcal device for xpux | |||
| when *device_type* is None and *physical_dev* starts with *gpu* or *cpu*, | |||
| the default device type would be modified accordingly for future calls to | |||
| :func:`set_device_map` when remapping device number. | |||
| """ | |||
| global _default_device_type | |||
| if ( | |||
| device_type is None | |||
| and isinstance(physical_dev, str) | |||
| and not physical_dev.isdigit() | |||
| and not physical_dev.startswith("xpu") | |||
| ): | |||
| t = physical_dev[:3] | |||
| if t == "gpu": | |||
| _default_device_type = "CUDA" | |||
| else: | |||
| assert t == "cpu", "bad physical_dev: {}".format(physical_dev) | |||
| _default_device_type = "CPU" | |||
| set_default_device_type(_default_device_type) | |||
| device_type = _default_device_type | |||
| set_device_map(-1, physical_dev, device_type) | |||
| def set_default_device_type(device_type): | |||
| """set device type for xpu""" | |||
| global _default_device_type | |||
| device_type = device_type.upper() | |||
| _mgb.CompNode._set_unspec_device_type(device_type) | |||
| _default_device_type = device_type | |||
| def set_fork_cuda_warning_flag(flag): | |||
| """set warning to be printed at fork if cuda has been initialized | |||
| :type flag: int | |||
| :param flag: controls how the warning should be printed: | |||
| * 0: disable warning | |||
| * 1: print warning to log | |||
| * 2: print warning to log and raise exception | |||
| """ | |||
| _mgb._config.set_fork_cuda_warning_flag(int(flag)) | |||
| def get_device_count(device_type="xpu", warn=True): | |||
| """get number of devices installed on this system | |||
| :param device_type: device type, one of 'xpu', 'gpu' or 'cpu' | |||
| :type device_type: str | |||
| """ | |||
| return _mgb.CompNode._get_device_count(device_type.upper(), warn) | |||
| def parse_locator(device_name: str) -> tuple: | |||
| """get the tensor locator expression by device name. | |||
| :param device_name: device name, like 'cpu0', 'gpu1' and 'xpux' | |||
| :type device_name: str | |||
| :return: (device_type, dev_num, stream_num) | |||
| """ | |||
| return _mgb.CompNode._parse_locator(device_name) | |||
| def set_mem_reserve_size(size): | |||
| """set memory reserve size: | |||
| * If *size* is greater than 1, it is the absolute amount of memory to | |||
| be reserved in MB; | |||
| * If *size* is in the range (0, 1), it is the ratio of total memory; | |||
| * If *size* is 0, memory reservation and pre-allocation would be | |||
| disabled; | |||
| * If *size* is -1, disable custom memory allocator and use cuda APIs | |||
| directly. | |||
| """ | |||
| _mgb._config.set_mem_reserve_size(float(size)) | |||
| def set_comp_graph_option(comp_graph, name, val): | |||
| """set computing graph option and return its old value | |||
| :type comp_graph: :class:`.CompGraph` | |||
| :param comp_graph: the computing graph whose option should be modified | |||
| :type name: str | |||
| :param name: option name | |||
| Currently supported options are: | |||
| * "no_profiling_on_shape_change": bool; | |||
| When execution strategy is set to profiling, always use the | |||
| initial profile result and do not re-run profiling even if input | |||
| shape changes. | |||
| * "seq_opt.enable_mem_plan_opt": bool | |||
| * "seq_opt.enable_mem_reuse_alloc": bool | |||
| * "seq_opt.enable_seq_comp_node_opt": bool | |||
| * "force_dynamic_alloc": bool | |||
| * "var_sanity_check_first_run": bool | |||
| * "enable_sublinear_memory_opt": bool | |||
| * "enable_memory_swap": bool; whether to enable memory swap; it | |||
| usually performs worse than sublinear memory | |||
| * "enable_var_mem_defragment": bool | |||
| * "allocate_static_mem_after_graph_compile": bool | |||
| * "enable_grad_var_static_reshape": bool: | |||
| If set to ``True``, dynamically-shaped gradients whose original | |||
| shape is statically inferrable would be reshaped, so static | |||
| shape inference can continue | |||
| * "async_exec_level": int | |||
| * ``0``: do not dispatch asynchronously | |||
| * ``1``: async dispatch if there are more than 1 cuda comp | |||
| nodes | |||
| * mask ``0b10``: async for comp nodes with unlimited queue | |||
| (e.g. CPU comp nodes) | |||
| * mask ``0b100``: async for even one comp node | |||
| * "log_level": int | |||
| * ``0``: no log info for graph construction/compiling | |||
| * ``1``: static memory allocation status, | |||
| WorkspaceLimitGetter summary, and optimizer summary | |||
| * ``2``: optimizer details and duplicated operators tha are | |||
| removed | |||
| * "graph_opt.jit": whether to enable JIT | |||
| * "graph_opt.tensorrt": whether to enable fine-grained automatic | |||
| replacement for TensorRT operators | |||
| * "graph_opt.android_nn": whether to enable fine-grained automatic | |||
| replacement for Android NN operators | |||
| * "graph_opt_level": int | |||
| * ``0``: disable | |||
| * ``1``: level-1: inplace arith transformations during graph | |||
| construction | |||
| * ``2``: (default) level-2: level-1, plus global optimization | |||
| before graph compiling | |||
| * ``3``: also enable JIT | |||
| :param val: new option value | |||
| :return: old option value | |||
| """ | |||
| if name == "log_static_mem_alloc": | |||
| name = "log_level" | |||
| if name == "enable_async_exec": | |||
| name = "async_exec_level" | |||
| return _mgb._config.set_comp_graph_option(comp_graph, name, int(val)) | |||
| def comp_graph_is_eager(comp_graph): | |||
| return _mgb._config.comp_graph_is_eager(comp_graph) | |||
| def add_extra_vardep(var, dep): | |||
| """add *dep* as an extra dependency of *var*, so if *var* is required to | |||
| compute the final output when compiling a comp graph, *dep* would also be | |||
| included in the computing sequence. Note that the order computing of these | |||
| two vars is not guaranteed. | |||
| """ | |||
| assert isinstance(var, _mgb.SymbolVar) and isinstance(dep, _mgb.SymbolVar) | |||
| assert var.owner_graph == dep.owner_graph | |||
| return _mgb._config.add_extra_vardep(var, dep) | |||
| class _GraphPropertyBase: | |||
| """helper class for implementing operator property setter context managers""" | |||
| _cur_graph = None | |||
| _graph2stack = None | |||
| """class attribute that maintains mapping from graph to property stack; | |||
| should be defined by child classes""" | |||
| __prop_setup__ = None | |||
| """overwritten by subclass to setup property""" | |||
| __prop_clear__ = None | |||
| """overwritten by subclass to clear property""" | |||
| def __init__(self, comp_graph, prop): | |||
| """:param comp_graph: computing graph, or None to not set this | |||
| property""" | |||
| if comp_graph is not None: | |||
| assert isinstance( | |||
| comp_graph, _mgb.CompGraph | |||
| ), "invalid comp graph: {!r}".format(comp_graph) | |||
| self._cur_graph = comp_graph | |||
| self._graph2stack.setdefault(comp_graph, []).append(prop) | |||
| def __setup(self, prop): | |||
| self.__prop_setup__(self._cur_graph, prop) | |||
| def __clear(self): | |||
| self.__prop_clear__(self._cur_graph) | |||
| def __enter__(self): | |||
| if self._cur_graph is None: | |||
| return | |||
| stack = self._graph2stack[self._cur_graph] | |||
| if len(stack) > 1: | |||
| # clear nested property | |||
| self.__clear() | |||
| self.__setup(stack[-1]) | |||
| def __exit__(self, exc_type, exc_value, exc_traceback): | |||
| if self._cur_graph is None: | |||
| return | |||
| stack = self._graph2stack[self._cur_graph] | |||
| self.__clear() | |||
| stack.pop() | |||
| if stack: | |||
| # restore nested property | |||
| self.__setup(stack[-1]) | |||
| else: | |||
| del self._graph2stack[self._cur_graph] | |||
| class exc_opr_tracker_scope(_GraphPropertyBase): | |||
| """context manager for associating an object with all operators created | |||
| within this context; so when an exception is raised, information about the | |||
| corresponding operator could be retrieved from | |||
| :attr:`.MegBrainError.tracker` | |||
| :param comp_graph: the computing graph where the operators should be tracked | |||
| :type comp_graph: :class:`.CompGraph` | |||
| :param tracker: an arbitrary python object to track the operators | |||
| """ | |||
| _graph2stack = {} | |||
| def __init__(self, comp_graph, tracker): | |||
| assert ( | |||
| tracker is not None | |||
| ), "bad args for exc_opr_tracker_scope: {!r} {!r}".format(comp_graph, tracker) | |||
| super().__init__(comp_graph, tracker) | |||
| __prop_setup__ = staticmethod(_mgb._config.begin_set_exc_opr_tracker) | |||
| __prop_clear__ = staticmethod(_mgb._config.end_set_exc_opr_tracker) | |||
| class opr_priority_scope(_GraphPropertyBase): | |||
| """context manager for setting priority for all operators created in this | |||
| context | |||
| :param comp_graph: the computing graph for which operator priority should | |||
| be set | |||
| :type comp_graph: :class:`.CompGraph` | |||
| :param priority: operator priority. Smaller number means higher priority. | |||
| Default value is 0. Grad operator would use negative priority by | |||
| default. | |||
| """ | |||
| _graph2stack = {} | |||
| LOWEST_PRIORITY = 2 ** 31 - 1 | |||
| """lowest prority (i.e. max possible value)""" | |||
| HIGHEST_PRIORITY = -LOWEST_PRIORITY | |||
| """highest prority (i.e. min possible value)""" | |||
| def __init__(self, comp_graph, priority): | |||
| super().__init__(comp_graph, int(priority)) | |||
| __prop_setup__ = staticmethod(_mgb._config.begin_set_opr_priority) | |||
| __prop_clear__ = staticmethod(_mgb._config.end_set_opr_priority) | |||
| OprTrackerResult = collections.namedtuple( | |||
| "OprTrackerResult", ["msg", "tracker", "grad_tracker"] | |||
| ) | |||
| def get_opr_tracker(cg, var_id): | |||
| """get the tracking object associated with the owner operator of a var | |||
| :param cg: the computing graph | |||
| :param var_id: id of the var whose owner opr tracker should be found | |||
| :return: if no var is found, ``None`` is returned; otherwise return an | |||
| :class:`OprTrackerResult` object | |||
| """ | |||
| assert isinstance(cg, _mgb.CompGraph) | |||
| ret = _mgb._config.get_opr_tracker(cg, int(var_id)) | |||
| if ret is None: | |||
| return | |||
| return OprTrackerResult(*ret) | |||
| def set_opr_sublinear_memory_endpoint(var): | |||
| """set the owner operator of a symvar to be endpoint of sublinear memory | |||
| optimizer | |||
| :type var: :class:`.SymbolVar` | |||
| """ | |||
| _mgb._config.set_opr_sublinear_memory_endpoint(var) | |||
| def max_size_t(): | |||
| """get max value of size_t type on local architecture""" | |||
| return _mgb.max_size_t() | |||
| def is_cuda_ctx_set(): | |||
| """return whether current thread has an active cuda driver context""" | |||
| return _mgb._config.is_cuda_ctx_set() | |||
| def get_include_path(): | |||
| """get include path for building megbrain extensions""" | |||
| return os.path.join(os.path.realpath(os.path.dirname(__file__)), "include") | |||
| def get_cuda_gencode(only_cap=False): | |||
| """get -gencode options to be passed to nvcc for compiling on local | |||
| machine | |||
| :param only_cap: if True, return only a list of cuda compute capability | |||
| strings (like ``['35', '52']`` ) | |||
| """ | |||
| ret = _mgb._config.get_cuda_gencode().split() | |||
| if not only_cap: | |||
| ret = " ".join(map("-gencode arch=compute_{0},code=sm_{0}".format, ret)) | |||
| return ret | |||
| def get_cuda_lib_path(): | |||
| """get the cuda lib64 path by locating nvcc | |||
| """ | |||
| return _mgb._config.get_cuda_lib_path() | |||
| def get_cuda_include_path(): | |||
| """get the cuda include path by locating nvcc, including | |||
| parent path and `parent path`/include | |||
| """ | |||
| return _mgb._config.get_cuda_include_path() | |||
| def get_cuda_version(): | |||
| """get runtime cuda version | |||
| """ | |||
| return _mgb._config.get_cuda_version() | |||
| def is_local_cuda_env_ok(): | |||
| """check whether local cuda environment ok by locating nvcc | |||
| """ | |||
| return _mgb._config.is_local_cuda_env_ok() | |||
| def is_compiled_with_cuda(): | |||
| """whether cuda is enabled at compile time""" | |||
| return _mgb._config.is_compiled_with_cuda() | |||
| def load_opr_library(path): | |||
| """Load an external operator library. This essentially sets megbrain | |||
| symbols as public and load the library. | |||
| :param path: path to the shared object; if it is None, then only megbrain | |||
| symbols are made public. | |||
| """ | |||
| _mgb._config.load_opr_library( | |||
| os.path.realpath(os.path.join(os.path.dirname(__file__), "_mgb.so")), path | |||
| ) | |||
| def dump_registered_oprs(): | |||
| """ | |||
| get all registered oprs, return dict(id, name) | |||
| """ | |||
| return dict(_mgb._config.dump_registered_oprs()) | |||
| def create_mm_server(server_addr, port): | |||
| """ | |||
| create mm server with server address | |||
| throw exception if server_addr is already used | |||
| """ | |||
| return _mgb._config.create_mm_server(server_addr, port) | |||
| def group_barrier(server_addr, port, size, rank): | |||
| """ | |||
| block until all ranks reach this barrier | |||
| """ | |||
| return _mgb._config.group_barrier(server_addr, port, size, rank) | |||
| @@ -1,432 +0,0 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| """used for creating a megbrain operator from python""" | |||
| import copy | |||
| import itertools | |||
| from abc import ABCMeta, abstractmethod, abstractproperty | |||
| from . import helper as _helper | |||
| from . import mgb as _mgb | |||
| class _CraniotomeBaseMeta(ABCMeta): | |||
| _base_created = False | |||
| def __init__(cls, name, bases, member_dict): | |||
| if _CraniotomeBaseMeta._base_created: | |||
| assert "__init__" not in member_dict, ( | |||
| "Craniotome operators should not overwrite __init__ method; " | |||
| "use setup() instead." | |||
| ) | |||
| forbidden = set( | |||
| k for k in dir(CraniotomeBase) if k[0] == "_" and k[1] != "_" | |||
| ) | |||
| forbidden.add("get_io_vars") | |||
| check_key = member_dict.get("__check_key__", True) | |||
| whitelist = ["__classcell__"] | |||
| for k in member_dict.keys(): | |||
| assert k not in forbidden, "{} could not be overwritten".format(k) | |||
| if ( | |||
| check_key | |||
| and k.startswith("__") | |||
| and k.endswith("__") | |||
| and k not in whitelist | |||
| and not hasattr(CraniotomeBase, k) | |||
| ): | |||
| raise KeyError( | |||
| "name {} in class {} does not exist in the baseclass".format( | |||
| k, name | |||
| ) | |||
| ) | |||
| else: | |||
| _CraniotomeBaseMeta._base_created = True | |||
| super().__init__(name, bases, member_dict) | |||
| class CraniotomeBase(_mgb.CraniotomeDesc, metaclass=_CraniotomeBaseMeta): | |||
| """base class used for extending megbrain core operators in python | |||
| Note: all names starting and ending with two underscores in the subclasses | |||
| would be checked and KeyError would be raised if the name does not exist in | |||
| the base class. This behavor can be disabled by setting ``__check_key__`` | |||
| to ``False`` (see the testcase for more details) | |||
| """ | |||
| # methods and attributes to be overwritten by subclasses | |||
| __expand_single_outputs__ = True | |||
| """if :attr:`__nr_outputs__` is 1, whether to return a single | |||
| :class:`.SymbolVar` instead of a tuple in :meth:`make`""" | |||
| __is_dynamic_output_shape__ = False | |||
| """whether output shape could not be inferred from input shape. If value of | |||
| this attribute is ``False``, :meth:`infer_shape` must be implemented. If | |||
| this attribute is ``True`` but the operator has no inputs, then | |||
| :meth:`infer_shape` would also be called to infer output shape before | |||
| operator execution. | |||
| """ | |||
| __disable_sys_mem_alloc__ = False | |||
| """whether to disable system memory allocator. This is used when | |||
| :attr:`__is_dynamic_output_shape__` is ``False`` but the output memory | |||
| should not be managed by megbrain system (so it can be forwarded from | |||
| external buffer)""" | |||
| __allow_duplicate__ = True | |||
| """whether this operator can be duplicated (e.g. used in sublinear | |||
| memory)""" | |||
| __allow_empty_out__ = False | |||
| """whether empty output shape is allowed; if it is set as ``False``, then | |||
| an exception would be raised if output var is empty to prevent erroneously | |||
| forgetting initializing output vars""" | |||
| @abstractproperty | |||
| def __nr_inputs__(self): | |||
| """number of input vars""" | |||
| @abstractproperty | |||
| def __nr_outputs__(self): | |||
| """number of output vars""" | |||
| @abstractmethod | |||
| def execute(self, inputs, outputs): | |||
| """execute the operator, read values from *inputs* by calling | |||
| :meth:`.CompGraphCallbackValueProxy.get_value` and write results into | |||
| *outputs* by calling :meth:`.SharedND.set_value` | |||
| :param inputs: values for each input var | |||
| :type inputs: tuple of :class:`.CompGraphCallbackValueProxy` | |||
| :param outputs: values for each output var | |||
| :type outputs: tuple of :class:`.SharedND` | |||
| """ | |||
| def setup(self): | |||
| """overwritten by subclass to accept kwargs passed to :meth:`make` to | |||
| setup the operator""" | |||
| def infer_shape(self, inp_shapes): | |||
| """infer output shape from input shapes | |||
| :type inp_shapes: tuple of tuple of ints | |||
| :param inp_shapes: input shapes for each input var | |||
| :rtype: tuple of tuple of ints | |||
| :return: output shapes for each output var | |||
| """ | |||
| raise NotImplementedError( | |||
| "{}: infer_shape() not implemented; for operators with dynamic " | |||
| "output shape, __is_dynamic_output_shape__ should be set to True".format( | |||
| self | |||
| ) | |||
| ) | |||
| def grad(self, wrt_idx, inputs, outputs, out_grad): | |||
| """compute symbolic gradient; should be overwritten by differentiable | |||
| subclasses | |||
| :type wrt_idx: int | |||
| :param wrt_idx: the input var with respect to which the gradient should | |||
| be computed; please also see the notes below | |||
| :type inputs: tuple of :class:`.SymbolVar` | |||
| :param inputs: input symbol vars | |||
| :type outputs: tuple of :class:`.SymbolVar` | |||
| :param outputs: output symbol vars | |||
| :type out_grad: tuple of (:class:`.SymbolVar` or None) | |||
| :param out_grad: gradients of loss with respect to each output var | |||
| .. note:: | |||
| In case when loss does not depend on some var (i.e. zero grad), | |||
| the corresponding value in *out_grad* would be ``None``. It is | |||
| guaranteed that at least one element in *out_grad* is not | |||
| ``None``. | |||
| .. note:: | |||
| This function can return either of the following: | |||
| 1. Gradient of the input specified by ``wrt_idx`` | |||
| 2. A list containing gradients of all inputs. In this case, | |||
| ``wrt_idx`` can be ignored. | |||
| And the so called gradient can be either one of: | |||
| 1. A :class:`.SymbolVar` representing the symbolic gradient | |||
| value | |||
| 2. ``0`` representing zero gradient | |||
| """ | |||
| raise NotImplementedError("grad for {} not implemented".format(self)) | |||
| def init_output_dtype(self, input_dtypes): | |||
| """infer output dtypes from input dtypes; return None to use default | |||
| infer function in megbrain. | |||
| .. note:: | |||
| This method must be implemented if there is no input var | |||
| :param input_dtypes: input dtypes | |||
| :type input_dtypes: list of :class:`numpy.dtype` | |||
| :rtype: None or list of :class:`numpy.dtype`-compatible | |||
| """ | |||
| def get_serialize_params(self): | |||
| """get params for megbrain graph serialization. This function should | |||
| return a list or tuple, containing one or two elements: the first | |||
| element must be a string, representing the name passed to | |||
| ``opr_loader_maker`` during deserializing; the second element, if | |||
| exists, must be convertible to ``bytes`` and is used for dumping any | |||
| extra opr params, which can be retrieved by ``load_buf_with_len`` | |||
| during deserializing. | |||
| """ | |||
| raise NotImplementedError( | |||
| "get_serialize_params() for {} not implemented".format(self) | |||
| ) | |||
| def copy(self): | |||
| """copy this craniotome descriptor; the default implementation creates | |||
| a new object, and copies object ``__dict__``""" | |||
| ret = type(self)() | |||
| d0 = self.__dict__.copy() | |||
| d0.pop("this") | |||
| ret.__dict__.update(copy.deepcopy(d0)) | |||
| return ret | |||
| def on_graph_compiled(self, used_outputs): | |||
| """a callback that would be invoked when the graph is compiled; it | |||
| would always have a matching :meth:`on_compiled_func_deleted` call | |||
| :param used_outputs: indices of outputs that are needed for the | |||
| computation | |||
| :type used_outputs: ``tuple of int`` | |||
| """ | |||
| def on_compiled_func_deleted(self): | |||
| """a callback that would be invoked when the compiled function is | |||
| destructed; it would always have a matching :meth:`on_graph_compiled` | |||
| call""" | |||
| def get_io_vars(self): | |||
| """get input vars, comp order dep vars and output vars | |||
| :return: a dict with keys ``'input'``, ``'output'`` and | |||
| ``'comp_order'`` that maps to corresponding list of vars | |||
| """ | |||
| all_vars = list(self._get_all_io_vars()) | |||
| nr_inp = self.__nr_inputs__ | |||
| nr_out = self.__nr_outputs__ | |||
| nr_comp_order = self._get_nr_dev_comp_order_deps() | |||
| s0 = nr_inp + nr_comp_order | |||
| return dict( | |||
| input=all_vars[:nr_inp], | |||
| comp_order=all_vars[nr_inp:s0], | |||
| output=all_vars[s0:], | |||
| ) | |||
| @property | |||
| def owner_opr_id(self): | |||
| """ID of the operator that owns this descriptor""" | |||
| return self._get_opr_id() | |||
| @property | |||
| def comp_node(self): | |||
| """comp node on which this operator runs""" | |||
| return self._get_comp_node() | |||
| # below are methods that should not be changed | |||
| def _hash(self): | |||
| return int(hash(self)) % (1 << 64) | |||
| def _setup_self(self, dst): | |||
| dst.append(self) | |||
| def _is_same(self, rhs): | |||
| return bool(self == rhs) | |||
| def _node_flag(self): | |||
| return ( | |||
| (int(bool(self.__is_dynamic_output_shape__)) << 0) | |||
| | (int(not self.__allow_duplicate__) << 1) | |||
| | (int(bool(self.__allow_empty_out__)) << 2) | |||
| | (int(bool(self.__disable_sys_mem_alloc__)) << 3) | |||
| ) | |||
| def _get_opr_type_name(self): | |||
| return str(self.__class__.__name__) | |||
| def _get_nr_outputs(self): | |||
| return int(self.__nr_outputs__) | |||
| def _execute(self, inputs, outputs): | |||
| inputs = tuple(inputs) | |||
| outputs = tuple(outputs) | |||
| if not self.__is_dynamic_output_shape__: | |||
| out_shapes = [i.shape for i in outputs] | |||
| self.execute(inputs, outputs) | |||
| if not self.__is_dynamic_output_shape__: | |||
| new_shapes = [i.shape for i in outputs] | |||
| assert ( | |||
| out_shapes == new_shapes | |||
| ), "output shape changed after executing {}: before={} after={}".format( | |||
| self, out_shapes, new_shapes | |||
| ) | |||
| def _infer_shape(self, inp_shapes): | |||
| inp_shapes = tuple(tuple(map(int, i)) for i in inp_shapes) | |||
| oshp_get = self.infer_shape(inp_shapes) | |||
| assert ( | |||
| len(oshp_get) == self.__nr_outputs__ | |||
| ), "{}: expect {} outputs; got {}(val: {}) from infer_shape".format( | |||
| self, self.__nr_outputs__, len(oshp_get), oshp_get | |||
| ) | |||
| return _helper.cvt_to_vector_of_shape(oshp_get) | |||
| def _grad(self, wrt_idx, inputs, outputs, out_grad): | |||
| og = [] | |||
| for i in out_grad: | |||
| if i.valid: | |||
| og.append(i) | |||
| else: | |||
| og.append(None) | |||
| rst = self.grad(int(wrt_idx), tuple(inputs), tuple(outputs), tuple(og)) | |||
| if not isinstance(rst, (list, tuple)): | |||
| rst = [rst] | |||
| else: | |||
| assert len(rst) == len( | |||
| inputs | |||
| ), "{}: opr has {} inputs but {} grads are returned".format( | |||
| self, len(inputs), len(rst) | |||
| ) | |||
| for i in range(len(rst)): | |||
| cur = rst[i] | |||
| if cur is 0: | |||
| rst[i] = _mgb.SymbolVar() | |||
| else: | |||
| assert isinstance(cur, _mgb.SymbolVar), ( | |||
| "{}: invalid grad result; it should be either " | |||
| "0 or a SymbolVar, got {!r} instead".format(self, cur) | |||
| ) | |||
| return rst | |||
| def _get_nr_dev_comp_order_deps(self): | |||
| return 0 | |||
| def _init_output_dtype(self, input_dtypes, ret): | |||
| get = self.init_output_dtype(input_dtypes) | |||
| if get is not None: | |||
| assert isinstance(ret, (list, tuple)) and len(get) == len(ret) | |||
| ret[:] = get | |||
| return True | |||
| assert self.__nr_inputs__, ( | |||
| "{}: init_output_dtype must be implemented " | |||
| "if there is no input var".format(self) | |||
| ) | |||
| return False | |||
| def _setup_serialize_params(self, output): | |||
| val = list(self.get_serialize_params()) | |||
| assert len(val) in [1, 2] | |||
| name = val[0] | |||
| assert isinstance(name, str) | |||
| output.append(name) | |||
| if len(val) == 2: | |||
| output.append(bytes(val[1])) | |||
| def _copy(self): | |||
| ret = self.copy() | |||
| assert type(ret) is type( | |||
| self | |||
| ), "copy() returned different type: src={} copied={}".format( | |||
| type(self), type(ret) | |||
| ) | |||
| assert ret is not self | |||
| ret.__disown__() | |||
| self._set_copy_result(ret) | |||
| def _on_graph_compile_or_func_del(self, used_outputs): | |||
| if used_outputs: | |||
| self.on_graph_compiled(used_outputs) | |||
| else: | |||
| self.on_compiled_func_deleted() | |||
| def __repr__(self): | |||
| return "cranoiotome:{}".format(self.__class__.__name__) | |||
| @classmethod | |||
| def make( | |||
| cls, | |||
| *inputs, | |||
| comp_graph=None, | |||
| name=None, | |||
| comp_node=None, | |||
| config=None, | |||
| dev_comp_order_deps=[], | |||
| **kwargs | |||
| ): | |||
| """apply this operator on some input vars and return corresponding | |||
| output vars | |||
| :type inputs: tuple of :class:`.SymbolVar` | |||
| :param inputs: input symvars; immediate values could also be accepted, | |||
| as long as there is symvar to infer comp node and comp graph | |||
| :param comp_graph: if there is no input vars, *comp_graph* must be | |||
| provided to specify which computing graph to insert this operator | |||
| :param dev_comp_order_deps: vars that must have been computed | |||
| before executing this operator | |||
| :param kwargs: extra keyword arguments to be passed to :meth:`setup` of | |||
| this class | |||
| :param name: name of the resulting operator | |||
| :rtype: tuple of :class:`.SymbolVar` | |||
| :return: output symvars | |||
| """ | |||
| if not inputs and not dev_comp_order_deps: | |||
| assert isinstance( | |||
| comp_graph, _mgb.CompGraph | |||
| ), "{}: comp_graph must be given if no inputs provided".format(self) | |||
| desc = cls() | |||
| desc.setup(**kwargs) | |||
| assert ( | |||
| len(inputs) == desc.__nr_inputs__ | |||
| ), "{}: expected {} inputs, got {}".format( | |||
| desc, desc.__nr_inputs__, len(inputs) | |||
| ) | |||
| config = _helper.gen_config(name, comp_node, config) | |||
| # get inp_vec | |||
| inp_vec = _mgb._VectorSymbolVar() | |||
| for i in _helper.canonize_input_vars( | |||
| itertools.chain(inputs, dev_comp_order_deps), | |||
| comp_graph=comp_graph, | |||
| config=config, | |||
| ): | |||
| inp_vec.push_back(i) | |||
| desc._get_nr_dev_comp_order_deps = lambda *, val=len(dev_comp_order_deps): val | |||
| if comp_graph is not None: | |||
| desc._get_comp_graph = lambda: comp_graph | |||
| expand_single_outputs = desc.__expand_single_outputs__ | |||
| desc.__disown__() | |||
| rst = _mgb.make_opr_from_craniotome_desc(desc, inp_vec, config) | |||
| if expand_single_outputs and len(rst) == 1: | |||
| return rst[0] | |||
| return tuple(rst) | |||
| def make_opr(cls): | |||
| """decorator used to wrap a :class:`.CraniotomeBase` subclass and return | |||
| its :meth:`~.CraniotomeBase.make` method | |||
| """ | |||
| assert issubclass(cls, CraniotomeBase) | |||
| return cls.make | |||
| @@ -1,286 +0,0 @@ | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import collections | |||
| from typing import Union | |||
| import numpy as np | |||
| from .mgb import bfloat16, intb1, intb2, intb4 | |||
| _QuantDtypeMetadata = collections.namedtuple( | |||
| "QuantDtypeMetadata", ["name", "np_dtype_str", "is_unsigned", "qmin", "qmax",] | |||
| ) | |||
| _metadata_dict = { | |||
| "quint8": _QuantDtypeMetadata("Quantized8Asymm", "uint8", True, 0, 255), | |||
| "qint8": _QuantDtypeMetadata("QuantizedS8", "int8", False, -128, 127), | |||
| "quint4": _QuantDtypeMetadata("Quantized4Asymm", "uint8", True, 0, 15), | |||
| "qint4": _QuantDtypeMetadata("QuantizedS4", "int8", False, -8, 7), | |||
| "qint32": _QuantDtypeMetadata( | |||
| "QuantizedS32", "int32", False, -(2 ** 31), 2 ** 31 - 1, | |||
| ), | |||
| # NOTE: int2 is not supported for model dump yet | |||
| "quint2": _QuantDtypeMetadata(None, "uint8", True, 0, 3), | |||
| "qint2": _QuantDtypeMetadata(None, "int8", False, -2, 1), | |||
| } | |||
| def is_quantize(dtype): | |||
| return ( | |||
| hasattr(dtype, "metadata") | |||
| and dtype.metadata is not None | |||
| and "mgb_dtype" in dtype.metadata | |||
| ) | |||
| def is_lowbit(dtype): | |||
| return (dtype is intb1) or (dtype is intb2) or (dtype is intb4) | |||
| def is_bfloat16(dtype): | |||
| return dtype is bfloat16 | |||
| def get_scale(dtype): | |||
| assert is_quantize(dtype) | |||
| return dtype.metadata["mgb_dtype"]["scale"] | |||
| def get_zero_point(dtype): | |||
| assert is_quantize(dtype) | |||
| metadata = dtype.metadata["mgb_dtype"] | |||
| assert metadata["name"] in ("Quantized8Asymm", "Quantized4Asymm") | |||
| return metadata["zero_point"] | |||
| def _check_zero_point(zp: int, dtype_str: str): | |||
| qmin = _metadata_dict[dtype_str].qmin | |||
| qmax = _metadata_dict[dtype_str].qmax | |||
| if zp < qmin or zp > qmax: | |||
| raise ValueError( | |||
| "zero_point should be within [{}, {}] for {}".format(qmin, qmax, dtype_str) | |||
| ) | |||
| def get_quantized_dtype(dtype_str: str, scale: float, zp: Union[int, None]): | |||
| r""" | |||
| Get quantized dtype with metadata attribute according to _metadata_dict. | |||
| Note that unsigned dtype must have ``zero_point`` and signed dtype must | |||
| not have ``zero_point``, to be consitent with tensor generated by calling | |||
| compiled function from `CompGraph.compile(inputs, outspec)`. | |||
| :param dtype: a string indicating which dtype to return | |||
| :param scale: a number for scale to store in dtype's metadata | |||
| :param zp: a number for zero_point to store in dtype's metadata | |||
| """ | |||
| metadata = _metadata_dict[dtype_str] | |||
| np_dtype_str = metadata.np_dtype_str | |||
| is_unsigned = metadata.is_unsigned | |||
| if is_unsigned: | |||
| if zp is None or int(zp) != zp: | |||
| raise ValueError("zero_point should be an integer") | |||
| zp = int(zp) | |||
| _check_zero_point(zp, dtype_str) | |||
| return np.dtype( | |||
| np_dtype_str, | |||
| metadata={ | |||
| "mgb_dtype": { | |||
| "name": metadata.name, | |||
| "scale": float(scale), | |||
| "zero_point": zp, | |||
| } | |||
| }, | |||
| ) | |||
| else: | |||
| return np.dtype( | |||
| np_dtype_str, | |||
| metadata={"mgb_dtype": {"name": metadata.name, "scale": float(scale)}}, | |||
| ) | |||
| def quint8(scale, zero_point): | |||
| """ | |||
| Consturct a quantized unsigned int8 data type with ``scale`` (float) and | |||
| ``zero_point`` (uint8). The real value represented by a quint8 data type is | |||
| float_val = scale * (uint8_val - zero_point) | |||
| """ | |||
| return get_quantized_dtype("quint8", scale, zero_point) | |||
| def qint8(scale): | |||
| """ | |||
| Construct a quantized int8 data type with ``scale`` (float). The real value | |||
| represented by a qint8 data type is float_val = scale * int8_val | |||
| """ | |||
| return get_quantized_dtype("qint8", scale, None) | |||
| def qint32(scale): | |||
| """ | |||
| Construct a quantized int32 data type with ``scale`` (float). The real value | |||
| represented by a qint32 data type is float_val = scale * int32_val | |||
| """ | |||
| return get_quantized_dtype("qint32", scale, None) | |||
| def quint4(scale, zero_point): | |||
| """ | |||
| Consturct a quantized unsigned int4 data type with ``scale`` (float) and | |||
| ``zero_point`` (uint8). The real value represented by a quint4 data type is | |||
| float_val = scale * (uint4_val - zero_point) | |||
| """ | |||
| return get_quantized_dtype("quint4", scale, zero_point) | |||
| def qint4(scale): | |||
| """ | |||
| Construct a quantized int4 data type with ``scale`` (float). The real value | |||
| represented by a qint4 data type is float_val = scale * int4_val | |||
| """ | |||
| return get_quantized_dtype("qint4", scale, None) | |||
| def _convert_to_quantized_dtype(arr: np.ndarray, dtype: np.dtype, dtype_str: str): | |||
| metadata = _metadata_dict[dtype_str] | |||
| arr_metadata = dtype.metadata["mgb_dtype"] | |||
| if not isinstance(arr, np.ndarray): | |||
| raise ValueError("arr parameter should be instance of np.ndarray") | |||
| if not is_quantize(dtype) or arr_metadata["name"] != metadata.name: | |||
| raise ValueError("dtype parameter should be a {} dtype".format(dtype_str)) | |||
| is_unsigned = metadata.is_unsigned | |||
| if is_unsigned: | |||
| scale, zp = ( | |||
| arr_metadata["scale"], | |||
| arr_metadata["zero_point"], | |||
| ) | |||
| return ( | |||
| (np.round(arr / scale) + zp) | |||
| .clip(metadata.qmin, metadata.qmax) | |||
| .astype(dtype) | |||
| ) | |||
| else: | |||
| # don't trick to combine with is_unsigned, seeing ``get_quantized_dtype`` | |||
| scale = arr_metadata["scale"] | |||
| return np.round(arr / scale).clip(metadata.qmin, metadata.qmax).astype(dtype) | |||
| def _convert_from_quantized_dtype(arr: np.ndarray, dtype_str: str): | |||
| metadata = _metadata_dict[dtype_str] | |||
| arr_metadata = arr.dtype.metadata["mgb_dtype"] | |||
| if not isinstance(arr, np.ndarray): | |||
| raise ValueError("arr parameter should be instance of np.ndarray") | |||
| if not is_quantize(arr.dtype) or arr_metadata["name"] != metadata.name: | |||
| raise ValueError("arr's dtype should be a {} dtype".format(dtype_str)) | |||
| is_unsigned = metadata.is_unsigned | |||
| if is_unsigned: | |||
| scale, zp = ( | |||
| arr_metadata["scale"], | |||
| arr_metadata["zero_point"], | |||
| ) | |||
| return (arr.astype(np.float32) - zp) * scale | |||
| else: | |||
| # don't trick to combine with is_unsigned, seeing ``get_quantized_dtype`` | |||
| scale = arr_metadata["scale"] | |||
| return (arr.astype(np.float32)) * scale | |||
| def convert_to_quint8(arr: np.ndarray, q: np.dtype): | |||
| """ | |||
| Quantize a float NumPy ndarray into a quint8 one with specified params. | |||
| :param arr: Input ndarray. | |||
| :param q: Target data type, should be a quint8. | |||
| """ | |||
| return _convert_to_quantized_dtype(arr, q, "quint8") | |||
| def convert_from_quint8(arr: np.ndarray): | |||
| """ | |||
| Dequantize a quint8 NumPy ndarray into a float one. | |||
| :param arr: Input ndarray. | |||
| """ | |||
| return _convert_from_quantized_dtype(arr, "quint8") | |||
| def convert_to_qint8(arr: np.ndarray, q: np.dtype): | |||
| """ | |||
| Quantize a float NumPy ndarray into a qint8 one with specified params. | |||
| :param arr: Input ndarray. | |||
| :param q: Target data type, should be a qint8. | |||
| """ | |||
| return _convert_to_quantized_dtype(arr, q, "qint8") | |||
| def convert_from_qint8(arr: np.ndarray): | |||
| """ | |||
| Dequantize a qint8 NumPy ndarray into a float one. | |||
| :param arr: Input ndarray. | |||
| """ | |||
| return _convert_from_quantized_dtype(arr, "qint8") | |||
| def convert_to_qint32(arr: np.ndarray, q: np.dtype): | |||
| """ | |||
| Quantize a float NumPy ndarray into a qint32 one with specified params. | |||
| :param arr: Input ndarray. | |||
| :param q: Target data type, should be a qint8. | |||
| """ | |||
| return _convert_to_quantized_dtype(arr, q, "qint32") | |||
| def convert_from_qint32(arr): | |||
| """ | |||
| Dequantize a qint32 NumPy ndarray into a float one. | |||
| :param arr: Input ndarray. | |||
| """ | |||
| return _convert_from_quantized_dtype(arr, "qint32") | |||
| def convert_to_quint4(arr: np.ndarray, q: np.dtype): | |||
| """ | |||
| Quantize a float NumPy ndarray into a quint4 one with specified params. | |||
| :param arr: Input ndarray. | |||
| :param q: Target data type, should be a quint4. | |||
| """ | |||
| return _convert_to_quantized_dtype(arr, q, "quint4") | |||
| def convert_from_quint4(arr: np.ndarray): | |||
| """ | |||
| Dequantize a quint4 NumPy ndarray into a float one. | |||
| :param arr: Input ndarray. | |||
| """ | |||
| return _convert_from_quantized_dtype(arr, "quint4") | |||
| def convert_to_qint4(arr: np.ndarray, q: np.dtype): | |||
| """ | |||
| Quantize a float NumPy ndarray into a qint4 one with specified params. | |||
| :param arr: Input ndarray. | |||
| :param q: Target data type, should be a qint4. | |||
| """ | |||
| return _convert_to_quantized_dtype(arr, q, "qint4") | |||
| def convert_from_qint4(arr: np.ndarray): | |||
| """ | |||
| Dequantize a qint4 NumPy ndarray into a float one. | |||
| :param arr: Input ndarray. | |||
| """ | |||
| return _convert_from_quantized_dtype(arr, "qint4") | |||
| @@ -1,947 +0,0 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # Copyright [2001] [Cython] | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # --------------------------------------------------------------------- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # | |||
| # This file has been modified by Megvii ("Megvii Modifications"). | |||
| # All Megvii Modifications are Copyright (C) 2014-2020 Megvii Inc. All rights reserved. | |||
| # ---------------------------------------------------------------------- | |||
| import sys | |||
| from functools import reduce | |||
| from operator import or_ as _or_ | |||
| from types import DynamicClassAttribute, MappingProxyType | |||
| # try _collections first to reduce startup cost | |||
| try: | |||
| from _collections import OrderedDict | |||
| except ImportError: | |||
| from collections import OrderedDict | |||
| __all__ = [ | |||
| "EnumMeta", | |||
| "Enum", | |||
| "IntEnum", | |||
| "Flag", | |||
| "IntFlag", | |||
| "auto", | |||
| "unique", | |||
| ] | |||
| def _is_descriptor(obj): | |||
| """Returns True if obj is a descriptor, False otherwise.""" | |||
| return ( | |||
| hasattr(obj, "__get__") or hasattr(obj, "__set__") or hasattr(obj, "__delete__") | |||
| ) | |||
| def _is_dunder(name): | |||
| """Returns True if a __dunder__ name, False otherwise.""" | |||
| return ( | |||
| name[:2] == name[-2:] == "__" | |||
| and name[2:3] != "_" | |||
| and name[-3:-2] != "_" | |||
| and len(name) > 4 | |||
| ) | |||
| def _is_sunder(name): | |||
| """Returns True if a _sunder_ name, False otherwise.""" | |||
| return ( | |||
| name[0] == name[-1] == "_" | |||
| and name[1:2] != "_" | |||
| and name[-2:-1] != "_" | |||
| and len(name) > 2 | |||
| ) | |||
| def _make_class_unpicklable(cls): | |||
| """Make the given class un-picklable.""" | |||
| def _break_on_call_reduce(self, proto): | |||
| raise TypeError("%r cannot be pickled" % self) | |||
| cls.__reduce_ex__ = _break_on_call_reduce | |||
| cls.__module__ = "<unknown>" | |||
| _auto_null = object() | |||
| class auto: | |||
| """ | |||
| Instances are replaced with an appropriate value in Enum class suites. | |||
| """ | |||
| value = _auto_null | |||
| class _EnumDict(dict): | |||
| """Track enum member order and ensure member names are not reused. | |||
| EnumMeta will use the names found in self._member_names as the | |||
| enumeration member names. | |||
| """ | |||
| def __init__(self): | |||
| super().__init__() | |||
| self._member_names = [] | |||
| self._last_values = [] | |||
| def __setitem__(self, key, value): | |||
| """Changes anything not dundered or not a descriptor. | |||
| If an enum member name is used twice, an error is raised; duplicate | |||
| values are not checked for. | |||
| Single underscore (sunder) names are reserved. | |||
| """ | |||
| if _is_sunder(key): | |||
| if key not in ( | |||
| "_order_", | |||
| "_create_pseudo_member_", | |||
| "_generate_next_value_", | |||
| "_missing_", | |||
| ): | |||
| raise ValueError("_names_ are reserved for future Enum use") | |||
| if key == "_generate_next_value_": | |||
| setattr(self, "_generate_next_value", value) | |||
| elif _is_dunder(key): | |||
| if key == "__order__": | |||
| key = "_order_" | |||
| elif key in self._member_names: | |||
| # descriptor overwriting an enum? | |||
| raise TypeError("Attempted to reuse key: %r" % key) | |||
| elif not _is_descriptor(value): | |||
| if key in self: | |||
| # enum overwriting a descriptor? | |||
| raise TypeError("%r already defined as: %r" % (key, self[key])) | |||
| if isinstance(value, auto): | |||
| if value.value == _auto_null: | |||
| value.value = self._generate_next_value( | |||
| key, 1, len(self._member_names), self._last_values[:] | |||
| ) | |||
| value = value.value | |||
| self._member_names.append(key) | |||
| self._last_values.append(value) | |||
| super().__setitem__(key, value) | |||
| # Dummy value for Enum as EnumMeta explicitly checks for it, but of course | |||
| # until EnumMeta finishes running the first time the Enum class doesn't exist. | |||
| # This is also why there are checks in EnumMeta like `if Enum is not None` | |||
| Enum = None | |||
| class EnumMeta(type): | |||
| """Metaclass for Enum""" | |||
| @classmethod | |||
| def __prepare__(metacls, cls, bases): | |||
| # create the namespace dict | |||
| enum_dict = _EnumDict() | |||
| # inherit previous flags and _generate_next_value_ function | |||
| member_type, first_enum = metacls._get_mixins_(bases) | |||
| if first_enum is not None: | |||
| enum_dict["_generate_next_value_"] = getattr( | |||
| first_enum, "_generate_next_value_", None | |||
| ) | |||
| return enum_dict | |||
| def __new__(metacls, cls, bases, classdict): | |||
| # an Enum class is final once enumeration items have been defined; it | |||
| # cannot be mixed with other types (int, float, etc.) if it has an | |||
| # inherited __new__ unless a new __new__ is defined (or the resulting | |||
| # class will fail). | |||
| member_type, first_enum = metacls._get_mixins_(bases) | |||
| __new__, save_new, use_args = metacls._find_new_( | |||
| classdict, member_type, first_enum | |||
| ) | |||
| # save enum items into separate mapping so they don't get baked into | |||
| # the new class | |||
| enum_members = {k: classdict[k] for k in classdict._member_names} | |||
| for name in classdict._member_names: | |||
| del classdict[name] | |||
| # adjust the sunders | |||
| _order_ = classdict.pop("_order_", None) | |||
| # check for illegal enum names (any others?) | |||
| invalid_names = set(enum_members) & { | |||
| "mro", | |||
| } | |||
| if invalid_names: | |||
| raise ValueError( | |||
| "Invalid enum member name: {0}".format(",".join(invalid_names)) | |||
| ) | |||
| # create a default docstring if one has not been provided | |||
| if "__doc__" not in classdict: | |||
| classdict["__doc__"] = "An enumeration." | |||
| # create our new Enum type | |||
| enum_class = super().__new__(metacls, cls, bases, classdict) | |||
| enum_class._member_names_ = [] # names in definition order | |||
| enum_class._member_map_ = OrderedDict() # name->value map | |||
| enum_class._member_type_ = member_type | |||
| # save attributes from super classes so we know if we can take | |||
| # the shortcut of storing members in the class dict | |||
| base_attributes = {a for b in enum_class.mro() for a in b.__dict__} | |||
| # Reverse value->name map for hashable values. | |||
| enum_class._value2member_map_ = {} | |||
| # If a custom type is mixed into the Enum, and it does not know how | |||
| # to pickle itself, pickle.dumps will succeed but pickle.loads will | |||
| # fail. Rather than have the error show up later and possibly far | |||
| # from the source, sabotage the pickle protocol for this class so | |||
| # that pickle.dumps also fails. | |||
| # | |||
| # However, if the new class implements its own __reduce_ex__, do not | |||
| # sabotage -- it's on them to make sure it works correctly. We use | |||
| # __reduce_ex__ instead of any of the others as it is preferred by | |||
| # pickle over __reduce__, and it handles all pickle protocols. | |||
| if "__reduce_ex__" not in classdict: | |||
| if member_type is not object: | |||
| methods = ( | |||
| "__getnewargs_ex__", | |||
| "__getnewargs__", | |||
| "__reduce_ex__", | |||
| "__reduce__", | |||
| ) | |||
| if not any(m in member_type.__dict__ for m in methods): | |||
| _make_class_unpicklable(enum_class) | |||
| # instantiate them, checking for duplicates as we go | |||
| # we instantiate first instead of checking for duplicates first in case | |||
| # a custom __new__ is doing something funky with the values -- such as | |||
| # auto-numbering ;) | |||
| for member_name in classdict._member_names: | |||
| value = enum_members[member_name] | |||
| if not isinstance(value, tuple): | |||
| args = (value,) | |||
| else: | |||
| args = value | |||
| if member_type is tuple: # special case for tuple enums | |||
| args = (args,) # wrap it one more time | |||
| if not use_args: | |||
| enum_member = __new__(enum_class) | |||
| if not hasattr(enum_member, "_value_"): | |||
| enum_member._value_ = value | |||
| else: | |||
| enum_member = __new__(enum_class, *args) | |||
| if not hasattr(enum_member, "_value_"): | |||
| if member_type is object: | |||
| enum_member._value_ = value | |||
| else: | |||
| enum_member._value_ = member_type(*args) | |||
| value = enum_member._value_ | |||
| enum_member._name_ = member_name | |||
| enum_member.__objclass__ = enum_class | |||
| enum_member.__init__(*args) | |||
| # If another member with the same value was already defined, the | |||
| # new member becomes an alias to the existing one. | |||
| for name, canonical_member in enum_class._member_map_.items(): | |||
| if canonical_member._value_ == enum_member._value_: | |||
| enum_member = canonical_member | |||
| break | |||
| else: | |||
| # Aliases don't appear in member names (only in __members__). | |||
| enum_class._member_names_.append(member_name) | |||
| # performance boost for any member that would not shadow | |||
| # a DynamicClassAttribute | |||
| if member_name not in base_attributes: | |||
| setattr(enum_class, member_name, enum_member) | |||
| # now add to _member_map_ | |||
| enum_class._member_map_[member_name] = enum_member | |||
| try: | |||
| # This may fail if value is not hashable. We can't add the value | |||
| # to the map, and by-value lookups for this value will be | |||
| # linear. | |||
| enum_class._value2member_map_[value] = enum_member | |||
| except TypeError: | |||
| pass | |||
| # double check that repr and friends are not the mixin's or various | |||
| # things break (such as pickle) | |||
| for name in ("__repr__", "__str__", "__format__", "__reduce_ex__"): | |||
| class_method = getattr(enum_class, name) | |||
| obj_method = getattr(member_type, name, None) | |||
| enum_method = getattr(first_enum, name, None) | |||
| if obj_method is not None and obj_method is class_method: | |||
| setattr(enum_class, name, enum_method) | |||
| # replace any other __new__ with our own (as long as Enum is not None, | |||
| # anyway) -- again, this is to support pickle | |||
| if Enum is not None: | |||
| # if the user defined their own __new__, save it before it gets | |||
| # clobbered in case they subclass later | |||
| if save_new: | |||
| enum_class.__new_member__ = __new__ | |||
| enum_class.__new__ = Enum.__new__ | |||
| # py3 support for definition order (helps keep py2/py3 code in sync) | |||
| if _order_ is not None: | |||
| if isinstance(_order_, str): | |||
| _order_ = _order_.replace(",", " ").split() | |||
| if _order_ != enum_class._member_names_: | |||
| raise TypeError("member order does not match _order_") | |||
| return enum_class | |||
| def __bool__(self): | |||
| """ | |||
| classes/types should always be True. | |||
| """ | |||
| return True | |||
| def __call__( | |||
| cls, value, names=None, *, module=None, qualname=None, type=None, start=1 | |||
| ): | |||
| """Either returns an existing member, or creates a new enum class. | |||
| This method is used both when an enum class is given a value to match | |||
| to an enumeration member (i.e. Color(3)) and for the functional API | |||
| (i.e. Color = Enum('Color', names='RED GREEN BLUE')). | |||
| When used for the functional API: | |||
| `value` will be the name of the new class. | |||
| `names` should be either a string of white-space/comma delimited names | |||
| (values will start at `start`), or an iterator/mapping of name, value pairs. | |||
| `module` should be set to the module this class is being created in; | |||
| if it is not set, an attempt to find that module will be made, but if | |||
| it fails the class will not be picklable. | |||
| `qualname` should be set to the actual location this class can be found | |||
| at in its module; by default it is set to the global scope. If this is | |||
| not correct, unpickling will fail in some circumstances. | |||
| `type`, if set, will be mixed in as the first base class. | |||
| """ | |||
| if names is None: # simple value lookup | |||
| return cls.__new__(cls, value) | |||
| # otherwise, functional API: we're creating a new Enum type | |||
| return cls._create_( | |||
| value, names, module=module, qualname=qualname, type=type, start=start | |||
| ) | |||
| def __contains__(cls, member): | |||
| return isinstance(member, cls) and member._name_ in cls._member_map_ | |||
| def __delattr__(cls, attr): | |||
| # nicer error message when someone tries to delete an attribute | |||
| # (see issue19025). | |||
| if attr in cls._member_map_: | |||
| raise AttributeError("%s: cannot delete Enum member." % cls.__name__) | |||
| super().__delattr__(attr) | |||
| def __dir__(self): | |||
| return [ | |||
| "__class__", | |||
| "__doc__", | |||
| "__members__", | |||
| "__module__", | |||
| ] + self._member_names_ | |||
| def __getattr__(cls, name): | |||
| """Return the enum member matching `name` | |||
| We use __getattr__ instead of descriptors or inserting into the enum | |||
| class' __dict__ in order to support `name` and `value` being both | |||
| properties for enum members (which live in the class' __dict__) and | |||
| enum members themselves. | |||
| """ | |||
| if _is_dunder(name): | |||
| raise AttributeError(name) | |||
| try: | |||
| return cls._member_map_[name] | |||
| except KeyError: | |||
| raise AttributeError(name) from None | |||
| def __getitem__(cls, name): | |||
| return cls._member_map_[name] | |||
| def __iter__(cls): | |||
| return (cls._member_map_[name] for name in cls._member_names_) | |||
| def __len__(cls): | |||
| return len(cls._member_names_) | |||
| @property | |||
| def __members__(cls): | |||
| """Returns a mapping of member name->value. | |||
| This mapping lists all enum members, including aliases. Note that this | |||
| is a read-only view of the internal mapping. | |||
| """ | |||
| return MappingProxyType(cls._member_map_) | |||
| def __repr__(cls): | |||
| return "<enum %r>" % cls.__name__ | |||
| def __reversed__(cls): | |||
| return (cls._member_map_[name] for name in reversed(cls._member_names_)) | |||
| def __setattr__(cls, name, value): | |||
| """Block attempts to reassign Enum members. | |||
| A simple assignment to the class namespace only changes one of the | |||
| several possible ways to get an Enum member from the Enum class, | |||
| resulting in an inconsistent Enumeration. | |||
| """ | |||
| member_map = cls.__dict__.get("_member_map_", {}) | |||
| if name in member_map: | |||
| raise AttributeError("Cannot reassign members.") | |||
| super().__setattr__(name, value) | |||
| def _create_( | |||
| cls, class_name, names=None, *, module=None, qualname=None, type=None, start=1 | |||
| ): | |||
| """Convenience method to create a new Enum class. | |||
| `names` can be: | |||
| * A string containing member names, separated either with spaces or | |||
| commas. Values are incremented by 1 from `start`. | |||
| * An iterable of member names. Values are incremented by 1 from `start`. | |||
| * An iterable of (member name, value) pairs. | |||
| * A mapping of member name -> value pairs. | |||
| """ | |||
| metacls = cls.__class__ | |||
| bases = (cls,) if type is None else (type, cls) | |||
| _, first_enum = cls._get_mixins_(bases) | |||
| classdict = metacls.__prepare__(class_name, bases) | |||
| # special processing needed for names? | |||
| if isinstance(names, str): | |||
| names = names.replace(",", " ").split() | |||
| if isinstance(names, (tuple, list)) and names and isinstance(names[0], str): | |||
| original_names, names = names, [] | |||
| last_values = [] | |||
| for count, name in enumerate(original_names): | |||
| value = first_enum._generate_next_value_( | |||
| name, start, count, last_values[:] | |||
| ) | |||
| last_values.append(value) | |||
| names.append((name, value)) | |||
| # Here, names is either an iterable of (name, value) or a mapping. | |||
| for item in names: | |||
| if isinstance(item, str): | |||
| member_name, member_value = item, names[item] | |||
| else: | |||
| member_name, member_value = item | |||
| classdict[member_name] = member_value | |||
| enum_class = metacls.__new__(metacls, class_name, bases, classdict) | |||
| # TODO: replace the frame hack if a blessed way to know the calling | |||
| # module is ever developed | |||
| if module is None: | |||
| try: | |||
| module = sys._getframe(2).f_globals["__name__"] | |||
| except (AttributeError, ValueError) as exc: | |||
| pass | |||
| if module is None: | |||
| _make_class_unpicklable(enum_class) | |||
| else: | |||
| enum_class.__module__ = module | |||
| if qualname is not None: | |||
| enum_class.__qualname__ = qualname | |||
| return enum_class | |||
| @staticmethod | |||
| def _get_mixins_(bases): | |||
| """Returns the type for creating enum members, and the first inherited | |||
| enum class. | |||
| bases: the tuple of bases that was given to __new__ | |||
| """ | |||
| if not bases: | |||
| return object, Enum | |||
| # double check that we are not subclassing a class with existing | |||
| # enumeration members; while we're at it, see if any other data | |||
| # type has been mixed in so we can use the correct __new__ | |||
| member_type = first_enum = None | |||
| for base in bases: | |||
| if base is not Enum and issubclass(base, Enum) and base._member_names_: | |||
| raise TypeError("Cannot extend enumerations") | |||
| # base is now the last base in bases | |||
| if not issubclass(base, Enum): | |||
| raise TypeError( | |||
| "new enumerations must be created as " | |||
| "`ClassName([mixin_type,] enum_type)`" | |||
| ) | |||
| # get correct mix-in type (either mix-in type of Enum subclass, or | |||
| # first base if last base is Enum) | |||
| if not issubclass(bases[0], Enum): | |||
| member_type = bases[0] # first data type | |||
| first_enum = bases[-1] # enum type | |||
| else: | |||
| for base in bases[0].__mro__: | |||
| # most common: (IntEnum, int, Enum, object) | |||
| # possible: (<Enum 'AutoIntEnum'>, <Enum 'IntEnum'>, | |||
| # <class 'int'>, <Enum 'Enum'>, | |||
| # <class 'object'>) | |||
| if issubclass(base, Enum): | |||
| if first_enum is None: | |||
| first_enum = base | |||
| else: | |||
| if member_type is None: | |||
| member_type = base | |||
| return member_type, first_enum | |||
| @staticmethod | |||
| def _find_new_(classdict, member_type, first_enum): | |||
| """Returns the __new__ to be used for creating the enum members. | |||
| classdict: the class dictionary given to __new__ | |||
| member_type: the data type whose __new__ will be used by default | |||
| first_enum: enumeration to check for an overriding __new__ | |||
| """ | |||
| # now find the correct __new__, checking to see of one was defined | |||
| # by the user; also check earlier enum classes in case a __new__ was | |||
| # saved as __new_member__ | |||
| __new__ = classdict.get("__new__", None) | |||
| # should __new__ be saved as __new_member__ later? | |||
| save_new = __new__ is not None | |||
| if __new__ is None: | |||
| # check all possibles for __new_member__ before falling back to | |||
| # __new__ | |||
| for method in ("__new_member__", "__new__"): | |||
| for possible in (member_type, first_enum): | |||
| target = getattr(possible, method, None) | |||
| if target not in { | |||
| None, | |||
| None.__new__, | |||
| object.__new__, | |||
| Enum.__new__, | |||
| }: | |||
| __new__ = target | |||
| break | |||
| if __new__ is not None: | |||
| break | |||
| else: | |||
| __new__ = object.__new__ | |||
| # if a non-object.__new__ is used then whatever value/tuple was | |||
| # assigned to the enum member name will be passed to __new__ and to the | |||
| # new enum member's __init__ | |||
| if __new__ is object.__new__: | |||
| use_args = False | |||
| else: | |||
| use_args = True | |||
| return __new__, save_new, use_args | |||
| class Enum(metaclass=EnumMeta): | |||
| """Generic enumeration. | |||
| Derive from this class to define new enumerations. | |||
| """ | |||
| def __new__(cls, value): | |||
| # all enum instances are actually created during class construction | |||
| # without calling this method; this method is called by the metaclass' | |||
| # __call__ (i.e. Color(3) ), and by pickle | |||
| if type(value) is cls: | |||
| # For lookups like Color(Color.RED) | |||
| return value | |||
| # by-value search for a matching enum member | |||
| # see if it's in the reverse mapping (for hashable values) | |||
| try: | |||
| if value in cls._value2member_map_: | |||
| return cls._value2member_map_[value] | |||
| except TypeError: | |||
| # not there, now do long search -- O(n) behavior | |||
| for member in cls._member_map_.values(): | |||
| if member._value_ == value: | |||
| return member | |||
| # still not found -- try _missing_ hook | |||
| return cls._missing_(value) | |||
| def _generate_next_value_(name, start, count, last_values): | |||
| for last_value in reversed(last_values): | |||
| try: | |||
| return last_value + 1 | |||
| except TypeError: | |||
| pass | |||
| else: | |||
| return start | |||
| @classmethod | |||
| def _missing_(cls, value): | |||
| raise ValueError("%r is not a valid %s" % (value, cls.__name__)) | |||
| def __repr__(self): | |||
| return "<%s.%s: %r>" % (self.__class__.__name__, self._name_, self._value_) | |||
| def __str__(self): | |||
| return "%s.%s" % (self.__class__.__name__, self._name_) | |||
| def __dir__(self): | |||
| added_behavior = [ | |||
| m | |||
| for cls in self.__class__.mro() | |||
| for m in cls.__dict__ | |||
| if m[0] != "_" and m not in self._member_map_ | |||
| ] | |||
| return ["__class__", "__doc__", "__module__"] + added_behavior | |||
| def __format__(self, format_spec): | |||
| # mixed-in Enums should use the mixed-in type's __format__, otherwise | |||
| # we can get strange results with the Enum name showing up instead of | |||
| # the value | |||
| # pure Enum branch | |||
| if self._member_type_ is object: | |||
| cls = str | |||
| val = str(self) | |||
| # mix-in branch | |||
| else: | |||
| cls = self._member_type_ | |||
| val = self._value_ | |||
| return cls.__format__(val, format_spec) | |||
| def __hash__(self): | |||
| return hash(self._name_) | |||
| def __reduce_ex__(self, proto): | |||
| return self.__class__, (self._value_,) | |||
| # DynamicClassAttribute is used to provide access to the `name` and | |||
| # `value` properties of enum members while keeping some measure of | |||
| # protection from modification, while still allowing for an enumeration | |||
| # to have members named `name` and `value`. This works because enumeration | |||
| # members are not set directly on the enum class -- __getattr__ is | |||
| # used to look them up. | |||
| @DynamicClassAttribute | |||
| def name(self): | |||
| """The name of the Enum member.""" | |||
| return self._name_ | |||
| @DynamicClassAttribute | |||
| def value(self): | |||
| """The value of the Enum member.""" | |||
| return self._value_ | |||
| @classmethod | |||
| def _convert(cls, name, module, filter, source=None): | |||
| """ | |||
| Create a new Enum subclass that replaces a collection of global constants | |||
| """ | |||
| # convert all constants from source (or module) that pass filter() to | |||
| # a new Enum called name, and export the enum and its members back to | |||
| # module; | |||
| # also, replace the __reduce_ex__ method so unpickling works in | |||
| # previous Python versions | |||
| module_globals = vars(sys.modules[module]) | |||
| if source: | |||
| source = vars(source) | |||
| else: | |||
| source = module_globals | |||
| # We use an OrderedDict of sorted source keys so that the | |||
| # _value2member_map is populated in the same order every time | |||
| # for a consistent reverse mapping of number to name when there | |||
| # are multiple names for the same number rather than varying | |||
| # between runs due to hash randomization of the module dictionary. | |||
| members = [(name, source[name]) for name in source.keys() if filter(name)] | |||
| try: | |||
| # sort by value | |||
| members.sort(key=lambda t: (t[1], t[0])) | |||
| except TypeError: | |||
| # unless some values aren't comparable, in which case sort by name | |||
| members.sort(key=lambda t: t[0]) | |||
| cls = cls(name, members, module=module) | |||
| cls.__reduce_ex__ = _reduce_ex_by_name | |||
| module_globals.update(cls.__members__) | |||
| module_globals[name] = cls | |||
| return cls | |||
| class IntEnum(int, Enum): | |||
| """Enum where members are also (and must be) ints""" | |||
| def _reduce_ex_by_name(self, proto): | |||
| return self.name | |||
| class Flag(Enum): | |||
| """Support for flags""" | |||
| def _generate_next_value_(name, start, count, last_values): | |||
| """ | |||
| Generate the next value when not given. | |||
| name: the name of the member | |||
| start: the initital start value or None | |||
| count: the number of existing members | |||
| last_value: the last value assigned or None | |||
| """ | |||
| if not count: | |||
| return start if start is not None else 1 | |||
| for last_value in reversed(last_values): | |||
| try: | |||
| high_bit = _high_bit(last_value) | |||
| break | |||
| except Exception: | |||
| raise TypeError("Invalid Flag value: %r" % last_value) from None | |||
| return 2 ** (high_bit + 1) | |||
| @classmethod | |||
| def _missing_(cls, value): | |||
| original_value = value | |||
| if value < 0: | |||
| value = ~value | |||
| possible_member = cls._create_pseudo_member_(value) | |||
| if original_value < 0: | |||
| possible_member = ~possible_member | |||
| return possible_member | |||
| @classmethod | |||
| def _create_pseudo_member_(cls, value): | |||
| """ | |||
| Create a composite member iff value contains only members. | |||
| """ | |||
| pseudo_member = cls._value2member_map_.get(value, None) | |||
| if pseudo_member is None: | |||
| # verify all bits are accounted for | |||
| _, extra_flags = _decompose(cls, value) | |||
| if extra_flags: | |||
| raise ValueError("%r is not a valid %s" % (value, cls.__name__)) | |||
| # construct a singleton enum pseudo-member | |||
| pseudo_member = object.__new__(cls) | |||
| pseudo_member._name_ = None | |||
| pseudo_member._value_ = value | |||
| # use setdefault in case another thread already created a composite | |||
| # with this value | |||
| pseudo_member = cls._value2member_map_.setdefault(value, pseudo_member) | |||
| return pseudo_member | |||
| def __contains__(self, other): | |||
| if not isinstance(other, self.__class__): | |||
| return NotImplemented | |||
| return other._value_ & self._value_ == other._value_ | |||
| def __repr__(self): | |||
| cls = self.__class__ | |||
| if self._name_ is not None: | |||
| return "<%s.%s: %r>" % (cls.__name__, self._name_, self._value_) | |||
| members, uncovered = _decompose(cls, self._value_) | |||
| return "<%s.%s: %r>" % ( | |||
| cls.__name__, | |||
| "|".join([str(m._name_ or m._value_) for m in members]), | |||
| self._value_, | |||
| ) | |||
| def __str__(self): | |||
| cls = self.__class__ | |||
| if self._name_ is not None: | |||
| return "%s.%s" % (cls.__name__, self._name_) | |||
| members, uncovered = _decompose(cls, self._value_) | |||
| if len(members) == 1 and members[0]._name_ is None: | |||
| return "%s.%r" % (cls.__name__, members[0]._value_) | |||
| else: | |||
| return "%s.%s" % ( | |||
| cls.__name__, | |||
| "|".join([str(m._name_ or m._value_) for m in members]), | |||
| ) | |||
| def __bool__(self): | |||
| return bool(self._value_) | |||
| def __or__(self, other): | |||
| if not isinstance(other, self.__class__): | |||
| return NotImplemented | |||
| return self.__class__(self._value_ | other._value_) | |||
| def __and__(self, other): | |||
| if not isinstance(other, self.__class__): | |||
| return NotImplemented | |||
| return self.__class__(self._value_ & other._value_) | |||
| def __xor__(self, other): | |||
| if not isinstance(other, self.__class__): | |||
| return NotImplemented | |||
| return self.__class__(self._value_ ^ other._value_) | |||
| def __invert__(self): | |||
| members, uncovered = _decompose(self.__class__, self._value_) | |||
| inverted_members = [ | |||
| m | |||
| for m in self.__class__ | |||
| if m not in members and not m._value_ & self._value_ | |||
| ] | |||
| inverted = reduce(_or_, inverted_members, self.__class__(0)) | |||
| return self.__class__(inverted) | |||
| class IntFlag(int, Flag): | |||
| """Support for integer-based Flags""" | |||
| @classmethod | |||
| def _missing_(cls, value): | |||
| if not isinstance(value, int): | |||
| raise ValueError("%r is not a valid %s" % (value, cls.__name__)) | |||
| new_member = cls._create_pseudo_member_(value) | |||
| return new_member | |||
| @classmethod | |||
| def _create_pseudo_member_(cls, value): | |||
| pseudo_member = cls._value2member_map_.get(value, None) | |||
| if pseudo_member is None: | |||
| need_to_create = [value] | |||
| # get unaccounted for bits | |||
| _, extra_flags = _decompose(cls, value) | |||
| # timer = 10 | |||
| while extra_flags: | |||
| # timer -= 1 | |||
| bit = _high_bit(extra_flags) | |||
| flag_value = 2 ** bit | |||
| if ( | |||
| flag_value not in cls._value2member_map_ | |||
| and flag_value not in need_to_create | |||
| ): | |||
| need_to_create.append(flag_value) | |||
| if extra_flags == -flag_value: | |||
| extra_flags = 0 | |||
| else: | |||
| extra_flags ^= flag_value | |||
| for value in reversed(need_to_create): | |||
| # construct singleton pseudo-members | |||
| pseudo_member = int.__new__(cls, value) | |||
| pseudo_member._name_ = None | |||
| pseudo_member._value_ = value | |||
| # use setdefault in case another thread already created a composite | |||
| # with this value | |||
| pseudo_member = cls._value2member_map_.setdefault(value, pseudo_member) | |||
| return pseudo_member | |||
| def __or__(self, other): | |||
| if not isinstance(other, (self.__class__, int)): | |||
| return NotImplemented | |||
| result = self.__class__(self._value_ | self.__class__(other)._value_) | |||
| return result | |||
| def __and__(self, other): | |||
| if not isinstance(other, (self.__class__, int)): | |||
| return NotImplemented | |||
| return self.__class__(self._value_ & self.__class__(other)._value_) | |||
| def __xor__(self, other): | |||
| if not isinstance(other, (self.__class__, int)): | |||
| return NotImplemented | |||
| return self.__class__(self._value_ ^ self.__class__(other)._value_) | |||
| __ror__ = __or__ | |||
| __rand__ = __and__ | |||
| __rxor__ = __xor__ | |||
| def __invert__(self): | |||
| result = self.__class__(~self._value_) | |||
| return result | |||
| def _high_bit(value): | |||
| """returns index of highest bit, or -1 if value is zero or negative""" | |||
| return value.bit_length() - 1 | |||
| def unique(enumeration): | |||
| """Class decorator for enumerations ensuring unique member values.""" | |||
| duplicates = [] | |||
| for name, member in enumeration.__members__.items(): | |||
| if name != member.name: | |||
| duplicates.append((name, member.name)) | |||
| if duplicates: | |||
| alias_details = ", ".join( | |||
| ["%s -> %s" % (alias, name) for (alias, name) in duplicates] | |||
| ) | |||
| raise ValueError( | |||
| "duplicate values found in %r: %s" % (enumeration, alias_details) | |||
| ) | |||
| return enumeration | |||
| def _decompose(flag, value): | |||
| """Extract all members from the value.""" | |||
| # _decompose is only called if the value is not named | |||
| not_covered = value | |||
| negative = value < 0 | |||
| # issue29167: wrap accesses to _value2member_map_ in a list to avoid race | |||
| # conditions between iterating over it and having more psuedo- | |||
| # members added to it | |||
| if negative: | |||
| # only check for named flags | |||
| flags_to_check = [ | |||
| (m, v) | |||
| for v, m in list(flag._value2member_map_.items()) | |||
| if m.name is not None | |||
| ] | |||
| else: | |||
| # check for named flags and powers-of-two flags | |||
| flags_to_check = [ | |||
| (m, v) | |||
| for v, m in list(flag._value2member_map_.items()) | |||
| if m.name is not None or _power_of_two(v) | |||
| ] | |||
| members = [] | |||
| for member, member_value in flags_to_check: | |||
| if member_value and member_value & value == member_value: | |||
| members.append(member) | |||
| not_covered &= ~member_value | |||
| if not members and value in flag._value2member_map_: | |||
| members.append(flag._value2member_map_[value]) | |||
| members.sort(key=lambda m: m._value_, reverse=True) | |||
| if len(members) > 1 and members[0].value == value: | |||
| # we have the breakdown, don't need the value member itself | |||
| members.pop(0) | |||
| return members, not_covered | |||
| def _power_of_two(value): | |||
| if value < 1: | |||
| return False | |||
| return value == 2 ** _high_bit(value) | |||
| @@ -1,58 +0,0 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| """exception handling""" | |||
| from . import mgb as _mgb | |||
| class MegBrainError(Exception): | |||
| """exception class used by megbrain library""" | |||
| tracker = None | |||
| """the tracker setup by :func:`.set_exc_opr_tracker` when the related | |||
| operator is created""" | |||
| tracker_grad_orig = None | |||
| """if this operator is created by taking gradient, this var would be the | |||
| tracker of the operator that causes the grad.""" | |||
| def __init__(self, msg, tracker, tracker_grad_orig): | |||
| assert isinstance(msg, str) | |||
| super().__init__(msg, tracker, tracker_grad_orig) | |||
| self.tracker = tracker | |||
| self.tracker_grad_orig = tracker_grad_orig | |||
| @classmethod | |||
| def _format_tracker(cls, tracker): | |||
| return ("| " + i for i in str(tracker).split("\n")) | |||
| def __str__(self): | |||
| lines = [] | |||
| lines.extend(self.args[0].split("\n")) | |||
| if self.tracker is not None: | |||
| lines.append("Exception tracker:") | |||
| lines.extend(self._format_tracker(self.tracker)) | |||
| if self.tracker_grad_orig is not None: | |||
| lines.append( | |||
| "Exception caused by taking grad of another operator with tracker:" | |||
| ) | |||
| lines.extend(self._format_tracker(self.tracker_grad_orig)) | |||
| while not lines[-1].strip(): | |||
| lines.pop() | |||
| for idx, ct in enumerate(lines): | |||
| if ct.startswith("bt:"): | |||
| lines[idx] = "+ " + lines[idx] | |||
| for t in range(idx + 1, len(lines)): | |||
| lines[t] = "| " + lines[t] | |||
| break | |||
| return "\n".join(lines) | |||
| _mgb._reg_exception_class(MegBrainError) | |||
| @@ -1,41 +0,0 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| """global initialization work; classes/functions defined in this module should | |||
| not be used by user code""" | |||
| import atexit | |||
| import os | |||
| import sys | |||
| import traceback | |||
| from . import mgb | |||
| from .logconf import get_logger | |||
| from .persistent_cache import PersistentCacheOnServer | |||
| class PyStackExtracterImpl(mgb._PyStackExtracter): | |||
| def extract(self): | |||
| return "".join(traceback.format_stack()[:-1]) | |||
| mgb._register_logger(get_logger()) | |||
| assert sys.executable | |||
| mgb._timed_func_set_fork_exec_path( | |||
| sys.executable, | |||
| os.path.join(os.path.dirname(__file__), "_timed_func_fork_exec_entry.py"), | |||
| ) | |||
| persistent_cache_impl_ins = PersistentCacheOnServer() | |||
| mgb._PersistentCache.reg(persistent_cache_impl_ins) | |||
| PyStackExtracterImplIns = PyStackExtracterImpl() | |||
| PyStackExtracterImpl.reg(PyStackExtracterImplIns) | |||
| atexit.register(mgb._mgb_global_finalize) | |||
| @@ -1,316 +0,0 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import collections | |||
| import numpy as np | |||
| from . import mgb | |||
| from .exc import MegBrainError | |||
| from .mgb import SharedND, SymbolVar | |||
| from .opr_param_defs import OptionalAxisV1 | |||
| def canonize_reshape(inputs, *, comp_graph, config): | |||
| src, tshape = inputs | |||
| tshape = cvt_to_shape_desc(tshape, src, comp_graph, config) | |||
| return src, tshape | |||
| def canonize_shape_input(inputs, *, comp_graph, config): | |||
| assert isinstance(inputs, (list, tuple)) and len(inputs) == 1 | |||
| return [cvt_to_shape_desc(inputs[0], None, comp_graph, config)] | |||
| def cvt_to_shape_desc(val, inpvar, graph, config): | |||
| """convert some python object to a :class:`SymbolVar` that describes tensor | |||
| shape | |||
| :param val: the python object to be converted from | |||
| :param inpvar, graph, config: provide graph and comp node information; can | |||
| be None if not known. Either input or (graph, config) must be provided. | |||
| :return: a new var corresponding to *val* | |||
| :rtype: :class:`.SymbolVar` | |||
| """ | |||
| if hasattr(val, "__mgb_symvar__"): | |||
| val = val.__mgb_symvar__() | |||
| elif hasattr(val, "symvar"): | |||
| val = val.symvar | |||
| if isinstance(val, SymbolVar): | |||
| return val | |||
| if not isinstance(val, collections.Iterable): | |||
| val = [val] | |||
| components = [] | |||
| has_sym = False | |||
| for i in val: | |||
| if hasattr(i, "__mgb_symvar__"): | |||
| i = i.__mgb_symvar__() | |||
| elif hasattr(i, "symvar"): | |||
| i = i.symvar | |||
| if isinstance(i, SymbolVar): | |||
| has_sym = True | |||
| components.append(i) | |||
| else: | |||
| assert isinstance(i, int), ( | |||
| "shape desc could contain either int or SymbolVar, got {}" | |||
| " actually".format(repr(i)) | |||
| ) | |||
| components.append(i) | |||
| assert components, "shape desc could not be empty" | |||
| if inpvar is not None: | |||
| assert isinstance(inpvar, SymbolVar) | |||
| if graph is None: | |||
| graph = inpvar.owner_graph | |||
| else: | |||
| assert graph == inpvar.owner_graph | |||
| config = mgb.make_opr_config(comp_node=inpvar.comp_node) | |||
| else: | |||
| assert isinstance(graph, mgb.CompGraph), "graph must be provided" | |||
| assert isinstance(config, mgb.OperatorNodeConfig) | |||
| if not has_sym: | |||
| shape = np.ascontiguousarray(components, dtype=np.int32) | |||
| assert np.all(shape == components), "failed to convert to shape: {}".format( | |||
| components | |||
| ) | |||
| return mgb._make_immutable(graph, shape, None, config) | |||
| for idx, v in enumerate(components): | |||
| if not isinstance(v, SymbolVar): | |||
| vi = int(v) | |||
| assert vi == v, "could not convert {} to int".format(v) | |||
| components[idx] = mgb._make_immutable(graph, vi, None, config) | |||
| from . import opr as O | |||
| return O.concat(components, axis=0, config=config) | |||
| def canonize_input_vars(inputs, *, comp_graph, config): | |||
| """convert immediate numbers and SharedND to SymbolVar in inputs; at least | |||
| one of the inputs must be SymbolVar, so comp node and comp graph can | |||
| beinferred | |||
| :return: list of converted vars | |||
| """ | |||
| from . import make_immutable | |||
| if ( | |||
| isinstance(inputs, (list, tuple)) | |||
| and len(inputs) == 1 | |||
| and isinstance(inputs[0], (list, tuple)) | |||
| ): | |||
| # handle the case when a list is passed to a function with | |||
| # variable-length argument (e.g. concat has signature concat(*inputs) | |||
| # and is called with concat([a, b])) | |||
| inputs = inputs[0] | |||
| if isinstance(inputs, SymbolVar): | |||
| return [inputs] | |||
| old_inputs = inputs | |||
| inputs = [] | |||
| get_comp_node = None | |||
| need_cvt = False | |||
| for i in old_inputs: | |||
| if isinstance(i, SymbolVar): | |||
| get_comp_node = lambda cn=i.comp_node: cn | |||
| if comp_graph is not None: | |||
| assert comp_graph == i.owner_graph | |||
| else: | |||
| comp_graph = i.owner_graph | |||
| else: | |||
| need_cvt = True | |||
| inputs.append(i) | |||
| if not need_cvt: | |||
| return inputs | |||
| if get_comp_node is None: | |||
| def get_comp_node(): | |||
| nonlocal get_comp_node | |||
| cn = config.require_comp_node() | |||
| get_comp_node = lambda: cn | |||
| return cn | |||
| for idx, var in enumerate(inputs): | |||
| if not isinstance(var, SymbolVar): | |||
| if isinstance(var, SharedND): | |||
| var = var.symvar(comp_graph) | |||
| elif isinstance(var, mgb.SharedScalar): | |||
| var = var._as_sym_var(comp_graph, get_comp_node()) | |||
| elif hasattr(var, "__mgb_symvar__"): | |||
| try: | |||
| cn = get_comp_node() | |||
| except MegBrainError: | |||
| cn = None | |||
| var = var.__mgb_symvar__(comp_graph=comp_graph, comp_node=cn) | |||
| elif hasattr(var, "symvar"): | |||
| var = var.symvar | |||
| else: | |||
| var = make_immutable(get_comp_node(), comp_graph, var) | |||
| inputs[idx] = var | |||
| return inputs | |||
| def cvt_to_vector_of_shape(shapes): | |||
| """convert ``[[int]]`` to nested ``std::vector`` of ``size_t``""" | |||
| ret = mgb._VectorTensorShape() | |||
| for i in shapes: | |||
| val = tuple(i) | |||
| assert val and all( | |||
| j > 0 and isinstance(j, int) for j in val | |||
| ), "something returns bad shape in infer_shape(): {}".format(val) | |||
| ret.push_back(val) | |||
| return ret | |||
| def cvt_to_opr_param_def(param, ptype, kwargs): | |||
| if param is not None: | |||
| if isinstance(param, ptype): | |||
| return param | |||
| param = [param] | |||
| assert len(param) == len( | |||
| ptype.__slots__ | |||
| ), "{} needs {} params, but {} are provided".format( | |||
| ptype, len(ptype.__slots__), len(param) | |||
| ) | |||
| return ptype(*param) | |||
| ckw = {} | |||
| for i in ptype.__slots__: | |||
| val = kwargs.pop(i, ckw) | |||
| if val is not ckw: | |||
| ckw[i] = val | |||
| return ptype(**ckw) | |||
| def cvt_getitem_to_idx_desc(inpvar, tuple_val, *, allow_newaxis=True): | |||
| """convert ``__getitem__`` args to index desc | |||
| :return: ``(new_var, index_desc)`` where new_var is inpvar with | |||
| ``np.newaxis`` applied; note that ``index_desc`` can be ``None``. | |||
| """ | |||
| assert isinstance(inpvar, SymbolVar), "bad input: {!r}".format(inpvar) | |||
| if not isinstance(tuple_val, tuple): | |||
| tuple_val = (tuple_val,) | |||
| axis_indexer = mgb._VectorAxisIndexer() | |||
| config = mgb.make_opr_config(comp_node=inpvar.comp_node) | |||
| graph = inpvar.owner_graph | |||
| def as_symvar(v, *, allow_list=True): | |||
| if isinstance(v, SymbolVar): | |||
| return v | |||
| vi = np.ascontiguousarray(v, dtype=np.int32) | |||
| assert np.abs(vi - v).max() == 0, "bad index: {!r}".format(v) | |||
| return mgb._make_immutable(graph, vi, None, config) | |||
| def _s(v): # convert slice item | |||
| if v is None: | |||
| return SymbolVar() | |||
| return as_symvar(v, allow_list=False) | |||
| new_axes = [] | |||
| cur_axis = -1 | |||
| for i_idx, i in enumerate(tuple_val): | |||
| cur_axis += 1 | |||
| if i is np.newaxis: | |||
| if cur_axis >= 0: | |||
| new_axes.append(cur_axis) | |||
| continue | |||
| if i is Ellipsis: | |||
| cur_axis = -1 | |||
| for j in tuple_val[:i_idx:-1]: | |||
| if j is Ellipsis: | |||
| raise IndexError("only one ellipsis is allowed") | |||
| if j is np.newaxis: | |||
| new_axes.append(cur_axis) | |||
| cur_axis -= 1 | |||
| continue | |||
| if isinstance(i, slice): | |||
| if i.start is None and i.stop is None and i.step is None: | |||
| continue | |||
| cur = mgb._AxisIndexer.make_interval( | |||
| cur_axis, _s(i.start), _s(i.stop), _s(i.step) | |||
| ) | |||
| else: | |||
| cur = mgb._AxisIndexer.make_index(cur_axis, as_symvar(i)) | |||
| axis_indexer.push_back(cur) | |||
| if new_axes: | |||
| if not allow_newaxis: | |||
| raise IndexError("newaxis is not allowed here") | |||
| inpvar = mgb._Opr.add_axis(inpvar, new_axes, mgb.make_opr_config()) | |||
| if axis_indexer.empty(): | |||
| axis_indexer = None | |||
| return inpvar, axis_indexer | |||
| def cvt_to_reshape_unspec_axis(unspec_axis, tshape): | |||
| assert isinstance(unspec_axis, OptionalAxisV1), repr(unspec_axis) | |||
| unspec_axis = unspec_axis.axis | |||
| assert abs(unspec_axis) <= OptionalAxisV1.MAX_NDIM | |||
| if not isinstance(tshape, SymbolVar): | |||
| for idx, val in enumerate(tshape): | |||
| if val == -1: | |||
| assert ( | |||
| unspec_axis == OptionalAxisV1.INVALID_AXIS | |||
| ), "multiple unknown dimensions for reshape" | |||
| unspec_axis = idx | |||
| return OptionalAxisV1(unspec_axis) | |||
| def gen_config(name, comp_node, config, output_dtype=None): | |||
| if config is None: | |||
| config = mgb.make_opr_config(name, comp_node, output_dtype) | |||
| else: | |||
| assert isinstance(config, mgb.OperatorNodeConfig) | |||
| assert name is None and comp_node is None | |||
| return config | |||
| def cvt_opr_result(rst, *, explode_single=True): | |||
| """:param explode_single: whether to return the content of a single-item | |||
| list rather thatn the list itself""" | |||
| if not isinstance(rst, mgb.SymbolVar): | |||
| assert isinstance(rst, (list, tuple)) | |||
| if len(rst) == 1 and explode_single: | |||
| return cvt_opr_result(rst[0]) | |||
| return tuple(map(cvt_opr_result, rst)) | |||
| if not rst.valid: | |||
| return None | |||
| # TODO Because the __init__ of SwigObject can not be modified to keep the | |||
| # reference of graph, we get owner graph explicitly here. The correct | |||
| # handling is moving the reference to SwigWrapper, but it is unsupported to | |||
| # add a member variable to SwigWrapper, so we should wrap the SymbolVar | |||
| # manually in megbrain_wrap.h | |||
| rst.owner_graph | |||
| f32 = np.float32 | |||
| if not hasattr(cvt_opr_result, "_cvt_to_float32"): | |||
| import os | |||
| from .logconf import get_logger | |||
| cvt_opr_result._cvt_to_float32 = os.getenv("MGB_ALL_FLOAT32") | |||
| if cvt_opr_result._cvt_to_float32: | |||
| get_logger().warn( | |||
| "\n" | |||
| "+=====================================================+\n" | |||
| "| MGB_ALL_FLOAT32 is set, so all megbrain opr result |\n" | |||
| "| would to converted to float32; this should only be |\n" | |||
| "| used for loading old models. |\n" | |||
| "+=====================================================+" | |||
| ) | |||
| if cvt_opr_result._cvt_to_float32 and rst.dtype != f32: | |||
| rst = rst.astype(f32) | |||
| return rst | |||
| @@ -1,54 +0,0 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import logging | |||
| import os | |||
| _replaced_logger = None | |||
| def get_logger(): | |||
| global _replaced_logger | |||
| if _replaced_logger is not None: | |||
| return _replaced_logger | |||
| logger = logging.getLogger("megbrain") | |||
| logger.propagate = False | |||
| logger.setLevel(logging.INFO) | |||
| handler = logging.StreamHandler() | |||
| handler.setFormatter(MgbLogFormatter(datefmt="%d %H:%M:%S")) | |||
| handler.setLevel(0) | |||
| del logger.handlers[:] | |||
| logger.addHandler(handler) | |||
| _replaced_logger = logger | |||
| return logger | |||
| class MgbLogFormatter(logging.Formatter): | |||
| def format(self, record): | |||
| date = "\x1b[32m[%(asctime)s %(lineno)d@%(filename)s:%(name)s]\x1b[0m" | |||
| msg = "%(message)s" | |||
| if record.levelno == logging.DEBUG: | |||
| fmt = "{} \x1b[32mDBG\x1b[0m {}".format(date, msg) | |||
| elif record.levelno == logging.WARNING: | |||
| fmt = "{} \x1b[1;31mWRN\x1b[0m {}".format(date, msg) | |||
| elif record.levelno == logging.ERROR: | |||
| fmt = "{} \x1b[1;4;31mERR\x1b[0m {}".format(date, msg) | |||
| else: | |||
| fmt = date + " " + msg | |||
| self._style._fmt = fmt | |||
| return super().format(record) | |||
| def set_logger(logger): | |||
| """replace the logger""" | |||
| global _replaced_logger | |||
| _replaced_logger = logger | |||
| from .mgb import _register_logger | |||
| _register_logger(logger) | |||
| @@ -1,87 +0,0 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| """helper utils for the core mgb module""" | |||
| import collections | |||
| import inspect | |||
| import json | |||
| import threading | |||
| from abc import ABCMeta, abstractmethod | |||
| class callback_lazycopy: | |||
| """wraps around a callable to be passed to :meth:`.CompGraph.compile`. | |||
| This is used to disable eager copy, so we could get rid of an h2d copy and | |||
| a d2h if values are to be passed from one callback to another | |||
| :class:`.SharedND`. | |||
| """ | |||
| def __init__(self, func): | |||
| assert isinstance(func, collections.Callable) | |||
| self.__func = func | |||
| @property | |||
| def func(self): | |||
| return self.__func | |||
| class SharedNDLazyInitializer(metaclass=ABCMeta): | |||
| """lazy initialization policy for :class:`.SharedND`""" | |||
| @abstractmethod | |||
| def get_shape(self): | |||
| """get shape, without loading value""" | |||
| @abstractmethod | |||
| def get_value(self): | |||
| """get value as numpy ndarray""" | |||
| class copy_output: | |||
| """wraps a :class:`.SymbolVar` in outspec for :meth:`.CompGraph.compile`, | |||
| to copy the output to function return value""" | |||
| symvar = None | |||
| borrow_mem = None | |||
| def __init__(self, symvar, *, borrow_mem=False): | |||
| """ | |||
| :param borrow_mem: see :meth:`.CompGraphCallbackValueProxy.get_value` | |||
| """ | |||
| from .mgb import SymbolVar | |||
| assert isinstance( | |||
| symvar, SymbolVar | |||
| ), "copy_output expects an SymbolVar, got {} instead".format(symvar) | |||
| self.symvar = symvar | |||
| self.borrow_mem = borrow_mem | |||
| class FuncOutputSaver: | |||
| """instance could be used as callbacks for :meth:`.CompGraph.compile` to | |||
| copy output to host buffer | |||
| """ | |||
| _value = None | |||
| _borrow_mem = None | |||
| def __init__(self, borrow_mem=False): | |||
| self._borrow_mem = borrow_mem | |||
| def __call__(self, v): | |||
| self._value = v.get_value(borrow_mem=self._borrow_mem) | |||
| def get(self): | |||
| assert ( | |||
| self._value is not None | |||
| ), "{} not called; maybe due to unwaited async func".format(self) | |||
| return self._value | |||
| @@ -1,3 +0,0 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # Copyright (c) 2015-2019 Megvii Inc. All rights reserved. | |||
| @@ -1,90 +0,0 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import argparse | |||
| import getpass | |||
| import json | |||
| import os | |||
| import shelve | |||
| from .logconf import get_logger | |||
| from .mgb import _PersistentCache | |||
| from .version import __version__ | |||
| class _FakeRedisConn: | |||
| def __init__(self): | |||
| try: | |||
| from ..hub.hub import _get_megengine_home | |||
| cache_dir = os.path.expanduser( | |||
| os.path.join(_get_megengine_home(), "persistent_cache") | |||
| ) | |||
| os.makedirs(cache_dir, exist_ok=True) | |||
| cache_file = os.path.join(cache_dir, "cache") | |||
| self._dict = shelve.open(cache_file) | |||
| self._is_shelve = True | |||
| except: | |||
| self._dict = {} | |||
| self._is_shelve = False | |||
| def get(self, key): | |||
| if self._is_shelve and isinstance(key, bytes): | |||
| key = key.decode("utf-8") | |||
| return self._dict.get(key) | |||
| def set(self, key, val): | |||
| if self._is_shelve and isinstance(key, bytes): | |||
| key = key.decode("utf-8") | |||
| self._dict[key] = val | |||
| def __del__(self): | |||
| if self._is_shelve: | |||
| self._dict.close() | |||
| class PersistentCacheOnServer(_PersistentCache): | |||
| _cached_conn = None | |||
| _prefix = None | |||
| _prev_get_refkeep = None | |||
| @property | |||
| def _conn(self): | |||
| """get redis connection""" | |||
| if self._cached_conn is None: | |||
| self._cached_conn = _FakeRedisConn() | |||
| self._prefix = self.make_user_prefix() | |||
| return self._cached_conn | |||
| @classmethod | |||
| def make_user_prefix(cls): | |||
| return "mgbcache:{}".format(getpass.getuser()) | |||
| def _make_key(self, category, key): | |||
| prefix_with_version = "{}:MGB{}".format(self._prefix, __version__) | |||
| return b"@".join( | |||
| (prefix_with_version.encode("ascii"), category.encode("ascii"), key) | |||
| ) | |||
| def put(self, category, key, value): | |||
| conn = self._conn | |||
| key = self._make_key(category, key) | |||
| conn.set(key, value) | |||
| def get(self, category, key): | |||
| conn = self._conn | |||
| key = self._make_key(category, key) | |||
| self._prev_get_refkeep = conn.get(key) | |||
| return self._prev_get_refkeep | |||
| @@ -1,261 +0,0 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| """plugins associated with computing graph""" | |||
| import atexit | |||
| import collections | |||
| import json | |||
| import os | |||
| import platform | |||
| import signal | |||
| import struct | |||
| import numpy as np | |||
| from . import mgb as _mgb | |||
| from .logconf import get_logger | |||
| InfkernFinderInputValueRec = collections.namedtuple( | |||
| "InfkernFinderInputValueRec", ["var_name", "var_id", "run_id", "value"] | |||
| ) | |||
| class CompGraphProfiler(_mgb._CompGraphProfilerImpl): | |||
| """a plugin to profile computing graphs""" | |||
| def __init__(self, comp_graph): | |||
| super().__init__(comp_graph) | |||
| def get(self): | |||
| """get visualizable profiling result on a function""" | |||
| return json.loads(self._get_result()) | |||
| def write_json(self, fobj): | |||
| """write the result to a json file | |||
| :param fobj: a file-like object, or a string | |||
| """ | |||
| if isinstance(fobj, str): | |||
| with open(fobj, "w") as fout: | |||
| return self.write_json(fout) | |||
| fobj.write(self._get_result()) | |||
| class NumRangeChecker(_mgb._NumRangeCheckerImpl): | |||
| """check that all numberical float values of variables in a computing graph | |||
| are within given range""" | |||
| def __init__(self, comp_graph, max_abs_val): | |||
| """:param max_abs_val: max absolute value""" | |||
| super().__init__(comp_graph, float(max_abs_val)) | |||
| class TextOprIODump(_mgb._TextOprIODumpImpl): | |||
| """dump all internal results as text to a file""" | |||
| def __init__(self, comp_graph, fpath, *, print_addr=None, max_size=None): | |||
| super().__init__(comp_graph, fpath) | |||
| if print_addr is not None: | |||
| self.print_addr(print_addr) | |||
| if max_size is not None: | |||
| self.max_size(max_size) | |||
| def print_addr(self, flag): | |||
| """set whether to print var address | |||
| :return: self | |||
| """ | |||
| self._print_addr(flag) | |||
| return self | |||
| def max_size(self, size): | |||
| """set the number of elements to be printed for each var | |||
| :return: self | |||
| """ | |||
| self._max_size(size) | |||
| return self | |||
| class BinaryOprIODump(_mgb._BinaryOprIODumpImpl): | |||
| """dump all internal results binary files to a directory; the values can be | |||
| loaded by :func:`load_tensor_binary` | |||
| """ | |||
| def __init__(self, comp_graph, dir_path): | |||
| super().__init__(comp_graph, dir_path) | |||
| class InfkernFinder(_mgb._InfkernFinderImpl): | |||
| """a plugin to find kernels that cause infinite loops""" | |||
| def __init__(self, comp_graph, record_input_value): | |||
| """ | |||
| :param record_input_value: whether need to record input var values of | |||
| all operators | |||
| :type record_input_value: bool | |||
| """ | |||
| super().__init__(comp_graph, record_input_value) | |||
| def write_to_file(self, fpath): | |||
| """write current execution status to a text file | |||
| :return: ID of the first operator that is still not finished, | |||
| or None if all oprs are finished | |||
| :rtype: int or None | |||
| """ | |||
| v = self._write_to_file(fpath) | |||
| if v == 0: | |||
| return | |||
| return v - 1 | |||
| def get_input_values(self, opr_id): | |||
| """get recorded input values of a given operator. Return a list | |||
| of :class:`InfkernFinderInputValueRec`. Note that the value in | |||
| each item is either None (if it is not recorded) or a numpy | |||
| array | |||
| """ | |||
| ret = [] | |||
| for idx in range(self._get_input_values_prepare(opr_id)): | |||
| vn = self._get_input_values_var_name(idx) | |||
| vi = self._get_input_values_var_idx(idx) | |||
| ri = self._get_input_values_run_id(idx) | |||
| val = self._get_input_values_val(idx) | |||
| if not val.shape: | |||
| val = None | |||
| else: | |||
| val = val.get_value() | |||
| ret.append(InfkernFinderInputValueRec(vn, vi, ri, val)) | |||
| return ret | |||
| def fast_signal_hander(signum, callback): | |||
| """bypass python's signal handling system and registera handler that is | |||
| called ASAP in a dedicated thread (in contrary, python calls handlers in | |||
| the main thread) | |||
| :param callback: signal callback, taking the signal number as its sole | |||
| argument | |||
| """ | |||
| def cb_wrapped(): | |||
| try: | |||
| callback(signum) | |||
| except: | |||
| get_logger().exception("error calling signal handler for {}".format(signum)) | |||
| _mgb._FastSignal.register_handler(signum, cb_wrapped) | |||
| atexit.register(_mgb._FastSignal.shutdown) | |||
| class GlobalInfkernFinder: | |||
| """ | |||
| manage a list of :class:`InfkernFinder` objects; when this process is | |||
| signaled with SIGUSR1, an interactive IPython shell would be presented for | |||
| further investigation | |||
| """ | |||
| _signal = None | |||
| if platform.system() != "Windows": | |||
| _signal = signal.SIGUSR1 | |||
| else: | |||
| _signal = signal.CTRL_C_EVENT | |||
| _registry = [] | |||
| _shell_maker = None | |||
| @classmethod | |||
| def add_graph(cls, comp_graph): | |||
| """register a graph so it can be tracked by :class:`InfkernFinder`""" | |||
| enabled = os.getenv("MGB_DBG_INFKERN_FINDER") | |||
| if not enabled: | |||
| return | |||
| if enabled == "1": | |||
| record_input_value = False | |||
| else: | |||
| assert enabled == "2", ( | |||
| "MGB_DBG_INFKERN_FINDER must be either 1 or 2, indicating " | |||
| "whether to record input values" | |||
| ) | |||
| record_input_value = True | |||
| finder = InfkernFinder(comp_graph, record_input_value) | |||
| get_logger().warning( | |||
| "interactive InfkernFinder {} registered to graph {}; all input " | |||
| "var values would be recorded and the graph would never be " | |||
| "reclaimed. You can enter the interactive debug session by " | |||
| 'executing "kill -{} {}". record_input_value={}'.format( | |||
| finder, comp_graph, cls._signal, os.getpid(), record_input_value | |||
| ) | |||
| ) | |||
| if not cls._registry: | |||
| from IPython.terminal.embed import InteractiveShellEmbed | |||
| cls._shell_maker = InteractiveShellEmbed | |||
| fast_signal_hander(cls._signal, cls._on_signal) | |||
| cls._registry.append(finder) | |||
| @classmethod | |||
| def _on_signal(cls, signum): | |||
| shell = cls._shell_maker() | |||
| shell( | |||
| header="Enter interactive InfkernFinder session; the registered " | |||
| "finder objects can be found in variable f", | |||
| local_ns={"f": cls._registry}, | |||
| ) | |||
| def load_tensor_binary(fobj): | |||
| """load a tensor dumped by the :class:`BinaryOprIODump` plugin; the actual | |||
| tensor value dump is implemented by ``mgb::debug::dump_tensor``. | |||
| Multiple values can be compared by ``tools/compare_binary_iodump.py``. | |||
| :param fobj: file object, or a string that contains the file name | |||
| :return: tuple ``(tensor_value, tensor_name)`` | |||
| """ | |||
| if isinstance(fobj, str): | |||
| with open(fobj, "rb") as fin: | |||
| return load_tensor_binary(fin) | |||
| DTYPE_LIST = { | |||
| 0: np.float32, | |||
| 1: np.uint8, | |||
| 2: np.int8, | |||
| 3: np.int16, | |||
| 4: np.int32, | |||
| 5: _mgb.intb1, | |||
| 6: _mgb.intb2, | |||
| 7: _mgb.intb4, | |||
| 8: None, | |||
| 9: np.float16, | |||
| # quantized dtype start from 100000 | |||
| # see MEGDNN_PARAMETERIZED_DTYPE_ENUM_BASE in | |||
| # dnn/include/megdnn/dtype.h | |||
| 100000: np.uint8, | |||
| 100001: np.int32, | |||
| 100002: np.int8, | |||
| } | |||
| header_fmt = struct.Struct("III") | |||
| name_len, dtype, max_ndim = header_fmt.unpack(fobj.read(header_fmt.size)) | |||
| assert ( | |||
| DTYPE_LIST[dtype] is not None | |||
| ), "Cannot load this tensor: dtype Byte is unsupported." | |||
| shape = list(struct.unpack("I" * max_ndim, fobj.read(max_ndim * 4))) | |||
| while shape[-1] == 0: | |||
| shape.pop(-1) | |||
| name = fobj.read(name_len).decode("ascii") | |||
| return np.fromfile(fobj, dtype=DTYPE_LIST[dtype]).reshape(shape), name | |||
| @@ -1,57 +0,0 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| """version information for MegBrain package""" | |||
| import collections | |||
| from . import mgb as _mgb | |||
| class Version( | |||
| collections.namedtuple("VersionBase", ["major", "minor", "patch", "dev"]) | |||
| ): | |||
| """simple sematic version object""" | |||
| @classmethod | |||
| def __normalize(cls, v): | |||
| if isinstance(v, str): | |||
| v = v.split(".") | |||
| a, b, c = map(int, v) | |||
| return cls(a, b, c) | |||
| def __eq__(self, rhs): | |||
| return super().__eq__(self.__normalize(rhs)) | |||
| def __ne__(self, rhs): | |||
| return super().__ne__(self.__normalize(rhs)) | |||
| def __lt__(self, rhs): | |||
| return super().__lt__(self.__normalize(rhs)) | |||
| def __le__(self, rhs): | |||
| return super().__le__(self.__normalize(rhs)) | |||
| def __gt__(self, rhs): | |||
| return super().__gt__(self.__normalize(rhs)) | |||
| def __ge__(self, rhs): | |||
| return super().__ge__(self.__normalize(rhs)) | |||
| def __str__(self): | |||
| rst = "{}.{}.{}".format(self.major, self.minor, self.patch) | |||
| if self.dev: | |||
| rst += "-dev{}".format(self.dev) | |||
| return rst | |||
| Version.__new__.__defaults__ = (0,) # dev defaults to 0 | |||
| version_info = Version(*_mgb._get_mgb_version()) | |||
| __version__ = str(version_info) | |||
| @@ -1,20 +0,0 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from .device import ( | |||
| get_default_device, | |||
| get_device_count, | |||
| is_cuda_available, | |||
| set_default_device, | |||
| ) | |||
| from .function import Function | |||
| from .graph import Graph, dump | |||
| from .serialization import load, save | |||
| from .tensor import Tensor, TensorDict, tensor, wrap_io_tensor | |||
| from .tensor_factory import ones, zeros | |||
| from .tensor_nn import Buffer, Parameter | |||
| @@ -1,60 +0,0 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import os | |||
| import megengine._internal as mgb | |||
| _default_device = os.getenv("MGE_DEFAULT_DEVICE", "xpux") | |||
| def get_device_count(device_type: str) -> int: | |||
| """Gets number of devices installed on this system. | |||
| :param device_type: device type, one of 'gpu' or 'cpu' | |||
| """ | |||
| device_type_set = ("cpu", "gpu") | |||
| assert device_type in device_type_set, "device must be one of {}".format( | |||
| device_type_set | |||
| ) | |||
| return mgb.config.get_device_count(device_type) | |||
| def is_cuda_available() -> bool: | |||
| """Returns whether cuda device is available on this system. | |||
| """ | |||
| return mgb.config.get_device_count("gpu", warn=False) > 0 | |||
| def set_default_device(device: str = "xpux"): | |||
| r"""Sets default computing node. | |||
| :param device: default device type. The type can be 'cpu0', 'cpu1', etc., | |||
| or 'gpu0', 'gpu1', etc., to specify the particular cpu or gpu to use. | |||
| 'cpux' and 'gupx' can also be used to specify any number of cpu or gpu devices. | |||
| 'multithread' device type is avaliable when inference, which implements | |||
| multi-threading parallelism at the operator level. For example, | |||
| 'multithread4' will compute with 4 threads. which implements | |||
| The default value is 'xpux' to specify any device available. | |||
| It can also be set by environmental variable `MGE_DEFAULT_DEVICE`. | |||
| """ | |||
| global _default_device # pylint: disable=global-statement | |||
| _default_device = device | |||
| def get_default_device() -> str: | |||
| r"""Gets default computing node. | |||
| It returns the value set by :func:`~.set_default_device`. | |||
| """ | |||
| return _default_device | |||
| @@ -1,176 +0,0 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import copy | |||
| from abc import ABCMeta, abstractmethod | |||
| from typing import Iterable, Tuple, Union | |||
| import megengine._internal as mgb | |||
| from .tensor import Tensor | |||
| class _OverrideGradientCraniotome(mgb.craniotome.CraniotomeBase): | |||
| __nr_inputs__ = None | |||
| __nr_outputs__ = None | |||
| __expand_single_outputs__ = False | |||
| __allow_duplicate__ = False | |||
| grad_func = None | |||
| def setup(self, nr_inputs, nr_outputs, grad_func): | |||
| self.__nr_inputs__ = nr_inputs + nr_outputs | |||
| self.__nr_outputs__ = nr_outputs | |||
| self.grad_func = grad_func | |||
| def infer_shape(self, inp_shapes): | |||
| return inp_shapes[-self.__nr_outputs__ :] | |||
| def init_output_dtype(self, input_dtypes): | |||
| return input_dtypes[-self.__nr_outputs__ :] | |||
| def execute(self, inputs, outputs): | |||
| for ivar, ovar in zip(inputs[-self.__nr_outputs__ :], outputs): | |||
| ovar.set_value(ivar) | |||
| def grad(self, wrt_idx, inputs, outputs, out_grad): | |||
| # TODO: Make sure grad_values really have values in eager mode. | |||
| # Porting to the new imperative engine would solve this, but if it | |||
| # don't happen, EagerEvalManager should be changed. | |||
| grads = self.grad_func( | |||
| *(Tensor(x) if x is not None else None for x in out_grad) | |||
| ) | |||
| # pylint: disable=literal-comparison | |||
| if isinstance(grads, Tensor) or grads is None or grads is 0: | |||
| grads = (grads,) | |||
| assert ( | |||
| len(grads) == self.__nr_inputs__ - self.__nr_outputs__ | |||
| ), "Function.backward should return a tuple with len = {}, got {}".format( | |||
| self.__nr_inputs__ - self.__nr_outputs__, len(grads) | |||
| ) | |||
| # pylint: disable=literal-comparison | |||
| return ( | |||
| list(x._symvar if x is not None and x is not 0 else 0 for x in grads) | |||
| + [0] * self.__nr_outputs__ | |||
| ) | |||
| def get_serialize_params(self): | |||
| raise NotImplementedError("Serialization of Function is not implemented") | |||
| class Function(metaclass=ABCMeta): | |||
| """ | |||
| Defines a block of operations with customizable differentiation. | |||
| The computation should be defined in ``forward`` method, with gradient | |||
| computation defined in ``backward`` method. | |||
| Each instance of ``Function`` should be used only once during forwardding. | |||
| Examples: | |||
| .. testcode:: | |||
| class Sigmoid(Function): | |||
| def forward(self, x): | |||
| y = 1 / (1 + F.exp(-x)) | |||
| self.save_for_backward(y) | |||
| return y | |||
| def backward(self, output_grads): | |||
| (y, ) = self.saved_tensors | |||
| return output_grads * y * (1-y) | |||
| """ | |||
| _has_saved_state = False | |||
| saved_tensors = None | |||
| def __init__(self): | |||
| self.saved_tensors = () | |||
| @abstractmethod | |||
| def forward(self, *inputs: Iterable[Tensor]) -> Union[Tuple[Tensor], Tensor]: | |||
| """ | |||
| Applies operations to ``inputs`` and returns results. It must be overriden by all subclasses. | |||
| Users can call :meth:`~.function.Function.save_for_backward` in this method to save tensors. | |||
| :param input: Input tensors. | |||
| :return: A tuple of Tensor or a single Tensor. | |||
| .. note:: | |||
| This method should return a tuple of Tensor or a single Tensor representing the output | |||
| of the function. | |||
| """ | |||
| raise NotImplementedError | |||
| @abstractmethod | |||
| def backward( | |||
| self, *output_grads: Iterable[Union[Tensor, None]] | |||
| ) -> Union[Tuple[Tensor], Tensor]: | |||
| """ | |||
| Compute the gradient of the forward function. It must be overriden by all subclasses. | |||
| :param output_grads: gradients of outputs that are returned by :meth:`~.function.Function.forward` | |||
| .. note:: | |||
| In case when some tensors of outputs are not related to loss function, the corresponding | |||
| values in ``output_grads`` would be ``None``. | |||
| .. note:: | |||
| This method should return a tuple which containing the gradients of all inputs, in the same order | |||
| as the ``inputs`` argument of :meth:`~.function.Function.forward` . A ``Tensor`` could be returned | |||
| instead if there is only one input. If users want to stop the propagation of some gradients, | |||
| the corresponding returned values should be set ``None`` . | |||
| """ | |||
| raise NotImplementedError | |||
| def save_for_backward(self, *tensors: Iterable[Tensor]): | |||
| """ | |||
| Saves tensors needed for gradient computation. This method should be called only | |||
| once in :meth:`~.function.Function.forward`, additional calls will replace values saved previously. | |||
| The saved tensors can be accessed through the ``saved_tensors`` attribute. | |||
| """ | |||
| self.saved_tensors = tensors | |||
| def __deepcopy__(self, memo): | |||
| """ | |||
| Defines how the operator is deeply copied | |||
| """ | |||
| cls = self.__class__ | |||
| result = cls.__new__(cls) | |||
| tmp = self.saved_tensors | |||
| self.saved_tensors = None | |||
| memo[id(self)] = result | |||
| for k, v in self.__dict__.items(): | |||
| setattr(result, k, copy.deepcopy(v, memo)) | |||
| setattr(result, "saved_tensors", tmp) | |||
| self.saved_tensors = tmp | |||
| return result | |||
| def __call__(self, *inputs): | |||
| assert ( | |||
| not self._has_saved_state | |||
| ), "A Function instance should not be called multiple times" | |||
| outputs = self.forward(*inputs) | |||
| if isinstance(outputs, Tensor): | |||
| outputs = (outputs,) | |||
| self._has_saved_state = True | |||
| sv = (x._symvar for x in inputs + outputs) | |||
| outputs = _OverrideGradientCraniotome.make( | |||
| *sv, nr_inputs=len(inputs), nr_outputs=len(outputs), grad_func=self.backward | |||
| ) | |||
| outputs = tuple(map(Tensor, outputs)) | |||
| if len(outputs) == 1: | |||
| outputs = outputs[0] | |||
| return outputs | |||
| @@ -1,158 +0,0 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import collections | |||
| import threading | |||
| import megengine._internal as mgb | |||
| from .device import get_default_device | |||
| class _DefaultGraph(threading.local): | |||
| r""" | |||
| An implicit thread-local graph | |||
| """ | |||
| def __init__(self): | |||
| super(_DefaultGraph, self).__init__() | |||
| self._default_graph = None | |||
| def get_default(self): | |||
| r"""Returns a default Graph object for eager evaluation. | |||
| """ | |||
| if self._default_graph is None: | |||
| self._default_graph = Graph() | |||
| return self._default_graph | |||
| _default_graph = _DefaultGraph() | |||
| class Graph(mgb.CompGraph): | |||
| r""" | |||
| A computing graph that supporting context management. | |||
| :param check_env_var: whether to check environment vars including ``MGB_COMP_GRAPH_OPT``. | |||
| :param eager_evaluation: use dynamic graph(``True``) or static graph(``False``). | |||
| Examples: | |||
| .. testcode:: | |||
| import numpy as np | |||
| from megengine import tensor | |||
| from megengine.core import Graph | |||
| with Graph(eager_evaluation=True): | |||
| x = tensor([1, 2]) | |||
| print(x) | |||
| Outputs: | |||
| .. testoutput:: | |||
| Tensor([1 2], dtype=int32) | |||
| """ | |||
| __saved_graph = None | |||
| def __new__( | |||
| cls, *, check_env_var: bool = True, eager_evaluation: bool = True, **kwargs | |||
| ): | |||
| kwargs.update(eager_evaluation=eager_evaluation) | |||
| self = mgb.comp_graph(extra_opts=kwargs, check_env_var=check_env_var) | |||
| self.__class__ = cls | |||
| return self | |||
| def __init__( | |||
| self, *, check_env_var: bool = True, eager_evaluation: bool = True, **kwargs | |||
| ): | |||
| # pylint: disable=super-init-not-called | |||
| pass | |||
| def __enter__(self): | |||
| self.__saved_graph = _default_graph._default_graph | |||
| _default_graph._default_graph = self | |||
| return self | |||
| def __exit__(self, type, value, traceback): | |||
| _default_graph._default_graph = self.__saved_graph | |||
| del self.__saved_graph | |||
| def _use_default_if_none(device, comp_graph): | |||
| if device is None: | |||
| device = get_default_device() | |||
| if comp_graph is None: | |||
| comp_graph = get_default_graph() | |||
| return device, comp_graph | |||
| def dump(outputs, fpath, optimize_options=None, **kwargs): | |||
| r""" | |||
| Serializes this computing graph and writes it to a file. | |||
| :type outputs: ``Tensor`` or a collection of ``Tensor`` | |||
| :param outputs: output variables that need to be retrieved when | |||
| deserializing | |||
| :type fpath: ``str`` | |||
| :param fpath: path for the output file | |||
| :type optimize_options: ``list`` | |||
| :param optimize_options: ``['f16_io_f32_comp', 'f16_io_comp', 'use_nhwcd4', 'fuse_conv_bias_nonlinearity']`` , four elements are optional, it can be an empty list, None or a list containing any of them. | |||
| .. note:: | |||
| ``f16_io_f32_comp`` – whether to use float16 for I/O between oprs and use float32 as internal computation precision. Note the output var would be changed to float16; | |||
| ``f16_io_comp`` – whether to use float16 for both I/O and computation precision; | |||
| ``use_nhwcd4`` – whether to use NHWCD4 data format. This is faster on some OpenCL devices; | |||
| ``fuse_conv_bias_nonlinearity`` – whether to fuse conv+bias+nonlinearty into one opr. This is supported only when ``use_nhwcd4`` is set. | |||
| """ | |||
| from .tensor import Tensor | |||
| assert optimize_options is None or isinstance( | |||
| optimize_options, list | |||
| ), "optimize_options must be a list" | |||
| if isinstance(outputs, Tensor): | |||
| outputs = [outputs] | |||
| else: | |||
| assert isinstance(outputs, collections.Iterable), "{} not iterable".format( | |||
| outputs | |||
| ) | |||
| outputs = list(outputs) | |||
| for output in outputs: | |||
| assert isinstance(output, Tensor), "All outputs must be Tensors." | |||
| outputs = [o._symvar for o in outputs] | |||
| if optimize_options: | |||
| opt_dict = dict.fromkeys(optimize_options, True) | |||
| mgb.optimize_for_inference(outputs, **opt_dict) | |||
| mgb.serialize_comp_graph_to_file(fpath, outputs, **kwargs) | |||
| def set_default_graph(default_graph): | |||
| r""" | |||
| Sets a global default Graph object. | |||
| """ | |||
| global _default_graph # pylint: disable=global-statement | |||
| _default_graph._default_graph = default_graph | |||
| def get_default_graph(): | |||
| r""" | |||
| Returns a default Graph object, most probably for eager evaluation. | |||
| """ | |||
| return _default_graph.get_default() | |||
| @@ -1,128 +0,0 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import pickle | |||
| import megengine._internal as mgb | |||
| from ..utils.max_recursion_limit import max_recursion_limit | |||
| from .device import get_default_device | |||
| def save(obj, f, pickle_module=pickle, pickle_protocol=pickle.HIGHEST_PROTOCOL): | |||
| r"""Save an object to disk file. | |||
| :type obj: object | |||
| :param obj: object to save. Only ``module`` or ``state_dict`` are allowed. | |||
| :type f: text file object | |||
| :param f: a string of file name or a text file object to which ``obj`` is saved to. | |||
| :type pickle_module: | |||
| :param pickle_module: Default: ``pickle``. | |||
| :type pickle_protocol: | |||
| :param pickle_protocol: Default: ``pickle.HIGHEST_PROTOCOL``. | |||
| """ | |||
| if isinstance(f, str): | |||
| with open(f, "wb") as fout: | |||
| save( | |||
| obj, fout, pickle_module=pickle_module, pickle_protocol=pickle_protocol | |||
| ) | |||
| return | |||
| with max_recursion_limit(): | |||
| assert hasattr(f, "write"), "{} does not support write".format(f) | |||
| pickle_module.dump(obj, f, pickle_protocol) | |||
| class dmap: | |||
| def __init__(self, map_location): | |||
| self.map_location = map_location | |||
| def __enter__(self): | |||
| mgb.add_device_map(self.map_location) | |||
| return self | |||
| def __exit__(self, type, value, traceback): | |||
| mgb.del_device_map() | |||
| def _get_callable_map_location(map_location): | |||
| if map_location is None: | |||
| def callable_map_location(state): | |||
| return str(get_default_device()) | |||
| elif isinstance(map_location, str): | |||
| def callable_map_location(state): | |||
| return map_location | |||
| elif isinstance(map_location, dict): | |||
| locator_map = {} | |||
| for key, value in map_location.items(): | |||
| locator_key = mgb.config.parse_locator(key)[:2] | |||
| locator_map[locator_key] = value | |||
| def callable_map_location(state): | |||
| orig = mgb.config.parse_locator(state)[:2] | |||
| if orig in locator_map.keys(): | |||
| state = locator_map[orig] | |||
| return state | |||
| else: | |||
| assert callable(map_location), "map_location should be str, dict or function" | |||
| callable_map_location = map_location | |||
| return callable_map_location | |||
| def load(f, map_location=None, pickle_module=pickle): | |||
| r"""Load an object saved with save() from a file. | |||
| :type f: text file object | |||
| :param f: a string of file name or a text file object from which to load. | |||
| :type map_location: str, dict or a function specifying the map rules | |||
| :param map_location: Default: ``None``. | |||
| .. note:: | |||
| map_location will change the logical locator when loading models, | |||
| avoiding tensors be loading on non-existent device. If you want to | |||
| add the mapping relationship between logical locator and physical | |||
| locator in runtime, please call :func:`mge.set_device_map()` | |||
| :type pickle_module: | |||
| :param pickle_module: Default: ``pickle``. | |||
| .. note:: | |||
| If you will call :func:`mge.set_default_device()`, please do it | |||
| before :func:`mge.load()`. | |||
| Examples: | |||
| .. testcode: | |||
| import megengine as mge | |||
| mge.load('model.mge') | |||
| # Load all tensors based on logical location. | |||
| mge.load('model.mge', map_location='gpu0') | |||
| # Load all tensors onto the device: GPU0 | |||
| mge.load('model.mge', map_location={'gpu0':'cpu0'}) | |||
| # Load all tensors based on logical location, but 'GPU0' will be renamed to 'CPU0' | |||
| mge.load('model.mge', map_location=lambda dev: 'cpu0') | |||
| # Load all tensors onto the device" CPU0 | |||
| """ | |||
| if isinstance(f, str): | |||
| with open(f, "rb") as fin: | |||
| return load(fin, map_location=map_location, pickle_module=pickle_module) | |||
| map_location = _get_callable_map_location(map_location) # callable map_location | |||
| with dmap(map_location): | |||
| return pickle_module.load(f) | |||
| @@ -1,771 +0,0 @@ | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import collections | |||
| import copy | |||
| import functools | |||
| import itertools | |||
| import weakref | |||
| from typing import Callable, Tuple, Union | |||
| import numpy as np | |||
| import megengine._internal as mgb | |||
| from .graph import _use_default_if_none, get_default_graph | |||
| def wrap_io_tensor(func): | |||
| r"""A wrapper to make ``func`` compatible with functions in ``_internal.opr``. | |||
| """ | |||
| @functools.wraps(func) | |||
| def wrapper(*args, **kwargs): | |||
| comp_graph = None | |||
| for i in itertools.chain(args, kwargs.values()): | |||
| if isinstance(i, Tensor) and i._comp_graph: | |||
| comp_graph = i._comp_graph | |||
| break | |||
| else: | |||
| comp_graph = get_default_graph() | |||
| new_args = ( | |||
| arg._attach(comp_graph) if isinstance(arg, Tensor) else arg for arg in args | |||
| ) | |||
| new_kwargs = { | |||
| k: v._attach(comp_graph) if isinstance(v, Tensor) else v | |||
| for k, v in kwargs.items() | |||
| } | |||
| ret = func(*new_args, **new_kwargs) | |||
| if isinstance(ret, mgb.SymbolVar): | |||
| ret = Tensor(ret) | |||
| elif isinstance(ret, list): | |||
| ret = [Tensor(t) if isinstance(t, mgb.SymbolVar) else t for t in ret] | |||
| elif isinstance(ret, tuple): | |||
| ret = tuple(Tensor(t) if isinstance(t, mgb.SymbolVar) else t for t in ret) | |||
| return ret | |||
| return wrapper | |||
| def _wrap_symbolvar_binary_op(f): | |||
| @functools.wraps(f) | |||
| def wrapped(self, other): | |||
| comp_graph = ( | |||
| isinstance(other, Tensor) | |||
| and other._comp_graph | |||
| or self._comp_graph | |||
| or get_default_graph() | |||
| ) | |||
| if isinstance(other, Tensor): | |||
| other = other._attach(comp_graph) | |||
| return Tensor(f(self._attach(comp_graph), other)) | |||
| return wrapped | |||
| def _wrap_slice(inp: slice): | |||
| r""" | |||
| A wrapper to handle Tensor values in ``inp`` slice. | |||
| """ | |||
| start = inp.start._symvar if isinstance(inp.start, Tensor) else inp.start | |||
| stop = inp.stop._symvar if isinstance(inp.stop, Tensor) else inp.stop | |||
| step = inp.step._symvar if isinstance(inp.step, Tensor) else inp.step | |||
| return slice(start, stop, step) | |||
| def _wrap_idx(idx: Tuple[Union[int, "Tensor"]]): | |||
| r""" | |||
| A wrapper to handle Tensor values in ``idx``. | |||
| """ | |||
| if not isinstance(idx, tuple): | |||
| idx = (idx,) | |||
| idx = tuple(i._symvar if isinstance(i, Tensor) else i for i in idx) | |||
| idx = tuple(_wrap_slice(i) if isinstance(i, slice) else i for i in idx) | |||
| return idx | |||
| class _MGBIndexWrapper: | |||
| r""" | |||
| A wrapper class to handle ``__getitem__`` for index containing Tensor values. | |||
| :param dest: a destination Tensor to do indexing on. | |||
| :param mgb_index: an ``_internal`` helper function indicating how to index. | |||
| :param val: a optional Tensor parameter used for ``mgb_index``. | |||
| """ | |||
| def __init__(self, dest: "Tensor", mgb_index: Callable, val=None): | |||
| self.dest = dest | |||
| self.val = val | |||
| self.mgb_index = mgb_index | |||
| def __getitem__(self, idx): | |||
| if self.val is None: | |||
| return wrap_io_tensor(self.mgb_index(self.dest._symvar).__getitem__)( | |||
| _wrap_idx(idx) | |||
| ) | |||
| else: | |||
| return wrap_io_tensor( | |||
| self.mgb_index(self.dest._symvar, self.val._symvar).__getitem__ | |||
| )(_wrap_idx(idx)) | |||
| class _Guard: | |||
| r""" | |||
| A wrapper class with custom ``__del__`` method calling ``deleter``. | |||
| :param deleter: a function to be called in ``__del__``. | |||
| """ | |||
| def __init__(self, deleter: Callable): | |||
| self.deleter = deleter | |||
| def __del__(self): | |||
| self.deleter() | |||
| class Tensor: | |||
| r"""The main data container in MegEngine. | |||
| Use :func:`~.tensor` to create a Tensor with existed data. | |||
| """ | |||
| requires_grad = False | |||
| grad = None | |||
| def __init__(self, val=None, *, requires_grad=None): | |||
| self._reset(val, requires_grad=requires_grad) | |||
| self.q_dict = {"mode": None, "scale": None, "zero_point": None} | |||
| def _reset(self, val=None, *, requires_grad=None): | |||
| self.__sym_override = None | |||
| if val is None: | |||
| self.__val = None | |||
| self.__sym = None | |||
| elif isinstance(val, mgb.SharedND): | |||
| self.__val = val | |||
| self.__sym = None | |||
| elif isinstance(val, mgb.SymbolVar): | |||
| self.__val = None | |||
| self.__sym = val | |||
| else: | |||
| raise TypeError("must be initialized with SymbolVar or SharedND") | |||
| self.requires_grad = requires_grad | |||
| def _as_tensor(self, obj): | |||
| r"""Convert the data into a ``Tensor``. If the data is already a Tensor | |||
| with the same dtype and device, no copy will be performed. Otherwise a | |||
| new Tensor will be returned with computational graph retained. | |||
| """ | |||
| if isinstance(obj, Tensor): | |||
| return obj | |||
| if isinstance(obj, mgb.SymbolVar): | |||
| return Tensor(obj) | |||
| if isinstance(obj, mgb.SharedScalar): | |||
| return Tensor(obj._as_sym_var(self._comp_graph, self._comp_node)) | |||
| return tensor(data=obj, device=self.device) | |||
| def numpy(self): | |||
| r"""Return the tensor value in numpy.ndarray format. | |||
| """ | |||
| if self.__val is not None: | |||
| assert self.__sym is None | |||
| return self.__val.get_value() | |||
| if self.__sym is None: | |||
| raise ValueError("uninitialized") | |||
| if self.__sym.eager_val is not None: | |||
| return self.__sym.eager_val.get_value() | |||
| return self.__sym.inferred_value | |||
| def item(self): | |||
| r"""If tensor only has only one value, return it.""" | |||
| return self.numpy().item() | |||
| def _attach(self, comp_graph, *, volatile=True): | |||
| sym = self.__sym_override or self.__sym | |||
| if sym: | |||
| if sym.owner_graph != comp_graph: | |||
| raise RuntimeError("internal error") | |||
| return sym | |||
| if self.__val: | |||
| return self.__val.symvar(comp_graph, volatile=volatile) | |||
| else: | |||
| raise ValueError("uninitialized") | |||
| @property | |||
| def _symvar(self): | |||
| if self.__sym_override: | |||
| return self.__sym_override | |||
| if self.__sym: | |||
| assert not self.__val | |||
| return self.__sym | |||
| if not self.__val: | |||
| raise ValueError("uninitialized") | |||
| return self._attach(get_default_graph()) | |||
| def __mgb_symvar__(self, comp_graph=None, **_): | |||
| if self.__sym_override: | |||
| return self.__sym_override | |||
| if self.__val and comp_graph: | |||
| return self._attach(comp_graph) | |||
| return self._symvar # read by mgb.opr | |||
| def _override_symvar_during_trace(self, trace, symvar): | |||
| assert self.__val and not self.__sym | |||
| assert trace is type(trace)._active_instance | |||
| deleters = trace._user_cache.setdefault(Tensor, set()) | |||
| self_ref = weakref.ref(self) | |||
| def restore(): | |||
| self = self_ref() | |||
| if self is not None: | |||
| self.__sym_override = None | |||
| deleters.add(_Guard(restore)) | |||
| self.__sym_override = symvar | |||
| @property | |||
| def dtype(self): | |||
| r"""Return the data type of the tensor. | |||
| """ | |||
| if self.__val is not None: | |||
| return self.__val.dtype | |||
| return self._symvar.dtype | |||
| @dtype.setter | |||
| def dtype(self, dtype: str = None): | |||
| r"""Set the data type of the tensor. | |||
| """ | |||
| if self.__val is not None: | |||
| self.__val = mgb.make_shared(self.device, value=self.astype(dtype).numpy()) | |||
| elif self.__sym_override is not None: | |||
| self.__sym_override = self.__sym_override.astype(dtype) | |||
| elif self.__sym is not None: | |||
| self.__sym = self.__sym.astype(dtype) | |||
| @property | |||
| def name(self): | |||
| r"""Get the tensor name, does not support Parameter and Buffer. | |||
| """ | |||
| return self._symvar.name | |||
| @name.setter | |||
| def name(self, name: str = None): | |||
| r"""Set the tensor name, does not support Parameter and Buffer. | |||
| """ | |||
| if self.__val is not None: | |||
| raise ValueError("name setting is not available for Parameter or Buffer.") | |||
| if self.__sym_override is not None: | |||
| self.__sym_override = self.__sym_override.rename(name) | |||
| if self.__sym is not None: | |||
| assert not self.__val | |||
| self.__sym = self.__sym.rename(name) | |||
| @property | |||
| def _comp_node(self): | |||
| if self.__val is not None: | |||
| return self.__val.comp_node | |||
| return self._symvar.comp_node | |||
| device = _comp_node | |||
| @property | |||
| def _comp_graph(self): | |||
| if self.__sym is not None: | |||
| return self.__sym.owner_graph | |||
| return None | |||
| @property | |||
| def shape(self): | |||
| r"""Return an int tuple that is the shape/layout of the tensor. | |||
| Could be invalid in static graph mode. | |||
| """ | |||
| from ..jit import trace | |||
| if trace._active_instance: # pylint: disable=protected-access | |||
| # NOTE: this is an hack | |||
| shape = mgb.opr.get_var_shape(self._symvar) | |||
| return tuple(Tensor(shape[i]) for i in range(self.ndim)) | |||
| return self._symvar.imm_shape | |||
| def set_value(self, value, *, sync=True, inplace=False, share=False): | |||
| r"""Set value to the tensor. | |||
| """ | |||
| if not self.__val: | |||
| raise ValueError("not detached") | |||
| if isinstance(value, Tensor): | |||
| value = value.__val or value.__sym.eager_val | |||
| self.__val.set_value(value, sync=sync, inplace=inplace, share=share) | |||
| def fill(self, value): | |||
| r"""Fills the tensor with the specified value. | |||
| """ | |||
| self.set_value(np.full(self.shape, value, dtype=self.dtype)) | |||
| def reset_zero(self): | |||
| r"""Reset the tensor and fills with zeros. | |||
| """ | |||
| if not self.__val: | |||
| raise ValueError("not detached") | |||
| self.__val.reset_zero() | |||
| def to(self, device): | |||
| r"""Performs Tensor device conversion, returns Tensor with the specified device. | |||
| """ | |||
| return wrap_io_tensor(mgb.opr.copy)(self, comp_node=device) | |||
| # https://docs.python.org/3/reference/datamodel.html#object.__hash__ | |||
| # > If a class does not define an __eq__() method it should not define a | |||
| # > __hash__() operation either | |||
| __hash__ = None # type: ignore[assignment] | |||
| def __eq__(self, rhs): | |||
| rhs = self._as_tensor(rhs) | |||
| return Tensor(self._symvar._binary_opr("EQ", rhs._symvar)) | |||
| def __ne__(self, rhs): | |||
| return 1 - self.__eq__(rhs) | |||
| def __len__(self): | |||
| if self._symvar.eager_val is not None: | |||
| return self._symvar.eager_val.shape[0] | |||
| raise TypeError( | |||
| "__len__ and __iter__ is not available for tensors on non eager graph." | |||
| ) | |||
| __add__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__add__) | |||
| __radd__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__radd__) | |||
| __sub__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__sub__) | |||
| __rsub__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__rsub__) | |||
| __mul__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__mul__) | |||
| __rmul__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__rmul__) | |||
| __matmul__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__matmul__) | |||
| __rmatmul__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__rmatmul__) | |||
| __lshift__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__lshift__) | |||
| __rshift__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__rshift__) | |||
| __truediv__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__truediv__) | |||
| __rtruediv__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__rtruediv__) | |||
| __floordiv__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__floordiv__) | |||
| __rfloordiv__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__rfloordiv__) | |||
| __mod__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__mod__) | |||
| __rmod__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__rmod__) | |||
| __pow__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__pow__) | |||
| __rpow__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__rpow__) | |||
| __lt__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__lt__) | |||
| __gt__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__gt__) | |||
| __le__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__le__) | |||
| __ge__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__ge__) | |||
| __neg__ = wrap_io_tensor(mgb.SymbolVar.__neg__) | |||
| sum = wrap_io_tensor(mgb.SymbolVar.sum) | |||
| """ | |||
| Sum up the given tensors. | |||
| """ | |||
| max = wrap_io_tensor(mgb.SymbolVar.max) | |||
| """ | |||
| Return the maximum value of given tensor. | |||
| """ | |||
| min = wrap_io_tensor(mgb.SymbolVar.min) | |||
| """ | |||
| Return the minimum value of given tensor. | |||
| """ | |||
| prod = wrap_io_tensor(mgb.SymbolVar.prod) | |||
| """ | |||
| Return the product value of the given tensor. | |||
| """ | |||
| mean = wrap_io_tensor(mgb.SymbolVar.mean) | |||
| """ | |||
| Return the mean value of the given tensor. | |||
| """ | |||
| dimshuffle = wrap_io_tensor(mgb.SymbolVar.dimshuffle) | |||
| """ | |||
| See more details in :func:`~.functional.tensor.dimshuffle`. | |||
| """ | |||
| astype = wrap_io_tensor(mgb.SymbolVar.astype) | |||
| """ | |||
| Cast the tensor to a specified type. | |||
| """ | |||
| def reshape(self, *target_shape): | |||
| r"""Return a tensor which has given target shape | |||
| Examples: | |||
| .. testcode:: | |||
| import numpy as np | |||
| from megengine import tensor | |||
| inp = tensor(np.arange(1, 17, dtype=np.int32).reshape(4,4)) | |||
| out = tensor(np.arange(100, 116, dtype=np.int32).reshape(1,16)) | |||
| out = out.reshape(inp.shape) | |||
| print(out.numpy()) | |||
| .. testoutput:: | |||
| [[100 101 102 103] | |||
| [104 105 106 107] | |||
| [108 109 110 111] | |||
| [112 113 114 115]] | |||
| """ | |||
| if isinstance(target_shape[0], tuple): | |||
| if len(target_shape) > 1: | |||
| raise ValueError("Only single tuple is accepted in reshape") | |||
| target_shape = target_shape[0] | |||
| target_shape = (t._symvar if isinstance(t, Tensor) else t for t in target_shape) | |||
| return Tensor(mgb.SymbolVar.reshape(self._symvar, *target_shape)) | |||
| def broadcast(self, *target_shape): | |||
| r"""Return a tesnor broadcasted by current tensor to given target shape | |||
| Examples: | |||
| .. testcode:: | |||
| import numpy as np | |||
| from megengine import tensor | |||
| data = tensor(np.arange(100, 104, dtype=np.int32).reshape(1,4)) | |||
| data = data.broadcast((4,4)) | |||
| print(data.numpy()) | |||
| .. testoutput:: | |||
| [[100 101 102 103] | |||
| [100 101 102 103] | |||
| [100 101 102 103] | |||
| [100 101 102 103]] | |||
| """ | |||
| if isinstance(target_shape[0], tuple): | |||
| if len(target_shape) > 1: | |||
| raise ValueError("Only single tuple is accepted in broadcast") | |||
| target_shape = target_shape[0] | |||
| target_shape = (t._symvar if isinstance(t, Tensor) else t for t in target_shape) | |||
| return Tensor(mgb.SymbolVar.broadcast(self._symvar, *target_shape)) | |||
| # Prefer operators on Tensor instead of convert to numpy | |||
| __array_priority__ = 1000 | |||
| # mgb indexing family | |||
| def __getitem__(self, idx): | |||
| return wrap_io_tensor(self._symvar.__getitem__)(_wrap_idx(idx)) | |||
| def set_subtensor(self, val: "Tensor") -> _MGBIndexWrapper: | |||
| r""" | |||
| Return a object which supports using ``__getitem__`` to set subtensor. | |||
| ``c = a.set_subtensor(b)[idx]`` is equivalent to ``c = a.copy()`` and ``c[idx] = b``. | |||
| """ | |||
| return _MGBIndexWrapper(self, mgb.opr.set_subtensor, val) | |||
| def incr_subtensor(self, val: "Tensor") -> _MGBIndexWrapper: | |||
| r""" | |||
| Return a object which supports using ``__getitem__`` to increase subtensor. | |||
| ``c = a.incr_subtensor(b)[idx]`` is equivalent to ``c = a.copy()`` and ``c[idx] += b``. | |||
| """ | |||
| return _MGBIndexWrapper(self, mgb.opr.incr_subtensor, val) | |||
| @property | |||
| def ai(self) -> _MGBIndexWrapper: | |||
| r""" | |||
| Return a object which supports complex index method to get subtensor. | |||
| Examples: | |||
| .. testcode:: | |||
| from megengine import tensor | |||
| a = tensor(np.arange(16, dtype=np.float32).reshape((4, 4))) | |||
| print(a.ai[:, [2, 3]]) | |||
| Outputs: | |||
| .. testoutput:: | |||
| Tensor([[ 2. 3.] | |||
| [ 6. 7.] | |||
| [10. 11.] | |||
| [14. 15.]]) | |||
| """ | |||
| return _MGBIndexWrapper(self, mgb.opr.advanced_indexing) | |||
| def set_ai(self, val: "Tensor") -> _MGBIndexWrapper: | |||
| r""" | |||
| Equal to :meth:`~.Tensor.set_subtensor` which supports advanced indexing. | |||
| """ | |||
| return _MGBIndexWrapper(self, mgb.opr.set_advanced_indexing, val) | |||
| def incr_ai(self, val: "Tensor") -> _MGBIndexWrapper: | |||
| r""" | |||
| Equal to :meth:`~.Tensor.incr_subtensor` which supports advanced indexing. | |||
| """ | |||
| return _MGBIndexWrapper(self, mgb.opr.incr_advanced_indexing, val) | |||
| @property | |||
| def mi(self) -> _MGBIndexWrapper: | |||
| r""" | |||
| Return a object which supports getting subtensor by | |||
| the coordinates which is Cartesian product of given index. | |||
| Examples: | |||
| .. testcode:: | |||
| from megengine import tensor | |||
| a = tensor(np.arange(16, dtype=np.float32).reshape((4, 4))) | |||
| print(a.mi[[1, 2], [2, 3]]) | |||
| # is equal to elements on [1, 2] * [2, 3] = [[(1,2), (1, 3)], [(2, 2), (2, 3)]] | |||
| # a[1,2] = 6, a[1,3] = 7, a[2,2] = 10, a[2,3] = 11 | |||
| Outputs: | |||
| .. testoutput:: | |||
| Tensor([[ 6. 7.] | |||
| [10. 11.]]) | |||
| """ | |||
| return _MGBIndexWrapper(self, mgb.opr.mesh_indexing) | |||
| def set_mi(self, val: "Tensor") -> _MGBIndexWrapper: | |||
| r""" | |||
| Equal to :meth:`~.Tensor.set_subtensor` which using mesh indexing. | |||
| """ | |||
| return _MGBIndexWrapper(self, mgb.opr.set_mesh_indexing, val) | |||
| def incr_mi(self, val: "Tensor") -> _MGBIndexWrapper: | |||
| r""" | |||
| Equal to :meth:`~.Tensor.incr_subtensor` which using mesh indexing. | |||
| """ | |||
| return _MGBIndexWrapper(self, mgb.opr.incr_mesh_indexing, val) | |||
| @property | |||
| def batched_mi(self) -> _MGBIndexWrapper: | |||
| r""" | |||
| Return a object which supports getting subtensor by | |||
| batched mesh indexing. | |||
| For Tensor ``a`` and index ``idx``, each value of the ``idx`` need to be a 2-dim matrix or slice. | |||
| Cartesian product ``... * idx[k-1][i] * idx[k][i] * idx[k+1][i] * ...`` will be a subtensor from ``a[i]``. | |||
| Each matrix ``idx[k]`` should have the size of ``batched_dim`` rows as ``idx[0]`` indicated. | |||
| And for slice value, it will apply same slice for each ``batched_dim``. For more details see the example below. | |||
| Examples: | |||
| .. testcode:: | |||
| from megengine import tensor | |||
| a = tensor(np.arange(144, dtype=np.float32).reshape((3, 3, 4, 4))) | |||
| print(a.batched_mi[:2, [[0],[1]],[[0,1],[2,3]],[[0],[1]]]) | |||
| # is equal to elements from a[0] with ``[0] * [0,1] * [0] = [[[(0,0,0)], [(0,1,0)]]]``(shape is [1,2,1]) | |||
| # and from a[1] with ``[1] * [2,3] * [1] = [[[(1,2,1)], [(1,3,1)]]]``(shape is also [1,2,1]) | |||
| # a[0,0,0,0] = 0, a[0,0,1,0] = 4, a[1,1,2,1] = 73, a[1,1,3,1] = 77 | |||
| print(a.batched_mi[:2, [[0],[1]], :2, :1]) | |||
| # is equal to ``a.batched_mi[:2, [[0],[1]], [[0,1],[0,1]],[[0],[0]]]`` | |||
| Outputs: | |||
| .. testoutput:: | |||
| Tensor([[[[ 0.] | |||
| [ 4.]]] | |||
| [[[73.] | |||
| [77.]]]]) | |||
| Tensor([[[[ 0.] | |||
| [ 4.]]] | |||
| [[[64.] | |||
| [68.]]]]) | |||
| """ | |||
| return _MGBIndexWrapper(self, mgb.opr.batched_mesh_indexing) | |||
| def batched_set_mi(self, val: "Tensor") -> _MGBIndexWrapper: | |||
| r""" | |||
| Equal to :meth:`~.Tensor.incr_subtensor` which using batched mesh indexing. | |||
| """ | |||
| return _MGBIndexWrapper(self, mgb.opr.batched_set_mesh_indexing, val) | |||
| def batched_incr_mi(self, val: "Tensor") -> _MGBIndexWrapper: | |||
| r""" | |||
| Equal to :meth:`~.Tensor.incr_subtensor` which using batched mesh indexing. | |||
| """ | |||
| return _MGBIndexWrapper(self, mgb.opr.batched_incr_mesh_indexing, val) | |||
| def __array__(self, dtype=None): | |||
| if dtype is None: | |||
| return self.numpy() | |||
| else: | |||
| return self.numpy().astype(dtype, copy=False) | |||
| def __int__(self): | |||
| return int(self.item()) | |||
| def __index__(self): | |||
| return int(self.item()) | |||
| def __round__(self, ndigits=0): | |||
| if ndigits != 0: | |||
| raise ValueError("ndigits must be 0 for Tensor.round") | |||
| return Tensor(mgb.opr.elemwise([self._symvar], mode="ROUND")) | |||
| round = __round__ | |||
| def sqrt(self): | |||
| r"""Return a tensor that each element is the square root of its | |||
| original value. | |||
| """ | |||
| return Tensor(mgb.opr.sqrt(self._symvar)) | |||
| def shapeof(self, axis=None): | |||
| r"""Return a Tensor that represent the shape of the tensor. | |||
| """ | |||
| return Tensor(mgb.opr.get_var_shape(self._symvar, axis=axis)) | |||
| @property | |||
| def ndim(self): | |||
| r"""Return the number of dimensions of the tensor. | |||
| """ | |||
| return len(self._symvar.imm_shape) | |||
| def __repr__(self): | |||
| piece = "Tensor(" | |||
| with np.printoptions(precision=4, suppress=True): | |||
| piece += "{}".format(str(self.numpy())) | |||
| if self.dtype != np.float32: | |||
| piece += ", dtype={}".format(np.dtype(self.dtype).name) | |||
| if self._comp_node.locator_logical != ("XPU", -1, 0): | |||
| piece += ", device={}".format(self.device) | |||
| piece += ")" | |||
| return piece | |||
| def __bool__(self): | |||
| raise RuntimeError( | |||
| "Tensor object should not be converted to bool or used in a if statement. Use .numpy(), int() or float() if you want to use its value in if statement, be aware that this may lead to incorrect result in non-eager mode." | |||
| ) | |||
| def __getstate__(self): | |||
| r""" __getstate__ will be called for pickle serialization or deep copy | |||
| """ | |||
| assert (self.__val is not None) and ( | |||
| self.__sym is None | |||
| ), "Only SharedND initialized Tensor can be serialized or deep copied" | |||
| metadata = {"requires_grad": self.requires_grad} | |||
| state = { | |||
| "data": self.numpy(), | |||
| "device": self.device, | |||
| "dtype": self.dtype, | |||
| "metadata": metadata, | |||
| } | |||
| return state | |||
| def __setstate__(self, state): | |||
| data = state.pop("data") | |||
| device = state.pop("device") | |||
| dtype = state.pop("dtype") | |||
| metadata = state.pop("metadata", {}) | |||
| requires_grad = metadata.pop("requires_grad", None) | |||
| snd = mgb.make_shared(device, value=data, dtype=dtype) | |||
| self._reset(snd, requires_grad=requires_grad) | |||
| def __deepcopy__(self, memo): | |||
| """ | |||
| The default deepcopy will ignore other attributes except those defined at | |||
| __getstate__ and __setstate__ method. | |||
| So we need to add __deepcopy__ method to deepcopy correct attributes. | |||
| """ | |||
| assert (self.__val is not None) and ( | |||
| self.__sym is None | |||
| ), "Only SharedND initialized Tensor can be serialized or deep copied" | |||
| cls = self.__class__ | |||
| result = cls.__new__(cls) | |||
| memo[id(self)] = result | |||
| for k, v in self.__dict__.items(): | |||
| setattr(result, k, copy.deepcopy(v, memo)) | |||
| return result | |||
| def tensor( | |||
| data: Union[list, np.ndarray] = None, | |||
| *, | |||
| dtype: str = None, | |||
| device: mgb.CompNode = None, | |||
| requires_grad: bool = None | |||
| ): | |||
| r"""A helper function to create a :class:`~.Tensor` using existing data. | |||
| :param data: an existing data array, must be Python list, NumPy array or None. | |||
| :param dtype: target Tensor data type, one of ``("uint8", "int8", "int16", "int32", "float32", "float16")``. | |||
| :param device: target device for Tensor storing. | |||
| :param requires_grad: whether its gradiant will be calculated during :meth:`~.Optimizer.backward` | |||
| """ | |||
| supported_dtypes = ("uint8", "int8", "int16", "int32", "float32", "float16") | |||
| if isinstance(data, Tensor): | |||
| raise NotImplementedError | |||
| if dtype is not None and np.dtype(dtype).name not in supported_dtypes: | |||
| raise TypeError("unsupported dtype {}".format(dtype)) | |||
| if data is not None: | |||
| if not isinstance(data, np.ndarray): | |||
| data = np.array(data, dtype=dtype) | |||
| # In order to accept tensor([1]), | |||
| # Automaticlly convert to 32-bit number instead of numpy's default 64-bit when input data is not nparray. | |||
| dtype = mgb.to_mgb_supported_dtype(data.dtype) | |||
| if dtype is None: | |||
| if data.dtype.name not in supported_dtypes: | |||
| raise TypeError("unsupported dtype {}".format(data.dtype)) | |||
| device, _ = _use_default_if_none(device, None) | |||
| shared_nd = mgb.make_shared(device, value=data, dtype=dtype) | |||
| return Tensor(shared_nd, requires_grad=requires_grad) | |||
| class TensorDict(collections.MutableMapping): | |||
| r""" | |||
| A helper class to maintain dict with Tensor key. | |||
| """ | |||
| def __init__(self, *args, **kwargs): | |||
| self.data = {} | |||
| for i in args: | |||
| self.update(i) | |||
| self.update(**kwargs) | |||
| class keyfn: | |||
| def __new__(cls, x: Tensor): | |||
| if not isinstance(x, Tensor): | |||
| return x | |||
| return super().__new__(cls) | |||
| def __init__(self, x: Tensor): | |||
| self._data = x # do not save id directly to make pickle work | |||
| def __hash__(self): | |||
| return id(self._data) | |||
| def __eq__(self, other): | |||
| return isinstance(other, type(self)) and id(self._data) == id(other._data) | |||
| def __getitem__(self, key): | |||
| _, v = self.data[self.keyfn(key)] | |||
| return v | |||
| def __setitem__(self, key, value): | |||
| self.data[self.keyfn(key)] = key, value | |||
| def __delitem__(self, key): | |||
| del self.data[self.keyfn(key)] | |||
| def __iter__(self): | |||
| for _, (k, _) in self.data.items(): | |||
| yield k | |||
| def __len__(self): | |||
| return len(self.data) | |||
| @@ -1,109 +0,0 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from typing import Iterable, Optional, Union | |||
| import megengine._internal as mgb | |||
| from .graph import _use_default_if_none | |||
| from .tensor import Tensor | |||
| __all__ = ["zeros", "ones"] | |||
| def scalar( | |||
| value, | |||
| dtype: type = None, | |||
| device: Optional[mgb.CompNode] = None, | |||
| comp_graph: Optional[mgb.CompGraph] = None, | |||
| ) -> Tensor: | |||
| """ | |||
| convert ``value`` to the type of :class:`~.Tensor`. | |||
| """ | |||
| device, comp_graph = _use_default_if_none(device, comp_graph) | |||
| return Tensor(mgb.make_immutable(device, comp_graph, value, dtype=dtype, name=None)) | |||
| def zeros( | |||
| shape: Union[int, Iterable[int], Tensor], | |||
| dtype: type = None, | |||
| device: Optional[mgb.CompNode] = None, | |||
| comp_graph: Optional[mgb.CompGraph] = None, | |||
| ) -> Tensor: | |||
| """ | |||
| Create a tensor filled with 0. | |||
| :param shape: tensor shape | |||
| :param dtype: data type, Default: "int32" | |||
| :param device: Compute node of the matrix, Default: None | |||
| :param comp_graph: Compute graph of the matrix, Default: None | |||
| :return: tensor of zeros | |||
| Examples: | |||
| .. testcode:: | |||
| import megengine as mge | |||
| t = mge.zeros((2, 2), dtype="int32") | |||
| print(t.numpy()) | |||
| Outputs: | |||
| .. testoutput:: | |||
| [[0 0] | |||
| [0 0]] | |||
| """ | |||
| device, comp_graph = _use_default_if_none(device, comp_graph) | |||
| if isinstance(shape, (int, Tensor)): | |||
| shape = (shape,) | |||
| tensor = scalar(0, dtype=dtype, device=device, comp_graph=comp_graph) | |||
| tensor = tensor.broadcast(*shape) | |||
| return tensor | |||
| def ones( | |||
| shape: Union[int, Iterable[int], Tensor], | |||
| dtype: type = None, | |||
| device: Optional[mgb.CompNode] = None, | |||
| comp_graph: Optional[mgb.CompGraph] = None, | |||
| ) -> Tensor: | |||
| """ | |||
| Create a tensor filled with 1. | |||
| :param shape: tensor shape | |||
| :param dtype: data type, Default: "int32" | |||
| :param device: Compute node of the matrix, Default: None | |||
| :param comp_graph: Compute graph of the matrix, Default: None | |||
| :return: tensor of ones | |||
| Examples: | |||
| .. testcode:: | |||
| import megengine as mge | |||
| t = mge.ones((2, 2), dtype="float32") | |||
| print(t.numpy()) | |||
| Outputs: | |||
| .. testoutput:: | |||
| [[1. 1.] | |||
| [1. 1.]] | |||
| """ | |||
| device, comp_graph = _use_default_if_none(device, comp_graph) | |||
| if isinstance(shape, (int, Tensor)): | |||
| shape = (shape,) | |||
| tensor = scalar(1, dtype=dtype, device=device, comp_graph=comp_graph) | |||
| tensor = tensor.broadcast(*shape) | |||
| return tensor | |||
| @@ -1,45 +0,0 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from .tensor import Tensor, tensor | |||
| class Buffer(Tensor): | |||
| r"""A kind of Tensor with ``requires_grad=False``. | |||
| """ | |||
| def __init__(self, value, *, dtype=None, device=None, requires_grad=False): | |||
| # pylint: disable=super-init-not-called | |||
| t = tensor(value, dtype=dtype, device=device, requires_grad=requires_grad) | |||
| self.__dict__.update(t.__dict__) | |||
| class Parameter(Tensor): | |||
| r"""A kind of Tensor that is to be considered a module parameter. | |||
| """ | |||
| def __init__(self, value, *, dtype=None, device=None, requires_grad=True): | |||
| # pylint: disable=super-init-not-called | |||
| if isinstance(value, Tensor): | |||
| t = value | |||
| else: | |||
| t = tensor(value, dtype=dtype, device=device, requires_grad=requires_grad) | |||
| self.__dict__.update(t.__dict__) | |||
| # broadcast and allreduce will not be performed in optimizer if replica_mode is False | |||
| self.replica_mode = True | |||
| @property | |||
| def shape(self): | |||
| r"""Return shape of parameter. | |||
| """ | |||
| if self._Tensor__val is not None: | |||
| return self._Tensor__val.shape | |||
| elif self._Tensor__sym is not None: | |||
| return self._Tensor__sym.imm_shape | |||
| return None | |||
| @@ -1,17 +0,0 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from .collator import Collator | |||
| from .dataloader import DataLoader | |||
| from .sampler import ( | |||
| Infinite, | |||
| RandomSampler, | |||
| ReplacementSampler, | |||
| Sampler, | |||
| SequentialSampler, | |||
| ) | |||
| @@ -1,144 +0,0 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import binascii | |||
| import os | |||
| import queue | |||
| import subprocess | |||
| from multiprocessing import Queue | |||
| import pyarrow | |||
| import pyarrow.plasma as plasma | |||
| MGE_PLASMA_MEMORY = int(os.environ.get("MGE_PLASMA_MEMORY", 4000000000)) # 4GB | |||
| # Each process only need to start one plasma store, so we set it as a global variable. | |||
| # TODO: how to share between different processes? | |||
| MGE_PLASMA_STORE_MANAGER = None | |||
| def _clear_plasma_store(): | |||
| # `_PlasmaStoreManager.__del__` will not be called automaticly in subprocess, | |||
| # so this function should be called explicitly | |||
| global MGE_PLASMA_STORE_MANAGER | |||
| if MGE_PLASMA_STORE_MANAGER is not None and MGE_PLASMA_STORE_MANAGER.refcount == 0: | |||
| del MGE_PLASMA_STORE_MANAGER | |||
| MGE_PLASMA_STORE_MANAGER = None | |||
| class _PlasmaStoreManager: | |||
| __initialized = False | |||
| def __init__(self): | |||
| self.socket_name = "/tmp/mge_plasma_{}".format( | |||
| binascii.hexlify(os.urandom(8)).decode() | |||
| ) | |||
| debug_flag = bool(os.environ.get("MGE_DATALOADER_PLASMA_DEBUG", 0)) | |||
| # NOTE: this is a hack. Directly use `plasma_store` may make subprocess | |||
| # difficult to handle the exception happened in `plasma-store-server`. | |||
| # For `plasma_store` is just a wrapper of `plasma-store-server`, which use | |||
| # `os.execv` to call the executable `plasma-store-server`. | |||
| cmd_path = os.path.join(pyarrow.__path__[0], "plasma-store-server") | |||
| self.plasma_store = subprocess.Popen( | |||
| [cmd_path, "-s", self.socket_name, "-m", str(MGE_PLASMA_MEMORY),], | |||
| stdout=None if debug_flag else subprocess.DEVNULL, | |||
| stderr=None if debug_flag else subprocess.DEVNULL, | |||
| ) | |||
| self.__initialized = True | |||
| self.refcount = 1 | |||
| def __del__(self): | |||
| if self.__initialized and self.plasma_store.returncode is None: | |||
| self.plasma_store.kill() | |||
| class PlasmaShmQueue: | |||
| def __init__(self, maxsize: int = 0): | |||
| r"""Use pyarrow in-memory plasma store to implement shared memory queue. | |||
| Compared to native `multiprocess.Queue`, `PlasmaShmQueue` avoid pickle/unpickle | |||
| and communication overhead, leading to better performance in multi-process | |||
| application. | |||
| :type maxsize: int | |||
| :param maxsize: maximum size of the queue, `None` means no limit. (default: ``None``) | |||
| """ | |||
| # Lazy start the plasma store manager | |||
| global MGE_PLASMA_STORE_MANAGER | |||
| if MGE_PLASMA_STORE_MANAGER is None: | |||
| try: | |||
| MGE_PLASMA_STORE_MANAGER = _PlasmaStoreManager() | |||
| except Exception as e: | |||
| err_info = ( | |||
| "Please make sure pyarrow installed correctly!\n" | |||
| "You can try reinstall pyarrow and see if you can run " | |||
| "`plasma_store -s /tmp/mge_plasma_xxx -m 1000` normally." | |||
| ) | |||
| raise RuntimeError( | |||
| "Exception happened in starting plasma_store: {}\n" | |||
| "Tips: {}".format(str(e), err_info) | |||
| ) | |||
| else: | |||
| MGE_PLASMA_STORE_MANAGER.refcount += 1 | |||
| self.socket_name = MGE_PLASMA_STORE_MANAGER.socket_name | |||
| # TODO: how to catch the exception happened in `plasma.connect`? | |||
| self.client = None | |||
| # Used to store the header for the data.(ObjectIDs) | |||
| self.queue = Queue(maxsize) # type: Queue | |||
| def put(self, data, block=True, timeout=None): | |||
| if self.client is None: | |||
| self.client = plasma.connect(self.socket_name) | |||
| try: | |||
| object_id = self.client.put(data) | |||
| except plasma.PlasmaStoreFull: | |||
| raise RuntimeError("plasma store out of memory!") | |||
| try: | |||
| self.queue.put(object_id, block, timeout) | |||
| except queue.Full: | |||
| self.client.delete([object_id]) | |||
| raise queue.Full | |||
| def get(self, block=True, timeout=None): | |||
| if self.client is None: | |||
| self.client = plasma.connect(self.socket_name) | |||
| object_id = self.queue.get(block, timeout) | |||
| if not self.client.contains(object_id): | |||
| raise RuntimeError( | |||
| "ObjectID: {} not found in plasma store".format(object_id) | |||
| ) | |||
| data = self.client.get(object_id) | |||
| self.client.delete([object_id]) | |||
| return data | |||
| def qsize(self): | |||
| return self.queue.qsize() | |||
| def empty(self): | |||
| return self.queue.empty() | |||
| def join(self): | |||
| self.queue.join() | |||
| def disconnect_client(self): | |||
| if self.client is not None: | |||
| self.client.disconnect() | |||
| def close(self): | |||
| self.queue.close() | |||
| self.disconnect_client() | |||
| global MGE_PLASMA_STORE_MANAGER | |||
| MGE_PLASMA_STORE_MANAGER.refcount -= 1 | |||
| _clear_plasma_store() | |||
| def cancel_join_thread(self): | |||
| self.queue.cancel_join_thread() | |||
| @@ -1,76 +0,0 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # Copyright (c) 2016- Facebook, Inc (Adam Paszke) | |||
| # Copyright (c) 2014- Facebook, Inc (Soumith Chintala) | |||
| # Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) | |||
| # Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) | |||
| # Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) | |||
| # Copyright (c) 2011-2013 NYU (Clement Farabet) | |||
| # Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) | |||
| # Copyright (c) 2006 Idiap Research Institute (Samy Bengio) | |||
| # Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) | |||
| # --------------------------------------------------------------------- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # | |||
| # This file has been modified by Megvii ("Megvii Modifications"). | |||
| # All Megvii Modifications are Copyright (C) 2014-2020 Megvii Inc. All rights reserved. | |||
| # ---------------------------------------------------------------------- | |||
| import collections.abc | |||
| import re | |||
| import numpy as np | |||
| np_str_obj_array_pattern = re.compile(r"[aO]") | |||
| default_collate_err_msg_format = ( | |||
| "default_collator: inputs must contain numpy arrays, numbers, " | |||
| "Unicode strings, bytes, dicts or lists; found {}" | |||
| ) | |||
| class Collator: | |||
| r""" | |||
| Used for merge a list of samples to form a mini-batch of Tenor(s). Used when using batched loading from a dataset. | |||
| modified from https://github.com/pytorch/pytorch/blob/master/torch/utils/data/_utils/collate.py | |||
| """ | |||
| def apply(self, inputs): | |||
| """ | |||
| input : sequence_N(tuple(CHW, C, CK)) | |||
| output : tuple(NCHW, NC, NCK) | |||
| """ | |||
| elem = inputs[0] | |||
| elem_type = type(elem) | |||
| if ( | |||
| elem_type.__module__ == "numpy" | |||
| and elem_type.__name__ != "str_" | |||
| and elem_type.__name__ != "string_" | |||
| ): | |||
| elem = inputs[0] | |||
| if elem_type.__name__ == "ndarray": | |||
| # array of string classes and object | |||
| if np_str_obj_array_pattern.search(elem.dtype.str) is not None: | |||
| raise TypeError(default_collate_err_msg_format.format(elem.dtype)) | |||
| return np.ascontiguousarray(np.stack(inputs)) | |||
| elif elem.shape == (): # scalars | |||
| return np.array(inputs) | |||
| elif isinstance(elem, float): | |||
| return np.array(inputs, dtype=np.float64) | |||
| elif isinstance(elem, int): | |||
| return np.array(inputs) | |||
| elif isinstance(elem, (str, bytes)): | |||
| return inputs | |||
| elif isinstance(elem, collections.abc.Mapping): | |||
| return {key: self.apply([d[key] for d in inputs]) for key in elem} | |||
| elif isinstance(elem, tuple) and hasattr(elem, "_fields"): # namedtuple | |||
| return elem_type(*(self.apply(samples) for samples in zip(*inputs))) | |||
| elif isinstance(elem, collections.abc.Sequence): | |||
| transposed = zip(*inputs) | |||
| return [self.apply(samples) for samples in transposed] | |||
| raise TypeError(default_collate_err_msg_format.format(elem_type)) | |||
| @@ -1,500 +0,0 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import collections | |||
| import math | |||
| import multiprocessing | |||
| import queue | |||
| import random | |||
| import time | |||
| import numpy as np | |||
| from ..logger import get_logger | |||
| from ..random.rng import _random_seed_generator | |||
| from .collator import Collator | |||
| from .dataset import Dataset | |||
| from .sampler import Sampler, SequentialSampler | |||
| from .transform import PseudoTransform, Transform | |||
| logger = get_logger(__name__) | |||
| MP_QUEUE_GET_TIMEOUT = 5 | |||
| class DataLoader: | |||
| __initialized = False | |||
| def __init__( | |||
| self, | |||
| dataset: Dataset, | |||
| sampler: Sampler = None, | |||
| transform: Transform = None, | |||
| collator: Collator = None, | |||
| num_workers: int = 0, | |||
| timeout: int = 0, | |||
| divide: bool = False, | |||
| ): | |||
| r"""Provides a convenient way to iterate on a given dataset. | |||
| `DataLoader` combines a dataset with sampler, transform and collator, | |||
| make it flexible to get minibatch continually from a dataset. | |||
| :type dataset: Dataset | |||
| :param dataset: dataset from which to load the minibatch. | |||
| :type sampler: Sampler | |||
| :param sampler: defines the strategy to sample data from the dataset. | |||
| If specified, :attr:`shuffle` must be ``False``. | |||
| :type transform: Transform | |||
| :param transform: defined the transforming strategy for a sampled batch. | |||
| (default: ``None``) | |||
| :type collator: Collator | |||
| :param collator: defined the merging strategy for a transformed batch. | |||
| (default: ``None``) | |||
| :type num_workers: int | |||
| :param num_workers: the number of sub-process to load, transform and collate | |||
| the batch. ``0`` means using single-process. (default: ``0``) | |||
| :type timeout: int | |||
| :param timeout: if positive, means the timeout value(second) for collecting a | |||
| batch from workers. (default: 0) | |||
| :type divide: bool | |||
| :param divide: define the paralleling strategy in multi-processing mode. | |||
| ``True`` means one batch is divided into :attr:`num_workers` pieces, and | |||
| the workers will process these pieces parallelly. ``False`` means | |||
| different sub-process will process different batch. (default: ``False``) | |||
| """ | |||
| if num_workers < 0: | |||
| raise ValueError("num_workers should not be negative") | |||
| if timeout < 0: | |||
| raise ValueError("timeout should not be negative") | |||
| if divide and num_workers <= 1: | |||
| raise ValueError("divide should not be set to True when num_workers <= 1") | |||
| self.dataset = dataset | |||
| self.num_workers = num_workers | |||
| self.timeout = timeout | |||
| self.divide = divide | |||
| if sampler is None: | |||
| self.sampler = SequentialSampler(dataset, batch_size=1, drop_last=False) | |||
| else: | |||
| self.sampler = sampler | |||
| if divide: | |||
| if self.sampler.batch_size <= self.num_workers: | |||
| raise ValueError( | |||
| "batch size must not smaller than num_workers in divide mode." | |||
| ) | |||
| elif self.sampler.batch_size % self.num_workers: | |||
| logger.warning( | |||
| "batch size is not divisible by num_workers, may lose performance in divide mode." | |||
| ) | |||
| if transform is None: | |||
| self.transform = PseudoTransform() | |||
| else: | |||
| self.transform = transform | |||
| if collator is None: | |||
| self.collator = Collator() | |||
| else: | |||
| self.collator = collator | |||
| self.__initialized = True | |||
| def __iter__(self): | |||
| if self.num_workers == 0: | |||
| return _SerialDataLoaderIter(self) | |||
| else: | |||
| return _ParallelDataLoaderIter(self) | |||
| def __len__(self): | |||
| return len(self.sampler) | |||
| class _BaseDataLoaderIter: | |||
| def __init__(self, loader): | |||
| self.dataset = loader.dataset | |||
| self.sampler = loader.sampler | |||
| self.seed = _random_seed_generator().__next__() | |||
| self.transform = loader.transform | |||
| self.collator = loader.collator | |||
| self.num_workers = loader.num_workers | |||
| self.timeout = loader.timeout | |||
| self.divide = loader.divide | |||
| self.num_processed = 0 | |||
| def _get_next_batch(self): | |||
| raise NotImplementedError | |||
| def __len__(self): | |||
| return len(self.sampler) | |||
| def __iter__(self): | |||
| return self | |||
| def __next__(self): | |||
| if self.num_processed >= len(self): | |||
| raise StopIteration | |||
| minibatch = self._get_next_batch() | |||
| self.num_processed += 1 | |||
| return minibatch | |||
| class _SerialDataLoaderIter(_BaseDataLoaderIter): | |||
| def __init__(self, loader): | |||
| super(_SerialDataLoaderIter, self).__init__(loader) | |||
| self.indices_iter = iter(self.sampler) | |||
| def _get_next_batch(self): | |||
| indices = next(self.indices_iter) | |||
| items = [self.dataset[idx] for idx in indices] | |||
| trans_items = self.transform.apply_batch(items) | |||
| return self.collator.apply(trans_items) | |||
| class _ParallelDataLoaderIter(_BaseDataLoaderIter): | |||
| __initialized = False | |||
| def __init__(self, loader): | |||
| super(_ParallelDataLoaderIter, self).__init__(loader) | |||
| self.task_queues = [ | |||
| multiprocessing.Queue(maxsize=2) for _ in range(self.num_workers) | |||
| ] | |||
| self.feed_batch_idx = multiprocessing.Value("i", 0) | |||
| self.target_batch_idx = multiprocessing.Value("i", 0) | |||
| self.shutdown_flag = multiprocessing.Value("i", 0) | |||
| self.trans_data_queues = [ | |||
| multiprocessing.Queue(maxsize=1) for _ in range(self.num_workers) | |||
| ] | |||
| # use shared-memory queue implemented by pyarrow plasma store. | |||
| from ._queue import PlasmaShmQueue | |||
| self.batch_queue = PlasmaShmQueue(maxsize=2) | |||
| self.task_feeding_worker = multiprocessing.Process( | |||
| target=_task_feeding_loop, | |||
| args=( | |||
| iter(self.sampler), | |||
| self.task_queues, | |||
| self.num_workers, | |||
| self.divide, | |||
| self.shutdown_flag, | |||
| self.feed_batch_idx, | |||
| ), | |||
| daemon=True, | |||
| ) | |||
| self.task_feeding_worker.start() | |||
| self.workers = [] | |||
| for worker_id in range(self.num_workers): | |||
| worker = multiprocessing.Process( | |||
| target=_worker_loop, | |||
| args=( | |||
| self.dataset, | |||
| self.task_queues[worker_id], | |||
| self.trans_data_queues[worker_id], | |||
| self.transform, | |||
| self.seed + worker_id + 1, | |||
| self.shutdown_flag, | |||
| ), | |||
| daemon=True, | |||
| ) | |||
| worker.start() | |||
| self.workers.append(worker) | |||
| if self.divide: | |||
| self.data_collecting_worker = multiprocessing.Process( | |||
| target=_data_gathering_loop, | |||
| args=( | |||
| self.trans_data_queues, | |||
| self.batch_queue, | |||
| self.collator, | |||
| len(self), | |||
| self.num_workers, | |||
| self.shutdown_flag, | |||
| self.target_batch_idx, | |||
| ), | |||
| daemon=True, | |||
| ) | |||
| else: | |||
| self.data_collecting_worker = multiprocessing.Process( | |||
| target=_data_selecting_loop, | |||
| args=( | |||
| self.trans_data_queues, | |||
| self.batch_queue, | |||
| self.collator, | |||
| len(self), | |||
| self.num_workers, | |||
| self.shutdown_flag, | |||
| self.target_batch_idx, | |||
| ), | |||
| daemon=True, | |||
| ) | |||
| self.data_collecting_worker.start() | |||
| self.__initialized = True | |||
| def _check_workers(self): | |||
| # Check the status of each worker. | |||
| if not self.data_collecting_worker.is_alive(): | |||
| exitcode = self.task_feeding_worker.exitcode | |||
| if exitcode != 0: | |||
| raise RuntimeError("data collecting worker died. {}".format(exitcode)) | |||
| if not self.task_feeding_worker.is_alive(): | |||
| exitcode = self.task_feeding_worker.exitcode | |||
| if exitcode != 0: | |||
| raise RuntimeError("task feeding worker died. {}".format(exitcode)) | |||
| for worker_id, worker in enumerate(self.workers): | |||
| if not worker.is_alive(): | |||
| exitcode = worker.exitcode | |||
| if exitcode != 0: | |||
| raise RuntimeError("worker:{} died. {}".format(worker_id, exitcode)) | |||
| logger.debug("all workers are alive.") | |||
| def _try_get_next_batch(self): | |||
| start_time = time.time() | |||
| while True: | |||
| self._check_workers() | |||
| try: | |||
| return self.batch_queue.get(timeout=1) | |||
| except queue.Empty: | |||
| logger.debug("batch queue empty!") | |||
| waited_time = time.time() - start_time | |||
| if self.timeout > 0: | |||
| if waited_time > self.timeout: | |||
| raise RuntimeError("get_next_batch timeout!") | |||
| def _get_next_batch(self): | |||
| batch_data = self._try_get_next_batch() | |||
| return batch_data | |||
| def _shutdown(self): | |||
| with self.shutdown_flag.get_lock(): | |||
| self.shutdown_flag.value = 1 | |||
| if self.task_feeding_worker.is_alive(): | |||
| self.task_feeding_worker.terminate() | |||
| self.task_feeding_worker.join() | |||
| if self.data_collecting_worker.is_alive(): | |||
| self.data_collecting_worker.terminate() | |||
| self.data_collecting_worker.join() | |||
| for worker in self.workers: | |||
| if worker.is_alive(): | |||
| worker.terminate() | |||
| worker.join() | |||
| for q in self.trans_data_queues: | |||
| q.cancel_join_thread() | |||
| q.close() | |||
| for q in self.task_queues: | |||
| q.cancel_join_thread() | |||
| q.close() | |||
| self.batch_queue.cancel_join_thread() | |||
| self.batch_queue.close() | |||
| def __del__(self): | |||
| if self.__initialized: | |||
| self._shutdown() | |||
| def _task_feeding_loop( | |||
| indices_iter, task_queues, num_workers, divide, shutdown_flag, feed_batch_idx | |||
| ): | |||
| # Feed the indices into the task queues | |||
| while True: | |||
| if shutdown_flag.value == 1: | |||
| break | |||
| batch_idx = feed_batch_idx.value | |||
| try: | |||
| indices = next(indices_iter) | |||
| except StopIteration: | |||
| break | |||
| if divide: | |||
| # make sure all task_queues is ready for put | |||
| while any([q.full() for q in task_queues]): | |||
| if shutdown_flag.value == 1: | |||
| return | |||
| # divide into small pieces, feed to different workers. | |||
| sub_num = math.ceil(len(indices) / num_workers) | |||
| for worker_id in range(num_workers): | |||
| sub_indices = indices[worker_id * sub_num : (worker_id + 1) * sub_num] | |||
| task_queues[worker_id].put((batch_idx, sub_indices)) | |||
| else: | |||
| # distribute tasks to different workers uniformly. | |||
| target_id = batch_idx % num_workers | |||
| while task_queues[target_id].full(): | |||
| if shutdown_flag.value == 1: | |||
| return | |||
| task_queues[target_id].put((batch_idx, indices)) | |||
| with feed_batch_idx.get_lock(): | |||
| feed_batch_idx.value += 1 | |||
| def _worker_loop(dataset, task_queue, trans_data_queue, transform, seed, shutdown_flag): | |||
| # Get dataset items and do the transform | |||
| random.seed(seed) | |||
| np.random.seed(seed) | |||
| while True: | |||
| if shutdown_flag.value == 1: | |||
| break | |||
| try: | |||
| batch_idx, indices = task_queue.get(timeout=MP_QUEUE_GET_TIMEOUT) | |||
| except queue.Empty: | |||
| continue | |||
| if len(indices) > 0: | |||
| items = [dataset[idx] for idx in indices] | |||
| trans_items = transform.apply_batch(items) | |||
| else: | |||
| # in case of incomplete last batch | |||
| trans_items = () | |||
| while True: | |||
| try: | |||
| trans_data_queue.put((batch_idx, trans_items), timeout=1) | |||
| break | |||
| except queue.Full: | |||
| if shutdown_flag.value == 1: | |||
| break | |||
| logger.debug("batch part queue is full!") | |||
| def _data_gathering_loop( | |||
| trans_data_queues, | |||
| batch_queue, | |||
| collator, | |||
| length, | |||
| num_workers, | |||
| shutdown_flag, | |||
| target_idx, | |||
| ): | |||
| # Gathering the small pieces of batch data into full batch data | |||
| while True: | |||
| if shutdown_flag.value == 1: | |||
| break | |||
| target_batch_idx = target_idx.value | |||
| if target_batch_idx >= length: | |||
| break | |||
| full_trans_items = [] | |||
| for worker_id in range(num_workers): | |||
| while True: | |||
| try: | |||
| batch_idx, trans_items = trans_data_queues[worker_id].get( | |||
| timeout=MP_QUEUE_GET_TIMEOUT | |||
| ) | |||
| break | |||
| except queue.Empty: | |||
| if shutdown_flag.value == 1: | |||
| break | |||
| logger.debug( | |||
| "worker:{} data queue get timeout! target batch idx:{}".format( | |||
| worker_id, target_batch_idx | |||
| ) | |||
| ) | |||
| if batch_idx != target_batch_idx: | |||
| raise RuntimeError( | |||
| "Unexperted batch_idx in data gathering loop. worker_id:{}.".format( | |||
| worker_id | |||
| ) | |||
| ) | |||
| else: | |||
| full_trans_items.extend(trans_items) | |||
| # Merge different parts into a batch. | |||
| full_batch = collator.apply(full_trans_items) | |||
| while True: | |||
| try: | |||
| batch_queue.put(full_batch, timeout=1) | |||
| break | |||
| except queue.Full: | |||
| if shutdown_flag.value == 1: | |||
| break | |||
| logger.debug("batch queue is full!") | |||
| with target_idx.get_lock(): | |||
| target_idx.value += 1 | |||
| batch_queue.disconnect_client() | |||
| def _data_selecting_loop( | |||
| trans_data_queues, | |||
| batch_queue, | |||
| collator, | |||
| length, | |||
| num_workers, | |||
| shutdown_flag, | |||
| target_idx, | |||
| ): | |||
| # Make sure that batch is generated exactly with the same order as generated indices | |||
| while True: | |||
| if shutdown_flag.value == 1: | |||
| break | |||
| target_batch_idx = target_idx.value | |||
| if target_batch_idx >= length: | |||
| break | |||
| target_worker_id = target_batch_idx % num_workers | |||
| while True: | |||
| try: | |||
| batch_idx, trans_items = trans_data_queues[target_worker_id].get( | |||
| timeout=MP_QUEUE_GET_TIMEOUT | |||
| ) | |||
| batch_data = collator.apply(trans_items) | |||
| break | |||
| except queue.Empty: | |||
| if shutdown_flag.value == 1: | |||
| break | |||
| logger.debug( | |||
| "worker:{} data queue get timeout! target batch idx:{}".format( | |||
| target_worker_id, target_batch_idx | |||
| ) | |||
| ) | |||
| if batch_idx != target_batch_idx: | |||
| raise RuntimeError( | |||
| "batch_idx {} mismatch the target_batch_idx {}".format( | |||
| batch_idx, target_batch_idx | |||
| ) | |||
| ) | |||
| while True: | |||
| try: | |||
| batch_queue.put(batch_data, timeout=1) | |||
| break | |||
| except queue.Full: | |||
| if shutdown_flag.value == 1: | |||
| break | |||
| logger.debug("batch queue is full!") | |||
| with target_idx.get_lock(): | |||
| target_idx.value += 1 | |||
| batch_queue.disconnect_client() | |||
| @@ -1,10 +0,0 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from .meta_dataset import ArrayDataset, Dataset, MapDataset, StreamDataset | |||
| from .vision import * | |||
| @@ -1,73 +0,0 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from abc import ABC, abstractmethod | |||
| from typing import Tuple | |||
| class Dataset(ABC): | |||
| r""" | |||
| An abstract class for all Datasets | |||
| """ | |||
| @abstractmethod | |||
| def __init__(self): | |||
| pass | |||
| class MapDataset(Dataset): | |||
| r""" | |||
| An abstract class for map data | |||
| __getitem__ and __len__ method are aditionally needed | |||
| """ | |||
| @abstractmethod | |||
| def __init__(self): | |||
| pass | |||
| @abstractmethod | |||
| def __getitem__(self, index): | |||
| pass | |||
| @abstractmethod | |||
| def __len__(self): | |||
| pass | |||
| class StreamDataset(Dataset): | |||
| r""" | |||
| An abstract class for stream data | |||
| __iter__ method is aditionally needed | |||
| """ | |||
| @abstractmethod | |||
| def __init__(self): | |||
| pass | |||
| @abstractmethod | |||
| def __iter__(self): | |||
| pass | |||
| class ArrayDataset(MapDataset): | |||
| def __init__(self, *arrays): | |||
| r""" | |||
| ArrayDataset is a dataset for numpy array data, one or more numpy arrays | |||
| are needed to initiate the dataset. And the dimensions represented sample number | |||
| are expected to be the same. | |||
| """ | |||
| super().__init__() | |||
| if not all(len(arrays[0]) == len(array) for array in arrays): | |||
| raise ValueError("lengths of input arrays are inconsistent") | |||
| self.arrays = arrays | |||
| def __getitem__(self, index: int) -> Tuple: | |||
| return tuple(array[index] for array in self.arrays) | |||
| def __len__(self) -> int: | |||
| return len(self.arrays[0]) | |||
| @@ -1,17 +0,0 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from .cifar import CIFAR10, CIFAR100 | |||
| from .cityscapes import Cityscapes | |||
| from .coco import COCO | |||
| from .folder import ImageFolder | |||
| from .imagenet import ImageNet | |||
| from .meta_vision import VisionDataset | |||
| from .mnist import MNIST | |||
| from .objects365 import Objects365 | |||
| from .voc import PascalVOC | |||
| @@ -1,171 +0,0 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import os | |||
| import pickle | |||
| import tarfile | |||
| from typing import Tuple | |||
| import numpy as np | |||
| from ....logger import get_logger | |||
| from .meta_vision import VisionDataset | |||
| from .utils import _default_dataset_root, load_raw_data_from_url | |||
| logger = get_logger(__name__) | |||
| class CIFAR10(VisionDataset): | |||
| r""" ``Dataset`` for CIFAR10 meta data | |||
| """ | |||
| url_path = "http://www.cs.utoronto.ca/~kriz/" | |||
| raw_file_name = "cifar-10-python.tar.gz" | |||
| raw_file_md5 = "c58f30108f718f92721af3b95e74349a" | |||
| raw_file_dir = "cifar-10-batches-py" | |||
| train_batch = [ | |||
| "data_batch_1", | |||
| "data_batch_2", | |||
| "data_batch_3", | |||
| "data_batch_4", | |||
| "data_batch_5", | |||
| ] | |||
| test_batch = ["test_batch"] | |||
| meta_info = {"name": "batches.meta"} | |||
| def __init__( | |||
| self, | |||
| root: str = None, | |||
| train: bool = True, | |||
| download: bool = True, | |||
| timeout: int = 500, | |||
| ): | |||
| super().__init__(root, order=("image", "image_category")) | |||
| self.timeout = timeout | |||
| # process the root path | |||
| if root is None: | |||
| self.root = self._default_root | |||
| if not os.path.exists(self.root): | |||
| os.makedirs(self.root) | |||
| else: | |||
| self.root = root | |||
| if not os.path.exists(self.root): | |||
| if download: | |||
| logger.debug( | |||
| "dir %s does not exist, will be automatically created", | |||
| self.root, | |||
| ) | |||
| os.makedirs(self.root) | |||
| else: | |||
| raise ValueError("dir %s does not exist" % self.root) | |||
| self.target_file = os.path.join(self.root, self.raw_file_dir) | |||
| # check existence of target pickle dir, if exists load the | |||
| # pickle file no matter what download is set | |||
| if os.path.exists(self.target_file): | |||
| if train: | |||
| self.arrays = self.bytes2array(self.train_batch) | |||
| else: | |||
| self.arrays = self.bytes2array(self.test_batch) | |||
| else: | |||
| if download: | |||
| self.download() | |||
| if train: | |||
| self.arrays = self.bytes2array(self.train_batch) | |||
| else: | |||
| self.arrays = self.bytes2array(self.test_batch) | |||
| else: | |||
| raise ValueError( | |||
| "dir does not contain target file %s, please set download=True" | |||
| % (self.target_file) | |||
| ) | |||
| def __getitem__(self, index: int) -> Tuple: | |||
| return tuple(array[index] for array in self.arrays) | |||
| def __len__(self) -> int: | |||
| return len(self.arrays[0]) | |||
| @property | |||
| def _default_root(self): | |||
| return os.path.join(_default_dataset_root(), self.__class__.__name__) | |||
| @property | |||
| def meta(self): | |||
| meta_path = os.path.join(self.root, self.raw_file_dir, self.meta_info["name"]) | |||
| with open(meta_path, "rb") as f: | |||
| meta = pickle.load(f, encoding="bytes") | |||
| return meta | |||
| def download(self): | |||
| url = self.url_path + self.raw_file_name | |||
| load_raw_data_from_url( | |||
| url, self.raw_file_name, self.raw_file_md5, self.root, self.timeout | |||
| ) | |||
| self.process() | |||
| def untar(self, file_path, dirs): | |||
| assert file_path.endswith(".tar.gz") | |||
| logger.debug("untar file %s to %s", file_path, dirs) | |||
| t = tarfile.open(file_path) | |||
| t.extractall(path=dirs) | |||
| def bytes2array(self, filenames): | |||
| data = [] | |||
| label = [] | |||
| for filename in filenames: | |||
| path = os.path.join(self.root, self.raw_file_dir, filename) | |||
| logger.debug("unpickle file %s", path) | |||
| with open(path, "rb") as fo: | |||
| dic = pickle.load(fo, encoding="bytes") | |||
| batch_data = dic[b"data"].reshape(-1, 3, 32, 32).transpose((0, 2, 3, 1)) | |||
| data.extend(list(batch_data[..., [2, 1, 0]])) | |||
| label.extend(dic[b"labels"]) | |||
| label = np.array(label, dtype=np.int32) | |||
| return (data, label) | |||
| def process(self): | |||
| logger.info("process raw data ...") | |||
| self.untar(os.path.join(self.root, self.raw_file_name), self.root) | |||
| class CIFAR100(CIFAR10): | |||
| url_path = "http://www.cs.utoronto.ca/~kriz/" | |||
| raw_file_name = "cifar-100-python.tar.gz" | |||
| raw_file_md5 = "eb9058c3a382ffc7106e4002c42a8d85" | |||
| raw_file_dir = "cifar-100-python" | |||
| train_batch = ["train"] | |||
| test_batch = ["test"] | |||
| meta_info = {"name": "meta"} | |||
| @property | |||
| def meta(self): | |||
| meta_path = os.path.join(self.root, self.raw_file_dir, self.meta_info["name"]) | |||
| with open(meta_path, "rb") as f: | |||
| meta = pickle.load(f, encoding="bytes") | |||
| return meta | |||
| def bytes2array(self, filenames): | |||
| data = [] | |||
| fine_label = [] | |||
| coarse_label = [] | |||
| for filename in filenames: | |||
| path = os.path.join(self.root, self.raw_file_dir, filename) | |||
| logger.debug("unpickle file %s", path) | |||
| with open(path, "rb") as fo: | |||
| dic = pickle.load(fo, encoding="bytes") | |||
| batch_data = dic[b"data"].reshape(-1, 3, 32, 32).transpose((0, 2, 3, 1)) | |||
| data.extend(list(batch_data[..., [2, 1, 0]])) | |||
| fine_label.extend(dic[b"fine_labels"]) | |||
| coarse_label.extend(dic[b"coarse_labels"]) | |||
| fine_label = np.array(fine_label, dtype=np.int32) | |||
| coarse_label = np.array(coarse_label, dtype=np.int32) | |||
| return data, fine_label, coarse_label | |||
| @@ -1,151 +0,0 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # --------------------------------------------------------------------- | |||
| # Part of the following code in this file refs to torchvision | |||
| # BSD 3-Clause License | |||
| # | |||
| # Copyright (c) Soumith Chintala 2016, | |||
| # All rights reserved. | |||
| # --------------------------------------------------------------------- | |||
| import json | |||
| import os | |||
| import cv2 | |||
| import numpy as np | |||
| from .meta_vision import VisionDataset | |||
| class Cityscapes(VisionDataset): | |||
| r"""`Cityscapes <http://www.cityscapes-dataset.com/>`_ Dataset. | |||
| """ | |||
| supported_order = ( | |||
| "image", | |||
| "mask", | |||
| "info", | |||
| ) | |||
| def __init__(self, root, image_set, mode, *, order=None): | |||
| super().__init__(root, order=order, supported_order=self.supported_order) | |||
| city_root = self.root | |||
| if not os.path.isdir(city_root): | |||
| raise RuntimeError("Dataset not found or corrupted.") | |||
| self.mode = mode | |||
| self.images_dir = os.path.join(city_root, "leftImg8bit", image_set) | |||
| self.masks_dir = os.path.join(city_root, self.mode, image_set) | |||
| self.images, self.masks = [], [] | |||
| # self.target_type = ["instance", "semantic", "polygon", "color"] | |||
| # for semantic segmentation | |||
| if mode == "gtFine": | |||
| valid_modes = ("train", "test", "val") | |||
| else: | |||
| valid_modes = ("train", "train_extra", "val") | |||
| for city in os.listdir(self.images_dir): | |||
| img_dir = os.path.join(self.images_dir, city) | |||
| mask_dir = os.path.join(self.masks_dir, city) | |||
| for file_name in os.listdir(img_dir): | |||
| mask_name = "{}_{}".format( | |||
| file_name.split("_leftImg8bit")[0], | |||
| self._get_target_suffix(self.mode, "semantic"), | |||
| ) | |||
| self.images.append(os.path.join(img_dir, file_name)) | |||
| self.masks.append(os.path.join(mask_dir, mask_name)) | |||
| def __getitem__(self, index): | |||
| target = [] | |||
| for k in self.order: | |||
| if k == "image": | |||
| image = cv2.imread(self.images[index], cv2.IMREAD_COLOR) | |||
| target.append(image) | |||
| elif k == "mask": | |||
| mask = cv2.imread(self.masks[index], cv2.IMREAD_GRAYSCALE) | |||
| mask = self._trans_mask(mask) | |||
| mask = mask[:, :, np.newaxis] | |||
| target.append(mask) | |||
| elif k == "info": | |||
| if image is None: | |||
| image = cv2.imread(self.images[index], cv2.IMREAD_COLOR) | |||
| info = [image.shape[0], image.shape[1], self.images[index]] | |||
| target.append(info) | |||
| else: | |||
| raise NotImplementedError | |||
| return tuple(target) | |||
| def __len__(self): | |||
| return len(self.images) | |||
| def _trans_mask(self, mask): | |||
| trans_labels = [ | |||
| 7, | |||
| 8, | |||
| 11, | |||
| 12, | |||
| 13, | |||
| 17, | |||
| 19, | |||
| 20, | |||
| 21, | |||
| 22, | |||
| 23, | |||
| 24, | |||
| 25, | |||
| 26, | |||
| 27, | |||
| 28, | |||
| 31, | |||
| 32, | |||
| 33, | |||
| ] | |||
| label = np.ones(mask.shape) * 255 | |||
| for i, tl in enumerate(trans_labels): | |||
| label[mask == tl] = i | |||
| return label.astype(np.uint8) | |||
| def _get_target_suffix(self, mode, target_type): | |||
| if target_type == "instance": | |||
| return "{}_instanceIds.png".format(mode) | |||
| elif target_type == "semantic": | |||
| return "{}_labelIds.png".format(mode) | |||
| elif target_type == "color": | |||
| return "{}_color.png".format(mode) | |||
| else: | |||
| return "{}_polygons.json".format(mode) | |||
| def _load_json(self, path): | |||
| with open(path, "r") as file: | |||
| data = json.load(file) | |||
| return data | |||
| class_names = ( | |||
| "road", | |||
| "sidewalk", | |||
| "building", | |||
| "wall", | |||
| "fence", | |||
| "pole", | |||
| "traffic light", | |||
| "traffic sign", | |||
| "vegetation", | |||
| "terrain", | |||
| "sky", | |||
| "person", | |||
| "rider", | |||
| "car", | |||
| "truck", | |||
| "bus", | |||
| "train", | |||
| "motorcycle", | |||
| "bicycle", | |||
| ) | |||
| @@ -1,366 +0,0 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # --------------------------------------------------------------------- | |||
| # Part of the following code in this file refs to maskrcnn-benchmark | |||
| # MIT License | |||
| # | |||
| # Copyright (c) 2018 Facebook | |||
| # --------------------------------------------------------------------- | |||
| import json | |||
| import os | |||
| from collections import defaultdict | |||
| import cv2 | |||
| import numpy as np | |||
| from .meta_vision import VisionDataset | |||
| min_keypoints_per_image = 10 | |||
| def _count_visible_keypoints(anno): | |||
| return sum(sum(1 for v in ann["keypoints"][2::3] if v > 0) for ann in anno) | |||
| def has_valid_annotation(anno, order): | |||
| # if it"s empty, there is no annotation | |||
| if len(anno) == 0: | |||
| return False | |||
| if "boxes" in order or "boxes_category" in order: | |||
| if "bbox" not in anno[0]: | |||
| return False | |||
| if "keypoints" in order: | |||
| if "keypoints" not in anno[0]: | |||
| return False | |||
| # for keypoint detection tasks, only consider valid images those | |||
| # containing at least min_keypoints_per_image | |||
| if _count_visible_keypoints(anno) < min_keypoints_per_image: | |||
| return False | |||
| return True | |||
| class COCO(VisionDataset): | |||
| r"""`MS COCO <http://cocodataset.org/#home>`_ Dataset. | |||
| """ | |||
| supported_order = ( | |||
| "image", | |||
| "boxes", | |||
| "boxes_category", | |||
| "keypoints", | |||
| # TODO: need to check | |||
| # "polygons", | |||
| "info", | |||
| ) | |||
| def __init__( | |||
| self, root, ann_file, remove_images_without_annotations=False, *, order=None | |||
| ): | |||
| super().__init__(root, order=order, supported_order=self.supported_order) | |||
| with open(ann_file, "r") as f: | |||
| dataset = json.load(f) | |||
| self.imgs = dict() | |||
| for img in dataset["images"]: | |||
| # for saving memory | |||
| if "license" in img: | |||
| del img["license"] | |||
| if "coco_url" in img: | |||
| del img["coco_url"] | |||
| if "date_captured" in img: | |||
| del img["date_captured"] | |||
| if "flickr_url" in img: | |||
| del img["flickr_url"] | |||
| self.imgs[img["id"]] = img | |||
| self.img_to_anns = defaultdict(list) | |||
| for ann in dataset["annotations"]: | |||
| # for saving memory | |||
| if ( | |||
| "boxes" not in self.order | |||
| and "boxes_category" not in self.order | |||
| and "bbox" in ann | |||
| ): | |||
| del ann["bbox"] | |||
| if "polygons" not in self.order and "segmentation" in ann: | |||
| del ann["segmentation"] | |||
| self.img_to_anns[ann["image_id"]].append(ann) | |||
| self.cats = dict() | |||
| for cat in dataset["categories"]: | |||
| self.cats[cat["id"]] = cat | |||
| self.ids = list(sorted(self.imgs.keys())) | |||
| # filter images without detection annotations | |||
| if remove_images_without_annotations: | |||
| ids = [] | |||
| for img_id in self.ids: | |||
| anno = self.img_to_anns[img_id] | |||
| # filter crowd annotations | |||
| anno = [obj for obj in anno if obj["iscrowd"] == 0] | |||
| anno = [ | |||
| obj for obj in anno if obj["bbox"][2] > 0 and obj["bbox"][3] > 0 | |||
| ] | |||
| if has_valid_annotation(anno, order): | |||
| ids.append(img_id) | |||
| self.img_to_anns[img_id] = anno | |||
| else: | |||
| del self.imgs[img_id] | |||
| del self.img_to_anns[img_id] | |||
| self.ids = ids | |||
| self.json_category_id_to_contiguous_id = { | |||
| v: i + 1 for i, v in enumerate(sorted(self.cats.keys())) | |||
| } | |||
| self.contiguous_category_id_to_json_id = { | |||
| v: k for k, v in self.json_category_id_to_contiguous_id.items() | |||
| } | |||
| def __getitem__(self, index): | |||
| img_id = self.ids[index] | |||
| anno = self.img_to_anns[img_id] | |||
| target = [] | |||
| for k in self.order: | |||
| if k == "image": | |||
| file_name = self.imgs[img_id]["file_name"] | |||
| path = os.path.join(self.root, file_name) | |||
| image = cv2.imread(path, cv2.IMREAD_COLOR) | |||
| target.append(image) | |||
| elif k == "boxes": | |||
| boxes = [obj["bbox"] for obj in anno] | |||
| boxes = np.array(boxes, dtype=np.float32).reshape(-1, 4) | |||
| # transfer boxes from xywh to xyxy | |||
| boxes[:, 2:] += boxes[:, :2] | |||
| target.append(boxes) | |||
| elif k == "boxes_category": | |||
| boxes_category = [obj["category_id"] for obj in anno] | |||
| boxes_category = [ | |||
| self.json_category_id_to_contiguous_id[c] for c in boxes_category | |||
| ] | |||
| boxes_category = np.array(boxes_category, dtype=np.int32) | |||
| target.append(boxes_category) | |||
| elif k == "keypoints": | |||
| keypoints = [obj["keypoints"] for obj in anno] | |||
| keypoints = np.array(keypoints, dtype=np.float32).reshape( | |||
| -1, len(self.keypoint_names), 3 | |||
| ) | |||
| target.append(keypoints) | |||
| elif k == "polygons": | |||
| polygons = [obj["segmentation"] for obj in anno] | |||
| polygons = [ | |||
| [np.array(p, dtype=np.float32).reshape(-1, 2) for p in ps] | |||
| for ps in polygons | |||
| ] | |||
| target.append(polygons) | |||
| elif k == "info": | |||
| info = self.imgs[img_id] | |||
| info = [info["height"], info["width"], info["file_name"]] | |||
| target.append(info) | |||
| else: | |||
| raise NotImplementedError | |||
| return tuple(target) | |||
| def __len__(self): | |||
| return len(self.ids) | |||
| def get_img_info(self, index): | |||
| img_id = self.ids[index] | |||
| img_info = self.imgs[img_id] | |||
| return img_info | |||
| class_names = ( | |||
| "person", | |||
| "bicycle", | |||
| "car", | |||
| "motorcycle", | |||
| "airplane", | |||
| "bus", | |||
| "train", | |||
| "truck", | |||
| "boat", | |||
| "traffic light", | |||
| "fire hydrant", | |||
| "stop sign", | |||
| "parking meter", | |||
| "bench", | |||
| "bird", | |||
| "cat", | |||
| "dog", | |||
| "horse", | |||
| "sheep", | |||
| "cow", | |||
| "elephant", | |||
| "bear", | |||
| "zebra", | |||
| "giraffe", | |||
| "backpack", | |||
| "umbrella", | |||
| "handbag", | |||
| "tie", | |||
| "suitcase", | |||
| "frisbee", | |||
| "skis", | |||
| "snowboard", | |||
| "sports ball", | |||
| "kite", | |||
| "baseball bat", | |||
| "baseball glove", | |||
| "skateboard", | |||
| "surfboard", | |||
| "tennis racket", | |||
| "bottle", | |||
| "wine glass", | |||
| "cup", | |||
| "fork", | |||
| "knife", | |||
| "spoon", | |||
| "bowl", | |||
| "banana", | |||
| "apple", | |||
| "sandwich", | |||
| "orange", | |||
| "broccoli", | |||
| "carrot", | |||
| "hot dog", | |||
| "pizza", | |||
| "donut", | |||
| "cake", | |||
| "chair", | |||
| "couch", | |||
| "potted plant", | |||
| "bed", | |||
| "dining table", | |||
| "toilet", | |||
| "tv", | |||
| "laptop", | |||
| "mouse", | |||
| "remote", | |||
| "keyboard", | |||
| "cell phone", | |||
| "microwave", | |||
| "oven", | |||
| "toaster", | |||
| "sink", | |||
| "refrigerator", | |||
| "book", | |||
| "clock", | |||
| "vase", | |||
| "scissors", | |||
| "teddy bear", | |||
| "hair drier", | |||
| "toothbrush", | |||
| ) | |||
| classes_originID = { | |||
| "person": 1, | |||
| "bicycle": 2, | |||
| "car": 3, | |||
| "motorcycle": 4, | |||
| "airplane": 5, | |||
| "bus": 6, | |||
| "train": 7, | |||
| "truck": 8, | |||
| "boat": 9, | |||
| "traffic light": 10, | |||
| "fire hydrant": 11, | |||
| "stop sign": 13, | |||
| "parking meter": 14, | |||
| "bench": 15, | |||
| "bird": 16, | |||
| "cat": 17, | |||
| "dog": 18, | |||
| "horse": 19, | |||
| "sheep": 20, | |||
| "cow": 21, | |||
| "elephant": 22, | |||
| "bear": 23, | |||
| "zebra": 24, | |||
| "giraffe": 25, | |||
| "backpack": 27, | |||
| "umbrella": 28, | |||
| "handbag": 31, | |||
| "tie": 32, | |||
| "suitcase": 33, | |||
| "frisbee": 34, | |||
| "skis": 35, | |||
| "snowboard": 36, | |||
| "sports ball": 37, | |||
| "kite": 38, | |||
| "baseball bat": 39, | |||
| "baseball glove": 40, | |||
| "skateboard": 41, | |||
| "surfboard": 42, | |||
| "tennis racket": 43, | |||
| "bottle": 44, | |||
| "wine glass": 46, | |||
| "cup": 47, | |||
| "fork": 48, | |||
| "knife": 49, | |||
| "spoon": 50, | |||
| "bowl": 51, | |||
| "banana": 52, | |||
| "apple": 53, | |||
| "sandwich": 54, | |||
| "orange": 55, | |||
| "broccoli": 56, | |||
| "carrot": 57, | |||
| "hot dog": 58, | |||
| "pizza": 59, | |||
| "donut": 60, | |||
| "cake": 61, | |||
| "chair": 62, | |||
| "couch": 63, | |||
| "potted plant": 64, | |||
| "bed": 65, | |||
| "dining table": 67, | |||
| "toilet": 70, | |||
| "tv": 72, | |||
| "laptop": 73, | |||
| "mouse": 74, | |||
| "remote": 75, | |||
| "keyboard": 76, | |||
| "cell phone": 77, | |||
| "microwave": 78, | |||
| "oven": 79, | |||
| "toaster": 80, | |||
| "sink": 81, | |||
| "refrigerator": 82, | |||
| "book": 84, | |||
| "clock": 85, | |||
| "vase": 86, | |||
| "scissors": 87, | |||
| "teddy bear": 88, | |||
| "hair drier": 89, | |||
| "toothbrush": 90, | |||
| } | |||
| keypoint_names = ( | |||
| "nose", | |||
| "left_eye", | |||
| "right_eye", | |||
| "left_ear", | |||
| "right_ear", | |||
| "left_shoulder", | |||
| "right_shoulder", | |||
| "left_elbow", | |||
| "right_elbow", | |||
| "left_wrist", | |||
| "right_wrist", | |||
| "left_hip", | |||
| "right_hip", | |||
| "left_knee", | |||
| "right_knee", | |||
| "left_ankle", | |||
| "right_ankle", | |||
| ) | |||
| @@ -1,90 +0,0 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # BSD 3-Clause License | |||
| # Copyright (c) Soumith Chintala 2016, | |||
| # All rights reserved. | |||
| # --------------------------------------------------------------------- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # | |||
| # This file has been modified by Megvii ("Megvii Modifications"). | |||
| # All Megvii Modifications are Copyright (C) 2014-2020 Megvii Inc. All rights reserved. | |||
| # --------------------------------------------------------------------- | |||
| import os | |||
| from typing import Dict, List, Tuple | |||
| import cv2 | |||
| import numpy as np | |||
| from .meta_vision import VisionDataset | |||
| from .utils import is_img | |||
| class ImageFolder(VisionDataset): | |||
| def __init__(self, root: str, check_valid_func=None, class_name: bool = False): | |||
| r""" | |||
| ImageFolder is a class for loading image data and labels from a organized folder. | |||
| the folder is expected to be organized as followed | |||
| root/cls/xxx.img_ext | |||
| labels are indices of sorted classes in the root directory | |||
| :param root: root directory of an image folder | |||
| :param loader: a function used to load image from path, | |||
| if ``None``, default function that loads | |||
| images with PILwill be called | |||
| :param check_valid_func: a function used to check if files in folder are | |||
| expected image files, if ``None``, default function | |||
| that checks file extensions will be called | |||
| :param class_name: if ``True``, return class name instead of class index | |||
| """ | |||
| super().__init__(root, order=("image", "image_category")) | |||
| self.root = root | |||
| if check_valid_func is not None: | |||
| self.check_valid = check_valid_func | |||
| else: | |||
| self.check_valid = is_img | |||
| self.class_name = class_name | |||
| self.class_dict = self.collect_class() | |||
| self.samples = self.collect_samples() | |||
| def collect_samples(self) -> List: | |||
| samples = [] | |||
| directory = os.path.expanduser(self.root) | |||
| for key in sorted(self.class_dict.keys()): | |||
| d = os.path.join(directory, key) | |||
| if not os.path.isdir(d): | |||
| continue | |||
| for r, _, filename in sorted(os.walk(d, followlinks=True)): | |||
| for name in sorted(filename): | |||
| path = os.path.join(r, name) | |||
| if self.check_valid(path): | |||
| if self.class_name: | |||
| samples.append((path, key)) | |||
| else: | |||
| samples.append((path, self.class_dict[key])) | |||
| return samples | |||
| def collect_class(self) -> Dict: | |||
| classes = [d.name for d in os.scandir(self.root) if d.is_dir()] | |||
| classes.sort() | |||
| return {classes[i]: np.int32(i) for i in range(len(classes))} | |||
| def __getitem__(self, index: int) -> Tuple: | |||
| path, label = self.samples[index] | |||
| img = cv2.imread(path, cv2.IMREAD_COLOR) | |||
| return img, label | |||
| def __len__(self): | |||
| return len(self.samples) | |||
| @@ -1,248 +0,0 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # BSD 3-Clause License | |||
| # | |||
| # Copyright (c) Soumith Chintala 2016, | |||
| # All rights reserved. | |||
| # --------------------------------------------------------------------- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # | |||
| # This file has been modified by Megvii ("Megvii Modifications"). | |||
| # All Megvii Modifications are Copyright (C) 2014-2020 Megvii Inc. All rights reserved. | |||
| # --------------------------------------------------------------------- | |||
| import os | |||
| import shutil | |||
| from tqdm import tqdm | |||
| from ....core.serialization import load, save | |||
| from ....distributed.util import is_distributed | |||
| from ....logger import get_logger | |||
| from .folder import ImageFolder | |||
| from .utils import _default_dataset_root, calculate_md5, untar, untargz | |||
| logger = get_logger(__name__) | |||
| class ImageNet(ImageFolder): | |||
| r""" | |||
| Load ImageNet from raw files or folder, expected folder looks like | |||
| .. code-block:: bash | |||
| ${root}/ | |||
| | [REQUIRED TAR FILES] | |||
| |- ILSVRC2012_img_train.tar | |||
| |- ILSVRC2012_img_val.tar | |||
| |- ILSVRC2012_devkit_t12.tar.gz | |||
| | [OPTIONAL IMAGE FOLDERS] | |||
| |- train/cls/xxx.${img_ext} | |||
| |- val/cls/xxx.${img_ext} | |||
| |- ILSVRC2012_devkit_t12/data/meta.mat | |||
| |- ILSVRC2012_devkit_t12/data/ILSVRC2012_validation_ground_truth.txt | |||
| If the image folders don't exist, raw tar files are required to get extracted and processed. | |||
| """ | |||
| raw_file_meta = { | |||
| "train": ("ILSVRC2012_img_train.tar", "1d675b47d978889d74fa0da5fadfb00e"), | |||
| "val": ("ILSVRC2012_img_val.tar", "29b22e2961454d5413ddabcf34fc5622"), | |||
| "devkit": ("ILSVRC2012_devkit_t12.tar.gz", "fa75699e90414af021442c21a62c3abf"), | |||
| } # ImageNet raw files | |||
| default_train_dir = "train" | |||
| default_val_dir = "val" | |||
| default_devkit_dir = "ILSVRC2012_devkit_t12" | |||
| def __init__(self, root: str = None, train: bool = True, **kwargs): | |||
| r""" | |||
| initialization: | |||
| * if ``root`` contains ``self.target_folder`` depent on ``train``: | |||
| * initialize ImageFolder with target_folder | |||
| * else: | |||
| * if all raw files are in ``root``: | |||
| * parse ``self.target_folder`` from raw files | |||
| * initialize ImageFolder with ``self.target_folder`` | |||
| * else: | |||
| * raise error | |||
| :param root: root directory of imagenet data, if root is ``None``, used default_dataset_root | |||
| :param train: if ``True``, load the train split, otherwise load the validation split | |||
| """ | |||
| # process the root path | |||
| if root is None: | |||
| self.root = self._default_root | |||
| else: | |||
| self.root = root | |||
| if not os.path.exists(self.root): | |||
| raise FileNotFoundError("dir %s does not exist" % self.root) | |||
| self.devkit_dir = os.path.join(self.root, self.default_devkit_dir) | |||
| if not os.path.exists(self.devkit_dir): | |||
| logger.warning("devkit directory %s does not exists", self.devkit_dir) | |||
| self._prepare_devkit() | |||
| self.train = train | |||
| if train: | |||
| self.target_folder = os.path.join(self.root, self.default_train_dir) | |||
| else: | |||
| self.target_folder = os.path.join(self.root, self.default_val_dir) | |||
| if not os.path.exists(self.target_folder): | |||
| logger.warning( | |||
| "expected image folder %s does not exist, try to load from raw file", | |||
| self.target_folder, | |||
| ) | |||
| if not self.check_raw_file(): | |||
| raise FileNotFoundError( | |||
| "expected image folder %s does not exist, and raw files do not exist in %s" | |||
| % (self.target_folder, self.root) | |||
| ) | |||
| elif is_distributed(): | |||
| raise RuntimeError( | |||
| "extracting raw file shouldn't be done in distributed mode, use single process instead" | |||
| ) | |||
| elif train: | |||
| self._prepare_train() | |||
| else: | |||
| self._prepare_val() | |||
| super().__init__(self.target_folder, **kwargs) | |||
| @property | |||
| def _default_root(self): | |||
| return os.path.join(_default_dataset_root(), self.__class__.__name__) | |||
| @property | |||
| def valid_ground_truth(self): | |||
| groud_truth_path = os.path.join( | |||
| self.devkit_dir, "data", "ILSVRC2012_validation_ground_truth.txt" | |||
| ) | |||
| if os.path.exists(groud_truth_path): | |||
| with open(groud_truth_path, "r") as f: | |||
| val_labels = f.readlines() | |||
| return [int(val_label) for val_label in val_labels] | |||
| else: | |||
| raise FileNotFoundError( | |||
| "valid ground truth file %s does not exist" % groud_truth_path | |||
| ) | |||
| @property | |||
| def meta(self): | |||
| try: | |||
| return load(os.path.join(self.devkit_dir, "meta.pkl")) | |||
| except FileNotFoundError: | |||
| import scipy.io | |||
| meta_path = os.path.join(self.devkit_dir, "data", "meta.mat") | |||
| if not os.path.exists(meta_path): | |||
| raise FileNotFoundError("meta file %s does not exist" % meta_path) | |||
| meta = scipy.io.loadmat(meta_path, squeeze_me=True)["synsets"] | |||
| nums_children = list(zip(*meta))[4] | |||
| meta = [ | |||
| meta[idx] | |||
| for idx, num_children in enumerate(nums_children) | |||
| if num_children == 0 | |||
| ] | |||
| idcs, wnids, classes = list(zip(*meta))[:3] | |||
| classes = [tuple(clss.split(", ")) for clss in classes] | |||
| idx_to_wnid = dict(zip(idcs, wnids)) | |||
| wnid_to_classes = dict(zip(wnids, classes)) | |||
| logger.info( | |||
| "saving cached meta file to %s", | |||
| os.path.join(self.devkit_dir, "meta.pkl"), | |||
| ) | |||
| save( | |||
| (idx_to_wnid, wnid_to_classes), | |||
| os.path.join(self.devkit_dir, "meta.pkl"), | |||
| ) | |||
| return idx_to_wnid, wnid_to_classes | |||
| def check_raw_file(self) -> bool: | |||
| return all( | |||
| [ | |||
| os.path.exists(os.path.join(self.root, value[0])) | |||
| for _, value in self.raw_file_meta.items() | |||
| ] | |||
| ) | |||
| def _organize_val_data(self): | |||
| id2wnid = self.meta[0] | |||
| val_idcs = self.valid_ground_truth | |||
| val_wnids = [id2wnid[idx] for idx in val_idcs] | |||
| val_images = sorted( | |||
| [ | |||
| os.path.join(self.target_folder, image) | |||
| for image in os.listdir(self.target_folder) | |||
| ] | |||
| ) | |||
| logger.debug("mkdir for val set wnids") | |||
| for wnid in set(val_wnids): | |||
| os.makedirs(os.path.join(self.root, self.default_val_dir, wnid)) | |||
| logger.debug("mv val images into wnids dir") | |||
| for wnid, img_file in tqdm(zip(val_wnids, val_images)): | |||
| shutil.move( | |||
| img_file, | |||
| os.path.join( | |||
| self.root, self.default_val_dir, wnid, os.path.basename(img_file) | |||
| ), | |||
| ) | |||
| def _prepare_val(self): | |||
| assert not self.train | |||
| raw_filename, checksum = self.raw_file_meta["val"] | |||
| raw_file = os.path.join(self.root, raw_filename) | |||
| logger.info("checksum valid tar file %s ...", raw_file) | |||
| assert ( | |||
| calculate_md5(raw_file) == checksum | |||
| ), "checksum mismatch, {} may be damaged".format(raw_file) | |||
| logger.info("extract valid tar file... this may take 10-20 minutes") | |||
| untar(os.path.join(self.root, raw_file), self.target_folder) | |||
| self._organize_val_data() | |||
| def _prepare_train(self): | |||
| assert self.train | |||
| raw_filename, checksum = self.raw_file_meta["train"] | |||
| raw_file = os.path.join(self.root, raw_filename) | |||
| logger.info("checksum train tar file %s ...", raw_file) | |||
| assert ( | |||
| calculate_md5(raw_file) == checksum | |||
| ), "checksum mismatch, {} may be damaged".format(raw_file) | |||
| logger.info("extract train tar file.. this may take several hours") | |||
| untar( | |||
| os.path.join(self.root, raw_file), self.target_folder, | |||
| ) | |||
| paths = [ | |||
| os.path.join(self.target_folder, child_dir) | |||
| for child_dir in os.listdir(self.target_folder) | |||
| ] | |||
| for path in tqdm(paths): | |||
| untar(path, os.path.splitext(path)[0], remove=True) | |||
| def _prepare_devkit(self): | |||
| raw_filename, checksum = self.raw_file_meta["devkit"] | |||
| raw_file = os.path.join(self.root, raw_filename) | |||
| logger.info("checksum devkit tar file %s ...", raw_file) | |||
| assert ( | |||
| calculate_md5(raw_file) == checksum | |||
| ), "checksum mismatch, {} may be damaged".format(raw_file) | |||
| logger.info("extract devkit file..") | |||
| untargz(os.path.join(self.root, self.raw_file_meta["devkit"][0])) | |||
| @@ -1,41 +0,0 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import collections.abc | |||
| import os | |||
| from ..meta_dataset import MapDataset | |||
| class VisionDataset(MapDataset): | |||
| _repr_indent = 4 | |||
| def __init__(self, root, *, order=None, supported_order=None): | |||
| if isinstance(root, (str, bytes)): | |||
| root = os.path.expanduser(root) | |||
| self.root = root | |||
| if order is None: | |||
| order = ("image",) | |||
| if not isinstance(order, collections.abc.Sequence): | |||
| raise ValueError( | |||
| "order should be a sequence, but got order={}".format(order) | |||
| ) | |||
| if supported_order is not None: | |||
| assert isinstance(supported_order, collections.abc.Sequence) | |||
| for k in order: | |||
| if k not in supported_order: | |||
| raise NotImplementedError("{} is unsupported data type".format(k)) | |||
| self.order = order | |||
| def __getitem__(self, index): | |||
| raise NotImplementedError | |||
| def __len__(self): | |||
| raise NotImplementedError | |||
| @@ -1,197 +0,0 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import gzip | |||
| import os | |||
| import struct | |||
| from typing import Tuple | |||
| import numpy as np | |||
| from tqdm import tqdm | |||
| from ....logger import get_logger | |||
| from .meta_vision import VisionDataset | |||
| from .utils import _default_dataset_root, load_raw_data_from_url | |||
| logger = get_logger(__name__) | |||
| class MNIST(VisionDataset): | |||
| r""" ``Dataset`` for MNIST meta data | |||
| """ | |||
| url_path = "http://yann.lecun.com/exdb/mnist/" | |||
| """ | |||
| url prefix for downloading raw file | |||
| """ | |||
| raw_file_name = [ | |||
| "train-images-idx3-ubyte.gz", | |||
| "train-labels-idx1-ubyte.gz", | |||
| "t10k-images-idx3-ubyte.gz", | |||
| "t10k-labels-idx1-ubyte.gz", | |||
| ] | |||
| """ | |||
| raw file names of both training set and test set (10k) | |||
| """ | |||
| raw_file_md5 = [ | |||
| "f68b3c2dcbeaaa9fbdd348bbdeb94873", | |||
| "d53e105ee54ea40749a09fcbcd1e9432", | |||
| "9fb629c4189551a2d022fa330f9573f3", | |||
| "ec29112dd5afa0611ce80d1b7f02629c", | |||
| ] | |||
| """ | |||
| md5 for checking raw files | |||
| """ | |||
| def __init__( | |||
| self, | |||
| root: str = None, | |||
| train: bool = True, | |||
| download: bool = True, | |||
| timeout: int = 500, | |||
| ): | |||
| r""" | |||
| :param root: path for mnist dataset downloading or loading, if ``None``, | |||
| set ``root`` to the ``_default_root`` | |||
| :param train: if ``True``, loading trainingset, else loading test set | |||
| :param download: if raw files do not exists and download sets to ``True``, | |||
| download raw files and process, otherwise raise ValueError, default is True | |||
| """ | |||
| super().__init__(root, order=("image", "image_category")) | |||
| self.timeout = timeout | |||
| # process the root path | |||
| if root is None: | |||
| self.root = self._default_root | |||
| if not os.path.exists(self.root): | |||
| os.makedirs(self.root) | |||
| else: | |||
| self.root = root | |||
| if not os.path.exists(self.root): | |||
| if download: | |||
| logger.debug( | |||
| "dir %s does not exist, will be automatically created", | |||
| self.root, | |||
| ) | |||
| os.makedirs(self.root) | |||
| else: | |||
| raise ValueError("dir %s does not exist" % self.root) | |||
| if self._check_raw_files(): | |||
| self.process(train) | |||
| elif download: | |||
| self.download() | |||
| self.process(train) | |||
| else: | |||
| raise ValueError( | |||
| "root does not contain valid raw files, please set download=True" | |||
| ) | |||
| def __getitem__(self, index: int) -> Tuple: | |||
| return tuple(array[index] for array in self.arrays) | |||
| def __len__(self) -> int: | |||
| return len(self.arrays[0]) | |||
| @property | |||
| def _default_root(self): | |||
| return os.path.join(_default_dataset_root(), self.__class__.__name__) | |||
| @property | |||
| def meta(self): | |||
| return self._meta_data | |||
| def _check_raw_files(self): | |||
| return all( | |||
| [ | |||
| os.path.exists(os.path.join(self.root, path)) | |||
| for path in self.raw_file_name | |||
| ] | |||
| ) | |||
| def download(self): | |||
| for file_name, md5 in zip(self.raw_file_name, self.raw_file_md5): | |||
| url = self.url_path + file_name | |||
| load_raw_data_from_url(url, file_name, md5, self.root, self.timeout) | |||
| def process(self, train): | |||
| # load raw files and transform them into meta data and datasets Tuple(np.array) | |||
| logger.info("process the raw files of %s set...", "train" if train else "test") | |||
| if train: | |||
| meta_data_images, images = parse_idx3( | |||
| os.path.join(self.root, self.raw_file_name[0]) | |||
| ) | |||
| meta_data_labels, labels = parse_idx1( | |||
| os.path.join(self.root, self.raw_file_name[1]) | |||
| ) | |||
| else: | |||
| meta_data_images, images = parse_idx3( | |||
| os.path.join(self.root, self.raw_file_name[2]) | |||
| ) | |||
| meta_data_labels, labels = parse_idx1( | |||
| os.path.join(self.root, self.raw_file_name[3]) | |||
| ) | |||
| self._meta_data = { | |||
| "images": meta_data_images, | |||
| "labels": meta_data_labels, | |||
| } | |||
| self.arrays = (images, labels.astype(np.int32)) | |||
| def parse_idx3(idx3_file): | |||
| # parse idx3 file to meta data and data in numpy array (images) | |||
| logger.debug("parse idx3 file %s ...", idx3_file) | |||
| assert idx3_file.endswith(".gz") | |||
| with gzip.open(idx3_file, "rb") as f: | |||
| bin_data = f.read() | |||
| # parse meta data | |||
| offset = 0 | |||
| fmt_header = ">iiii" | |||
| magic, imgs, height, width = struct.unpack_from(fmt_header, bin_data, offset) | |||
| meta_data = {"magic": magic, "imgs": imgs, "height": height, "width": width} | |||
| # parse images | |||
| image_size = height * width | |||
| offset += struct.calcsize(fmt_header) | |||
| fmt_image = ">" + str(image_size) + "B" | |||
| images = [] | |||
| bar = tqdm(total=meta_data["imgs"], ncols=80) | |||
| for image in struct.iter_unpack(fmt_image, bin_data[offset:]): | |||
| images.append(np.array(image, dtype=np.uint8).reshape((height, width, 1))) | |||
| bar.update() | |||
| bar.close() | |||
| return meta_data, images | |||
| def parse_idx1(idx1_file): | |||
| # parse idx1 file to meta data and data in numpy array (labels) | |||
| logger.debug("parse idx1 file %s ...", idx1_file) | |||
| assert idx1_file.endswith(".gz") | |||
| with gzip.open(idx1_file, "rb") as f: | |||
| bin_data = f.read() | |||
| # parse meta data | |||
| offset = 0 | |||
| fmt_header = ">ii" | |||
| magic, imgs = struct.unpack_from(fmt_header, bin_data, offset) | |||
| meta_data = {"magic": magic, "imgs": imgs} | |||
| # parse labels | |||
| offset += struct.calcsize(fmt_header) | |||
| fmt_image = ">B" | |||
| labels = np.empty(imgs, dtype=int) | |||
| bar = tqdm(total=meta_data["imgs"], ncols=80) | |||
| for i, label in enumerate(struct.iter_unpack(fmt_image, bin_data[offset:])): | |||
| labels[i] = label[0] | |||
| bar.update() | |||
| bar.close() | |||
| return meta_data, labels | |||
| @@ -1,498 +0,0 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # --------------------------------------------------------------------- | |||
| # Part of the following code in this file refs to maskrcnn-benchmark | |||
| # MIT License | |||
| # | |||
| # Copyright (c) 2018 Facebook | |||
| # --------------------------------------------------------------------- | |||
| import json | |||
| import os | |||
| from collections import defaultdict | |||
| import cv2 | |||
| import numpy as np | |||
| from .meta_vision import VisionDataset | |||
| class Objects365(VisionDataset): | |||
| r"""`Objects365 <https://www.objects365.org/overview.html>`_ Dataset. | |||
| """ | |||
| supported_order = ( | |||
| "image", | |||
| "boxes", | |||
| "boxes_category", | |||
| "info", | |||
| ) | |||
| def __init__( | |||
| self, root, ann_file, remove_images_without_annotations=False, *, order=None | |||
| ): | |||
| super().__init__(root, order=order, supported_order=self.supported_order) | |||
| with open(ann_file, "r") as f: | |||
| dataset = json.load(f) | |||
| self.imgs = dict() | |||
| for img in dataset["images"]: | |||
| self.imgs[img["id"]] = img | |||
| self.img_to_anns = defaultdict(list) | |||
| for ann in dataset["annotations"]: | |||
| # for saving memory | |||
| if ( | |||
| "boxes" not in self.order | |||
| and "boxes_category" not in self.order | |||
| and "bbox" in ann | |||
| ): | |||
| del ann["bbox"] | |||
| self.img_to_anns[ann["image_id"]].append(ann) | |||
| self.cats = dict() | |||
| for cat in dataset["categories"]: | |||
| self.cats[cat["id"]] = cat | |||
| self.ids = list(sorted(self.imgs.keys())) | |||
| # filter images without detection annotations | |||
| if remove_images_without_annotations: | |||
| ids = [] | |||
| for img_id in self.ids: | |||
| anno = self.img_to_anns[img_id] | |||
| # filter crowd annotations | |||
| anno = [obj for obj in anno if obj["iscrowd"] == 0] | |||
| anno = [ | |||
| obj for obj in anno if obj["bbox"][2] > 0 and obj["bbox"][3] > 0 | |||
| ] | |||
| if len(anno) > 0: | |||
| ids.append(img_id) | |||
| self.img_to_anns[img_id] = anno | |||
| else: | |||
| del self.imgs[img_id] | |||
| del self.img_to_anns[img_id] | |||
| self.ids = ids | |||
| self.json_category_id_to_contiguous_id = { | |||
| v: i + 1 for i, v in enumerate(sorted(self.cats.keys())) | |||
| } | |||
| self.contiguous_category_id_to_json_id = { | |||
| v: k for k, v in self.json_category_id_to_contiguous_id.items() | |||
| } | |||
| def __getitem__(self, index): | |||
| img_id = self.ids[index] | |||
| anno = self.img_to_anns[img_id] | |||
| target = [] | |||
| for k in self.order: | |||
| if k == "image": | |||
| file_name = self.imgs[img_id]["file_name"] | |||
| path = os.path.join(self.root, file_name) | |||
| image = cv2.imread(path, cv2.IMREAD_COLOR) | |||
| target.append(image) | |||
| elif k == "boxes": | |||
| boxes = [obj["bbox"] for obj in anno] | |||
| boxes = np.array(boxes, dtype=np.float32).reshape(-1, 4) | |||
| # transfer boxes from xywh to xyxy | |||
| boxes[:, 2:] += boxes[:, :2] | |||
| target.append(boxes) | |||
| elif k == "boxes_category": | |||
| boxes_category = [obj["category_id"] for obj in anno] | |||
| boxes_category = [ | |||
| self.json_category_id_to_contiguous_id[c] for c in boxes_category | |||
| ] | |||
| boxes_category = np.array(boxes_category, dtype=np.int32) | |||
| target.append(boxes_category) | |||
| elif k == "info": | |||
| info = self.imgs[img_id] | |||
| info = [info["height"], info["width"], info["file_name"]] | |||
| target.append(info) | |||
| else: | |||
| raise NotImplementedError | |||
| return tuple(target) | |||
| def __len__(self): | |||
| return len(self.ids) | |||
| def get_img_info(self, index): | |||
| img_id = self.ids[index] | |||
| img_info = self.imgs[img_id] | |||
| return img_info | |||
| class_names = ( | |||
| "person", | |||
| "sneakers", | |||
| "chair", | |||
| "hat", | |||
| "lamp", | |||
| "bottle", | |||
| "cabinet/shelf", | |||
| "cup", | |||
| "car", | |||
| "glasses", | |||
| "picture/frame", | |||
| "desk", | |||
| "handbag", | |||
| "street lights", | |||
| "book", | |||
| "plate", | |||
| "helmet", | |||
| "leather shoes", | |||
| "pillow", | |||
| "glove", | |||
| "potted plant", | |||
| "bracelet", | |||
| "flower", | |||
| "tv", | |||
| "storage box", | |||
| "vase", | |||
| "bench", | |||
| "wine glass", | |||
| "boots", | |||
| "bowl", | |||
| "dining table", | |||
| "umbrella", | |||
| "boat", | |||
| "flag", | |||
| "speaker", | |||
| "trash bin/can", | |||
| "stool", | |||
| "backpack", | |||
| "couch", | |||
| "belt", | |||
| "carpet", | |||
| "basket", | |||
| "towel/napkin", | |||
| "slippers", | |||
| "barrel/bucket", | |||
| "coffee table", | |||
| "suv", | |||
| "toy", | |||
| "tie", | |||
| "bed", | |||
| "traffic light", | |||
| "pen/pencil", | |||
| "microphone", | |||
| "sandals", | |||
| "canned", | |||
| "necklace", | |||
| "mirror", | |||
| "faucet", | |||
| "bicycle", | |||
| "bread", | |||
| "high heels", | |||
| "ring", | |||
| "van", | |||
| "watch", | |||
| "sink", | |||
| "horse", | |||
| "fish", | |||
| "apple", | |||
| "camera", | |||
| "candle", | |||
| "teddy bear", | |||
| "cake", | |||
| "motorcycle", | |||
| "wild bird", | |||
| "laptop", | |||
| "knife", | |||
| "traffic sign", | |||
| "cell phone", | |||
| "paddle", | |||
| "truck", | |||
| "cow", | |||
| "power outlet", | |||
| "clock", | |||
| "drum", | |||
| "fork", | |||
| "bus", | |||
| "hanger", | |||
| "nightstand", | |||
| "pot/pan", | |||
| "sheep", | |||
| "guitar", | |||
| "traffic cone", | |||
| "tea pot", | |||
| "keyboard", | |||
| "tripod", | |||
| "hockey", | |||
| "fan", | |||
| "dog", | |||
| "spoon", | |||
| "blackboard/whiteboard", | |||
| "balloon", | |||
| "air conditioner", | |||
| "cymbal", | |||
| "mouse", | |||
| "telephone", | |||
| "pickup truck", | |||
| "orange", | |||
| "banana", | |||
| "airplane", | |||
| "luggage", | |||
| "skis", | |||
| "soccer", | |||
| "trolley", | |||
| "oven", | |||
| "remote", | |||
| "baseball glove", | |||
| "paper towel", | |||
| "refrigerator", | |||
| "train", | |||
| "tomato", | |||
| "machinery vehicle", | |||
| "tent", | |||
| "shampoo/shower gel", | |||
| "head phone", | |||
| "lantern", | |||
| "donut", | |||
| "cleaning products", | |||
| "sailboat", | |||
| "tangerine", | |||
| "pizza", | |||
| "kite", | |||
| "computer box", | |||
| "elephant", | |||
| "toiletries", | |||
| "gas stove", | |||
| "broccoli", | |||
| "toilet", | |||
| "stroller", | |||
| "shovel", | |||
| "baseball bat", | |||
| "microwave", | |||
| "skateboard", | |||
| "surfboard", | |||
| "surveillance camera", | |||
| "gun", | |||
| "life saver", | |||
| "cat", | |||
| "lemon", | |||
| "liquid soap", | |||
| "zebra", | |||
| "duck", | |||
| "sports car", | |||
| "giraffe", | |||
| "pumpkin", | |||
| "piano", | |||
| "stop sign", | |||
| "radiator", | |||
| "converter", | |||
| "tissue ", | |||
| "carrot", | |||
| "washing machine", | |||
| "vent", | |||
| "cookies", | |||
| "cutting/chopping board", | |||
| "tennis racket", | |||
| "candy", | |||
| "skating and skiing shoes", | |||
| "scissors", | |||
| "folder", | |||
| "baseball", | |||
| "strawberry", | |||
| "bow tie", | |||
| "pigeon", | |||
| "pepper", | |||
| "coffee machine", | |||
| "bathtub", | |||
| "snowboard", | |||
| "suitcase", | |||
| "grapes", | |||
| "ladder", | |||
| "pear", | |||
| "american football", | |||
| "basketball", | |||
| "potato", | |||
| "paint brush", | |||
| "printer", | |||
| "billiards", | |||
| "fire hydrant", | |||
| "goose", | |||
| "projector", | |||
| "sausage", | |||
| "fire extinguisher", | |||
| "extension cord", | |||
| "facial mask", | |||
| "tennis ball", | |||
| "chopsticks", | |||
| "electronic stove and gas stove", | |||
| "pie", | |||
| "frisbee", | |||
| "kettle", | |||
| "hamburger", | |||
| "golf club", | |||
| "cucumber", | |||
| "clutch", | |||
| "blender", | |||
| "tong", | |||
| "slide", | |||
| "hot dog", | |||
| "toothbrush", | |||
| "facial cleanser", | |||
| "mango", | |||
| "deer", | |||
| "egg", | |||
| "violin", | |||
| "marker", | |||
| "ship", | |||
| "chicken", | |||
| "onion", | |||
| "ice cream", | |||
| "tape", | |||
| "wheelchair", | |||
| "plum", | |||
| "bar soap", | |||
| "scale", | |||
| "watermelon", | |||
| "cabbage", | |||
| "router/modem", | |||
| "golf ball", | |||
| "pine apple", | |||
| "crane", | |||
| "fire truck", | |||
| "peach", | |||
| "cello", | |||
| "notepaper", | |||
| "tricycle", | |||
| "toaster", | |||
| "helicopter", | |||
| "green beans", | |||
| "brush", | |||
| "carriage", | |||
| "cigar", | |||
| "earphone", | |||
| "penguin", | |||
| "hurdle", | |||
| "swing", | |||
| "radio", | |||
| "CD", | |||
| "parking meter", | |||
| "swan", | |||
| "garlic", | |||
| "french fries", | |||
| "horn", | |||
| "avocado", | |||
| "saxophone", | |||
| "trumpet", | |||
| "sandwich", | |||
| "cue", | |||
| "kiwi fruit", | |||
| "bear", | |||
| "fishing rod", | |||
| "cherry", | |||
| "tablet", | |||
| "green vegetables", | |||
| "nuts", | |||
| "corn", | |||
| "key", | |||
| "screwdriver", | |||
| "globe", | |||
| "broom", | |||
| "pliers", | |||
| "volleyball", | |||
| "hammer", | |||
| "eggplant", | |||
| "trophy", | |||
| "dates", | |||
| "board eraser", | |||
| "rice", | |||
| "tape measure/ruler", | |||
| "dumbbell", | |||
| "hamimelon", | |||
| "stapler", | |||
| "camel", | |||
| "lettuce", | |||
| "goldfish", | |||
| "meat balls", | |||
| "medal", | |||
| "toothpaste", | |||
| "antelope", | |||
| "shrimp", | |||
| "rickshaw", | |||
| "trombone", | |||
| "pomegranate", | |||
| "coconut", | |||
| "jellyfish", | |||
| "mushroom", | |||
| "calculator", | |||
| "treadmill", | |||
| "butterfly", | |||
| "egg tart", | |||
| "cheese", | |||
| "pig", | |||
| "pomelo", | |||
| "race car", | |||
| "rice cooker", | |||
| "tuba", | |||
| "crosswalk sign", | |||
| "papaya", | |||
| "hair drier", | |||
| "green onion", | |||
| "chips", | |||
| "dolphin", | |||
| "sushi", | |||
| "urinal", | |||
| "donkey", | |||
| "electric drill", | |||
| "spring rolls", | |||
| "tortoise/turtle", | |||
| "parrot", | |||
| "flute", | |||
| "measuring cup", | |||
| "shark", | |||
| "steak", | |||
| "poker card", | |||
| "binoculars", | |||
| "llama", | |||
| "radish", | |||
| "noodles", | |||
| "yak", | |||
| "mop", | |||
| "crab", | |||
| "microscope", | |||
| "barbell", | |||
| "bread/bun", | |||
| "baozi", | |||
| "lion", | |||
| "red cabbage", | |||
| "polar bear", | |||
| "lighter", | |||
| "seal", | |||
| "mangosteen", | |||
| "comb", | |||
| "eraser", | |||
| "pitaya", | |||
| "scallop", | |||
| "pencil case", | |||
| "saw", | |||
| "table tennis paddle", | |||
| "okra", | |||
| "starfish", | |||
| "eagle", | |||
| "monkey", | |||
| "durian", | |||
| "game board", | |||
| "rabbit", | |||
| "french horn", | |||
| "ambulance", | |||
| "asparagus", | |||
| "hoverboard", | |||
| "pasta", | |||
| "target", | |||
| "hotair balloon", | |||
| "chainsaw", | |||
| "lobster", | |||
| "iron", | |||
| "flashlight", | |||
| ) | |||
| @@ -1,89 +0,0 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import hashlib | |||
| import os | |||
| import tarfile | |||
| from ....distributed.util import is_distributed | |||
| from ....logger import get_logger | |||
| from ....utils.http_download import download_from_url | |||
| IMG_EXT = (".jpg", ".png", ".jpeg", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp") | |||
| logger = get_logger(__name__) | |||
| def _default_dataset_root(): | |||
| default_dataset_root = os.path.expanduser( | |||
| os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "megengine") | |||
| ) | |||
| return default_dataset_root | |||
| def load_raw_data_from_url( | |||
| url: str, filename: str, target_md5: str, raw_data_dir: str, timeout: int | |||
| ): | |||
| cached_file = os.path.join(raw_data_dir, filename) | |||
| logger.debug( | |||
| "load_raw_data_from_url: downloading to or using cached %s ...", cached_file | |||
| ) | |||
| if not os.path.exists(cached_file): | |||
| if is_distributed(): | |||
| logger.warning( | |||
| "Downloading raw data in DISTRIBUTED mode\n" | |||
| " File may be downloaded multiple times. We recommend\n" | |||
| " users to download in single process first." | |||
| ) | |||
| md5 = download_from_url(url, cached_file, http_read_timeout=timeout) | |||
| else: | |||
| md5 = calculate_md5(cached_file) | |||
| if target_md5 == md5: | |||
| logger.debug("%s exists with correct md5: %s", filename, target_md5) | |||
| else: | |||
| os.remove(cached_file) | |||
| raise RuntimeError("{} exists but fail to match md5".format(filename)) | |||
| def calculate_md5(filename): | |||
| m = hashlib.md5() | |||
| with open(filename, "rb") as f: | |||
| while True: | |||
| data = f.read(4096) | |||
| if not data: | |||
| break | |||
| m.update(data) | |||
| return m.hexdigest() | |||
| def is_img(filename): | |||
| return filename.lower().endswith(IMG_EXT) | |||
| def untar(path, to=None, remove=False): | |||
| if to is None: | |||
| to = os.path.dirname(path) | |||
| with tarfile.open(path, "r") as tar: | |||
| tar.extractall(path=to) | |||
| if remove: | |||
| os.remove(path) | |||
| def untargz(path, to=None, remove=False): | |||
| if path.endswith(".tar.gz"): | |||
| if to is None: | |||
| to = os.path.dirname(path) | |||
| with tarfile.open(path, "r:gz") as tar: | |||
| tar.extractall(path=to) | |||
| else: | |||
| raise ValueError("path %s does not end with .tar" % path) | |||
| if remove: | |||
| os.remove(path) | |||
| @@ -1,185 +0,0 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # --------------------------------------------------------------------- | |||
| # Part of the following code in this file refs to torchvision | |||
| # BSD 3-Clause License | |||
| # | |||
| # Copyright (c) Soumith Chintala 2016, | |||
| # All rights reserved. | |||
| # --------------------------------------------------------------------- | |||
| import collections.abc | |||
| import os | |||
| import xml.etree.ElementTree as ET | |||
| import cv2 | |||
| import numpy as np | |||
| from .meta_vision import VisionDataset | |||
| class PascalVOC(VisionDataset): | |||
| r"""`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Dataset. | |||
| """ | |||
| supported_order = ( | |||
| "image", | |||
| "boxes", | |||
| "boxes_category", | |||
| "mask", | |||
| "info", | |||
| ) | |||
| def __init__(self, root, image_set, *, order=None): | |||
| if ("boxes" in order or "boxes_category" in order) and "mask" in order: | |||
| raise ValueError( | |||
| "PascalVOC only supports boxes & boxes_category or mask, not both." | |||
| ) | |||
| super().__init__(root, order=order, supported_order=self.supported_order) | |||
| if not os.path.isdir(self.root): | |||
| raise RuntimeError("Dataset not found or corrupted.") | |||
| self.image_set = image_set | |||
| image_dir = os.path.join(self.root, "JPEGImages") | |||
| if "boxes" in order or "boxes_category" in order: | |||
| annotation_dir = os.path.join(self.root, "Annotations") | |||
| splitdet_dir = os.path.join(self.root, "ImageSets/Main") | |||
| split_f = os.path.join(splitdet_dir, image_set.rstrip("\n") + ".txt") | |||
| with open(os.path.join(split_f), "r") as f: | |||
| self.file_names = [x.strip() for x in f.readlines()] | |||
| self.images = [os.path.join(image_dir, x + ".jpg") for x in self.file_names] | |||
| self.annotations = [ | |||
| os.path.join(annotation_dir, x + ".xml") for x in self.file_names | |||
| ] | |||
| assert len(self.images) == len(self.annotations) | |||
| elif "mask" in order: | |||
| if "aug" in image_set: | |||
| mask_dir = os.path.join(self.root, "SegmentationClass_aug") | |||
| else: | |||
| mask_dir = os.path.join(self.root, "SegmentationClass") | |||
| splitmask_dir = os.path.join(self.root, "ImageSets/Segmentation") | |||
| split_f = os.path.join(splitmask_dir, image_set.rstrip("\n") + ".txt") | |||
| with open(os.path.join(split_f), "r") as f: | |||
| self.file_names = [x.strip() for x in f.readlines()] | |||
| self.images = [os.path.join(image_dir, x + ".jpg") for x in self.file_names] | |||
| self.masks = [os.path.join(mask_dir, x + ".png") for x in self.file_names] | |||
| assert len(self.images) == len(self.masks) | |||
| else: | |||
| raise NotImplementedError | |||
| self.img_infos = dict() | |||
| def __getitem__(self, index): | |||
| target = [] | |||
| for k in self.order: | |||
| if k == "image": | |||
| image = cv2.imread(self.images[index], cv2.IMREAD_COLOR) | |||
| target.append(image) | |||
| elif k == "boxes": | |||
| anno = self.parse_voc_xml(ET.parse(self.annotations[index]).getroot()) | |||
| boxes = [obj["bndbox"] for obj in anno["annotation"]["object"]] | |||
| # boxes type xyxy | |||
| boxes = [ | |||
| (bb["xmin"], bb["ymin"], bb["xmax"], bb["ymax"]) for bb in boxes | |||
| ] | |||
| boxes = np.array(boxes, dtype=np.float32).reshape(-1, 4) | |||
| target.append(boxes) | |||
| elif k == "boxes_category": | |||
| anno = self.parse_voc_xml(ET.parse(self.annotations[index]).getroot()) | |||
| boxes_category = [obj["name"] for obj in anno["annotation"]["object"]] | |||
| boxes_category = [ | |||
| self.class_names.index(bc) + 1 for bc in boxes_category | |||
| ] | |||
| boxes_category = np.array(boxes_category, dtype=np.int32) | |||
| target.append(boxes_category) | |||
| elif k == "mask": | |||
| if "aug" in self.image_set: | |||
| mask = cv2.imread(self.masks[index], cv2.IMREAD_GRAYSCALE) | |||
| else: | |||
| mask = cv2.imread(self.masks[index], cv2.IMREAD_COLOR) | |||
| mask = self._trans_mask(mask) | |||
| mask = mask[:, :, np.newaxis] | |||
| target.append(mask) | |||
| elif k == "info": | |||
| info = self.get_img_info(index, image) | |||
| info = [info["height"], info["width"], info["file_name"]] | |||
| target.append(info) | |||
| else: | |||
| raise NotImplementedError | |||
| return tuple(target) | |||
| def __len__(self): | |||
| return len(self.images) | |||
| def get_img_info(self, index, image=None): | |||
| if index not in self.img_infos: | |||
| if image is None: | |||
| image = cv2.imread(self.images[index], cv2.IMREAD_COLOR) | |||
| self.img_infos[index] = dict( | |||
| height=image.shape[0], | |||
| width=image.shape[1], | |||
| file_name=self.file_names[index], | |||
| ) | |||
| return self.img_infos[index] | |||
| def _trans_mask(self, mask): | |||
| label = np.ones(mask.shape[:2]) * 255 | |||
| for i in range(len(self.class_colors)): | |||
| b, g, r = self.class_colors[i] | |||
| label[ | |||
| (mask[:, :, 0] == b) & (mask[:, :, 1] == g) & (mask[:, :, 2] == r) | |||
| ] = i | |||
| return label.astype(np.uint8) | |||
| def parse_voc_xml(self, node): | |||
| voc_dict = {} | |||
| children = list(node) | |||
| if children: | |||
| def_dic = collections.defaultdict(list) | |||
| for dc in map(self.parse_voc_xml, children): | |||
| for ind, v in dc.items(): | |||
| def_dic[ind].append(v) | |||
| if node.tag == "annotation": | |||
| def_dic["object"] = [def_dic["object"]] | |||
| voc_dict = { | |||
| node.tag: { | |||
| ind: v[0] if len(v) == 1 else v for ind, v in def_dic.items() | |||
| } | |||
| } | |||
| if node.text: | |||
| text = node.text.strip() | |||
| if not children: | |||
| voc_dict[node.tag] = text | |||
| return voc_dict | |||
| class_names = ( | |||
| "aeroplane", | |||
| "bicycle", | |||
| "bird", | |||
| "boat", | |||
| "bottle", | |||
| "bus", | |||
| "car", | |||
| "cat", | |||
| "chair", | |||
| "cow", | |||
| "diningtable", | |||
| "dog", | |||
| "horse", | |||
| "motorbike", | |||
| "person", | |||
| "pottedplant", | |||
| "sheep", | |||
| "sofa", | |||
| "train", | |||
| "tvmonitor", | |||
| ) | |||
| @@ -1,274 +0,0 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import collections.abc | |||
| import math | |||
| from abc import ABC | |||
| from typing import Any, Generator, Iterator, List, Union | |||
| import numpy as np | |||
| import megengine.distributed as dist | |||
| class Sampler(ABC): | |||
| def __init__( | |||
| self, | |||
| dataset, | |||
| batch_size=1, | |||
| drop_last=False, | |||
| num_samples=None, | |||
| world_size=None, | |||
| rank=None, | |||
| seed=None, | |||
| ): | |||
| r""" | |||
| An abstract class for all sampler | |||
| :type dataset: `dataset` | |||
| :param dataset: dataset to sample from | |||
| :type batch_size: positive integer | |||
| :param batch_size: batch size for batch method | |||
| :type drop_last: bool | |||
| :param drop_last: set ``True`` to drop the last incomplete batch, | |||
| if the dataset size is not divisible by the batch size. If ``False`` and | |||
| the size of dataset is not divisible by the batch_size, then the last batch will | |||
| be smaller. (default: ``False``) | |||
| :type num_samples: positive integer | |||
| :param num_samples: number of samples assigned to one rank | |||
| :type world_size: positive integer | |||
| :param world_size: number of ranks | |||
| :type rank: non-negative integer within 0 and world_size | |||
| :param rank: rank id, non-negative interger within 0 and ``world_size`` | |||
| :type seed: non-negative integer | |||
| :param seed: seed for random operators | |||
| """ | |||
| if ( | |||
| not isinstance(batch_size, int) | |||
| or isinstance(batch_size, bool) | |||
| or batch_size <= 0 | |||
| ): | |||
| raise ValueError( | |||
| "batch_size should be a positive integer value, " | |||
| "but got batch_size={}".format(batch_size) | |||
| ) | |||
| if not isinstance(drop_last, bool): | |||
| raise ValueError( | |||
| "drop_last should be a boolean value, but got " | |||
| "drop_last={}".format(drop_last) | |||
| ) | |||
| if num_samples is not None and ( | |||
| not isinstance(num_samples, int) | |||
| or isinstance(num_samples, bool) | |||
| or num_samples <= 0 | |||
| ): | |||
| raise ValueError( | |||
| "num_samples should be a positive integer " | |||
| "value, but got num_samples={}".format(num_samples) | |||
| ) | |||
| self.batch_size = batch_size | |||
| self.dataset = dataset | |||
| self.drop_last = drop_last | |||
| if world_size is None: | |||
| world_size = dist.get_world_size() if dist.is_distributed() else 1 | |||
| self.world_size = world_size | |||
| if rank is None: | |||
| rank = dist.get_rank() if dist.is_distributed() else 0 | |||
| self.rank = rank | |||
| if num_samples is None: | |||
| num_samples = len(self.dataset) | |||
| self.num_samples = int(math.ceil(num_samples / self.world_size)) | |||
| # Make sure seeds are the same at each rank | |||
| if seed is None and self.world_size > 1: | |||
| seed = 0 | |||
| self.rng = np.random.RandomState(seed) | |||
| def __iter__(self) -> Union[Generator, Iterator]: | |||
| return self.batch() | |||
| def __len__(self) -> int: | |||
| if self.drop_last: | |||
| return self.num_samples // self.batch_size | |||
| else: | |||
| return int(math.ceil(self.num_samples / self.batch_size)) | |||
| def sample(self): | |||
| """ | |||
| return a list contains all sample indices | |||
| """ | |||
| raise NotImplementedError | |||
| def scatter(self, indices) -> List: | |||
| r""" | |||
| scatter method is used for splitting indices into subset, each subset | |||
| will be assigned to a rank. Indices are evenly splitted by default. | |||
| If customized indices assignment method is needed, please rewrite this method | |||
| """ | |||
| total_size = self.num_samples * self.world_size | |||
| # add extra indices to make it evenly divisible | |||
| indices += indices[: (total_size - len(indices))] | |||
| assert len(indices) == total_size | |||
| # subsample | |||
| indices = indices[self.rank : total_size : self.world_size] | |||
| assert len(indices) == self.num_samples | |||
| return indices | |||
| def batch(self) -> Iterator[List[Any]]: | |||
| r""" | |||
| batch method provides a batch indices generator | |||
| """ | |||
| indices = list(self.sample()) | |||
| # user might pass the world_size parameter without dist, | |||
| # so dist.is_distributed() should not be used | |||
| if self.world_size > 1: | |||
| indices = self.scatter(indices) | |||
| step, length = self.batch_size, len(indices) | |||
| batch_index = [indices[i : i + step] for i in range(0, length, step)] | |||
| if self.drop_last and len(batch_index[-1]) < self.batch_size: | |||
| batch_index.pop() | |||
| return iter(batch_index) | |||
| class SequentialSampler(Sampler): | |||
| def __init__( | |||
| self, | |||
| dataset, | |||
| batch_size=1, | |||
| drop_last=False, | |||
| indices=None, | |||
| world_size=None, | |||
| rank=None, | |||
| ): | |||
| r""" | |||
| Sample elements sequentially | |||
| """ | |||
| super().__init__(dataset, batch_size, drop_last, None, world_size, rank) | |||
| if indices is not None and not isinstance(indices, collections.abc.Sequence): | |||
| raise ValueError( | |||
| "indices should be None or a sequence, " | |||
| "but got indices={}".format(indices) | |||
| ) | |||
| self.indices = indices | |||
| def sample(self) -> Iterator[Any]: | |||
| r""" | |||
| return a generator | |||
| """ | |||
| if self.indices is None: | |||
| return iter(range(len(self.dataset))) | |||
| else: | |||
| return self.indices | |||
| class RandomSampler(Sampler): | |||
| def __init__( | |||
| self, | |||
| dataset, | |||
| batch_size=1, | |||
| drop_last=False, | |||
| indices=None, | |||
| world_size=None, | |||
| rank=None, | |||
| seed=None, | |||
| ): | |||
| r""" | |||
| Sample elements randomly without replacement | |||
| """ | |||
| super().__init__(dataset, batch_size, drop_last, None, world_size, rank, seed) | |||
| if indices is not None and not isinstance(indices, collections.abc.Sequence): | |||
| raise ValueError( | |||
| "indices should be None or a sequence, " | |||
| "but got indices={}".format(indices) | |||
| ) | |||
| self.indices = indices | |||
| def sample(self) -> List: | |||
| if self.indices is None: | |||
| return self.rng.permutation(len(self.dataset)).tolist() | |||
| else: | |||
| return self.rng.permutation(self.indices).tolist() | |||
| class ReplacementSampler(Sampler): | |||
| def __init__( | |||
| self, | |||
| dataset, | |||
| batch_size=1, | |||
| drop_last=False, | |||
| num_samples=None, | |||
| weights=None, | |||
| world_size=None, | |||
| rank=None, | |||
| seed=None, | |||
| ): | |||
| r""" | |||
| Sample elements randomly with replacement | |||
| :type weights: List | |||
| :param weights: weights for sampling indices, it could be unnormalized weights | |||
| """ | |||
| super().__init__( | |||
| dataset, batch_size, drop_last, num_samples, world_size, rank, seed | |||
| ) | |||
| if weights is not None: | |||
| if not isinstance(weights, collections.abc.Sequence): | |||
| raise ValueError( | |||
| "weights should be None or a sequence, " | |||
| "but got weights={}".format(weights) | |||
| ) | |||
| if len(weights) != len(dataset): | |||
| raise ValueError( | |||
| "len(dataset)={} should be equal to" | |||
| "len(weights)={}".format(len(dataset), len(weights)) | |||
| ) | |||
| self.weights = weights | |||
| if self.weights is not None: | |||
| self.weights = np.array(weights) / sum(weights) | |||
| def sample(self) -> List: | |||
| n = len(self.dataset) | |||
| if self.weights is None: | |||
| return self.rng.randint(n, size=self.num_samples).tolist() | |||
| else: | |||
| return self.rng.multinomial(n, self.weights, self.num_samples).tolist() | |||
| class Infinite(Sampler): | |||
| r"""Infinite Sampler warper for basic sampler""" | |||
| def sample(self): | |||
| raise NotImplementedError("sample method not supported in Infinite") | |||
| def __init__(self, sampler): | |||
| self.sampler = sampler | |||
| self.sampler_iter = iter(self.sampler) | |||
| def __iter__(self): | |||
| return self | |||
| def __next__(self): | |||
| try: | |||
| index = next(self.sampler_iter) | |||
| except StopIteration: | |||
| self.sampler_iter = iter(self.sampler) | |||
| index = next(self.sampler_iter) | |||
| return index | |||
| def __len__(self): | |||
| return np.iinfo(np.int64).max | |||
| @@ -1,10 +0,0 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from .meta_transform import PseudoTransform, Transform | |||
| from .vision import * | |||
| @@ -1,31 +0,0 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from abc import ABC, abstractmethod | |||
| from typing import Sequence, Tuple | |||
| class Transform(ABC): | |||
| """ | |||
| rewrite apply method in subclass | |||
| """ | |||
| def apply_batch(self, inputs: Sequence[Tuple]): | |||
| return tuple(self.apply(input) for input in inputs) | |||
| @abstractmethod | |||
| def apply(self, input: Tuple): | |||
| pass | |||
| def __repr__(self): | |||
| return self.__class__.__name__ | |||
| class PseudoTransform(Transform): | |||
| def apply(self, input: Tuple): | |||
| return input | |||
| @@ -1,9 +0,0 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from .transform import * | |||
| @@ -1,111 +0,0 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import collections.abc | |||
| import functools | |||
| import random | |||
| import cv2 | |||
| import numpy as np | |||
| def wrap_keepdims(func): | |||
| """Wraper to keep the dimension of input images unchanged""" | |||
| @functools.wraps(func) | |||
| def wrapper(image, *args, **kwargs): | |||
| if len(image.shape) != 3: | |||
| raise ValueError( | |||
| "image must have 3 dims, but got {} dims".format(len(image.shape)) | |||
| ) | |||
| ret = func(image, *args, **kwargs) | |||
| if len(ret.shape) == 2: | |||
| ret = ret[:, :, np.newaxis] | |||
| return ret | |||
| return wrapper | |||
| @wrap_keepdims | |||
| def to_gray(image): | |||
| r""" | |||
| Change BGR format image's color space to gray | |||
| :param image: Input BGR format image, with (H, W, C) shape | |||
| :return: Gray format image, with (H, W, C) shape | |||
| """ | |||
| return cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) | |||
| @wrap_keepdims | |||
| def to_bgr(image): | |||
| r""" | |||
| Change gray format image's color space to BGR | |||
| :param image: input Gray format image, with (H, W, C) shape | |||
| :return: BGR format image, with (H, W, C) shape | |||
| """ | |||
| return cv2.cvtColor(image, cv2.COLOR_GRAY2BGR) | |||
| @wrap_keepdims | |||
| def pad(input, size, value): | |||
| r""" | |||
| Pad input data with *value* and given *size* | |||
| :param input: Input data, with (H, W, C) shape | |||
| :param size: Padding size of input data, it could be integer or sequence. | |||
| If it's an integer, the input data will be padded in four directions. | |||
| If it's a sequence contains two integer, the bottom and right side | |||
| of input data will be padded. | |||
| If it's a sequence contains four integer, the top, bottom, left, right | |||
| side of input data will be padded with given size. | |||
| :param value: Padding value of data, could be a sequence of int or float. | |||
| if it's float value, the dtype of image will be casted to float32 also. | |||
| :return: Padded image | |||
| """ | |||
| if isinstance(size, int): | |||
| size = (size, size, size, size) | |||
| elif isinstance(size, collections.abc.Sequence) and len(size) == 2: | |||
| size = (0, size[0], 0, size[1]) | |||
| if np.array(value).dtype == float: | |||
| input = input.astype(np.float32) | |||
| return cv2.copyMakeBorder(input, *size, cv2.BORDER_CONSTANT, value=value) | |||
| @wrap_keepdims | |||
| def flip(image, flipCode): | |||
| r""" | |||
| Accordding to the flipCode (the type of flip), flip the input image | |||
| :param image: Input image, with (H, W, C) shape | |||
| :param flipCode: code that indicates the type of flip. | |||
| 1 : Flip horizontally | |||
| 0 : Flip vertically | |||
| -1 : Flip horizontally and vertically | |||
| :return: BGR format image, with (H, W, C) shape | |||
| """ | |||
| return cv2.flip(image, flipCode=flipCode) | |||
| @wrap_keepdims | |||
| def resize(input, size, interpolation=cv2.INTER_LINEAR): | |||
| r""" | |||
| resize the input data to given size | |||
| :param input: Input data, could be image or masks, with (H, W, C) shape | |||
| :param size: Target size of input data, with (height, width) shape. | |||
| :param interpolation: Interpolation method. | |||
| :return: Resized data, with (H, W, C) shape | |||
| """ | |||
| if len(size) != 2: | |||
| raise ValueError("resize needs (h, w), but got {}".format(size)) | |||
| if isinstance(interpolation, collections.abc.Sequence): | |||
| interpolation = random.choice(interpolation) | |||
| return cv2.resize(input, size[::-1], interpolation=interpolation) | |||
| @@ -1,33 +0,0 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from .functional import ( | |||
| all_gather, | |||
| all_reduce_max, | |||
| all_reduce_min, | |||
| all_reduce_sum, | |||
| all_to_all, | |||
| bcast_param, | |||
| broadcast, | |||
| gather, | |||
| reduce_scatter_sum, | |||
| reduce_sum, | |||
| scatter, | |||
| ) | |||
| from .util import ( | |||
| get_backend, | |||
| get_free_ports, | |||
| get_master_ip, | |||
| get_master_port, | |||
| get_rank, | |||
| get_world_size, | |||
| group_barrier, | |||
| init_process_group, | |||
| is_distributed, | |||
| synchronized, | |||
| ) | |||
| @@ -1,302 +0,0 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from typing import Optional, Union | |||
| import megengine._internal as mgb | |||
| from megengine._internal.opr_param_defs import CollectiveComm as Param | |||
| from ..core import Buffer, Parameter, Tensor, wrap_io_tensor | |||
| from ..functional import add_update | |||
| from .helper import collective_comm_symvar | |||
| from .util import get_rank, is_distributed | |||
| @wrap_io_tensor | |||
| def _collective_comm(*args, **kargs): | |||
| return collective_comm_symvar(*args, **kargs) | |||
| def _group_check(*args): | |||
| """Return True when arguments are all None or all not None | |||
| """ | |||
| l = [val is None for val in args] | |||
| return len(set(l)) <= 1 | |||
| def reduce_sum( | |||
| tensor: Tensor, | |||
| key: Optional[str] = None, | |||
| nr_ranks: Optional[int] = None, | |||
| is_root: Optional[bool] = None, | |||
| ) -> Tensor: | |||
| """Create reduce_sum operator for collective communication | |||
| :param tensor: input tensor | |||
| :param key: unique identifier for collective communication | |||
| :param nr_ranks: number of ranks, use util.get_world_size() as default | |||
| :param is_root: whether this is a root node | |||
| """ | |||
| assert _group_check( | |||
| key, nr_ranks, is_root | |||
| ), "key, nr_ranks, is_root should be set at the same time" | |||
| return _collective_comm( | |||
| tensor, key, Param.Mode.REDUCE_SUM, nr_ranks, is_root, device=tensor.device, | |||
| ) | |||
| def gather( | |||
| tensor: Tensor, | |||
| key: Optional[str] = None, | |||
| nr_ranks: Optional[int] = None, | |||
| is_root: Optional[bool] = None, | |||
| rank: Optional[int] = None, | |||
| ) -> Tensor: | |||
| """Create gather operator for collective communication | |||
| :param tensor: input tensor | |||
| :param key: unique identifier for collective communication | |||
| :param nr_ranks: number of ranks, use util.get_world_size() as default | |||
| :param is_root: whether this is a root node | |||
| :param rank: rank of this node | |||
| """ | |||
| assert _group_check( | |||
| key, nr_ranks, is_root, rank | |||
| ), "key, nr_ranks, is_root, rank should be set at the same time" | |||
| return _collective_comm( | |||
| tensor, key, Param.Mode.GATHER, nr_ranks, is_root, rank, device=tensor.device, | |||
| ) | |||
| def broadcast( | |||
| tensor: Tensor, | |||
| key: Optional[str] = None, | |||
| nr_ranks: Optional[int] = None, | |||
| is_root: Optional[bool] = None, | |||
| ) -> Tensor: | |||
| """Create broadcast operator for collective communication | |||
| :param tensor: input tensor | |||
| :param key: unique identifier for collective communication | |||
| :param nr_ranks: number of ranks, use util.get_world_size() as default | |||
| :param is_root: whether this is a root node | |||
| """ | |||
| assert _group_check( | |||
| key, nr_ranks, is_root | |||
| ), "key, nr_ranks, is_root should be set at the same time" | |||
| if is_root is None: | |||
| is_root = get_rank() == 0 | |||
| if is_root: | |||
| inp = tensor | |||
| else: | |||
| inp = tensor._symvar.owner_graph | |||
| return _collective_comm( | |||
| inp, | |||
| key, | |||
| Param.Mode.BROADCAST, | |||
| nr_ranks, | |||
| is_root, | |||
| dtype=tensor.dtype, | |||
| device=tensor.device, | |||
| ) | |||
| def scatter( | |||
| tensor: Tensor, | |||
| key: Optional[str] = None, | |||
| nr_ranks: Optional[int] = None, | |||
| is_root: Optional[bool] = None, | |||
| rank: Optional[int] = None, | |||
| ) -> Tensor: | |||
| """Create scatter operator for collective communication | |||
| :param tensor: input tensor | |||
| :param key: unique identifier for collective communication | |||
| :param nr_ranks: number of ranks, use util.get_world_size() as default | |||
| :param is_root: whether this is a root node | |||
| :param rank: rank of this node | |||
| """ | |||
| assert _group_check( | |||
| key, nr_ranks, is_root, rank | |||
| ), "key, nr_ranks, is_root, rank should be set at the same time" | |||
| if key is None: | |||
| key = tensor._symvar.name | |||
| if is_root is None: | |||
| is_root = get_rank() == 0 | |||
| if is_root: | |||
| inp = tensor | |||
| else: | |||
| inp = tensor._symvar.owner_graph | |||
| return _collective_comm( | |||
| inp, | |||
| key, | |||
| Param.Mode.SCATTER, | |||
| nr_ranks, | |||
| is_root, | |||
| rank, | |||
| dtype=tensor.dtype, | |||
| device=tensor.device, | |||
| ) | |||
| def all_to_all( | |||
| tensor: Tensor, | |||
| key: Optional[str] = None, | |||
| nr_ranks: Optional[int] = None, | |||
| rank: Optional[int] = None, | |||
| local_grad: Optional[bool] = False, | |||
| ) -> Tensor: | |||
| """Create all_to_all operator for collective communication | |||
| :param tensor: input tensor | |||
| :param key: unique identifier for collective communication | |||
| :param nr_ranks: number of ranks, use util.get_world_size() as default | |||
| :param rank: rank of this node | |||
| :param local_grad: whether use local grad | |||
| """ | |||
| assert _group_check( | |||
| key, nr_ranks, rank | |||
| ), "key, nr_ranks, rank should be set at the same time" | |||
| return _collective_comm( | |||
| tensor, key, Param.Mode.ALL_TO_ALL, nr_ranks, rank=rank, local_grad=local_grad, | |||
| ) | |||
| def all_gather( | |||
| tensor: Tensor, | |||
| key: Optional[str] = None, | |||
| nr_ranks: Optional[int] = None, | |||
| rank: Optional[int] = None, | |||
| local_grad: Optional[bool] = False, | |||
| ) -> Tensor: | |||
| """Create all_gather operator for collective communication | |||
| :param tensor: input tensor | |||
| :param key: unique identifier for collective communication | |||
| :param nr_ranks: number of ranks, use util.get_world_size() as default | |||
| :param rank: rank of this node | |||
| :param local_grad: whether use local grad | |||
| """ | |||
| assert _group_check( | |||
| key, nr_ranks, rank | |||
| ), "key, nr_ranks, rank should be set at the same time" | |||
| return _collective_comm( | |||
| tensor, key, Param.Mode.ALL_GATHER, nr_ranks, rank=rank, local_grad=local_grad | |||
| ) | |||
| def reduce_scatter_sum( | |||
| tensor: Tensor, | |||
| key: Optional[str] = None, | |||
| nr_ranks: Optional[int] = None, | |||
| rank: Optional[int] = None, | |||
| local_grad: Optional[bool] = False, | |||
| ) -> Tensor: | |||
| """Create reduce_scatter_sum operator for collective communication | |||
| :param tensor: input tensor | |||
| :param key: unique identifier for collective communication | |||
| :param nr_ranks: number of ranks, use util.get_world_size() as default | |||
| :param rank: rank of this node | |||
| :param local_grad: whether use local grad | |||
| """ | |||
| assert _group_check( | |||
| key, nr_ranks, rank | |||
| ), "key, nr_ranks, rank should be set at the same time" | |||
| return _collective_comm( | |||
| tensor, | |||
| key, | |||
| Param.Mode.REDUCE_SCATTER_SUM, | |||
| nr_ranks, | |||
| rank=rank, | |||
| local_grad=local_grad, | |||
| ) | |||
| def all_reduce_sum( | |||
| tensor: Tensor, | |||
| key: Optional[str] = None, | |||
| nr_ranks: Optional[int] = None, | |||
| local_grad: Optional[bool] = False, | |||
| ) -> Tensor: | |||
| """Create all_reduce_sum operator for collective communication | |||
| :param tensor: input tensor | |||
| :param key: unique identifier for collective communication | |||
| :param nr_ranks: number of ranks, use util.get_world_size() as default | |||
| :param local_grad: whether use local grad | |||
| """ | |||
| assert _group_check(key, nr_ranks), "key, nr_ranks should be set at the same time" | |||
| return _collective_comm( | |||
| tensor, key, Param.Mode.ALL_REDUCE_SUM, nr_ranks, local_grad=local_grad | |||
| ) | |||
| def all_reduce_max( | |||
| tensor: Tensor, | |||
| key: Optional[str] = None, | |||
| nr_ranks: Optional[int] = None, | |||
| local_grad: Optional[bool] = False, | |||
| ) -> Tensor: | |||
| """Create all_reduce_max operator for collective communication | |||
| :param tensor: input tensor | |||
| :param key: unique identifier for collective communication | |||
| :param nr_ranks: number of ranks, use util.get_world_size() as default | |||
| :param local_grad: whether use local grad | |||
| """ | |||
| assert _group_check(key, nr_ranks), "key, nr_ranks should be set at the same time" | |||
| return _collective_comm( | |||
| tensor, key, Param.Mode.ALL_REDUCE_MAX, nr_ranks, local_grad=local_grad | |||
| ) | |||
| def all_reduce_min( | |||
| tensor: Tensor, | |||
| key: Optional[str] = None, | |||
| nr_ranks: Optional[int] = None, | |||
| local_grad: Optional[bool] = False, | |||
| ) -> Tensor: | |||
| """Create all_reduce_min operator for collective communication | |||
| :param tensor: input tensor | |||
| :param key: unique identifier for collective communication | |||
| :param nr_ranks: number of ranks, use util.get_world_size() as default | |||
| :param local_grad: whether use local grad | |||
| """ | |||
| assert _group_check(key, nr_ranks), "key, nr_ranks should be set at the same time" | |||
| return _collective_comm( | |||
| tensor, key, Param.Mode.ALL_REDUCE_MIN, nr_ranks, local_grad=local_grad | |||
| ) | |||
| def bcast_param( | |||
| inp: Union[Buffer, Parameter], | |||
| key: Optional[str] = None, | |||
| nr_ranks: Optional[int] = None, | |||
| is_root: Optional[bool] = None, | |||
| ) -> None: | |||
| """Broadcast parameters among devices | |||
| :param inp: input Buffer or Parameter to be synchronized | |||
| :param key: unique identifier for collective communication | |||
| :param nr_ranks: number of ranks, use util.get_world_size() as default | |||
| :param is_root: whether this is a root node | |||
| """ | |||
| if not is_distributed(): | |||
| return | |||
| assert _group_check( | |||
| key, nr_ranks, is_root | |||
| ), "key, nr_ranks, is_root should be set at the same time" | |||
| assert isinstance(inp, (Buffer, Parameter)) | |||
| bcast_res = broadcast(inp, key, nr_ranks, is_root) | |||
| add_update(inp, bcast_res, alpha=0) | |||
| @@ -1,63 +0,0 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from typing import Optional, Union | |||
| import megengine._internal as mgb | |||
| from megengine._internal.opr_param_defs import CollectiveComm as CollParam | |||
| from .util import ( | |||
| get_backend, | |||
| get_group_id, | |||
| get_master_ip, | |||
| get_master_port, | |||
| get_rank, | |||
| get_world_size, | |||
| ) | |||
| def collective_comm_symvar( | |||
| inp: Union[mgb.SymbolVar, mgb.CompGraph], | |||
| key: Optional[str] = None, | |||
| op: CollParam.Mode = None, | |||
| nr_ranks: Optional[int] = None, | |||
| is_root: Optional[bool] = None, | |||
| rank: Optional[int] = None, | |||
| local_grad: Optional[bool] = False, | |||
| dtype: Optional[type] = None, | |||
| device: Optional[mgb.CompNode] = None, | |||
| comp_graph: Optional[mgb.CompGraph] = None, | |||
| ) -> mgb.SymbolVar: | |||
| """Helper function for creating collective_comm operators | |||
| :param inp: tensor or comp_graph | |||
| :param key: unique identifier for collective communication | |||
| :param op: mode of collective communication | |||
| :param nr_ranks: number of ranks, use util.get_world_size() as default | |||
| :param is_root: whether this node is root node | |||
| :param rank: rank of this node | |||
| :param local_grad: whether use local grad | |||
| :param dtype: output data type, use dtype of inp as default | |||
| :param device: output comp node, use comp node of inp as default | |||
| :param comp_graph: output comp graph, use comp graph of inp as default | |||
| """ | |||
| return mgb.opr.collective_comm( | |||
| inp, | |||
| key=key if key is not None else ("collective_comm_" + str(get_group_id())), | |||
| nr_devices=nr_ranks if nr_ranks is not None else get_world_size(), | |||
| is_root=is_root if is_root is not None else (get_rank() == 0), | |||
| rank=rank if rank is not None else get_rank(), | |||
| local_grad=local_grad, | |||
| server_addr=get_master_ip(), | |||
| port=get_master_port(), | |||
| param=CollParam(mode=op), | |||
| dtype=dtype, | |||
| backend=get_backend(), | |||
| comp_node=device, | |||
| comp_graph=comp_graph, | |||
| ) | |||
| @@ -1,146 +0,0 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import functools | |||
| import socket | |||
| from typing import Callable, List, Optional | |||
| import megengine._internal as mgb | |||
| from ..core import set_default_device | |||
| _master_ip = None | |||
| _master_port = 0 | |||
| _world_size = 0 | |||
| _rank = 0 | |||
| _backend = None | |||
| _group_id = 0 | |||
| def init_process_group( | |||
| master_ip: str, | |||
| master_port: int, | |||
| world_size: int, | |||
| rank: int, | |||
| dev: int, | |||
| backend: Optional[str] = "nccl", | |||
| ) -> None: | |||
| """Initialize the distributed process group, and also specify the device used in the current process. | |||
| :param master_ip: IP address of the master node. | |||
| :param master_port: Port available for all processes to communicate. | |||
| :param world_size: Total number of processes participating in the job. | |||
| :param rank: Rank of the current process. | |||
| :param dev: The GPU device id to bind this process to. | |||
| :param backend: Communicator backend, currently support 'nccl' and 'ucx' | |||
| """ | |||
| global _master_ip # pylint: disable=global-statement | |||
| global _master_port # pylint: disable=global-statement | |||
| global _world_size # pylint: disable=global-statement | |||
| global _rank # pylint: disable=global-statement | |||
| global _backend # pylint: disable=global-statement | |||
| global _group_id # pylint: disable=global-statement | |||
| if not isinstance(master_ip, str): | |||
| raise TypeError("Expect type str but got {}".format(type(master_ip))) | |||
| if not isinstance(master_port, int): | |||
| raise TypeError("Expect type int but got {}".format(type(master_port))) | |||
| if not isinstance(world_size, int): | |||
| raise TypeError("Expect type int but got {}".format(type(world_size))) | |||
| if not isinstance(rank, int): | |||
| raise TypeError("Expect type int but got {}".format(type(rank))) | |||
| if not isinstance(backend, str): | |||
| raise TypeError("Expect type str but got {}".format(type(backend))) | |||
| _master_ip = master_ip | |||
| _master_port = master_port | |||
| _world_size = world_size | |||
| _rank = rank | |||
| _backend = backend | |||
| _group_id = 0 | |||
| set_default_device(mgb.comp_node("gpu" + str(dev))) | |||
| if rank == 0: | |||
| _master_port = mgb.config.create_mm_server("0.0.0.0", master_port) | |||
| if _master_port == -1: | |||
| raise Exception("Failed to start server on port {}".format(master_port)) | |||
| else: | |||
| assert master_port > 0, "master_port must be specified for non-zero rank" | |||
| def is_distributed() -> bool: | |||
| """Return True if the distributed process group has been initialized""" | |||
| return _world_size is not None and _world_size > 1 | |||
| def get_master_ip() -> str: | |||
| """Get the IP address of the master node""" | |||
| return str(_master_ip) | |||
| def get_master_port() -> int: | |||
| """Get the port of the rpc server on the master node""" | |||
| return _master_port | |||
| def get_world_size() -> int: | |||
| """Get the total number of processes participating in the job""" | |||
| return _world_size | |||
| def get_rank() -> int: | |||
| """Get the rank of the current process""" | |||
| return _rank | |||
| def get_backend() -> str: | |||
| """Get the backend str""" | |||
| return str(_backend) | |||
| def get_group_id() -> int: | |||
| """Get group id for collective communication""" | |||
| global _group_id | |||
| _group_id += 1 | |||
| return _group_id | |||
| def group_barrier() -> None: | |||
| """Block until all ranks in the group reach this barrier""" | |||
| mgb.config.group_barrier(_master_ip, _master_port, _world_size, _rank) | |||
| def synchronized(func: Callable): | |||
| """Decorator. Decorated function will synchronize when finished. | |||
| Specifically, we use this to prevent data race during hub.load""" | |||
| @functools.wraps(func) | |||
| def wrapper(*args, **kwargs): | |||
| if not is_distributed(): | |||
| return func(*args, **kwargs) | |||
| ret = func(*args, **kwargs) | |||
| group_barrier() | |||
| return ret | |||
| return wrapper | |||
| def get_free_ports(num: int) -> List[int]: | |||
| """Get one or more free ports. | |||
| """ | |||
| socks, ports = [], [] | |||
| for i in range(num): | |||
| sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | |||
| sock.bind(("", 0)) | |||
| socks.append(sock) | |||
| ports.append(sock.getsockname()[1]) | |||
| for sock in socks: | |||
| sock.close() | |||
| return ports | |||
| @@ -1,118 +0,0 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # pylint: disable=redefined-builtin | |||
| from .elemwise import ( | |||
| abs, | |||
| add, | |||
| arccos, | |||
| arcsin, | |||
| ceil, | |||
| clamp, | |||
| cos, | |||
| divide, | |||
| equal, | |||
| exp, | |||
| floor, | |||
| greater, | |||
| greater_equal, | |||
| isinf, | |||
| isnan, | |||
| less, | |||
| less_equal, | |||
| log, | |||
| maximum, | |||
| minimum, | |||
| mod, | |||
| multiply, | |||
| power, | |||
| relu, | |||
| round, | |||
| sigmoid, | |||
| sin, | |||
| subtract, | |||
| tanh, | |||
| ) | |||
| from .graph import add_extra_vardep, add_update, grad | |||
| from .loss import ( | |||
| binary_cross_entropy, | |||
| cross_entropy, | |||
| cross_entropy_with_softmax, | |||
| hinge_loss, | |||
| l1_loss, | |||
| nll_loss, | |||
| smooth_l1_loss, | |||
| square_loss, | |||
| triplet_margin_loss, | |||
| ) | |||
| from .math import ( | |||
| argmax, | |||
| argmin, | |||
| logsumexp, | |||
| max, | |||
| mean, | |||
| min, | |||
| norm, | |||
| normalize, | |||
| prod, | |||
| sqrt, | |||
| sum, | |||
| ) | |||
| from .nn import ( | |||
| assert_equal, | |||
| avg_pool2d, | |||
| batch_norm2d, | |||
| batched_matrix_mul, | |||
| conv2d, | |||
| conv_transpose2d, | |||
| dropout, | |||
| embedding, | |||
| eye, | |||
| flatten, | |||
| identity, | |||
| indexing_one_hot, | |||
| interpolate, | |||
| leaky_relu, | |||
| linear, | |||
| local_conv2d, | |||
| matrix_mul, | |||
| max_pool2d, | |||
| one_hot, | |||
| prelu, | |||
| remap, | |||
| roi_align, | |||
| roi_pooling, | |||
| softmax, | |||
| softplus, | |||
| sync_batch_norm, | |||
| warp_perspective, | |||
| ) | |||
| from .quantized import conv_bias_activation | |||
| from .sort import argsort, sort, top_k | |||
| from .tensor import ( | |||
| add_axis, | |||
| arange, | |||
| broadcast_to, | |||
| concat, | |||
| cond_take, | |||
| dimshuffle, | |||
| gather, | |||
| linspace, | |||
| remove_axis, | |||
| reshape, | |||
| scatter, | |||
| shapeof, | |||
| transpose, | |||
| where, | |||
| zeros_like, | |||
| ) | |||
| from .utils import accuracy, zero_grad | |||
| # delete namespace | |||
| # pylint: disable=undefined-variable | |||
| del elemwise, graph, loss, math, nn, tensor # type: ignore[name-defined] | |||
| @@ -1,49 +0,0 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import os | |||
| _conv_execution_strategy = os.getenv("MEGENGINE_CONV_EXECUTION_STRATEGY", "HEURISTIC") | |||
| def get_conv_execution_strategy() -> str: | |||
| """Returns the execuation strategy of :class:`~.Conv2d`. | |||
| See :func:`~.set_conv_execution_strategy` for possible return values | |||
| """ | |||
| return _conv_execution_strategy | |||
| def set_conv_execution_strategy(option: str): | |||
| """Sets the execuation strategy of :class:`~.Conv2d`. | |||
| :param option: Decides how :class:`~.Conv2d` algorithm is chosen. | |||
| Available values: | |||
| * 'HEURISTIC' uses heuristic to choose the fastest algorithm. | |||
| * 'PROFILE' runs possible algorithms on real device to find the best. | |||
| * 'PROFILE_HEURISTIC' uses profile result and heuristic to choose the fastest algorithm. | |||
| * 'PROFILE_REPRODUCIBLE' uses the fastest of profile result that is also reproducible. | |||
| * 'HEURISTIC_REPRODUCIBLE' uses heuristic to choose the fastest algorithm that is also reproducible. | |||
| The default strategy is 'HEURISTIC'. | |||
| It can also be set through the environmental variable 'MEGENGINE_CONV_EXECUTION_STRATEGY'. | |||
| """ | |||
| valid_option = ( | |||
| "HEURISTIC", | |||
| "PROFILE", | |||
| "PROFILE_HEURISTIC", | |||
| "PROFILE_REPRODUCIBLE", | |||
| "HEURISTIC_REPRODUCIBLE", | |||
| ) | |||
| if not option in valid_option: | |||
| raise ValueError("Valid option can only be one of {}".format(valid_option)) | |||
| global _conv_execution_strategy # pylint: disable=global-statement | |||
| _conv_execution_strategy = option | |||
| @@ -1,299 +0,0 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # pylint: disable=unused-argument,invalid-name,redefined-builtin,arguments-out-of-order | |||
| import functools | |||
| import megengine._internal as mgb | |||
| from ..core.graph import _use_default_if_none | |||
| from ..core.tensor import Tensor, wrap_io_tensor | |||
| __all__ = [ | |||
| "abs", | |||
| "arccos", | |||
| "add", | |||
| "arcsin", | |||
| "clamp", | |||
| "ceil", | |||
| "cos", | |||
| "divide", | |||
| "equal", | |||
| "exp", | |||
| "greater", | |||
| "greater_equal", | |||
| "floor", | |||
| "isinf", | |||
| "isnan", | |||
| "less", | |||
| "less_equal", | |||
| "log", | |||
| "maximum", | |||
| "minimum", | |||
| "mod", | |||
| "multiply", | |||
| "power", | |||
| "relu", | |||
| "round", | |||
| "sigmoid", | |||
| "sin", | |||
| "subtract", | |||
| "tanh", | |||
| ] | |||
| def _elemwise(mode): # DONT export | |||
| """Decorator helps to wrap megbrain element-wise oprs""" | |||
| def elemwise_decorator(func): | |||
| @functools.wraps(func) | |||
| @wrap_io_tensor | |||
| def elemwise_func(*inputs) -> Tensor: | |||
| if all(isinstance(i, (int, float)) for i in inputs): | |||
| device, comp_graph = _use_default_if_none(None, None) | |||
| ret = mgb.opr.elemwise( | |||
| *inputs, mode=mode, comp_node=device, comp_graph=comp_graph | |||
| ) | |||
| return ret.inferred_value[0] | |||
| return mgb.opr.elemwise(*inputs, mode=mode) | |||
| return elemwise_func | |||
| return elemwise_decorator | |||
| @_elemwise("ABS") | |||
| def abs(x): | |||
| """Calculate the absolute value element-wise.""" | |||
| @_elemwise("ACOS") | |||
| def arccos(x): | |||
| """Inverse cosine, element-wise.""" | |||
| @_elemwise("ADD") | |||
| def add(x, y): | |||
| """Element-wise addition.""" | |||
| @_elemwise("ASIN") | |||
| def arcsin(x): | |||
| """Inverse sine, element-wise.""" | |||
| @_elemwise("CEIL") | |||
| def ceil(x): | |||
| """Return the ceil of the input, element-wise.""" | |||
| @_elemwise("COS") | |||
| def cos(x): | |||
| """Cosine, element-wise.""" | |||
| @_elemwise("TRUE_DIV") | |||
| def divide(x, y): | |||
| """Return (x / y) element-wise.""" | |||
| @_elemwise("EQ") | |||
| def equal(x, y): | |||
| """Return (x == y) element-wise.""" | |||
| @_elemwise("EXP") | |||
| def exp(x): | |||
| """Calculate the exponential element-wise""" | |||
| @_elemwise("FLOOR") | |||
| def floor(x): | |||
| """Return the floor of the input, element-wise""" | |||
| def greater(x, y): | |||
| """Return (x > y) element-wise.""" | |||
| return less(y, x) | |||
| def greater_equal(x, y): | |||
| """Return (x >= y) element-wise""" | |||
| return less_equal(y, x) | |||
| @_elemwise("LT") | |||
| def less(x, y): | |||
| """Return (x < y) element-wise.""" | |||
| @_elemwise("LEQ") | |||
| def less_equal(x, y): | |||
| """Return (x =< y) element-wise.""" | |||
| @_elemwise("LOG") | |||
| def log(x): | |||
| """Natural logarithm (base `e`), element-wise.""" | |||
| @_elemwise("MAX") | |||
| def maximum(x, y): | |||
| """Element-wise maximum of array elements.""" | |||
| @_elemwise("MIN") | |||
| def minimum(x, y): | |||
| """Element-wise minimum of array elements.""" | |||
| @_elemwise("MOD") | |||
| def mod(x, y): | |||
| """Return element-wise remainder of division.""" | |||
| @_elemwise("MUL") | |||
| def multiply(x, y): | |||
| """Element-wise multiplication.""" | |||
| @_elemwise("POW") | |||
| def power(x, y): | |||
| """First tensor elements raised to powers from second tensor (x ** y), element-wise.""" | |||
| @_elemwise("RELU") | |||
| def relu(x): | |||
| """Return `max(x, 0)` element-wise.""" | |||
| @_elemwise("ROUND") | |||
| def round(x): | |||
| """Round tensor to int element-wise.""" | |||
| @_elemwise("SIGMOID") | |||
| def sigmoid(x): | |||
| """Return 1 / ( 1 + exp( -x ) ) element-wise.""" | |||
| @_elemwise("SIN") | |||
| def sin(x): | |||
| """Sine, element-wise.""" | |||
| @_elemwise("SUB") | |||
| def subtract(x, y): | |||
| """Subtract arguments element-wise""" | |||
| @_elemwise("TANH") | |||
| def tanh(x): | |||
| """Compute hyperbolic tangent element-wise.""" | |||
| @wrap_io_tensor | |||
| def clamp(inp: Tensor, lower=None, upper=None) -> Tensor: | |||
| r""" | |||
| Clamp all elements in :attr:`inp` into the range `[` :attr:`lower`, :attr:`upper` `]` and return | |||
| a resulting tensor: | |||
| .. math:: | |||
| y_i = \begin{cases} | |||
| \text{lower} & \text{if } x_i < \text{lower} \\ | |||
| x_i & \text{if } \text{lower} \leq x_i \leq \text{upper} \\ | |||
| \text{upper} & \text{if } x_i > \text{upper} | |||
| \end{cases} | |||
| :param inp: the input tensor. | |||
| :param lower: lower-bound of the range to be clamped to | |||
| :param upper: upper-bound of the range to be clamped to | |||
| Example: | |||
| .. testcode:: | |||
| import numpy as np | |||
| from megengine import tensor | |||
| import megengine.functional as F | |||
| a = tensor(np.arange(5).astype(np.int32)) | |||
| print(F.clamp(a, 2, 4).numpy()) | |||
| print(F.clamp(a, lower=3).numpy()) | |||
| print(F.clamp(a, upper=3).numpy()) | |||
| .. testoutput:: | |||
| [2 2 2 3 4] | |||
| [3 3 3 3 4] | |||
| [0 1 2 3 3] | |||
| """ | |||
| assert ( | |||
| lower is not None or upper is not None | |||
| ), "At least one of 'lower' or 'upper' must not be None" | |||
| if lower is not None: | |||
| if upper is not None: | |||
| assert lower <= upper, "clamp lower bound is bigger that upper bound" | |||
| return minimum(maximum(inp, lower), upper) | |||
| else: | |||
| return maximum(inp, lower) | |||
| else: | |||
| return minimum(inp, upper) | |||
| def isnan(inp: Tensor) -> Tensor: | |||
| r"""Returns a new tensor representing if each element is NaN or not. | |||
| :param: inp | |||
| :return: a new tensor representing if each element in :attr:`inp` is NaN or not. | |||
| Examples: | |||
| .. testcode:: | |||
| from megengine import tensor | |||
| import megengine.functional as F | |||
| x = tensor([1, float("nan"), 0]) | |||
| print(F.isnan(x)) | |||
| .. testoutput:: | |||
| Tensor([0 1 0], dtype=uint8) | |||
| """ | |||
| return (inp != inp).astype("uint8") | |||
| def isinf(inp: Tensor) -> Tensor: | |||
| r"""Returns a new tensor representing if each element is Inf or not. | |||
| :param: inp | |||
| :return: a new tensor representing if each element in :attr:`inp` is Inf or not. | |||
| Examples: | |||
| .. testcode:: | |||
| from megengine import tensor | |||
| import megengine.functional as F | |||
| x = tensor([1, float("inf"), 0]) | |||
| print(F.isinf(x)) | |||
| .. testoutput:: | |||
| Tensor([0 1 0], dtype=uint8) | |||
| """ | |||
| return (abs(inp) == float("inf")).astype("uint8") | |||
| @@ -1,65 +0,0 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # pylint: disable=too-many-lines | |||
| from typing import List | |||
| import megengine._internal as mgb | |||
| from ..core import Tensor, wrap_io_tensor | |||
| @wrap_io_tensor | |||
| def cambricon_subgraph( | |||
| inputs: List[Tensor], data: bytes, symbol: str, tensor_dim_mutable: bool, | |||
| ) -> List[Tensor]: | |||
| """Load a serialized Cambricon subgraph (i.e. cnrtModel_t) and | |||
| execute the operations defined in the subgraph. | |||
| :param inputs: List of input tensors of the subgraph. | |||
| :param data: The serialized subgraph. | |||
| :param symbol: The name of the function in the subgraph. | |||
| The function is corresponding to a cnmlFusionOp | |||
| which is added to the cnmlModel_t/cnrtModel_t. | |||
| :param tensor_dim_mutable: Whether the input tensors' shapes are mutalbe | |||
| in cnrtModel_t | |||
| """ | |||
| return mgb.opr.cambricon_runtime( | |||
| data, symbol, tuple(map(lambda x: x._symvar, inputs)), tensor_dim_mutable | |||
| ) | |||
| @wrap_io_tensor | |||
| def atlas_subgraph(inputs: List[Tensor], data: bytes) -> List[Tensor]: | |||
| """Load a serialized Atlas subgraph (i.e. om model) and | |||
| execute the operations defined in the subgraph. | |||
| :param inputs: List of input tensors of the subgraph. | |||
| :param data: The serialized subgraph. | |||
| """ | |||
| return mgb.opr.atlas_runtime(tuple(map(lambda x: x._symvar, inputs)), data) | |||
| @wrap_io_tensor | |||
| def extern_opr_subgraph( | |||
| inputs, output_shapes: List[tuple], dump_name: str, dump_data: bytes, | |||
| ) -> List[Tensor]: | |||
| """Load a serialized extern opr subgraph and fake execute the operator | |||
| :param inputs: Tensor or list of input tensors. | |||
| :param output_shapes: The output shapes. | |||
| :param dump_name: The serialized subgraph name. | |||
| :param dump_data: The serialized subgraph. | |||
| :return: List of tensors | |||
| """ | |||
| if not isinstance(inputs, list): | |||
| inputs = [inputs] | |||
| return mgb.opr.extern_c_opr_placeholder( | |||
| inputs, output_shapes, dump_name=dump_name, dump_data=dump_data, | |||
| ) | |||
| @@ -1,125 +0,0 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import collections | |||
| from typing import Iterable, Optional, Union | |||
| import megengine._internal as mgb | |||
| from ..core.graph import get_default_graph | |||
| from ..core.tensor import Tensor, wrap_io_tensor | |||
| from ..jit import barrier, mark_impure, trace | |||
| @wrap_io_tensor | |||
| def grad( | |||
| target: Tensor, | |||
| wrt: Union[Tensor, Iterable[Tensor]], | |||
| warn_mid_wrt: bool = True, | |||
| use_virtual_grad: bool = None, | |||
| return_zero_for_nodep: bool = True, | |||
| ) -> Union[Tensor, Iterable[Optional[Tensor]], None]: | |||
| r"""Compute the symbolic gradient of ``target`` with repect to ``wrt``. | |||
| ``wrt`` can either be a single tensor or a sequence of tensors. | |||
| :param target: ``grad`` target tensor | |||
| :param wrt: with respect to which to compute the gradient | |||
| :param warn_mid_wrt: whether to give warning if ``wrt`` is not endpoint | |||
| :param use_virtual_grad: whether to use virtual ``grad`` opr, so fwd graph can | |||
| be optimized before applying ``grad``; if ``None`` is given, then virtual | |||
| ``grad`` would be used if ``graph_opt_level >= 2`` | |||
| :param return_zero_for_nodep: if ``target`` does not depend on ``wrt``, set to True to return | |||
| a zero-valued :class:`~.Tensor` rather than ``None``; can't be set to False when using | |||
| virtual ``grad`` opr. | |||
| :return: :math:`\partial\text{target} / \partial\text{wrt}` | |||
| """ | |||
| if not isinstance(wrt, mgb.SymbolVar): | |||
| assert isinstance(wrt, collections.Iterable) | |||
| wrt = [w._symvar for w in wrt] | |||
| return mgb.grad(target, wrt, warn_mid_wrt, use_virtual_grad, return_zero_for_nodep) | |||
| _add_update_cache = {} # type: dict | |||
| _dummy = mgb.SharedScalar(0) | |||
| def add_update( | |||
| dest: Tensor, | |||
| delta: Tensor, | |||
| *, | |||
| alpha: Union[Tensor, float, int] = 1.0, | |||
| beta: Union[Tensor, float, int] = 1.0, | |||
| bias: Union[Tensor, float, int] = 0.0 | |||
| ): | |||
| r"""Inplace modify ``dest`` as follows: | |||
| .. math:: | |||
| dest = alpha * dest + beta * delta + bias | |||
| :param dest: input data that will be inplace modified. | |||
| :param delta: update value that will be added to ``dest``. | |||
| :param alpha: weight ratio of ``dest``. Default: 1.0 | |||
| :param beta: weight ratio of ``delta``. Default: 1.0 | |||
| :param bias: bias value appended to the result. Default: 0.0 | |||
| """ | |||
| if isinstance(beta, Tensor) or isinstance(alpha, Tensor): | |||
| delta *= beta | |||
| beta = 1.0 | |||
| if isinstance(alpha, Tensor): | |||
| delta += (alpha - 1.0) * dest | |||
| alpha = 1.0 | |||
| if isinstance(bias, Tensor): | |||
| delta += bias | |||
| bias = 0.0 | |||
| comp_graph = dest._comp_graph or get_default_graph() | |||
| comp_node = dest._comp_node | |||
| if not isinstance(delta, Tensor): | |||
| _delta = mgb.make_immutable( | |||
| value=delta, comp_node=comp_node, comp_graph=comp_graph | |||
| ) | |||
| else: | |||
| _delta = delta._attach(comp_graph) | |||
| _dest = dest._attach(comp_graph) | |||
| # use (dest, delta) as the key, so we could not add the same delta to dest in static graph | |||
| key = (comp_graph._id(), _dest.id, _delta.id) | |||
| if key in _add_update_cache: | |||
| _alpha, _beta, _bias, config = _add_update_cache[key] | |||
| mgb.mgb._mgb.SharedScalar__set(_alpha, alpha) | |||
| mgb.mgb._mgb.SharedScalar__set(_beta, beta) | |||
| mgb.mgb._mgb.SharedScalar__set(_bias, bias) | |||
| else: | |||
| _alpha = mgb.SharedScalar(alpha) | |||
| _beta = mgb.SharedScalar(beta) | |||
| _bias = mgb.SharedScalar(bias) | |||
| config = mgb.helper.gen_config(None, comp_node, None) | |||
| _add_update_cache[key] = (_alpha, _beta, _bias, config) | |||
| u = mgb.mgb._Opr.add_update( | |||
| _dest, barrier(_delta), _alpha, _beta, _bias, _dummy, config | |||
| ) | |||
| mark_impure(u) | |||
| if trace._active_instance: | |||
| dest._override_symvar_during_trace(trace._active_instance, u) | |||
| return Tensor(u) | |||
| @wrap_io_tensor | |||
| def add_extra_vardep(oup: Tensor, dep: Tensor): | |||
| r"""Explicitly set the dependency that tensor ``oup`` depends on tensor ``dep``. | |||
| """ | |||
| return mgb.config.add_extra_vardep(oup, dep) | |||
| @@ -1,391 +0,0 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import megengine._internal as mgb | |||
| from ..core.tensor import Tensor | |||
| from .elemwise import abs, equal, log, maximum, power, relu | |||
| from .nn import assert_equal, indexing_one_hot | |||
| from .tensor import where | |||
| from .utils import zero_grad | |||
| def l1_loss(pred: Tensor, label: Tensor) -> Tensor: | |||
| r""" | |||
| Calculates the mean absolute error (MAE) between | |||
| each element in the pred :math:`x` and label :math:`y`. | |||
| The mean absolute error can be described as: | |||
| .. math:: \ell(x,y) = mean\left(L \right) | |||
| where | |||
| .. math:: | |||
| L = \{l_1,\dots,l_N\}, \quad | |||
| l_n = \left| x_n - y_n \right|, | |||
| :math:`x` and :math:`y` are tensors of arbitrary shapes with a total | |||
| of :math:`N` elements each. :math:`N` is the batch size. | |||
| :param pred: The predicted result from model. | |||
| :param label: The ground truth to compare. | |||
| Examples: | |||
| .. testcode:: | |||
| import numpy as np | |||
| import megengine as mge | |||
| import megengine.functional as F | |||
| ipt = mge.tensor(np.array([3, 3, 3, 3]).astype(np.float32)) | |||
| tgt = mge.tensor(np.array([2, 8, 6, 1]).astype(np.float32)) | |||
| loss = F.l1_loss(ipt,tgt) | |||
| print(loss.numpy()) | |||
| Outputs: | |||
| .. testoutput:: | |||
| [2.75] | |||
| """ | |||
| diff = pred - label | |||
| return abs(diff).mean() | |||
| def square_loss(pred: Tensor, label: Tensor) -> Tensor: | |||
| r""" | |||
| Calculates the mean squared error (squared L2 norm) between | |||
| each element in the pred :math:`x` and label :math:`y`. | |||
| The mean squared error can be described as: | |||
| .. math:: \ell(x, y) = mean\left( L \right) | |||
| where | |||
| .. math:: | |||
| L = \{l_1,\dots,l_N\}, \quad | |||
| l_n = \left( x_n - y_n \right)^2, | |||
| :math:`x` and :math:`y` are tensors of arbitrary shapes with a total | |||
| of :math:`N` elements each. :math:`N` is the batch size. | |||
| :param pred: The predicted result from model. | |||
| :param label: The ground truth to compare. | |||
| Shape: | |||
| - pred: :math:`(N, *)` where :math:`*` means any number of additional | |||
| dimensions | |||
| - label: :math:`(N, *)`. Same shape as ``pred`` | |||
| """ | |||
| diff = pred - label | |||
| return (diff ** 2).mean() | |||
| def cross_entropy( | |||
| inp: Tensor, target: Tensor, axis: int = 1, ignore_index: int = -1 | |||
| ) -> Tensor: | |||
| r""" | |||
| Returns the cross entropy loss in a classification problem. | |||
| .. math:: \textrm{CrossEntropy}(x, y) = - \sum_{i} y_i\log(x_i) | |||
| :param inp: The input tensor representing the predicted probability. | |||
| :param label: The input tensor representing the classification label. | |||
| :param axis: An axis along which cross_entropy will be applied. Default: 1 | |||
| :param ignore_index: Specifies a target value that is ignored and does not contribute to the input gradient. Default: -1 | |||
| Examples: | |||
| .. testcode:: | |||
| import numpy as np | |||
| from megengine import tensor | |||
| import megengine.functional as F | |||
| data_shape = (1, 2) | |||
| label_shape = (1, ) | |||
| pred = tensor(np.array([0.5, 0.5], dtype=np.float32).reshape(data_shape)) | |||
| label = tensor(np.ones(label_shape, dtype=np.int32)) | |||
| loss = F.cross_entropy(pred, label) | |||
| print(loss.numpy()) | |||
| Outputs: | |||
| .. testoutput:: | |||
| [0.69] | |||
| """ | |||
| n0 = inp.ndim | |||
| n1 = target.ndim | |||
| assert n0 == n1 + 1, ( | |||
| "target ndim must be one less than input ndim; input_ndim={} " | |||
| "target_ndim={}".format(n0, n1) | |||
| ) | |||
| if ignore_index != -1: | |||
| mask = 1 - equal(target, ignore_index) | |||
| target = target * mask | |||
| loss = -log(indexing_one_hot(inp, target, axis)) * mask | |||
| return loss.sum() / maximum(mask.sum(), 1.0) | |||
| else: | |||
| return -log(indexing_one_hot(inp, target, axis)).mean() | |||
| def cross_entropy_with_softmax( | |||
| pred: Tensor, label: Tensor, axis: int = 1, label_smooth: float = 0 | |||
| ) -> Tensor: | |||
| r""" | |||
| Returns loss after applying :func:`~.softmax` + :func:`~.cross_entropy`. | |||
| It has better numerical stability compared with sequential calls to :func:`~.softmax` and :func:`~.cross_entropy`. | |||
| When using label smoothing, the label distribution is as follows: | |||
| .. math:: y^{LS}_{k}=y_{k}\left(1-\alpha\right)+\alpha/K | |||
| where :math:`y^{LS}` and :math:`y` are new label distribution and origin label distribution respectively. | |||
| k is the index of label distribution. :math:`\alpha` is label_smooth and :math:`K` is the number of classes. | |||
| :param pred: The input tensor representing the predicted probability. | |||
| :param label: The input tensor representing the classification label. | |||
| :param axis: An axis along which softmax will be applied. Default: 1. | |||
| :param label_smooth: A label smoothing of parameter that can re-distribute target distribution. Default: 0. | |||
| """ | |||
| n0 = pred.ndim | |||
| n1 = label.ndim | |||
| assert n0 == n1 + 1, ( | |||
| "target ndim must be one less than input ndim; input_ndim={} " | |||
| "target_ndim={}".format(n0, n1) | |||
| ) | |||
| num_classes = pred.shapeof(axis) | |||
| # Denominator of the softmax | |||
| offset = zero_grad(pred.max(axis=axis, keepdims=True)) | |||
| pred = pred - offset | |||
| down = mgb.opr.elem.exp(pred).sum(axis=axis, keepdims=True) | |||
| up = indexing_one_hot(pred, label, axis) | |||
| if label_smooth != 0: | |||
| factor = label_smooth / num_classes | |||
| up = up * (1 - label_smooth) + pred.sum(axis=axis, keepdims=True) * factor | |||
| return (log(down) - up).mean() | |||
| def triplet_margin_loss( | |||
| anchor: Tensor, positive: Tensor, negative: Tensor, margin: float = 1.0, p: int = 2 | |||
| ) -> Tensor: | |||
| r""" | |||
| Creates a criterion that measures the triplet loss given an input tensors. | |||
| .. math:: | |||
| L(a, p, n) = max\left\{d\left(a_{i},p_{i}\right)-d\left(a_{i}, n_{i}\right)+margin, 0\right\},\ | |||
| d\left(x_{i},y_{i}\right)=\left\|x_{i}-y_{i}\right\|_{p} | |||
| :param anchor: The input tensor representing the anchor samples. | |||
| :param positive: The input tensor representing the positive samples. | |||
| :param negative: The input tensor representing the negative samples. | |||
| :param margin: Default: 1.0 | |||
| :param p: The norm degree for pairwise distance. Default: 2.0 | |||
| """ | |||
| s0 = anchor.shapeof() | |||
| s1 = positive.shapeof() | |||
| s2 = negative.shapeof() | |||
| assert_equal(s0, s1) | |||
| assert_equal(s1, s2) | |||
| n0 = anchor.ndim | |||
| n1 = positive.ndim | |||
| n2 = negative.ndim | |||
| assert n0 == 2 and n1 == 2 and n2 == 2, ( | |||
| "anchor ndim, positive ndim, and negative ndim must be 2; " | |||
| "anchor_ndim={} positive_ndim={} negative_ndim={}".format(n0, n1, n2) | |||
| ) | |||
| assert p > 0, "a margin with a value greater than 0; p={}".format(p) | |||
| diff0 = abs(anchor - positive) | |||
| diff1 = abs(anchor - negative) | |||
| d1 = power(power(diff0, p).sum(axis=1, keepdims=True), 1 / p) | |||
| d2 = power(power(diff1, p).sum(axis=1, keepdims=True), 1 / p) | |||
| loss = maximum(d1 - d2 + margin, 0) | |||
| return loss.mean() | |||
| def binary_cross_entropy(pred: Tensor, label: Tensor) -> Tensor: | |||
| r"""Function that measures the Binary Cross Entropy between the target and the prediction. | |||
| :param pred: (N,*) where * means, any number of additional dimensions. | |||
| :param label: (N,*), same shape as the input. | |||
| """ | |||
| s0 = pred.shapeof() | |||
| s1 = label.shapeof() | |||
| assert_equal(s0, s1) | |||
| return -1.0 * (label * log(pred) + (1.0 - label) * log(1 - pred)).mean() | |||
| def nll_loss( | |||
| pred: Tensor, label: Tensor, axis: int = 1, ignore_index: int = -1 | |||
| ) -> Tensor: | |||
| r""" | |||
| The negative log likelihood loss. | |||
| :param pred: The predicted result from model. | |||
| :param label: The ground truth to compare. | |||
| Examples: | |||
| .. testcode:: | |||
| import numpy as np | |||
| from megengine import tensor | |||
| import megengine.functional as F | |||
| data_shape = (2, 2) | |||
| label_shape = (2, ) | |||
| data = tensor( | |||
| np.array([[1, 0.5], [0.3, 1.2]], dtype=np.float32).reshape(data_shape), | |||
| ) | |||
| label = tensor( | |||
| np.ones(label_shape, dtype=np.int32) | |||
| ) | |||
| pred = F.log(F.softmax(data)) | |||
| loss1 = F.nll_loss(pred, label) | |||
| loss2 = F.cross_entropy_with_softmax(data, label) | |||
| print(loss1.numpy(), loss2.numpy()) | |||
| Outputs: | |||
| .. testoutput:: | |||
| [0.6576154] [0.6576154] | |||
| """ | |||
| n0 = pred.ndim | |||
| n1 = label.ndim | |||
| assert n0 == n1 + 1, ( | |||
| "target ndim must be one less than input ndim; input_ndim={} " | |||
| "target_ndim={}".format(n0, n1) | |||
| ) | |||
| mask = 1.0 - equal(label, ignore_index) | |||
| label = label * mask | |||
| loss = indexing_one_hot(pred, label, axis) * mask | |||
| return -1.0 * loss.sum() / maximum(mask.sum(), 1.0) | |||
| def hinge_loss(pred: Tensor, label: Tensor, norm: str = "L1") -> Tensor: | |||
| r""" | |||
| Caculate the hinge loss which is often used in SVMs. | |||
| The hinge loss can be described as: | |||
| .. math:: loss(x, y) = \frac{1}{N}\sum_i\sum_j(max(0, 1 - x_{ij}*y_{ij})) | |||
| :param pred: The input tensor representing the predicted probability, shape is (N, C). | |||
| :param label: The input tensor representing the binary classification label, shape is (N, C). | |||
| :param norm: Specify the norm to caculate the loss, should be "L1" or "L2". | |||
| Examples: | |||
| .. testcode:: | |||
| from megengine import tensor | |||
| import megengine.functional as F | |||
| pred = tensor([[0.5, -0.5, 0.1], [-0.6, 0.7, 0.8]]) | |||
| label = tensor([[1, -1, -1], [-1, 1, 1]]) | |||
| loss = F.hinge_loss(pred, label) | |||
| print(loss.numpy()) | |||
| Outputs: | |||
| .. testoutput:: | |||
| [1.5] | |||
| """ | |||
| assert norm in ["L1", "L2"], "norm must be L1 or L2" | |||
| # Converts binary labels to -1/1 labels. | |||
| loss = relu(1.0 - pred * label) | |||
| if norm == "L1": | |||
| return loss.sum(axis=1).mean() | |||
| else: | |||
| return (loss ** 2).sum(axis=1).mean() | |||
| def smooth_l1_loss(pred: Tensor, label: Tensor) -> Tensor: | |||
| r""" | |||
| Caculate the smooth l1 loss proposed in `Fast R-CNN paper by Ross Girshick`. | |||
| The smooth l1 loss can be described as: | |||
| .. math:: | |||
| \text{loss}(x, y) = \frac{1}{n} \sum_{i} l_{i} | |||
| where :math:`l_{i}` is given by: | |||
| .. math:: | |||
| l_{i} = | |||
| \begin{cases} | |||
| 0.5 (x_i - y_i)^2, & \text{if } |x_i - y_i| < 1 \\ | |||
| |x_i - y_i| - 0.5, & \text{otherwise } | |||
| \end{cases} | |||
| :param pred: The predicted result from model. | |||
| :param label: The ground truth to compare. | |||
| Examples: | |||
| .. testcode:: | |||
| from megengine import tensor | |||
| import megengine.functional as F | |||
| pred = tensor([[0.5, -0.5, 0.1], [-0.6, 0.7, 0.8]]) | |||
| label = tensor([[0.4, 1.5, 1.2], [0., 0.1, 2.2]]) | |||
| loss = F.smooth_l1_loss(pred, label) | |||
| print(loss.numpy()) | |||
| Outputs: | |||
| .. testoutput:: | |||
| [0.5608334] | |||
| """ | |||
| diff = abs(pred - label) | |||
| l2_loss = 0.5 * (diff ** 2) | |||
| l1_loss = diff - 0.5 | |||
| mask = diff < 1 | |||
| loss = where(mask, l2_loss, l1_loss) | |||
| return loss.mean() | |||
| @@ -1,333 +0,0 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import math | |||
| import numbers | |||
| from typing import Optional, Sequence, Union | |||
| import megengine._internal as mgb | |||
| from ..core import Tensor, wrap_io_tensor | |||
| from .elemwise import clamp, exp, isinf, log | |||
| from .tensor import remove_axis, where, zeros_like | |||
| @wrap_io_tensor | |||
| def sum(inp: Tensor, axis: Optional[int] = None, keepdims: bool = False) -> Tensor: | |||
| r"""Returns the sum of each row of the ``inp`` tensor in the given ``axis``. | |||
| :param inp: The input tensor. | |||
| :param axis: The dimension to reduce. If None, all the dimensions will be reduced. | |||
| Default: None | |||
| :param keepdims: Whether the output tensor has ``axis`` retained or not. | |||
| Default: False | |||
| :return: The output tensor | |||
| Examples: | |||
| .. testcode:: | |||
| import numpy as np | |||
| from megengine import tensor | |||
| import megengine.functional as F | |||
| data = tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3)) | |||
| out = F.sum(data) | |||
| print(out.numpy()) | |||
| .. testoutput:: | |||
| [21] | |||
| """ | |||
| return mgb.opr.reduce_(inp, "SUM", axis, keepdims) | |||
| @wrap_io_tensor | |||
| def prod(inp: Tensor, axis: Optional[int] = None, keepdims=False) -> Tensor: | |||
| r""" | |||
| Returns the element product of input tensor along given *axis*. | |||
| :param inp: The input tensor | |||
| :param axis: The dimension to reduce. If None, all the dimensions will be reduced. Default: ``None`` | |||
| :param keepdims: Whether the output tensor has *axis* retained or not. Default: ``False`` | |||
| :return: The output tensor | |||
| Examples: | |||
| .. testcode:: | |||
| import numpy as np | |||
| from megengine import tensor | |||
| import megengine.functional as F | |||
| data = tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3)) | |||
| out = F.prod(data) | |||
| print(out.numpy()) | |||
| Outputs: | |||
| .. testoutput:: | |||
| [720] | |||
| """ | |||
| return mgb.opr.reduce_(inp, "PRODUCT", axis, keepdims) | |||
| @wrap_io_tensor | |||
| def mean(inp: Tensor, axis: Optional[int] = None, keepdims: bool = False) -> Tensor: | |||
| """Returns the mean value of each row of the ``inp`` tensor in | |||
| the given ``axis``. If axis is a list of dimensions, | |||
| reduce over all of them. | |||
| :param inp: The input tensor | |||
| :param axis: The dimension to reduce. If None, all the dimensions will be reduced. Default: None | |||
| :param keepdims: Whether the output tensor has ``axis`` retained or not. Default: False | |||
| Examples: | |||
| .. testcode:: | |||
| import numpy as np | |||
| from megengine import tensor | |||
| import megengine.functional as F | |||
| data = tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3)) | |||
| out = F.mean(data) | |||
| print(out.numpy()) | |||
| .. testoutput:: | |||
| [3.5] | |||
| """ | |||
| return mgb.opr.mean(inp, axis, keepdims) | |||
| @wrap_io_tensor | |||
| def min(inp: Tensor, axis: Optional[int] = None, keepdims: bool = False) -> Tensor: | |||
| r""" | |||
| Returns the min value of input tensor along given *axis*. | |||
| :param inp: The input tensor | |||
| :param axis: The dimension to reduce. If None, all the dimensions will be reduced. Default: None | |||
| :param keepdims: Whether the output tensor has *axis* retained or not. Default: False | |||
| :return: The output tensor | |||
| Examples: | |||
| .. testcode:: | |||
| import numpy as np | |||
| from megengine import tensor | |||
| import megengine.functional as F | |||
| x = tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3)) | |||
| y = F.min(x) | |||
| print(y.numpy()) | |||
| Outputs: | |||
| .. testoutput:: | |||
| [1] | |||
| """ | |||
| return mgb.opr.reduce_(inp, "MIN", axis, keepdims) | |||
| @wrap_io_tensor | |||
| def max(inp: Tensor, axis: Optional[int] = None, keepdims: bool = False) -> Tensor: | |||
| r"""Returns the max value of the input tensor along given *axis*. | |||
| :param inp: The input tensor | |||
| :param axis: The dimension to reduce. If None, all the dimensions will be reduced. Default: None | |||
| :param keepdims: Whether the output tensor has *axis* retained or not. Default: False | |||
| :return: The output tensor | |||
| Examples: | |||
| .. testcode:: | |||
| import numpy as np | |||
| from megengine import tensor | |||
| import megengine.functional as F | |||
| x = tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3)) | |||
| y = F.max(x) | |||
| print(y.numpy()) | |||
| .. testoutput:: | |||
| [6] | |||
| """ | |||
| return mgb.opr.reduce_(inp, "MAX", axis, keepdims) | |||
| @wrap_io_tensor | |||
| def sqrt(inp: Tensor) -> Tensor: | |||
| """ | |||
| Return a new tensor with the square-root of the elements of ``inp`` | |||
| :param inp: The input tensor | |||
| :return: The computed tensor | |||
| Examples: | |||
| .. testcode:: | |||
| import numpy as np | |||
| import megengine as mge | |||
| import megengine.functional as F | |||
| data = mge.tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3)) | |||
| out = F.sqrt(data) | |||
| print(out.numpy()) | |||
| Outputs: | |||
| .. testoutput:: | |||
| [[0. 1. 1.4142] | |||
| [1.7321 2. 2.2361 ]] | |||
| """ | |||
| return mgb.opr.sqrt(inp) | |||
| def norm(inp: Tensor, p: int = 2, axis: Optional[int] = None, keepdims=False): | |||
| """Calculate ``p``-norm of input tensor along certain axis. | |||
| :param inp: The input tensor | |||
| :param p: power of value ``p`` applied to ``inp``. Default: 2 | |||
| :param axis: The dimension to reduce. If None, all the dimensions will be reduced. Default: None | |||
| :param keepdims: Whether the output tensor has ``axis`` retained or not. Default: False | |||
| :return: The output tensor | |||
| """ | |||
| if axis is None: | |||
| inp = inp.reshape(-1) | |||
| return (inp ** p).sum(axis=axis, keepdims=keepdims) ** (1.0 / p) | |||
| @wrap_io_tensor | |||
| def argmin(inp: Tensor, axis: Optional[int] = None, keepdims: bool = False) -> Tensor: | |||
| r"""Returns the indices of the minimum values along an axis | |||
| :param inp: The input tensor | |||
| :param axis: The dimension to reduce. If None, all the dimensions will be reduced. Default: None | |||
| :param keepdims: Whether the output tensor has *axis* retained or not. Default: False | |||
| :return: The output tensor | |||
| Examples: | |||
| .. testcode:: | |||
| import numpy as np | |||
| from megengine import tensor | |||
| import megengine.functional as F | |||
| x = tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3)) | |||
| y = F.argmin(x) | |||
| print(y.numpy()) | |||
| .. testoutput:: | |||
| [0] | |||
| """ | |||
| return mgb.opr.argmin(inp, axis, keepdims) | |||
| @wrap_io_tensor | |||
| def argmax(inp: Tensor, axis: Optional[int] = None, keepdims: bool = False) -> Tensor: | |||
| r"""Returns the indices of the maximum values along an axis | |||
| :param inp: The input tensor | |||
| :param axis: The dimension to reduce. If None, all the dimensions will be reduced. Default: None | |||
| :param keepdims: Whether the output tensor has *axis* retained or not. Default: False | |||
| :return: The output tensor | |||
| Examples: | |||
| .. testcode:: | |||
| import numpy as np | |||
| from megengine import tensor | |||
| import megengine.functional as F | |||
| x = tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3)) | |||
| y = F.argmax(x) | |||
| print(y.numpy()) | |||
| .. testoutput:: | |||
| [5] | |||
| """ | |||
| return mgb.opr.argmax(inp, axis, keepdims) | |||
| def normalize( | |||
| inp: Tensor, p: int = 2, axis: Optional[int] = None, eps: float = 1e-12 | |||
| ) -> Tensor: | |||
| r"""Perform :math:`L_p` normalization of input tensor along certain axis. | |||
| For a tensor :attr:`inp` of shape :math:`(n_0, ..., n_{dim}, ..., n_k)`, each | |||
| :math:`n_{dim}` -element vector :math:`v` along dimension :attr:`axis` is transformed as: | |||
| .. math:: | |||
| v = \frac{v}{\max(\lVert v \rVert_p, \epsilon)}. | |||
| :param inp: the input tensor | |||
| :param p: power of value ``p`` applied to ``inp``. Default: 2 | |||
| :param axis: The dimension to reduce. If None, all the dimensions will be reduced | |||
| to calculate the norm. Default: None | |||
| :param eps: a small value to avoid division by zero. Default: 1e-12 | |||
| :return: the normalized output tensor | |||
| """ | |||
| if axis is None: | |||
| return inp / clamp(norm(inp, p), lower=eps) | |||
| else: | |||
| return inp / clamp(norm(inp, p, axis, keepdims=True), lower=eps) | |||
| def logsumexp(inp: Tensor, axis: Union[int, Sequence[int]], keepdims: bool = False): | |||
| r""" | |||
| Compute the log of the sum of exponentials of inputs along the given :attr:`axis`. The computation is numerically stabilized. | |||
| .. math:: | |||
| \mathsf{logsumexp}(x_1, \dots, x_n) = \log(\exp(x_1) + \cdots + \exp(x_n)) | |||
| :param inp: The input tensor. | |||
| :param axis: Axis over which the sum is taken. It can be a single axis or a list of axes. | |||
| :param keepdims: whether to retain :attr:`axis` or not for the output tensor. | |||
| """ | |||
| if isinstance(axis, numbers.Integral): | |||
| axis = (axis,) | |||
| max_value = inp | |||
| for dim in axis: | |||
| max_value = max_value.max(axis=dim, keepdims=True) | |||
| max_value = where( | |||
| isinf(max_value).astype("int32"), zeros_like(max_value), max_value | |||
| ) | |||
| x = exp(inp - max_value) | |||
| for dim in axis: | |||
| x = x.sum(axis=dim, keepdims=True) | |||
| x = max_value + log(x) | |||
| if not keepdims: | |||
| axis = sorted(axis, reverse=True) | |||
| for i in axis: | |||
| x = remove_axis(x, axis=i) | |||
| return x | |||
| @@ -1,80 +0,0 @@ | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # pylint: disable=too-many-lines | |||
| from typing import Tuple, Union | |||
| from .. import _internal as mgb | |||
| from ..core import Tensor, wrap_io_tensor | |||
| from ..utils.types import _pair, _pair_nonzero | |||
| from .debug_param import get_conv_execution_strategy | |||
| @wrap_io_tensor | |||
| def conv_bias_activation( | |||
| inp: Tensor, | |||
| weight: Tensor, | |||
| bias: Tensor, | |||
| dtype=None, | |||
| stride: Union[int, Tuple[int, int]] = 1, | |||
| padding: Union[int, Tuple[int, int]] = 0, | |||
| dilation: Union[int, Tuple[int, int]] = 1, | |||
| groups: int = 1, | |||
| nonlinear_mode="IDENTITY", | |||
| conv_mode="CROSS_CORRELATION", | |||
| compute_mode="DEFAULT", | |||
| ) -> Tensor: | |||
| """ convolution bias with activation operation, only for inference. | |||
| :param inp: The feature map of the convolution operation | |||
| :param weight: The convolution kernel | |||
| :param bias: The bias added to the result of convolution | |||
| :param stride: Stride of the 2D convolution operation. Default: 1 | |||
| :param padding: Size of the paddings added to the input on both sides of its | |||
| spatial dimensions. Only zero-padding is supported. Default: 0 | |||
| :param dilation: Dilation of the 2D convolution operation. Default: 1 | |||
| :param groups: number of groups to divide input and output channels into, | |||
| so as to perform a "grouped convolution". When ``groups`` is not 1, | |||
| ``in_channels`` and ``out_channels`` must be divisible by ``groups``, | |||
| and the shape of weight should be ``(groups, out_channel // groups, | |||
| in_channels // groups, height, width)``. | |||
| :type conv_mode: string or :class:`mgb.opr_param_defs.Convolution.Mode` | |||
| :param conv_mode: Supports 'CROSS_CORRELATION' or 'CONVOLUTION'. Default: | |||
| 'CROSS_CORRELATION'. | |||
| :param dtype: Support for np.dtype, Default: | |||
| np.int8. | |||
| :type compute_mode: string or | |||
| :class:`mgb.opr_param_defs.Convolution.ComputeMode` | |||
| :param compute_mode: When set to 'DEFAULT', no special requirements will be | |||
| placed on the precision of intermediate results. When set to 'FLOAT32', | |||
| Float32 would be used for accumulator and intermediate result, but only | |||
| effective when input and output are of Float16 dtype. | |||
| """ | |||
| ph, pw = _pair(padding) | |||
| sh, sw = _pair_nonzero(stride) | |||
| dh, dw = _pair_nonzero(dilation) | |||
| sparse_type = "DENSE" if groups == 1 else "GROUP" | |||
| res = mgb.opr.conv_bias_activation( | |||
| inp, | |||
| weight, | |||
| bias, | |||
| compute_mode=compute_mode, | |||
| dtype=dtype, | |||
| strategy=get_conv_execution_strategy(), | |||
| nonlineMode=nonlinear_mode, | |||
| sparse=sparse_type, | |||
| format="NCHW", | |||
| pad_h=ph, | |||
| pad_w=pw, | |||
| stride_h=sh, | |||
| stride_w=sw, | |||
| dilate_h=dh, | |||
| dilate_w=dw, | |||
| mode=conv_mode, | |||
| ) | |||
| return res | |||
| @@ -1,123 +0,0 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import functools | |||
| from typing import Optional, Tuple, Union | |||
| import megengine._internal as mgb | |||
| from ..core.tensor import Tensor, wrap_io_tensor | |||
| __all__ = ["argsort", "sort", "top_k"] | |||
| @wrap_io_tensor | |||
| def argsort(inp: Tensor, descending: bool = False) -> Tuple[Tensor, Tensor]: | |||
| r""" | |||
| Sort the target 2d matrix by row, return both the sorted tensor and indices. | |||
| :param inp: The input tensor, if 2d, each row will be sorted | |||
| :param descending: Sort in descending order, where the largest comes first. Default: ``False`` | |||
| :return: Tuple of two tensors (sorted_tensor, indices_of_int32) | |||
| Examples: | |||
| .. testcode:: | |||
| import numpy as np | |||
| from megengine import tensor | |||
| import megengine.functional as F | |||
| data = tensor(np.array([1,2], dtype=np.float32)) | |||
| sorted, indices = F.argsort(data) | |||
| print(sorted.numpy(), indices.numpy()) | |||
| Outputs: | |||
| .. testoutput:: | |||
| [1. 2.] [0 1] | |||
| """ | |||
| assert len(inp.imm_shape) <= 2, "Input should be 1d or 2d" | |||
| if descending: | |||
| order = mgb.opr_param_defs.Argsort.Order.DESCENDING | |||
| else: | |||
| order = mgb.opr_param_defs.Argsort.Order.ASCENDING | |||
| if len(inp.imm_shape) == 1: | |||
| inp = inp.reshape(1, -1) | |||
| tns, ind = mgb.opr.argsort(inp, order=order) | |||
| return tns[0], ind[0] | |||
| return mgb.opr.argsort(inp, order=order) | |||
| @functools.wraps(argsort) | |||
| def sort(*args, **kwargs): | |||
| return argsort(*args, **kwargs) | |||
| @wrap_io_tensor | |||
| def top_k( | |||
| inp: Tensor, | |||
| k: int, | |||
| descending: bool = False, | |||
| kth_only: bool = False, | |||
| no_sort: bool = False, | |||
| ) -> Tuple[Tensor, Tensor]: | |||
| r""" | |||
| Selected the Top-K (by default) smallest elements of 2d matrix by row. | |||
| :param inp: The input tensor, if 2d, each row will be sorted | |||
| :param k: The number of elements needed | |||
| :param descending: If true, return the largest elements instead. Default: ``False`` | |||
| :param kth_only: If true, only the k-th element will be returned. Default: ``False`` | |||
| :param no_sort: If true, the returned elements can be unordered. Default: ``False`` | |||
| :return: Tuple of two tensors (topk_tensor, indices_of_int32) | |||
| Examples: | |||
| .. testcode:: | |||
| import numpy as np | |||
| from megengine import tensor | |||
| import megengine.functional as F | |||
| data = tensor(np.array([2, 4, 6, 8, 7, 5, 3, 1], dtype=np.float32)) | |||
| top, indices = F.top_k(data, 5) | |||
| print(top.numpy(), indices.numpy()) | |||
| Outputs: | |||
| .. testoutput:: | |||
| [1. 2. 3. 4. 5.] [7 0 6 1 5] | |||
| """ | |||
| assert len(inp.imm_shape) <= 2, "Input should be 1d or 2d" | |||
| if kth_only: | |||
| raise NotImplementedError( | |||
| "TODO: would enconter:" | |||
| "NotImplementedError: SymbolVar var could not be itered" | |||
| ) | |||
| if descending: | |||
| inp = -inp | |||
| Mode = mgb.opr_param_defs.TopK.Mode | |||
| if kth_only: | |||
| mode = Mode.KTH_ONLY | |||
| elif no_sort: | |||
| mode = Mode.VALUE_IDX_NOSORT | |||
| else: | |||
| mode = Mode.VALUE_IDX_SORTED | |||
| if len(inp.imm_shape) == 1: | |||
| inp = inp.reshape(1, -1) | |||
| tns, ind = mgb.opr.top_k(inp, k, mode=mode) | |||
| tns = tns[0] | |||
| ind = ind[0] | |||
| else: | |||
| tns, ind = mgb.opr.top_k(inp, k, mode=mode) | |||
| if descending: | |||
| tns = -tns | |||
| return tns, ind | |||
| @@ -1,667 +0,0 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import functools | |||
| from typing import Iterable, List, Optional, Union | |||
| import numpy as np | |||
| import megengine._internal as mgb | |||
| from megengine._internal import CompGraph, CompNode | |||
| from ..core import zeros | |||
| from ..core.graph import _use_default_if_none | |||
| from ..core.tensor import Tensor, wrap_io_tensor | |||
| from .elemwise import ceil | |||
| from .utils import _decide_comp_node_and_comp_graph | |||
| @wrap_io_tensor | |||
| def broadcast_to(inp: Tensor, shape: Union[int, Iterable[int]]) -> Tensor: | |||
| """ | |||
| Broadcast a tensor to ``shape`` | |||
| :param inp: The input tensor | |||
| :param shape: The target shape | |||
| :return: The output tensor | |||
| Examples: | |||
| .. testcode:: | |||
| import numpy as np | |||
| from megengine import tensor | |||
| import megengine.functional as F | |||
| data = tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3)) | |||
| out = F.broadcast_to(data, (4, 2, 3)) | |||
| print(out.numpy()) | |||
| Outputs: | |||
| .. testoutput:: | |||
| [[[0. 1. 2.] | |||
| [3. 4. 5.]] | |||
| [[0. 1. 2.] | |||
| [3. 4. 5.]] | |||
| [[0. 1. 2.] | |||
| [3. 4. 5.]] | |||
| [[0. 1. 2.] | |||
| [3. 4. 5.]]] | |||
| """ | |||
| if isinstance(shape, int): | |||
| shape = (shape,) | |||
| return mgb.opr.broadcast(inp, shape) | |||
| def _get_idx(index, axis): | |||
| index_dims = len(index.imm_shape) | |||
| idx = [] | |||
| comp_node, comp_graph = _decide_comp_node_and_comp_graph(index) | |||
| for i in range(index_dims): | |||
| if i != axis: | |||
| shape = [1] * index_dims | |||
| shape[i] = index.axis_shape(i) | |||
| arange = mgb.opr.linspace( | |||
| 0, | |||
| index.axis_shape(i) - 1, | |||
| index.axis_shape(i), | |||
| comp_node=comp_node, | |||
| comp_graph=comp_graph, | |||
| ) | |||
| arange = ( | |||
| arange.reshape(*shape) | |||
| .broadcast(index.shape) | |||
| .reshape(-1) | |||
| .astype(np.int32) | |||
| ) | |||
| idx.append(arange) | |||
| else: | |||
| idx.append(index.reshape(-1)) | |||
| return tuple(idx) | |||
| @wrap_io_tensor | |||
| def gather(inp: Tensor, axis: int, index: Tensor) -> Tensor: | |||
| r""" | |||
| Gather data from :attr:`inp` on :attr:`axis` using :attr:`index`. | |||
| For a 3-D tensor, the output is specified by:: | |||
| out[i][j][k] = inp[index[i][j][k]][j][k] # if axis == 0 | |||
| out[i][j][k] = inp[i][index[i][j][k]][k] # if axis == 1 | |||
| out[i][j][k] = inp[i][j][index[i][j][k]] # if axis == 2 | |||
| if :attr:`inp` is an n-dimensional tensor with size | |||
| :math:`(x_0,x_1,...,x_{i-1},x_i,x_{i+1},...,x_{n-1})` and axis=i, | |||
| then :attr:`index` must be an n-dimensional tensor with size | |||
| :math:`(x_0,x_1,...,x_{i-1},y,x_{i+1},...,x_{n-1})` where :math:`y\ge 1` and | |||
| output will have the same size as :attr:`index`. | |||
| :param inp: the source tensor | |||
| :param axis: the axis along which to index | |||
| :param index: the indices of elements to gather | |||
| Examples: | |||
| .. testcode:: | |||
| import megengine.functional as F | |||
| from megengine.core import tensor | |||
| inp = tensor([ | |||
| [1,2], [3,4], [5,6], | |||
| ]) | |||
| index = tensor([[0,2], [1,0]]) | |||
| oup = F.gather(inp, 0, index) | |||
| print(oup.numpy()) | |||
| Outputs: | |||
| .. testoutput:: | |||
| [[1 6] | |||
| [3 2]] | |||
| """ | |||
| input_shape = inp.imm_shape | |||
| index_shape = index.imm_shape | |||
| input_dims = len(input_shape) | |||
| index_dims = len(index_shape) | |||
| if input_dims != index_dims: | |||
| raise ValueError( | |||
| "The index tensor must have same dimensions as input tensor, " | |||
| "But the input dims:{}, the index dims:{}".format(input_dims, index_dims) | |||
| ) | |||
| if axis < 0 or axis >= input_dims: | |||
| raise ValueError( | |||
| "Index axis {} is output of bounds, should in range [0 {})".format( | |||
| axis, input_dims | |||
| ) | |||
| ) | |||
| for i in range(input_dims): | |||
| if i != axis and input_shape[i] != index_shape[i]: | |||
| raise ValueError( | |||
| "The input {} and index {} must have the same size apart from axis {}".format( | |||
| input_shape, index_shape, axis | |||
| ) | |||
| ) | |||
| idx = _get_idx(index, axis) | |||
| return mgb.opr.advanced_indexing(inp)[idx].reshape( | |||
| index.shape | |||
| ) # pylint: disable=no-member | |||
| @wrap_io_tensor | |||
| def concat( | |||
| inps: Iterable[Tensor], | |||
| axis: int = 0, | |||
| device: Optional[CompNode] = None, | |||
| comp_graph: Optional[CompGraph] = None, | |||
| ) -> Tensor: | |||
| r""" | |||
| Concat some tensors | |||
| :param inps: Input tensors to concat | |||
| :param axis: the dimension over which the tensors are concatenated. Default: 0 | |||
| :param device: The comp node output on. Default: None | |||
| :param comp_graph: The graph in which output is. Default: None | |||
| :return: The output tensor | |||
| Examples: | |||
| .. testcode:: | |||
| import numpy as np | |||
| from megengine import tensor | |||
| import megengine.functional as F | |||
| data1 = tensor(np.arange(0, 6, dtype=np.float32).reshape((2, 3))) | |||
| data2 = tensor(np.arange(6, 12, dtype=np.float32).reshape((2, 3))) | |||
| out = F.concat([data1, data2]) | |||
| print(out.numpy()) | |||
| Outputs: | |||
| .. testoutput:: | |||
| [[ 0. 1. 2.] | |||
| [ 3. 4. 5.] | |||
| [ 6. 7. 8.] | |||
| [ 9. 10. 11.]] | |||
| """ | |||
| # Output buffer not supported | |||
| return mgb.opr.concat( | |||
| *list(inps), axis=axis, comp_node=device, comp_graph=comp_graph | |||
| ) | |||
| @wrap_io_tensor | |||
| def scatter(inp: Tensor, axis: int, index: Tensor, source: Tensor) -> Tensor: | |||
| r""" | |||
| Writes all values from the tensor :attr:`source` into :attr:`inp` at the indices specified in the :attr:`index` tensor. | |||
| For each value in :attr:`source`, its output index is specified by its index | |||
| in :attr:`source` for ``axis != dimension`` and by the corresponding value in | |||
| :attr:`index` for ``axis = dimension``. | |||
| For a 3-D tensor, :attr:`inp` is updated as:: | |||
| inp[index[i][j][k]][j][k] = source[i][j][k] # if axis == 0 | |||
| inp[i][index[i][j][k]][k] = source[i][j][k] # if axis == 1 | |||
| inp[i][j][index[i][j][k]] = source[i][j][k] # if axis == 2 | |||
| :attr:`inp`, :attr:`index` and :attr:`source` should have same number of dimensions. | |||
| It is also required that ``source.shape(d) <= inp.shape(d)`` and ``index.shape(d) == source.shape(d)`` | |||
| for all dimensions ``d``. | |||
| Moreover, the values of :attr:`index` must be between ``0`` and ``inp.shape(axis) - 1`` inclusive. | |||
| .. note:: | |||
| Please notice that, due to performance issues, the result is uncertain on the GPU device | |||
| if scatter difference positions from source to the same destination position | |||
| regard to index tensor. | |||
| Show the case using the following examples, the oup[0][2] is maybe | |||
| from source[0][2] which value is 0.2256 or source[1][2] which value is 0.5339 | |||
| if set the index[1][2] from 1 to 0. | |||
| :param inp: the inp tensor which to be scattered | |||
| :param axis: the axis along which to index | |||
| :param index: the indices of elements to scatter | |||
| :param source: the source element(s) to scatter | |||
| Examples: | |||
| .. testcode:: | |||
| import numpy as np | |||
| import megengine.functional as F | |||
| from megengine.core import tensor | |||
| inp = tensor(np.zeros(shape=(3,5),dtype=np.float32)) | |||
| source = tensor([[0.9935,0.9465,0.2256,0.8926,0.4396],[0.7723,0.0718,0.5939,0.357,0.4576]]) | |||
| index = tensor([[0,2,0,2,1],[2,0,1,1,2]]) | |||
| oup = F.scatter(inp, 0, index,source) | |||
| print(oup.numpy()) | |||
| Outputs: | |||
| .. testoutput:: | |||
| [[0.9935 0.0718 0.2256 0. 0. ] | |||
| [0. 0. 0.5939 0.357 0.4396] | |||
| [0.7723 0.9465 0. 0.8926 0.4576]] | |||
| """ | |||
| input_shape = inp.imm_shape | |||
| index_shape = index.imm_shape | |||
| source_shape = source.imm_shape | |||
| input_dims = len(input_shape) | |||
| index_dims = len(index_shape) | |||
| source_dims = len(source_shape) | |||
| if input_dims != index_dims or input_dims != source_dims: | |||
| raise ValueError("The input, source and index tensor must have same dimensions") | |||
| if axis < 0 or axis >= input_dims: | |||
| raise ValueError( | |||
| "Index axis {} is output of bounds, should in range [0 {})".format( | |||
| axis, input_dims | |||
| ) | |||
| ) | |||
| for i in range(source_dims): | |||
| if source_shape[i] > input_shape[i]: | |||
| raise ValueError( | |||
| "The each shape size for source {} must be less than or equal to input {} ".format( | |||
| source_shape, input_shape | |||
| ) | |||
| ) | |||
| for i in range(index_dims): | |||
| if index_shape[i] != source_shape[i]: | |||
| raise ValueError( | |||
| "The each shape size for index {} must be equal to source {} ".format( | |||
| index_shape, source_shape | |||
| ) | |||
| ) | |||
| for i in range(index_dims): | |||
| if i != axis and index_shape[i] > input_shape[i]: | |||
| raise ValueError( | |||
| "The index {} must be less than or equal to input {} size apart from axis {}".format( | |||
| index_shape, input_shape, axis | |||
| ) | |||
| ) | |||
| idx = _get_idx(index, axis) | |||
| return mgb.opr.set_advanced_indexing(inp, source.flatten())[idx] | |||
| @wrap_io_tensor | |||
| def where(mask: Tensor, x: Tensor, y: Tensor) -> Tensor: | |||
| r""" | |||
| Select elements either from Tensor x or Tensor y, according to mask. | |||
| .. math:: | |||
| \textrm{out}_i = x_i \textrm{ if } \textrm{mask}_i \textrm{ is True else } y_i | |||
| :param mask: a mask used for choosing x or y | |||
| :param x: the first choice | |||
| :param y: the second choice | |||
| Examples: | |||
| .. testcode:: | |||
| from megengine import tensor | |||
| import megengine.functional as F | |||
| mask = tensor(np.array([[1, 0], [0, 1]], dtype=np.int32)) | |||
| x = tensor(np.array([[1, np.inf], [np.nan, 4]], | |||
| dtype=np.float32)) | |||
| y = tensor(np.array([[5, 6], [7, 8]], dtype=np.float32)) | |||
| out = F.where(mask, x, y) | |||
| print(out.numpy()) | |||
| Outputs: | |||
| .. testoutput:: | |||
| [[1. 6.] | |||
| [7. 4.]] | |||
| """ | |||
| v0, index0 = mgb.opr.cond_take( | |||
| x, mask, mode=mgb.opr_param_defs.CondTake.Mode.EQ, val=1 | |||
| ) | |||
| v1, index1 = mgb.opr.cond_take( | |||
| y, mask, mode=mgb.opr_param_defs.CondTake.Mode.EQ, val=0 | |||
| ) | |||
| out = x.flatten() | |||
| index = mgb.opr.concat(index0, index1, axis=0) | |||
| v = mgb.opr.concat(v0, v1, axis=0) | |||
| out = mgb.opr.set_advanced_indexing(out, v)[index] | |||
| out = out.reshape(x.shape) | |||
| return out | |||
| @wrap_io_tensor | |||
| def cond_take(mask: Tensor, x: Tensor, val=1) -> Tensor: | |||
| r""" | |||
| Take elements from data if specific condition is satisfied on mask. This operator has two outputs: the first is the elements taken, and the second is the indices corresponding to those elements; they are both 1-dimensional. High-dimension input would first be flattened. | |||
| :param mask: condition param; must be the same shape with data | |||
| :param x: input tensor from which to take elements | |||
| :param val: value to be compared to by mode | |||
| Examples: | |||
| .. testcode:: | |||
| from megengine import tensor | |||
| import megengine.functional as F | |||
| mask = tensor(np.array([[1, 0], [0, 1]], dtype=np.int32)) | |||
| x = tensor(np.array([[1, np.inf], [np.nan, 4]], | |||
| dtype=np.float32)) | |||
| v, index = F.cond_take(mask, x, 1) | |||
| print(v, index) | |||
| Outputs: | |||
| .. testoutput:: | |||
| Tensor([1. 4.]) Tensor([0 3], dtype=int32) | |||
| """ | |||
| v, index = mgb.opr.cond_take( | |||
| x, mask, mode=mgb.opr_param_defs.CondTake.Mode.EQ, val=val | |||
| ) | |||
| return v, index | |||
| def shapeof(x: Tensor, axis=None): | |||
| r""" | |||
| The shape of input tensor. | |||
| """ | |||
| return x.shapeof(axis=axis) | |||
| @wrap_io_tensor | |||
| def dimshuffle(inp: Tensor, pattern: Iterable[int]) -> Tensor: | |||
| r""" | |||
| Swap shapes and strides according to given pattern | |||
| :param inp: Input tensor | |||
| :param pattern: a list of integers including 0, 1, ... , ``ndim``-1, and any number of ``'x'`` char in dimensions where this tensor should be broadcasted. For examples: | |||
| * (``'x'``) -> make a 0d (scalar) into a 1d vector | |||
| * (0, 1) -> identity for 2d vectors | |||
| * (1, 0) -> inverts the first and second dimensions | |||
| * (``'x'``, 0) -> make a row out of a 1d vector (N to 1xN) | |||
| * (0, ``'x'``) -> make a column out of a 1d vector (N to Nx1) | |||
| * (2, 0, 1) -> AxBxC to CxAxB | |||
| * (0, ``'x'``, 1) -> AxB to Ax1xB | |||
| * (1, ``'x'``, 0) -> AxB to Bx1xA | |||
| * (1,) -> This remove dimensions 0. It must be a broadcastable dimension (1xA to A) | |||
| :return: The output tensor | |||
| Examples: | |||
| .. testcode:: | |||
| import numpy as np | |||
| from megengine import tensor | |||
| import megengine.functional as F | |||
| x = tensor(np.array([[1, 1], [0, 0]], dtype=np.int32)) | |||
| out = F.dimshuffle(x, (1, 0)) | |||
| print(out.numpy()) | |||
| Outputs: | |||
| .. testoutput:: | |||
| [[1 0] | |||
| [1 0]] | |||
| """ | |||
| return mgb.opr.dimshuffle(inp, pattern) | |||
| @wrap_io_tensor | |||
| def reshape(inp: Tensor, target_shape: Iterable[int]) -> Tensor: | |||
| r""" | |||
| Reshape a tensor to given target shape; total number of logical elements must | |||
| remain unchanged | |||
| :param inp: Input tensor | |||
| :param target_shape: target shape, the components would be concatenated to form the | |||
| target shape, and it can contain an element of -1 representing unspec_axis. | |||
| Examples: | |||
| .. testcode:: | |||
| import numpy as np | |||
| from megengine import tensor | |||
| import megengine.functional as F | |||
| x = tensor(np.arange(12, dtype=np.int32)) | |||
| out = F.reshape(x, (3, 2, 2)) | |||
| print(out.numpy()) | |||
| Outputs: | |||
| .. testoutput:: | |||
| [[[ 0 1] | |||
| [ 2 3]] | |||
| [[ 4 5] | |||
| [ 6 7]] | |||
| [[ 8 9] | |||
| [10 11]]] | |||
| """ | |||
| return mgb.opr.reshape(inp, target_shape) | |||
| def transpose(inp: Tensor, pattern: Iterable[int]) -> Tensor: | |||
| r"""Equivalent to :func:`dimshuffle` | |||
| """ | |||
| return dimshuffle(inp, pattern) | |||
| @wrap_io_tensor | |||
| def add_axis(inp: Tensor, axis: int) -> Tensor: | |||
| r""" | |||
| Add dimension before given axis. | |||
| :param inp: Input tensor | |||
| :param axis: Place of new axes | |||
| :return: The output tensor | |||
| Examples: | |||
| .. testcode:: | |||
| import numpy as np | |||
| from megengine import tensor | |||
| import megengine.functional as F | |||
| x = tensor([1, 2]) | |||
| out = F.add_axis(x, 0) | |||
| print(out.shape) | |||
| Outputs: | |||
| .. testoutput:: | |||
| (1, 2) | |||
| """ | |||
| if not isinstance(axis, int): | |||
| raise ValueError("axis must be int, but got type:{}".format(type(axis))) | |||
| return mgb.opr.add_axis(inp, axis) | |||
| @wrap_io_tensor | |||
| def remove_axis(inp: Tensor, axis: int) -> Tensor: | |||
| r""" | |||
| Remove dimension of shape 1. | |||
| :param inp: Input tensor | |||
| :param axis: Place of axis to be removed | |||
| :return: The output tensor | |||
| Examples: | |||
| .. testcode:: | |||
| import numpy as np | |||
| from megengine import tensor | |||
| import megengine.functional as F | |||
| x = tensor(np.array([1, 2], dtype=np.int32).reshape(1, 1, 2, 1)) | |||
| out = F.remove_axis(x, 3) | |||
| print(out.shape) | |||
| Outputs: | |||
| .. testoutput:: | |||
| (1, 1, 2) | |||
| """ | |||
| if not isinstance(axis, int): | |||
| raise ValueError("axis must be int, but got type:{}".format(type(axis))) | |||
| return mgb.opr.remove_axis(inp, axis) | |||
| def linspace( | |||
| start: Union[int, float, Tensor], | |||
| stop: Union[int, float, Tensor], | |||
| num: Union[int, Tensor], | |||
| dtype=np.float32, | |||
| device: Optional[CompNode] = None, | |||
| comp_graph: Optional[CompGraph] = None, | |||
| ) -> Tensor: | |||
| r""" | |||
| Return equally spaced numbers over a specified interval | |||
| :param start: Starting value of the squence, shoule be scalar | |||
| :param stop: The last value of the squence, shoule be scalar | |||
| :param num: number of values to generate | |||
| :param dtype: result data type | |||
| :return: The generated tensor | |||
| Examples: | |||
| .. testcode:: | |||
| import numpy as np | |||
| import megengine.functional as F | |||
| a = F.linspace(3,10,5) | |||
| print(a.numpy()) | |||
| .. testoutput:: | |||
| [ 3. 4.75 6.5 8.25 10. ] | |||
| """ | |||
| if dtype is not np.float32: | |||
| raise ValueError("linspace is only implemented for float32") | |||
| device, comp_graph = _use_default_if_none(device, comp_graph) | |||
| ret = Tensor( | |||
| mgb.opr.linspace(start, stop, num, comp_node=device, comp_graph=comp_graph) | |||
| ) | |||
| return ret.astype(dtype) | |||
| def arange( | |||
| start: Union[int, float, Tensor], | |||
| end: Union[int, float, Tensor], | |||
| step: Union[int, float, Tensor] = 1, | |||
| dtype=np.float32, | |||
| device: Optional[CompNode] = None, | |||
| comp_graph: Optional[CompGraph] = None, | |||
| ) -> Tensor: | |||
| r""" | |||
| Returns a Tensor with values from `start` to `end` with adjacent interval `step` | |||
| :param start: starting value of the squence, shoule be scalar | |||
| :param end: ending value of the squence, shoule be scalar | |||
| :param step: the gap between each pair of adjacent values. Default 1 | |||
| :param dtype: result data type | |||
| :return: The generated tensor | |||
| Examples: | |||
| .. testcode:: | |||
| import numpy as np | |||
| import megengine.functional as F | |||
| a = F.arange(1, 5, 1) | |||
| print(a.numpy()) | |||
| .. testoutput:: | |||
| [1. 2. 3. 4.] | |||
| """ | |||
| if dtype is not np.float32: | |||
| raise ValueError("arange is only implemented for float32") | |||
| num = ceil((end - start) / step) | |||
| stop = start + step * (num - 1) | |||
| ret = linspace(start, stop, num, device=device, comp_graph=comp_graph) | |||
| return ret | |||
| def zeros_like(inp: Tensor) -> Tensor: | |||
| r""" | |||
| Returns a zero tensor with the same shape as input tensor | |||
| :param inp: input tensor | |||
| Examples: | |||
| .. testcode:: | |||
| import numpy as np | |||
| from megengine import tensor | |||
| import megengine.functional as F | |||
| inp = tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3)) | |||
| out = F.zeros_like(inp) | |||
| print(out.numpy()) | |||
| .. testoutput:: | |||
| [[0 0 0] | |||
| [0 0 0]] | |||
| """ | |||
| return zeros(inp.shapeof()).astype(inp.dtype) | |||
| @@ -1,81 +0,0 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from typing import Iterable, Union | |||
| import megengine._internal as mgb | |||
| from ..core.graph import _use_default_if_none | |||
| from ..core.tensor import Tensor, wrap_io_tensor | |||
| from .elemwise import equal | |||
| from .sort import top_k | |||
| def _decide_comp_node_and_comp_graph(*args: mgb.SymbolVar): | |||
| for i in args: | |||
| if isinstance(i, mgb.SymbolVar): | |||
| return i.comp_node, i.owner_graph | |||
| return _use_default_if_none(None, None) | |||
| def accuracy( | |||
| logits: Tensor, target: Tensor, topk: Union[int, Iterable[int]] = 1 | |||
| ) -> Union[Tensor, Iterable[Tensor]]: | |||
| r""" | |||
| Calculate the classification accuracy given predicted logits and ground-truth labels. | |||
| :param logits: Model predictions of shape [batch_size, num_classes], | |||
| representing the probability (likelyhood) of each class. | |||
| :param target: Ground-truth labels, 1d tensor of int32 | |||
| :param topk: Specifies the topk values, could be an int or tuple of ints. Default: 1 | |||
| :return: Tensor(s) of classification accuracy between 0.0 and 1.0 | |||
| Examples: | |||
| .. testcode:: | |||
| import numpy as np | |||
| from megengine import tensor | |||
| import megengine.functional as F | |||
| logits = tensor(np.arange(80, dtype=np.int32).reshape(8,10)) | |||
| target = tensor(np.arange(8, dtype=np.int32)) | |||
| top1, top5 = F.accuracy(logits, target, (1, 5)) | |||
| print(top1.numpy(), top5.numpy()) | |||
| Outputs: | |||
| .. testoutput:: | |||
| [0.] [0.375] | |||
| """ | |||
| if isinstance(topk, int): | |||
| topk = (topk,) | |||
| _, pred = top_k(logits, k=max(topk), descending=True) | |||
| accs = [] | |||
| for k in topk: | |||
| correct = equal( | |||
| pred[:, :k], target.dimshuffle(0, "x").broadcast(target.shapeof(0), k) | |||
| ) | |||
| accs.append(correct.sum() / target.shapeof(0)) | |||
| if len(topk) == 1: # type: ignore[arg-type] | |||
| accs = accs[0] | |||
| return accs | |||
| @wrap_io_tensor | |||
| def zero_grad(inp: Tensor) -> Tensor: | |||
| r""" | |||
| Returns a tensor which is treated as constant during backward gradient calcuation, | |||
| i.e. its gradient is zero. | |||
| :param inp: Input tensor. | |||
| See implementation of :func:`~.softmax` for example. | |||
| """ | |||
| return mgb.opr.zero_grad(inp) | |||
| @@ -1,16 +0,0 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from .hub import ( | |||
| help, | |||
| import_module, | |||
| list, | |||
| load, | |||
| load_serialized_obj_from_url, | |||
| pretrained, | |||
| ) | |||
| @@ -1,17 +0,0 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| DEFAULT_BRANCH_NAME = "master" | |||
| HUBCONF = "hubconf.py" | |||
| HUBDEPENDENCY = "dependencies" | |||
| DEFAULT_GIT_HOST = "github.com" | |||
| ENV_MGE_HOME = "MGE_HOME" | |||
| ENV_XDG_CACHE_HOME = "XDG_CACHE_HOME" | |||
| DEFAULT_CACHE_DIR = "~/.cache" | |||
| DEFAULT_PROTOCOL = "HTTPS" | |||
| HTTP_READ_TIMEOUT = 120 | |||
| @@ -1,30 +0,0 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| class FetcherError(Exception): | |||
| """Base class for fetch related error.""" | |||
| class InvalidRepo(FetcherError): | |||
| """The repo provided was somehow invalid.""" | |||
| class InvalidGitHost(FetcherError): | |||
| """The git host provided was somehow invalid.""" | |||
| class GitPullError(FetcherError): | |||
| """A git pull error occurred""" | |||
| class GitCheckoutError(FetcherError): | |||
| """A git checkout error occurred""" | |||
| class InvalidProtocol(FetcherError): | |||
| """The protocol provided was somehow invalid""" | |||
| @@ -1,300 +0,0 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import hashlib | |||
| import os | |||
| import re | |||
| import shutil | |||
| import subprocess | |||
| from tempfile import NamedTemporaryFile | |||
| from typing import Tuple | |||
| from zipfile import ZipFile | |||
| import requests | |||
| from tqdm import tqdm | |||
| from megengine.utils.http_download import ( | |||
| CHUNK_SIZE, | |||
| HTTP_CONNECTION_TIMEOUT, | |||
| HTTPDownloadError, | |||
| ) | |||
| from ..distributed.util import is_distributed, synchronized | |||
| from ..logger import get_logger | |||
| from .const import DEFAULT_BRANCH_NAME, HTTP_READ_TIMEOUT | |||
| from .exceptions import GitCheckoutError, GitPullError, InvalidGitHost, InvalidRepo | |||
| from .tools import cd | |||
| logger = get_logger(__name__) | |||
| HTTP_TIMEOUT = (HTTP_CONNECTION_TIMEOUT, HTTP_READ_TIMEOUT) | |||
| pattern = re.compile( | |||
| r"^(?:[a-z0-9]" # First character of the domain | |||
| r"(?:[a-z0-9-_]{0,61}[a-z0-9])?\.)" # Sub domain + hostname | |||
| r"+[a-z0-9][a-z0-9-_]{0,61}" # First 61 characters of the gTLD | |||
| r"[a-z]$" # Last character of the gTLD | |||
| ) | |||
| class RepoFetcherBase: | |||
| @classmethod | |||
| def fetch( | |||
| cls, | |||
| git_host: str, | |||
| repo_info: str, | |||
| use_cache: bool = False, | |||
| commit: str = None, | |||
| silent: bool = True, | |||
| ) -> str: | |||
| raise NotImplementedError() | |||
| @classmethod | |||
| def _parse_repo_info(cls, repo_info: str) -> Tuple[str, str, str]: | |||
| try: | |||
| branch_info = DEFAULT_BRANCH_NAME | |||
| if ":" in repo_info: | |||
| prefix_info, branch_info = repo_info.split(":") | |||
| else: | |||
| prefix_info = repo_info | |||
| repo_owner, repo_name = prefix_info.split("/") | |||
| return repo_owner, repo_name, branch_info | |||
| except ValueError: | |||
| raise InvalidRepo("repo_info: '{}' is invalid.".format(repo_info)) | |||
| @classmethod | |||
| def _check_git_host(cls, git_host): | |||
| return cls._is_valid_domain(git_host) or cls._is_valid_host(git_host) | |||
| @classmethod | |||
| def _is_valid_domain(cls, s): | |||
| try: | |||
| return pattern.match(s.encode("idna").decode("ascii")) | |||
| except UnicodeError: | |||
| return False | |||
| @classmethod | |||
| def _is_valid_host(cls, s): | |||
| nums = s.split(".") | |||
| if len(nums) != 4 or any(not _.isdigit() for _ in nums): | |||
| return False | |||
| return all(0 <= int(_) < 256 for _ in nums) | |||
| @classmethod | |||
| def _gen_repo_dir(cls, repo_dir: str) -> str: | |||
| return hashlib.sha1(repo_dir.encode()).hexdigest()[:16] | |||
| class GitSSHFetcher(RepoFetcherBase): | |||
| @classmethod | |||
| @synchronized | |||
| def fetch( | |||
| cls, | |||
| git_host: str, | |||
| repo_info: str, | |||
| use_cache: bool = False, | |||
| commit: str = None, | |||
| silent: bool = True, | |||
| ) -> str: | |||
| """ | |||
| Fetches git repo by SSH protocol | |||
| :param git_host: | |||
| host address of git repo. | |||
| example: github.com | |||
| :param repo_info: | |||
| a string with format ``"repo_owner/repo_name[:tag_name/:branch_name]"`` with an optional | |||
| tag/branch. The default branch is ``master`` if not specified. | |||
| example: ``"brain_sdk/MegBrain[:hub]"`` | |||
| :param use_cache: | |||
| whether to use locally fetched code or completely re-fetch | |||
| :param commit: | |||
| commit id on github or gitlab | |||
| :param silent: | |||
| whether to accept the stdout and stderr of the subprocess with PIPE, instead of | |||
| displaying on the screen | |||
| :return: | |||
| directory where the repo code is stored | |||
| """ | |||
| if not cls._check_git_host(git_host): | |||
| raise InvalidGitHost("git_host: '{}' is malformed.".format(git_host)) | |||
| repo_owner, repo_name, branch_info = cls._parse_repo_info(repo_info) | |||
| normalized_branch_info = branch_info.replace("/", "_") | |||
| repo_dir_raw = "{}_{}_{}".format( | |||
| repo_owner, repo_name, normalized_branch_info | |||
| ) + ("_{}".format(commit) if commit else "") | |||
| repo_dir = cls._gen_repo_dir(repo_dir_raw) | |||
| git_url = "git@{}:{}/{}.git".format(git_host, repo_owner, repo_name) | |||
| if use_cache and os.path.exists(repo_dir): # use cache | |||
| logger.debug("Cache Found in %s", repo_dir) | |||
| return repo_dir | |||
| if is_distributed(): | |||
| logger.warning( | |||
| "When using `hub.load` or `hub.list` to fetch git repositories\n" | |||
| " in DISTRIBUTED mode for the first time, processes are synchronized to\n" | |||
| " ensure that target repository is ready to use for each process.\n" | |||
| " Users are expected to see this warning no more than ONCE, otherwise\n" | |||
| " (very little chance) you may need to remove corrupt cache\n" | |||
| " `%s` and fetch again.", | |||
| repo_dir, | |||
| ) | |||
| shutil.rmtree(repo_dir, ignore_errors=True) # ignore and clear cache | |||
| logger.debug( | |||
| "Git Clone from Repo:%s Branch: %s to %s", | |||
| git_url, | |||
| normalized_branch_info, | |||
| repo_dir, | |||
| ) | |||
| kwargs = ( | |||
| {"stderr": subprocess.PIPE, "stdout": subprocess.PIPE} if silent else {} | |||
| ) | |||
| if commit is None: | |||
| # shallow clone repo by branch/tag | |||
| p = subprocess.Popen( | |||
| [ | |||
| "git", | |||
| "clone", | |||
| "-b", | |||
| normalized_branch_info, | |||
| git_url, | |||
| repo_dir, | |||
| "--depth=1", | |||
| ], | |||
| **kwargs, | |||
| ) | |||
| cls._check_clone_pipe(p) | |||
| else: | |||
| # clone repo and checkout to commit_id | |||
| p = subprocess.Popen(["git", "clone", git_url, repo_dir], **kwargs) | |||
| cls._check_clone_pipe(p) | |||
| with cd(repo_dir): | |||
| logger.debug("git checkout to %s", commit) | |||
| p = subprocess.Popen(["git", "checkout", commit], **kwargs) | |||
| _, err = p.communicate() | |||
| if p.returncode: | |||
| shutil.rmtree(repo_dir, ignore_errors=True) | |||
| raise GitCheckoutError( | |||
| "Git checkout error, please check the commit id.\n" | |||
| + err.decode() | |||
| ) | |||
| with cd(repo_dir): | |||
| shutil.rmtree(".git") | |||
| return repo_dir | |||
| @classmethod | |||
| def _check_clone_pipe(cls, p): | |||
| _, err = p.communicate() | |||
| if p.returncode: | |||
| raise GitPullError( | |||
| "Repo pull error, please check repo info.\n" + err.decode() | |||
| ) | |||
| class GitHTTPSFetcher(RepoFetcherBase): | |||
| @classmethod | |||
| @synchronized | |||
| def fetch( | |||
| cls, | |||
| git_host: str, | |||
| repo_info: str, | |||
| use_cache: bool = False, | |||
| commit: str = None, | |||
| silent: bool = True, | |||
| ) -> str: | |||
| """ | |||
| Fetches git repo by HTTPS protocol | |||
| :param git_host: | |||
| host address of git repo | |||
| example: github.com | |||
| :param repo_info: | |||
| a string with format ``"repo_owner/repo_name[:tag_name/:branch_name]"`` with an optional | |||
| tag/branch. The default branch is ``master`` if not specified. | |||
| example: ``"brain_sdk/MegBrain[:hub]"`` | |||
| :param use_cache: | |||
| whether to use locally cached code or completely re-fetch | |||
| :param commit: | |||
| commit id on github or gitlab | |||
| :param silent: | |||
| whether to accept the stdout and stderr of the subprocess with PIPE, instead of | |||
| displaying on the screen | |||
| :return: | |||
| directory where the repo code is stored | |||
| """ | |||
| if not cls._check_git_host(git_host): | |||
| raise InvalidGitHost("git_host: '{}' is malformed.".format(git_host)) | |||
| repo_owner, repo_name, branch_info = cls._parse_repo_info(repo_info) | |||
| normalized_branch_info = branch_info.replace("/", "_") | |||
| repo_dir_raw = "{}_{}_{}".format( | |||
| repo_owner, repo_name, normalized_branch_info | |||
| ) + ("_{}".format(commit) if commit else "") | |||
| repo_dir = cls._gen_repo_dir(repo_dir_raw) | |||
| archive_url = cls._git_archive_link( | |||
| git_host, repo_owner, repo_name, branch_info, commit | |||
| ) | |||
| if use_cache and os.path.exists(repo_dir): # use cache | |||
| logger.debug("Cache Found in %s", repo_dir) | |||
| return repo_dir | |||
| if is_distributed(): | |||
| logger.warning( | |||
| "When using `hub.load` or `hub.list` to fetch git repositories " | |||
| "in DISTRIBUTED mode for the first time, processes are synchronized to " | |||
| "ensure that target repository is ready to use for each process.\n" | |||
| "Users are expected to see this warning no more than ONCE, otherwise" | |||
| "(very little chance) you may need to remove corrupt hub cache %s and fetch again." | |||
| ) | |||
| shutil.rmtree(repo_dir, ignore_errors=True) # ignore and clear cache | |||
| logger.debug("Downloading from %s to %s", archive_url, repo_dir) | |||
| cls._download_zip_and_extract(archive_url, repo_dir) | |||
| return repo_dir | |||
| @classmethod | |||
| def _download_zip_and_extract(cls, url, target_dir): | |||
| resp = requests.get(url, timeout=HTTP_TIMEOUT, stream=True) | |||
| if resp.status_code != 200: | |||
| raise HTTPDownloadError( | |||
| "An error occured when downloading from {}".format(url) | |||
| ) | |||
| total_size = int(resp.headers.get("Content-Length", 0)) | |||
| _bar = tqdm(total=total_size, unit="iB", unit_scale=True) | |||
| with NamedTemporaryFile("w+b") as f: | |||
| for chunk in resp.iter_content(CHUNK_SIZE): | |||
| if not chunk: | |||
| break | |||
| _bar.update(len(chunk)) | |||
| f.write(chunk) | |||
| _bar.close() | |||
| f.seek(0) | |||
| with ZipFile(f) as temp_zip_f: | |||
| zip_dir_name = temp_zip_f.namelist()[0].split("/")[0] | |||
| temp_zip_f.extractall(".") | |||
| shutil.move(zip_dir_name, target_dir) | |||
| @classmethod | |||
| def _git_archive_link(cls, git_host, repo_owner, repo_name, branch_info, commit): | |||
| archive_link = "https://{}/{}/{}/archive/{}.zip".format( | |||
| git_host, repo_owner, repo_name, commit or branch_info | |||
| ) | |||
| return archive_link | |||
| @@ -1,333 +0,0 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import functools | |||
| import hashlib | |||
| import os | |||
| import sys | |||
| import types | |||
| from typing import Any, List | |||
| from urllib.parse import urlparse | |||
| from megengine.utils.http_download import download_from_url | |||
| from ..core.serialization import load as _mge_load_serialized | |||
| from ..distributed import is_distributed | |||
| from ..logger import get_logger | |||
| from .const import ( | |||
| DEFAULT_CACHE_DIR, | |||
| DEFAULT_GIT_HOST, | |||
| DEFAULT_PROTOCOL, | |||
| ENV_MGE_HOME, | |||
| ENV_XDG_CACHE_HOME, | |||
| HTTP_READ_TIMEOUT, | |||
| HUBCONF, | |||
| HUBDEPENDENCY, | |||
| ) | |||
| from .exceptions import InvalidProtocol | |||
| from .fetcher import GitHTTPSFetcher, GitSSHFetcher | |||
| from .tools import cd, check_module_exists, load_module | |||
| logger = get_logger(__name__) | |||
| PROTOCOLS = { | |||
| "HTTPS": GitHTTPSFetcher, | |||
| "SSH": GitSSHFetcher, | |||
| } | |||
| def _get_megengine_home() -> str: | |||
| """MGE_HOME setting complies with the XDG Base Directory Specification | |||
| """ | |||
| megengine_home = os.path.expanduser( | |||
| os.getenv( | |||
| ENV_MGE_HOME, | |||
| os.path.join(os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), "megengine"), | |||
| ) | |||
| ) | |||
| return megengine_home | |||
| def _get_repo( | |||
| git_host: str, | |||
| repo_info: str, | |||
| use_cache: bool = False, | |||
| commit: str = None, | |||
| protocol: str = DEFAULT_PROTOCOL, | |||
| ) -> str: | |||
| if protocol not in PROTOCOLS: | |||
| raise InvalidProtocol( | |||
| "Invalid protocol, the value should be one of {}.".format( | |||
| ", ".join(PROTOCOLS.keys()) | |||
| ) | |||
| ) | |||
| cache_dir = os.path.expanduser(os.path.join(_get_megengine_home(), "hub")) | |||
| with cd(cache_dir): | |||
| fetcher = PROTOCOLS[protocol] | |||
| repo_dir = fetcher.fetch(git_host, repo_info, use_cache, commit) | |||
| return os.path.join(cache_dir, repo_dir) | |||
| def _check_dependencies(module: types.ModuleType) -> None: | |||
| if not hasattr(module, HUBDEPENDENCY): | |||
| return | |||
| dependencies = getattr(module, HUBDEPENDENCY) | |||
| if not dependencies: | |||
| return | |||
| missing_deps = [m for m in dependencies if not check_module_exists(m)] | |||
| if len(missing_deps): | |||
| raise RuntimeError("Missing dependencies: {}".format(", ".join(missing_deps))) | |||
| def _init_hub( | |||
| repo_info: str, | |||
| git_host: str, | |||
| use_cache: bool = True, | |||
| commit: str = None, | |||
| protocol: str = DEFAULT_PROTOCOL, | |||
| ): | |||
| """Imports hubmodule like python import | |||
| :param repo_info: | |||
| a string with format ``"repo_owner/repo_name[:tag_name/:branch_name]"`` with an optional | |||
| tag/branch. The default branch is ``master`` if not specified. | |||
| Example: ``"brain_sdk/MegBrain[:hub]"`` | |||
| :param git_host: | |||
| host address of git repo | |||
| Example: github.com | |||
| :param use_cache: | |||
| whether to use locally cached code or completely re-fetch | |||
| :param commit: | |||
| commit id on github or gitlab | |||
| :param protocol: | |||
| which protocol to use to get the repo, and HTTPS protocol only supports public repo on github. | |||
| The value should be one of HTTPS, SSH. | |||
| :return: | |||
| hubconf.py as a python module | |||
| """ | |||
| cache_dir = os.path.expanduser(os.path.join(_get_megengine_home(), "hub")) | |||
| os.makedirs(cache_dir, exist_ok=True) | |||
| absolute_repo_dir = _get_repo( | |||
| git_host, repo_info, use_cache=use_cache, commit=commit, protocol=protocol | |||
| ) | |||
| sys.path.insert(0, absolute_repo_dir) | |||
| hubmodule = load_module(HUBCONF, os.path.join(absolute_repo_dir, HUBCONF)) | |||
| sys.path.remove(absolute_repo_dir) | |||
| return hubmodule | |||
| @functools.wraps(_init_hub) | |||
| def import_module(*args, **kwargs): | |||
| return _init_hub(*args, **kwargs) | |||
| def list( | |||
| repo_info: str, | |||
| git_host: str = DEFAULT_GIT_HOST, | |||
| use_cache: bool = True, | |||
| commit: str = None, | |||
| protocol: str = DEFAULT_PROTOCOL, | |||
| ) -> List[str]: | |||
| """Lists all entrypoints available in repo hubconf | |||
| :param repo_info: | |||
| a string with format ``"repo_owner/repo_name[:tag_name/:branch_name]"`` with an optional | |||
| tag/branch. The default branch is ``master`` if not specified. | |||
| Example: ``"brain_sdk/MegBrain[:hub]"`` | |||
| :param git_host: | |||
| host address of git repo | |||
| Example: github.com | |||
| :param use_cache: | |||
| whether to use locally cached code or completely re-fetch | |||
| :param commit: | |||
| commit id on github or gitlab | |||
| :param protocol: | |||
| which protocol to use to get the repo, and HTTPS protocol only supports public repo on github. | |||
| The value should be one of HTTPS, SSH. | |||
| :return: | |||
| all entrypoint names of the model | |||
| """ | |||
| hubmodule = _init_hub(repo_info, git_host, use_cache, commit, protocol) | |||
| return [ | |||
| _ | |||
| for _ in dir(hubmodule) | |||
| if not _.startswith("__") and callable(getattr(hubmodule, _)) | |||
| ] | |||
| def load( | |||
| repo_info: str, | |||
| entry: str, | |||
| *args, | |||
| git_host: str = DEFAULT_GIT_HOST, | |||
| use_cache: bool = True, | |||
| commit: str = None, | |||
| protocol: str = DEFAULT_PROTOCOL, | |||
| **kwargs | |||
| ) -> Any: | |||
| """Loads model from github or gitlab repo, with pretrained weights. | |||
| :param repo_info: | |||
| a string with format ``"repo_owner/repo_name[:tag_name/:branch_name]"`` with an optional | |||
| tag/branch. The default branch is ``master`` if not specified. | |||
| Example: ``"brain_sdk/MegBrain[:hub]"`` | |||
| :param entry: | |||
| an entrypoint defined in hubconf | |||
| :param git_host: | |||
| host address of git repo | |||
| Example: github.com | |||
| :param use_cache: | |||
| whether to use locally cached code or completely re-fetch | |||
| :param commit: | |||
| commit id on github or gitlab | |||
| :param protocol: | |||
| which protocol to use to get the repo, and HTTPS protocol only supports public repo on github. | |||
| The value should be one of HTTPS, SSH. | |||
| :return: | |||
| a single model with corresponding pretrained weights. | |||
| """ | |||
| hubmodule = _init_hub(repo_info, git_host, use_cache, commit, protocol) | |||
| if not hasattr(hubmodule, entry) or not callable(getattr(hubmodule, entry)): | |||
| raise RuntimeError("Cannot find callable {} in hubconf.py".format(entry)) | |||
| _check_dependencies(hubmodule) | |||
| module = getattr(hubmodule, entry)(*args, **kwargs) | |||
| return module | |||
| def help( | |||
| repo_info: str, | |||
| entry: str, | |||
| git_host: str = DEFAULT_GIT_HOST, | |||
| use_cache: bool = True, | |||
| commit: str = None, | |||
| protocol: str = DEFAULT_PROTOCOL, | |||
| ) -> str: | |||
| """This function returns docstring of entrypoint ``entry`` by following steps: | |||
| 1. Pull the repo code specified by git and repo_info | |||
| 2. Load the entry defined in repo's hubconf.py | |||
| 3. Return docstring of function entry | |||
| :param repo_info: | |||
| a string with format ``"repo_owner/repo_name[:tag_name/:branch_name]"`` with an optional | |||
| tag/branch. The default branch is ``master`` if not specified. | |||
| Example: ``"brain_sdk/MegBrain[:hub]"`` | |||
| :param entry: | |||
| an entrypoint defined in hubconf.py | |||
| :param git_host: | |||
| host address of git repo | |||
| Example: github.com | |||
| :param use_cache: | |||
| whether to use locally cached code or completely re-fetch | |||
| :param commit: | |||
| commit id on github or gitlab | |||
| :param protocol: | |||
| which protocol to use to get the repo, and HTTPS protocol only supports public repo on github. | |||
| The value should be one of HTTPS, SSH. | |||
| :return: | |||
| docstring of entrypoint ``entry`` | |||
| """ | |||
| hubmodule = _init_hub(repo_info, git_host, use_cache, commit, protocol) | |||
| if not hasattr(hubmodule, entry) or not callable(getattr(hubmodule, entry)): | |||
| raise RuntimeError("Cannot find callable {} in hubconf.py".format(entry)) | |||
| doc = getattr(hubmodule, entry).__doc__ | |||
| return doc | |||
| def load_serialized_obj_from_url(url: str, model_dir=None) -> Any: | |||
| """Loads MegEngine serialized object from the given URL. | |||
| If the object is already present in ``model_dir``, it's deserialized and | |||
| returned. If no ``model_dir`` is specified, it will be ``MGE_HOME/serialized``. | |||
| :param url: url to serialized object | |||
| :param model_dir: dir to cache target serialized file | |||
| :return: loaded object | |||
| """ | |||
| if model_dir is None: | |||
| model_dir = os.path.join(_get_megengine_home(), "serialized") | |||
| os.makedirs(model_dir, exist_ok=True) | |||
| parts = urlparse(url) | |||
| filename = os.path.basename(parts.path) | |||
| # use hash as prefix to avoid filename conflict from different urls | |||
| sha256 = hashlib.sha256() | |||
| sha256.update(url.encode()) | |||
| digest = sha256.hexdigest()[:6] | |||
| filename = digest + "_" + filename | |||
| cached_file = os.path.join(model_dir, filename) | |||
| logger.info( | |||
| "load_serialized_obj_from_url: download to or using cached %s", cached_file | |||
| ) | |||
| if not os.path.exists(cached_file): | |||
| if is_distributed(): | |||
| logger.warning( | |||
| "Downloading serialized object in DISTRIBUTED mode\n" | |||
| " File may be downloaded multiple times. We recommend\n" | |||
| " users to download in single process first." | |||
| ) | |||
| download_from_url(url, cached_file, HTTP_READ_TIMEOUT) | |||
| state_dict = _mge_load_serialized(cached_file) | |||
| return state_dict | |||
| class pretrained: | |||
| r""" | |||
| Decorator which helps to download pretrained weights from the given url. | |||
| For example, we can decorate a resnet18 function as follows | |||
| .. code-block:: | |||
| @hub.pretrained("https://url/to/pretrained_resnet18.pkl") | |||
| def resnet18(**kwargs): | |||
| return ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) | |||
| When decorated function is called with ``pretrained=True``, MegEngine will automatically | |||
| download and fill the returned model with pretrained weights. | |||
| """ | |||
| def __init__(self, url): | |||
| self.url = url | |||
| def __call__(self, func): | |||
| @functools.wraps(func) | |||
| def pretrained_model_func( | |||
| pretrained=False, **kwargs | |||
| ): # pylint: disable=redefined-outer-name | |||
| model = func(**kwargs) | |||
| if pretrained: | |||
| weights = load_serialized_obj_from_url(self.url) | |||
| model.load_state_dict(weights) | |||
| return model | |||
| return pretrained_model_func | |||
| __all__ = [ | |||
| "list", | |||
| "load", | |||
| "help", | |||
| "load_serialized_obj_from_url", | |||
| "pretrained", | |||
| "import_module", | |||
| ] | |||
| @@ -1,48 +0,0 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import importlib.util | |||
| import os | |||
| import types | |||
| from contextlib import contextmanager | |||
| from typing import Iterator | |||
| def load_module(name: str, path: str) -> types.ModuleType: | |||
| """ | |||
| Loads module specified by name and path | |||
| :param name: module name | |||
| :param path: module path | |||
| """ | |||
| spec = importlib.util.spec_from_file_location(name, path) | |||
| module = importlib.util.module_from_spec(spec) | |||
| spec.loader.exec_module(module) | |||
| return module | |||
| def check_module_exists(module: str) -> bool: | |||
| """Checks whether python module exists or not | |||
| :param module: name of module | |||
| """ | |||
| return importlib.util.find_spec(module) is not None | |||
| @contextmanager | |||
| def cd(target: str) -> Iterator[None]: | |||
| """Changes current directory to target | |||
| :param target: target directory | |||
| """ | |||
| prev = os.getcwd() | |||
| os.chdir(os.path.expanduser(target)) | |||
| try: | |||
| yield | |||
| finally: | |||
| os.chdir(prev) | |||
| @@ -1,570 +0,0 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import contextlib | |||
| import functools | |||
| import itertools | |||
| import os | |||
| from typing import Callable, Tuple, Union | |||
| import numpy as np | |||
| import megengine._internal as mgb | |||
| from megengine._internal.plugin import CompGraphProfiler | |||
| from ..core import Tensor, graph, tensor | |||
| from .sublinear_memory_config import SublinearMemoryConfig | |||
| def sideeffect(f): | |||
| # during eager tracing, wrapped function is called with proxy inputs | |||
| # during static tracing, wrapped function will not be called at all | |||
| @functools.wraps(f) | |||
| def wrapper(*args, **kwargs): # pylint: disable=inconsistent-return-statements | |||
| if not trace._active_instance: | |||
| return f(*args, **kwargs) | |||
| tensors = {} | |||
| for i, x in itertools.chain(enumerate(args), kwargs.items()): | |||
| if isinstance(x, Tensor): | |||
| tensors[i] = x | |||
| if tensors: | |||
| _keys, tensors = zip(*tensors.items()) | |||
| else: | |||
| _keys, tensors = (), () | |||
| def callback(*tensors, f=f, keys=_keys, args=args, kwargs=kwargs): | |||
| replace = dict(zip(keys, tensors)) | |||
| args = tuple(replace.get(i, x) for i, x in enumerate(args)) | |||
| kwargs = {i: replace.get(i, x) for i, x in kwargs.items()} | |||
| if f(*args, **kwargs) is not None: | |||
| raise TypeError("a sideeffect function should return None") | |||
| # TODO: clear memory | |||
| trace._active_instance._register_callback(callback, tensors) | |||
| return wrapper | |||
| def mark_impure(x): | |||
| if not trace._active_instance: | |||
| return x | |||
| return trace._active_instance._mark_impure(x) | |||
| def barrier(x): | |||
| if not trace._active_instance: | |||
| return x | |||
| return trace._active_instance._insert_barrier(x) | |||
| def _dummy(): | |||
| return mgb.make_immutable(*graph._use_default_if_none(None, None), 0) | |||
| class unset: | |||
| pass | |||
| class trace: | |||
| """ | |||
| Wrap a callable and provide: | |||
| * tracing via :meth:`.trace` and :meth:`.dump` | |||
| * accelerated evalutaion via :meth:`.__call__` | |||
| :param func: Positional only argument. | |||
| :param symbolic: Whether to use symbolic tensor. Default: False | |||
| :param opt_level: Optimization level for compiling trace. | |||
| :param log_level: Log level. | |||
| :param sublinear_memory_config: Configuration for sublinear memory optimization. | |||
| If not None, it enables sublinear memory optimization with given setting. | |||
| :param allreduce_pack_max_size: Maximum size of an allreduce pack in MB. | |||
| If not None, multiple gradients will be packed and synchronized together | |||
| :param profiling: Whether to profile compiled trace. Default: False | |||
| """ | |||
| _active_instance = None | |||
| enabled = not os.getenv("MGE_DISABLE_TRACE") | |||
| _UNSTARTED = "unstarted" | |||
| _STARTED = "started" | |||
| _FINISHED = "finished" | |||
| def __new__(cls, *args, **kwargs): | |||
| if not args: | |||
| return functools.partial(cls, **kwargs) | |||
| return super().__new__(cls) | |||
| def __init__( | |||
| self, | |||
| func: Callable[..., Union[None, Tensor, Tuple[Tensor]]], | |||
| *, | |||
| symbolic: bool = False, | |||
| opt_level: int = None, | |||
| log_level: int = None, | |||
| sublinear_memory_config: SublinearMemoryConfig = None, | |||
| allreduce_pack_max_size: int = None, | |||
| profiling: bool = False | |||
| ): | |||
| self.__wrapped__ = func | |||
| self._symbolic = symbolic | |||
| self._graph_opt_level = opt_level | |||
| self._log_level = log_level | |||
| self._sublinear_memory_config = sublinear_memory_config | |||
| self._allreduce_pack_max_size = allreduce_pack_max_size | |||
| self._status = self._UNSTARTED | |||
| self._args = None | |||
| self._kwargs = None | |||
| self._outputs = unset | |||
| self._sym_outputs = unset | |||
| self._outspec = None | |||
| self._checkpoint = None | |||
| self._compiled_func = None | |||
| self._profiling = profiling | |||
| self._profiler = None | |||
| @property | |||
| def _active(self): | |||
| c1 = self._status == self._STARTED | |||
| c2 = type(self)._active_instance is self | |||
| assert c1 == c2 | |||
| return c1 | |||
| def _register_callback(self, f, args=()): | |||
| assert self._active | |||
| assert isinstance(args, (tuple, list)) | |||
| proxies = self._make_proxies(args) | |||
| self._forward(args, proxies, checkpoint=True) | |||
| # NOTE: under eager graph callback will fire immediately | |||
| job = mgb.opr.callback_injector( | |||
| self._insert_barrier(_dummy()), lambda _: f(*proxies) | |||
| ) | |||
| self._insert_checkpoint(job) | |||
| self._outspec.append(job) | |||
| def _insert_barrier(self, x): | |||
| assert self._active | |||
| if self._checkpoint is None: | |||
| return x | |||
| if isinstance(x, Tensor): | |||
| x = x._symvar | |||
| wrap = True | |||
| else: | |||
| wrap = False | |||
| if not isinstance(x, mgb.SymbolVar): | |||
| raise TypeError | |||
| x = mgb.opr.virtual_dep([x, self._checkpoint]) | |||
| if wrap: | |||
| x = Tensor(x) | |||
| return x | |||
| def _insert_checkpoint(self, *args, no_barrier=False): | |||
| assert self._active | |||
| if not args: | |||
| return | |||
| args = tuple(x._symvar if isinstance(x, Tensor) else x for x in args) | |||
| for x in args: | |||
| if not isinstance(x, mgb.SymbolVar): | |||
| raise TypeError | |||
| if not no_barrier and self._checkpoint is not None: | |||
| # normally no need to _insert_barrier here, but if | |||
| # someone forget to call _insert_barrier beforehand, | |||
| # this can make things less broken | |||
| args += (self._checkpoint,) | |||
| if len(args) == 1: | |||
| self._checkpoint = args[0] | |||
| else: | |||
| self._checkpoint = mgb.opr.virtual_dep(args) | |||
| def _mark_impure(self, x): | |||
| assert self._active | |||
| ret = x | |||
| if isinstance(x, Tensor): | |||
| x = x._symvar | |||
| if not isinstance(x, mgb.SymbolVar): | |||
| raise TypeError | |||
| self._outspec.append(x) | |||
| self._insert_checkpoint(x) | |||
| return ret | |||
| def _make_proxies(self, args): | |||
| assert isinstance(args, (tuple, list)) | |||
| for x in args: | |||
| assert isinstance(x, Tensor) | |||
| return tuple(tensor(dtype=x.dtype, device=x.device) for x in args) | |||
| def _forward(self, srcs, dests, checkpoint=True): | |||
| # pseudo-op: does not run under static graph; traced | |||
| # TODO: use shared memory | |||
| assert len(srcs) == len(dests) | |||
| if not self._active: | |||
| for s, d in zip(srcs, dests): | |||
| d.set_value(s, share=False) | |||
| return | |||
| jobs = [] | |||
| for s, d in zip(srcs, dests): | |||
| def callback(value, dest=d): | |||
| dest.set_value(value, share=False) | |||
| s = self._insert_barrier(s._symvar) | |||
| # NOTE: callback immediately fire in eager graph | |||
| jobs.append(mgb.opr.callback_injector(s, callback)) | |||
| self._outspec.extend(jobs) | |||
| if checkpoint: | |||
| self._insert_checkpoint(*jobs, no_barrier=True) | |||
| def _forward_inputs(self, *args, **kwargs): | |||
| if self._kwargs is None: | |||
| self._kwargs = kwargs | |||
| elif self._kwargs != kwargs: | |||
| raise ValueError("kwargs must not change between invocations") | |||
| if self._args is None: | |||
| self._args = [] | |||
| for i in args: | |||
| if isinstance(i, Tensor): | |||
| self._args.append(tensor(dtype=i.dtype, device=i.device)) | |||
| self._args[-1].set_value(i, share=False) | |||
| else: | |||
| self._args.append(tensor(i)) | |||
| else: | |||
| if not len(args) == len(self._args): | |||
| raise TypeError | |||
| for i, proxy in zip(args, self._args): | |||
| proxy.set_value(i, share=False) | |||
| # XXX: sync? | |||
| def _make_outputs(self, outputs): | |||
| if outputs is None: | |||
| self._outputs = None | |||
| return | |||
| if isinstance(outputs, Tensor): | |||
| # no one is able to call barrier after this, so no need to checkpoint | |||
| # but checkpoint do little harm anyway | |||
| (self._outputs,) = self._make_proxies([outputs]) | |||
| return | |||
| if not isinstance(outputs, (tuple, list)): | |||
| raise TypeError("should return (tuple of) tensor") | |||
| for i in outputs: | |||
| if not isinstance(i, Tensor): | |||
| raise TypeError("should return (tuple of) tensor") | |||
| self._outputs = self._make_proxies(outputs) | |||
| def _foward_outputs(self, outputs): | |||
| # pseudo-op: does not run under static graph; traced | |||
| if self._outputs is unset: | |||
| self._make_outputs(outputs) | |||
| if self._outputs is None: | |||
| if outputs is not None: | |||
| raise TypeError("should return None") | |||
| elif isinstance(self._outputs, Tensor): | |||
| if not isinstance(outputs, Tensor): | |||
| raise TypeError("should return a tensor") | |||
| self._forward([outputs], [self._outputs]) | |||
| else: | |||
| assert isinstance(self._outputs, tuple) | |||
| def check(): | |||
| if not isinstance(outputs, (tuple, list)): | |||
| return False | |||
| if len(self._outputs) != len(outputs): | |||
| return False | |||
| for x in outputs: | |||
| if not isinstance(x, Tensor): | |||
| return False | |||
| return True | |||
| if not check(): | |||
| raise TypeError( | |||
| "should return tuple of %d tensors" % len(self._outputs) | |||
| ) | |||
| self._forward(outputs, self._outputs) | |||
| def _apply_graph_options(self, cg): | |||
| # graph opt level | |||
| if self._graph_opt_level is not None: | |||
| cg.set_option("graph_opt_level", self._graph_opt_level) | |||
| # log level | |||
| if self._log_level is not None: | |||
| cg.set_option("log_level", self._log_level) | |||
| # sublinear | |||
| if self._sublinear_memory_config is not None: | |||
| cg.set_option("enable_sublinear_memory_opt", True) | |||
| cg.set_option( | |||
| "sublinear_mem_config.lb_memory", | |||
| self._sublinear_memory_config.lb_memory, | |||
| ) | |||
| cg.set_option( | |||
| "sublinear_mem_config.genetic_nr_iter", | |||
| self._sublinear_memory_config.genetic_nr_iter, | |||
| ) | |||
| cg.set_option( | |||
| "sublinear_mem_config.genetic_pool_size", | |||
| self._sublinear_memory_config.genetic_pool_size, | |||
| ) | |||
| cg.set_option( | |||
| "sublinear_mem_config.thresh_nr_try", | |||
| self._sublinear_memory_config.thresh_nr_try, | |||
| ) | |||
| cg.set_option( | |||
| "sublinear_mem_config.num_worker", | |||
| self._sublinear_memory_config.num_worker, | |||
| ) | |||
| # pack allreduce | |||
| if self._allreduce_pack_max_size is not None: | |||
| cg.set_option("allreduce_pack_max_size", self._allreduce_pack_max_size) | |||
| # profile | |||
| if self._profiling: | |||
| self._profiler = CompGraphProfiler(cg) | |||
| def _get_graph(self, eager): | |||
| if eager: | |||
| if not hasattr(self, "_eager_graph"): | |||
| # pylint: disable=attribute-defined-outside-init | |||
| self._eager_graph = graph.Graph(eager_evaluation=True) | |||
| self._apply_graph_options(self._eager_graph) | |||
| return self._eager_graph | |||
| else: | |||
| if not hasattr(self, "_static_graph"): | |||
| # pylint: disable=attribute-defined-outside-init | |||
| self._static_graph = graph.Graph(eager_evaluation=False) | |||
| self._apply_graph_options(self._static_graph) | |||
| return self._static_graph | |||
| @contextlib.contextmanager | |||
| def _prepare(self, args, kwargs, enable): | |||
| # prepare for execution | |||
| self._forward_inputs(*args, **kwargs) | |||
| if not enable: | |||
| # XXX: use our own graph here? | |||
| cg = None | |||
| elif self._status == self._FINISHED: | |||
| cg = None | |||
| elif self._symbolic: | |||
| cg = self._get_graph(eager=False) | |||
| else: | |||
| cg = self._get_graph(eager=True) | |||
| try: | |||
| # NOTE: always trace in a new graph, so capturing an undetached tensor | |||
| # will never work (would work if tracing in default graph) | |||
| if cg is None: | |||
| yield | |||
| else: | |||
| with cg: | |||
| yield | |||
| finally: | |||
| # XXX: properly release memory | |||
| if cg: | |||
| cg.clear_device_memory() | |||
| @contextlib.contextmanager | |||
| def _activate(self): | |||
| # prepare for tracing | |||
| if self._status != self._UNSTARTED: | |||
| raise RuntimeError("cannot trace a second time") | |||
| if type(self)._active_instance is not None: | |||
| raise RuntimeError("nested trace is unsupported") | |||
| self._status = self._STARTED | |||
| type(self)._active_instance = self | |||
| self._user_cache = {} | |||
| try: | |||
| yield | |||
| finally: | |||
| self._status = self._FINISHED | |||
| self._user_cache = None | |||
| type(self)._active_instance = None | |||
| def _run_wrapped(self): | |||
| outputs = self.__wrapped__(*self._args, **self._kwargs) | |||
| self._foward_outputs(outputs) | |||
| return outputs | |||
| def _do_trace(self): | |||
| with self._activate(): | |||
| self._outspec = [] | |||
| outputs = self._run_wrapped() | |||
| if outputs is None: | |||
| self._sym_outputs = None | |||
| else: | |||
| if isinstance(outputs, Tensor): | |||
| outputs = [outputs] | |||
| # _run_wrapped has checked validity of outputs | |||
| self._sym_outputs = tuple(i._symvar for i in outputs) | |||
| mgb.comp_graph_tools.set_priority_to_id(self._outspec) | |||
| self._compiled_func = graph.get_default_graph().compile(None, self._outspec) | |||
| def trace(self, *args: Tensor, **kwargs): | |||
| """ | |||
| Trace wrapped callable with provided arguments. | |||
| """ | |||
| with self._prepare(args, kwargs, enable=True): | |||
| self._do_trace() | |||
| return self | |||
| def __call__(self, *args: Tensor, **kwargs): | |||
| """ | |||
| Evaluate on provided arguments, using compiled trace | |||
| instead of the original callable if applicable. | |||
| :return: ``None`` or :class:`~.Tensor` or tuple of :class:`~.Tensor`, depending on the | |||
| return value of wrapped callable. | |||
| """ | |||
| with self._prepare(args, kwargs, enable=self.enabled): | |||
| if not self.enabled: | |||
| self._run_wrapped() | |||
| elif self._status == self._FINISHED: | |||
| self._compiled_func() | |||
| else: | |||
| if self._status == self._UNSTARTED: | |||
| self._do_trace() | |||
| if self._symbolic: | |||
| self._compiled_func() | |||
| return self._outputs | |||
| def dump( | |||
| self, | |||
| fpath, | |||
| *, | |||
| arg_names=None, | |||
| append=False, | |||
| optimize_for_inference=False, | |||
| output_names=None, | |||
| **kwargs | |||
| ): | |||
| """ | |||
| Serialize trace to file system. | |||
| :param fpath: positional only argument. Path of output file. | |||
| :param arg_names: names of the input tensors in the traced function. | |||
| :param append: whether output is appended to ``fpath``. | |||
| :param optimize_for_inference: whether to enable optimize_for_inference | |||
| pass before dump. | |||
| :param output_names: names of the output tensors in the traced function, | |||
| will use the default name if does not specify. | |||
| :param enable_io16xc32: whether to use float16 for I/O between oprs and use | |||
| float32 as internal computation precision. Note the output var would be | |||
| changed to float16. | |||
| :param enable_ioc16: whether to use float16 for both I/O and computation | |||
| precision. | |||
| :param enable_hwcd4: whether to use NHWCD4 data layout. This is faster on some | |||
| OpenCL backend. | |||
| :param enable_nchw88: whether to use NCHW88 data layout. it currently | |||
| used in X86 AVX backend. | |||
| :param enable_nchw44: whether to use NCHW44 data layout. it currently | |||
| used in arm backend. | |||
| :param enable_nchw44_dot: whether to use NCHW44_dot data layout. it currently | |||
| used in armv8.2+dotprod backend. | |||
| :param enable_nchw4: whether to use NCHW4 data layout. it currently | |||
| used in nvidia backend(based on cudnn). | |||
| :param enable_nchw32: whether to use NCHW32 data layout. it currently | |||
| used in nvidia backend with tensorcore(based on cudnn). | |||
| :param enable_chwn4: whether to use CHWN4 data layout. it currently | |||
| used in nvidia backend with tensorcore. | |||
| :param enable_fuse_conv_bias_nonlinearity: whether to fuse conv+bias+nonlinearty | |||
| into one opr. | |||
| :param enable_fuse_conv_bias_with_z: whether to fuse conv_bias with z | |||
| input for inference on nvidia backend(this optimization pass will | |||
| result in mismatch of the precision of output of training and | |||
| inference) | |||
| """ | |||
| if self._status != self._FINISHED: | |||
| raise ValueError("not traced") | |||
| assert isinstance(self._sym_outputs, (tuple, type(None))) | |||
| if not self._sym_outputs: | |||
| raise ValueError("not outputs") | |||
| if arg_names is None: | |||
| arg_names = ["arg_%d" % i for i in range(len(self._args))] | |||
| elif len(arg_names) != len(self._args): | |||
| raise ValueError( | |||
| "len(arg_names) should be {}, got {}".format( | |||
| len(self._args), len(arg_names) | |||
| ) | |||
| ) | |||
| if isinstance(output_names, str): | |||
| output_names = [output_names] | |||
| if output_names is None: | |||
| output_names = [var.name for var in self._sym_outputs] | |||
| elif len(output_names) != len(self._sym_outputs): | |||
| raise ValueError( | |||
| "len(output_names) should be {}, got {}".format( | |||
| len(self._sym_outputs), len(output_names) | |||
| ) | |||
| ) | |||
| optimize_for_inference_args_map = { | |||
| "enable_io16xc32": "f16_io_f32_comp", | |||
| "enable_ioc16": "f16_io_comp", | |||
| "enable_hwcd4": "use_nhwcd4", | |||
| "enable_nchw4": "use_nchw4", | |||
| "enable_nchw88": "use_nchw88", | |||
| "enable_nchw32": "use_nchw32", | |||
| "enable_nchw44": "use_nchw44", | |||
| "enable_nchw44_dot": "use_nchw44_dot", | |||
| "enable_chwn4": "use_chwn4", | |||
| "enable_fuse_conv_bias_nonlinearity": "fuse_conv_bias_nonlinearity", | |||
| "enable_fuse_conv_bias_with_z": "fuse_conv_bias_with_z", | |||
| } | |||
| if optimize_for_inference: | |||
| optimize_for_inference_kwargs = {} | |||
| for k, v in optimize_for_inference_args_map.items(): | |||
| if kwargs.pop(k, False): | |||
| optimize_for_inference_kwargs[v] = True | |||
| else: | |||
| for k in optimize_for_inference_args_map: | |||
| if kwargs.get(k, False): | |||
| raise ValueError( | |||
| "cannot set %s when optimize_for_inference is not set" % k | |||
| ) | |||
| if kwargs: | |||
| raise ValueError("unknown options: %s" % list(kwargs)) | |||
| cg = self._sym_outputs[0].owner_graph | |||
| replace = {} | |||
| for t, name in zip(self._args, arg_names): | |||
| # relies on symvar dedup | |||
| s = t.__mgb_symvar__(comp_graph=cg) | |||
| replace[s] = mgb.make_arg( | |||
| t.device, cg, dtype=t.dtype, shape=t.shape, name=name | |||
| ) | |||
| # Convert VolatileSharedDeviceTensor to SharedDeviceTensor, | |||
| # otherwise some optimizations would not work. The conversion is | |||
| # safe because there simply is no way (using builtin ops) to make | |||
| # a VolatileSharedDeviceTensor actually volatile. | |||
| for s in mgb.cgtools.get_dep_vars( | |||
| self._sym_outputs, "VolatileSharedDeviceTensor" | |||
| ): | |||
| if s in replace: | |||
| continue # is an input | |||
| replace[s] = mgb.SharedND._from_symvar(s).symvar( | |||
| cg, name=s.name, volatile=False | |||
| ) | |||
| sym_outputs = mgb.cgtools.replace_vars(self._sym_outputs, replace) | |||
| sym_outputs = list(sym_outputs) | |||
| if optimize_for_inference: | |||
| sym_outputs = mgb.optimize_for_inference( | |||
| sym_outputs, **optimize_for_inference_kwargs | |||
| ) | |||
| for var, name in zip(sym_outputs, output_names): | |||
| var.rename(name) | |||
| mgb.serialize_comp_graph_to_file(fpath, sym_outputs, append=append) | |||
| def get_profile(self): | |||
| """ | |||
| Get profiling result for compiled trace. | |||
| :return: a json compatible object. | |||
| """ | |||
| if not self._profiler: | |||
| raise RuntimeError("trace is not set with profiling=True") | |||
| return self._profiler.get() | |||
| @@ -1,56 +0,0 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from ..core.device import get_device_count | |||
| class SublinearMemoryConfig: | |||
| r""" | |||
| Configuration for sublinear memory optimization. | |||
| :param thresh_nr_try: number of samples both for searching in linear space | |||
| and around current thresh in sublinear memory optimization. Default: 10. | |||
| It can also be set through the environmental variable 'MGB_SUBLINEAR_MEMORY_THRESH_NR_TRY'. | |||
| :param genetic_nr_iter: number of iterations to find the best checkpoints in genetic algorithm. | |||
| Default: 0. | |||
| It can also be set through the environmental variable 'MGB_SUBLINEAR_MEMORY_GENETIC_NR_ITER'. | |||
| :param genetic_pool_size: number of samples for the crossover random selection | |||
| during genetic optimization. Default: 20. | |||
| It can also be set through the environmental variable 'MGB_SUBLINEAR_MEMORY_GENETIC_POOL_SIZE'. | |||
| :param lb_memory: memory lower bound of bottleneck size in MB for sublinear memory optimization. | |||
| It can be used to perform manual tradeoff between memory and speed. Default: 0. | |||
| It can also be set through the environmental variable 'MGB_SUBLINEAR_MEMORY_LOWER_BOUND_MB'. | |||
| :param num_worker: number of thread workers to search the optimum checkpoints | |||
| in sublinear memory optimization. Default: half of cpu number in the system. | |||
| Note: the value must be greater or equal to one. | |||
| It can also be set through the environmental variable 'MGB_SUBLINEAR_MEMORY_WORKERS'. | |||
| Note that the environmental variable MGB_COMP_GRAPH_OPT must be set to 'enable_sublinear_memory_opt=1' | |||
| in order for the above environmental variable to be effective. | |||
| """ | |||
| def __init__( | |||
| self, | |||
| thresh_nr_try: int = 10, | |||
| genetic_nr_iter: int = 0, | |||
| genetic_pool_size: int = 20, | |||
| lb_memory: int = 0, | |||
| num_worker: int = max(1, get_device_count("cpu") // 2), | |||
| ): | |||
| assert thresh_nr_try >= 0, "thresh_nr_try must be greater or equal to zero" | |||
| self.thresh_nr_try = thresh_nr_try | |||
| assert genetic_nr_iter >= 0, "genetic_nr_iter must be greater or equal to zero" | |||
| self.genetic_nr_iter = genetic_nr_iter | |||
| assert ( | |||
| genetic_pool_size >= 0 | |||
| ), "genetic_pool_size must be greater or equal to zero" | |||
| self.genetic_pool_size = genetic_pool_size | |||
| self.lb_memory = lb_memory | |||
| assert num_worker > 0, "num_worker must be greater or equal to one" | |||
| self.num_worker = num_worker | |||
| @@ -1,231 +0,0 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import contextlib | |||
| import logging | |||
| import os | |||
| import sys | |||
| _all_loggers = [] | |||
| _default_level_name = os.getenv("MEGENGINE_LOGGING_LEVEL", "INFO") | |||
| _default_level = logging.getLevelName(_default_level_name.upper()) | |||
| def set_log_file(fout, mode="a"): | |||
| r"""Sets log output file. | |||
| :type fout: str or file-like | |||
| :param fout: file-like object that supports write and flush, or string for | |||
| the filename | |||
| :type mode: str | |||
| :param mode: specify the mode to open log file if *fout* is a string | |||
| """ | |||
| if isinstance(fout, str): | |||
| fout = open(fout, mode) | |||
| MegEngineLogFormatter.log_fout = fout | |||
| class MegEngineLogFormatter(logging.Formatter): | |||
| log_fout = None | |||
| date_full = "[%(asctime)s %(lineno)d@%(filename)s:%(name)s] " | |||
| date = "%(asctime)s " | |||
| msg = "%(message)s" | |||
| max_lines = 256 | |||
| def _color_exc(self, msg): | |||
| r"""Sets the color of message as the execution type. | |||
| """ | |||
| return "\x1b[34m{}\x1b[0m".format(msg) | |||
| def _color_dbg(self, msg): | |||
| r"""Sets the color of message as the debugging type. | |||
| """ | |||
| return "\x1b[36m{}\x1b[0m".format(msg) | |||
| def _color_warn(self, msg): | |||
| r"""Sets the color of message as the warning type. | |||
| """ | |||
| return "\x1b[1;31m{}\x1b[0m".format(msg) | |||
| def _color_err(self, msg): | |||
| r"""Sets the color of message as the error type. | |||
| """ | |||
| return "\x1b[1;4;31m{}\x1b[0m".format(msg) | |||
| def _color_omitted(self, msg): | |||
| r"""Sets the color of message as the omitted type. | |||
| """ | |||
| return "\x1b[35m{}\x1b[0m".format(msg) | |||
| def _color_normal(self, msg): | |||
| r"""Sets the color of message as the normal type. | |||
| """ | |||
| return msg | |||
| def _color_date(self, msg): | |||
| r"""Sets the color of message the same as date. | |||
| """ | |||
| return "\x1b[32m{}\x1b[0m".format(msg) | |||
| def format(self, record): | |||
| if record.levelno == logging.DEBUG: | |||
| mcl, mtxt = self._color_dbg, "DBG" | |||
| elif record.levelno == logging.WARNING: | |||
| mcl, mtxt = self._color_warn, "WRN" | |||
| elif record.levelno == logging.ERROR: | |||
| mcl, mtxt = self._color_err, "ERR" | |||
| else: | |||
| mcl, mtxt = self._color_normal, "" | |||
| if mtxt: | |||
| mtxt += " " | |||
| if self.log_fout: | |||
| self.__set_fmt(self.date_full + mtxt + self.msg) | |||
| formatted = super(MegEngineLogFormatter, self).format(record) | |||
| nr_line = formatted.count("\n") + 1 | |||
| if nr_line >= self.max_lines: | |||
| head, body = formatted.split("\n", 1) | |||
| formatted = "\n".join( | |||
| [ | |||
| head, | |||
| "BEGIN_LONG_LOG_{}_LINES{{".format(nr_line - 1), | |||
| body, | |||
| "}}END_LONG_LOG_{}_LINES".format(nr_line - 1), | |||
| ] | |||
| ) | |||
| self.log_fout.write(formatted) | |||
| self.log_fout.write("\n") | |||
| self.log_fout.flush() | |||
| self.__set_fmt(self._color_date(self.date) + mcl(mtxt + self.msg)) | |||
| formatted = super(MegEngineLogFormatter, self).format(record) | |||
| if record.exc_text or record.exc_info: | |||
| # handle exception format | |||
| b = formatted.find("Traceback ") | |||
| if b != -1: | |||
| s = formatted[b:] | |||
| s = self._color_exc(" " + s.replace("\n", "\n ")) | |||
| formatted = formatted[:b] + s | |||
| nr_line = formatted.count("\n") + 1 | |||
| if nr_line >= self.max_lines: | |||
| lines = formatted.split("\n") | |||
| remain = self.max_lines // 2 | |||
| removed = len(lines) - remain * 2 | |||
| if removed > 0: | |||
| mid_msg = self._color_omitted( | |||
| "[{} log lines omitted (would be written to output file " | |||
| "if set_log_file() has been called;\n" | |||
| " the threshold can be set at " | |||
| "MegEngineLogFormatter.max_lines)]".format(removed) | |||
| ) | |||
| formatted = "\n".join(lines[:remain] + [mid_msg] + lines[-remain:]) | |||
| return formatted | |||
| if sys.version_info.major < 3: | |||
| def __set_fmt(self, fmt): | |||
| self._fmt = fmt | |||
| else: | |||
| def __set_fmt(self, fmt): | |||
| self._style._fmt = fmt | |||
| def get_logger(name=None, formatter=MegEngineLogFormatter): | |||
| r"""Gets megengine logger with given name. | |||
| """ | |||
| logger = logging.getLogger(name) | |||
| if getattr(logger, "_init_done__", None): | |||
| return logger | |||
| logger._init_done__ = True | |||
| logger.propagate = False | |||
| logger.setLevel(_default_level) | |||
| handler = logging.StreamHandler() | |||
| handler.setFormatter(formatter(datefmt="%d %H:%M:%S")) | |||
| handler.setLevel(0) | |||
| del logger.handlers[:] | |||
| logger.addHandler(handler) | |||
| _all_loggers.append(logger) | |||
| return logger | |||
| def set_log_level(level, update_existing=True): | |||
| """Sets default logging level. | |||
| :type level: int e.g. logging.INFO | |||
| :param level: loggin level given by python :mod:`logging` module | |||
| :param update_existing: whether to update existing loggers | |||
| """ | |||
| global _default_level # pylint: disable=global-statement | |||
| _default_level = level | |||
| if update_existing: | |||
| for i in _all_loggers: | |||
| i.setLevel(level) | |||
| _logger = get_logger(__name__) | |||
| try: | |||
| if sys.version_info.major < 3: | |||
| raise ImportError() | |||
| from megengine._internal.logconf import set_logger as _set_mgb_logger | |||
| class MegBrainLogFormatter(MegEngineLogFormatter): | |||
| date = "%(asctime)s[mgb] " | |||
| def _color_date(self, msg): | |||
| return "\x1b[33m{}\x1b[0m".format(msg) | |||
| _megbrain_logger = get_logger("megbrain", MegBrainLogFormatter) | |||
| _set_mgb_logger(_megbrain_logger) | |||
| def set_mgb_log_level(level): | |||
| r"""Sets megbrain log level | |||
| :type level: int e.g. logging.INFO | |||
| :param level: new log level | |||
| :return: original log level | |||
| """ | |||
| logger = _megbrain_logger | |||
| rst = logger.getEffectiveLevel() | |||
| logger.setLevel(level) | |||
| return rst | |||
| except ImportError as exc: | |||
| def set_mgb_log_level(level): | |||
| raise NotImplementedError("megbrain has not been imported") | |||
| @contextlib.contextmanager | |||
| def replace_mgb_log_level(level): | |||
| r"""Replaces megbrain log level in a block and restore after exiting. | |||
| :type level: int e.g. logging.INFO | |||
| :param level: new log level | |||
| """ | |||
| old = set_mgb_log_level(level) | |||
| try: | |||
| yield | |||
| finally: | |||
| set_mgb_log_level(old) | |||
| def enable_debug_log(): | |||
| r"""Sets logging level to debug for all components. | |||
| """ | |||
| set_log_level(logging.DEBUG) | |||
| set_mgb_log_level(logging.DEBUG) | |||
| @@ -1,23 +0,0 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from .activation import LeakyReLU, PReLU, ReLU, Sigmoid, Softmax | |||
| from .batchnorm import BatchNorm1d, BatchNorm2d, SyncBatchNorm | |||
| from .concat import Concat | |||
| from .conv import Conv2d, ConvRelu2d, ConvTranspose2d, LocalConv2d | |||
| from .conv_bn import ConvBn2d, ConvBnRelu2d | |||
| from .dropout import Dropout | |||
| from .elemwise import Elemwise | |||
| from .embedding import Embedding | |||
| from .identity import Identity | |||
| from .linear import Linear | |||
| from .module import Module | |||
| from .parampack import ParamPack | |||
| from .pooling import AvgPool2d, MaxPool2d | |||
| from .quant_dequant import DequantStub, QuantStub | |||
| from .sequential import Sequential | |||
| @@ -1,231 +0,0 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import numpy as np | |||
| from ..core import Parameter | |||
| from ..functional import leaky_relu, prelu, relu, sigmoid, softmax | |||
| from .module import Module | |||
| class Softmax(Module): | |||
| r""" | |||
| Applies a softmax function. Softmax is defined as: | |||
| .. math:: | |||
| \text{Softmax}(x_{i}) = \frac{exp(x_i)}{\sum_j exp(x_j)} | |||
| It is applied to an n-dimensional input Tensor and rescaling them so that the elements of the | |||
| n-dimensional output Tensor lie in the range of `[0, 1]` and sum to 1. | |||
| :param axis: An axis along which softmax will be applied. By default, | |||
| softmax will apply along the highest ranked axis. | |||
| Examples: | |||
| .. testcode:: | |||
| import numpy as np | |||
| import megengine as mge | |||
| import megengine.module as M | |||
| data = mge.tensor(np.array([-2,-1,0,1,2]).astype(np.float32)) | |||
| softmax = M.Softmax() | |||
| output = softmax(data) | |||
| with np.printoptions(precision=6): | |||
| print(output.numpy()) | |||
| Outputs: | |||
| .. testoutput:: | |||
| [0.011656 0.031685 0.086129 0.234122 0.636409] | |||
| """ | |||
| def __init__(self, axis=None): | |||
| super().__init__() | |||
| self.axis = axis | |||
| def forward(self, inputs): | |||
| return softmax(inputs, self.axis) | |||
| class Sigmoid(Module): | |||
| r""" | |||
| Applies the element-wise function: | |||
| .. math:: | |||
| \text{Sigmoid}(x) = \frac{1}{1 + \exp(-x)} | |||
| Examples: | |||
| .. testcode:: | |||
| import numpy as np | |||
| import megengine as mge | |||
| import megengine.module as M | |||
| data = mge.tensor(np.array([-2,-1,0,1,2,]).astype(np.float32)) | |||
| sigmoid = M.Sigmoid() | |||
| output = sigmoid(data) | |||
| with np.printoptions(precision=6): | |||
| print(output.numpy()) | |||
| Outputs: | |||
| .. testoutput:: | |||
| [0.119203 0.268941 0.5 0.731059 0.880797] | |||
| """ | |||
| def forward(self, inputs): | |||
| return sigmoid(inputs) | |||
| class ReLU(Module): | |||
| r""" | |||
| Applies the element-wise function: | |||
| .. math:: | |||
| \text{ReLU}(x) = \max(x, 0) | |||
| Examples: | |||
| .. testcode:: | |||
| import numpy as np | |||
| import megengine as mge | |||
| import megengine.module as M | |||
| data = mge.tensor(np.array([-2,-1,0,1,2,]).astype(np.float32)) | |||
| relu = M.ReLU() | |||
| output = relu(data) | |||
| with np.printoptions(precision=6): | |||
| print(output.numpy()) | |||
| Outputs: | |||
| .. testoutput:: | |||
| [0. 0. 0. 1. 2.] | |||
| """ | |||
| def forward(self, x): | |||
| return relu(x) | |||
| class PReLU(Module): | |||
| r""" | |||
| Applies the element-wise function: | |||
| .. math:: | |||
| \text{PReLU}(x) = \max(0,x) + a * \min(0,x) | |||
| or | |||
| .. math:: | |||
| \text{PReLU}(x) = | |||
| \begin{cases} | |||
| x, & \text{ if } x \geq 0 \\ | |||
| ax, & \text{ otherwise } | |||
| \end{cases} | |||
| Here :math:`a` is a learnable parameter. When called without arguments, `PReLU()` uses | |||
| a single paramter :math:`a` across all input channel. If called with `PReLU(num_of_channels)`, | |||
| a seperate :math:`a` is used for each input channle. | |||
| :param num_parameters: number of :math:`a` to learn, there is only two | |||
| values are legitimate: 1, or the number of channels at input. Default: 1 | |||
| :param init: the initial value of :math:`a`. Default: 0.25 | |||
| Examples: | |||
| .. testcode:: | |||
| import numpy as np | |||
| import megengine as mge | |||
| import megengine.module as M | |||
| data = mge.tensor(np.array([-1.2, -3.7, 2.7]).astype(np.float32)) | |||
| prelu = M.PReLU() | |||
| output = prelu(data) | |||
| print(output.numpy()) | |||
| Outputs: | |||
| .. testoutput:: | |||
| [-0.3 -0.925 2.7 ] | |||
| """ | |||
| def __init__(self, num_parameters: int = 1, init: float = 0.25): | |||
| super().__init__() | |||
| self.num_parameters = num_parameters | |||
| if num_parameters > 1: | |||
| # Assume format is NCHW | |||
| self.weight = Parameter( | |||
| value=np.full((1, num_parameters, 1, 1), init, dtype=np.float32) | |||
| ) | |||
| else: | |||
| self.weight = Parameter(value=[init]) | |||
| def forward(self, inputs): | |||
| assert self.weight.shape == (1,) or self.weight.shape == ( | |||
| 1, | |||
| int(inputs.shape[1]), | |||
| 1, | |||
| 1, | |||
| ), "invalid weight's shape" | |||
| return prelu(inputs, self.weight) | |||
| class LeakyReLU(Module): | |||
| r""" | |||
| Applies the element-wise function: | |||
| .. math:: | |||
| \text{LeakyReLU}(x) = \max(0,x) + negative\_slope \times \min(0,x) | |||
| or | |||
| .. math:: | |||
| \text{LeakyReLU}(x) = | |||
| \begin{cases} | |||
| x, & \text{ if } x \geq 0 \\ | |||
| negative\_slope \times x, & \text{ otherwise } | |||
| \end{cases} | |||
| Examples: | |||
| .. testcode:: | |||
| import numpy as np | |||
| import megengine as mge | |||
| import megengine.module as M | |||
| data = mge.tensor(np.array([-8, -12, 6, 10]).astype(np.float32)) | |||
| leakyrelu = M.LeakyReLU(0.01) | |||
| output = leakyrelu(data) | |||
| print(output.numpy()) | |||
| Outputs: | |||
| .. testoutput:: | |||
| [-0.08 -0.12 6. 10. ] | |||
| """ | |||
| def __init__(self, negative_slope: float = 0.01): | |||
| super().__init__() | |||
| self.negative_slope = negative_slope | |||
| def forward(self, inputs): | |||
| return leaky_relu(inputs, self.negative_slope) | |||
| @@ -1,257 +0,0 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import numpy as np | |||
| from ..core import Buffer, Parameter | |||
| from ..core.device import get_default_device | |||
| from ..functional import batch_norm2d, sync_batch_norm | |||
| from . import init | |||
| from .module import Module | |||
| class _BatchNorm(Module): | |||
| def __init__( | |||
| self, | |||
| num_features, | |||
| eps=1e-5, | |||
| momentum=0.9, | |||
| affine=True, | |||
| track_running_stats=True, | |||
| ): | |||
| super(_BatchNorm, self).__init__() | |||
| self.num_features = num_features | |||
| self.eps = eps | |||
| self.momentum = momentum | |||
| self.affine = affine | |||
| self.track_running_stats = track_running_stats | |||
| if self.affine: | |||
| self.weight = Parameter(np.ones(num_features, dtype=np.float32)) | |||
| self.bias = Parameter(np.zeros(num_features, dtype=np.float32)) | |||
| else: | |||
| self.weight = None | |||
| self.bias = None | |||
| tshape = (1, self.num_features, 1, 1) | |||
| if self.track_running_stats: | |||
| self.running_mean = Buffer(np.zeros(tshape, dtype=np.float32)) | |||
| self.running_var = Buffer(np.ones(tshape, dtype=np.float32)) | |||
| else: | |||
| self.running_mean = None | |||
| self.running_var = None | |||
| def reset_running_stats(self) -> None: | |||
| if self.track_running_stats: | |||
| init.zeros_(self.running_mean) | |||
| init.ones_(self.running_var) | |||
| def reset_parameters(self) -> None: | |||
| self.reset_running_stats() | |||
| if self.affine: | |||
| init.ones_(self.weight) | |||
| init.zeros_(self.bias) | |||
| def _check_input_ndim(self, inp): | |||
| raise NotImplementedError | |||
| def forward(self, inp): | |||
| self._check_input_ndim(inp) | |||
| _ndims = len(inp.shape) | |||
| if _ndims != 4: | |||
| origin_shape = inp.shapeof() | |||
| if _ndims == 2: | |||
| n, c = inp.shapeof(0), inp.shapeof(1) | |||
| new_shape = (n, c, 1, 1) | |||
| elif _ndims == 3: | |||
| n, c, h = inp.shapeof(0), inp.shapeof(1), inp.shapeof(2) | |||
| new_shape = (n, c, h, 1) | |||
| inp = inp.reshape(new_shape) | |||
| if self.training and self.track_running_stats: | |||
| exponential_average_factor = self.momentum | |||
| else: | |||
| exponential_average_factor = 0.0 # useless | |||
| # FIXME currently rocm does not support real bn opr so we just use | |||
| # sync_batch_norm(as implemented by elemwise) here, | |||
| # we will fix it in the next version | |||
| if get_default_device() == "rocmx": | |||
| output = sync_batch_norm( | |||
| inp, | |||
| self.running_mean, | |||
| self.running_var, | |||
| self.weight, | |||
| self.bias, | |||
| self.training or not self.track_running_stats, | |||
| exponential_average_factor, | |||
| self.eps, | |||
| ) | |||
| else: | |||
| output = batch_norm2d( | |||
| inp, | |||
| self.running_mean, | |||
| self.running_var, | |||
| self.weight, | |||
| self.bias, | |||
| self.training or not self.track_running_stats, | |||
| exponential_average_factor, | |||
| self.eps, | |||
| ) | |||
| if _ndims != 4: | |||
| output = output.reshape(origin_shape) | |||
| return output | |||
| class SyncBatchNorm(_BatchNorm): | |||
| r""" | |||
| Applies Synchronization Batch Normalization. | |||
| """ | |||
| def _check_input_ndim(self, inp): | |||
| if len(inp.shape) not in {2, 3, 4}: | |||
| raise ValueError( | |||
| "expected 2D, 3D or 4D input (got {}D input)".format(len(inp.shape)) | |||
| ) | |||
| def forward(self, inp): | |||
| self._check_input_ndim(inp) | |||
| _ndims = len(inp.shape) | |||
| if _ndims != 4: | |||
| origin_shape = inp.shapeof() | |||
| if _ndims == 2: | |||
| n, c = inp.shapeof(0), inp.shapeof(1) | |||
| new_shape = (n, c, 1, 1) | |||
| elif _ndims == 3: | |||
| n, c, h = inp.shapeof(0), inp.shapeof(1), inp.shapeof(2) | |||
| new_shape = (n, c, h, 1) | |||
| inp = inp.reshape(new_shape) | |||
| if self.training and self.track_running_stats: | |||
| exponential_average_factor = self.momentum | |||
| else: | |||
| exponential_average_factor = 0.0 # useless | |||
| output = sync_batch_norm( | |||
| inp, | |||
| self.running_mean, | |||
| self.running_var, | |||
| self.weight, | |||
| self.bias, | |||
| self.training or not self.track_running_stats, | |||
| exponential_average_factor, | |||
| self.eps, | |||
| ) | |||
| if _ndims != 4: | |||
| output = output.reshape(origin_shape) | |||
| return output | |||
| class BatchNorm1d(_BatchNorm): | |||
| r""" | |||
| Applies Batch Normalization over a 2D/3D tensor. | |||
| Refer to :class:`~.BatchNorm2d` for more information. | |||
| """ | |||
| def _check_input_ndim(self, inp): | |||
| if len(inp.shape) not in {2, 3}: | |||
| raise ValueError( | |||
| "expected 2D or 3D input (got {}D input)".format(len(inp.shape)) | |||
| ) | |||
| class BatchNorm2d(_BatchNorm): | |||
| r""" | |||
| Applies Batch Normalization over a 4D tensor. | |||
| .. math:: | |||
| y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta | |||
| The mean and standard-deviation are calculated per-dimension over | |||
| the mini-batches and :math:`\gamma` and :math:`\beta` are learnable | |||
| parameter vectors. | |||
| By default, during training this layer keeps running estimates of its | |||
| computed mean and variance, which are then used for normalization during | |||
| evaluation. The running estimates are kept with a default :attr:`momentum` | |||
| of 0.9. | |||
| If :attr:`track_running_stats` is set to ``False``, this layer will not | |||
| keep running estimates, and batch statistics are instead used during | |||
| evaluation time. | |||
| .. note:: | |||
| This :attr:`momentum` argument is different from one used in optimizer | |||
| classes and the conventional notion of momentum. Mathematically, the | |||
| update rule for running statistics here is | |||
| :math:`\hat{x}_\text{new} = \text{momentum} \times \hat{x} + (1 - \text{momentum}) \times x_t`, | |||
| where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the | |||
| new observed value. | |||
| Because the Batch Normalization is done over the `C` dimension, computing | |||
| statistics on `(N, H, W)` slices, it's common terminology to call this | |||
| Spatial Batch Normalization. | |||
| :type num_features: int | |||
| :param num_features: usually the :math:`C` from an input of size | |||
| :math:`(N, C, H, W)` or the highest ranked dimension of an input with | |||
| less than 4D. | |||
| :type eps: float | |||
| :param eps: a value added to the denominator for numerical stability. | |||
| Default: 1e-5. | |||
| :type momentum: float | |||
| :param momentum: the value used for the `running_mean` and `running_var` | |||
| computation. | |||
| Default: 0.9 | |||
| :type affine: bool | |||
| :param affine: a boolean value that when set to ``True``, this module has | |||
| learnable affine parameters. Default: ``True`` | |||
| :type track_running_stats: bool | |||
| :param track_running_stats: when set to ``True``, this module tracks the | |||
| running mean and variance. When set to ``False``, this module does not | |||
| track such statistics and always uses batch statistics in both training | |||
| and eval modes. Default: ``True``. | |||
| Examples: | |||
| .. testcode:: | |||
| import numpy as np | |||
| import megengine as mge | |||
| import megengine.module as M | |||
| # With Learnable Parameters | |||
| m = M.BatchNorm2d(4) | |||
| inp = mge.tensor(np.random.rand(1, 4, 3, 3).astype("float32")) | |||
| oup = m(inp) | |||
| print(m.weight, m.bias) | |||
| # Without Learnable Parameters | |||
| m = M.BatchNorm2d(4, affine=False) | |||
| oup = m(inp) | |||
| print(m.weight, m.bias) | |||
| .. testoutput:: | |||
| Tensor([1. 1. 1. 1.]) Tensor([0. 0. 0. 0.]) | |||
| None None | |||
| """ | |||
| def _check_input_ndim(self, inp): | |||
| if len(inp.shape) != 4: | |||
| raise ValueError("expected 4D input (got {}D input)".format(len(inp.shape))) | |||
| @@ -1,22 +0,0 @@ | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from typing import Iterable | |||
| from .. import functional as F | |||
| from ..core.tensor import Tensor | |||
| from .module import Module | |||
| class Concat(Module): | |||
| r""" | |||
| A :class:`~.Module` to do functional concat. Could be replaced with :class:`~.QATModule` | |||
| version :class:`~.qat.concat.Concat` using :func:`~.quantize.quantize_qat`. | |||
| """ | |||
| def forward(self, inps: Iterable[Tensor], axis: int = 0): | |||
| return F.concat(inps, axis) | |||
| @@ -1,392 +0,0 @@ | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from abc import abstractmethod | |||
| from typing import Tuple, Union | |||
| import numpy as np | |||
| import megengine._internal as mgb | |||
| from .. import functional as F | |||
| from ..core import Parameter | |||
| from ..utils.types import _pair, _pair_nonzero | |||
| from . import init | |||
| from .module import Module | |||
| class _ConvNd(Module): | |||
| """base class for convolution modules, including transposed conv""" | |||
| def __init__( | |||
| self, | |||
| in_channels: int, | |||
| out_channels: int, | |||
| kernel_size: Union[int, Tuple[int, int]], | |||
| stride: Union[int, Tuple[int, int]], | |||
| padding: Union[int, Tuple[int, int]], | |||
| dilation: Union[int, Tuple[int, int]], | |||
| groups: int, | |||
| bias: bool = True, | |||
| ): | |||
| super().__init__() | |||
| if in_channels % groups != 0: | |||
| raise ValueError("in_channels must be divisible by groups") | |||
| if out_channels % groups != 0: | |||
| raise ValueError("out_channels must be divisible by groups") | |||
| self.in_channels = in_channels | |||
| self.out_channels = out_channels | |||
| self.kernel_size = kernel_size | |||
| self.stride = stride | |||
| self.padding = padding | |||
| self.dilation = dilation | |||
| self.groups = groups | |||
| self.weight = Parameter(np.zeros(self._infer_weight_shape(), dtype=np.float32)) | |||
| self.bias = None | |||
| if bias: | |||
| self.bias = Parameter(np.zeros(self._infer_bias_shape(), dtype=np.float32)) | |||
| self.reset_parameters() | |||
| @abstractmethod | |||
| def _get_fanin(self): | |||
| pass | |||
| def reset_parameters(self) -> None: | |||
| fanin = self._get_fanin() | |||
| std = np.sqrt(1 / fanin) | |||
| init.normal_(self.weight, 0.0, std) | |||
| if self.bias is not None: | |||
| init.zeros_(self.bias) | |||
| @abstractmethod | |||
| def _infer_weight_shape(self): | |||
| pass | |||
| @abstractmethod | |||
| def _infer_bias_shape(self): | |||
| pass | |||
| class Conv2d(_ConvNd): | |||
| r"""Applies a 2D convolution over an input tensor. | |||
| For instance, given an input of the size :math:`(N, C_{\text{in}}, H, W)`, | |||
| this layer generates an output of the size | |||
| :math:`(N, C_{\text{out}}, H_{\text{out}}, W_{\text{out}})` through the | |||
| process described as below: | |||
| .. math:: | |||
| \text{out}(N_i, C_{\text{out}_j}) = \text{bias}(C_{\text{out}_j}) + | |||
| \sum_{k = 0}^{C_{\text{in}} - 1} \text{weight}(C_{\text{out}_j}, k) \star \text{input}(N_i, k) | |||
| where :math:`\star` is the valid 2D cross-correlation operator, | |||
| :math:`N` is a batch size, :math:`C` denotes a number of channels, | |||
| :math:`H` is a height of input planes in pixels, and :math:`W` is | |||
| width in pixels. | |||
| When ``groups == in_channels`` and ``out_channels == K * in_channels``, | |||
| where `K` is a positive integer, this operation is also known as depthwise | |||
| convolution. | |||
| In other words, for an input of size :math:`(N, C_{in}, H_{in}, W_{in})`, | |||
| a depthwise convolution with a depthwise multiplier `K`, can be constructed | |||
| by arguments :math:`(in\_channels=C_{in}, out\_channels=C_{in} \times K, ..., groups=C_{in})`. | |||
| :param in_channels: number of input channels. | |||
| :param out_channels: number of output channels. | |||
| :param kernel_size: size of weight on spatial dimensions. If ``kernel_size`` is | |||
| an :class:`int`, the actual kernel size would be | |||
| ``(kernel_size, kernel_size)``. Default: 1 | |||
| :param stride: stride of the 2D convolution operation. Default: 1 | |||
| :param padding: size of the paddings added to the input on both sides of its | |||
| spatial dimensions. Only zero-padding is supported. Default: 0 | |||
| :param dilation: dilation of the 2D convolution operation. Default: 1 | |||
| :param groups: number of groups to divide input and output channels into, | |||
| so as to perform a "grouped convolution". When ``groups`` is not 1, | |||
| ``in_channels`` and ``out_channels`` must be divisible by ``groups``, | |||
| and there would be an extra dimension at the beginning of the weight's | |||
| shape. Specifically, the shape of weight would be ``(groups, | |||
| out_channel // groups, in_channels // groups, *kernel_size)``. | |||
| :param bias: whether to add a bias onto the result of convolution. Default: | |||
| True | |||
| :param conv_mode: Supports `CROSS_CORRELATION` or `CONVOLUTION`. Default: | |||
| `CROSS_CORRELATION`. | |||
| :param compute_mode: When set to `DEFAULT`, no special requirements will be | |||
| placed on the precision of intermediate results. When set to `FLOAT32`, | |||
| float32 would be used for accumulator and intermediate result, but only | |||
| effective when input and output are of float16 dtype. | |||
| """ | |||
| _conv_mode_type = mgb.opr_param_defs.Convolution.Mode | |||
| _compute_mode_type = mgb.opr_param_defs.Convolution.ComputeMode | |||
| def __init__( | |||
| self, | |||
| in_channels: int, | |||
| out_channels: int, | |||
| kernel_size: Union[int, Tuple[int, int]], | |||
| stride: Union[int, Tuple[int, int]] = 1, | |||
| padding: Union[int, Tuple[int, int]] = 0, | |||
| dilation: Union[int, Tuple[int, int]] = 1, | |||
| groups: int = 1, | |||
| bias: bool = True, | |||
| conv_mode: str = "CROSS_CORRELATION", | |||
| compute_mode: str = "DEFAULT", | |||
| ): | |||
| kernel_size = _pair_nonzero(kernel_size) | |||
| stride = _pair_nonzero(stride) | |||
| padding = _pair(padding) | |||
| dilation = _pair_nonzero(dilation) | |||
| self.conv_mode = self._conv_mode_type.convert(conv_mode) | |||
| self.compute_mode = self._compute_mode_type.convert(compute_mode) | |||
| super().__init__( | |||
| in_channels, | |||
| out_channels, | |||
| kernel_size, | |||
| stride, | |||
| padding, | |||
| dilation, | |||
| groups, | |||
| bias, | |||
| ) | |||
| def _get_fanin(self): | |||
| kh, kw = self.kernel_size | |||
| ic = self.in_channels | |||
| return kh * kw * ic | |||
| def _infer_weight_shape(self): | |||
| group = self.groups | |||
| ichl = self.in_channels | |||
| ochl = self.out_channels | |||
| kh, kw = self.kernel_size | |||
| if group == 1: | |||
| # Assume format is NCHW | |||
| return (ochl, ichl, kh, kw) | |||
| assert ( | |||
| ichl % group == 0 and ochl % group == 0 | |||
| ), "invalid config: input_channels={} output_channels={} group={}".format( | |||
| ichl, ochl, group | |||
| ) | |||
| # Assume format is NCHW | |||
| return (group, ochl // group, ichl // group, kh, kw) | |||
| def _infer_bias_shape(self): | |||
| # Assume format is NCHW | |||
| return (1, self.out_channels, 1, 1) | |||
| def calc_conv(self, inp, weight, bias): | |||
| return F.conv2d( | |||
| inp, | |||
| weight, | |||
| bias, | |||
| self.stride, | |||
| self.padding, | |||
| self.dilation, | |||
| self.groups, | |||
| self.conv_mode, | |||
| self.compute_mode, | |||
| ) | |||
| def forward(self, inp): | |||
| return self.calc_conv(inp, self.weight, self.bias) | |||
| class ConvTranspose2d(_ConvNd): | |||
| r"""Applies a 2D transposed convolution over an input tensor. | |||
| This module is also known as a deconvolution or a fractionally-strided convolution. | |||
| :class:`ConvTranspose2d` can ben seen as the gradient of :class:`Conv2d` operation | |||
| with respect to its input. | |||
| Convolution usually reduces the size of input, while transposed convolution works | |||
| the opposite way, transforming a smaller input to a larger output while preserving the | |||
| connectivity pattern. | |||
| :param in_channels: number of input channels. | |||
| :param out_channels: number of output channels. | |||
| :param kernel_size: size of weight on spatial dimensions. If ``kernel_size`` is | |||
| an :class:`int`, the actual kernel size would be | |||
| ``(kernel_size, kernel_size)``. Default: 1 | |||
| :param stride: stride of the 2D convolution operation. Default: 1 | |||
| :param padding: size of the paddings added to the input on both sides of its | |||
| spatial dimensions. Only zero-padding is supported. Default: 0 | |||
| :param dilation: dilation of the 2D convolution operation. Default: 1 | |||
| :param groups: number of groups to divide input and output channels into, | |||
| so as to perform a "grouped convolution". When ``groups`` is not 1, | |||
| ``in_channels`` and ``out_channels`` must be divisible by ``groups``, | |||
| and there would be an extra dimension at the beginning of the weight's | |||
| shape. Specifically, the shape of weight would be ``(groups, | |||
| out_channels // groups, in_channels // groups, *kernel_size)``. Default: 1 | |||
| :param bias: wether to add a bias onto the result of convolution. Default: | |||
| True | |||
| :param conv_mode: Supports `CROSS_CORRELATION` or `CONVOLUTION`. Default: | |||
| `CROSS_CORRELATION`. | |||
| :param compute_mode: When set to `DEFAULT`, no special requirements will be | |||
| placed on the precision of intermediate results. When set to `FLOAT32`, | |||
| float32 would be used for accumulator and intermediate result, but only | |||
| effective when input and output are of float16 dtype. | |||
| """ | |||
| _conv_mode_type = mgb.opr_param_defs.Convolution.Mode | |||
| _compute_mode_type = mgb.opr_param_defs.Convolution.ComputeMode | |||
| def __init__( | |||
| self, | |||
| in_channels: int, | |||
| out_channels: int, | |||
| kernel_size: Union[int, Tuple[int, int]], | |||
| stride: Union[int, Tuple[int, int]] = 1, | |||
| padding: Union[int, Tuple[int, int]] = 0, | |||
| dilation: Union[int, Tuple[int, int]] = 1, | |||
| groups: int = 1, | |||
| bias: bool = True, | |||
| conv_mode: str = "CROSS_CORRELATION", | |||
| compute_mode: str = "DEFAULT", | |||
| ): | |||
| kernel_size = _pair_nonzero(kernel_size) | |||
| stride = _pair_nonzero(stride) | |||
| padding = _pair(padding) | |||
| dilation = _pair_nonzero(dilation) | |||
| self.conv_mode = self._conv_mode_type.convert(conv_mode) | |||
| self.compute_mode = self._compute_mode_type.convert(compute_mode) | |||
| super().__init__( | |||
| in_channels, | |||
| out_channels, | |||
| kernel_size, | |||
| stride, | |||
| padding, | |||
| dilation, | |||
| groups, | |||
| bias, | |||
| ) | |||
| def _get_fanin(self): | |||
| kh, kw = self.kernel_size | |||
| oc = self.out_channels | |||
| return kh * kw * oc | |||
| def _infer_weight_shape(self): | |||
| group = self.groups | |||
| ichl = self.in_channels | |||
| ochl = self.out_channels | |||
| kh, kw = self.kernel_size | |||
| if group == 1: | |||
| # Assume format is NCHW | |||
| return (ichl, ochl, kh, kw) | |||
| assert ( | |||
| ichl % group == 0 and ochl % group == 0 | |||
| ), "invalid config: input_channels={} output_channels={} group={}".format( | |||
| ichl, ochl, group | |||
| ) | |||
| # Assume format is NCHW | |||
| return (group, ichl // group, ochl // group, kh, kw) | |||
| def _infer_bias_shape(self): | |||
| # Assume format is NCHW | |||
| return (1, self.out_channels, 1, 1) | |||
| def forward(self, inp): | |||
| return F.conv_transpose2d( | |||
| inp, | |||
| self.weight, | |||
| self.bias, | |||
| self.stride, | |||
| self.padding, | |||
| self.dilation, | |||
| self.groups, | |||
| self.conv_mode, | |||
| self.compute_mode, | |||
| ) | |||
| class LocalConv2d(Conv2d): | |||
| r"""Applies a spatial convolution with untied kernels over an input 4D tensor. | |||
| It is also known as the locally connected layer. | |||
| :param in_channels: number of input channels. | |||
| :param out_channels: number of output channels. | |||
| :param input_height: the height of the input images. | |||
| :param input_width: the width of the input images. | |||
| :param kernel_size: size of weight on spatial dimensions. If ``kernel_size`` is | |||
| an :class:`int`, the actual kernel size would be | |||
| ``(kernel_size, kernel_size)``. Default: 1 | |||
| :param stride: stride of the 2D convolution operation. Default: 1 | |||
| :param padding: size of the paddings added to the input on both sides of its | |||
| spatial dimensions. Only zero-padding is supported. Default: 0 | |||
| :param groups: number of groups to divide input and output channels into, | |||
| so as to perform a "grouped convolution". When ``groups`` is not 1, | |||
| ``in_channels`` and ``out_channels`` must be divisible by ``groups``. | |||
| The shape of weight is ``(groups, output_height, output_width, | |||
| in_channels // groups, *kernel_size, out_channels // groups)``. | |||
| """ | |||
| _conv_mode_type = mgb.opr_param_defs.Convolution.Mode | |||
| def __init__( | |||
| self, | |||
| in_channels: int, | |||
| out_channels: int, | |||
| input_height: int, | |||
| input_width: int, | |||
| kernel_size: Union[int, Tuple[int, int]], | |||
| stride: Union[int, Tuple[int, int]] = 1, | |||
| padding: Union[int, Tuple[int, int]] = 0, | |||
| dilation: Union[int, Tuple[int, int]] = 1, | |||
| groups: int = 1, | |||
| conv_mode: str = "CROSS_CORRELATION", | |||
| ): | |||
| self.input_height = input_height | |||
| self.input_width = input_width | |||
| super().__init__( | |||
| in_channels, | |||
| out_channels, | |||
| kernel_size, | |||
| stride, | |||
| padding, | |||
| dilation, | |||
| groups, | |||
| bias=False, | |||
| ) | |||
| def _infer_weight_shape(self): | |||
| group = self.groups | |||
| output_height = ( | |||
| self.input_height + self.padding[0] * 2 - self.kernel_size[0] | |||
| ) // self.stride[0] + 1 | |||
| output_width = ( | |||
| self.input_width + self.padding[1] * 2 - self.kernel_size[1] | |||
| ) // self.stride[1] + 1 | |||
| # Assume format is NCHW | |||
| return ( | |||
| group, | |||
| output_height, | |||
| output_width, | |||
| self.in_channels // group, | |||
| self.kernel_size[0], | |||
| self.kernel_size[1], | |||
| self.out_channels // group, | |||
| ) | |||
| def forward(self, inp): | |||
| return F.local_conv2d( | |||
| inp, self.weight, self.stride, self.padding, self.dilation, self.conv_mode | |||
| ) | |||
| class ConvRelu2d(Conv2d): | |||
| r""" | |||
| A fused :class:`~.Module` including Conv2d and relu. Could be replaced | |||
| with :class:`~.QATModule` version :class:`~.qat.conv.ConvRelu2d` using | |||
| :func:`~.quantize.quantize_qat`. | |||
| """ | |||
| def forward(self, inp): | |||
| return F.relu(self.calc_conv(inp, self.weight, self.bias)) | |||
| @@ -1,69 +0,0 @@ | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from typing import Tuple, Union | |||
| from ..functional import relu | |||
| from .batchnorm import BatchNorm2d | |||
| from .conv import Conv2d | |||
| from .module import Module | |||
| class _ConvBnActivation2d(Module): | |||
| def __init__( | |||
| self, | |||
| in_channels: int, | |||
| out_channels: int, | |||
| kernel_size: Union[int, Tuple[int, int]], | |||
| stride: Union[int, Tuple[int, int]] = 1, | |||
| padding: Union[int, Tuple[int, int]] = 0, | |||
| dilation: Union[int, Tuple[int, int]] = 1, | |||
| groups: int = 1, | |||
| bias: bool = True, | |||
| conv_mode: str = "CROSS_CORRELATION", | |||
| compute_mode: str = "DEFAULT", | |||
| eps=1e-5, | |||
| momentum=0.9, | |||
| affine=True, | |||
| track_running_stats=True, | |||
| ): | |||
| super().__init__() | |||
| self.conv = Conv2d( | |||
| in_channels, | |||
| out_channels, | |||
| kernel_size, | |||
| stride, | |||
| padding, | |||
| dilation, | |||
| groups, | |||
| bias, | |||
| conv_mode, | |||
| compute_mode, | |||
| ) | |||
| self.bn = BatchNorm2d(out_channels, eps, momentum, affine, track_running_stats) | |||
| class ConvBn2d(_ConvBnActivation2d): | |||
| r""" | |||
| A fused :class:`~.Module` including Conv2d, BatchNorm2d. Could be replaced | |||
| with :class:`~.QATModule` version :class:`~.qat.conv_bn.ConvBn2d` using | |||
| :func:`~.quantize.quantize_qat`. | |||
| """ | |||
| def forward(self, inp): | |||
| return self.bn(self.conv(inp)) | |||
| class ConvBnRelu2d(_ConvBnActivation2d): | |||
| r""" | |||
| A fused :class:`~.Module` including Conv2d, BatchNorm2d and relu. Could be replaced | |||
| with :class:`~.QATModule` version :class:`~.qat.conv_bn.ConvBnRelu2d` using | |||
| :func:`~.quantize.quantize_qat`. | |||
| """ | |||
| def forward(self, inp): | |||
| return relu(self.bn(self.conv(inp))) | |||
| @@ -1,29 +0,0 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from ..functional import dropout | |||
| from .module import Module | |||
| class Dropout(Module): | |||
| r"""Randomly set input elements to zeros with the probability :math:`drop\_prob` during training. Commonly used in large networks to prevent overfitting. | |||
| Note that we perform dropout only during training, we also rescale(multiply) the output tensor | |||
| by :math:`\frac{1}{1 - drop\_prob}`. During inference :class:`~.Dropout` is equal to :class:`~.Identity`. | |||
| :param drop_prob: The probability to drop (set to zero) each single element | |||
| """ | |||
| def __init__(self, drop_prob=0.0): | |||
| super().__init__() | |||
| self.drop_prob = drop_prob | |||
| def forward(self, inputs): | |||
| if self.training: | |||
| return dropout(inputs, self.drop_prob, rescale=True) | |||
| else: | |||
| return inputs | |||
| @@ -1,90 +0,0 @@ | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from .. import _internal as mgb | |||
| from ..core import Tensor, wrap_io_tensor | |||
| from ..core.graph import _use_default_if_none | |||
| from .module import Module | |||
| @wrap_io_tensor | |||
| def _elemwise_func(mode, *inputs, **kwargs) -> Tensor: | |||
| if all(isinstance(i, (int, float)) for i in inputs): | |||
| device, comp_graph = _use_default_if_none(None, None) | |||
| ret = mgb.opr.elemwise( | |||
| *inputs, mode=mode, comp_node=device, comp_graph=comp_graph, **kwargs | |||
| ) | |||
| return ret.inferred_value[0] | |||
| return mgb.opr.elemwise(*inputs, mode=mode, **kwargs) | |||
| class Elemwise(Module): | |||
| r""" | |||
| A :class:`~.Module` to do elemwise operator. Could be replaced with :class:`~.QATModule` | |||
| version :class:`~.qat.elemwise.Elemwise` using :func:`~.quantize.quantize_qat`. | |||
| :param method: the elemwise method, support the following string. | |||
| It will do the normal elemwise operator for float. | |||
| * "ADD": a + b | |||
| * "FUSE_ADD_RELU": max(x+y, 0) | |||
| * "MUL": x * y | |||
| * "MIN": min(x, y) | |||
| * "MAX": max(x, y) | |||
| * "SUB": x - y | |||
| * "TRUE_DIV": x / y | |||
| * "FUSE_ADD_SIGMOID": sigmoid(x + y) | |||
| * "FUSE_ADD_TANH": tanh(x + y) | |||
| * "RELU": x > 0 ? x : 0 | |||
| * "ABS": x > 0 ? x : -x | |||
| * "SIGMOID": sigmoid(x) | |||
| * "EXP": exp(x) | |||
| * "TANH": tanh(x) | |||
| * "FUSE_MUL_ADD3": x * y + z | |||
| * "FAST_TANH": fast_tanh(x) | |||
| * "NEGATE": -x | |||
| * "ACOS": acos(x) | |||
| * "ASIN": asin(x) | |||
| * "CEIL": ceil(x) | |||
| * "COS": cos(x) | |||
| * "EXPM1": expm1(x) | |||
| * "FLOOR": floor(x) | |||
| * "LOG": log(x) | |||
| * "LOG1P": log1p(x) | |||
| * "SIN": sin(x) | |||
| * "ROUND": round(x) | |||
| * "ERF": erf(x) | |||
| * "ERFINV": erfinv(x) | |||
| * "ERFC": erfc(x) | |||
| * "ERFCINV": erfcinv(x) | |||
| * "ABS_GRAD": abs_grad | |||
| * "FLOOR_DIV": floor_div | |||
| * "MOD": mod | |||
| * "SIGMOID_GRAD": sigmoid_grad | |||
| * "SWITCH_GT0": switch_gt0 | |||
| * "TANH_GRAD": tanh_grad | |||
| * "LT": lt | |||
| * "LEQ": leq | |||
| * "EQ": eq | |||
| * "POW": pow | |||
| * "LOG_SUM_EXP": log_sum_exp | |||
| * "FAST_TANH_GRAD": fast_tanh_grad | |||
| * "ATAN2": atan2 | |||
| * "COND_LEQ_MOV": cond_leq_mov | |||
| * "H_SWISH": h_swish | |||
| * "FUSE_ADD_H_SWISH": h_swish(x+y) | |||
| * "H_SWISH_GRAD": h_swish_grad | |||
| """ | |||
| _elemwise_mode_type = mgb.opr_param_defs.Elemwise.Mode | |||
| def __init__(self, method): | |||
| super().__init__() | |||
| self.method = self._elemwise_mode_type.convert(method) | |||
| def forward(self, *inps): | |||
| return _elemwise_func(self.method, *inps) | |||
| @@ -1,171 +0,0 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from typing import Optional | |||
| import numpy as np | |||
| from ..core import Parameter | |||
| from ..functional import embedding as embedding_func | |||
| from . import init | |||
| from .module import Module | |||
| class Embedding(Module): | |||
| r""" | |||
| A simple lookup table that stores embeddings of a fixed dictionary and size. | |||
| This module is often used to store word embeddings and retrieve them using indices. | |||
| The input to the module is a list of indices, and the output is the corresponding word embeddings. | |||
| The indices should less than num_embeddings. | |||
| :param num_embeddings: size of embedding dictionary. | |||
| :param embedding_dim: size of each embedding vector. | |||
| :param padding_idx: should be set to None, not support now. | |||
| :param max_norm: should be set to None, not support now. | |||
| :param norm_type: should be set to None, not support now. | |||
| :param initial_weight: the learnable weights of the module of shape (num_embeddings, embedding_dim). | |||
| Examples: | |||
| .. testcode:: | |||
| import numpy as np | |||
| import megengine as mge | |||
| import megengine.module as M | |||
| weight = mge.tensor(np.array([(1.2,2.3,3.4,4.5,5.6),(0.1,1.1,2.1,3.1,4.1)], dtype=np.float32)) | |||
| data = mge.tensor(np.array([(0,1,1),(1,0,1),(0,0,1)], dtype=np.int32)) | |||
| embedding = M.Embedding(2, 5, initial_weight=weight) | |||
| output = embedding(data) | |||
| with np.printoptions(precision=6): | |||
| print(output.numpy()) | |||
| Outputs: | |||
| .. testoutput:: | |||
| [[[1.2 2.3 3.4 4.5 5.6] | |||
| [0.1 1.1 2.1 3.1 4.1] | |||
| [0.1 1.1 2.1 3.1 4.1]] | |||
| [[0.1 1.1 2.1 3.1 4.1] | |||
| [1.2 2.3 3.4 4.5 5.6] | |||
| [0.1 1.1 2.1 3.1 4.1]] | |||
| [[1.2 2.3 3.4 4.5 5.6] | |||
| [1.2 2.3 3.4 4.5 5.6] | |||
| [0.1 1.1 2.1 3.1 4.1]]] | |||
| """ | |||
| def __init__( | |||
| self, | |||
| num_embeddings: int, | |||
| embedding_dim: int, | |||
| padding_idx: Optional[int] = None, | |||
| max_norm: Optional[float] = None, | |||
| norm_type: Optional[float] = None, | |||
| initial_weight: Parameter = None, | |||
| ): | |||
| super().__init__() | |||
| if padding_idx is not None: | |||
| raise ValueError("Not support padding index now.") | |||
| if max_norm is not None or norm_type is not None: | |||
| raise ValueError("Not support weight normalize now.") | |||
| self.padding_idx = padding_idx | |||
| self.max_norm = max_norm | |||
| self.norm_type = norm_type | |||
| self.num_embeddings = num_embeddings | |||
| self.embedding_dim = embedding_dim | |||
| if initial_weight is None: | |||
| self.weight = Parameter( | |||
| np.random.uniform( | |||
| size=(self.num_embeddings, self.embedding_dim) | |||
| ).astype(np.float32) | |||
| ) | |||
| self.reset_parameters() | |||
| else: | |||
| if initial_weight.shape != (num_embeddings, embedding_dim): | |||
| raise ValueError( | |||
| "The weight shape should match num_embeddings and embedding_dim" | |||
| ) | |||
| self.weight = Parameter(initial_weight.numpy()) | |||
| def reset_parameters(self) -> None: | |||
| init.normal_(self.weight) | |||
| def forward(self, inputs): | |||
| return embedding_func(inputs, self.weight) | |||
| @classmethod | |||
| def from_pretrained( | |||
| cls, | |||
| embeddings: Parameter, | |||
| freeze: Optional[bool] = True, | |||
| padding_idx: Optional[int] = None, | |||
| max_norm: Optional[float] = None, | |||
| norm_type: Optional[float] = None, | |||
| ): | |||
| r""" | |||
| Creates Embedding instance from given 2-dimensional FloatTensor. | |||
| :param embeddings: Tensor contained weight for the embedding. | |||
| :param freeze: If ``True``, the weight does not get updated during the learning process. Default: ``True``. | |||
| :param padding_idx: should be set to None, not support Now. | |||
| :param max_norm: should be set to None, not support Now. | |||
| :param norm_type: should be set to None, not support Now. | |||
| Examples: | |||
| .. testcode:: | |||
| import numpy as np | |||
| import megengine as mge | |||
| import megengine.module as M | |||
| weight = mge.tensor(np.array([(1.2,2.3,3.4,4.5,5.6),(0.1,1.1,2.1,3.1,4.1)], dtype=np.float32)) | |||
| data = mge.tensor(np.array([(0,1,1),(1,0,1),(0,0,1)], dtype=np.int32)) | |||
| embedding = M.Embedding.from_pretrained(weight, freeze=False) | |||
| output = embedding(data) | |||
| print(output.numpy()) | |||
| Outputs: | |||
| .. testoutput:: | |||
| [[[1.2 2.3 3.4 4.5 5.6] | |||
| [0.1 1.1 2.1 3.1 4.1] | |||
| [0.1 1.1 2.1 3.1 4.1]] | |||
| [[0.1 1.1 2.1 3.1 4.1] | |||
| [1.2 2.3 3.4 4.5 5.6] | |||
| [0.1 1.1 2.1 3.1 4.1]] | |||
| [[1.2 2.3 3.4 4.5 5.6] | |||
| [1.2 2.3 3.4 4.5 5.6] | |||
| [0.1 1.1 2.1 3.1 4.1]]] | |||
| """ | |||
| embeddings_shape = embeddings.shape | |||
| embeddings_dim = len(embeddings_shape) | |||
| if embeddings_dim != 2: | |||
| raise ValueError("Embeddings parameter is expected to be 2-dimensional") | |||
| rows = embeddings_shape[0] | |||
| cols = embeddings_shape[1] | |||
| embedding = cls( | |||
| num_embeddings=rows, | |||
| embedding_dim=cols, | |||
| initial_weight=embeddings, | |||
| padding_idx=padding_idx, | |||
| max_norm=max_norm, | |||
| norm_type=norm_type, | |||
| ) | |||
| embedding.weight.requires_grad = not freeze | |||
| return embedding | |||
| @@ -1,83 +0,0 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import numpy as np | |||
| from ..functional.external import ( | |||
| atlas_subgraph, | |||
| cambricon_subgraph, | |||
| extern_opr_subgraph, | |||
| ) | |||
| from .module import Module | |||
| class CambriconSubgraph(Module): | |||
| r"""Load a serialized Cambricon subgraph. | |||
| See :func:`~.cambricon_subgraph` for more details. | |||
| """ | |||
| def __init__( | |||
| self, data, symbol, tensor_dim_mutable, | |||
| ): | |||
| super(CambriconSubgraph, self).__init__() | |||
| self._data = data | |||
| self.symbol = symbol | |||
| self.tensor_dim_mutable = tensor_dim_mutable | |||
| @property | |||
| def data(self): | |||
| return self._data.tobytes() | |||
| @data.setter | |||
| def data(self, val): | |||
| self._data = np.frombuffer(val, dtype=np.uint8) | |||
| def forward(self, inputs): | |||
| outputs = cambricon_subgraph( | |||
| inputs, self._data, self.symbol, self.tensor_dim_mutable, | |||
| ) | |||
| return outputs | |||
| class AtlasSubgraph(Module): | |||
| r"""Load a serialized Atlas subgraph. | |||
| See :func:`~.atlas_subgraph` for more details. | |||
| """ | |||
| def __init__(self, data): | |||
| super(AtlasSubgraph, self).__init__() | |||
| self._data = data | |||
| @property | |||
| def data(self): | |||
| return self._data.tobytes() | |||
| @data.setter | |||
| def data(self, val): | |||
| self._data = np.frombuffer(val, dtype=np.uint8) | |||
| def forward(self, inputs): | |||
| outputs = atlas_subgraph(inputs, self._data) | |||
| return outputs | |||
| class ExternOprSubgraph(Module): | |||
| r"""Load a serialized extern opr subgraph. | |||
| """ | |||
| def __init__(self, data, name, output_shapes): | |||
| super(ExternOprSubgraph, self).__init__() | |||
| self.data = data | |||
| self.name = name | |||
| self.output_shapes = output_shapes | |||
| def forward(self, inputs): | |||
| outputs = extern_opr_subgraph(inputs, self.output_shapes, self.name, self.data,) | |||
| return outputs | |||
| @@ -1,17 +0,0 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from ..functional import identity | |||
| from .module import Module | |||
| class Identity(Module): | |||
| r"""A placeholder identity operator that will ignore any argument.""" | |||
| def forward(self, x): | |||
| return identity(x) | |||
| @@ -1,264 +0,0 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import math | |||
| from functools import reduce | |||
| from typing import Optional, Tuple, Union | |||
| import numpy as np | |||
| from ..core import Graph, Tensor | |||
| from ..random import gaussian, uniform | |||
| def fill_(tensor: Tensor, val: Union[float, int]) -> None: | |||
| """Fill the given ``tensor`` with value ``val``. | |||
| :param tensor: An n-dimentional tensor to be initialized | |||
| :param val: The value to be filled throughout the tensor | |||
| """ | |||
| tensor.set_value(np.full(tensor.shape, val, tensor.dtype)) | |||
| def zeros_(tensor: Tensor) -> None: | |||
| """Fill the given ``tensor`` with scalar value `0`. | |||
| :param tensor: An n-dimentional tensor to be initialized | |||
| """ | |||
| fill_(tensor, 0) | |||
| def ones_(tensor: Tensor) -> None: | |||
| """Fill the given ``tensor`` with the scalar value `1`. | |||
| :param tensor: An n-dimentional tensor to be initialized | |||
| """ | |||
| fill_(tensor, 1) | |||
| def uniform_(tensor: Tensor, a: float = 0.0, b: float = 1.0) -> None: | |||
| r"""Fill the given ``tensor`` with random value sampled from uniform distribution | |||
| :math:`\mathcal{U}(\text{a}, \text{b})`. | |||
| :param tensor: An n-dimentional tensor to be initialized | |||
| :param a: Lower bound of the sampling interval | |||
| :param b: Upper bound of the sampling interval | |||
| """ | |||
| with Graph(eager_evaluation=True): | |||
| tensor.set_value((b - a) * uniform(tensor.shape) + a) | |||
| def normal_(tensor: Tensor, mean: float = 0.0, std: float = 1.0) -> None: | |||
| r"""Fill the given ``tensor`` with random value sampled from normal distribution | |||
| :math:`\mathcal{N}(\text{mean}, \text{std}^2)`. | |||
| :param tensor: An n-dimentional tensor to be initialized | |||
| :param mean: The mean of the normal distribution | |||
| :param std: The standard deviation of the normal distribution | |||
| """ | |||
| with Graph(eager_evaluation=True): | |||
| tensor.set_value(gaussian(tensor.shape, mean=mean, std=std)) | |||
| def calculate_gain( | |||
| nonlinearity: str, param: Optional[Union[int, float]] = None | |||
| ) -> float: | |||
| r"""Return a recommended gain value (see the table below) for the given nonlinearity | |||
| function. | |||
| ================= ==================================================== | |||
| nonlinearity gain | |||
| ================= ==================================================== | |||
| Linear / Identity :math:`1` | |||
| Conv{1,2,3}D :math:`1` | |||
| Sigmoid :math:`1` | |||
| Tanh :math:`\frac{5}{3}` | |||
| ReLU :math:`\sqrt{2}` | |||
| Leaky Relu :math:`\sqrt{\frac{2}{1 + \text{negative_{slope}}^2}}` | |||
| ================= ==================================================== | |||
| :param nonlinearity: Name of the non-linear function | |||
| :param param: Optional parameter for leaky_relu. Only effective when | |||
| ``nonlinearity`` is "leaky_relu". | |||
| """ | |||
| linear_fns = [ | |||
| "linear", | |||
| "conv1d", | |||
| "conv2d", | |||
| "conv3d", | |||
| "conv_transpose1d", | |||
| "conv_transpose2d", | |||
| "conv_transpose3d", | |||
| ] | |||
| if nonlinearity in linear_fns or nonlinearity == "sigmoid": | |||
| return 1 | |||
| if nonlinearity == "tanh": | |||
| return 5.0 / 3 | |||
| if nonlinearity == "relu": | |||
| return math.sqrt(2.0) | |||
| if nonlinearity == "leaky_relu": | |||
| if param is None: | |||
| negative_slope = 0.01 | |||
| elif ( | |||
| not isinstance(param, bool) | |||
| and isinstance(param, int) | |||
| or isinstance(param, float) | |||
| ): | |||
| # True/False are instances of int, hence check above | |||
| negative_slope = param | |||
| else: | |||
| raise ValueError("negative_slope {} not a valid number".format(param)) | |||
| return math.sqrt(2.0 / (1 + negative_slope ** 2)) | |||
| raise ValueError("Unsupported nonlinearity {}".format(nonlinearity)) | |||
| def calculate_fan_in_and_fan_out(tensor: Tensor) -> Tuple[float, float]: | |||
| """ | |||
| Calculate fan_in / fan_out value for given weight tensor. This function assumes | |||
| input tensor is stored in NCHW format. | |||
| :param tensor: Weight tensor in NCHW format | |||
| """ | |||
| shape = tensor.shape | |||
| ndim = len(shape) | |||
| if ndim < 2: | |||
| raise ValueError( | |||
| "fan_in and fan_out can not be computed for tensor with fewer than 2 " | |||
| "dimensions" | |||
| ) | |||
| if ndim == 2: # Linear | |||
| fan_in = shape[1] | |||
| fan_out = shape[0] | |||
| else: | |||
| num_input_fmaps = shape[1] | |||
| num_output_fmaps = shape[0] | |||
| receptive_field_size = 1 | |||
| if ndim > 2: | |||
| receptive_field_size = reduce(lambda x, y: x * y, shape[2:], 1) | |||
| fan_in = num_input_fmaps * receptive_field_size | |||
| fan_out = num_output_fmaps * receptive_field_size | |||
| return fan_in, fan_out | |||
| def calculate_correct_fan(tensor: Tensor, mode: str) -> float: | |||
| """ | |||
| Calculate fan_in or fan_out value for given weight tensor, depending on given | |||
| ``mode``. | |||
| See :func:`calculate_fan_in_and_fan_out` for details. | |||
| :param tensor: Weight tensor in NCHW format | |||
| :param mode: ``'fan_in'`` or ``'fan_out'`` | |||
| """ | |||
| mode = mode.lower() | |||
| valid_modes = ["fan_in", "fan_out"] | |||
| if mode not in valid_modes: | |||
| raise ValueError( | |||
| "Mode {} not supported, please use one of {}".format(mode, valid_modes) | |||
| ) | |||
| fan_in, fan_out = calculate_fan_in_and_fan_out(tensor) | |||
| return fan_in if mode == "fan_in" else fan_out | |||
| def xavier_uniform_(tensor: Tensor, gain: float = 1.0) -> None: | |||
| r"""Fill ``tensor`` with random values sampled from :math:`\mathcal{U}(-a, a)` | |||
| where | |||
| .. math:: | |||
| a = \text{gain} \times \sqrt{\frac{6}{\text{fan_in} + \text{fan_out}}} | |||
| Also known as Glorot initialization. Detailed information can be retrieved from | |||
| `"Understanding the difficulty of training deep feedforward neural networks" <http://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf>`_. | |||
| :param tensor: An n-dimentional tensor to be initialized | |||
| :param gain: Scaling factor for :math:`a`. | |||
| """ | |||
| fan_in, fan_out = calculate_fan_in_and_fan_out(tensor) | |||
| std = gain * math.sqrt(2.0 / float(fan_in + fan_out)) | |||
| a = math.sqrt(3.0) * std | |||
| uniform_(tensor, -a, a) | |||
| def xavier_normal_(tensor: Tensor, gain: float = 1.0) -> None: | |||
| r"""Fill ``tensor`` with random values sampled from | |||
| :math:`\mathcal{N}(0, \text{std}^2)` where | |||
| .. math:: | |||
| \text{std} = \text{gain} \times \sqrt{\frac{2}{\text{fan_in} + \text{fan_out}}} | |||
| Also known as Glorot initialization. Detailed information can be retrieved from | |||
| `"Understanding the difficulty of training deep feedforward neural networks" <http://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf>`_. | |||
| :param tensor: An n-dimentional tensor to be initialized | |||
| :param gain: Scaling factor for :math:`std`. | |||
| """ | |||
| fan_in, fan_out = calculate_fan_in_and_fan_out(tensor) | |||
| std = gain * math.sqrt(2.0 / float(fan_in + fan_out)) | |||
| normal_(tensor, 0.0, std) | |||
| def msra_uniform_( | |||
| tensor: Tensor, a: float = 0, mode: str = "fan_in", nonlinearity: str = "leaky_relu" | |||
| ) -> None: | |||
| r"""Fill ``tensor`` wilth random values sampled from | |||
| :math:`\mathcal{U}(-\text{bound}, \text{bound})` where | |||
| .. math:: | |||
| \text{bound} = \sqrt{\frac{6}{(1 + a^2) \times \text{fan_in}}} | |||
| Detailed information can be retrieved from | |||
| `"Delving deep into rectifiers: Surpassing human-level performance on ImageNet | |||
| classification" <https://www.cv-foundation.org/openaccess/content_iccv_2015/papers/He_Delving_Deep_into_ICCV_2015_paper.pdf>`_. | |||
| :param tensor: An n-dimentional tensor to be initialized | |||
| :param a: Optional parameter for calculating gain for leaky_relu. See | |||
| :func:`calculate_gain` for details. | |||
| :param mode: ``'fan_in'`` or ``'fan_out'``, used to calculate :math:`gain`, the | |||
| scaling factor for :math:`bound`. See :func:`calculate_fan_in_and_fan_out` for | |||
| details. | |||
| :param nonlinearity: Name of the non-linear function used to calculate :math:`gain`. | |||
| See :func:`calculate_gain` for details. | |||
| """ | |||
| fan = calculate_correct_fan(tensor, mode) | |||
| gain = calculate_gain(nonlinearity, a) | |||
| std = gain / math.sqrt(fan) | |||
| bound = math.sqrt(3.0) * std | |||
| uniform_(tensor, -bound, bound) | |||
| def msra_normal_( | |||
| tensor: Tensor, a: float = 0, mode: str = "fan_in", nonlinearity: str = "leaky_relu" | |||
| ) -> None: | |||
| r"""Fill ``tensor`` wilth random values sampled from | |||
| :math:`\mathcal{N}(0, \text{std}^2)` where | |||
| .. math:: | |||
| \text{std} = \sqrt{\frac{2}{(1 + a^2) \times \text{fan_in}}} | |||
| Detailed information can be retrieved from | |||
| `"Delving deep into rectifiers: Surpassing human-level performance on ImageNet | |||
| classification" <https://www.cv-foundation.org/openaccess/content_iccv_2015/papers/He_Delving_Deep_into_ICCV_2015_paper.pdf>`_. | |||
| :param tensor: An n-dimentional tensor to be initialized | |||
| :param a: Optional parameter for calculating gain for leaky_relu. See | |||
| :func:`calculate_gain` for details. | |||
| :param mode: ``'fan_in'`` or ``'fan_out'``, used to calculate :math:`gain`, the | |||
| scaling factor for :math:`gain`. See :func:`calculate_fan_in_and_fan_out` for | |||
| details. | |||
| :param nonlinearity: Name of the non-linear function used to calculate :math:`gain`. | |||
| See :func:`calculate_gain` for details. | |||
| """ | |||
| fan = calculate_correct_fan(tensor, mode) | |||
| gain = calculate_gain(nonlinearity, a) | |||
| std = gain / math.sqrt(fan) | |||
| normal_(tensor, 0, std) | |||
| @@ -1,61 +0,0 @@ | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import numpy as np | |||
| from .. import functional as F | |||
| from ..core import Parameter | |||
| from . import init | |||
| from .module import Module | |||
| class Linear(Module): | |||
| r"""Applies a linear transformation to the input. For instance, if input | |||
| is x, then output y is: | |||
| .. math:: | |||
| y = xW^T + b | |||
| where :math:`y_i= \sum_j W_{ij} x_j + b_i` | |||
| :param in_features: size of each input sample. | |||
| :param out_features: size of each output sample. | |||
| :param bias: If set to ``False``, the layer will not learn an additive bias. | |||
| Default: ``True`` | |||
| """ | |||
| def __init__( | |||
| self, in_features: int, out_features: int, bias: bool = True, **kwargs | |||
| ): | |||
| super().__init__(**kwargs) | |||
| self.out_features = out_features | |||
| self.in_features = in_features | |||
| w_shape = (out_features, in_features) | |||
| self.weight = Parameter(np.zeros(w_shape, dtype=np.float32)) | |||
| self.bias = None | |||
| if bias: | |||
| b_shape = (out_features,) | |||
| self.bias = Parameter(np.zeros(b_shape, dtype=np.float32)) | |||
| self.reset_parameters() | |||
| def _get_fanin(self): | |||
| return self.in_features | |||
| def reset_parameters(self) -> None: | |||
| fanin = self._get_fanin() | |||
| std = np.sqrt(1 / fanin) | |||
| init.normal_(self.weight, 0.0, std) | |||
| if self.bias is not None: | |||
| init.zeros_(self.bias) | |||
| def _calc_linear(self, x, weight, bias): | |||
| return F.linear(x, weight, bias) | |||
| def forward(self, x): | |||
| return self._calc_linear(x, self.weight, self.bias) | |||
| @@ -1,507 +0,0 @@ | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from abc import ABCMeta, abstractmethod | |||
| from collections import OrderedDict | |||
| from typing import Any, Callable, Iterable, Optional, Set, Tuple, Union | |||
| import numpy as np | |||
| from .._internal.dtype import is_quantize | |||
| from ..core import Buffer, Parameter, Tensor | |||
| from ..logger import get_logger | |||
| from ..utils.hook import HookHandler | |||
| logger = get_logger(__name__) | |||
| def _expand_structure(key, obj): | |||
| if isinstance(obj, (Tensor, Module)): | |||
| return [(key, obj)] | |||
| elif isinstance(obj, (list, tuple, dict)): | |||
| ret = [] | |||
| if isinstance(obj, dict): | |||
| targets = ((k, obj[k]) for k in sorted(obj)) | |||
| else: | |||
| targets = ((str(k), v) for k, v in enumerate(obj)) | |||
| for k, o in targets: | |||
| sub_ret = _expand_structure(k, o) | |||
| if sub_ret and not isinstance(k, str): | |||
| raise AssertionError( | |||
| "keys for Tensor and Module must be str, error key: {}".format(k) | |||
| ) | |||
| for kt, vt in sub_ret: | |||
| ret.extend([(key + "." + kt, vt)]) | |||
| return ret | |||
| else: | |||
| return [] | |||
| def _is_parameter(obj): | |||
| return isinstance(obj, Parameter) | |||
| def _is_buffer(obj): | |||
| return isinstance(obj, Buffer) | |||
| def _is_module(obj): | |||
| return isinstance(obj, Module) | |||
| class Module(metaclass=ABCMeta): | |||
| """Base Module class. | |||
| """ | |||
| def __init__(self): | |||
| # runtime attributes | |||
| self.training = True | |||
| self.quantize_disabled = False | |||
| # hooks | |||
| self._forward_pre_hooks = OrderedDict() | |||
| self._forward_hooks = OrderedDict() | |||
| @abstractmethod | |||
| def forward(self, inputs): | |||
| pass | |||
| def register_forward_pre_hook(self, hook: Callable) -> HookHandler: | |||
| """Register a hook to handle forward inputs. `hook` should be a function | |||
| Note that `inputs` keyword inputs | |||
| :param hook: a function that receive `module` and `inputs`, then return | |||
| a modified `inputs` or `None`. | |||
| :return: a handler with :meth:`~.HookHandler.remove` interface to delete the hook. | |||
| """ | |||
| return HookHandler(self._forward_pre_hooks, hook) | |||
| def register_forward_hook(self, hook: Callable) -> HookHandler: | |||
| """Register a hook to handle forward results. `hook` should be a function that | |||
| receive `module`, `inputs` and `outputs`, then return a modified `outputs` or `None`. | |||
| This method return a handler with :meth:`~.HookHandler.remove` interface to delete the hook. | |||
| """ | |||
| return HookHandler(self._forward_hooks, hook) | |||
| def __call__(self, *inputs, **kwargs): | |||
| for hook in self._forward_pre_hooks.values(): | |||
| modified_inputs = hook(self, inputs) | |||
| if modified_inputs is not None: | |||
| if not isinstance(modified_inputs, tuple): | |||
| modified_inputs = (modified_inputs,) | |||
| inputs = modified_inputs | |||
| outputs = self.forward(*inputs, **kwargs) | |||
| for hook in self._forward_hooks.values(): | |||
| modified_outputs = hook(self, inputs, outputs) | |||
| if modified_outputs is not None: | |||
| outputs = modified_outputs | |||
| return outputs | |||
| def _flatten( | |||
| self, | |||
| *, | |||
| recursive: bool = True, | |||
| with_key: bool = False, | |||
| with_parent: bool = False, | |||
| prefix: Optional[str] = None, | |||
| predicate: Callable[[Any], bool] = lambda _: True, | |||
| seen: Optional[Set[int]] = None | |||
| ) -> Union[Iterable[Any], Iterable[Tuple[str, Any]]]: | |||
| """Scans the module object and returns an iterable for the :class:`~.Tensor` | |||
| and :class:`~.Module` attributes that agree with the ``predicate``. For multiple | |||
| calls of this function with same arguments, the order of objects within the | |||
| returned iterable is guaranteed to be identical, as long as all the involved | |||
| module objects' ``__dict__`` does not change thoughout those calls. | |||
| :param recursive: Whether to recursively scan all the submodules. | |||
| :param with_key: Whether to yield keys along with yielded objects. | |||
| :param with_parent: Whether to yield ``self`` along with yielded objects. | |||
| :param prefix: The prefix appended to the yielded keys. | |||
| :param predicate: The predicate function applied to scanned objects. | |||
| :param seen: A dict that records whether a module has been traversed yet. | |||
| """ | |||
| if seen is None: | |||
| seen = set([id(self)]) | |||
| module_dict = vars(self) | |||
| _prefix = "" if prefix is None else prefix + "." | |||
| for key in sorted(module_dict): | |||
| for expanded_key, leaf in _expand_structure(key, module_dict[key]): | |||
| leaf_id = id(leaf) | |||
| if leaf_id in seen: | |||
| continue | |||
| seen.add(leaf_id) | |||
| if predicate(leaf): | |||
| if with_key and with_parent: | |||
| yield _prefix + expanded_key, leaf, self | |||
| elif with_key: | |||
| yield _prefix + expanded_key, leaf | |||
| elif with_parent: | |||
| yield leaf, self | |||
| else: | |||
| yield leaf | |||
| if recursive and isinstance(leaf, Module): | |||
| yield from leaf._flatten( | |||
| recursive=recursive, | |||
| with_key=with_key, | |||
| with_parent=with_parent, | |||
| prefix=_prefix + expanded_key if with_key else None, | |||
| predicate=predicate, | |||
| seen=seen, | |||
| ) | |||
| def parameters( | |||
| self, requires_grad: Optional[bool] = None, recursive: bool = True, **kwargs | |||
| ) -> Iterable[Parameter]: | |||
| r"""Returns an iterable for the :class:`~.Parameter` of the module. | |||
| :param requires_grad: Limitation over the :attr:`~.Parameter.requires_grad` | |||
| attribute of returned :class:`.Parameter`. ``None`` for no limitation. | |||
| :param recursive: If ``True``, returns all :class:`~.Parameter` within this | |||
| module, else only returns :class:`~.Parameter` that are direct attributes | |||
| of this module. | |||
| """ | |||
| def predicate(obj) -> bool: | |||
| return _is_parameter(obj) and ( | |||
| requires_grad is None or obj.requires_grad == requires_grad | |||
| ) | |||
| yield from self._flatten( | |||
| with_key=False, predicate=predicate, recursive=recursive, **kwargs | |||
| ) | |||
| def named_parameters( | |||
| self, | |||
| requires_grad: Optional[bool] = None, | |||
| prefix: Optional[str] = None, | |||
| recursive: bool = True, | |||
| **kwargs | |||
| ) -> Iterable[Tuple[str, Parameter]]: | |||
| """Returns an iterable for key :class:`~.Parameter` pairs of the module, where | |||
| ``key`` is the dotted path from this module to the :class:`~.Parameter` . | |||
| :param requires_grad: Limitation over the :attr:`~.Parameter.requires_grad` | |||
| attribute of returned :class:`~.Parameter` . ``None`` for no limitation. | |||
| :param prefix: The prefix prepended to the keys. | |||
| :param recursive: If ``True``, returns all :class:`~.Parameter` within this | |||
| module, else only returns :class:`~.Parameter` that are direct attributes | |||
| of this module. | |||
| """ | |||
| def predicate(obj) -> bool: | |||
| return _is_parameter(obj) and ( | |||
| requires_grad is None or obj.requires_grad == requires_grad | |||
| ) | |||
| yield from self._flatten( | |||
| with_key=True, | |||
| prefix=prefix, | |||
| predicate=predicate, | |||
| recursive=recursive, | |||
| **kwargs, | |||
| ) | |||
| def buffers(self, recursive: bool = True, **kwargs) -> Iterable[Buffer]: | |||
| """Returns an iterable for the :class:`~.Buffer` of the module. | |||
| :param recursive: If ``True``, returns all :class:`~.Buffer` within this | |||
| module, else only returns :class:`~.Buffer` that are direct attributes | |||
| of this module. | |||
| """ | |||
| yield from self._flatten( | |||
| with_key=False, predicate=_is_buffer, recursive=recursive, **kwargs | |||
| ) | |||
| def named_buffers( | |||
| self, prefix: Optional[str] = None, recursive: bool = True, **kwargs | |||
| ) -> Iterable[Tuple[str, Buffer]]: | |||
| """Returns an iterable for key :class:`~.Buffer` pairs of the module, where | |||
| ``key`` is the dotted path from this module to the :class:`~.Buffer` . | |||
| :param prefix: The prefix prepended to the keys. | |||
| :param recursive: If ``True``, returns all :class:`~.Buffer` within this | |||
| module, else only returns :class:`~.Buffer` that are direct attributes | |||
| of this module. | |||
| """ | |||
| yield from self._flatten( | |||
| with_key=True, | |||
| prefix=prefix, | |||
| predicate=_is_buffer, | |||
| recursive=recursive, | |||
| **kwargs, | |||
| ) | |||
| def children(self, **kwargs) -> "Iterable[Module]": | |||
| """Returns an iterable for all the submodules that are direct attributes of this | |||
| module. | |||
| """ | |||
| yield from self._flatten( | |||
| with_key=False, predicate=_is_module, recursive=False, **kwargs | |||
| ) | |||
| def named_children(self, **kwargs) -> "Iterable[Tuple[str, Module]]": | |||
| """Returns an iterable of key-submodule pairs for all the submodules that are | |||
| direct attributes of this module, where 'key' is the attribute name of | |||
| submodules. | |||
| """ | |||
| yield from self._flatten( | |||
| with_key=True, predicate=_is_module, recursive=False, **kwargs | |||
| ) | |||
| def modules(self, **kwargs) -> "Iterable[Module]": | |||
| """Returns an iterable for all the modules within this module, including itself. | |||
| """ | |||
| if "with_parent" in kwargs and kwargs["with_parent"]: | |||
| yield self, None | |||
| else: | |||
| yield self | |||
| yield from self._flatten(with_key=False, predicate=_is_module, **kwargs) | |||
| def named_modules( | |||
| self, prefix: Optional[str] = None, **kwargs | |||
| ) -> "Iterable[Tuple[str, Module]]": | |||
| """Returns an iterable of key-module pairs for all the modules within this | |||
| module, including itself, where 'key' is the dotted path from this module to the | |||
| submodules. | |||
| :param prefix: The prefix prepended to the path. | |||
| """ | |||
| if "with_parent" in kwargs and kwargs["with_parent"]: | |||
| yield ("" if prefix is None else prefix), self, None | |||
| else: | |||
| yield ("" if prefix is None else prefix), self | |||
| yield from self._flatten( | |||
| with_key=True, prefix=prefix, predicate=_is_module, **kwargs | |||
| ) | |||
| def apply(self, fn: "Callable[[Module], Any]") -> None: | |||
| """Apply function ``fn`` to all the modules within this module, including | |||
| itself. | |||
| :param fn: The function to be applied on modules. | |||
| """ | |||
| for it in self.modules(): | |||
| fn(it) | |||
| def zero_grad(self) -> None: | |||
| """Set all parameters' grads to zero | |||
| """ | |||
| for param in self.parameters(): | |||
| if param.grad is not None: | |||
| param.grad.reset_zero() | |||
| def train(self, mode: bool = True, recursive: bool = True) -> None: | |||
| """Set training mode of all the modules within this module (including itself) to | |||
| ``mode``. This effectively sets the ``training`` attributes of those modules | |||
| to ``mode``, but only has effect on certain modules (e.g. | |||
| :class:`~.BatchNorm2d`, :class:`~.Dropout`, :class:`~.Observer`) | |||
| :param mode: the training mode to be set on modules. | |||
| :param recursive: whether to recursively call submodules' ``train()``. | |||
| """ | |||
| if not recursive: | |||
| self.training = mode | |||
| return | |||
| def fn(module: Module) -> None: | |||
| module.train(mode, recursive=False) | |||
| self.apply(fn) | |||
| def eval(self) -> None: | |||
| """Set training mode of all the modules within this module (including itself) to | |||
| ``False``. See :meth:`~.Module.train` for details. | |||
| """ | |||
| self.train(False) | |||
| def disable_quantize(self, value=True): | |||
| r""" | |||
| Set ``module``'s ``quantize_disabled`` attribute and return ``module``. | |||
| Could be used as a decorator. | |||
| """ | |||
| def fn(module: Module) -> None: | |||
| module.quantize_disabled = value | |||
| self.apply(fn) | |||
| def replace_param( | |||
| self, params: dict, start_pos: int, seen: Optional[Set[int]] = None | |||
| ): | |||
| """Replace module's parameters with `params`, used by :class:`~.ParamPack` to | |||
| speedup multimachine training. | |||
| """ | |||
| offset = 0 | |||
| if seen is None: | |||
| seen = set([id(self)]) | |||
| module_dict = vars(self) | |||
| for key in sorted(module_dict): | |||
| hash_id = id(module_dict[key]) | |||
| if hash_id in seen: | |||
| continue | |||
| seen.add(hash_id) | |||
| if isinstance(module_dict[key], Parameter): | |||
| if start_pos + offset in params: | |||
| assert module_dict[key].shape == params[start_pos + offset].shape | |||
| module_dict[key] = params[start_pos + offset] | |||
| offset += 1 | |||
| if isinstance(module_dict[key], Module): | |||
| offset += module_dict[key].replace_param( | |||
| params, start_pos + offset, seen | |||
| ) | |||
| return offset | |||
| def state_dict(self, rst=None, prefix="", keep_var=False): | |||
| r"""Returns a dictionary containing whole states of the module. | |||
| """ | |||
| def is_state(obj): | |||
| return _is_parameter(obj) or _is_buffer(obj) | |||
| if rst is None: | |||
| rst = OrderedDict() | |||
| for k, v in self._flatten(recursive=False, with_key=True, predicate=is_state): | |||
| assert prefix + k not in rst, "duplicated state: {}".format(k) | |||
| if keep_var: | |||
| rst[prefix + k] = v | |||
| else: | |||
| rst[prefix + k] = v.numpy() | |||
| for k, submodule in self._flatten( | |||
| recursive=False, | |||
| with_key=True, | |||
| predicate=lambda obj: isinstance(obj, Module), | |||
| ): | |||
| submodule.state_dict(rst, prefix + k + ".", keep_var) | |||
| return rst | |||
| def load_state_dict( | |||
| self, | |||
| state_dict: Union[dict, Callable[[str, Tensor], Optional[np.ndarray]]], | |||
| strict=True, | |||
| ): | |||
| r"""Load a given dictionary created by :func:`state_dict` into this module. | |||
| If ``strict`` is ``True``, the keys of :func:`state_dict` must exactly match the keys | |||
| returned by :func:`state_dict`. | |||
| Users can also pass a closure: `Function[key: str, var: Tensor] -> Optional[np.ndarray]` | |||
| as a `state_dict`, in order to handle complex situations. For example, load everything | |||
| except for the final linear classifier: | |||
| .. code-block:: | |||
| state_dict = {...} # Dict[str, np.ndarray] | |||
| model.load_state_dict({ | |||
| k: None if k.startswith('fc') else v | |||
| for k, v in state_dict.items() | |||
| }, strict=False) | |||
| Here returning `None` means skipping parameter `k`. | |||
| To prevent shape mismatch (e.g. load PyTorch weights), we can reshape before loading: | |||
| .. code-block:: | |||
| state_dict = {...} | |||
| def reshape_accordingly(k, v): | |||
| return state_dict[k].reshape(v.shape) | |||
| model.load_state_dict(reshape_accordingly) | |||
| We can also perform inplace re-initialization or pruning: | |||
| .. code-block:: | |||
| def reinit_and_pruning(k, v): | |||
| if 'bias' in k: | |||
| M.init.zero_(v) | |||
| if 'conv' in k: | |||
| return v.numpy() * (np.abs(v.numpy()) > 1e-3).astype("float32) | |||
| model.load_state_dict(reinit_and_pruning, strict=False) | |||
| """ | |||
| unused = [] | |||
| if isinstance(state_dict, dict): | |||
| unused = state_dict.keys() | |||
| def closure(k, _): # var unused | |||
| return state_dict[k] if k in state_dict else None | |||
| elif callable(state_dict): | |||
| closure = state_dict | |||
| else: | |||
| raise ValueError( | |||
| "`state_dict` must load a dict or callable, got {}".format( | |||
| type(state_dict) | |||
| ) | |||
| ) | |||
| loaded, skipped = self._load_state_dict_with_closure(closure) | |||
| unused = set(unused) - loaded | |||
| if len(unused) != 0: | |||
| if strict: | |||
| raise KeyError( | |||
| "Unused params violate `strict=True`, unused={}".format(unused) | |||
| ) | |||
| else: | |||
| logger.warning( | |||
| "Unused params in `strict=False` mode, unused={}".format(unused) | |||
| ) | |||
| if len(skipped) != 0: | |||
| if strict: | |||
| raise KeyError( | |||
| "Missing params violate `strict=True`, missing={}".format(skipped) | |||
| ) | |||
| else: | |||
| logger.warning( | |||
| "Missing params in `strict=False` mode, missing={}".format(skipped) | |||
| ) | |||
| def _load_state_dict_with_closure(self, closure): | |||
| """Advance state_dict load through callable `closure` whose signature is | |||
| `closure(key: str, var: Tensor) -> Union[np.ndarry, None]` | |||
| """ | |||
| assert callable(closure), "closure must be a function" | |||
| loaded = [] | |||
| skipped = [] | |||
| local_state_dict = self.state_dict(keep_var=True) | |||
| for k, var in local_state_dict.items(): | |||
| to_be_load = closure(k, var) | |||
| if to_be_load is None: | |||
| skipped.append(k) | |||
| continue | |||
| assert isinstance( | |||
| to_be_load, np.ndarray | |||
| ), "closure should return a `np.ndarray`, now `{}` get {}".format( | |||
| k, to_be_load | |||
| ) | |||
| assert ( | |||
| var.shape == to_be_load.shape | |||
| ), "param `{}` shape mismatch, should be {}, get {}".format( | |||
| k, var.shape, to_be_load.shape | |||
| ) | |||
| # For quantized dtype, the initialized dtype | |||
| # scale/zero_points maybe invalid, use pretrained dtype instead. | |||
| if is_quantize(to_be_load.dtype) and is_quantize(var.dtype): | |||
| var.dtype = to_be_load.dtype | |||
| var.set_value(to_be_load) | |||
| loaded.append(k) | |||
| return set(loaded), set(skipped) | |||
| @@ -1,157 +0,0 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import collections | |||
| from typing import Callable, Iterable, Optional, Tuple | |||
| import numpy as np | |||
| from .._internal.opr import param_pack_split | |||
| from ..core import Parameter, Tensor | |||
| from .module import Module | |||
| class ParamPack(Module): | |||
| r"""Pack module's parameters by gathering their memory to continuous address. | |||
| Using (device, dtype, requires_grad) as key, for example ('gpu0', float32, True), | |||
| parameters with same key will be packed togather. | |||
| It helps a lot for multimachine training by speeding up allreduce gradients. | |||
| :param model: the module you want to pack parameters. | |||
| :param nr_ignore_first: how many parameters will be unpacked at first. | |||
| :param max_size_per_group: upper bound of packed parameters' size in MB. | |||
| :param max_nr_params_per_group: upper bound of the number of parameters of each group. | |||
| """ | |||
| def __init__( | |||
| self, | |||
| model: Module, | |||
| nr_ignore_first: int = 8, | |||
| max_size_per_group: int = 10, | |||
| max_nr_params_per_group: int = 100, | |||
| group_func: Callable = lambda name, param: 0, | |||
| ): | |||
| super().__init__() | |||
| self._model = model | |||
| self._nr_ignore_first = nr_ignore_first | |||
| self._max_size_per_group = max_size_per_group | |||
| self._max_nr_params_per_group = max_nr_params_per_group | |||
| self._group_func = group_func | |||
| self._grouped_params = [] | |||
| self._packed_params = [] | |||
| params = model.named_parameters() | |||
| self._pack_params(params) | |||
| def parameters(self, requires_grad: Optional[bool] = None) -> Iterable[Parameter]: | |||
| for param in self._packed_params: | |||
| if requires_grad is None or param.requires_grad == requires_grad: | |||
| yield param | |||
| def named_parameters( | |||
| self, requires_grad: Optional[bool] = None | |||
| ) -> Iterable[Tuple[str, Parameter]]: | |||
| for idx, param in enumerate(self._packed_params): | |||
| if requires_grad is None or param.requires_grad == requires_grad: | |||
| yield "packed_param_" + str(idx), param | |||
| def _pack_params(self, params: Iterable[Tuple[str, Parameter]]): | |||
| groups = collections.defaultdict(list) | |||
| ignored = 0 | |||
| param_id = 0 | |||
| for name, param in params: | |||
| if self._nr_ignore_first > ignored: | |||
| ignored += 1 | |||
| self._grouped_params.append([{"shape": param.shape, "id": param_id}]) | |||
| param.pack_group_key = self._group_func(name, param) | |||
| self._packed_params.append(param) | |||
| else: | |||
| key = ( | |||
| param.dtype, | |||
| param.device, | |||
| param.requires_grad, | |||
| self._group_func(name, param), | |||
| ) | |||
| groups[key].append({"tensor": param, "id": param_id}) | |||
| param_id += 1 | |||
| for (dtype, device, requires_grad, group_key) in groups.keys(): | |||
| dtype_sz = np.dtype(dtype).itemsize | |||
| align = device.mem_align | |||
| if align < dtype_sz: | |||
| align = 1 | |||
| else: | |||
| assert align % dtype_sz == 0 | |||
| align //= dtype_sz | |||
| group = groups[(dtype, device, requires_grad, group_key)] | |||
| while group: | |||
| aligned_pos = [] | |||
| offset = 0 | |||
| params = [] | |||
| idx = 0 | |||
| while idx < len(group): | |||
| param = group[idx] | |||
| assert param["tensor"].device == device | |||
| padding = (align - (offset & (align - 1))) & (align - 1) | |||
| offset += padding | |||
| aligned_pos.append(offset) | |||
| params.append(param) | |||
| offset += int(np.prod(param["tensor"].shape)) | |||
| idx += 1 | |||
| if ( | |||
| offset * dtype_sz >= self._max_size_per_group * 1024 * 1024 | |||
| or idx >= self._max_nr_params_per_group | |||
| ): | |||
| break | |||
| group = group[idx:] | |||
| if idx == 1: | |||
| # ignore param packs with only one item | |||
| params[0]["tensor"].pack_group_key = group_key | |||
| self._packed_params.append(params[0]["tensor"]) | |||
| self._grouped_params.append( | |||
| [{"shape": params[0]["tensor"].shape, "id": params[0]["id"]}] | |||
| ) | |||
| continue | |||
| packed_value = np.zeros((offset,), dtype=dtype) | |||
| for param, pos in zip(params, aligned_pos): | |||
| val = param["tensor"].numpy() | |||
| packed_value[pos : pos + val.size] = val.flatten() | |||
| new_param = Parameter( | |||
| value=packed_value, | |||
| device=device, | |||
| dtype=dtype, | |||
| requires_grad=requires_grad, | |||
| ) | |||
| new_param.pack_group_key = group_key | |||
| self._packed_params.append(new_param) | |||
| self._grouped_params.append( | |||
| [{"shape": i["tensor"].shape, "id": i["id"]} for i in params] | |||
| ) | |||
| def forward(self, *args, **kwargs): | |||
| replace_param = dict() | |||
| for i in range(len(self._packed_params)): | |||
| packed_param = self._packed_params[i] | |||
| grouped_params = self._grouped_params[i] | |||
| if len(grouped_params) == 1: | |||
| continue | |||
| split = param_pack_split( | |||
| packed_param._symvar, [i["shape"] for i in grouped_params] | |||
| ) | |||
| split = [ | |||
| Parameter(Tensor(i, requires_grad=packed_param.requires_grad)) | |||
| for i in split | |||
| ] | |||
| for j in range(len(split)): | |||
| replace_param[grouped_params[j]["id"]] = split[j] | |||
| self._model.replace_param(replace_param, 0) | |||
| return self._model.forward(*args, **kwargs) | |||
| @@ -1,80 +0,0 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from abc import abstractmethod | |||
| from typing import Tuple, Union | |||
| from ..functional import avg_pool2d, max_pool2d | |||
| from .module import Module | |||
| class _PoolNd(Module): | |||
| def __init__( | |||
| self, | |||
| kernel_size: Union[int, Tuple[int, int]], | |||
| stride: Union[int, Tuple[int, int]] = None, | |||
| padding: Union[int, Tuple[int, int]] = 0, | |||
| ): | |||
| super(_PoolNd, self).__init__() | |||
| self.kernel_size = kernel_size | |||
| self.stride = stride or kernel_size | |||
| self.padding = padding | |||
| @abstractmethod | |||
| def forward(self, inp): | |||
| pass | |||
| class MaxPool2d(_PoolNd): | |||
| r"""Applies a 2D max pooling over an input. | |||
| For instance, given an input of the size :math:`(N, C, H, W)` and | |||
| :attr:`kernel_size` :math:`(kH, kW)`, this layer generates the output of | |||
| the size :math:`(N, C, H_{out}, W_{out})` through a process described as: | |||
| .. math:: | |||
| \begin{aligned} | |||
| out(N_i, C_j, h, w) ={} & \max_{m=0, \ldots, kH-1} \max_{n=0, \ldots, kW-1} | |||
| \text{input}(N_i, C_j, \text{stride[0]} \times h + m, | |||
| \text{stride[1]} \times w + n) | |||
| \end{aligned} | |||
| If :attr:`padding` is non-zero, then the input is implicitly zero-padded on | |||
| both sides for :attr:`padding` number of points. | |||
| :param kernel_size: the size of the window to take a max over. | |||
| :param stride: the stride of the window. Default value is ``kernel_size``. | |||
| :param padding: implicit zero padding to be added on both sides. | |||
| """ | |||
| def forward(self, inp): | |||
| return max_pool2d(inp, self.kernel_size, self.stride, self.padding) | |||
| class AvgPool2d(_PoolNd): | |||
| r"""Applies a 2D average pooling over an input. | |||
| For instance, given an input of the size :math:`(N, C, H, W)` and | |||
| :attr:`kernel_size` :math:`(kH, kW)`, this layer generates the output of | |||
| the size :math:`(N, C, H_{out}, W_{out})` through a process described as: | |||
| .. math:: | |||
| out(N_i, C_j, h, w) = \frac{1}{kH * kW} \sum_{m=0}^{kH-1} \sum_{n=0}^{kW-1} | |||
| input(N_i, C_j, stride[0] \times h + m, stride[1] \times w + n) | |||
| If :attr:`padding` is non-zero, then the input is implicitly zero-padded on | |||
| both sides for :attr:`padding` number of points. | |||
| :param kernel_size: the size of the window. | |||
| :param stride: the stride of the window. Default value is ``kernel_size``. | |||
| :param padding: implicit zero padding to be added on both sides. | |||
| """ | |||
| def forward(self, inp): | |||
| return avg_pool2d(inp, self.kernel_size, self.stride, self.padding) | |||
| @@ -1,9 +0,0 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from .pytorch import PyTorchModule | |||
| @@ -1,451 +0,0 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import collections | |||
| import copy | |||
| import functools | |||
| import os | |||
| from typing import Any, Callable, List, Optional, Tuple | |||
| import torch | |||
| from torch.utils.cpp_extension import load as load_torch_extension | |||
| import megengine._internal as mgb | |||
| from megengine._internal import CompGraph | |||
| from megengine._internal.mgb import CompGraphCallbackValueProxy | |||
| from ...core import Parameter, Tensor, get_default_device | |||
| from ..module import Module | |||
| from .utils import device_to_torch_device, torch_dtype_to_numpy_dtype | |||
| # A global dict to map opr during graph copy | |||
| _copy_dict = {} | |||
| @functools.lru_cache(None) | |||
| def _get_torch_mem_fwd_lib(): | |||
| source_file = os.path.join(os.path.dirname(__file__), "torch_mem_fwd.cpp") | |||
| return load_torch_extension( | |||
| "torch_mem_fwd", | |||
| [source_file], | |||
| extra_include_paths=[mgb.config.get_include_path()], | |||
| ) | |||
| def inp_mem_fwd(pubapi_dev_tensor_ptr: int) -> torch.Tensor: | |||
| """Forward a MegBrain tensor to torch tensor | |||
| :param pubapi_dev_tensor_ptr: pointer to MegBrain tensor | |||
| """ | |||
| return _get_torch_mem_fwd_lib().inp_mem_fwd(pubapi_dev_tensor_ptr) | |||
| def oup_mem_fwd( | |||
| pubapi_dev_tensor_ptr: int, tensor: torch.Tensor, keep_data_ptr: bool = True | |||
| ) -> None: | |||
| """Forward a torch tensor to a contiguous MegBrain tensor | |||
| :param pubapi_dev_tensor_ptr: Pointer to the MegBrain tensor | |||
| :param tensor: The input torch tensor | |||
| :param keep_data_ptr: if True, memory copy is not allowed here, | |||
| thus the input torch tensor must be contiguous also. | |||
| defaults to True | |||
| """ | |||
| _get_torch_mem_fwd_lib().oup_mem_fwd(pubapi_dev_tensor_ptr, tensor, keep_data_ptr) | |||
| def torch_param_to_mge( | |||
| name: str, param: torch.nn.Parameter, device, comp_graph: CompGraph | |||
| ) -> Parameter: | |||
| """Convert a torch parameter to a megengine parameter | |||
| :param name: parametr name | |||
| :param param: torch parameter | |||
| :param device: the device on which the megengine parameter is, | |||
| should be physically the same as the one on torch parameter | |||
| :param comp_graph: the owner graph of megengine parameter | |||
| :return: megengine parameter | |||
| """ | |||
| assert isinstance(param, torch.nn.Parameter) | |||
| dtype = torch_dtype_to_numpy_dtype(param.dtype) | |||
| mge_param = Parameter(None, dtype=dtype) | |||
| shared_nd = mge_param._Tensor__val | |||
| oup_mem_fwd(shared_nd.pubapi_dev_tensor_ptr, param.data, True) | |||
| return mge_param | |||
| class _PyTorchSubgraphGradOpr(mgb.craniotome.CraniotomeBase): | |||
| __nr_inputs__ = None | |||
| __nr_outputs__ = None | |||
| __allow_duplicate__ = False | |||
| __disable_sys_mem_alloc__ = True | |||
| __is_dynamic_output_shape__ = True | |||
| _forward_opr = None # type: PyTorchSubgraphImplOpr | |||
| _shape_infer_func = None | |||
| _condensed_out_grad_idx = None # type: List[Optional[int]] | |||
| _forward_input_cnt = None | |||
| _forward_output_cnt = None | |||
| _output_grad_cnt = None | |||
| _param_cnt = None | |||
| def setup( | |||
| self, forward_opr, condensed_out_grad_idx: List[Optional[int]], infer_shape=None | |||
| ): | |||
| self._forward_opr = forward_opr | |||
| self._forward_input_cnt = forward_opr.input_cnt | |||
| self._forward_output_cnt = forward_opr.output_cnt | |||
| self._param_cnt = forward_opr.param_cnt | |||
| self._output_grad_cnt = sum([idx is not None for idx in condensed_out_grad_idx]) | |||
| self.__nr_inputs__ = ( | |||
| self._forward_input_cnt | |||
| + self._param_cnt | |||
| + self._forward_output_cnt | |||
| + self._output_grad_cnt | |||
| ) | |||
| self.__nr_outputs__ = self._forward_input_cnt + self._param_cnt | |||
| self._forward_opr = forward_opr | |||
| self._condensed_out_grad_idx = condensed_out_grad_idx | |||
| self._shape_infer_func = infer_shape | |||
| if infer_shape is not None: | |||
| type(self).__is_dynamic_output_shape__ = False | |||
| def execute( | |||
| self, | |||
| inputs: Tuple[CompGraphCallbackValueProxy, ...], | |||
| outputs: Tuple[mgb.SharedND, ...], | |||
| ): | |||
| assert self._forward_opr._last_forward_inputs is not None | |||
| assert self._forward_opr._last_forward_outputs is not None | |||
| if self._forward_opr._last_forward_outputs is None: | |||
| self._forward_opr.execute(inputs[: self.__nr_outputs__], None) | |||
| out_grads = [ | |||
| inp_mem_fwd(inputs[idx].pubapi_dev_tensor_ptr) if idx else None | |||
| for idx in self._condensed_out_grad_idx | |||
| ] | |||
| grads = torch.autograd.grad( | |||
| self._forward_opr._last_forward_outputs, | |||
| self._forward_opr._last_forward_inputs | |||
| + self._forward_opr._last_forward_params, | |||
| out_grads, # type: ignore | |||
| only_inputs=True, | |||
| allow_unused=True, | |||
| ) | |||
| for ovar, oten in zip(outputs, grads): | |||
| oup_mem_fwd(ovar.pubapi_dev_tensor_ptr, oten) | |||
| def grad(self, wrt_idx, inputs, outputs, out_grad): | |||
| raise NotImplementedError("Apply grad to a grad opr is not supported") | |||
| def infer_shape(self, inp_shapes): | |||
| if callable(self._shape_infer_func): | |||
| return self._shape_infer_func(inp_shapes) | |||
| raise NotImplementedError( | |||
| "No shape inference function specified on PyTorchSubgraphImplOpr" | |||
| ) | |||
| def copy(self): | |||
| ret = type(self)() | |||
| d0 = self.__dict__.copy() | |||
| d0.pop("this") | |||
| d0.pop("_forward_opr") | |||
| later_copy = self._forward_opr in _copy_dict | |||
| if later_copy: | |||
| assert len(_copy_dict) == 1 | |||
| forward_opr_copy = _copy_dict[self._forward_opr] | |||
| else: | |||
| forward_opr_copy = self._forward_opr | |||
| ret.__dict__["_forward_opr"] = forward_opr_copy | |||
| ret.__dict__.update(copy.deepcopy(d0)) | |||
| _copy_dict[self] = ret | |||
| if later_copy: | |||
| forward_opr_copy._grad_opr = ret | |||
| _copy_dict.clear() | |||
| return ret | |||
| class PyTorchSubgraphImplOpr(mgb.craniotome.CraniotomeBase): | |||
| # pylint: disable=abstract-method | |||
| """This is a pytorch module wrapper to operator""" | |||
| __nr_inputs__ = None # type: int | |||
| __nr_outputs__ = None # type: int | |||
| __allow_duplicate__ = False | |||
| __disable_sys_mem_alloc__ = True | |||
| __is_dynamic_output_shape__ = True | |||
| _grad_opr = None | |||
| _func = None # type: Callable[[Any], Any] | |||
| input_cnt = None # type: int | |||
| output_cnt = None # type: int | |||
| param_cnt = None # type: int | |||
| _shape_infer_func = None | |||
| _last_forward_inputs = None | |||
| _last_forward_outputs = None # type: List[torch.Tensor] | |||
| _last_forward_params = None # type: List[torch.Tensor] | |||
| def setup(self, *, input_cnt, output_cnt, func, params, infer_shape=None): | |||
| """Setup the operator by accepted kwargs | |||
| :param input_cnt: input count of torch module | |||
| :param output_cnt: output count of torch module | |||
| :param func: a callable object accept inputs and returns outputs | |||
| usually a torch module itself | |||
| :param params: parameters of the torch module | |||
| :param infer_shape: a callable infers output shapes from input shapes, | |||
| defaults to None | |||
| """ | |||
| param_cnt = len(params) | |||
| self.input_cnt = input_cnt | |||
| self.output_cnt = output_cnt | |||
| self.param_cnt = param_cnt | |||
| self.__nr_inputs__ = input_cnt + param_cnt | |||
| self.__nr_outputs__ = output_cnt | |||
| self._func = func | |||
| self._shape_infer_func = infer_shape | |||
| if infer_shape is not None: | |||
| type(self).__is_dynamic_output_shape__ = False | |||
| self._last_forward_params = params | |||
| def execute( | |||
| self, | |||
| inputs: Tuple[CompGraphCallbackValueProxy, ...], | |||
| outputs: Optional[Tuple[mgb.SharedND, ...]], | |||
| ): | |||
| """execute the operator, read values from *inputs*, | |||
| forward them to torch tensor and do execution by self.func | |||
| and forward results to outputs | |||
| :param inputs: values for each input var | |||
| :param outputs: values for each output var | |||
| """ | |||
| input_value_proxys = inputs[: self.input_cnt] | |||
| input_torch_tensors = [ | |||
| inp_mem_fwd(ivar.pubapi_dev_tensor_ptr).requires_grad_() | |||
| for ivar in input_value_proxys | |||
| ] | |||
| output_torch_tensors = self._func(*input_torch_tensors) | |||
| if isinstance(output_torch_tensors, torch.Tensor): | |||
| output_torch_tensors = [output_torch_tensors] | |||
| # `execute` may be called in _PyTorchSubgraphGradOp with None as outputs | |||
| if outputs: | |||
| for ovar, oten in zip(outputs, output_torch_tensors): | |||
| oup_mem_fwd(ovar.pubapi_dev_tensor_ptr, oten) | |||
| # Retain input / output tensors for backward | |||
| self._last_forward_inputs = input_torch_tensors | |||
| self._last_forward_outputs = output_torch_tensors | |||
| def grad( | |||
| self, | |||
| wrt_idx, | |||
| inputs: Tuple[mgb.SymbolVar, ...], | |||
| outputs: Tuple[mgb.SymbolVar, ...], | |||
| out_grads: Tuple[mgb.SymbolVar, ...], | |||
| ): | |||
| """generate a grad opr which calculates grad by torch.autograd.grad and cache it | |||
| :param wrt_idx: the input var with respect to which the gradient should | |||
| be computed | |||
| :param inputs: operator inputs | |||
| :param outputs: operator outputs | |||
| :param out_grads: gradients of each output var | |||
| :return: an initialized grad opr | |||
| """ | |||
| if self._grad_opr is None: | |||
| condensed_out_grad = [] | |||
| condensed_out_grad_idx = [] # type: List[Optional[int]] | |||
| idx = self.__nr_inputs__ + len(outputs) | |||
| for out_grad in out_grads: | |||
| if out_grad is None: | |||
| condensed_out_grad_idx.append(None) | |||
| else: | |||
| condensed_out_grad.append(out_grad) | |||
| condensed_out_grad_idx.append(idx) | |||
| idx += 1 | |||
| self._grad_opr = _PyTorchSubgraphGradOpr.make( | |||
| *(inputs + outputs + tuple(condensed_out_grad)), | |||
| forward_opr=self, | |||
| condensed_out_grad_idx=condensed_out_grad_idx, | |||
| ) | |||
| return self._grad_opr | |||
| def infer_shape(self, inp_shapes): | |||
| """infer output shape from input shapes | |||
| :param inp_shapes: input shapes as tuple | |||
| :return: output shapes | |||
| """ | |||
| if callable(self._shape_infer_func): | |||
| return self._shape_infer_func(inp_shapes) | |||
| raise NotImplementedError( | |||
| "No shape inference function specified on PyTorchSubgraphImplOpr" | |||
| ) | |||
| def copy(self): | |||
| ret = type(self)() | |||
| d0 = self.__dict__.copy() | |||
| d0.pop("this") | |||
| ret.__dict__["_last_forward_inputs"] = d0.pop("_last_forward_inputs") | |||
| ret.__dict__["_last_forward_outputs"] = d0.pop("_last_forward_outputs") | |||
| ret.__dict__["_last_forward_params"] = d0.pop("_last_forward_params") | |||
| ret.__dict__["_func"] = d0.pop("_func") | |||
| d0.pop("_grad_opr") | |||
| later_copy = self._grad_opr in _copy_dict | |||
| if later_copy: | |||
| assert len(_copy_dict) == 1 | |||
| grad_opr_copy = _copy_dict[self._grad_opr] | |||
| else: | |||
| grad_opr_copy = self._grad_opr | |||
| ret.__dict__["_grad_opr"] = grad_opr_copy | |||
| ret.__dict__.update(copy.deepcopy(d0)) | |||
| _copy_dict[self] = ret | |||
| if later_copy: | |||
| grad_opr_copy._forward_opr = ret | |||
| _copy_dict.clear() | |||
| return ret | |||
| class PyTorchModule(Module): | |||
| """Wrap a pytorch module as megengine module | |||
| :param torch_module: torch module to be wrapped | |||
| :param device: target device this module would be in | |||
| :param output_cnt: output count of this module | |||
| :param input_shape: input shape inferrer | |||
| :param comp_graph: target comp_graph on which this module would be in | |||
| """ | |||
| __torch_module = None # type: torch.nn.Module | |||
| __output_cnt = None | |||
| __infer_shape = None | |||
| __comp_graph = None | |||
| __device = None | |||
| _torch_params = None | |||
| _param_inputs = None | |||
| _name_param_list = None # type: List[Tuple[str, Parameter]] | |||
| def __init__( | |||
| self, | |||
| torch_module, | |||
| device=None, | |||
| output_cnt=1, | |||
| *, | |||
| infer_shape=None, | |||
| comp_graph=None | |||
| ): | |||
| super().__init__() | |||
| if not isinstance(torch_module, torch.nn.Module): | |||
| raise TypeError( | |||
| "torch_module should either be an instance of torch.nn.Module " | |||
| "or its subclass" | |||
| ) | |||
| self.__torch_module = torch_module | |||
| if not isinstance(output_cnt, int): | |||
| raise TypeError("output_cnt must be int") | |||
| if output_cnt <= 0: | |||
| raise ValueError("output_cnt must be greater than zero") | |||
| self.__output_cnt = output_cnt | |||
| if infer_shape and not callable(infer_shape): | |||
| raise TypeError("infer_shape should either be None or a callable object") | |||
| self.__infer_shape = infer_shape | |||
| if comp_graph and not isinstance(comp_graph, mgb.CompGraph): | |||
| raise TypeError("comp_graph shoud eighter be None or a mgb.CompGraph") | |||
| self.__comp_graph = comp_graph | |||
| self._torch_params = [] | |||
| self._param_inputs = [] | |||
| self._name_param_list = [] | |||
| if device is None: | |||
| device = get_default_device() | |||
| if isinstance(device, str): | |||
| device = mgb.comp_node(device) | |||
| self.device = device | |||
| def init_params(self): | |||
| """forward torch parameters to megengine parameters and store, | |||
| would be called in constructor and setter of device | |||
| """ | |||
| self._torch_params = [] | |||
| self._param_inputs = [] | |||
| self._name_param_list = [] | |||
| for name, torch_param in self.__torch_module.named_parameters(recurse=True): | |||
| formated_name = "_torch_{}_{}".format(id(self.__torch_module), name) | |||
| mge_param = torch_param_to_mge( | |||
| formated_name, torch_param, self.device, self.__comp_graph | |||
| ) | |||
| self._param_inputs.append(mge_param) | |||
| self._torch_params.append(torch_param) | |||
| self._name_param_list.append((name, mge_param)) | |||
| def get_param_by_name(self, param_name: str) -> Parameter: | |||
| """find parameter by its name | |||
| :param param_name: name of parameter | |||
| :return: the parameter | |||
| """ | |||
| for name, param in self._name_param_list: | |||
| if param_name == name: | |||
| return param | |||
| raise KeyError("Cannot find param: {}".format(param_name)) | |||
| def forward(self, *inputs): | |||
| """apply the module on given inputs | |||
| :return: output vars | |||
| """ | |||
| param_inputs = [param._symvar for param in self._param_inputs] | |||
| inputs = [tensor._symvar for tensor in list(inputs)] + param_inputs | |||
| out = PyTorchSubgraphImplOpr.make( | |||
| *inputs, | |||
| input_cnt=len(inputs) - len(param_inputs), | |||
| output_cnt=self.__output_cnt, | |||
| func=self.__torch_module.forward, | |||
| params=self._torch_params, | |||
| infer_shape=self.__infer_shape, | |||
| ) | |||
| if isinstance(out, mgb.SymbolVar): | |||
| return Tensor(out) | |||
| assert isinstance(out, collections.Iterable) | |||
| return [Tensor(sym) for sym in out] | |||
| def get_device(self): | |||
| """get the device this module belongs to""" | |||
| return self.__device | |||
| def set_device(self, device: mgb.CompNode): | |||
| """set the device and move torch module to corresponding device""" | |||
| touch_device = device_to_torch_device(device) | |||
| self.__torch_module.to(device=touch_device) | |||
| self.__device = device | |||
| self.init_params() | |||
| device = property(get_device, set_device) | |||
| @@ -1,148 +0,0 @@ | |||
| /** | |||
| * \file python_module/megengine/module/pytorch/torch_mem_fwd.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #include "torch/extension.h" | |||
| #include "megbrain_pubapi.h" | |||
| using MGBTensor = mgb::pubapi::DeviceTensor; | |||
| torch::Tensor mgb_to_torch(const MGBTensor *src) { | |||
| mgb::pubapi::CallbackOnce deleter; | |||
| void* tensor_raw_ptr; | |||
| src->forward_to(&tensor_raw_ptr, &deleter); | |||
| auto deleter_wrap = [deleter](void*) mutable { | |||
| deleter.consume(); | |||
| }; | |||
| // TODO: support non-contiguous layout | |||
| std::vector<int64_t> sizes; | |||
| for (size_t i = 0; i < src->desc.ndim; ++ i) { | |||
| sizes.push_back(src->desc.shape[i]); | |||
| } | |||
| torch::TensorOptions options; | |||
| switch (src->desc.dtype) { | |||
| #define map_dtype(mgb_dtype, torch_dtype) \ | |||
| case MGBTensor::DataType::mgb_dtype: \ | |||
| options = options.dtype(caffe2::TypeMeta::Make<torch_dtype>()); \ | |||
| break; | |||
| map_dtype(FLOAT32, float); | |||
| map_dtype(FLOAT16, torch::Half); | |||
| map_dtype(INT32, int); | |||
| map_dtype(INT16, int16_t); | |||
| map_dtype(INT8, int8_t); | |||
| map_dtype(UINT8, uint8_t); | |||
| #undef map_dtype | |||
| default: | |||
| throw std::runtime_error("bad case for data type."); | |||
| } | |||
| // TODO: Maybe we should impl copy on different devices? | |||
| switch (src->desc.type) { | |||
| case MGBTensor::Type::CUDA: { | |||
| int device_id = src->desc.cuda_ctx.device; | |||
| if (device_id >= 0) { | |||
| options = options.device(torch::DeviceType::CUDA, device_id); | |||
| } else { | |||
| throw std::runtime_error("bad case for device(cuda) id."); | |||
| } | |||
| // TODO: consider cuda synchronization here | |||
| // Maybe all tasks issued on cuda_ctx(device, stream) should be done? | |||
| break; | |||
| } | |||
| case MGBTensor::Type::CPU: | |||
| options = options.device(torch::DeviceType::CPU); | |||
| // Torch's API are all synchronous. | |||
| src->sync(); | |||
| break; | |||
| default: | |||
| throw std::runtime_error("bad case for device type."); | |||
| } | |||
| auto tensor = torch::from_blob(tensor_raw_ptr, sizes, deleter_wrap, options); | |||
| return tensor; | |||
| } | |||
| void torch_to_mgb(MGBTensor* dst, torch::Tensor src) { | |||
| MGBTensor::Desc desc; | |||
| desc.dev_ptr = src.data_ptr(); | |||
| // src is contiguous torch tensor here, so no strides needed | |||
| std::vector<size_t> shape; | |||
| // desc.shape is the pointer to a size array used to construct | |||
| // an inner-mgb tensor, which should be valid until calling of | |||
| // forward_other_memory return | |||
| for (auto &&i : src.sizes()) { | |||
| shape.push_back(i); | |||
| } | |||
| desc.shape = shape.data(); | |||
| desc.ndim = shape.size(); | |||
| switch (src.scalar_type()) { | |||
| #define map_dtype(mgb_dtype, torch_dtype) \ | |||
| case torch::ScalarType::torch_dtype: \ | |||
| desc.dtype = MGBTensor::DataType::mgb_dtype; \ | |||
| break; | |||
| map_dtype(FLOAT32, Float); | |||
| map_dtype(FLOAT16, Half); | |||
| map_dtype(INT32, Int); | |||
| map_dtype(INT16, Short); | |||
| map_dtype(INT8, Char); | |||
| map_dtype(UINT8, Byte); | |||
| #undef map_dtype | |||
| default: | |||
| throw std::runtime_error("bad case for data type."); | |||
| } | |||
| // TODO: cuda setting and synchronization like mgb_to_torch | |||
| if (src.device().type() == torch::DeviceType::CUDA) { | |||
| desc.type = MGBTensor::Type::CUDA; | |||
| desc.cuda_ctx.device = src.get_device(); | |||
| desc.cuda_ctx.stream = nullptr; | |||
| } else { | |||
| assert(src.device().type() == torch::DeviceType::CPU); | |||
| desc.type = MGBTensor::Type::CUDA; | |||
| } | |||
| mgb::pubapi::CallbackOnce deleter; | |||
| deleter.user_data = new torch::Tensor(src); | |||
| deleter.fptr = [](void* ptr) { | |||
| delete static_cast<torch::Tensor*>(ptr); | |||
| }; | |||
| dst->forward_other_memory(desc, deleter); | |||
| } | |||
| torch::Tensor inp_mem_fwd(uintptr_t dv_ptr) { | |||
| // construct torch Tensor from mgb DeviceTensor stored in dv_ptr. | |||
| return mgb_to_torch(reinterpret_cast<MGBTensor*>(dv_ptr)); | |||
| } | |||
| void oup_mem_fwd(uintptr_t dv_ptr, torch::Tensor src, | |||
| bool keep_data_ptr=false) { | |||
| // forward storage in torch Tensor to mgb DeviceTensor | |||
| // keep_data_ptr: set to True to ensure forwarding data_ptr under \p src | |||
| // to megbrain, or it maybe copy src to a new contiguous tensor storage. | |||
| // which would return src itself if tensor is contiguous | |||
| auto src_contig = src.contiguous(); | |||
| if (keep_data_ptr && src_contig.data_ptr() != src.data_ptr()) { | |||
| throw std::runtime_error("should keep tensor data ptr, but it changed"); | |||
| } | |||
| torch_to_mgb(reinterpret_cast<MGBTensor*>(dv_ptr), src_contig); | |||
| } | |||
| PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { | |||
| m.def("inp_mem_fwd", &inp_mem_fwd, "Forward mgb DeviceTensor ptr into torch Tensor as network input."); | |||
| m.def("oup_mem_fwd", &oup_mem_fwd, "Forward torch network Tensor to corresponding mgb VarNode.", | |||
| py::arg("dv_ptr"), py::arg("src"), py::arg("keep_data_ptr") = false); | |||
| } | |||
| @@ -1,67 +0,0 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import numpy as np | |||
| import torch | |||
| import megengine._internal as mgb | |||
| _TORCH_NUMPY_MAPPING = { | |||
| torch.float16: np.float16, | |||
| torch.float32: np.float32, | |||
| torch.float64: np.float64, | |||
| torch.int8: np.int8, | |||
| torch.int16: np.int16, | |||
| torch.int32: np.int32, | |||
| } | |||
| def torch_dtype_to_numpy_dtype(torch_dtype: torch.dtype): | |||
| """map torch dtype to numpy dtype | |||
| :param torch_dtype: torch dtype | |||
| :return: numpy dtype | |||
| """ | |||
| if not isinstance(torch_dtype, torch.dtype): | |||
| raise TypeError("Argument `torch_dtype` should be an instance of torch.dtype") | |||
| if torch_dtype not in _TORCH_NUMPY_MAPPING: | |||
| raise ValueError("Unknown PyTorch dtype: {}".format(torch_dtype)) | |||
| return _TORCH_NUMPY_MAPPING[torch_dtype] | |||
| def torch_device_to_device(device: torch.device): | |||
| """map torch device to device | |||
| :param device: torch device | |||
| :return: device | |||
| """ | |||
| if not isinstance(device, torch.device): | |||
| raise TypeError("Argument `device` should be an instance of torch.device") | |||
| index = device.index | |||
| if index is None: | |||
| index = "x" | |||
| if device.type == "cpu": | |||
| return "cpu{}".format(index) | |||
| elif device.type == "cuda": | |||
| return "gpu{}".format(index) | |||
| raise ValueError("Unknown PyTorch device: {}".format(device)) | |||
| def device_to_torch_device(device: mgb.CompNode): | |||
| """map device to torch device | |||
| :param device: megbrain compute node | |||
| :return: corresponding torch device | |||
| """ | |||
| t, d, _ = device.locator_physical | |||
| if t == "CUDA": | |||
| return torch.device("cuda", d) | |||
| elif t == "CPU": | |||
| return torch.device("cpu", d) | |||
| else: | |||
| raise Exception("Unsupported device type: {}".format(t)) | |||
| @@ -1,14 +0,0 @@ | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from .concat import Concat | |||
| from .conv import Conv2d, ConvRelu2d | |||
| from .conv_bn import ConvBn2d, ConvBnRelu2d | |||
| from .elemwise import Elemwise | |||
| from .linear import Linear | |||
| from .module import QATModule | |||
| from .quant_dequant import DequantStub, QuantStub | |||