You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

correlation.cpp 4.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. /**
  2. * \file dnn/src/common/correlation.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "megdnn/oprs.h"
  13. #include "src/common/utils.h"
  14. namespace megdnn {
  15. void CorrelationBase::deduce_layout_fwd(
  16. const TensorLayout& data1, const TensorLayout& data2, TensorLayout& dst) {
  17. megdnn_assert_contiguous(data1);
  18. megdnn_assert_contiguous(data2);
  19. megdnn_assert_contiguous(dst);
  20. auto errmsg = [&]() {
  21. return megdnn_layout_msg(data1) + ", " + megdnn_layout_msg(data2) + ", " +
  22. megdnn_layout_msg(dst);
  23. };
  24. MEGDNN_MARK_USED_VAR(errmsg);
  25. using Format = CorrelationBase::Param::Format;
  26. megdnn_assert(param().format == Format::NCHW);
  27. auto data1_dtype = data1.dtype, data2_dtype = data2.dtype;
  28. megdnn_assert(
  29. data1_dtype == data2_dtype &&
  30. data1_dtype.category() == DTypeCategory::FLOAT);
  31. megdnn_assert(data1.ndim == 4_z, "%s", errmsg().c_str());
  32. megdnn_assert(data2.ndim == 4_z, "%s", errmsg().c_str());
  33. uint32_t pad_size = param().pad_size;
  34. uint32_t kernel_size = param().kernel_size;
  35. uint32_t stride1 = param().stride1;
  36. uint32_t stride2 = param().stride2;
  37. uint32_t max_displacement = param().max_displacement;
  38. int paddedbottomheight = data1[2] + 2 * pad_size;
  39. int paddedbottomwidth = data1[3] + 2 * pad_size;
  40. uint32_t kernel_radius = (kernel_size - 1) / 2;
  41. uint32_t border_size = max_displacement + kernel_radius;
  42. uint32_t top_width =
  43. ceil(static_cast<float>(paddedbottomwidth - border_size * 2) /
  44. static_cast<float>(stride1));
  45. uint32_t top_height =
  46. ceil(static_cast<float>(paddedbottomheight - border_size * 2) /
  47. static_cast<float>(stride1));
  48. uint32_t neighborhood_grid_radius = max_displacement / stride2;
  49. uint32_t neighborhood_grid_width = neighborhood_grid_radius * 2 + 1;
  50. uint32_t top_channels = neighborhood_grid_width * neighborhood_grid_width;
  51. megdnn_assert(top_width >= 1 && top_height >= 1);
  52. dst = TensorLayout{{data1[0], top_channels, top_height, top_width}, data1.dtype};
  53. }
  54. void CorrelationBase::check_layout_fwd(
  55. const TensorLayout& data1, const TensorLayout& data2, const TensorLayout& dst) {
  56. TensorLayout dst_expected;
  57. megdnn_assert_eq_dtype(data1, dst);
  58. megdnn_assert_eq_shape(data1, data2);
  59. deduce_layout_fwd(data1, data2, dst_expected);
  60. megdnn_assert_eq_shape(dst_expected, dst);
  61. }
  62. void CorrelationForward::deduce_layout(
  63. const TensorLayout& data1, const TensorLayout& data2, TensorLayout& dst) {
  64. deduce_layout_fwd(data1, data2, dst);
  65. }
  66. void CorrelationForward::check_exec(
  67. const TensorLayout& data1, const TensorLayout& data2, const TensorLayout& dst,
  68. size_t workspace_in_bytes) {
  69. check_layout_fwd(data1, data2, dst);
  70. auto required_workspace_in_bytes = get_workspace_in_bytes(data1, data2, dst);
  71. megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
  72. }
  73. void CorrelationBackwardData1::check_exec(
  74. const TensorLayout& diff, const TensorLayout& data1, const TensorLayout& data2,
  75. const TensorLayout& grad1, size_t workspace_in_bytes) {
  76. check_layout_fwd(grad1, data2, diff);
  77. megdnn_assert_eq_shape(data1, data2);
  78. auto required_workspace_in_bytes =
  79. get_workspace_in_bytes(diff, data1, data2, grad1);
  80. megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
  81. }
  82. void CorrelationBackwardData2::check_exec(
  83. const TensorLayout& diff, const TensorLayout& data1, const TensorLayout& data2,
  84. const TensorLayout& grad2, size_t workspace_in_bytes) {
  85. check_layout_fwd(data1, grad2, diff);
  86. megdnn_assert_eq_shape(data1, data2);
  87. auto required_workspace_in_bytes =
  88. get_workspace_in_bytes(diff, data1, data2, grad2);
  89. megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
  90. }
  91. void CorrelationBackwardData2::deduce_layout(
  92. const TensorLayout& diff, const TensorLayout& data1, const TensorLayout& data2,
  93. TensorLayout& grad) {
  94. megdnn_assert_eq_shape(data1, data2);
  95. check_layout_fwd(data1, data2, diff);
  96. grad = data2;
  97. }
  98. void CorrelationBackwardData1::deduce_layout(
  99. const TensorLayout& diff, const TensorLayout& data1, const TensorLayout& data2,
  100. TensorLayout& grad) {
  101. megdnn_assert_eq_shape(data1, data2);
  102. check_layout_fwd(data1, data2, diff);
  103. grad = data1;
  104. }
  105. } // namespace megdnn
  106. // vim: syntax=cpp.doxygen