GitOrigin-RevId: 0cd484e753
tags/v1.8.0
| @@ -1936,6 +1936,75 @@ protected: | |||||
| const TensorLayout& grad_s, size_t workspace_in_bytes); | const TensorLayout& grad_s, size_t workspace_in_bytes); | ||||
| }; | }; | ||||
| class LayerNormBase : public OperatorBase { | |||||
| DEF_OPR_IMPL_CTOR(LayerNormBase, OperatorBase); | |||||
| DEF_OPR_PARAM(LayerNorm); | |||||
| protected: | |||||
| void deduce_layout_fwd( | |||||
| const TensorLayout& data, const TensorLayout& weight, | |||||
| const TensorLayout& bias, TensorLayout& dst, TensorLayout& mean, | |||||
| TensorLayout& rstd); | |||||
| void check_layout_fwd( | |||||
| const TensorLayout& data, const TensorLayout& weight, | |||||
| const TensorLayout& bias, const TensorLayout& dst, const TensorLayout& mean, | |||||
| const TensorLayout& rstd); | |||||
| }; | |||||
| class LayerNormForward : public LayerNormBase { | |||||
| DEF_OPR_IMPL(LayerNormForward, LayerNormBase, 3, 3); | |||||
| public: | |||||
| virtual void exec( | |||||
| _megdnn_tensor_in data, _megdnn_tensor_in weight, _megdnn_tensor_in bias, | |||||
| _megdnn_tensor_out dst, _megdnn_tensor_out mean, _megdnn_tensor_out rstd, | |||||
| _megdnn_workspace workspace) = 0; | |||||
| void deduce_layout( | |||||
| const TensorLayout& data, const TensorLayout& weight, | |||||
| const TensorLayout& bias, TensorLayout& dst, TensorLayout& mean, | |||||
| TensorLayout& rstd); | |||||
| virtual size_t get_workspace_in_bytes( | |||||
| const TensorLayout& data, const TensorLayout& weight, | |||||
| const TensorLayout& bias, const TensorLayout& dst, const TensorLayout& mean, | |||||
| const TensorLayout& rstd) = 0; | |||||
| protected: | |||||
| void check_exec( | |||||
| const TensorLayout& data, const TensorLayout& weight, | |||||
| const TensorLayout& bias, const TensorLayout& dst, const TensorLayout& mean, | |||||
| const TensorLayout& rstd, size_t workspace_in_bytes); | |||||
| }; | |||||
| using LayerNorm = LayerNormForward; | |||||
| class LayerNormBackward : public LayerNormBase { | |||||
| DEF_OPR_IMPL(LayerNormBackward, LayerNormBase, 5, 3); | |||||
| public: | |||||
| virtual void exec( | |||||
| _megdnn_tensor_in diff, _megdnn_tensor_in data, _megdnn_tensor_in weight, | |||||
| _megdnn_tensor_in mean, _megdnn_tensor_in rstd, _megdnn_tensor_out ddata, | |||||
| _megdnn_tensor_out dweight, _megdnn_tensor_out dbias, | |||||
| _megdnn_workspace workspace) = 0; | |||||
| void deduce_layout( | |||||
| const TensorLayout& diff, const TensorLayout& data, | |||||
| const TensorLayout& weight, const TensorLayout& mean, | |||||
| const TensorLayout& rstd, TensorLayout& ddata, TensorLayout& dweight, | |||||
| TensorLayout& dbias); | |||||
| virtual size_t get_workspace_in_bytes( | |||||
| const TensorLayout& diff, const TensorLayout& data, | |||||
| const TensorLayout& weight, const TensorLayout& mean, | |||||
| const TensorLayout& rstd, const TensorLayout& ddata, | |||||
| const TensorLayout& dweight, const TensorLayout& dbias) = 0; | |||||
| protected: | |||||
| void check_exec( | |||||
| const TensorLayout& diff, const TensorLayout& data, | |||||
| const TensorLayout& weight, const TensorLayout& mean, | |||||
| const TensorLayout& rstd, const TensorLayout& ddata, | |||||
| const TensorLayout& dweight, const TensorLayout& dbias, | |||||
| size_t workspace_in_bytes); | |||||
| }; | |||||
| } // namespace megdnn | } // namespace megdnn | ||||
| #include "megdnn/internal/opr_header_epilogue.h" | #include "megdnn/internal/opr_header_epilogue.h" | ||||
| @@ -1212,3 +1212,10 @@ PADDING_MODES = [Doc('REPLICATE = 0', 'aaaaaa|abcdefgh|hhhhhhh'), | |||||
| member_alias=[(i, 'PADDING_{}'.format(i)) for i in PADDING_MODES] | member_alias=[(i, 'PADDING_{}'.format(i)) for i in PADDING_MODES] | ||||
| ) | ) | ||||
| ) | ) | ||||
| (pdef('LayerNorm') | |||||
| .add_fields('bool', 'affine', 'true') | |||||
| .add_fields('float32', 'eps', '1e-5f') | |||||
| .add_fields('uint64', 'normalized_dim', '1') | |||||
| .add_fields('uint64', 'normalized_size', '1') | |||||
| ) | |||||
| @@ -209,7 +209,10 @@ private: | |||||
| cb(LSQBackward) \ | cb(LSQBackward) \ | ||||
| cb(Fill) \ | cb(Fill) \ | ||||
| cb(PaddingForward) \ | cb(PaddingForward) \ | ||||
| cb(PaddingBackward) | |||||
| cb(PaddingBackward) \ | |||||
| cb(LayerNormForward) \ | |||||
| cb(LayerNormBackward) | |||||
| // clang-format on | // clang-format on | ||||
| /*! | /*! | ||||
| @@ -0,0 +1,180 @@ | |||||
| /** | |||||
| * \file dnn/src/common/layer_norm.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #include "megdnn/oprs.h" | |||||
| #include "src/common/utils.h" | |||||
| namespace megdnn { | |||||
| void LayerNormBase::deduce_layout_fwd( | |||||
| const TensorLayout& data, const TensorLayout& weight, const TensorLayout& bias, | |||||
| TensorLayout& dst, TensorLayout& mean, TensorLayout& rstd) { | |||||
| MEGDNN_MARK_USED_VAR(weight); | |||||
| MEGDNN_MARK_USED_VAR(bias); | |||||
| auto p = param(); | |||||
| TensorShape unnormalized_shape; | |||||
| unnormalized_shape.ndim = data.ndim - p.normalized_dim; | |||||
| for (size_t i = 0; i < unnormalized_shape.ndim; ++i) { | |||||
| unnormalized_shape.shape[i] = data.shape[i]; | |||||
| } | |||||
| TensorLayout unnormalized_layout = | |||||
| TensorLayout(unnormalized_shape, dtype::Float32()); | |||||
| dst = data; | |||||
| mean = unnormalized_layout; | |||||
| rstd = unnormalized_layout; | |||||
| } | |||||
| void LayerNormBase::check_layout_fwd( | |||||
| const TensorLayout& data, const TensorLayout& weight, const TensorLayout& bias, | |||||
| const TensorLayout& dst, const TensorLayout& mean, const TensorLayout& rstd) { | |||||
| megdnn_assert_contiguous(data); | |||||
| megdnn_assert_contiguous(weight); | |||||
| megdnn_assert_contiguous(bias); | |||||
| megdnn_assert_contiguous(dst); | |||||
| megdnn_assert_contiguous(mean); | |||||
| megdnn_assert_contiguous(rstd); | |||||
| auto errmsg = [&]() { | |||||
| return megdnn_layout_msg(data) + ", " + megdnn_layout_msg(weight) + ", " + | |||||
| megdnn_layout_msg(bias) + ", " + megdnn_layout_msg(dst) + ", " + | |||||
| megdnn_layout_msg(mean) + ", " + megdnn_layout_msg(rstd); | |||||
| }; | |||||
| MEGDNN_MARK_USED_VAR(errmsg); | |||||
| auto equal_layout = [](const TensorLayout& lhs, const TensorLayout& rhs) -> bool { | |||||
| if (!(lhs.ndim == rhs.ndim && lhs.dtype == rhs.dtype && | |||||
| lhs.format == rhs.format)) | |||||
| return false; | |||||
| for (size_t i = 0; i < lhs.ndim; ++i) { | |||||
| if (lhs.shape[i] != rhs.shape[i] || lhs.stride[i] != rhs.stride[i]) { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| return true; | |||||
| }; | |||||
| megdnn_assert(equal_layout(data, dst), "%s", errmsg().c_str()); | |||||
| megdnn_assert(equal_layout(weight, bias), "%s", errmsg().c_str()); | |||||
| megdnn_assert(equal_layout(mean, rstd), "%s", errmsg().c_str()); | |||||
| auto p = param(); | |||||
| uint64_t normalized_dim = p.normalized_dim; | |||||
| size_t unnormalized_dim = data.ndim - normalized_dim; | |||||
| megdnn_assert( | |||||
| normalized_dim < data.ndim, | |||||
| "the dims of normalized shape should smaller than input dims"); | |||||
| for (size_t i = 0; i < unnormalized_dim; ++i) { | |||||
| megdnn_assert(data.shape[i] == mean.shape[i], "%s", errmsg().c_str()); | |||||
| } | |||||
| if (p.affine) { | |||||
| for (size_t i = 0; i < normalized_dim; ++i) { | |||||
| megdnn_assert( | |||||
| data.shape[unnormalized_dim + i] == weight.shape[i], "%s", | |||||
| errmsg().c_str()); | |||||
| } | |||||
| } | |||||
| } | |||||
| void LayerNormForward::deduce_layout( | |||||
| const TensorLayout& data, const TensorLayout& weight, const TensorLayout& bias, | |||||
| TensorLayout& dst, TensorLayout& mean, TensorLayout& rstd) { | |||||
| deduce_layout_fwd(data, weight, bias, dst, mean, rstd); | |||||
| } | |||||
| void LayerNormForward::check_exec( | |||||
| const TensorLayout& data, const TensorLayout& weight, const TensorLayout& bias, | |||||
| const TensorLayout& dst, const TensorLayout& mean, const TensorLayout& rstd, | |||||
| size_t workspace_in_bytes) { | |||||
| check_layout_fwd(data, weight, bias, dst, mean, rstd); | |||||
| auto required_workspace_in_bytes = | |||||
| get_workspace_in_bytes(data, weight, bias, dst, mean, rstd); | |||||
| megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); | |||||
| } | |||||
| void LayerNormBackward::deduce_layout( | |||||
| const TensorLayout& diff, const TensorLayout& data, const TensorLayout& weight, | |||||
| const TensorLayout& mean, const TensorLayout& rstd, TensorLayout& ddata, | |||||
| TensorLayout& dweight, TensorLayout& dbias) { | |||||
| MEGDNN_MARK_USED_VAR(diff); | |||||
| MEGDNN_MARK_USED_VAR(mean); | |||||
| MEGDNN_MARK_USED_VAR(rstd); | |||||
| ddata = data; | |||||
| dweight = weight; | |||||
| dbias = weight; | |||||
| } | |||||
| void LayerNormBackward::check_exec( | |||||
| const TensorLayout& diff, const TensorLayout& data, const TensorLayout& weight, | |||||
| const TensorLayout& mean, const TensorLayout& rstd, const TensorLayout& ddata, | |||||
| const TensorLayout& dweight, const TensorLayout& dbias, | |||||
| size_t workspace_in_bytes) { | |||||
| auto p = param(); | |||||
| auto required_workspace_in_bytes = get_workspace_in_bytes( | |||||
| diff, data, weight, mean, rstd, ddata, dweight, dbias); | |||||
| megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); | |||||
| megdnn_assert_contiguous(diff); | |||||
| megdnn_assert_contiguous(data); | |||||
| megdnn_assert_contiguous(mean); | |||||
| megdnn_assert_contiguous(rstd); | |||||
| megdnn_assert_contiguous(ddata); | |||||
| if (p.affine) { | |||||
| megdnn_assert_contiguous(weight); | |||||
| megdnn_assert_contiguous(dweight); | |||||
| megdnn_assert_contiguous(dbias); | |||||
| } | |||||
| auto errmsg = [&]() { | |||||
| return megdnn_layout_msg(diff) + ", " + megdnn_layout_msg(data) + ", " + | |||||
| megdnn_layout_msg(weight) + ", " + megdnn_layout_msg(mean) + ", " + | |||||
| megdnn_layout_msg(rstd) + ", " + megdnn_layout_msg(ddata) + ", " + | |||||
| megdnn_layout_msg(dweight) + ", " + megdnn_layout_msg(dbias); | |||||
| }; | |||||
| MEGDNN_MARK_USED_VAR(errmsg); | |||||
| auto equal_layout = [](const TensorLayout& lhs, const TensorLayout& rhs) -> bool { | |||||
| if (!(lhs.ndim == rhs.ndim && lhs.dtype == rhs.dtype && | |||||
| lhs.format == rhs.format)) | |||||
| return false; | |||||
| for (size_t i = 0; i < lhs.ndim; ++i) { | |||||
| if (lhs.shape[i] != rhs.shape[i] || lhs.stride[i] != rhs.stride[i]) { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| return true; | |||||
| }; | |||||
| megdnn_assert(equal_layout(data, ddata), "%s", errmsg().c_str()); | |||||
| megdnn_assert(equal_layout(mean, rstd), "%s", errmsg().c_str()); | |||||
| if (p.affine) { | |||||
| megdnn_assert(equal_layout(weight, dweight), "%s", errmsg().c_str()); | |||||
| megdnn_assert(equal_layout(weight, dbias), "%s", errmsg().c_str()); | |||||
| } | |||||
| size_t normalized_dim = p.normalized_dim; | |||||
| size_t unnormalized_dim = data.ndim - normalized_dim; | |||||
| for (size_t i = 0; i < unnormalized_dim; ++i) { | |||||
| megdnn_assert(data.shape[i] == mean.shape[i], "%s", errmsg().c_str()); | |||||
| } | |||||
| if (p.affine) { | |||||
| for (size_t i = 0; i < normalized_dim; ++i) { | |||||
| megdnn_assert( | |||||
| data.shape[unnormalized_dim + i] == weight.shape[i], "%s", | |||||
| errmsg().c_str()); | |||||
| } | |||||
| } | |||||
| } | |||||
| } // namespace megdnn | |||||
| // vim: syntax=cpp.doxygen | |||||
| @@ -135,6 +135,8 @@ DEF(CheckNonFinite, 2, true, true); | |||||
| DEF(LSQForward, 5, true, true); | DEF(LSQForward, 5, true, true); | ||||
| DEF(LSQBackward, 7, true, false); | DEF(LSQBackward, 7, true, false); | ||||
| DEF(Fill, 1, true, false); | DEF(Fill, 1, true, false); | ||||
| DEF(LayerNormForward, 6, true, true); | |||||
| DEF(LayerNormBackward, 8, true, true); | |||||
| } // namespace megdnn | } // namespace megdnn | ||||
| // vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||
| @@ -45,6 +45,7 @@ | |||||
| #include "src/cuda/images2neibs/opr_impl.h" | #include "src/cuda/images2neibs/opr_impl.h" | ||||
| #include "src/cuda/indexing_multi_axis_vec/opr_impl.h" | #include "src/cuda/indexing_multi_axis_vec/opr_impl.h" | ||||
| #include "src/cuda/indexing_one_hot/opr_impl.h" | #include "src/cuda/indexing_one_hot/opr_impl.h" | ||||
| #include "src/cuda/layer_norm/opr_impl.h" | |||||
| #include "src/cuda/linspace/opr_impl.h" | #include "src/cuda/linspace/opr_impl.h" | ||||
| #include "src/cuda/local/opr_impl.h" | #include "src/cuda/local/opr_impl.h" | ||||
| #include "src/cuda/local_share/opr_impl.h" | #include "src/cuda/local_share/opr_impl.h" | ||||
| @@ -0,0 +1,664 @@ | |||||
| /** | |||||
| * \file dnn/src/cuda/layer_norm/layer_norm_cuda.cu | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #include <thrust/pair.h> | |||||
| #include <thrust/tuple.h> | |||||
| #include <cfloat> | |||||
| #include "megdnn/arch.h" | |||||
| #include "megdnn/dtype.h" | |||||
| #include "src/cuda/cuda_shfl_compat.cuh" | |||||
| #include "src/cuda/layer_norm/layer_norm_cuda.cuh" | |||||
| #include "src/cuda/utils.cuh" | |||||
| namespace megdnn { | |||||
| namespace cuda { | |||||
| namespace layer_norm { | |||||
| constexpr int kCUDANumThreads = 256; | |||||
| constexpr int vec_size = 4; | |||||
| // warp size may be used as array length, or used in host function, | |||||
| // so we define WARP_SIZE rather than using warpSize | |||||
| #define WARP_SIZE 32 | |||||
| #if defined(__clang__) | |||||
| #define __ubsan_ignore_float_divide_by_zero__ \ | |||||
| __attribute__((no_sanitize("float-divide-by-zero"))) | |||||
| #else | |||||
| #define __ubsan_ignore_float_divide_by_zero__ | |||||
| #endif | |||||
| struct WelfordStat { | |||||
| float mean; | |||||
| float sigma2; | |||||
| float count; | |||||
| MEGDNN_HOST MEGDNN_DEVICE WelfordStat() : mean(0.f), sigma2(0.f), count(0.f) {} | |||||
| MEGDNN_HOST MEGDNN_DEVICE WelfordStat(float mean, float sigma2, float count) | |||||
| : mean(mean), sigma2(sigma2), count(count) {} | |||||
| }; | |||||
| template <typename T, typename combine_t> | |||||
| struct WelfordData { | |||||
| T mean; | |||||
| T sigma2; | |||||
| combine_t count; | |||||
| MEGDNN_HOST MEGDNN_DEVICE WelfordData() : mean(0), sigma2(0), count(0) {} | |||||
| MEGDNN_HOST MEGDNN_DEVICE WelfordData(T mean, T sigma2, combine_t count) | |||||
| : mean(mean), sigma2(sigma2), count(count) {} | |||||
| }; | |||||
| template <typename T, typename combine_t, typename res_t> | |||||
| struct WelfordOps { | |||||
| public: | |||||
| using WelfordData_T = WelfordData<T, combine_t>; | |||||
| inline MEGDNN_DEVICE WelfordData_T reduce(WelfordData_T acc, T data) const { | |||||
| T delta = data - acc.mean; | |||||
| T new_mean = static_cast<T>(acc.mean + delta / (acc.count + 1)); | |||||
| T new_delta = static_cast<T>(data - new_mean); | |||||
| return { | |||||
| new_mean, | |||||
| acc.sigma2 + delta * new_delta, | |||||
| combine_t(acc.count + 1), | |||||
| }; | |||||
| } | |||||
| inline MEGDNN_DEVICE WelfordData_T | |||||
| combine(WelfordData_T lhs, WelfordData_T rhs) const { | |||||
| if (lhs.count != 0 && rhs.count != 0) { | |||||
| T delta = rhs.mean - lhs.mean; | |||||
| combine_t new_count = lhs.count + rhs.count; | |||||
| T nb_over_n = rhs.count / new_count; | |||||
| return {lhs.mean + delta * nb_over_n, | |||||
| lhs.sigma2 + rhs.sigma2 + delta * delta * lhs.count * nb_over_n, | |||||
| new_count}; | |||||
| } else { | |||||
| return (lhs.count != 0) ? lhs : rhs; | |||||
| } | |||||
| } | |||||
| inline MEGDNN_DEVICE res_t | |||||
| project(WelfordData_T acc) const __ubsan_ignore_float_divide_by_zero__ { | |||||
| const auto mean = static_cast<T>(acc.mean); | |||||
| const combine_t divisor = static_cast<combine_t>(acc.count); | |||||
| const auto var = acc.sigma2 / divisor; | |||||
| res_t results(var, mean); | |||||
| return results; | |||||
| } | |||||
| #if defined(__CUDACC__) || defined(__HIPCC__) | |||||
| inline MEGDNN_DEVICE WelfordData_T | |||||
| warp_shfl_down(WelfordData_T acc, int offset) const { | |||||
| return {__shfl_down(acc.mean, offset, warpSize), | |||||
| __shfl_down(acc.sigma2, offset, warpSize), | |||||
| __shfl_down(acc.count, offset, warpSize)}; | |||||
| } | |||||
| #endif | |||||
| MEGDNN_HOST MEGDNN_DEVICE WelfordOps() {} | |||||
| }; | |||||
| template <typename T, int vec_size> | |||||
| struct alignas(sizeof(T) * vec_size) aligned_vector { | |||||
| T val[vec_size]; | |||||
| }; | |||||
| template <typename T, bool is_cuda> | |||||
| using acc_type = T; | |||||
| template <typename U> | |||||
| MEGDNN_DEVICE WelfordStat | |||||
| update_welford_stat_online(const U val, const WelfordStat& curr_sum) { | |||||
| U delta = static_cast<U>(val - curr_sum.mean); | |||||
| U new_count = static_cast<U>(curr_sum.count + 1.f); | |||||
| U new_mean = static_cast<U>(curr_sum.mean + delta * (1.f / new_count)); | |||||
| return {new_mean, curr_sum.sigma2 + delta * (val - new_mean), new_count}; | |||||
| } | |||||
| MEGDNN_DEVICE WelfordStat | |||||
| combine_welford_stat(const WelfordStat lhs, const WelfordStat rhs) { | |||||
| using U = decltype(lhs.count); | |||||
| U delta = lhs.mean - rhs.mean; | |||||
| U count = rhs.count + lhs.count; | |||||
| U mean, sigma2; | |||||
| if (count > decltype(lhs.count){0}) { | |||||
| auto coef = 1.f / count; | |||||
| auto nA = rhs.count * coef; | |||||
| auto nB = lhs.count * coef; | |||||
| mean = nA * rhs.mean + nB * lhs.mean; | |||||
| sigma2 = rhs.sigma2 + lhs.sigma2 + delta * delta * rhs.count * nB; | |||||
| } else { | |||||
| mean = U(0); | |||||
| sigma2 = U(0); | |||||
| } | |||||
| return {mean, sigma2, count}; | |||||
| } | |||||
| template <typename T> | |||||
| MEGDNN_DEVICE WelfordStat | |||||
| compute_stats(const T* __restrict__ X, const int slice_len, float* buf) { | |||||
| using vec_t = aligned_vector<T, vec_size>; | |||||
| using acc_t = acc_type<T, true>; | |||||
| const vec_t* X_vec = reinterpret_cast<const vec_t*>(X); | |||||
| const int numx = blockDim.x * blockDim.y; | |||||
| const int thrx = threadIdx.x + threadIdx.y * blockDim.x; | |||||
| const int n_vec_to_read = slice_len / vec_size; | |||||
| WelfordStat w_stat(0.f, 0.f, 0.f); | |||||
| for (int i = thrx; i < n_vec_to_read; i += numx) { | |||||
| vec_t data = X_vec[i]; | |||||
| #pragma unroll | |||||
| for (int ii = 0; ii < vec_size; ii++) { | |||||
| w_stat = update_welford_stat_online( | |||||
| static_cast<acc_t>(data.val[ii]), w_stat); | |||||
| } | |||||
| } | |||||
| // intra-warp reduction | |||||
| #pragma unroll | |||||
| for (int offset = (warpSize >> 1); offset > 0; offset >>= 1) { | |||||
| WelfordStat w_tmp{ | |||||
| __shfl_down(w_stat.mean, offset, warpSize), | |||||
| __shfl_down(w_stat.sigma2, offset, warpSize), | |||||
| __shfl_down(w_stat.count, offset, warpSize)}; | |||||
| w_stat = combine_welford_stat(w_stat, w_tmp); | |||||
| } | |||||
| // threadIdx.x == 0 has correct values for each warp | |||||
| // inter-warp reductions | |||||
| if (blockDim.y > 1) { | |||||
| float* mean_sigma_buf = buf; | |||||
| float* count_buf = buf + blockDim.y; | |||||
| for (int offset = blockDim.y / 2; offset > 0; offset /= 2) { | |||||
| // upper half of warps write to shared | |||||
| if (threadIdx.x == 0 && threadIdx.y >= offset && threadIdx.y < 2 * offset) { | |||||
| const int wrt_y = threadIdx.y - offset; | |||||
| mean_sigma_buf[2 * wrt_y] = w_stat.mean; | |||||
| mean_sigma_buf[2 * wrt_y + 1] = w_stat.sigma2; | |||||
| count_buf[wrt_y] = w_stat.count; | |||||
| } | |||||
| __syncthreads(); | |||||
| // lower half merges | |||||
| if (threadIdx.x == 0 && threadIdx.y < offset) { | |||||
| WelfordStat w_tmp{ | |||||
| mean_sigma_buf[2 * threadIdx.y], | |||||
| mean_sigma_buf[2 * threadIdx.y + 1], count_buf[threadIdx.y]}; | |||||
| w_stat = combine_welford_stat(w_stat, w_tmp); | |||||
| } | |||||
| __syncthreads(); | |||||
| } | |||||
| if (threadIdx.x == 0 && threadIdx.y == 0) { | |||||
| mean_sigma_buf[0] = w_stat.mean; | |||||
| mean_sigma_buf[1] = w_stat.sigma2 / float(slice_len); | |||||
| } | |||||
| __syncthreads(); | |||||
| return WelfordStat{mean_sigma_buf[0], mean_sigma_buf[1], 0.f}; | |||||
| } else { | |||||
| return WelfordStat{ | |||||
| __shfl(w_stat.mean, 0, warpSize), | |||||
| __shfl(w_stat.sigma2, 0, warpSize) / float(slice_len), 0.f}; | |||||
| } | |||||
| } | |||||
| template <typename T, typename T_ACC> | |||||
| __global__ void vectorized_layer_norm_forward_affine_kernel( | |||||
| const int slice_len, T_ACC eps, const T* __restrict__ X, const T* weight, | |||||
| const T* bias, T_ACC* mean, T_ACC* rstd, T* Y) { | |||||
| // if we made smem WelfordStat type, there would be bank conflicts, | |||||
| // as one thread would have to write 3 consecutive floats | |||||
| extern __shared__ float s_data[]; | |||||
| auto slice_id = blockIdx.x; | |||||
| const T* slice = X + slice_id * slice_len; | |||||
| WelfordStat slice_w_stat = compute_stats(slice, slice_len, s_data); | |||||
| using vec_t = aligned_vector<T, vec_size>; | |||||
| const vec_t* X_vec = reinterpret_cast<const vec_t*>(slice); | |||||
| vec_t* Y_vec = reinterpret_cast<vec_t*>(Y + slice_id * slice_len); | |||||
| const int numx = blockDim.x * blockDim.y; | |||||
| const int thrx = threadIdx.x + threadIdx.y * blockDim.x; | |||||
| const int n_vec_to_read = slice_len / vec_size; | |||||
| T_ACC rstd_val = static_cast<T_ACC>(rsqrt(slice_w_stat.sigma2 + eps)); | |||||
| for (int i = thrx; i < n_vec_to_read; i += numx) { | |||||
| vec_t data = X_vec[i]; | |||||
| vec_t out; | |||||
| // computation is performed in T_ACC, X is cast to T_ACC and result is | |||||
| // implicitly cast to T | |||||
| #pragma unroll | |||||
| for (int ii = 0; ii < vec_size; ii++) { | |||||
| out.val[ii] = static_cast<T_ACC>(weight[i * vec_size + ii]) * | |||||
| (rstd_val * (static_cast<T_ACC>(data.val[ii]) - | |||||
| slice_w_stat.mean)) + | |||||
| static_cast<T_ACC>(bias[i * vec_size + ii]); | |||||
| } | |||||
| Y_vec[i] = out; | |||||
| } | |||||
| if (thrx == 0) { | |||||
| mean[slice_id] = slice_w_stat.mean; | |||||
| rstd[slice_id] = rstd_val; | |||||
| } | |||||
| } | |||||
| template <typename T, typename T_ACC> | |||||
| __global__ void vectorized_layer_norm_forward_kernel( | |||||
| const int slice_len, T_ACC eps, const T* __restrict__ X, const T* weight, | |||||
| const T* bias, T_ACC* mean, T_ACC* rstd, T* Y) { | |||||
| extern __shared__ float s_data[]; | |||||
| auto slice_id = blockIdx.x; | |||||
| const T* slice = X + slice_id * slice_len; | |||||
| WelfordStat slice_w_stat = compute_stats(slice, slice_len, s_data); | |||||
| using vec_t = aligned_vector<T, vec_size>; | |||||
| const vec_t* X_vec = reinterpret_cast<const vec_t*>(slice); | |||||
| vec_t* Y_vec = reinterpret_cast<vec_t*>(Y + slice_id * slice_len); | |||||
| const int numx = blockDim.x * blockDim.y; | |||||
| const int thrx = threadIdx.x + threadIdx.y * blockDim.x; | |||||
| const int n_vec_to_read = slice_len / vec_size; | |||||
| T_ACC rstd_val = static_cast<T_ACC>(rsqrt(slice_w_stat.sigma2 + eps)); | |||||
| for (int i = thrx; i < n_vec_to_read; i += numx) { | |||||
| vec_t data = X_vec[i]; | |||||
| vec_t out; | |||||
| #pragma unroll | |||||
| for (int ii = 0; ii < vec_size; ii++) { | |||||
| out.val[ii] = | |||||
| rstd_val * (static_cast<T_ACC>(data.val[ii]) - slice_w_stat.mean); | |||||
| } | |||||
| Y_vec[i] = out; | |||||
| } | |||||
| if (thrx == 0) { | |||||
| mean[slice_id] = slice_w_stat.mean; | |||||
| rstd[slice_id] = rstd_val; | |||||
| } | |||||
| } | |||||
| template <typename T, typename T_ACC> | |||||
| void launch_vectorized_layer_norm_forward_kernel( | |||||
| int64_t slice_len, int64_t slice_num, T_ACC eps, const T* X_data, | |||||
| const T* weight_data, const T* bias_data, T* Y_data, T_ACC* mean_data, | |||||
| T_ACC* rstd_data, cudaStream_t stream) { | |||||
| const int num_threads = 128; | |||||
| const dim3 threads(WARP_SIZE, num_threads / WARP_SIZE, 1); | |||||
| const dim3 blocks(slice_num); | |||||
| int nshared = threads.y > 1 ? threads.y * 3 / 2 * sizeof(T_ACC) : 0; | |||||
| if (weight_data == nullptr && bias_data == nullptr) { | |||||
| vectorized_layer_norm_forward_kernel<<<blocks, threads, nshared, stream>>>( | |||||
| slice_len, eps, X_data, weight_data, bias_data, mean_data, rstd_data, | |||||
| Y_data); | |||||
| } else { | |||||
| vectorized_layer_norm_forward_affine_kernel<<< | |||||
| blocks, threads, nshared, stream>>>( | |||||
| slice_len, eps, X_data, weight_data, bias_data, mean_data, rstd_data, | |||||
| Y_data); | |||||
| } | |||||
| after_kernel_launch(); | |||||
| } | |||||
| template <typename T, class ReduceOp> | |||||
| __inline__ MEGDNN_DEVICE T welford_warp_reduce(T val, const ReduceOp& op) { | |||||
| #pragma unroll | |||||
| for (int offset = (warpSize >> 1); offset > 0; offset >>= 1) { | |||||
| val = op.combine(val, op.warp_shfl_down(val, offset)); | |||||
| } | |||||
| return val; | |||||
| } | |||||
| template <typename T, class ReduceOp> | |||||
| __inline__ MEGDNN_DEVICE T | |||||
| welford_block_reduce(T val, const ReduceOp& op, const T& identity_element, T* shared) { | |||||
| const int lid = threadIdx.x % warpSize; | |||||
| const int wid = threadIdx.x / warpSize; | |||||
| val = welford_warp_reduce(val, op); | |||||
| __syncthreads(); | |||||
| if (lid == 0) { | |||||
| shared[wid] = val; | |||||
| } | |||||
| __syncthreads(); | |||||
| val = (threadIdx.x < blockDim.x / warpSize) ? shared[lid] : identity_element; | |||||
| if (wid == 0) { | |||||
| val = welford_warp_reduce(val, op); | |||||
| } | |||||
| return val; | |||||
| } | |||||
| template <typename T, typename T_ACC> | |||||
| __global__ void get_input_mean_and_rstd_kernel( | |||||
| int64_t slice_len, T_ACC eps, const T* X, T_ACC* mean, T_ACC* rstd) { | |||||
| using WelfordType = WelfordData<T_ACC, T_ACC>; | |||||
| using WelfordOp = WelfordOps<T_ACC, T_ACC, thrust::pair<T_ACC, T_ACC>>; | |||||
| __shared__ typename std::aligned_storage< | |||||
| sizeof(WelfordType), alignof(WelfordType)>::type val_shared[WARP_SIZE]; | |||||
| WelfordType* val_shared_ptr = reinterpret_cast<WelfordType*>(val_shared); | |||||
| const int64_t i = blockIdx.x; | |||||
| WelfordOp welford_op; | |||||
| WelfordType val( | |||||
| static_cast<T_ACC>(0), static_cast<T_ACC>(0), static_cast<T_ACC>(0)); | |||||
| for (int64_t j = threadIdx.x; j < slice_len; j += blockDim.x) { | |||||
| const int64_t index = i * slice_len + j; | |||||
| val = welford_op.reduce(val, static_cast<T_ACC>(X[index])); | |||||
| } | |||||
| val = welford_block_reduce( | |||||
| val, welford_op, | |||||
| WelfordType( | |||||
| static_cast<T_ACC>(0), static_cast<T_ACC>(0), | |||||
| static_cast<T_ACC>(0)), | |||||
| val_shared_ptr); | |||||
| if (threadIdx.x == 0) { | |||||
| T_ACC slice_mean; | |||||
| T_ACC slice_sigma2; | |||||
| thrust::tie(slice_sigma2, slice_mean) = welford_op.project(val); | |||||
| mean[i] = slice_mean; | |||||
| rstd[i] = rsqrt(slice_sigma2 + eps); | |||||
| } | |||||
| } | |||||
| template <typename T, typename T_ACC> | |||||
| __global__ void layer_norm_forward_kernel( | |||||
| int64_t slice_len, const T* X, const T_ACC* mean, const T_ACC* rstd, | |||||
| const T* weight, const T* bias, T* Y) { | |||||
| const int64_t i = blockIdx.x; | |||||
| for (int64_t j = threadIdx.x; j < slice_len; j += blockDim.x) { | |||||
| const int64_t index = i * slice_len + j; | |||||
| const T_ACC weight_v = | |||||
| weight == nullptr ? T_ACC(1) : static_cast<T_ACC>(weight[j]); | |||||
| const T_ACC bias_v = bias == nullptr ? T_ACC(0) : static_cast<T_ACC>(bias[j]); | |||||
| Y[index] = (static_cast<T_ACC>(X[index]) - static_cast<T_ACC>(mean[i])) * | |||||
| static_cast<T_ACC>(rstd[i]) * weight_v + | |||||
| bias_v; | |||||
| } | |||||
| } | |||||
| template <typename T, typename T_ACC> | |||||
| void forward( | |||||
| T* X, T* weight, T* bias, int64_t slice_num, int64_t slice_len, T_ACC eps, T* Y, | |||||
| T_ACC* mean, T_ACC* rstd, cudaStream_t stream) { | |||||
| auto can_vectorize = [&](const T* ptr, int alignment) { | |||||
| uint64_t addr = reinterpret_cast<uint64_t>(ptr); | |||||
| return addr % alignment == 0; | |||||
| }; | |||||
| constexpr int num_vec_elems = vec_size; | |||||
| constexpr int alignment = num_vec_elems * sizeof(T); | |||||
| if ((std::is_same<T, dt_float32>::value || std::is_same<T, dt_float16>::value || | |||||
| std::is_same<T, dt_bfloat16>::value) && | |||||
| slice_len <= static_cast<int64_t>(1ULL << std::numeric_limits<float>::digits) && | |||||
| slice_len % num_vec_elems == 0 && can_vectorize(X, alignment) && | |||||
| can_vectorize(Y, alignment)) { | |||||
| launch_vectorized_layer_norm_forward_kernel<T, T_ACC>( | |||||
| slice_len, slice_num, static_cast<T_ACC>(eps), X, weight, bias, Y, mean, | |||||
| rstd, stream); | |||||
| after_kernel_launch(); | |||||
| } else { | |||||
| get_input_mean_and_rstd_kernel<T, T_ACC> | |||||
| <<<slice_num, 512, 0, stream>>>(slice_len, eps, X, mean, rstd); | |||||
| after_kernel_launch(); | |||||
| layer_norm_forward_kernel<T, T_ACC><<<slice_num, kCUDANumThreads, 0, stream>>>( | |||||
| slice_len, X, mean, rstd, weight, bias, Y); | |||||
| after_kernel_launch(); | |||||
| } | |||||
| } | |||||
| template <typename T> | |||||
| __inline__ MEGDNN_DEVICE T warp_reduce_sum(T val) { | |||||
| #pragma unroll | |||||
| for (int offset = (warpSize >> 1); offset > 0; offset >>= 1) { | |||||
| val += __shfl_down(val, offset, warpSize); | |||||
| } | |||||
| return val; | |||||
| } | |||||
| template <typename T> | |||||
| __inline__ MEGDNN_DEVICE T block_reduce_sum(T val, T* shared) { | |||||
| const int lid = threadIdx.x % warpSize; | |||||
| const int wid = threadIdx.x / warpSize; | |||||
| val = warp_reduce_sum(val); | |||||
| __syncthreads(); | |||||
| if (lid == 0) { | |||||
| shared[wid] = val; | |||||
| } | |||||
| __syncthreads(); | |||||
| val = (threadIdx.x < blockDim.x / warpSize) ? shared[lid] : T(0); | |||||
| if (wid == 0) { | |||||
| val = warp_reduce_sum(val); | |||||
| } | |||||
| return val; | |||||
| } | |||||
| template <typename T, typename T_ACC> | |||||
| __inline__ MEGDNN_DEVICE void layer_norm_grad_input_kernel_impl( | |||||
| const T* __restrict__ dY, const T* __restrict__ X, | |||||
| const T_ACC* __restrict__ mean, const T_ACC* __restrict__ rstd, | |||||
| const T* __restrict__ weight, T* dX, const int slice_len, T_ACC* buf) { | |||||
| const auto slice_id = blockIdx.x; | |||||
| const T_ACC mean_val = mean[slice_id]; | |||||
| const T_ACC rstd_val = rstd[slice_id]; | |||||
| T_ACC stats_x1{0}, stats_x2{0}; | |||||
| constexpr int unroll = 4; | |||||
| auto l = unroll * threadIdx.x; | |||||
| const T* X_i = X + slice_id * slice_len; | |||||
| const T* dY_i = dY + slice_id * slice_len; | |||||
| T* dX_i = dX + slice_id * slice_len; | |||||
| // vectorized reads don't improve perf, so use regular unrolling | |||||
| for (; l + unroll - 1 < slice_len; l += blockDim.x * unroll) { | |||||
| #pragma unroll | |||||
| for (int k = 0; k < unroll; k++) { | |||||
| T_ACC weight_val = | |||||
| (weight != nullptr) ? static_cast<T_ACC>(weight[l + k]) : T_ACC(1); | |||||
| const T_ACC c_h = static_cast<T_ACC>(X_i[l + k]); | |||||
| const T_ACC c_loss = static_cast<T_ACC>(dY_i[l + k]); | |||||
| stats_x1 += c_loss * weight_val; | |||||
| stats_x2 += c_loss * weight_val * (c_h - mean_val) * rstd_val; | |||||
| } | |||||
| } | |||||
| for (; l < slice_len; l++) { | |||||
| T_ACC weight_val = | |||||
| (weight != nullptr) ? static_cast<T_ACC>(weight[l]) : T_ACC(1); | |||||
| const T_ACC c_h = static_cast<T_ACC>(X_i[l]); | |||||
| const T_ACC c_loss = static_cast<T_ACC>(dY_i[l]); | |||||
| stats_x1 += c_loss * weight_val; | |||||
| stats_x2 += c_loss * weight_val * (c_h - mean_val) * rstd_val; | |||||
| } | |||||
| stats_x1 = block_reduce_sum(stats_x1, buf); | |||||
| stats_x2 = block_reduce_sum(stats_x2, buf); | |||||
| if (threadIdx.x == 0) { | |||||
| buf[0] = stats_x1; | |||||
| buf[1] = stats_x2; | |||||
| } | |||||
| __syncthreads(); | |||||
| stats_x1 = buf[0]; | |||||
| stats_x2 = buf[1]; | |||||
| T_ACC fH = slice_len; | |||||
| T_ACC term1 = (T_ACC(1) / fH) * rstd_val; | |||||
| for (int l = threadIdx.x; l < slice_len; l += blockDim.x) { | |||||
| const T_ACC x = X_i[l]; | |||||
| const T_ACC dy = dY_i[l]; | |||||
| T_ACC weight_val = | |||||
| (weight != nullptr) ? static_cast<T_ACC>(weight[l]) : T_ACC(1); | |||||
| T_ACC f_grad_input = fH * weight_val * dy; | |||||
| f_grad_input -= (x - mean_val) * rstd_val * stats_x2; | |||||
| f_grad_input -= stats_x1; | |||||
| f_grad_input *= term1; | |||||
| dX_i[l] = f_grad_input; | |||||
| } | |||||
| } | |||||
| template <typename T, typename T_ACC> | |||||
| __global__ void layer_norm_grad_input_kernel( | |||||
| const T* __restrict__ dY, const T* __restrict__ X, | |||||
| const T_ACC* __restrict__ mean, const T_ACC* __restrict__ rstd, | |||||
| const T* __restrict__ weight, T* dX, const int slice_len) { | |||||
| alignas(sizeof(double)) extern __shared__ char s_data1[]; | |||||
| T_ACC* buf = reinterpret_cast<T_ACC*>(&s_data1); | |||||
| layer_norm_grad_input_kernel_impl(dY, X, mean, rstd, weight, dX, slice_len, buf); | |||||
| } | |||||
| template <typename T, typename T_ACC> | |||||
| __global__ void layer_norm_grad_weight_bias_simple_kernel( | |||||
| int64_t slice_num, int64_t slice_len, const T* dY, const T* X, | |||||
| const T_ACC* mean, const T_ACC* rstd, T* dweight, T* dbias) { | |||||
| const int64_t j = blockIdx.x * blockDim.x + threadIdx.x; | |||||
| if (j < slice_len) { | |||||
| T_ACC sum1 = 0; | |||||
| T_ACC sum2 = 0; | |||||
| for (int64_t i = 0; i < slice_num; ++i) { | |||||
| const int64_t index = i * slice_len + j; | |||||
| sum1 += dweight == nullptr ? T_ACC(0) | |||||
| : static_cast<T_ACC>(dY[index]) * | |||||
| (static_cast<T_ACC>(X[index]) - | |||||
| static_cast<T_ACC>(mean[i])) * | |||||
| static_cast<T_ACC>(rstd[i]); | |||||
| sum2 += dbias == nullptr ? T_ACC(0) : static_cast<T_ACC>(dY[index]); | |||||
| } | |||||
| if (dweight != nullptr) { | |||||
| dweight[j] = sum1; | |||||
| } | |||||
| if (dbias != nullptr) { | |||||
| dbias[j] = sum2; | |||||
| } | |||||
| } | |||||
| } | |||||
| template <typename T, typename T_ACC> | |||||
| __global__ void layer_norm_grad_weight_bias_kernel( | |||||
| int64_t slice_num, int64_t slice_len, const T* dY, const T* X, | |||||
| const T_ACC* mean, const T_ACC* rstd, T* dweight, T* dbias) { | |||||
| alignas(sizeof(double)) extern __shared__ char s_data1[]; | |||||
| T_ACC* s_data_typed = reinterpret_cast<T_ACC*>(&s_data1); | |||||
| const int64_t j = blockIdx.x * blockDim.x + threadIdx.x; | |||||
| constexpr int unroll = 8; | |||||
| T dYs[unroll]; | |||||
| T Xs[unroll]; | |||||
| T_ACC* means = s_data_typed; | |||||
| T_ACC* rstds = s_data_typed + unroll * blockDim.y; | |||||
| T_ACC dg_sum = 0; | |||||
| T_ACC db_sum = 0; | |||||
| if (j < slice_len) { | |||||
| int bcounter; | |||||
| for (bcounter = 0; bcounter < slice_num / (blockDim.y * unroll); bcounter++) { | |||||
| int offset = (bcounter * blockDim.y + threadIdx.y) * unroll; | |||||
| #pragma unroll | |||||
| for (int ii = 0; ii < unroll; ii++) { | |||||
| if (threadIdx.x == 0) { | |||||
| means[ii * blockDim.y + threadIdx.y] = mean[offset + ii]; | |||||
| rstds[ii * blockDim.y + threadIdx.y] = rstd[offset + ii]; | |||||
| } | |||||
| dYs[ii] = dY[(offset + ii) * slice_len + j]; | |||||
| Xs[ii] = X[(offset + ii) * slice_len + j]; | |||||
| } | |||||
| __syncthreads(); | |||||
| #pragma unroll | |||||
| for (int ii = 0; ii < unroll; ii++) { | |||||
| dg_sum += dYs[ii] * (Xs[ii] - means[ii * blockDim.y + threadIdx.y]) * | |||||
| rstds[ii * blockDim.y + threadIdx.y]; | |||||
| db_sum += dYs[ii]; | |||||
| } | |||||
| __syncthreads(); | |||||
| } | |||||
| int offset = (bcounter * blockDim.y + threadIdx.y) * unroll; | |||||
| for (int ii = 0; ii < 8; ii++) { | |||||
| T_ACC mean_val, rstd_val; // we don't use smem in the tail to avoid awkward | |||||
| // synchronizations, perf penalty is negligible | |||||
| if ((offset + ii) < slice_num) { | |||||
| mean_val = mean[offset + ii]; | |||||
| rstd_val = rstd[offset + ii]; | |||||
| dYs[0] = dY[(offset + ii) * slice_len + j]; | |||||
| Xs[0] = X[(offset + ii) * slice_len + j]; | |||||
| dg_sum += dYs[0] * (Xs[0] - mean_val) * rstd_val; | |||||
| db_sum += dYs[0]; | |||||
| } | |||||
| } | |||||
| s_data_typed[threadIdx.y * blockDim.x + threadIdx.x] = dg_sum; | |||||
| s_data_typed[blockDim.x * blockDim.y + threadIdx.y * blockDim.x + threadIdx.x] = | |||||
| db_sum; | |||||
| __syncthreads(); | |||||
| for (int offset = blockDim.y / 2; offset >= 1; offset /= 2) { | |||||
| if (threadIdx.y < offset) { | |||||
| s_data_typed[threadIdx.y * blockDim.x + threadIdx.x] += | |||||
| s_data_typed[(threadIdx.y + offset) * blockDim.x + threadIdx.x]; | |||||
| s_data_typed | |||||
| [blockDim.x * blockDim.y + threadIdx.y * blockDim.x + | |||||
| threadIdx.x] += s_data_typed | |||||
| [blockDim.x * blockDim.y + | |||||
| (threadIdx.y + offset) * blockDim.x + threadIdx.x]; | |||||
| } | |||||
| __syncthreads(); | |||||
| } | |||||
| if (threadIdx.y == 0) { | |||||
| if (dweight) { | |||||
| dweight[j] = s_data_typed[threadIdx.x]; | |||||
| } | |||||
| if (dbias) { | |||||
| dbias[j] = s_data_typed[threadIdx.x + blockDim.x * blockDim.y]; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| template <typename T, typename T_ACC> | |||||
| void backward( | |||||
| const T* dY_data, const T* X_data, const T_ACC* mean_data, | |||||
| const T_ACC* rstd_data, const T* weight_data, int64_t slice_num, | |||||
| int64_t slice_len, T* dX_data, T* dweight_data, T* dbias_data, | |||||
| cudaStream_t stream) { | |||||
| if (dX_data != nullptr) { | |||||
| const int num_threads = 128; | |||||
| const dim3 blocks(slice_num); | |||||
| int nshared = (num_threads / WARP_SIZE) * sizeof(T_ACC); | |||||
| layer_norm_grad_input_kernel<<<blocks, num_threads, nshared, stream>>>( | |||||
| dY_data, X_data, mean_data, rstd_data, weight_data, dX_data, slice_len); | |||||
| after_kernel_launch(); | |||||
| } | |||||
| if (dweight_data || dbias_data) { | |||||
| if (slice_num < 512) { | |||||
| const int64_t B = (slice_len + kCUDANumThreads - 1) / kCUDANumThreads; | |||||
| layer_norm_grad_weight_bias_simple_kernel<T, T_ACC> | |||||
| <<<B, kCUDANumThreads, 0, stream>>>( | |||||
| slice_num, slice_len, dY_data, X_data, mean_data, rstd_data, | |||||
| dweight_data, dbias_data); | |||||
| after_kernel_launch(); | |||||
| } else { | |||||
| dim3 threads{16, 32}; | |||||
| int blocks = (slice_len + threads.x - 1) / threads.x; | |||||
| layer_norm_grad_weight_bias_kernel<T, T_ACC> | |||||
| <<<blocks, threads, 2 * sizeof(T_ACC) * threads.x * threads.y, | |||||
| stream>>>( | |||||
| slice_num, slice_len, dY_data, X_data, mean_data, rstd_data, | |||||
| dweight_data, dbias_data); | |||||
| after_kernel_launch(); | |||||
| } | |||||
| } | |||||
| } | |||||
| #define INST(T, T_ACC) \ | |||||
| template void forward<T, T_ACC>( \ | |||||
| T*, T*, T*, int64_t, int64_t, T_ACC, T*, T_ACC*, T_ACC*, cudaStream_t); \ | |||||
| template void backward<T, T_ACC>( \ | |||||
| const T*, const T*, const T_ACC*, const T_ACC*, const T*, int64_t, \ | |||||
| int64_t, T*, T*, T*, cudaStream_t); | |||||
| INST(dt_float32, dt_float32) | |||||
| INST(dt_float16, dt_float32) | |||||
| INST(dt_bfloat16, dt_float32) | |||||
| #undef INST | |||||
| } // namespace layer_norm | |||||
| } // namespace cuda | |||||
| } // namespace megdnn | |||||
| // vim: syntax=cpp.doxygen | |||||
| @@ -0,0 +1,34 @@ | |||||
| /** | |||||
| * \file dnn/src/cuda/layer_norm/layer_norm.cuh | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #pragma once | |||||
| #include <cuda_runtime_api.h> | |||||
| namespace megdnn { | |||||
| namespace cuda { | |||||
| namespace layer_norm { | |||||
| template <typename T, typename T_ACC> | |||||
| void forward( | |||||
| T* X, T* gamma, T* beta, int64_t M, int64_t N, T_ACC eps, T* Y, T_ACC* mean, | |||||
| T_ACC* rstd, cudaStream_t stream); | |||||
| template <typename T, typename T_ACC> | |||||
| void backward( | |||||
| const T* dY_data, const T* X_data, const T_ACC* mean_data, | |||||
| const T_ACC* rstd_data, const T* gamma_data, int64_t M, int64_t N, T* dX_data, | |||||
| T* dgamma_data, T* dbeta_data, cudaStream_t stream); | |||||
| } // namespace layer_norm | |||||
| } // namespace cuda | |||||
| } // namespace megdnn | |||||
| // vim: syntax=cpp.doxygen | |||||
| @@ -0,0 +1,94 @@ | |||||
| /** | |||||
| * \file dnn/src/cuda/layer_norm/opr_impl.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #include "src/cuda/layer_norm/opr_impl.h" | |||||
| #include "src/cuda/layer_norm/layer_norm_cuda.cuh" | |||||
| #include "src/cuda/utils.h" | |||||
| namespace megdnn { | |||||
| namespace cuda { | |||||
| void LayerNormForwardImpl::exec( | |||||
| _megdnn_tensor_in data, _megdnn_tensor_in weight, _megdnn_tensor_in bias, | |||||
| _megdnn_tensor_out dst, _megdnn_tensor_out mean, _megdnn_tensor_out rstd, | |||||
| _megdnn_workspace workspace) { | |||||
| check_exec( | |||||
| data.layout, weight.layout, bias.layout, dst.layout, mean.layout, | |||||
| rstd.layout, workspace.size); | |||||
| auto p = param(); | |||||
| float eps = p.eps; | |||||
| bool affine = p.affine; | |||||
| uint64_t slice_length = p.normalized_size; | |||||
| uint64_t slice_dim = p.normalized_dim; | |||||
| uint64_t n_slices = 1; | |||||
| for (size_t i = 0; i < data.layout.ndim - slice_dim; ++i) { | |||||
| n_slices = n_slices * data.layout.shape[i]; | |||||
| } | |||||
| auto stream = cuda_stream(handle()); | |||||
| using namespace ::megdnn::cuda::layer_norm; | |||||
| #define cb(DType) \ | |||||
| if (data.layout.dtype == DType()) { \ | |||||
| using T = typename DTypeTrait<DType>::ctype; \ | |||||
| using T_ACC = float; \ | |||||
| forward<T, T_ACC>( \ | |||||
| data.ptr<T>(), affine ? weight.ptr<T>() : nullptr, \ | |||||
| affine ? bias.ptr<T>() : nullptr, static_cast<int64_t>(n_slices), \ | |||||
| static_cast<int64_t>(slice_length), static_cast<T_ACC>(eps), \ | |||||
| dst.ptr<T>(), mean.ptr<T_ACC>(), rstd.ptr<T_ACC>(), stream); \ | |||||
| return; \ | |||||
| } | |||||
| MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) | |||||
| #undef cb | |||||
| megdnn_throw("bad dtype"); | |||||
| } | |||||
| void LayerNormBackwardImpl::exec( | |||||
| _megdnn_tensor_in diff, _megdnn_tensor_in data, _megdnn_tensor_in weight, | |||||
| _megdnn_tensor_in mean, _megdnn_tensor_in rstd, _megdnn_tensor_out ddata, | |||||
| _megdnn_tensor_out dweight, _megdnn_tensor_out dbias, | |||||
| _megdnn_workspace workspace) { | |||||
| check_exec( | |||||
| diff.layout, data.layout, weight.layout, mean.layout, rstd.layout, | |||||
| ddata.layout, dweight.layout, dbias.layout, workspace.size); | |||||
| auto p = param(); | |||||
| bool affine = p.affine; | |||||
| uint64_t slice_length = p.normalized_size; | |||||
| uint64_t slice_dim = p.normalized_dim; | |||||
| uint64_t n_slices = 1; | |||||
| for (size_t i = 0; i < data.layout.ndim - slice_dim; ++i) { | |||||
| n_slices = n_slices * data.layout.shape[i]; | |||||
| } | |||||
| auto stream = cuda_stream(handle()); | |||||
| using namespace ::megdnn::cuda::layer_norm; | |||||
| #define cb(DType) \ | |||||
| if (data.layout.dtype == DType()) { \ | |||||
| using T = typename DTypeTrait<DType>::ctype; \ | |||||
| using T_ACC = float; \ | |||||
| backward<T, T_ACC>( \ | |||||
| diff.ptr<T>(), data.ptr<T>(), mean.ptr<T_ACC>(), rstd.ptr<T_ACC>(), \ | |||||
| affine ? weight.ptr<T>() : nullptr, n_slices, slice_length, \ | |||||
| ddata.ptr<T>(), affine ? dweight.ptr<T>() : nullptr, \ | |||||
| affine ? dbias.ptr<T>() : nullptr, stream); \ | |||||
| return; \ | |||||
| } | |||||
| MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) | |||||
| #undef cb | |||||
| megdnn_throw("bad dtype"); | |||||
| } | |||||
| } // namespace cuda | |||||
| } // namespace megdnn | |||||
| // vim: syntax=cpp.doxygen | |||||
| @@ -0,0 +1,53 @@ | |||||
| /** | |||||
| * \file dnn/src/cuda/layer_norm/opr_impl.h | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #pragma once | |||||
| #include "megdnn/oprs.h" | |||||
| #include "src/cuda/cudnn_wrapper.h" | |||||
| namespace megdnn { | |||||
| namespace cuda { | |||||
| class LayerNormForwardImpl final : public LayerNormForward { | |||||
| public: | |||||
| using LayerNormForward::LayerNormForward; | |||||
| void exec( | |||||
| _megdnn_tensor_in data, _megdnn_tensor_in weight, _megdnn_tensor_in bias, | |||||
| _megdnn_tensor_out dst, _megdnn_tensor_out mean, _megdnn_tensor_out rstd, | |||||
| _megdnn_workspace workspace) override; | |||||
| size_t get_workspace_in_bytes( | |||||
| const TensorLayout&, const TensorLayout&, const TensorLayout&, | |||||
| const TensorLayout&, const TensorLayout&, const TensorLayout&) override { | |||||
| return 0; | |||||
| } | |||||
| }; | |||||
| class LayerNormBackwardImpl final : public LayerNormBackward { | |||||
| public: | |||||
| using LayerNormBackward::LayerNormBackward; | |||||
| void exec( | |||||
| _megdnn_tensor_in diff, _megdnn_tensor_in data, _megdnn_tensor_in weight, | |||||
| _megdnn_tensor_in mean, _megdnn_tensor_in rstd, _megdnn_tensor_out ddata, | |||||
| _megdnn_tensor_out dweight, _megdnn_tensor_out dbias, | |||||
| _megdnn_workspace workspace) override; | |||||
| size_t get_workspace_in_bytes( | |||||
| const TensorLayout&, const TensorLayout&, const TensorLayout&, | |||||
| const TensorLayout&, const TensorLayout&, const TensorLayout&, | |||||
| const TensorLayout&, const TensorLayout&) override { | |||||
| return 0; | |||||
| } | |||||
| }; | |||||
| } // namespace cuda | |||||
| } // namespace megdnn | |||||
| // vim: syntax=cpp.doxygen | |||||
| @@ -47,6 +47,7 @@ | |||||
| #include "src/naive/images2neibs/opr_impl.h" | #include "src/naive/images2neibs/opr_impl.h" | ||||
| #include "src/naive/indexing_multi_axis_vec/opr_impl.h" | #include "src/naive/indexing_multi_axis_vec/opr_impl.h" | ||||
| #include "src/naive/indexing_one_hot/opr_impl.h" | #include "src/naive/indexing_one_hot/opr_impl.h" | ||||
| #include "src/naive/layer_norm/opr_impl.h" | |||||
| #include "src/naive/linspace/opr_impl.h" | #include "src/naive/linspace/opr_impl.h" | ||||
| #include "src/naive/local/opr_impl.h" | #include "src/naive/local/opr_impl.h" | ||||
| #include "src/naive/local_share/opr_impl.h" | #include "src/naive/local_share/opr_impl.h" | ||||
| @@ -0,0 +1,170 @@ | |||||
| /** | |||||
| * \file dnn/src/naive/layer_norm/opr_impl.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #include "src/naive/layer_norm/opr_impl.h" | |||||
| #include <algorithm> | |||||
| #include "src/common/utils.h" | |||||
| #include "src/naive/handle.h" | |||||
| using namespace megdnn; | |||||
| using namespace naive; | |||||
| namespace { | |||||
| using Param = megdnn::LayerNorm::Param; | |||||
| template <typename T, typename T_ACC = float> | |||||
| void forward( | |||||
| _megdnn_tensor_in data, _megdnn_tensor_in weight, _megdnn_tensor_in bias, | |||||
| _megdnn_tensor_out dst, _megdnn_tensor_out mean, _megdnn_tensor_out rstd, | |||||
| const Param& param) { | |||||
| float eps = param.eps; | |||||
| bool affine = param.affine; | |||||
| uint64_t slice_length = param.normalized_size; | |||||
| uint64_t slice_dim = param.normalized_dim; | |||||
| uint64_t n_slices = 1; | |||||
| for (size_t i = 0; i < data.layout.ndim - slice_dim; ++i) { | |||||
| n_slices = n_slices * data.layout.shape[i]; | |||||
| } | |||||
| for (size_t i = 0; i < n_slices; i++) { | |||||
| T_ACC slice_sum = static_cast<T>(0.0f); | |||||
| for (size_t j = 0; j < slice_length; j++) { | |||||
| auto value = data.ptr<T>()[i * slice_length + j]; | |||||
| slice_sum += value; | |||||
| } | |||||
| T_ACC slice_mean = static_cast<T>(slice_sum / slice_length); | |||||
| T_ACC slice_var = static_cast<T>(0.0f); | |||||
| for (size_t j = 0; j < slice_length; j++) { | |||||
| slice_var += (data.ptr<T>()[i * slice_length + j] - slice_mean) * | |||||
| (data.ptr<T>()[i * slice_length + j] - slice_mean); | |||||
| } | |||||
| slice_var = slice_var / slice_length; | |||||
| T_ACC slice_std = static_cast<T>(sqrt(slice_var + eps)); | |||||
| for (size_t j = 0; j < slice_length; j++) { | |||||
| dst.ptr<T>()[i * slice_length + j] = | |||||
| (data.ptr<T>()[i * slice_length + j] - slice_mean) / slice_std; | |||||
| if (affine) { | |||||
| dst.ptr<T>()[i * slice_length + j] = | |||||
| dst.ptr<T>()[i * slice_length + j] * weight.ptr<T>()[j] + | |||||
| bias.ptr<T>()[j]; | |||||
| } | |||||
| } | |||||
| mean.ptr<T_ACC>()[i] = static_cast<T_ACC>(slice_mean); | |||||
| rstd.ptr<T_ACC>()[i] = static_cast<T_ACC>(1.0 / slice_std); | |||||
| } | |||||
| } | |||||
| template <typename T, typename T_ACC = float> | |||||
| void backward( | |||||
| _megdnn_tensor_in diff, _megdnn_tensor_in data, _megdnn_tensor_in weight, | |||||
| _megdnn_tensor_in mean, _megdnn_tensor_in rstd, _megdnn_tensor_out ddata, | |||||
| _megdnn_tensor_out dweight, _megdnn_tensor_out dbias, const Param& param) { | |||||
| bool affine = param.affine; | |||||
| uint64_t slice_length = param.normalized_size; | |||||
| uint64_t slice_dim = param.normalized_dim; | |||||
| uint64_t n_slices = 1; | |||||
| for (size_t i = 0; i < data.layout.ndim - slice_dim; ++i) { | |||||
| n_slices = n_slices * data.layout.shape[i]; | |||||
| } | |||||
| if (affine) { | |||||
| for (size_t i = 0; i < slice_length; ++i) { | |||||
| dweight.ptr<T>()[i] = 0; | |||||
| dbias.ptr<T>()[i] = 0; | |||||
| } | |||||
| for (size_t i = 0; i < n_slices; ++i) { | |||||
| for (size_t j = 0; j < slice_length; ++j) { | |||||
| dweight.ptr<T>()[j] += | |||||
| (data.ptr<T>()[i * slice_length + j] - mean.ptr<T_ACC>()[i]) * | |||||
| rstd.ptr<T_ACC>()[i] * diff.ptr<T>()[i * slice_length + j]; | |||||
| dbias.ptr<T>()[j] += diff.ptr<T>()[i * slice_length + j]; | |||||
| } | |||||
| } | |||||
| } | |||||
| for (size_t i = 0; i < n_slices; ++i) { | |||||
| T_ACC ds = static_cast<T_ACC>(0.0f); | |||||
| T_ACC db = static_cast<T_ACC>(0.0f); | |||||
| T_ACC a = static_cast<T_ACC>(0.0f); | |||||
| T_ACC b = static_cast<T_ACC>(0.0f); | |||||
| T_ACC c = static_cast<T_ACC>(0.0f); | |||||
| for (size_t j = 0; j < slice_length; ++j) { | |||||
| auto value = data.ptr<T>()[i * slice_length + j]; | |||||
| auto diff_v = diff.ptr<T>()[i * slice_length + j]; | |||||
| auto weight_v = affine ? weight.ptr<T>()[j] : static_cast<T>(1.0f); | |||||
| db += diff_v * weight_v; | |||||
| ds += diff_v * value * weight_v; | |||||
| } | |||||
| a = rstd.ptr<T_ACC>()[i]; | |||||
| b = (db * mean.ptr<T_ACC>()[i] - ds) * a * a * a / slice_length; | |||||
| c = -b * mean.ptr<T_ACC>()[i] - db * a / slice_length; | |||||
| for (uint64_t j = 0; j < slice_length; j++) { | |||||
| auto weight_v = affine ? weight.ptr<T>()[j] : static_cast<T>(1.0f); | |||||
| ddata.ptr<T>()[i * slice_length + j] = | |||||
| diff.ptr<T>()[i * slice_length + j] * a * weight_v + | |||||
| data.ptr<T>()[i * slice_length + j] * b + c; | |||||
| } | |||||
| } | |||||
| } | |||||
| } // namespace | |||||
| namespace megdnn { | |||||
| namespace naive { | |||||
| void LayerNormForwardImpl::exec( | |||||
| _megdnn_tensor_in data, _megdnn_tensor_in weight, _megdnn_tensor_in bias, | |||||
| _megdnn_tensor_out dst, _megdnn_tensor_out mean, _megdnn_tensor_out rstd, | |||||
| _megdnn_workspace workspace) { | |||||
| check_exec( | |||||
| data.layout, weight.layout, bias.layout, dst.layout, mean.layout, | |||||
| rstd.layout, workspace.size); | |||||
| #define cb(DType) \ | |||||
| if (data.layout.dtype == DType()) { \ | |||||
| MEGDNN_DISPATCH_CPU_KERN_OPR(forward<typename DTypeTrait<DType>::ctype>( \ | |||||
| data, weight, bias, dst, mean, rstd, param())); \ | |||||
| return; \ | |||||
| } | |||||
| MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) | |||||
| #undef cb | |||||
| megdnn_throw("bad dtype"); | |||||
| } | |||||
| void LayerNormBackwardImpl::exec( | |||||
| _megdnn_tensor_in diff, _megdnn_tensor_in data, _megdnn_tensor_in weight, | |||||
| _megdnn_tensor_in mean, _megdnn_tensor_in rstd, _megdnn_tensor_out ddata, | |||||
| _megdnn_tensor_out dweight, _megdnn_tensor_out dbias, | |||||
| _megdnn_workspace workspace) { | |||||
| check_exec( | |||||
| diff.layout, data.layout, weight.layout, mean.layout, rstd.layout, | |||||
| ddata.layout, dweight.layout, dbias.layout, workspace.size); | |||||
| #define cb(DType) \ | |||||
| if (data.layout.dtype == DType()) { \ | |||||
| MEGDNN_DISPATCH_CPU_KERN_OPR(backward<typename DTypeTrait<DType>::ctype>( \ | |||||
| diff, data, weight, mean, rstd, ddata, dweight, dbias, param())); \ | |||||
| return; \ | |||||
| } | |||||
| MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) | |||||
| #undef cb | |||||
| megdnn_throw("bad dtype"); | |||||
| } | |||||
| } // namespace naive | |||||
| } // namespace megdnn | |||||
| // vim: syntax=cpp.doxygen | |||||
| @@ -0,0 +1,51 @@ | |||||
| /** | |||||
| * \file dnn/src/naive/layer_norm/opr_impl.h | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #pragma once | |||||
| #include "megdnn/oprs.h" | |||||
| namespace megdnn { | |||||
| namespace naive { | |||||
| class LayerNormForwardImpl final : public LayerNormForward { | |||||
| public: | |||||
| using LayerNormForward::LayerNormForward; | |||||
| void exec( | |||||
| _megdnn_tensor_in data, _megdnn_tensor_in weight, _megdnn_tensor_in bias, | |||||
| _megdnn_tensor_out dst, _megdnn_tensor_out mean, _megdnn_tensor_out rstd, | |||||
| _megdnn_workspace workspace) override; | |||||
| size_t get_workspace_in_bytes( | |||||
| const TensorLayout&, const TensorLayout&, const TensorLayout&, | |||||
| const TensorLayout&, const TensorLayout&, const TensorLayout&) override { | |||||
| return 0; | |||||
| } | |||||
| }; | |||||
| class LayerNormBackwardImpl final : public LayerNormBackward { | |||||
| public: | |||||
| using LayerNormBackward::LayerNormBackward; | |||||
| void exec( | |||||
| _megdnn_tensor_in diff, _megdnn_tensor_in data, _megdnn_tensor_in weight, | |||||
| _megdnn_tensor_in mean, _megdnn_tensor_in rstd, _megdnn_tensor_out ddata, | |||||
| _megdnn_tensor_out dweight, _megdnn_tensor_out dbias, | |||||
| _megdnn_workspace workspace) override; | |||||
| size_t get_workspace_in_bytes( | |||||
| const TensorLayout&, const TensorLayout&, const TensorLayout&, | |||||
| const TensorLayout&, const TensorLayout&, const TensorLayout&, | |||||
| const TensorLayout&, const TensorLayout&) override { | |||||
| return 0; | |||||
| } | |||||
| }; | |||||
| } // namespace naive | |||||
| } // namespace megdnn | |||||
| // vim: syntax=cpp.doxygen | |||||
| @@ -57,6 +57,15 @@ struct DeduceLayoutProxy<Opr, 5, true> { | |||||
| } | } | ||||
| }; | }; | ||||
| template <typename Opr> | |||||
| struct DeduceLayoutProxy<Opr, 6, true> { | |||||
| static void deduce_layout(Opr* opr, TensorLayoutArray& layouts) { | |||||
| megdnn_assert(layouts.size() == 6); | |||||
| opr->deduce_layout( | |||||
| layouts[0], layouts[1], layouts[2], layouts[3], layouts[4], layouts[5]); | |||||
| } | |||||
| }; | |||||
| template <typename Opr> | template <typename Opr> | ||||
| struct DeduceLayoutProxy<Opr, 5, false> { | struct DeduceLayoutProxy<Opr, 5, false> { | ||||
| static void deduce_layout(Opr*, TensorLayoutArray&) {} | static void deduce_layout(Opr*, TensorLayoutArray&) {} | ||||
| @@ -0,0 +1,94 @@ | |||||
| /** | |||||
| * \file dnn/test/cuda/layer_norm.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #include "test/cuda/fixture.h" | |||||
| #include "test/common/checker.h" | |||||
| namespace megdnn { | |||||
| namespace test { | |||||
| TEST_F(CUDA, LAYERNORM_FORWARD) { | |||||
| using Param = LayerNormForward::Param; | |||||
| Param param; | |||||
| param.affine = true; | |||||
| param.eps = 1e-6; | |||||
| param.normalized_dim = 1; | |||||
| Checker<LayerNormForward> checker(handle_cuda()); | |||||
| checker.set_epsilon(1e-2); | |||||
| auto run = [&](DType d) { | |||||
| for (size_t n_slices : {10, 30}) | |||||
| for (size_t slice_len : {10, 30}) { | |||||
| param.normalized_size = slice_len; | |||||
| checker.set_param(param) | |||||
| .set_dtype(0, d) | |||||
| .set_dtype(1, d) | |||||
| .set_dtype(2, d) | |||||
| .set_dtype(3, d) | |||||
| .set_dtype(4, dtype::Float32()) | |||||
| .set_dtype(5, dtype::Float32()) | |||||
| .execs({{n_slices, slice_len}, | |||||
| {slice_len}, | |||||
| {slice_len}, | |||||
| {n_slices, slice_len}, | |||||
| {n_slices}, | |||||
| {n_slices}}); | |||||
| } | |||||
| }; | |||||
| run(dtype::Float32()); | |||||
| run(dtype::Float16()); | |||||
| run(dtype::BFloat16()); | |||||
| } | |||||
| TEST_F(CUDA, LAYERNORM_BACKWARD) { | |||||
| using Param = LayerNormBackward::Param; | |||||
| Param param; | |||||
| param.affine = true; | |||||
| param.eps = 1e-6; | |||||
| param.normalized_dim = 1; | |||||
| Checker<LayerNormBackward> checker(handle_cuda()); | |||||
| checker.set_epsilon(1e-1); | |||||
| auto run = [&](DType d) { | |||||
| for (size_t n_slices : {10, 30}) | |||||
| for (size_t slice_len : {10, 30}) { | |||||
| param.normalized_size = slice_len; | |||||
| checker.set_param(param) | |||||
| .set_dtype(0, d) | |||||
| .set_dtype(1, d) | |||||
| .set_dtype(2, d) | |||||
| .set_dtype(3, dtype::Float32()) | |||||
| .set_dtype(4, dtype::Float32()) | |||||
| .set_dtype(5, d) | |||||
| .set_dtype(6, d) | |||||
| .set_dtype(7, d) | |||||
| .execs({{n_slices, slice_len}, | |||||
| {n_slices, slice_len}, | |||||
| {slice_len}, | |||||
| {n_slices}, | |||||
| {n_slices}, | |||||
| {n_slices, slice_len}, | |||||
| {slice_len}, | |||||
| {slice_len}}); | |||||
| } | |||||
| }; | |||||
| run(dtype::Float32()); | |||||
| run(dtype::Float16()); | |||||
| run(dtype::BFloat16()); | |||||
| } | |||||
| } // namespace test | |||||
| } // namespace megdnn | |||||
| // vim: syntax=cpp.doxygen | |||||
| @@ -1066,57 +1066,6 @@ def softmax(inp: Tensor, axis: Optional[int] = None) -> Tensor: | |||||
| return cached / down | return cached / down | ||||
| @lru_cache(maxsize=None) | |||||
| def _get_layerNorm(device, dtype, dim, gopt_level=2): | |||||
| @subgraph("LayerNormAffine", dtype, device, 5, gopt_level=gopt_level) | |||||
| def layerNormAffine(inputs, f, c): | |||||
| inp, eps, _flatten_shape, weight, bias = inputs | |||||
| inp_shape = f(GetVarShape(), inp) | |||||
| inp = f(Reshape(axis=dim), inp, _flatten_shape) | |||||
| mean = f(Reduce(mode="mean", axis=-1), inp) | |||||
| x2s = f(Reduce(mode="sum_sqr", axis=-1), inp) | |||||
| reduce_shape = f(GetVarShape(), x2s) | |||||
| reduce_size = f( | |||||
| "//", | |||||
| f(Reduce(mode="product", axis=0), inp_shape), | |||||
| f(Reduce(mode="product", axis=0), reduce_shape), | |||||
| ) | |||||
| reduce_size_f = f(TypeCvt(dtype=dtype), reduce_size) | |||||
| var = f("-", f("/", x2s, reduce_size_f), f("**", mean, c(2))) | |||||
| inv_sqrt_var = f("**", f("+", var, eps), c(-0.5)) | |||||
| oup = f("fma3", inp, inv_sqrt_var, f("*", f("-", mean), inv_sqrt_var)) | |||||
| affine_oup = f(Reshape(), oup, inp_shape) | |||||
| affine_oup = f("fma3", affine_oup, weight, bias) | |||||
| # NOTE: return oup make backward faster but take more memory | |||||
| return (affine_oup, oup, mean, x2s), (True, False, False, False) | |||||
| @subgraph("LayerNorm", dtype, device, 3, gopt_level=gopt_level) | |||||
| def layerNorm(inputs, f, c): | |||||
| inp, eps, _flatten_shape = inputs | |||||
| inp_shape = f(GetVarShape(), inp) | |||||
| inp = f(Reshape(axis=dim), inp, _flatten_shape) | |||||
| mean = f(Reduce(mode="mean", axis=-1), inp) | |||||
| x2s = f(Reduce(mode="sum_sqr", axis=-1), inp) | |||||
| reduce_shape = f(GetVarShape(), x2s) | |||||
| reduce_size = f( | |||||
| "//", | |||||
| f(Reduce(mode="product", axis=0), inp_shape), | |||||
| f(Reduce(mode="product", axis=0), reduce_shape), | |||||
| ) | |||||
| reduce_size_f = f(TypeCvt(dtype=dtype), reduce_size) | |||||
| var = f("-", f("/", x2s, reduce_size_f), f("**", mean, c(2))) | |||||
| inv_sqrt_var = f("**", f("+", var, eps), c(-0.5)) | |||||
| oup = f("fma3", inp, inv_sqrt_var, f("*", f("-", mean), inv_sqrt_var)) | |||||
| oup = f(Reshape(), oup, inp_shape) | |||||
| return (oup,), (True,) | |||||
| return (layerNorm, layerNormAffine) | |||||
| def layer_norm( | def layer_norm( | ||||
| inp: Tensor, | inp: Tensor, | ||||
| normalized_shape: tuple, | normalized_shape: tuple, | ||||
| @@ -1133,32 +1082,34 @@ def layer_norm( | |||||
| normalized_shape: the shape that you want to be normalizated | normalized_shape: the shape that you want to be normalizated | ||||
| affine: whether to use weight and bias | affine: whether to use weight and bias | ||||
| weight: must not be None when the affine is true | weight: must not be None when the affine is true | ||||
| bias: must not be None when the bias is true | |||||
| bias: must not be None when the affine is true | |||||
| eps: a value added to the denominator for numerical stability. Default: 1e-5 | eps: a value added to the denominator for numerical stability. Default: 1e-5 | ||||
| """ | """ | ||||
| if amp._enabled: | if amp._enabled: | ||||
| inp, weight, bias = cast_tensors(inp, weight, bias, promote=True) | inp, weight, bias = cast_tensors(inp, weight, bias, promote=True) | ||||
| _device = inp.device | |||||
| _dtype = inp.dtype | |||||
| _dim = len(inp.shape) - len(normalized_shape) | |||||
| if isinstance(normalized_shape, int): | |||||
| normalized_shape = [normalized_shape] | |||||
| _flatten_shape = concat( | |||||
| ( | |||||
| convert_single_value(inp.shape[:_dim], dtype="int32", device=inp.device), | |||||
| convert_single_value(-1, dtype="int32", device=inp.device), | |||||
| ) | |||||
| ) | |||||
| (layerNorm, layerNormAffine) = _get_layerNorm(_device, _dtype, _dim) | |||||
| normalized_dim = len(normalized_shape) | |||||
| assert normalized_dim > 0 | |||||
| eps = convert_single_value(eps, dtype=inp.dtype, device=inp.device) | |||||
| normalized_size = 1 | |||||
| for i in range(normalized_dim): | |||||
| normalized_size = normalized_size * normalized_shape[i] | |||||
| op = builtin.LayerNorm( | |||||
| affine=affine, | |||||
| eps=eps, | |||||
| normalized_dim=normalized_dim, | |||||
| normalized_size=normalized_size, | |||||
| ) | |||||
| if affine: | if affine: | ||||
| outvar, *_ = apply(layerNormAffine(), inp, eps, _flatten_shape, weight, bias) | |||||
| assert weight is not None and bias is not None | |||||
| return apply(op, inp, weight, bias)[0] | |||||
| else: | else: | ||||
| outvar, *_ = apply(layerNorm(), inp, eps, _flatten_shape) | |||||
| return outvar | |||||
| # assert weight is None and bias is None | |||||
| return apply(op, inp)[0] | |||||
| def batch_norm( | def batch_norm( | ||||
| @@ -865,61 +865,6 @@ def test_conv1d(): | |||||
| ) | ) | ||||
| def test_layer_norm(): | |||||
| def _layer_norm(x, normalized_shape, affine, weight=None, bias=None, eps=1e-5): | |||||
| __layer_norm = LayerNorm(normalized_shape=normalized_shape, affine=affine) | |||||
| __layer_norm.weight = weight | |||||
| __layer_norm.bias = bias | |||||
| return __layer_norm(x) | |||||
| def _layer_norm_numpy( | |||||
| x, normalized_shape, affine, weight=None, bias=None, eps=1e-5 | |||||
| ): | |||||
| x_shape = x.shape | |||||
| dim_delta = len(x_shape) - len(normalized_shape) | |||||
| non_flatten_shape = x_shape[:dim_delta] | |||||
| x = x.reshape(*non_flatten_shape, -1) | |||||
| mean = x.mean(axis=-1, keepdims=True) | |||||
| var = (x ** 2).mean(axis=-1, keepdims=True) - mean * mean | |||||
| x = (x - mean) / F.sqrt(var + eps) | |||||
| x = x.reshape(x_shape) | |||||
| if affine: | |||||
| x = weight * x + bias | |||||
| return x | |||||
| normalized_shape = (28, 28) | |||||
| inp_feat = Tensor(np.random.randn(32, 64, 28, 28), dtype="float32") | |||||
| weight = Tensor(np.random.randn(28, 28), dtype="float32") | |||||
| bias = Tensor(np.random.randn(28, 28), dtype="float32") | |||||
| inp_feat = inp_feat + 1 | |||||
| weight = weight + 1 | |||||
| bias = bias | |||||
| affine = False | |||||
| outvar = F.nn.layer_norm(inp_feat, normalized_shape, affine, weight, bias) | |||||
| targetvar = _layer_norm_numpy(inp_feat, normalized_shape, affine, weight, bias) | |||||
| assert abs(outvar - targetvar).mean() < 1e-7 | |||||
| # no random, affine True | |||||
| normalized_shape = (28, 28) | |||||
| inp_feat = Tensor(np.ones((32, 64, 28, 28)), dtype="float32") | |||||
| weight = Tensor(np.ones((28, 28)), dtype="float32") | |||||
| bias = Tensor(np.zeros((28, 28)), dtype="float32") | |||||
| affine = True | |||||
| outvar = F.nn.layer_norm(inp_feat, normalized_shape, affine, weight, bias) | |||||
| targetvar = _layer_norm(inp_feat, normalized_shape, affine, weight, bias) | |||||
| assert abs((outvar - targetvar).mean()) < 1e-7 | |||||
| assert abs(outvar.mean()) < 1e-7 | |||||
| def test_batchnorm2d_autocast(): | def test_batchnorm2d_autocast(): | ||||
| """check amp's result is equal to manually converted result""" | """check amp's result is equal to manually converted result""" | ||||
| amp.enabled = True | amp.enabled = True | ||||
| @@ -43,7 +43,7 @@ def test_cross_entropy(): | |||||
| x = softmax(x) | x = softmax(x) | ||||
| l_ref = ref(x, y) | l_ref = ref(x, y) | ||||
| l = F.nn.cross_entropy(tensor(x, "float32"), tensor(y, "int32"), with_logits=False) | l = F.nn.cross_entropy(tensor(x, "float32"), tensor(y, "int32"), with_logits=False) | ||||
| np.testing.assert_allclose(l.numpy(), l_ref) | |||||
| np.testing.assert_allclose(l.numpy(), l_ref, 1e-6, 1e-6) | |||||
| def test_cross_entropy_reduction(): | def test_cross_entropy_reduction(): | ||||
| @@ -20,6 +20,7 @@ | |||||
| #include "megbrain/opr/dnn/correlation.h" | #include "megbrain/opr/dnn/correlation.h" | ||||
| #include "megbrain/opr/dnn/fake_quant.h" | #include "megbrain/opr/dnn/fake_quant.h" | ||||
| #include "megbrain/opr/dnn/images2neibs.h" | #include "megbrain/opr/dnn/images2neibs.h" | ||||
| #include "megbrain/opr/dnn/layer_norm.h" | |||||
| #include "megbrain/opr/dnn/local.h" | #include "megbrain/opr/dnn/local.h" | ||||
| #include "megbrain/opr/dnn/lrn.h" | #include "megbrain/opr/dnn/lrn.h" | ||||
| #include "megbrain/opr/dnn/lsq.h" | #include "megbrain/opr/dnn/lsq.h" | ||||
| @@ -636,4 +637,29 @@ auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||||
| } | } | ||||
| OP_TRAIT_REG(LRN, LRN).apply_on_var_node(apply_on_var_node).fallback(); | OP_TRAIT_REG(LRN, LRN).apply_on_var_node(apply_on_var_node).fallback(); | ||||
| } // namespace lrn | } // namespace lrn | ||||
| namespace layer_norm { | |||||
| cg::OperatorNodeBase* apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||||
| auto&& op = static_cast<const LayerNorm&>(def); | |||||
| size_t nr_inp = inputs.size(); | |||||
| auto p = op.param(); | |||||
| mgb_assert((nr_inp == 3 && p.affine) || (nr_inp == 1 && !p.affine)); | |||||
| OperatorNodeConfig config{op.make_name()}; | |||||
| if (nr_inp == 3) { | |||||
| return opr::LayerNorm::make( | |||||
| inputs[0], inputs[1], inputs[2], op.param(), config)[0] | |||||
| .node() | |||||
| ->owner_opr(); | |||||
| } else { | |||||
| return opr::LayerNorm::make(inputs[0], op.param(), config)[0] | |||||
| .node() | |||||
| ->owner_opr(); | |||||
| } | |||||
| } | |||||
| OP_TRAIT_REG(LayerNorm, LayerNorm).apply_on_var_node(apply_on_var_node).fallback(); | |||||
| } // namespace layer_norm | |||||
| } // namespace mgb::imperative | } // namespace mgb::imperative | ||||
| @@ -431,4 +431,6 @@ def Padding: MgbHashableOp<"Padding", [PaddingParam]>; | |||||
| def LRN: MgbHashableOp<"LRN", [LRNParam]>; | def LRN: MgbHashableOp<"LRN", [LRNParam]>; | ||||
| def LayerNorm: MgbHashableOp<"LayerNorm", [LayerNormParam]>; | |||||
| #endif // MGB_OPS | #endif // MGB_OPS | ||||
| @@ -16,6 +16,7 @@ | |||||
| #include "megbrain/opr/dnn/correlation.h" | #include "megbrain/opr/dnn/correlation.h" | ||||
| #include "megbrain/opr/dnn/fake_quant.h" | #include "megbrain/opr/dnn/fake_quant.h" | ||||
| #include "megbrain/opr/dnn/images2neibs.h" | #include "megbrain/opr/dnn/images2neibs.h" | ||||
| #include "megbrain/opr/dnn/layer_norm.h" | |||||
| #include "megbrain/opr/dnn/local.h" | #include "megbrain/opr/dnn/local.h" | ||||
| #include "megbrain/opr/dnn/lrn.h" | #include "megbrain/opr/dnn/lrn.h" | ||||
| #include "megbrain/opr/dnn/lsq.h" | #include "megbrain/opr/dnn/lsq.h" | ||||
| @@ -420,6 +421,47 @@ struct OprMaker<opr::BatchNormBackward, 6> { | |||||
| } | } | ||||
| }; | }; | ||||
| template <> | |||||
| struct OprMaker<opr::LayerNorm, 0> { | |||||
| using Param = opr::LayerNorm::Param; | |||||
| static cg::OperatorNodeBase* make( | |||||
| const Param& param, const cg::VarNodeArray& i, ComputingGraph& graph, | |||||
| const OperatorNodeConfig& config) { | |||||
| MGB_MARK_USED_VAR(graph); | |||||
| if (i.size() == 3) { | |||||
| return opr::LayerNorm::make(i[0], i[1], i[2], param, config)[0] | |||||
| .node() | |||||
| ->owner_opr(); | |||||
| } else { | |||||
| mgb_assert(i.size() == 1); | |||||
| return opr::LayerNorm::make(i[0], param, config)[0].node()->owner_opr(); | |||||
| } | |||||
| } | |||||
| }; | |||||
| // OprMaker in MGB_SEREG_OPR only support unique output opr | |||||
| template <> | |||||
| struct OprMaker<opr::LayerNormBackward, 0> { | |||||
| using Param = opr::LayerNormBackward::Param; | |||||
| static cg::OperatorNodeBase* make( | |||||
| const Param& param, const cg::VarNodeArray& i, ComputingGraph& graph, | |||||
| const OperatorNodeConfig& config) { | |||||
| MGB_MARK_USED_VAR(graph); | |||||
| if (i.size() == 5) { | |||||
| return opr::LayerNormBackward::make( | |||||
| i[0], i[1], i[2], i[3], i[4], param, config)[0] | |||||
| .node() | |||||
| ->owner_opr(); | |||||
| } else { | |||||
| mgb_assert(i.size() == 4); | |||||
| return opr::LayerNormBackward::make( | |||||
| i[0], i[1], i[2], i[3], param, config)[0] | |||||
| .node() | |||||
| ->owner_opr(); | |||||
| } | |||||
| } | |||||
| }; | |||||
| template <class MegDNNConv = megdnn::LocalShare> | template <class MegDNNConv = megdnn::LocalShare> | ||||
| struct MakeLocalShareCaller2 { | struct MakeLocalShareCaller2 { | ||||
| template <typename Opr> | template <typename Opr> | ||||
| @@ -641,6 +683,8 @@ MGB_SEREG_OPR(TQT, 2); | |||||
| MGB_SEREG_OPR(TQTBackward, 3); | MGB_SEREG_OPR(TQTBackward, 3); | ||||
| MGB_SEREG_OPR(LSQ, 4); | MGB_SEREG_OPR(LSQ, 4); | ||||
| MGB_SEREG_OPR(LSQBackward, 5); | MGB_SEREG_OPR(LSQBackward, 5); | ||||
| MGB_SEREG_OPR(LayerNorm, 0); | |||||
| MGB_SEREG_OPR(LayerNormBackward, 0); | |||||
| } // namespace opr | } // namespace opr | ||||
| } // namespace mgb | } // namespace mgb | ||||
| @@ -0,0 +1,248 @@ | |||||
| /** | |||||
| * \file src/opr/impl/dnn/layer_norm.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #include "megbrain/opr/dnn/layer_norm.h" | |||||
| #include "megbrain/graph/grad_impl.h" | |||||
| #include "megbrain/opr/internal/out_shape_by_sym_var.h" | |||||
| #include "megbrain/opr/utility.h" | |||||
| #include "../internal/megdnn_opr_wrapper.inl" | |||||
| using namespace mgb; | |||||
| using namespace opr; | |||||
| /* ==================== LayerNormForward ==================== */ | |||||
| MGB_DYN_TYPE_OBJ_FINAL_IMPL(LayerNormForward); | |||||
| LayerNormForward::LayerNormForward( | |||||
| VarNode* data, VarNode* weight, VarNode* bias, const Param& param, | |||||
| const OperatorNodeConfig& config) | |||||
| : Super{data->owner_graph(), config, "layer_norm", {data, weight, bias}} { | |||||
| init_megdnn_opr(*this, param); | |||||
| add_input({data, weight, bias}); | |||||
| output(0)->dtype(data->dtype()); | |||||
| output(1)->dtype(dtype::Float32()); | |||||
| output(2)->dtype(dtype::Float32()); | |||||
| } | |||||
| LayerNormForward::LayerNormForward( | |||||
| VarNode* data, const Param& param, const OperatorNodeConfig& config) | |||||
| : Super{data->owner_graph(), config, "layer_norm", {data}} { | |||||
| init_megdnn_opr(*this, param); | |||||
| add_input({data}); | |||||
| output(0)->dtype(data->dtype()); | |||||
| output(1)->dtype(dtype::Float32()); | |||||
| output(2)->dtype(dtype::Float32()); | |||||
| } | |||||
| SymbolVarArray LayerNormForward::make( | |||||
| SymbolVar data, SymbolVar weight, SymbolVar bias, const Param& param, | |||||
| const OperatorNodeConfig& config) { | |||||
| auto outs = data.node() | |||||
| ->owner_graph() | |||||
| ->insert_opr(std::make_unique<LayerNormForward>( | |||||
| data.node(), weight.node(), bias.node(), param, config)) | |||||
| ->output(); | |||||
| SymbolVarArray ret; | |||||
| for (auto&& out : outs) { | |||||
| ret.emplace_back(out); | |||||
| } | |||||
| return ret; | |||||
| } | |||||
| SymbolVarArray LayerNormForward::make( | |||||
| SymbolVar data, const Param& param, const OperatorNodeConfig& config) { | |||||
| auto outs = data.node() | |||||
| ->owner_graph() | |||||
| ->insert_opr(std::make_unique<LayerNormForward>( | |||||
| data.node(), param, config)) | |||||
| ->output(); | |||||
| SymbolVarArray ret; | |||||
| for (auto&& out : outs) { | |||||
| ret.emplace_back(out); | |||||
| } | |||||
| return ret; | |||||
| } | |||||
| void LayerNormForward::get_output_var_shape( | |||||
| const TensorShapeArray& inp_shape, TensorShapeArray& out_shape) const { | |||||
| uint64_t normalized_dim = param().normalized_dim; | |||||
| out_shape[0] = inp_shape[0]; | |||||
| TensorShape unnormalized_shape; | |||||
| unnormalized_shape.ndim = inp_shape[0].ndim - normalized_dim; | |||||
| for (size_t i = 0; i < unnormalized_shape.ndim; ++i) { | |||||
| unnormalized_shape.shape[i] = inp_shape[0].shape[i]; | |||||
| } | |||||
| out_shape[1] = unnormalized_shape; | |||||
| out_shape[2] = unnormalized_shape; | |||||
| } | |||||
| size_t LayerNormForward::get_workspace_size_bytes( | |||||
| const TensorShapeArray& input_shapes, | |||||
| const TensorShapeArray& output_shapes) const { | |||||
| return 0; | |||||
| } | |||||
| void LayerNormForward::scn_do_execute() { | |||||
| if (param().affine) { | |||||
| megdnn_opr()->exec( | |||||
| input(0)->dev_tensor().as_megdnn(), input(1)->dev_tensor().as_megdnn(), | |||||
| input(2)->dev_tensor().as_megdnn(), output(0)->dev_tensor().as_megdnn(), | |||||
| output(1)->dev_tensor().as_megdnn(), | |||||
| output(2)->dev_tensor().as_megdnn(), {}); | |||||
| } else { | |||||
| megdnn_opr()->exec( | |||||
| input(0)->dev_tensor().as_megdnn(), {}, {}, | |||||
| output(0)->dev_tensor().as_megdnn(), | |||||
| output(1)->dev_tensor().as_megdnn(), | |||||
| output(2)->dev_tensor().as_megdnn(), {}); | |||||
| } | |||||
| } | |||||
| #if MGB_ENABLE_GRAD | |||||
| MGB_IMPL_OPR_GRAD(LayerNormForward) { | |||||
| auto p = opr.param(); | |||||
| SymbolVarArray grad; | |||||
| VarNodeArray ret; | |||||
| if (p.affine) { | |||||
| mgb_assert(wrt_idx < 3, "wrt_idx %zu is out of range", wrt_idx); | |||||
| grad = LayerNormBackward::make( | |||||
| out_grad[0], opr.input(0), opr.input(1), opr.output(1), opr.output(2), | |||||
| opr.param()); | |||||
| } else { | |||||
| mgb_assert(wrt_idx < 1, "wrt_idx %zu is out of range", wrt_idx); | |||||
| grad = LayerNormBackward::make( | |||||
| out_grad[0], opr.input(0), opr.output(1), opr.output(2), opr.param()); | |||||
| } | |||||
| uint32_t nr_ret = p.affine ? 3 : 1; | |||||
| for (uint32_t i = 0; i < nr_ret; ++i) { | |||||
| ret.push_back(grad[i].node()); | |||||
| } | |||||
| return ret; | |||||
| } | |||||
| #endif | |||||
| /* ==================== LayerNormBackward ==================== */ | |||||
| MGB_DYN_TYPE_OBJ_FINAL_IMPL(LayerNormBackward); | |||||
| LayerNormBackward::LayerNormBackward( | |||||
| VarNode* diff, VarNode* data, VarNode* weight, VarNode* mean, VarNode* rstd, | |||||
| const Param& param, const OperatorNodeConfig& config) | |||||
| : Super({diff->owner_graph(), | |||||
| config, | |||||
| "layer_norm_backward", | |||||
| {diff, data, weight, mean, rstd}}, | |||||
| 0, true) { | |||||
| init_megdnn_opr(*this, param); | |||||
| add_input({diff, data, weight, mean, rstd}); | |||||
| } | |||||
| LayerNormBackward::LayerNormBackward( | |||||
| VarNode* diff, VarNode* data, VarNode* mean, VarNode* rstd, const Param& param, | |||||
| const OperatorNodeConfig& config) | |||||
| : Super({diff->owner_graph(), | |||||
| config, | |||||
| "layer_norm_backward", | |||||
| {diff, data, mean, rstd}}, | |||||
| 0, true) { | |||||
| init_megdnn_opr(*this, param); | |||||
| add_input({diff, data, mean, rstd}); | |||||
| auto mark_empty_var = [&](VarNode* var) { | |||||
| var->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE) | |||||
| .add_flag(VarNode::Flag::VOLATILE_CONTENT); | |||||
| }; | |||||
| mark_empty_var(output(1)); | |||||
| mark_empty_var(output(2)); | |||||
| } | |||||
| SymbolVarArray LayerNormBackward::make( | |||||
| SymbolVar diff, SymbolVar data, SymbolVar weight, SymbolVar mean, | |||||
| SymbolVar rstd, const Param& param, const OperatorNodeConfig& config) { | |||||
| auto outs = diff.node() | |||||
| ->owner_graph() | |||||
| ->insert_opr(std::make_unique<LayerNormBackward>( | |||||
| diff.node(), data.node(), weight.node(), mean.node(), | |||||
| rstd.node(), param, config)) | |||||
| ->output(); | |||||
| SymbolVarArray ret; | |||||
| for (auto&& out : outs) { | |||||
| ret.emplace_back(out); | |||||
| } | |||||
| return ret; | |||||
| } | |||||
| SymbolVarArray LayerNormBackward::make( | |||||
| SymbolVar diff, SymbolVar data, SymbolVar mean, SymbolVar rstd, | |||||
| const Param& param, const OperatorNodeConfig& config) { | |||||
| auto outs = diff.node() | |||||
| ->owner_graph() | |||||
| ->insert_opr(std::make_unique<LayerNormBackward>( | |||||
| diff.node(), data.node(), mean.node(), rstd.node(), | |||||
| param, config)) | |||||
| ->output(); | |||||
| SymbolVarArray ret; | |||||
| for (auto&& out : outs) { | |||||
| ret.emplace_back(out); | |||||
| } | |||||
| return ret; | |||||
| } | |||||
| void LayerNormBackward::init_output_static_infer_desc() { | |||||
| using namespace cg::static_infer; | |||||
| auto&& mgr = owner_graph()->static_infer_manager(); | |||||
| mgr.register_shape_infer(output(0), ShapeInferDesc::make_identity(input(1))); | |||||
| if (param().affine) { | |||||
| mgr.register_shape_infer(output(1), ShapeInferDesc::make_identity(input(2))); | |||||
| mgr.register_shape_infer(output(2), ShapeInferDesc::make_identity(input(2))); | |||||
| } else { | |||||
| TensorShape empty; | |||||
| empty.ndim = 0; | |||||
| mgr.register_shape_infer(output(1), ShapeInferDesc::make_const(empty)); | |||||
| mgr.register_shape_infer(output(2), ShapeInferDesc::make_const(empty)); | |||||
| } | |||||
| this->init_output_static_infer_desc_workspace(false); | |||||
| } | |||||
| void LayerNormBackward::init_output_dtype() { | |||||
| output(0)->dtype(input(1)->dtype()); | |||||
| output(1)->dtype(input(2)->dtype()); | |||||
| output(2)->dtype(input(2)->dtype()); | |||||
| } | |||||
| size_t LayerNormBackward::get_workspace_size_bytes( | |||||
| const TensorShapeArray& input_shapes, | |||||
| const TensorShapeArray& output_shapes) const { | |||||
| return 0; | |||||
| } | |||||
| void LayerNormBackward::scn_do_execute() { | |||||
| if (param().affine) { | |||||
| megdnn_opr()->exec( | |||||
| input(0)->dev_tensor().as_megdnn(), input(1)->dev_tensor().as_megdnn(), | |||||
| input(2)->dev_tensor().as_megdnn(), input(3)->dev_tensor().as_megdnn(), | |||||
| input(4)->dev_tensor().as_megdnn(), output(0)->dev_tensor().as_megdnn(), | |||||
| output(1)->dev_tensor().as_megdnn(), | |||||
| output(2)->dev_tensor().as_megdnn(), {}); | |||||
| } else { | |||||
| megdnn_opr()->exec( | |||||
| input(0)->dev_tensor().as_megdnn(), input(1)->dev_tensor().as_megdnn(), | |||||
| {}, input(2)->dev_tensor().as_megdnn(), | |||||
| input(3)->dev_tensor().as_megdnn(), output(0)->dev_tensor().as_megdnn(), | |||||
| {}, {}, {}); | |||||
| } | |||||
| } | |||||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||||
| @@ -0,0 +1,78 @@ | |||||
| /** | |||||
| * \file src/opr/include/megbrain/opr/dnn/layer_norm.h | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| */ | |||||
| #pragma once | |||||
| #include "megbrain/opr/internal/megdnn_opr_wrapper.h" | |||||
| #include "megdnn/oprs.h" | |||||
| namespace mgb { | |||||
| namespace opr { | |||||
| MGB_DEFINE_OPR_CLASS_WITH_EXPORT( | |||||
| LayerNormForward, intl::MegDNNOprWrapperFwd<megdnn::LayerNormForward>) // { | |||||
| public: | |||||
| MGE_WIN_DECLSPEC_FUC LayerNormForward( | |||||
| VarNode* data, VarNode* weight, VarNode* bias, const Param& param, | |||||
| const OperatorNodeConfig& config); | |||||
| MGE_WIN_DECLSPEC_FUC LayerNormForward( | |||||
| VarNode* data, const Param& param, const OperatorNodeConfig& config); | |||||
| MGE_WIN_DECLSPEC_FUC static SymbolVarArray make( | |||||
| SymbolVar data, SymbolVar weight, SymbolVar bias, const Param& param = {}, | |||||
| const OperatorNodeConfig& config = {}); | |||||
| MGE_WIN_DECLSPEC_FUC static SymbolVarArray make( | |||||
| SymbolVar data, const Param& param = {}, | |||||
| const OperatorNodeConfig& config = {}); | |||||
| private: | |||||
| void get_output_var_shape( | |||||
| const TensorShapeArray& inp_shape, | |||||
| TensorShapeArray& out_shape) const override; | |||||
| size_t get_workspace_size_bytes( | |||||
| const TensorShapeArray& input_shapes, | |||||
| const TensorShapeArray& output_shapes) const override; | |||||
| void scn_do_execute() override; | |||||
| }; | |||||
| using LayerNorm = LayerNormForward; | |||||
| MGB_DEFINE_OPR_CLASS_WITH_EXPORT( | |||||
| LayerNormBackward, intl::MegDNNOprWrapperBwd<megdnn::LayerNormBackward>) // { | |||||
| public: | |||||
| MGE_WIN_DECLSPEC_FUC LayerNormBackward( | |||||
| VarNode* diff, VarNode* data, VarNode* weight, VarNode* mean, VarNode* rstd, | |||||
| const Param& param, const OperatorNodeConfig& config); | |||||
| MGE_WIN_DECLSPEC_FUC LayerNormBackward( | |||||
| VarNode* diff, VarNode* data, VarNode* mean, VarNode* rstd, | |||||
| const Param& param, const OperatorNodeConfig& config); | |||||
| MGE_WIN_DECLSPEC_FUC static SymbolVarArray make( | |||||
| SymbolVar diff, SymbolVar data, SymbolVar weight, SymbolVar mean, | |||||
| SymbolVar rstd, const Param& param = {}, | |||||
| const OperatorNodeConfig& config = {}); | |||||
| MGE_WIN_DECLSPEC_FUC static SymbolVarArray make( | |||||
| SymbolVar diff, SymbolVar data, SymbolVar mean, SymbolVar rstd, | |||||
| const Param& param = {}, const OperatorNodeConfig& config = {}); | |||||
| private: | |||||
| void init_output_static_infer_desc() override; | |||||
| void init_output_dtype() override; | |||||
| size_t get_workspace_size_bytes( | |||||
| const TensorShapeArray& input_shapes, | |||||
| const TensorShapeArray& output_shapes) const override; | |||||
| void scn_do_execute() override; | |||||
| }; | |||||
| } // namespace opr | |||||
| } // namespace mgb | |||||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||||
| @@ -0,0 +1,108 @@ | |||||
| /** | |||||
| * \file src/opr/test/dnn/layer_norm.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #include "megbrain/opr/dnn/layer_norm.h" | |||||
| #include "megbrain/comp_node_env.h" | |||||
| #include "megbrain/test/autocheck.h" | |||||
| #include "megbrain/test/helper.h" | |||||
| #include "megbrain/test/megdnn_helper.h" | |||||
| #include "megdnn/oprs.h" | |||||
| #include <cmath> | |||||
| #include <iomanip> | |||||
| #include <random> | |||||
| #include <sstream> | |||||
| using namespace mgb; | |||||
| namespace { | |||||
| using Param = opr::LayerNormForward::Param; | |||||
| void run_forward(bool is_affine, size_t normalized_size) { | |||||
| using Checker = AutoOprChecker<3, 3>; | |||||
| Param param; | |||||
| param.eps = 1e-5; | |||||
| param.affine = is_affine; | |||||
| param.normalized_dim = 1; | |||||
| param.normalized_size = normalized_size; | |||||
| auto make_graph = [&](const Checker::SymInpArray& inputs) -> Checker::SymOutArray { | |||||
| auto out = opr::LayerNormForward::make(inputs[0], inputs[1], inputs[2], param); | |||||
| return {out[0], out[1], out[2]}; | |||||
| }; | |||||
| auto fwd = [&](Checker::NumOutArray& dest, Checker::NumInpArray inp) { | |||||
| auto opr = | |||||
| MegDNNHandle::get(CompNodeEnv::from_comp_node(CompNode::default_cpu())) | |||||
| ->create_operator<megdnn::LayerNormForward>(); | |||||
| auto inp_shape = inp[0]->shape(); | |||||
| auto n_slices = inp_shape[0]; | |||||
| auto slice_len = inp_shape[1]; | |||||
| opr->param() = param; | |||||
| dest[0].dtype(dtype::Float32()) | |||||
| .comp_node(inp[0]->comp_node()) | |||||
| .resize({n_slices, slice_len}); | |||||
| dest[1].dtype(dtype::Float32()) | |||||
| .comp_node(inp[0]->comp_node()) | |||||
| .resize({n_slices}); | |||||
| dest[2].dtype(dtype::Float32()) | |||||
| .comp_node(inp[0]->comp_node()) | |||||
| .resize({n_slices}); | |||||
| opr->exec( | |||||
| inp[0]->as_megdnn(), inp[1]->as_megdnn(), inp[2]->as_megdnn(), | |||||
| dest[0].as_megdnn(), dest[1].as_megdnn(), dest[2].as_megdnn(), {}); | |||||
| }; | |||||
| auto gen = [&](HostTensorND& src) { | |||||
| HostTensorGenerator<dtype::Float32, RandomDistribution::GAUSSIAN> src_gen(0.f); | |||||
| src = *src_gen(src.shape(), src.comp_node()); | |||||
| }; | |||||
| Checker::RunOptions option; | |||||
| option.numdiff_max_err = 1e-4; | |||||
| Checker checker{make_graph, fwd}; | |||||
| checker.set_input_generator(0, gen); | |||||
| checker.set_input_generator(1, gen); | |||||
| checker.set_input_generator(2, gen); | |||||
| checker.set_input_allow_grad(0, false); | |||||
| checker.set_input_allow_grad(1, false); | |||||
| checker.set_input_allow_grad(2, false); | |||||
| checker.set_output_allow_grad(0, false); | |||||
| checker.set_output_allow_grad(1, false); | |||||
| checker.set_output_allow_grad(2, false); | |||||
| checker.run({TensorShape{normalized_size, normalized_size}, | |||||
| TensorShape{normalized_size}, TensorShape{normalized_size}}, | |||||
| option) | |||||
| .run({TensorShape{normalized_size, normalized_size}, | |||||
| TensorShape{normalized_size}, TensorShape{normalized_size}}, | |||||
| option) | |||||
| .run({TensorShape{normalized_size, normalized_size}, | |||||
| TensorShape{normalized_size}, TensorShape{normalized_size}}, | |||||
| option); | |||||
| } | |||||
| TEST(TestOprDNN, LayerNormForwardAffine) { | |||||
| REQUIRE_GPU(1); | |||||
| run_forward(true, 1); | |||||
| run_forward(true, 16); | |||||
| run_forward(true, 17); | |||||
| } | |||||
| } // anonymous namespace | |||||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||||
| @@ -116,6 +116,7 @@ union OperatorParam { | |||||
| param.Padding = 82, | param.Padding = 82, | ||||
| param.ShuffleRNG = 83, | param.ShuffleRNG = 83, | ||||
| param.CheckNonFinite = 84, | param.CheckNonFinite = 84, | ||||
| param.LayerNorm = 85, | |||||
| } | } | ||||
| table Operator { | table Operator { | ||||