You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

setup.py 2.2 kB

3 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. #!/usr/bin/env python
  2. # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
  3. import glob
  4. import os
  5. from setuptools import find_packages, setup
  6. import torch
  7. from torch.utils.cpp_extension import CUDA_HOME, CppExtension, CUDAExtension
  8. torch_ver = [int(x) for x in torch.__version__.split(".")[:2]]
  9. assert torch_ver >= [1, 3], "Requires PyTorch >= 1.3"
  10. def get_extensions():
  11. this_dir = os.path.dirname(os.path.abspath(__file__))
  12. extensions_dir = os.path.join(this_dir, "tensormask", "layers", "csrc")
  13. main_source = os.path.join(extensions_dir, "vision.cpp")
  14. sources = glob.glob(os.path.join(extensions_dir, "**", "*.cpp"))
  15. source_cuda = glob.glob(os.path.join(extensions_dir, "**", "*.cu")) + glob.glob(
  16. os.path.join(extensions_dir, "*.cu")
  17. )
  18. sources = [main_source] + sources
  19. extension = CppExtension
  20. extra_compile_args = {"cxx": []}
  21. define_macros = []
  22. if (torch.cuda.is_available() and CUDA_HOME is not None) or os.getenv("FORCE_CUDA", "0") == "1":
  23. extension = CUDAExtension
  24. sources += source_cuda
  25. define_macros += [("WITH_CUDA", None)]
  26. extra_compile_args["nvcc"] = [
  27. "-DCUDA_HAS_FP16=1",
  28. "-D__CUDA_NO_HALF_OPERATORS__",
  29. "-D__CUDA_NO_HALF_CONVERSIONS__",
  30. "-D__CUDA_NO_HALF2_OPERATORS__",
  31. ]
  32. # It's better if pytorch can do this by default ..
  33. CC = os.environ.get("CC", None)
  34. if CC is not None:
  35. extra_compile_args["nvcc"].append("-ccbin={}".format(CC))
  36. sources = [os.path.join(extensions_dir, s) for s in sources]
  37. include_dirs = [extensions_dir]
  38. ext_modules = [
  39. extension(
  40. "tensormask._C",
  41. sources,
  42. include_dirs=include_dirs,
  43. define_macros=define_macros,
  44. extra_compile_args=extra_compile_args,
  45. )
  46. ]
  47. return ext_modules
  48. setup(
  49. name="tensormask",
  50. version="0.1",
  51. author="FAIR",
  52. packages=find_packages(exclude=("configs", "tests")),
  53. python_requires=">=3.6",
  54. ext_modules=get_extensions(),
  55. cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension},
  56. )

No Description