You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

setup.py 7.8 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217
  1. #!/usr/bin/env python
  2. # Copyright (c) OpenMMLab. All rights reserved.
  3. import os
  4. import os.path as osp
  5. import shutil
  6. import sys
  7. import warnings
  8. from setuptools import find_packages, setup
  9. import torch
  10. from torch.utils.cpp_extension import (BuildExtension, CppExtension,
  11. CUDAExtension)
  12. def readme():
  13. with open('README.md', encoding='utf-8') as f:
  14. content = f.read()
  15. return content
  16. version_file = 'mmdet/version.py'
  17. def get_version():
  18. with open(version_file, 'r') as f:
  19. exec(compile(f.read(), version_file, 'exec'))
  20. return locals()['__version__']
  21. def make_cuda_ext(name, module, sources, sources_cuda=[]):
  22. define_macros = []
  23. extra_compile_args = {'cxx': []}
  24. if torch.cuda.is_available() or os.getenv('FORCE_CUDA', '0') == '1':
  25. define_macros += [('WITH_CUDA', None)]
  26. extension = CUDAExtension
  27. extra_compile_args['nvcc'] = [
  28. '-D__CUDA_NO_HALF_OPERATORS__',
  29. '-D__CUDA_NO_HALF_CONVERSIONS__',
  30. '-D__CUDA_NO_HALF2_OPERATORS__',
  31. ]
  32. sources += sources_cuda
  33. else:
  34. print(f'Compiling {name} without CUDA')
  35. extension = CppExtension
  36. return extension(
  37. name=f'{module}.{name}',
  38. sources=[os.path.join(*module.split('.'), p) for p in sources],
  39. define_macros=define_macros,
  40. extra_compile_args=extra_compile_args)
  41. def parse_requirements(fname='requirements.txt', with_version=True):
  42. """Parse the package dependencies listed in a requirements file but strips
  43. specific versioning information.
  44. Args:
  45. fname (str): path to requirements file
  46. with_version (bool, default=False): if True include version specs
  47. Returns:
  48. List[str]: list of requirements items
  49. CommandLine:
  50. python -c "import setup; print(setup.parse_requirements())"
  51. """
  52. import sys
  53. from os.path import exists
  54. import re
  55. require_fpath = fname
  56. def parse_line(line):
  57. """Parse information from a line in a requirements text file."""
  58. if line.startswith('-r '):
  59. # Allow specifying requirements in other files
  60. target = line.split(' ')[1]
  61. for info in parse_require_file(target):
  62. yield info
  63. else:
  64. info = {'line': line}
  65. if line.startswith('-e '):
  66. info['package'] = line.split('#egg=')[1]
  67. elif '@git+' in line:
  68. info['package'] = line
  69. else:
  70. # Remove versioning from the package
  71. pat = '(' + '|'.join(['>=', '==', '>']) + ')'
  72. parts = re.split(pat, line, maxsplit=1)
  73. parts = [p.strip() for p in parts]
  74. info['package'] = parts[0]
  75. if len(parts) > 1:
  76. op, rest = parts[1:]
  77. if ';' in rest:
  78. # Handle platform specific dependencies
  79. # http://setuptools.readthedocs.io/en/latest/setuptools.html#declaring-platform-specific-dependencies
  80. version, platform_deps = map(str.strip,
  81. rest.split(';'))
  82. info['platform_deps'] = platform_deps
  83. else:
  84. version = rest # NOQA
  85. info['version'] = (op, version)
  86. yield info
  87. def parse_require_file(fpath):
  88. with open(fpath, 'r') as f:
  89. for line in f.readlines():
  90. line = line.strip()
  91. if line and not line.startswith('#'):
  92. for info in parse_line(line):
  93. yield info
  94. def gen_packages_items():
  95. if exists(require_fpath):
  96. for info in parse_require_file(require_fpath):
  97. parts = [info['package']]
  98. if with_version and 'version' in info:
  99. parts.extend(info['version'])
  100. if not sys.version.startswith('3.4'):
  101. # apparently package_deps are broken in 3.4
  102. platform_deps = info.get('platform_deps')
  103. if platform_deps is not None:
  104. parts.append(';' + platform_deps)
  105. item = ''.join(parts)
  106. yield item
  107. packages = list(gen_packages_items())
  108. return packages
  109. def add_mim_extension():
  110. """Add extra files that are required to support MIM into the package.
  111. These files will be added by creating a symlink to the originals if the
  112. package is installed in `editable` mode (e.g. pip install -e .), or by
  113. copying from the originals otherwise.
  114. """
  115. # parse installment mode
  116. if 'develop' in sys.argv:
  117. # installed by `pip install -e .`
  118. mode = 'symlink'
  119. elif 'sdist' in sys.argv or 'bdist_wheel' in sys.argv:
  120. # installed by `pip install .`
  121. # or create source distribution by `python setup.py sdist`
  122. mode = 'copy'
  123. else:
  124. return
  125. filenames = ['tools', 'configs', 'demo', 'model-index.yml']
  126. repo_path = osp.dirname(__file__)
  127. mim_path = osp.join(repo_path, 'mmdet', '.mim')
  128. os.makedirs(mim_path, exist_ok=True)
  129. for filename in filenames:
  130. if osp.exists(filename):
  131. src_path = osp.join(repo_path, filename)
  132. tar_path = osp.join(mim_path, filename)
  133. if osp.isfile(tar_path) or osp.islink(tar_path):
  134. os.remove(tar_path)
  135. elif osp.isdir(tar_path):
  136. shutil.rmtree(tar_path)
  137. if mode == 'symlink':
  138. src_relpath = osp.relpath(src_path, osp.dirname(tar_path))
  139. os.symlink(src_relpath, tar_path)
  140. elif mode == 'copy':
  141. if osp.isfile(src_path):
  142. shutil.copyfile(src_path, tar_path)
  143. elif osp.isdir(src_path):
  144. shutil.copytree(src_path, tar_path)
  145. else:
  146. warnings.warn(f'Cannot copy file {src_path}.')
  147. else:
  148. raise ValueError(f'Invalid mode {mode}')
  149. if __name__ == '__main__':
  150. add_mim_extension()
  151. setup(
  152. name='mmdet',
  153. version=get_version(),
  154. description='OpenMMLab Detection Toolbox and Benchmark',
  155. long_description=readme(),
  156. long_description_content_type='text/markdown',
  157. author='MMDetection Contributors',
  158. author_email='openmmlab@gmail.com',
  159. keywords='computer vision, object detection',
  160. url='https://github.com/open-mmlab/mmdetection',
  161. packages=find_packages(exclude=('configs', 'tools', 'demo')),
  162. include_package_data=True,
  163. classifiers=[
  164. 'Development Status :: 5 - Production/Stable',
  165. 'License :: OSI Approved :: Apache Software License',
  166. 'Operating System :: OS Independent',
  167. 'Programming Language :: Python :: 3',
  168. 'Programming Language :: Python :: 3.6',
  169. 'Programming Language :: Python :: 3.7',
  170. 'Programming Language :: Python :: 3.8',
  171. 'Programming Language :: Python :: 3.9',
  172. ],
  173. license='Apache License 2.0',
  174. setup_requires=parse_requirements('requirements/build.txt'),
  175. tests_require=parse_requirements('requirements/tests.txt'),
  176. install_requires=parse_requirements('requirements/runtime.txt'),
  177. extras_require={
  178. 'all': parse_requirements('requirements.txt'),
  179. 'tests': parse_requirements('requirements/tests.txt'),
  180. 'build': parse_requirements('requirements/build.txt'),
  181. 'optional': parse_requirements('requirements/optional.txt'),
  182. },
  183. ext_modules=[],
  184. cmdclass={'build_ext': BuildExtension},
  185. zip_safe=False)

No Description

Contributors (2)