You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

layer_norm.cpp 6.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. /**
  2. * \file dnn/src/common/layer_norm.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "megdnn/oprs.h"
  13. #include "src/common/utils.h"
  14. namespace megdnn {
  15. void LayerNormBase::deduce_layout_fwd(
  16. const TensorLayout& data, const TensorLayout& weight, const TensorLayout& bias,
  17. TensorLayout& dst, TensorLayout& mean, TensorLayout& rstd) {
  18. MEGDNN_MARK_USED_VAR(weight);
  19. MEGDNN_MARK_USED_VAR(bias);
  20. auto p = param();
  21. TensorShape unnormalized_shape;
  22. unnormalized_shape.ndim = data.ndim - p.normalized_dim;
  23. for (size_t i = 0; i < unnormalized_shape.ndim; ++i) {
  24. unnormalized_shape.shape[i] = data.shape[i];
  25. }
  26. TensorLayout unnormalized_layout =
  27. TensorLayout(unnormalized_shape, dtype::Float32());
  28. dst = data;
  29. mean = unnormalized_layout;
  30. rstd = unnormalized_layout;
  31. }
  32. void LayerNormBase::check_layout_fwd(
  33. const TensorLayout& data, const TensorLayout& weight, const TensorLayout& bias,
  34. const TensorLayout& dst, const TensorLayout& mean, const TensorLayout& rstd) {
  35. megdnn_assert_contiguous(data);
  36. megdnn_assert_contiguous(weight);
  37. megdnn_assert_contiguous(bias);
  38. megdnn_assert_contiguous(dst);
  39. megdnn_assert_contiguous(mean);
  40. megdnn_assert_contiguous(rstd);
  41. auto errmsg = [&]() {
  42. return megdnn_layout_msg(data) + ", " + megdnn_layout_msg(weight) + ", " +
  43. megdnn_layout_msg(bias) + ", " + megdnn_layout_msg(dst) + ", " +
  44. megdnn_layout_msg(mean) + ", " + megdnn_layout_msg(rstd);
  45. };
  46. MEGDNN_MARK_USED_VAR(errmsg);
  47. auto equal_layout = [](const TensorLayout& lhs, const TensorLayout& rhs) -> bool {
  48. if (!(lhs.ndim == rhs.ndim && lhs.dtype == rhs.dtype &&
  49. lhs.format == rhs.format))
  50. return false;
  51. for (size_t i = 0; i < lhs.ndim; ++i) {
  52. if (lhs.shape[i] != rhs.shape[i] || lhs.stride[i] != rhs.stride[i]) {
  53. return false;
  54. }
  55. }
  56. return true;
  57. };
  58. megdnn_assert(equal_layout(data, dst), "%s", errmsg().c_str());
  59. megdnn_assert(equal_layout(weight, bias), "%s", errmsg().c_str());
  60. megdnn_assert(equal_layout(mean, rstd), "%s", errmsg().c_str());
  61. auto p = param();
  62. uint64_t normalized_dim = p.normalized_dim;
  63. size_t unnormalized_dim = data.ndim - normalized_dim;
  64. megdnn_assert(
  65. normalized_dim < data.ndim,
  66. "the dims of normalized shape should smaller than input dims");
  67. for (size_t i = 0; i < unnormalized_dim; ++i) {
  68. megdnn_assert(data.shape[i] == mean.shape[i], "%s", errmsg().c_str());
  69. }
  70. if (p.affine) {
  71. for (size_t i = 0; i < normalized_dim; ++i) {
  72. megdnn_assert(
  73. data.shape[unnormalized_dim + i] == weight.shape[i], "%s",
  74. errmsg().c_str());
  75. }
  76. }
  77. }
  78. void LayerNormForward::deduce_layout(
  79. const TensorLayout& data, const TensorLayout& weight, const TensorLayout& bias,
  80. TensorLayout& dst, TensorLayout& mean, TensorLayout& rstd) {
  81. deduce_layout_fwd(data, weight, bias, dst, mean, rstd);
  82. }
  83. void LayerNormForward::check_exec(
  84. const TensorLayout& data, const TensorLayout& weight, const TensorLayout& bias,
  85. const TensorLayout& dst, const TensorLayout& mean, const TensorLayout& rstd,
  86. size_t workspace_in_bytes) {
  87. check_layout_fwd(data, weight, bias, dst, mean, rstd);
  88. auto required_workspace_in_bytes =
  89. get_workspace_in_bytes(data, weight, bias, dst, mean, rstd);
  90. megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
  91. }
  92. void LayerNormBackward::deduce_layout(
  93. const TensorLayout& diff, const TensorLayout& data, const TensorLayout& weight,
  94. const TensorLayout& mean, const TensorLayout& rstd, TensorLayout& ddata,
  95. TensorLayout& dweight, TensorLayout& dbias) {
  96. MEGDNN_MARK_USED_VAR(diff);
  97. MEGDNN_MARK_USED_VAR(mean);
  98. MEGDNN_MARK_USED_VAR(rstd);
  99. ddata = data;
  100. dweight = weight;
  101. dbias = weight;
  102. }
  103. void LayerNormBackward::check_exec(
  104. const TensorLayout& diff, const TensorLayout& data, const TensorLayout& weight,
  105. const TensorLayout& mean, const TensorLayout& rstd, const TensorLayout& ddata,
  106. const TensorLayout& dweight, const TensorLayout& dbias,
  107. size_t workspace_in_bytes) {
  108. auto p = param();
  109. auto required_workspace_in_bytes = get_workspace_in_bytes(
  110. diff, data, weight, mean, rstd, ddata, dweight, dbias);
  111. megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
  112. megdnn_assert_contiguous(diff);
  113. megdnn_assert_contiguous(data);
  114. megdnn_assert_contiguous(mean);
  115. megdnn_assert_contiguous(rstd);
  116. megdnn_assert_contiguous(ddata);
  117. if (p.affine) {
  118. megdnn_assert_contiguous(weight);
  119. megdnn_assert_contiguous(dweight);
  120. megdnn_assert_contiguous(dbias);
  121. }
  122. auto errmsg = [&]() {
  123. return megdnn_layout_msg(diff) + ", " + megdnn_layout_msg(data) + ", " +
  124. megdnn_layout_msg(weight) + ", " + megdnn_layout_msg(mean) + ", " +
  125. megdnn_layout_msg(rstd) + ", " + megdnn_layout_msg(ddata) + ", " +
  126. megdnn_layout_msg(dweight) + ", " + megdnn_layout_msg(dbias);
  127. };
  128. MEGDNN_MARK_USED_VAR(errmsg);
  129. auto equal_layout = [](const TensorLayout& lhs, const TensorLayout& rhs) -> bool {
  130. if (!(lhs.ndim == rhs.ndim && lhs.dtype == rhs.dtype &&
  131. lhs.format == rhs.format))
  132. return false;
  133. for (size_t i = 0; i < lhs.ndim; ++i) {
  134. if (lhs.shape[i] != rhs.shape[i] || lhs.stride[i] != rhs.stride[i]) {
  135. return false;
  136. }
  137. }
  138. return true;
  139. };
  140. megdnn_assert(equal_layout(data, ddata), "%s", errmsg().c_str());
  141. megdnn_assert(equal_layout(mean, rstd), "%s", errmsg().c_str());
  142. if (p.affine) {
  143. megdnn_assert(equal_layout(weight, dweight), "%s", errmsg().c_str());
  144. megdnn_assert(equal_layout(weight, dbias), "%s", errmsg().c_str());
  145. }
  146. size_t normalized_dim = p.normalized_dim;
  147. size_t unnormalized_dim = data.ndim - normalized_dim;
  148. for (size_t i = 0; i < unnormalized_dim; ++i) {
  149. megdnn_assert(data.shape[i] == mean.shape[i], "%s", errmsg().c_str());
  150. }
  151. if (p.affine) {
  152. for (size_t i = 0; i < normalized_dim; ++i) {
  153. megdnn_assert(
  154. data.shape[unnormalized_dim + i] == weight.shape[i], "%s",
  155. errmsg().c_str());
  156. }
  157. }
  158. }
  159. } // namespace megdnn
  160. // vim: syntax=cpp.doxygen