You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

setup.py 3.7 kB

3 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. #!/usr/bin/env python
  2. # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
  3. import glob
  4. import os
  5. import shutil
  6. from setuptools import find_packages, setup
  7. from typing import List
  8. import torch
  9. from torch.utils.cpp_extension import CUDA_HOME, CppExtension, CUDAExtension
  10. torch_ver = [int(x) for x in torch.__version__.split(".")[:2]]
  11. assert torch_ver >= [1, 3], "Requires PyTorch >= 1.3"
  12. def get_extensions():
  13. this_dir = os.path.dirname(os.path.abspath(__file__))
  14. extensions_dir = os.path.join(this_dir, "detectron2", "layers", "csrc")
  15. main_source = os.path.join(extensions_dir, "vision.cpp")
  16. sources = glob.glob(os.path.join(extensions_dir, "**", "*.cpp"))
  17. source_cuda = glob.glob(os.path.join(extensions_dir, "**", "*.cu")) + glob.glob(
  18. os.path.join(extensions_dir, "*.cu")
  19. )
  20. sources = [main_source] + sources
  21. extension = CppExtension
  22. extra_compile_args = {"cxx": []}
  23. define_macros = []
  24. if (torch.cuda.is_available() and CUDA_HOME is not None) or os.getenv("FORCE_CUDA", "0") == "1":
  25. extension = CUDAExtension
  26. sources += source_cuda
  27. define_macros += [("WITH_CUDA", None)]
  28. extra_compile_args["nvcc"] = [
  29. "-DCUDA_HAS_FP16=1",
  30. "-D__CUDA_NO_HALF_OPERATORS__",
  31. "-D__CUDA_NO_HALF_CONVERSIONS__",
  32. "-D__CUDA_NO_HALF2_OPERATORS__",
  33. ]
  34. # It's better if pytorch can do this by default ..
  35. CC = os.environ.get("CC", None)
  36. if CC is not None:
  37. extra_compile_args["nvcc"].append("-ccbin={}".format(CC))
  38. include_dirs = [extensions_dir]
  39. ext_modules = [
  40. extension(
  41. "detectron2._C",
  42. sources,
  43. include_dirs=include_dirs,
  44. define_macros=define_macros,
  45. extra_compile_args=extra_compile_args,
  46. )
  47. ]
  48. return ext_modules
  49. def get_model_zoo_configs() -> List[str]:
  50. """
  51. Return a list of configs to include in package for model zoo. Copy over these configs inside
  52. detectron2/model_zoo.
  53. """
  54. # Use absolute paths while symlinking.
  55. source_configs_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "configs")
  56. destination = os.path.join(
  57. os.path.dirname(os.path.realpath(__file__)), "detectron2", "model_zoo", "configs"
  58. )
  59. # Symlink the config directory inside package to have a cleaner pip install.
  60. if os.path.exists(destination):
  61. # Remove stale symlink/directory from a previous build.
  62. if os.path.islink(destination):
  63. os.unlink(destination)
  64. else:
  65. shutil.rmtree(destination)
  66. try:
  67. os.symlink(source_configs_dir, destination)
  68. except OSError:
  69. # Fall back to copying if symlink fails: ex. on Windows.
  70. shutil.copytree(source_configs_dir, destination)
  71. config_paths = glob.glob("configs/**/*.yaml", recursive=True)
  72. return config_paths
  73. setup(
  74. name="detectron2",
  75. version="0.1",
  76. author="FAIR",
  77. url="https://github.com/facebookresearch/detectron2",
  78. description="Detectron2 is FAIR's next-generation research "
  79. "platform for object detection and segmentation.",
  80. packages=find_packages(exclude=("configs", "tests")),
  81. package_data={"detectron2.model_zoo": get_model_zoo_configs()},
  82. python_requires=">=3.6",
  83. install_requires=[
  84. "termcolor>=1.1",
  85. "Pillow>=6.0",
  86. "yacs>=0.1.6",
  87. "tabulate",
  88. "cloudpickle",
  89. "matplotlib",
  90. "tqdm>4.29.0",
  91. "tensorboard",
  92. "imagesize",
  93. ],
  94. extras_require={"all": ["shapely", "psutil"]},
  95. ext_modules=get_extensions(),
  96. cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension},
  97. )

No Description