You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

concat_split.cpp 3.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. /**
  2. * \file dnn/src/common/concat_split.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megdnn/oprs.h"
  12. #include "src/common/utils.h"
  13. #include <numeric>
  14. namespace megdnn {
  15. ConcatSplitBase::ConcatSplitBase(Handle* handle)
  16. : OperatorBase(handle),
  17. m_get_layout([](const TensorND& tensor) { return tensor.layout; }),
  18. m_get_shape([](const TensorLayout& layout) { return TensorShape(layout); }) {}
  19. void ConcatSplitBase::check_layout_common(
  20. const TensorLayoutArray& srcs, const TensorLayout& dst) {
  21. // ensure same data type
  22. for (auto&& src : srcs) {
  23. megdnn_assert(src.dtype == dst.dtype);
  24. }
  25. // ensure all layouts are contiguous
  26. for (auto&& src : srcs) {
  27. megdnn_assert_contiguous(src);
  28. }
  29. megdnn_assert_contiguous(dst);
  30. // ensure all layouts have the same ndim
  31. auto ndim = dst.ndim;
  32. for (auto&& src : srcs) {
  33. megdnn_assert_eq_size_t(src.ndim, ndim);
  34. }
  35. // ensure param().axis is correct
  36. megdnn_assert(
  37. param().axis < static_cast<int32_t>(ndim), "param().axis=%u, ndim=%zu",
  38. param().axis, ndim);
  39. // ensure shape size for each axis is correct
  40. for (size_t i = 0; i < ndim; ++i) {
  41. if (i == static_cast<size_t>(param().axis)) {
  42. size_t sum = 0_z;
  43. for (auto&& src : srcs)
  44. sum += src.shape[i];
  45. megdnn_assert_eq_size_t(sum, dst.shape[i]);
  46. } else {
  47. for (auto&& src : srcs) {
  48. megdnn_assert(src.shape[i] == dst.shape[i]);
  49. megdnn_assert_eq_size_t(src.shape[i], dst.shape[i]);
  50. }
  51. }
  52. }
  53. }
  54. void ConcatSplitBase::get_ABC(
  55. const TensorShapeArray& srcs, size_t& A, size_t* B, size_t& C) {
  56. auto axis = param().axis;
  57. auto shape_arr = srcs[0].shape;
  58. auto ndim = srcs[0].ndim;
  59. A = std::accumulate(shape_arr, shape_arr + axis, 1_z, SafeMultiplies<size_t>());
  60. for (size_t i = 0u; i < srcs.size(); ++i) {
  61. B[i] = srcs[i].shape[axis];
  62. }
  63. C = std::accumulate(
  64. shape_arr + (axis + 1), shape_arr + ndim, 1_z, SafeMultiplies<size_t>());
  65. }
  66. void ConcatForward::deduce_layout(const TensorLayoutArray& srcs, TensorLayout& dst) {
  67. dst = srcs[0];
  68. auto i = param().axis;
  69. dst.shape[i] = 0u;
  70. for (auto&& src : srcs) {
  71. dst.shape[i] += src.shape[i];
  72. }
  73. dst.init_contiguous_stride();
  74. }
  75. void ConcatForward::check_exec(
  76. const TensorLayoutArray& srcs, const TensorLayout& dst,
  77. size_t workspace_in_bytes) {
  78. check_layout_common(srcs, dst);
  79. auto required_workspace_in_bytes = get_workspace_in_bytes(srcs, dst);
  80. megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
  81. }
  82. void SplitForward::check_exec(
  83. const TensorLayout& src, const TensorLayoutArray& dsts,
  84. size_t workspace_in_bytes) {
  85. check_layout_common(dsts, src);
  86. auto required_workspace_in_bytes = get_workspace_in_bytes(src, dsts);
  87. megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
  88. }
  89. } // namespace megdnn
  90. // vim: syntax=cpp.doxygen