You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

vision.cpp 2.7 kB

3 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. // Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
  2. #include <torch/extension.h>
  3. #include "ROIAlign/ROIAlign.h"
  4. #include "ROIAlignRotated/ROIAlignRotated.h"
  5. #include "box_iou_rotated/box_iou_rotated.h"
  6. #include "deformable/deform_conv.h"
  7. #include "nms_rotated/nms_rotated.h"
  8. namespace detectron2 {
  9. #ifdef WITH_CUDA
  10. extern int get_cudart_version();
  11. #endif
  12. std::string get_cuda_version() {
  13. #ifdef WITH_CUDA
  14. std::ostringstream oss;
  15. // copied from
  16. // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/cuda/detail/CUDAHooks.cpp#L231
  17. auto printCudaStyleVersion = [&](int v) {
  18. oss << (v / 1000) << "." << (v / 10 % 100);
  19. if (v % 10 != 0) {
  20. oss << "." << (v % 10);
  21. }
  22. };
  23. printCudaStyleVersion(get_cudart_version());
  24. return oss.str();
  25. #else
  26. return std::string("not available");
  27. #endif
  28. }
  29. // similar to
  30. // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Version.cpp
  31. std::string get_compiler_version() {
  32. std::ostringstream ss;
  33. #if defined(__GNUC__)
  34. #ifndef __clang__
  35. { ss << "GCC " << __GNUC__ << "." << __GNUC_MINOR__; }
  36. #endif
  37. #endif
  38. #if defined(__clang_major__)
  39. {
  40. ss << "clang " << __clang_major__ << "." << __clang_minor__ << "."
  41. << __clang_patchlevel__;
  42. }
  43. #endif
  44. #if defined(_MSC_VER)
  45. { ss << "MSVC " << _MSC_FULL_VER; }
  46. #endif
  47. return ss.str();
  48. }
  49. PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  50. m.def("get_compiler_version", &get_compiler_version, "get_compiler_version");
  51. m.def("get_cuda_version", &get_cuda_version, "get_cuda_version");
  52. m.def("box_iou_rotated", &box_iou_rotated, "IoU for rotated boxes");
  53. m.def("deform_conv_forward", &deform_conv_forward, "deform_conv_forward");
  54. m.def(
  55. "deform_conv_backward_input",
  56. &deform_conv_backward_input,
  57. "deform_conv_backward_input");
  58. m.def(
  59. "deform_conv_backward_filter",
  60. &deform_conv_backward_filter,
  61. "deform_conv_backward_filter");
  62. m.def(
  63. "modulated_deform_conv_forward",
  64. &modulated_deform_conv_forward,
  65. "modulated_deform_conv_forward");
  66. m.def(
  67. "modulated_deform_conv_backward",
  68. &modulated_deform_conv_backward,
  69. "modulated_deform_conv_backward");
  70. m.def("nms_rotated", &nms_rotated, "NMS for rotated boxes");
  71. m.def("roi_align_forward", &ROIAlign_forward, "ROIAlign_forward");
  72. m.def("roi_align_backward", &ROIAlign_backward, "ROIAlign_backward");
  73. m.def(
  74. "roi_align_rotated_forward",
  75. &ROIAlignRotated_forward,
  76. "Forward pass for Rotated ROI-Align Operator");
  77. m.def(
  78. "roi_align_rotated_backward",
  79. &ROIAlignRotated_backward,
  80. "Backward pass for Rotated ROI-Align Operator");
  81. }
  82. } // namespace detectron2

No Description