You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

rnn.cpp 4.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. /**
  2. * \file dnn/src/common/rnn.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "src/common/rnn.h"
  12. #include "megdnn/oprs.h"
  13. #include "src/common/utils.h"
  14. namespace megdnn {
  15. void RNN::deduce_layout(
  16. const TensorLayout& input, const TensorLayout& hx,
  17. const TensorLayout& /*flatten_weights*/, TensorLayout& output, TensorLayout& hy,
  18. TensorLayout& reserve_space) {
  19. size_t seq_len = input.shape[0];
  20. size_t batch_size = input.shape[1];
  21. size_t D = param().bidirectional ? 2 : 1;
  22. size_t hidden_size = hx.shape[2];
  23. output = TensorLayout(
  24. TensorShape{seq_len, batch_size, D * hidden_size}, input.dtype);
  25. hy = TensorLayout(hx);
  26. reserve_space = {{get_reserve_size_in_bytes(input)}, input.dtype};
  27. }
  28. void RNN::check_exec(
  29. const TensorLayout& input, const TensorLayout& hx,
  30. const TensorLayout& flatten_weights, const TensorLayout& output,
  31. const TensorLayout& hy, const TensorLayout& /*reserve_space*/,
  32. size_t /*workspace_in_bytes*/) {
  33. auto errmsg = [&]() {
  34. std::string msg;
  35. msg.append("input=");
  36. msg.append(input.to_string());
  37. msg.append(", output=");
  38. msg.append(output.to_string());
  39. msg.append(", hx=");
  40. msg.append(hx.to_string());
  41. msg.append(", flatten_weights=");
  42. msg.append(flatten_weights.to_string());
  43. msg.append(", hy=");
  44. msg.append(hy.to_string());
  45. msg.append(", hidden_size=");
  46. msg.append(std::to_string(param().hidden_size));
  47. msg.append(", num_layers=");
  48. msg.append(std::to_string(param().num_layers));
  49. msg.append(", bidirectional=");
  50. msg.append(std::to_string(param().bidirectional));
  51. return msg;
  52. };
  53. size_t D = param().bidirectional ? 2 : 1;
  54. size_t b = param().bias ? 1 : 0;
  55. size_t num_layers = param().num_layers;
  56. size_t input_size = input.shape[2];
  57. size_t gate_hidden_size = param().hidden_size;
  58. // calculate size_dim1 the same as lstm
  59. size_t size_dim1 = D * (input_size + param().hidden_size) +
  60. (num_layers - 1) * D * ((D + 1) * param().hidden_size) +
  61. b * 2 * D * num_layers;
  62. #define ASSERT_BRIEF(_content) megdnn_assert(_content, "%s", errmsg().c_str());
  63. ASSERT_BRIEF(hx.ndim == 3)
  64. ASSERT_BRIEF(input.ndim == 3)
  65. ASSERT_BRIEF(output.ndim == 3)
  66. ASSERT_BRIEF(hy.ndim == 3)
  67. ASSERT_BRIEF(flatten_weights.shape[0] == gate_hidden_size)
  68. ASSERT_BRIEF(flatten_weights.shape[0] == size_dim1)
  69. ASSERT_BRIEF(hx.shape[0] == D * num_layers)
  70. ASSERT_BRIEF(hx.shape[1] == input.shape[1]) // batch_size
  71. ASSERT_BRIEF(hx.shape[2] == param().hidden_size)
  72. ASSERT_BRIEF(output.shape[0] == input.shape[0])
  73. ASSERT_BRIEF(output.shape[1] == input.shape[1])
  74. ASSERT_BRIEF(output.shape[2] == D * param().hidden_size)
  75. ASSERT_BRIEF(hy.shape[0] == hx.shape[0])
  76. ASSERT_BRIEF(hy.shape[1] == hx.shape[1])
  77. ASSERT_BRIEF(hy.shape[2] == hx.shape[2])
  78. #undef ASSERT_BRIEF
  79. }
  80. void RNNBackward::deduce_layout(
  81. const TensorLayout& x, const TensorLayout& /*y*/, const TensorLayout& hx,
  82. const TensorLayout& /*dy*/, const TensorLayout& /*dhy*/,
  83. const TensorLayout& flatten_weights, const TensorLayout& /*reserve_space*/,
  84. TensorLayout& dx, TensorLayout& dhx, TensorLayout& dw) {
  85. dx = x;
  86. dhx = hx;
  87. dw = flatten_weights;
  88. }
  89. void RNNBackward::check_exec(
  90. const TensorLayout& /*x*/, const TensorLayout& /*y*/,
  91. const TensorLayout& /*hx*/, const TensorLayout& /*dy*/,
  92. const TensorLayout& /*dhy*/, const TensorLayout& /*flatten_weights*/,
  93. const TensorLayout& /*reserve_space*/, const TensorLayout& /*dx*/,
  94. const TensorLayout& /*dhx*/, const TensorLayout& /*dw*/,
  95. size_t /*workspace_in_bytes*/) {}
  96. } // namespace megdnn