GitOrigin-RevId: 0cd484e753
tags/v1.8.0
| @@ -1936,6 +1936,75 @@ protected: | |||
| const TensorLayout& grad_s, size_t workspace_in_bytes); | |||
| }; | |||
| class LayerNormBase : public OperatorBase { | |||
| DEF_OPR_IMPL_CTOR(LayerNormBase, OperatorBase); | |||
| DEF_OPR_PARAM(LayerNorm); | |||
| protected: | |||
| void deduce_layout_fwd( | |||
| const TensorLayout& data, const TensorLayout& weight, | |||
| const TensorLayout& bias, TensorLayout& dst, TensorLayout& mean, | |||
| TensorLayout& rstd); | |||
| void check_layout_fwd( | |||
| const TensorLayout& data, const TensorLayout& weight, | |||
| const TensorLayout& bias, const TensorLayout& dst, const TensorLayout& mean, | |||
| const TensorLayout& rstd); | |||
| }; | |||
| class LayerNormForward : public LayerNormBase { | |||
| DEF_OPR_IMPL(LayerNormForward, LayerNormBase, 3, 3); | |||
| public: | |||
| virtual void exec( | |||
| _megdnn_tensor_in data, _megdnn_tensor_in weight, _megdnn_tensor_in bias, | |||
| _megdnn_tensor_out dst, _megdnn_tensor_out mean, _megdnn_tensor_out rstd, | |||
| _megdnn_workspace workspace) = 0; | |||
| void deduce_layout( | |||
| const TensorLayout& data, const TensorLayout& weight, | |||
| const TensorLayout& bias, TensorLayout& dst, TensorLayout& mean, | |||
| TensorLayout& rstd); | |||
| virtual size_t get_workspace_in_bytes( | |||
| const TensorLayout& data, const TensorLayout& weight, | |||
| const TensorLayout& bias, const TensorLayout& dst, const TensorLayout& mean, | |||
| const TensorLayout& rstd) = 0; | |||
| protected: | |||
| void check_exec( | |||
| const TensorLayout& data, const TensorLayout& weight, | |||
| const TensorLayout& bias, const TensorLayout& dst, const TensorLayout& mean, | |||
| const TensorLayout& rstd, size_t workspace_in_bytes); | |||
| }; | |||
| using LayerNorm = LayerNormForward; | |||
| class LayerNormBackward : public LayerNormBase { | |||
| DEF_OPR_IMPL(LayerNormBackward, LayerNormBase, 5, 3); | |||
| public: | |||
| virtual void exec( | |||
| _megdnn_tensor_in diff, _megdnn_tensor_in data, _megdnn_tensor_in weight, | |||
| _megdnn_tensor_in mean, _megdnn_tensor_in rstd, _megdnn_tensor_out ddata, | |||
| _megdnn_tensor_out dweight, _megdnn_tensor_out dbias, | |||
| _megdnn_workspace workspace) = 0; | |||
| void deduce_layout( | |||
| const TensorLayout& diff, const TensorLayout& data, | |||
| const TensorLayout& weight, const TensorLayout& mean, | |||
| const TensorLayout& rstd, TensorLayout& ddata, TensorLayout& dweight, | |||
| TensorLayout& dbias); | |||
| virtual size_t get_workspace_in_bytes( | |||
| const TensorLayout& diff, const TensorLayout& data, | |||
| const TensorLayout& weight, const TensorLayout& mean, | |||
| const TensorLayout& rstd, const TensorLayout& ddata, | |||
| const TensorLayout& dweight, const TensorLayout& dbias) = 0; | |||
| protected: | |||
| void check_exec( | |||
| const TensorLayout& diff, const TensorLayout& data, | |||
| const TensorLayout& weight, const TensorLayout& mean, | |||
| const TensorLayout& rstd, const TensorLayout& ddata, | |||
| const TensorLayout& dweight, const TensorLayout& dbias, | |||
| size_t workspace_in_bytes); | |||
| }; | |||
| } // namespace megdnn | |||
| #include "megdnn/internal/opr_header_epilogue.h" | |||
| @@ -1212,3 +1212,10 @@ PADDING_MODES = [Doc('REPLICATE = 0', 'aaaaaa|abcdefgh|hhhhhhh'), | |||
| member_alias=[(i, 'PADDING_{}'.format(i)) for i in PADDING_MODES] | |||
| ) | |||
| ) | |||
| (pdef('LayerNorm') | |||
| .add_fields('bool', 'affine', 'true') | |||
| .add_fields('float32', 'eps', '1e-5f') | |||
| .add_fields('uint64', 'normalized_dim', '1') | |||
| .add_fields('uint64', 'normalized_size', '1') | |||
| ) | |||
| @@ -209,7 +209,10 @@ private: | |||
| cb(LSQBackward) \ | |||
| cb(Fill) \ | |||
| cb(PaddingForward) \ | |||
| cb(PaddingBackward) | |||
| cb(PaddingBackward) \ | |||
| cb(LayerNormForward) \ | |||
| cb(LayerNormBackward) | |||
| // clang-format on | |||
| /*! | |||
| @@ -0,0 +1,180 @@ | |||
| /** | |||
| * \file dnn/src/common/layer_norm.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "megdnn/oprs.h" | |||
| #include "src/common/utils.h" | |||
| namespace megdnn { | |||
| void LayerNormBase::deduce_layout_fwd( | |||
| const TensorLayout& data, const TensorLayout& weight, const TensorLayout& bias, | |||
| TensorLayout& dst, TensorLayout& mean, TensorLayout& rstd) { | |||
| MEGDNN_MARK_USED_VAR(weight); | |||
| MEGDNN_MARK_USED_VAR(bias); | |||
| auto p = param(); | |||
| TensorShape unnormalized_shape; | |||
| unnormalized_shape.ndim = data.ndim - p.normalized_dim; | |||
| for (size_t i = 0; i < unnormalized_shape.ndim; ++i) { | |||
| unnormalized_shape.shape[i] = data.shape[i]; | |||
| } | |||
| TensorLayout unnormalized_layout = | |||
| TensorLayout(unnormalized_shape, dtype::Float32()); | |||
| dst = data; | |||
| mean = unnormalized_layout; | |||
| rstd = unnormalized_layout; | |||
| } | |||
| void LayerNormBase::check_layout_fwd( | |||
| const TensorLayout& data, const TensorLayout& weight, const TensorLayout& bias, | |||
| const TensorLayout& dst, const TensorLayout& mean, const TensorLayout& rstd) { | |||
| megdnn_assert_contiguous(data); | |||
| megdnn_assert_contiguous(weight); | |||
| megdnn_assert_contiguous(bias); | |||
| megdnn_assert_contiguous(dst); | |||
| megdnn_assert_contiguous(mean); | |||
| megdnn_assert_contiguous(rstd); | |||
| auto errmsg = [&]() { | |||
| return megdnn_layout_msg(data) + ", " + megdnn_layout_msg(weight) + ", " + | |||
| megdnn_layout_msg(bias) + ", " + megdnn_layout_msg(dst) + ", " + | |||
| megdnn_layout_msg(mean) + ", " + megdnn_layout_msg(rstd); | |||
| }; | |||
| MEGDNN_MARK_USED_VAR(errmsg); | |||
| auto equal_layout = [](const TensorLayout& lhs, const TensorLayout& rhs) -> bool { | |||
| if (!(lhs.ndim == rhs.ndim && lhs.dtype == rhs.dtype && | |||
| lhs.format == rhs.format)) | |||
| return false; | |||
| for (size_t i = 0; i < lhs.ndim; ++i) { | |||
| if (lhs.shape[i] != rhs.shape[i] || lhs.stride[i] != rhs.stride[i]) { | |||
| return false; | |||
| } | |||
| } | |||
| return true; | |||
| }; | |||
| megdnn_assert(equal_layout(data, dst), "%s", errmsg().c_str()); | |||
| megdnn_assert(equal_layout(weight, bias), "%s", errmsg().c_str()); | |||
| megdnn_assert(equal_layout(mean, rstd), "%s", errmsg().c_str()); | |||
| auto p = param(); | |||
| uint64_t normalized_dim = p.normalized_dim; | |||
| size_t unnormalized_dim = data.ndim - normalized_dim; | |||
| megdnn_assert( | |||
| normalized_dim < data.ndim, | |||
| "the dims of normalized shape should smaller than input dims"); | |||
| for (size_t i = 0; i < unnormalized_dim; ++i) { | |||
| megdnn_assert(data.shape[i] == mean.shape[i], "%s", errmsg().c_str()); | |||
| } | |||
| if (p.affine) { | |||
| for (size_t i = 0; i < normalized_dim; ++i) { | |||
| megdnn_assert( | |||
| data.shape[unnormalized_dim + i] == weight.shape[i], "%s", | |||
| errmsg().c_str()); | |||
| } | |||
| } | |||
| } | |||
| void LayerNormForward::deduce_layout( | |||
| const TensorLayout& data, const TensorLayout& weight, const TensorLayout& bias, | |||
| TensorLayout& dst, TensorLayout& mean, TensorLayout& rstd) { | |||
| deduce_layout_fwd(data, weight, bias, dst, mean, rstd); | |||
| } | |||
| void LayerNormForward::check_exec( | |||
| const TensorLayout& data, const TensorLayout& weight, const TensorLayout& bias, | |||
| const TensorLayout& dst, const TensorLayout& mean, const TensorLayout& rstd, | |||
| size_t workspace_in_bytes) { | |||
| check_layout_fwd(data, weight, bias, dst, mean, rstd); | |||
| auto required_workspace_in_bytes = | |||
| get_workspace_in_bytes(data, weight, bias, dst, mean, rstd); | |||
| megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); | |||
| } | |||
| void LayerNormBackward::deduce_layout( | |||
| const TensorLayout& diff, const TensorLayout& data, const TensorLayout& weight, | |||
| const TensorLayout& mean, const TensorLayout& rstd, TensorLayout& ddata, | |||
| TensorLayout& dweight, TensorLayout& dbias) { | |||
| MEGDNN_MARK_USED_VAR(diff); | |||
| MEGDNN_MARK_USED_VAR(mean); | |||
| MEGDNN_MARK_USED_VAR(rstd); | |||
| ddata = data; | |||
| dweight = weight; | |||
| dbias = weight; | |||
| } | |||
| void LayerNormBackward::check_exec( | |||
| const TensorLayout& diff, const TensorLayout& data, const TensorLayout& weight, | |||
| const TensorLayout& mean, const TensorLayout& rstd, const TensorLayout& ddata, | |||
| const TensorLayout& dweight, const TensorLayout& dbias, | |||
| size_t workspace_in_bytes) { | |||
| auto p = param(); | |||
| auto required_workspace_in_bytes = get_workspace_in_bytes( | |||
| diff, data, weight, mean, rstd, ddata, dweight, dbias); | |||
| megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); | |||
| megdnn_assert_contiguous(diff); | |||
| megdnn_assert_contiguous(data); | |||
| megdnn_assert_contiguous(mean); | |||
| megdnn_assert_contiguous(rstd); | |||
| megdnn_assert_contiguous(ddata); | |||
| if (p.affine) { | |||
| megdnn_assert_contiguous(weight); | |||
| megdnn_assert_contiguous(dweight); | |||
| megdnn_assert_contiguous(dbias); | |||
| } | |||
| auto errmsg = [&]() { | |||
| return megdnn_layout_msg(diff) + ", " + megdnn_layout_msg(data) + ", " + | |||
| megdnn_layout_msg(weight) + ", " + megdnn_layout_msg(mean) + ", " + | |||
| megdnn_layout_msg(rstd) + ", " + megdnn_layout_msg(ddata) + ", " + | |||
| megdnn_layout_msg(dweight) + ", " + megdnn_layout_msg(dbias); | |||
| }; | |||
| MEGDNN_MARK_USED_VAR(errmsg); | |||
| auto equal_layout = [](const TensorLayout& lhs, const TensorLayout& rhs) -> bool { | |||
| if (!(lhs.ndim == rhs.ndim && lhs.dtype == rhs.dtype && | |||
| lhs.format == rhs.format)) | |||
| return false; | |||
| for (size_t i = 0; i < lhs.ndim; ++i) { | |||
| if (lhs.shape[i] != rhs.shape[i] || lhs.stride[i] != rhs.stride[i]) { | |||
| return false; | |||
| } | |||
| } | |||
| return true; | |||
| }; | |||
| megdnn_assert(equal_layout(data, ddata), "%s", errmsg().c_str()); | |||
| megdnn_assert(equal_layout(mean, rstd), "%s", errmsg().c_str()); | |||
| if (p.affine) { | |||
| megdnn_assert(equal_layout(weight, dweight), "%s", errmsg().c_str()); | |||
| megdnn_assert(equal_layout(weight, dbias), "%s", errmsg().c_str()); | |||
| } | |||
| size_t normalized_dim = p.normalized_dim; | |||
| size_t unnormalized_dim = data.ndim - normalized_dim; | |||
| for (size_t i = 0; i < unnormalized_dim; ++i) { | |||
| megdnn_assert(data.shape[i] == mean.shape[i], "%s", errmsg().c_str()); | |||
| } | |||
| if (p.affine) { | |||
| for (size_t i = 0; i < normalized_dim; ++i) { | |||
| megdnn_assert( | |||
| data.shape[unnormalized_dim + i] == weight.shape[i], "%s", | |||
| errmsg().c_str()); | |||
| } | |||
| } | |||
| } | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -135,6 +135,8 @@ DEF(CheckNonFinite, 2, true, true); | |||
| DEF(LSQForward, 5, true, true); | |||
| DEF(LSQBackward, 7, true, false); | |||
| DEF(Fill, 1, true, false); | |||
| DEF(LayerNormForward, 6, true, true); | |||
| DEF(LayerNormBackward, 8, true, true); | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -45,6 +45,7 @@ | |||
| #include "src/cuda/images2neibs/opr_impl.h" | |||
| #include "src/cuda/indexing_multi_axis_vec/opr_impl.h" | |||
| #include "src/cuda/indexing_one_hot/opr_impl.h" | |||
| #include "src/cuda/layer_norm/opr_impl.h" | |||
| #include "src/cuda/linspace/opr_impl.h" | |||
| #include "src/cuda/local/opr_impl.h" | |||
| #include "src/cuda/local_share/opr_impl.h" | |||
| @@ -0,0 +1,664 @@ | |||
| /** | |||
| * \file dnn/src/cuda/layer_norm/layer_norm_cuda.cu | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include <thrust/pair.h> | |||
| #include <thrust/tuple.h> | |||
| #include <cfloat> | |||
| #include "megdnn/arch.h" | |||
| #include "megdnn/dtype.h" | |||
| #include "src/cuda/cuda_shfl_compat.cuh" | |||
| #include "src/cuda/layer_norm/layer_norm_cuda.cuh" | |||
| #include "src/cuda/utils.cuh" | |||
| namespace megdnn { | |||
| namespace cuda { | |||
| namespace layer_norm { | |||
| constexpr int kCUDANumThreads = 256; | |||
| constexpr int vec_size = 4; | |||
| // warp size may be used as array length, or used in host function, | |||
| // so we define WARP_SIZE rather than using warpSize | |||
| #define WARP_SIZE 32 | |||
| #if defined(__clang__) | |||
| #define __ubsan_ignore_float_divide_by_zero__ \ | |||
| __attribute__((no_sanitize("float-divide-by-zero"))) | |||
| #else | |||
| #define __ubsan_ignore_float_divide_by_zero__ | |||
| #endif | |||
| struct WelfordStat { | |||
| float mean; | |||
| float sigma2; | |||
| float count; | |||
| MEGDNN_HOST MEGDNN_DEVICE WelfordStat() : mean(0.f), sigma2(0.f), count(0.f) {} | |||
| MEGDNN_HOST MEGDNN_DEVICE WelfordStat(float mean, float sigma2, float count) | |||
| : mean(mean), sigma2(sigma2), count(count) {} | |||
| }; | |||
| template <typename T, typename combine_t> | |||
| struct WelfordData { | |||
| T mean; | |||
| T sigma2; | |||
| combine_t count; | |||
| MEGDNN_HOST MEGDNN_DEVICE WelfordData() : mean(0), sigma2(0), count(0) {} | |||
| MEGDNN_HOST MEGDNN_DEVICE WelfordData(T mean, T sigma2, combine_t count) | |||
| : mean(mean), sigma2(sigma2), count(count) {} | |||
| }; | |||
| template <typename T, typename combine_t, typename res_t> | |||
| struct WelfordOps { | |||
| public: | |||
| using WelfordData_T = WelfordData<T, combine_t>; | |||
| inline MEGDNN_DEVICE WelfordData_T reduce(WelfordData_T acc, T data) const { | |||
| T delta = data - acc.mean; | |||
| T new_mean = static_cast<T>(acc.mean + delta / (acc.count + 1)); | |||
| T new_delta = static_cast<T>(data - new_mean); | |||
| return { | |||
| new_mean, | |||
| acc.sigma2 + delta * new_delta, | |||
| combine_t(acc.count + 1), | |||
| }; | |||
| } | |||
| inline MEGDNN_DEVICE WelfordData_T | |||
| combine(WelfordData_T lhs, WelfordData_T rhs) const { | |||
| if (lhs.count != 0 && rhs.count != 0) { | |||
| T delta = rhs.mean - lhs.mean; | |||
| combine_t new_count = lhs.count + rhs.count; | |||
| T nb_over_n = rhs.count / new_count; | |||
| return {lhs.mean + delta * nb_over_n, | |||
| lhs.sigma2 + rhs.sigma2 + delta * delta * lhs.count * nb_over_n, | |||
| new_count}; | |||
| } else { | |||
| return (lhs.count != 0) ? lhs : rhs; | |||
| } | |||
| } | |||
| inline MEGDNN_DEVICE res_t | |||
| project(WelfordData_T acc) const __ubsan_ignore_float_divide_by_zero__ { | |||
| const auto mean = static_cast<T>(acc.mean); | |||
| const combine_t divisor = static_cast<combine_t>(acc.count); | |||
| const auto var = acc.sigma2 / divisor; | |||
| res_t results(var, mean); | |||
| return results; | |||
| } | |||
| #if defined(__CUDACC__) || defined(__HIPCC__) | |||
| inline MEGDNN_DEVICE WelfordData_T | |||
| warp_shfl_down(WelfordData_T acc, int offset) const { | |||
| return {__shfl_down(acc.mean, offset, warpSize), | |||
| __shfl_down(acc.sigma2, offset, warpSize), | |||
| __shfl_down(acc.count, offset, warpSize)}; | |||
| } | |||
| #endif | |||
| MEGDNN_HOST MEGDNN_DEVICE WelfordOps() {} | |||
| }; | |||
| template <typename T, int vec_size> | |||
| struct alignas(sizeof(T) * vec_size) aligned_vector { | |||
| T val[vec_size]; | |||
| }; | |||
| template <typename T, bool is_cuda> | |||
| using acc_type = T; | |||
| template <typename U> | |||
| MEGDNN_DEVICE WelfordStat | |||
| update_welford_stat_online(const U val, const WelfordStat& curr_sum) { | |||
| U delta = static_cast<U>(val - curr_sum.mean); | |||
| U new_count = static_cast<U>(curr_sum.count + 1.f); | |||
| U new_mean = static_cast<U>(curr_sum.mean + delta * (1.f / new_count)); | |||
| return {new_mean, curr_sum.sigma2 + delta * (val - new_mean), new_count}; | |||
| } | |||
| MEGDNN_DEVICE WelfordStat | |||
| combine_welford_stat(const WelfordStat lhs, const WelfordStat rhs) { | |||
| using U = decltype(lhs.count); | |||
| U delta = lhs.mean - rhs.mean; | |||
| U count = rhs.count + lhs.count; | |||
| U mean, sigma2; | |||
| if (count > decltype(lhs.count){0}) { | |||
| auto coef = 1.f / count; | |||
| auto nA = rhs.count * coef; | |||
| auto nB = lhs.count * coef; | |||
| mean = nA * rhs.mean + nB * lhs.mean; | |||
| sigma2 = rhs.sigma2 + lhs.sigma2 + delta * delta * rhs.count * nB; | |||
| } else { | |||
| mean = U(0); | |||
| sigma2 = U(0); | |||
| } | |||
| return {mean, sigma2, count}; | |||
| } | |||
| template <typename T> | |||
| MEGDNN_DEVICE WelfordStat | |||
| compute_stats(const T* __restrict__ X, const int slice_len, float* buf) { | |||
| using vec_t = aligned_vector<T, vec_size>; | |||
| using acc_t = acc_type<T, true>; | |||
| const vec_t* X_vec = reinterpret_cast<const vec_t*>(X); | |||
| const int numx = blockDim.x * blockDim.y; | |||
| const int thrx = threadIdx.x + threadIdx.y * blockDim.x; | |||
| const int n_vec_to_read = slice_len / vec_size; | |||
| WelfordStat w_stat(0.f, 0.f, 0.f); | |||
| for (int i = thrx; i < n_vec_to_read; i += numx) { | |||
| vec_t data = X_vec[i]; | |||
| #pragma unroll | |||
| for (int ii = 0; ii < vec_size; ii++) { | |||
| w_stat = update_welford_stat_online( | |||
| static_cast<acc_t>(data.val[ii]), w_stat); | |||
| } | |||
| } | |||
| // intra-warp reduction | |||
| #pragma unroll | |||
| for (int offset = (warpSize >> 1); offset > 0; offset >>= 1) { | |||
| WelfordStat w_tmp{ | |||
| __shfl_down(w_stat.mean, offset, warpSize), | |||
| __shfl_down(w_stat.sigma2, offset, warpSize), | |||
| __shfl_down(w_stat.count, offset, warpSize)}; | |||
| w_stat = combine_welford_stat(w_stat, w_tmp); | |||
| } | |||
| // threadIdx.x == 0 has correct values for each warp | |||
| // inter-warp reductions | |||
| if (blockDim.y > 1) { | |||
| float* mean_sigma_buf = buf; | |||
| float* count_buf = buf + blockDim.y; | |||
| for (int offset = blockDim.y / 2; offset > 0; offset /= 2) { | |||
| // upper half of warps write to shared | |||
| if (threadIdx.x == 0 && threadIdx.y >= offset && threadIdx.y < 2 * offset) { | |||
| const int wrt_y = threadIdx.y - offset; | |||
| mean_sigma_buf[2 * wrt_y] = w_stat.mean; | |||
| mean_sigma_buf[2 * wrt_y + 1] = w_stat.sigma2; | |||
| count_buf[wrt_y] = w_stat.count; | |||
| } | |||
| __syncthreads(); | |||
| // lower half merges | |||
| if (threadIdx.x == 0 && threadIdx.y < offset) { | |||
| WelfordStat w_tmp{ | |||
| mean_sigma_buf[2 * threadIdx.y], | |||
| mean_sigma_buf[2 * threadIdx.y + 1], count_buf[threadIdx.y]}; | |||
| w_stat = combine_welford_stat(w_stat, w_tmp); | |||
| } | |||
| __syncthreads(); | |||
| } | |||
| if (threadIdx.x == 0 && threadIdx.y == 0) { | |||
| mean_sigma_buf[0] = w_stat.mean; | |||
| mean_sigma_buf[1] = w_stat.sigma2 / float(slice_len); | |||
| } | |||
| __syncthreads(); | |||
| return WelfordStat{mean_sigma_buf[0], mean_sigma_buf[1], 0.f}; | |||
| } else { | |||
| return WelfordStat{ | |||
| __shfl(w_stat.mean, 0, warpSize), | |||
| __shfl(w_stat.sigma2, 0, warpSize) / float(slice_len), 0.f}; | |||
| } | |||
| } | |||
| template <typename T, typename T_ACC> | |||
| __global__ void vectorized_layer_norm_forward_affine_kernel( | |||
| const int slice_len, T_ACC eps, const T* __restrict__ X, const T* weight, | |||
| const T* bias, T_ACC* mean, T_ACC* rstd, T* Y) { | |||
| // if we made smem WelfordStat type, there would be bank conflicts, | |||
| // as one thread would have to write 3 consecutive floats | |||
| extern __shared__ float s_data[]; | |||
| auto slice_id = blockIdx.x; | |||
| const T* slice = X + slice_id * slice_len; | |||
| WelfordStat slice_w_stat = compute_stats(slice, slice_len, s_data); | |||
| using vec_t = aligned_vector<T, vec_size>; | |||
| const vec_t* X_vec = reinterpret_cast<const vec_t*>(slice); | |||
| vec_t* Y_vec = reinterpret_cast<vec_t*>(Y + slice_id * slice_len); | |||
| const int numx = blockDim.x * blockDim.y; | |||
| const int thrx = threadIdx.x + threadIdx.y * blockDim.x; | |||
| const int n_vec_to_read = slice_len / vec_size; | |||
| T_ACC rstd_val = static_cast<T_ACC>(rsqrt(slice_w_stat.sigma2 + eps)); | |||
| for (int i = thrx; i < n_vec_to_read; i += numx) { | |||
| vec_t data = X_vec[i]; | |||
| vec_t out; | |||
| // computation is performed in T_ACC, X is cast to T_ACC and result is | |||
| // implicitly cast to T | |||
| #pragma unroll | |||
| for (int ii = 0; ii < vec_size; ii++) { | |||
| out.val[ii] = static_cast<T_ACC>(weight[i * vec_size + ii]) * | |||
| (rstd_val * (static_cast<T_ACC>(data.val[ii]) - | |||
| slice_w_stat.mean)) + | |||
| static_cast<T_ACC>(bias[i * vec_size + ii]); | |||
| } | |||
| Y_vec[i] = out; | |||
| } | |||
| if (thrx == 0) { | |||
| mean[slice_id] = slice_w_stat.mean; | |||
| rstd[slice_id] = rstd_val; | |||
| } | |||
| } | |||
| template <typename T, typename T_ACC> | |||
| __global__ void vectorized_layer_norm_forward_kernel( | |||
| const int slice_len, T_ACC eps, const T* __restrict__ X, const T* weight, | |||
| const T* bias, T_ACC* mean, T_ACC* rstd, T* Y) { | |||
| extern __shared__ float s_data[]; | |||
| auto slice_id = blockIdx.x; | |||
| const T* slice = X + slice_id * slice_len; | |||
| WelfordStat slice_w_stat = compute_stats(slice, slice_len, s_data); | |||
| using vec_t = aligned_vector<T, vec_size>; | |||
| const vec_t* X_vec = reinterpret_cast<const vec_t*>(slice); | |||
| vec_t* Y_vec = reinterpret_cast<vec_t*>(Y + slice_id * slice_len); | |||
| const int numx = blockDim.x * blockDim.y; | |||
| const int thrx = threadIdx.x + threadIdx.y * blockDim.x; | |||
| const int n_vec_to_read = slice_len / vec_size; | |||
| T_ACC rstd_val = static_cast<T_ACC>(rsqrt(slice_w_stat.sigma2 + eps)); | |||
| for (int i = thrx; i < n_vec_to_read; i += numx) { | |||
| vec_t data = X_vec[i]; | |||
| vec_t out; | |||
| #pragma unroll | |||
| for (int ii = 0; ii < vec_size; ii++) { | |||
| out.val[ii] = | |||
| rstd_val * (static_cast<T_ACC>(data.val[ii]) - slice_w_stat.mean); | |||
| } | |||
| Y_vec[i] = out; | |||
| } | |||
| if (thrx == 0) { | |||
| mean[slice_id] = slice_w_stat.mean; | |||
| rstd[slice_id] = rstd_val; | |||
| } | |||
| } | |||
| template <typename T, typename T_ACC> | |||
| void launch_vectorized_layer_norm_forward_kernel( | |||
| int64_t slice_len, int64_t slice_num, T_ACC eps, const T* X_data, | |||
| const T* weight_data, const T* bias_data, T* Y_data, T_ACC* mean_data, | |||
| T_ACC* rstd_data, cudaStream_t stream) { | |||
| const int num_threads = 128; | |||
| const dim3 threads(WARP_SIZE, num_threads / WARP_SIZE, 1); | |||
| const dim3 blocks(slice_num); | |||
| int nshared = threads.y > 1 ? threads.y * 3 / 2 * sizeof(T_ACC) : 0; | |||
| if (weight_data == nullptr && bias_data == nullptr) { | |||
| vectorized_layer_norm_forward_kernel<<<blocks, threads, nshared, stream>>>( | |||
| slice_len, eps, X_data, weight_data, bias_data, mean_data, rstd_data, | |||
| Y_data); | |||
| } else { | |||
| vectorized_layer_norm_forward_affine_kernel<<< | |||
| blocks, threads, nshared, stream>>>( | |||
| slice_len, eps, X_data, weight_data, bias_data, mean_data, rstd_data, | |||
| Y_data); | |||
| } | |||
| after_kernel_launch(); | |||
| } | |||
| template <typename T, class ReduceOp> | |||
| __inline__ MEGDNN_DEVICE T welford_warp_reduce(T val, const ReduceOp& op) { | |||
| #pragma unroll | |||
| for (int offset = (warpSize >> 1); offset > 0; offset >>= 1) { | |||
| val = op.combine(val, op.warp_shfl_down(val, offset)); | |||
| } | |||
| return val; | |||
| } | |||
| template <typename T, class ReduceOp> | |||
| __inline__ MEGDNN_DEVICE T | |||
| welford_block_reduce(T val, const ReduceOp& op, const T& identity_element, T* shared) { | |||
| const int lid = threadIdx.x % warpSize; | |||
| const int wid = threadIdx.x / warpSize; | |||
| val = welford_warp_reduce(val, op); | |||
| __syncthreads(); | |||
| if (lid == 0) { | |||
| shared[wid] = val; | |||
| } | |||
| __syncthreads(); | |||
| val = (threadIdx.x < blockDim.x / warpSize) ? shared[lid] : identity_element; | |||
| if (wid == 0) { | |||
| val = welford_warp_reduce(val, op); | |||
| } | |||
| return val; | |||
| } | |||
| template <typename T, typename T_ACC> | |||
| __global__ void get_input_mean_and_rstd_kernel( | |||
| int64_t slice_len, T_ACC eps, const T* X, T_ACC* mean, T_ACC* rstd) { | |||
| using WelfordType = WelfordData<T_ACC, T_ACC>; | |||
| using WelfordOp = WelfordOps<T_ACC, T_ACC, thrust::pair<T_ACC, T_ACC>>; | |||
| __shared__ typename std::aligned_storage< | |||
| sizeof(WelfordType), alignof(WelfordType)>::type val_shared[WARP_SIZE]; | |||
| WelfordType* val_shared_ptr = reinterpret_cast<WelfordType*>(val_shared); | |||
| const int64_t i = blockIdx.x; | |||
| WelfordOp welford_op; | |||
| WelfordType val( | |||
| static_cast<T_ACC>(0), static_cast<T_ACC>(0), static_cast<T_ACC>(0)); | |||
| for (int64_t j = threadIdx.x; j < slice_len; j += blockDim.x) { | |||
| const int64_t index = i * slice_len + j; | |||
| val = welford_op.reduce(val, static_cast<T_ACC>(X[index])); | |||
| } | |||
| val = welford_block_reduce( | |||
| val, welford_op, | |||
| WelfordType( | |||
| static_cast<T_ACC>(0), static_cast<T_ACC>(0), | |||
| static_cast<T_ACC>(0)), | |||
| val_shared_ptr); | |||
| if (threadIdx.x == 0) { | |||
| T_ACC slice_mean; | |||
| T_ACC slice_sigma2; | |||
| thrust::tie(slice_sigma2, slice_mean) = welford_op.project(val); | |||
| mean[i] = slice_mean; | |||
| rstd[i] = rsqrt(slice_sigma2 + eps); | |||
| } | |||
| } | |||
| template <typename T, typename T_ACC> | |||
| __global__ void layer_norm_forward_kernel( | |||
| int64_t slice_len, const T* X, const T_ACC* mean, const T_ACC* rstd, | |||
| const T* weight, const T* bias, T* Y) { | |||
| const int64_t i = blockIdx.x; | |||
| for (int64_t j = threadIdx.x; j < slice_len; j += blockDim.x) { | |||
| const int64_t index = i * slice_len + j; | |||
| const T_ACC weight_v = | |||
| weight == nullptr ? T_ACC(1) : static_cast<T_ACC>(weight[j]); | |||
| const T_ACC bias_v = bias == nullptr ? T_ACC(0) : static_cast<T_ACC>(bias[j]); | |||
| Y[index] = (static_cast<T_ACC>(X[index]) - static_cast<T_ACC>(mean[i])) * | |||
| static_cast<T_ACC>(rstd[i]) * weight_v + | |||
| bias_v; | |||
| } | |||
| } | |||
| template <typename T, typename T_ACC> | |||
| void forward( | |||
| T* X, T* weight, T* bias, int64_t slice_num, int64_t slice_len, T_ACC eps, T* Y, | |||
| T_ACC* mean, T_ACC* rstd, cudaStream_t stream) { | |||
| auto can_vectorize = [&](const T* ptr, int alignment) { | |||
| uint64_t addr = reinterpret_cast<uint64_t>(ptr); | |||
| return addr % alignment == 0; | |||
| }; | |||
| constexpr int num_vec_elems = vec_size; | |||
| constexpr int alignment = num_vec_elems * sizeof(T); | |||
| if ((std::is_same<T, dt_float32>::value || std::is_same<T, dt_float16>::value || | |||
| std::is_same<T, dt_bfloat16>::value) && | |||
| slice_len <= static_cast<int64_t>(1ULL << std::numeric_limits<float>::digits) && | |||
| slice_len % num_vec_elems == 0 && can_vectorize(X, alignment) && | |||
| can_vectorize(Y, alignment)) { | |||
| launch_vectorized_layer_norm_forward_kernel<T, T_ACC>( | |||
| slice_len, slice_num, static_cast<T_ACC>(eps), X, weight, bias, Y, mean, | |||
| rstd, stream); | |||
| after_kernel_launch(); | |||
| } else { | |||
| get_input_mean_and_rstd_kernel<T, T_ACC> | |||
| <<<slice_num, 512, 0, stream>>>(slice_len, eps, X, mean, rstd); | |||
| after_kernel_launch(); | |||
| layer_norm_forward_kernel<T, T_ACC><<<slice_num, kCUDANumThreads, 0, stream>>>( | |||
| slice_len, X, mean, rstd, weight, bias, Y); | |||
| after_kernel_launch(); | |||
| } | |||
| } | |||
| template <typename T> | |||
| __inline__ MEGDNN_DEVICE T warp_reduce_sum(T val) { | |||
| #pragma unroll | |||
| for (int offset = (warpSize >> 1); offset > 0; offset >>= 1) { | |||
| val += __shfl_down(val, offset, warpSize); | |||
| } | |||
| return val; | |||
| } | |||
| template <typename T> | |||
| __inline__ MEGDNN_DEVICE T block_reduce_sum(T val, T* shared) { | |||
| const int lid = threadIdx.x % warpSize; | |||
| const int wid = threadIdx.x / warpSize; | |||
| val = warp_reduce_sum(val); | |||
| __syncthreads(); | |||
| if (lid == 0) { | |||
| shared[wid] = val; | |||
| } | |||
| __syncthreads(); | |||
| val = (threadIdx.x < blockDim.x / warpSize) ? shared[lid] : T(0); | |||
| if (wid == 0) { | |||
| val = warp_reduce_sum(val); | |||
| } | |||
| return val; | |||
| } | |||
| template <typename T, typename T_ACC> | |||
| __inline__ MEGDNN_DEVICE void layer_norm_grad_input_kernel_impl( | |||
| const T* __restrict__ dY, const T* __restrict__ X, | |||
| const T_ACC* __restrict__ mean, const T_ACC* __restrict__ rstd, | |||
| const T* __restrict__ weight, T* dX, const int slice_len, T_ACC* buf) { | |||
| const auto slice_id = blockIdx.x; | |||
| const T_ACC mean_val = mean[slice_id]; | |||
| const T_ACC rstd_val = rstd[slice_id]; | |||
| T_ACC stats_x1{0}, stats_x2{0}; | |||
| constexpr int unroll = 4; | |||
| auto l = unroll * threadIdx.x; | |||
| const T* X_i = X + slice_id * slice_len; | |||
| const T* dY_i = dY + slice_id * slice_len; | |||
| T* dX_i = dX + slice_id * slice_len; | |||
| // vectorized reads don't improve perf, so use regular unrolling | |||
| for (; l + unroll - 1 < slice_len; l += blockDim.x * unroll) { | |||
| #pragma unroll | |||
| for (int k = 0; k < unroll; k++) { | |||
| T_ACC weight_val = | |||
| (weight != nullptr) ? static_cast<T_ACC>(weight[l + k]) : T_ACC(1); | |||
| const T_ACC c_h = static_cast<T_ACC>(X_i[l + k]); | |||
| const T_ACC c_loss = static_cast<T_ACC>(dY_i[l + k]); | |||
| stats_x1 += c_loss * weight_val; | |||
| stats_x2 += c_loss * weight_val * (c_h - mean_val) * rstd_val; | |||
| } | |||
| } | |||
| for (; l < slice_len; l++) { | |||
| T_ACC weight_val = | |||
| (weight != nullptr) ? static_cast<T_ACC>(weight[l]) : T_ACC(1); | |||
| const T_ACC c_h = static_cast<T_ACC>(X_i[l]); | |||
| const T_ACC c_loss = static_cast<T_ACC>(dY_i[l]); | |||
| stats_x1 += c_loss * weight_val; | |||
| stats_x2 += c_loss * weight_val * (c_h - mean_val) * rstd_val; | |||
| } | |||
| stats_x1 = block_reduce_sum(stats_x1, buf); | |||
| stats_x2 = block_reduce_sum(stats_x2, buf); | |||
| if (threadIdx.x == 0) { | |||
| buf[0] = stats_x1; | |||
| buf[1] = stats_x2; | |||
| } | |||
| __syncthreads(); | |||
| stats_x1 = buf[0]; | |||
| stats_x2 = buf[1]; | |||
| T_ACC fH = slice_len; | |||
| T_ACC term1 = (T_ACC(1) / fH) * rstd_val; | |||
| for (int l = threadIdx.x; l < slice_len; l += blockDim.x) { | |||
| const T_ACC x = X_i[l]; | |||
| const T_ACC dy = dY_i[l]; | |||
| T_ACC weight_val = | |||
| (weight != nullptr) ? static_cast<T_ACC>(weight[l]) : T_ACC(1); | |||
| T_ACC f_grad_input = fH * weight_val * dy; | |||
| f_grad_input -= (x - mean_val) * rstd_val * stats_x2; | |||
| f_grad_input -= stats_x1; | |||
| f_grad_input *= term1; | |||
| dX_i[l] = f_grad_input; | |||
| } | |||
| } | |||
| template <typename T, typename T_ACC> | |||
| __global__ void layer_norm_grad_input_kernel( | |||
| const T* __restrict__ dY, const T* __restrict__ X, | |||
| const T_ACC* __restrict__ mean, const T_ACC* __restrict__ rstd, | |||
| const T* __restrict__ weight, T* dX, const int slice_len) { | |||
| alignas(sizeof(double)) extern __shared__ char s_data1[]; | |||
| T_ACC* buf = reinterpret_cast<T_ACC*>(&s_data1); | |||
| layer_norm_grad_input_kernel_impl(dY, X, mean, rstd, weight, dX, slice_len, buf); | |||
| } | |||
| template <typename T, typename T_ACC> | |||
| __global__ void layer_norm_grad_weight_bias_simple_kernel( | |||
| int64_t slice_num, int64_t slice_len, const T* dY, const T* X, | |||
| const T_ACC* mean, const T_ACC* rstd, T* dweight, T* dbias) { | |||
| const int64_t j = blockIdx.x * blockDim.x + threadIdx.x; | |||
| if (j < slice_len) { | |||
| T_ACC sum1 = 0; | |||
| T_ACC sum2 = 0; | |||
| for (int64_t i = 0; i < slice_num; ++i) { | |||
| const int64_t index = i * slice_len + j; | |||
| sum1 += dweight == nullptr ? T_ACC(0) | |||
| : static_cast<T_ACC>(dY[index]) * | |||
| (static_cast<T_ACC>(X[index]) - | |||
| static_cast<T_ACC>(mean[i])) * | |||
| static_cast<T_ACC>(rstd[i]); | |||
| sum2 += dbias == nullptr ? T_ACC(0) : static_cast<T_ACC>(dY[index]); | |||
| } | |||
| if (dweight != nullptr) { | |||
| dweight[j] = sum1; | |||
| } | |||
| if (dbias != nullptr) { | |||
| dbias[j] = sum2; | |||
| } | |||
| } | |||
| } | |||
| template <typename T, typename T_ACC> | |||
| __global__ void layer_norm_grad_weight_bias_kernel( | |||
| int64_t slice_num, int64_t slice_len, const T* dY, const T* X, | |||
| const T_ACC* mean, const T_ACC* rstd, T* dweight, T* dbias) { | |||
| alignas(sizeof(double)) extern __shared__ char s_data1[]; | |||
| T_ACC* s_data_typed = reinterpret_cast<T_ACC*>(&s_data1); | |||
| const int64_t j = blockIdx.x * blockDim.x + threadIdx.x; | |||
| constexpr int unroll = 8; | |||
| T dYs[unroll]; | |||
| T Xs[unroll]; | |||
| T_ACC* means = s_data_typed; | |||
| T_ACC* rstds = s_data_typed + unroll * blockDim.y; | |||
| T_ACC dg_sum = 0; | |||
| T_ACC db_sum = 0; | |||
| if (j < slice_len) { | |||
| int bcounter; | |||
| for (bcounter = 0; bcounter < slice_num / (blockDim.y * unroll); bcounter++) { | |||
| int offset = (bcounter * blockDim.y + threadIdx.y) * unroll; | |||
| #pragma unroll | |||
| for (int ii = 0; ii < unroll; ii++) { | |||
| if (threadIdx.x == 0) { | |||
| means[ii * blockDim.y + threadIdx.y] = mean[offset + ii]; | |||
| rstds[ii * blockDim.y + threadIdx.y] = rstd[offset + ii]; | |||
| } | |||
| dYs[ii] = dY[(offset + ii) * slice_len + j]; | |||
| Xs[ii] = X[(offset + ii) * slice_len + j]; | |||
| } | |||
| __syncthreads(); | |||
| #pragma unroll | |||
| for (int ii = 0; ii < unroll; ii++) { | |||
| dg_sum += dYs[ii] * (Xs[ii] - means[ii * blockDim.y + threadIdx.y]) * | |||
| rstds[ii * blockDim.y + threadIdx.y]; | |||
| db_sum += dYs[ii]; | |||
| } | |||
| __syncthreads(); | |||
| } | |||
| int offset = (bcounter * blockDim.y + threadIdx.y) * unroll; | |||
| for (int ii = 0; ii < 8; ii++) { | |||
| T_ACC mean_val, rstd_val; // we don't use smem in the tail to avoid awkward | |||
| // synchronizations, perf penalty is negligible | |||
| if ((offset + ii) < slice_num) { | |||
| mean_val = mean[offset + ii]; | |||
| rstd_val = rstd[offset + ii]; | |||
| dYs[0] = dY[(offset + ii) * slice_len + j]; | |||
| Xs[0] = X[(offset + ii) * slice_len + j]; | |||
| dg_sum += dYs[0] * (Xs[0] - mean_val) * rstd_val; | |||
| db_sum += dYs[0]; | |||
| } | |||
| } | |||
| s_data_typed[threadIdx.y * blockDim.x + threadIdx.x] = dg_sum; | |||
| s_data_typed[blockDim.x * blockDim.y + threadIdx.y * blockDim.x + threadIdx.x] = | |||
| db_sum; | |||
| __syncthreads(); | |||
| for (int offset = blockDim.y / 2; offset >= 1; offset /= 2) { | |||
| if (threadIdx.y < offset) { | |||
| s_data_typed[threadIdx.y * blockDim.x + threadIdx.x] += | |||
| s_data_typed[(threadIdx.y + offset) * blockDim.x + threadIdx.x]; | |||
| s_data_typed | |||
| [blockDim.x * blockDim.y + threadIdx.y * blockDim.x + | |||
| threadIdx.x] += s_data_typed | |||
| [blockDim.x * blockDim.y + | |||
| (threadIdx.y + offset) * blockDim.x + threadIdx.x]; | |||
| } | |||
| __syncthreads(); | |||
| } | |||
| if (threadIdx.y == 0) { | |||
| if (dweight) { | |||
| dweight[j] = s_data_typed[threadIdx.x]; | |||
| } | |||
| if (dbias) { | |||
| dbias[j] = s_data_typed[threadIdx.x + blockDim.x * blockDim.y]; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| template <typename T, typename T_ACC> | |||
| void backward( | |||
| const T* dY_data, const T* X_data, const T_ACC* mean_data, | |||
| const T_ACC* rstd_data, const T* weight_data, int64_t slice_num, | |||
| int64_t slice_len, T* dX_data, T* dweight_data, T* dbias_data, | |||
| cudaStream_t stream) { | |||
| if (dX_data != nullptr) { | |||
| const int num_threads = 128; | |||
| const dim3 blocks(slice_num); | |||
| int nshared = (num_threads / WARP_SIZE) * sizeof(T_ACC); | |||
| layer_norm_grad_input_kernel<<<blocks, num_threads, nshared, stream>>>( | |||
| dY_data, X_data, mean_data, rstd_data, weight_data, dX_data, slice_len); | |||
| after_kernel_launch(); | |||
| } | |||
| if (dweight_data || dbias_data) { | |||
| if (slice_num < 512) { | |||
| const int64_t B = (slice_len + kCUDANumThreads - 1) / kCUDANumThreads; | |||
| layer_norm_grad_weight_bias_simple_kernel<T, T_ACC> | |||
| <<<B, kCUDANumThreads, 0, stream>>>( | |||
| slice_num, slice_len, dY_data, X_data, mean_data, rstd_data, | |||
| dweight_data, dbias_data); | |||
| after_kernel_launch(); | |||
| } else { | |||
| dim3 threads{16, 32}; | |||
| int blocks = (slice_len + threads.x - 1) / threads.x; | |||
| layer_norm_grad_weight_bias_kernel<T, T_ACC> | |||
| <<<blocks, threads, 2 * sizeof(T_ACC) * threads.x * threads.y, | |||
| stream>>>( | |||
| slice_num, slice_len, dY_data, X_data, mean_data, rstd_data, | |||
| dweight_data, dbias_data); | |||
| after_kernel_launch(); | |||
| } | |||
| } | |||
| } | |||
| #define INST(T, T_ACC) \ | |||
| template void forward<T, T_ACC>( \ | |||
| T*, T*, T*, int64_t, int64_t, T_ACC, T*, T_ACC*, T_ACC*, cudaStream_t); \ | |||
| template void backward<T, T_ACC>( \ | |||
| const T*, const T*, const T_ACC*, const T_ACC*, const T*, int64_t, \ | |||
| int64_t, T*, T*, T*, cudaStream_t); | |||
| INST(dt_float32, dt_float32) | |||
| INST(dt_float16, dt_float32) | |||
| INST(dt_bfloat16, dt_float32) | |||
| #undef INST | |||
| } // namespace layer_norm | |||
| } // namespace cuda | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,34 @@ | |||
| /** | |||
| * \file dnn/src/cuda/layer_norm/layer_norm.cuh | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #pragma once | |||
| #include <cuda_runtime_api.h> | |||
| namespace megdnn { | |||
| namespace cuda { | |||
| namespace layer_norm { | |||
| template <typename T, typename T_ACC> | |||
| void forward( | |||
| T* X, T* gamma, T* beta, int64_t M, int64_t N, T_ACC eps, T* Y, T_ACC* mean, | |||
| T_ACC* rstd, cudaStream_t stream); | |||
| template <typename T, typename T_ACC> | |||
| void backward( | |||
| const T* dY_data, const T* X_data, const T_ACC* mean_data, | |||
| const T_ACC* rstd_data, const T* gamma_data, int64_t M, int64_t N, T* dX_data, | |||
| T* dgamma_data, T* dbeta_data, cudaStream_t stream); | |||
| } // namespace layer_norm | |||
| } // namespace cuda | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,94 @@ | |||
| /** | |||
| * \file dnn/src/cuda/layer_norm/opr_impl.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "src/cuda/layer_norm/opr_impl.h" | |||
| #include "src/cuda/layer_norm/layer_norm_cuda.cuh" | |||
| #include "src/cuda/utils.h" | |||
| namespace megdnn { | |||
| namespace cuda { | |||
| void LayerNormForwardImpl::exec( | |||
| _megdnn_tensor_in data, _megdnn_tensor_in weight, _megdnn_tensor_in bias, | |||
| _megdnn_tensor_out dst, _megdnn_tensor_out mean, _megdnn_tensor_out rstd, | |||
| _megdnn_workspace workspace) { | |||
| check_exec( | |||
| data.layout, weight.layout, bias.layout, dst.layout, mean.layout, | |||
| rstd.layout, workspace.size); | |||
| auto p = param(); | |||
| float eps = p.eps; | |||
| bool affine = p.affine; | |||
| uint64_t slice_length = p.normalized_size; | |||
| uint64_t slice_dim = p.normalized_dim; | |||
| uint64_t n_slices = 1; | |||
| for (size_t i = 0; i < data.layout.ndim - slice_dim; ++i) { | |||
| n_slices = n_slices * data.layout.shape[i]; | |||
| } | |||
| auto stream = cuda_stream(handle()); | |||
| using namespace ::megdnn::cuda::layer_norm; | |||
| #define cb(DType) \ | |||
| if (data.layout.dtype == DType()) { \ | |||
| using T = typename DTypeTrait<DType>::ctype; \ | |||
| using T_ACC = float; \ | |||
| forward<T, T_ACC>( \ | |||
| data.ptr<T>(), affine ? weight.ptr<T>() : nullptr, \ | |||
| affine ? bias.ptr<T>() : nullptr, static_cast<int64_t>(n_slices), \ | |||
| static_cast<int64_t>(slice_length), static_cast<T_ACC>(eps), \ | |||
| dst.ptr<T>(), mean.ptr<T_ACC>(), rstd.ptr<T_ACC>(), stream); \ | |||
| return; \ | |||
| } | |||
| MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) | |||
| #undef cb | |||
| megdnn_throw("bad dtype"); | |||
| } | |||
| void LayerNormBackwardImpl::exec( | |||
| _megdnn_tensor_in diff, _megdnn_tensor_in data, _megdnn_tensor_in weight, | |||
| _megdnn_tensor_in mean, _megdnn_tensor_in rstd, _megdnn_tensor_out ddata, | |||
| _megdnn_tensor_out dweight, _megdnn_tensor_out dbias, | |||
| _megdnn_workspace workspace) { | |||
| check_exec( | |||
| diff.layout, data.layout, weight.layout, mean.layout, rstd.layout, | |||
| ddata.layout, dweight.layout, dbias.layout, workspace.size); | |||
| auto p = param(); | |||
| bool affine = p.affine; | |||
| uint64_t slice_length = p.normalized_size; | |||
| uint64_t slice_dim = p.normalized_dim; | |||
| uint64_t n_slices = 1; | |||
| for (size_t i = 0; i < data.layout.ndim - slice_dim; ++i) { | |||
| n_slices = n_slices * data.layout.shape[i]; | |||
| } | |||
| auto stream = cuda_stream(handle()); | |||
| using namespace ::megdnn::cuda::layer_norm; | |||
| #define cb(DType) \ | |||
| if (data.layout.dtype == DType()) { \ | |||
| using T = typename DTypeTrait<DType>::ctype; \ | |||
| using T_ACC = float; \ | |||
| backward<T, T_ACC>( \ | |||
| diff.ptr<T>(), data.ptr<T>(), mean.ptr<T_ACC>(), rstd.ptr<T_ACC>(), \ | |||
| affine ? weight.ptr<T>() : nullptr, n_slices, slice_length, \ | |||
| ddata.ptr<T>(), affine ? dweight.ptr<T>() : nullptr, \ | |||
| affine ? dbias.ptr<T>() : nullptr, stream); \ | |||
| return; \ | |||
| } | |||
| MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) | |||
| #undef cb | |||
| megdnn_throw("bad dtype"); | |||
| } | |||
| } // namespace cuda | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,53 @@ | |||
| /** | |||
| * \file dnn/src/cuda/layer_norm/opr_impl.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #pragma once | |||
| #include "megdnn/oprs.h" | |||
| #include "src/cuda/cudnn_wrapper.h" | |||
| namespace megdnn { | |||
| namespace cuda { | |||
| class LayerNormForwardImpl final : public LayerNormForward { | |||
| public: | |||
| using LayerNormForward::LayerNormForward; | |||
| void exec( | |||
| _megdnn_tensor_in data, _megdnn_tensor_in weight, _megdnn_tensor_in bias, | |||
| _megdnn_tensor_out dst, _megdnn_tensor_out mean, _megdnn_tensor_out rstd, | |||
| _megdnn_workspace workspace) override; | |||
| size_t get_workspace_in_bytes( | |||
| const TensorLayout&, const TensorLayout&, const TensorLayout&, | |||
| const TensorLayout&, const TensorLayout&, const TensorLayout&) override { | |||
| return 0; | |||
| } | |||
| }; | |||
| class LayerNormBackwardImpl final : public LayerNormBackward { | |||
| public: | |||
| using LayerNormBackward::LayerNormBackward; | |||
| void exec( | |||
| _megdnn_tensor_in diff, _megdnn_tensor_in data, _megdnn_tensor_in weight, | |||
| _megdnn_tensor_in mean, _megdnn_tensor_in rstd, _megdnn_tensor_out ddata, | |||
| _megdnn_tensor_out dweight, _megdnn_tensor_out dbias, | |||
| _megdnn_workspace workspace) override; | |||
| size_t get_workspace_in_bytes( | |||
| const TensorLayout&, const TensorLayout&, const TensorLayout&, | |||
| const TensorLayout&, const TensorLayout&, const TensorLayout&, | |||
| const TensorLayout&, const TensorLayout&) override { | |||
| return 0; | |||
| } | |||
| }; | |||
| } // namespace cuda | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -47,6 +47,7 @@ | |||
| #include "src/naive/images2neibs/opr_impl.h" | |||
| #include "src/naive/indexing_multi_axis_vec/opr_impl.h" | |||
| #include "src/naive/indexing_one_hot/opr_impl.h" | |||
| #include "src/naive/layer_norm/opr_impl.h" | |||
| #include "src/naive/linspace/opr_impl.h" | |||
| #include "src/naive/local/opr_impl.h" | |||
| #include "src/naive/local_share/opr_impl.h" | |||
| @@ -0,0 +1,170 @@ | |||
| /** | |||
| * \file dnn/src/naive/layer_norm/opr_impl.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "src/naive/layer_norm/opr_impl.h" | |||
| #include <algorithm> | |||
| #include "src/common/utils.h" | |||
| #include "src/naive/handle.h" | |||
| using namespace megdnn; | |||
| using namespace naive; | |||
| namespace { | |||
| using Param = megdnn::LayerNorm::Param; | |||
| template <typename T, typename T_ACC = float> | |||
| void forward( | |||
| _megdnn_tensor_in data, _megdnn_tensor_in weight, _megdnn_tensor_in bias, | |||
| _megdnn_tensor_out dst, _megdnn_tensor_out mean, _megdnn_tensor_out rstd, | |||
| const Param& param) { | |||
| float eps = param.eps; | |||
| bool affine = param.affine; | |||
| uint64_t slice_length = param.normalized_size; | |||
| uint64_t slice_dim = param.normalized_dim; | |||
| uint64_t n_slices = 1; | |||
| for (size_t i = 0; i < data.layout.ndim - slice_dim; ++i) { | |||
| n_slices = n_slices * data.layout.shape[i]; | |||
| } | |||
| for (size_t i = 0; i < n_slices; i++) { | |||
| T_ACC slice_sum = static_cast<T>(0.0f); | |||
| for (size_t j = 0; j < slice_length; j++) { | |||
| auto value = data.ptr<T>()[i * slice_length + j]; | |||
| slice_sum += value; | |||
| } | |||
| T_ACC slice_mean = static_cast<T>(slice_sum / slice_length); | |||
| T_ACC slice_var = static_cast<T>(0.0f); | |||
| for (size_t j = 0; j < slice_length; j++) { | |||
| slice_var += (data.ptr<T>()[i * slice_length + j] - slice_mean) * | |||
| (data.ptr<T>()[i * slice_length + j] - slice_mean); | |||
| } | |||
| slice_var = slice_var / slice_length; | |||
| T_ACC slice_std = static_cast<T>(sqrt(slice_var + eps)); | |||
| for (size_t j = 0; j < slice_length; j++) { | |||
| dst.ptr<T>()[i * slice_length + j] = | |||
| (data.ptr<T>()[i * slice_length + j] - slice_mean) / slice_std; | |||
| if (affine) { | |||
| dst.ptr<T>()[i * slice_length + j] = | |||
| dst.ptr<T>()[i * slice_length + j] * weight.ptr<T>()[j] + | |||
| bias.ptr<T>()[j]; | |||
| } | |||
| } | |||
| mean.ptr<T_ACC>()[i] = static_cast<T_ACC>(slice_mean); | |||
| rstd.ptr<T_ACC>()[i] = static_cast<T_ACC>(1.0 / slice_std); | |||
| } | |||
| } | |||
| template <typename T, typename T_ACC = float> | |||
| void backward( | |||
| _megdnn_tensor_in diff, _megdnn_tensor_in data, _megdnn_tensor_in weight, | |||
| _megdnn_tensor_in mean, _megdnn_tensor_in rstd, _megdnn_tensor_out ddata, | |||
| _megdnn_tensor_out dweight, _megdnn_tensor_out dbias, const Param& param) { | |||
| bool affine = param.affine; | |||
| uint64_t slice_length = param.normalized_size; | |||
| uint64_t slice_dim = param.normalized_dim; | |||
| uint64_t n_slices = 1; | |||
| for (size_t i = 0; i < data.layout.ndim - slice_dim; ++i) { | |||
| n_slices = n_slices * data.layout.shape[i]; | |||
| } | |||
| if (affine) { | |||
| for (size_t i = 0; i < slice_length; ++i) { | |||
| dweight.ptr<T>()[i] = 0; | |||
| dbias.ptr<T>()[i] = 0; | |||
| } | |||
| for (size_t i = 0; i < n_slices; ++i) { | |||
| for (size_t j = 0; j < slice_length; ++j) { | |||
| dweight.ptr<T>()[j] += | |||
| (data.ptr<T>()[i * slice_length + j] - mean.ptr<T_ACC>()[i]) * | |||
| rstd.ptr<T_ACC>()[i] * diff.ptr<T>()[i * slice_length + j]; | |||
| dbias.ptr<T>()[j] += diff.ptr<T>()[i * slice_length + j]; | |||
| } | |||
| } | |||
| } | |||
| for (size_t i = 0; i < n_slices; ++i) { | |||
| T_ACC ds = static_cast<T_ACC>(0.0f); | |||
| T_ACC db = static_cast<T_ACC>(0.0f); | |||
| T_ACC a = static_cast<T_ACC>(0.0f); | |||
| T_ACC b = static_cast<T_ACC>(0.0f); | |||
| T_ACC c = static_cast<T_ACC>(0.0f); | |||
| for (size_t j = 0; j < slice_length; ++j) { | |||
| auto value = data.ptr<T>()[i * slice_length + j]; | |||
| auto diff_v = diff.ptr<T>()[i * slice_length + j]; | |||
| auto weight_v = affine ? weight.ptr<T>()[j] : static_cast<T>(1.0f); | |||
| db += diff_v * weight_v; | |||
| ds += diff_v * value * weight_v; | |||
| } | |||
| a = rstd.ptr<T_ACC>()[i]; | |||
| b = (db * mean.ptr<T_ACC>()[i] - ds) * a * a * a / slice_length; | |||
| c = -b * mean.ptr<T_ACC>()[i] - db * a / slice_length; | |||
| for (uint64_t j = 0; j < slice_length; j++) { | |||
| auto weight_v = affine ? weight.ptr<T>()[j] : static_cast<T>(1.0f); | |||
| ddata.ptr<T>()[i * slice_length + j] = | |||
| diff.ptr<T>()[i * slice_length + j] * a * weight_v + | |||
| data.ptr<T>()[i * slice_length + j] * b + c; | |||
| } | |||
| } | |||
| } | |||
| } // namespace | |||
| namespace megdnn { | |||
| namespace naive { | |||
| void LayerNormForwardImpl::exec( | |||
| _megdnn_tensor_in data, _megdnn_tensor_in weight, _megdnn_tensor_in bias, | |||
| _megdnn_tensor_out dst, _megdnn_tensor_out mean, _megdnn_tensor_out rstd, | |||
| _megdnn_workspace workspace) { | |||
| check_exec( | |||
| data.layout, weight.layout, bias.layout, dst.layout, mean.layout, | |||
| rstd.layout, workspace.size); | |||
| #define cb(DType) \ | |||
| if (data.layout.dtype == DType()) { \ | |||
| MEGDNN_DISPATCH_CPU_KERN_OPR(forward<typename DTypeTrait<DType>::ctype>( \ | |||
| data, weight, bias, dst, mean, rstd, param())); \ | |||
| return; \ | |||
| } | |||
| MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) | |||
| #undef cb | |||
| megdnn_throw("bad dtype"); | |||
| } | |||
| void LayerNormBackwardImpl::exec( | |||
| _megdnn_tensor_in diff, _megdnn_tensor_in data, _megdnn_tensor_in weight, | |||
| _megdnn_tensor_in mean, _megdnn_tensor_in rstd, _megdnn_tensor_out ddata, | |||
| _megdnn_tensor_out dweight, _megdnn_tensor_out dbias, | |||
| _megdnn_workspace workspace) { | |||
| check_exec( | |||
| diff.layout, data.layout, weight.layout, mean.layout, rstd.layout, | |||
| ddata.layout, dweight.layout, dbias.layout, workspace.size); | |||
| #define cb(DType) \ | |||
| if (data.layout.dtype == DType()) { \ | |||
| MEGDNN_DISPATCH_CPU_KERN_OPR(backward<typename DTypeTrait<DType>::ctype>( \ | |||
| diff, data, weight, mean, rstd, ddata, dweight, dbias, param())); \ | |||
| return; \ | |||
| } | |||
| MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) | |||
| #undef cb | |||
| megdnn_throw("bad dtype"); | |||
| } | |||
| } // namespace naive | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,51 @@ | |||
| /** | |||
| * \file dnn/src/naive/layer_norm/opr_impl.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #pragma once | |||
| #include "megdnn/oprs.h" | |||
| namespace megdnn { | |||
| namespace naive { | |||
| class LayerNormForwardImpl final : public LayerNormForward { | |||
| public: | |||
| using LayerNormForward::LayerNormForward; | |||
| void exec( | |||
| _megdnn_tensor_in data, _megdnn_tensor_in weight, _megdnn_tensor_in bias, | |||
| _megdnn_tensor_out dst, _megdnn_tensor_out mean, _megdnn_tensor_out rstd, | |||
| _megdnn_workspace workspace) override; | |||
| size_t get_workspace_in_bytes( | |||
| const TensorLayout&, const TensorLayout&, const TensorLayout&, | |||
| const TensorLayout&, const TensorLayout&, const TensorLayout&) override { | |||
| return 0; | |||
| } | |||
| }; | |||
| class LayerNormBackwardImpl final : public LayerNormBackward { | |||
| public: | |||
| using LayerNormBackward::LayerNormBackward; | |||
| void exec( | |||
| _megdnn_tensor_in diff, _megdnn_tensor_in data, _megdnn_tensor_in weight, | |||
| _megdnn_tensor_in mean, _megdnn_tensor_in rstd, _megdnn_tensor_out ddata, | |||
| _megdnn_tensor_out dweight, _megdnn_tensor_out dbias, | |||
| _megdnn_workspace workspace) override; | |||
| size_t get_workspace_in_bytes( | |||
| const TensorLayout&, const TensorLayout&, const TensorLayout&, | |||
| const TensorLayout&, const TensorLayout&, const TensorLayout&, | |||
| const TensorLayout&, const TensorLayout&) override { | |||
| return 0; | |||
| } | |||
| }; | |||
| } // namespace naive | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -57,6 +57,15 @@ struct DeduceLayoutProxy<Opr, 5, true> { | |||
| } | |||
| }; | |||
| template <typename Opr> | |||
| struct DeduceLayoutProxy<Opr, 6, true> { | |||
| static void deduce_layout(Opr* opr, TensorLayoutArray& layouts) { | |||
| megdnn_assert(layouts.size() == 6); | |||
| opr->deduce_layout( | |||
| layouts[0], layouts[1], layouts[2], layouts[3], layouts[4], layouts[5]); | |||
| } | |||
| }; | |||
| template <typename Opr> | |||
| struct DeduceLayoutProxy<Opr, 5, false> { | |||
| static void deduce_layout(Opr*, TensorLayoutArray&) {} | |||
| @@ -0,0 +1,94 @@ | |||
| /** | |||
| * \file dnn/test/cuda/layer_norm.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "test/cuda/fixture.h" | |||
| #include "test/common/checker.h" | |||
| namespace megdnn { | |||
| namespace test { | |||
| TEST_F(CUDA, LAYERNORM_FORWARD) { | |||
| using Param = LayerNormForward::Param; | |||
| Param param; | |||
| param.affine = true; | |||
| param.eps = 1e-6; | |||
| param.normalized_dim = 1; | |||
| Checker<LayerNormForward> checker(handle_cuda()); | |||
| checker.set_epsilon(1e-2); | |||
| auto run = [&](DType d) { | |||
| for (size_t n_slices : {10, 30}) | |||
| for (size_t slice_len : {10, 30}) { | |||
| param.normalized_size = slice_len; | |||
| checker.set_param(param) | |||
| .set_dtype(0, d) | |||
| .set_dtype(1, d) | |||
| .set_dtype(2, d) | |||
| .set_dtype(3, d) | |||
| .set_dtype(4, dtype::Float32()) | |||
| .set_dtype(5, dtype::Float32()) | |||
| .execs({{n_slices, slice_len}, | |||
| {slice_len}, | |||
| {slice_len}, | |||
| {n_slices, slice_len}, | |||
| {n_slices}, | |||
| {n_slices}}); | |||
| } | |||
| }; | |||
| run(dtype::Float32()); | |||
| run(dtype::Float16()); | |||
| run(dtype::BFloat16()); | |||
| } | |||
| TEST_F(CUDA, LAYERNORM_BACKWARD) { | |||
| using Param = LayerNormBackward::Param; | |||
| Param param; | |||
| param.affine = true; | |||
| param.eps = 1e-6; | |||
| param.normalized_dim = 1; | |||
| Checker<LayerNormBackward> checker(handle_cuda()); | |||
| checker.set_epsilon(1e-1); | |||
| auto run = [&](DType d) { | |||
| for (size_t n_slices : {10, 30}) | |||
| for (size_t slice_len : {10, 30}) { | |||
| param.normalized_size = slice_len; | |||
| checker.set_param(param) | |||
| .set_dtype(0, d) | |||
| .set_dtype(1, d) | |||
| .set_dtype(2, d) | |||
| .set_dtype(3, dtype::Float32()) | |||
| .set_dtype(4, dtype::Float32()) | |||
| .set_dtype(5, d) | |||
| .set_dtype(6, d) | |||
| .set_dtype(7, d) | |||
| .execs({{n_slices, slice_len}, | |||
| {n_slices, slice_len}, | |||
| {slice_len}, | |||
| {n_slices}, | |||
| {n_slices}, | |||
| {n_slices, slice_len}, | |||
| {slice_len}, | |||
| {slice_len}}); | |||
| } | |||
| }; | |||
| run(dtype::Float32()); | |||
| run(dtype::Float16()); | |||
| run(dtype::BFloat16()); | |||
| } | |||
| } // namespace test | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -1066,57 +1066,6 @@ def softmax(inp: Tensor, axis: Optional[int] = None) -> Tensor: | |||
| return cached / down | |||
| @lru_cache(maxsize=None) | |||
| def _get_layerNorm(device, dtype, dim, gopt_level=2): | |||
| @subgraph("LayerNormAffine", dtype, device, 5, gopt_level=gopt_level) | |||
| def layerNormAffine(inputs, f, c): | |||
| inp, eps, _flatten_shape, weight, bias = inputs | |||
| inp_shape = f(GetVarShape(), inp) | |||
| inp = f(Reshape(axis=dim), inp, _flatten_shape) | |||
| mean = f(Reduce(mode="mean", axis=-1), inp) | |||
| x2s = f(Reduce(mode="sum_sqr", axis=-1), inp) | |||
| reduce_shape = f(GetVarShape(), x2s) | |||
| reduce_size = f( | |||
| "//", | |||
| f(Reduce(mode="product", axis=0), inp_shape), | |||
| f(Reduce(mode="product", axis=0), reduce_shape), | |||
| ) | |||
| reduce_size_f = f(TypeCvt(dtype=dtype), reduce_size) | |||
| var = f("-", f("/", x2s, reduce_size_f), f("**", mean, c(2))) | |||
| inv_sqrt_var = f("**", f("+", var, eps), c(-0.5)) | |||
| oup = f("fma3", inp, inv_sqrt_var, f("*", f("-", mean), inv_sqrt_var)) | |||
| affine_oup = f(Reshape(), oup, inp_shape) | |||
| affine_oup = f("fma3", affine_oup, weight, bias) | |||
| # NOTE: return oup make backward faster but take more memory | |||
| return (affine_oup, oup, mean, x2s), (True, False, False, False) | |||
| @subgraph("LayerNorm", dtype, device, 3, gopt_level=gopt_level) | |||
| def layerNorm(inputs, f, c): | |||
| inp, eps, _flatten_shape = inputs | |||
| inp_shape = f(GetVarShape(), inp) | |||
| inp = f(Reshape(axis=dim), inp, _flatten_shape) | |||
| mean = f(Reduce(mode="mean", axis=-1), inp) | |||
| x2s = f(Reduce(mode="sum_sqr", axis=-1), inp) | |||
| reduce_shape = f(GetVarShape(), x2s) | |||
| reduce_size = f( | |||
| "//", | |||
| f(Reduce(mode="product", axis=0), inp_shape), | |||
| f(Reduce(mode="product", axis=0), reduce_shape), | |||
| ) | |||
| reduce_size_f = f(TypeCvt(dtype=dtype), reduce_size) | |||
| var = f("-", f("/", x2s, reduce_size_f), f("**", mean, c(2))) | |||
| inv_sqrt_var = f("**", f("+", var, eps), c(-0.5)) | |||
| oup = f("fma3", inp, inv_sqrt_var, f("*", f("-", mean), inv_sqrt_var)) | |||
| oup = f(Reshape(), oup, inp_shape) | |||
| return (oup,), (True,) | |||
| return (layerNorm, layerNormAffine) | |||
| def layer_norm( | |||
| inp: Tensor, | |||
| normalized_shape: tuple, | |||
| @@ -1133,32 +1082,34 @@ def layer_norm( | |||
| normalized_shape: the shape that you want to be normalizated | |||
| affine: whether to use weight and bias | |||
| weight: must not be None when the affine is true | |||
| bias: must not be None when the bias is true | |||
| bias: must not be None when the affine is true | |||
| eps: a value added to the denominator for numerical stability. Default: 1e-5 | |||
| """ | |||
| if amp._enabled: | |||
| inp, weight, bias = cast_tensors(inp, weight, bias, promote=True) | |||
| _device = inp.device | |||
| _dtype = inp.dtype | |||
| _dim = len(inp.shape) - len(normalized_shape) | |||
| if isinstance(normalized_shape, int): | |||
| normalized_shape = [normalized_shape] | |||
| _flatten_shape = concat( | |||
| ( | |||
| convert_single_value(inp.shape[:_dim], dtype="int32", device=inp.device), | |||
| convert_single_value(-1, dtype="int32", device=inp.device), | |||
| ) | |||
| ) | |||
| (layerNorm, layerNormAffine) = _get_layerNorm(_device, _dtype, _dim) | |||
| normalized_dim = len(normalized_shape) | |||
| assert normalized_dim > 0 | |||
| eps = convert_single_value(eps, dtype=inp.dtype, device=inp.device) | |||
| normalized_size = 1 | |||
| for i in range(normalized_dim): | |||
| normalized_size = normalized_size * normalized_shape[i] | |||
| op = builtin.LayerNorm( | |||
| affine=affine, | |||
| eps=eps, | |||
| normalized_dim=normalized_dim, | |||
| normalized_size=normalized_size, | |||
| ) | |||
| if affine: | |||
| outvar, *_ = apply(layerNormAffine(), inp, eps, _flatten_shape, weight, bias) | |||
| assert weight is not None and bias is not None | |||
| return apply(op, inp, weight, bias)[0] | |||
| else: | |||
| outvar, *_ = apply(layerNorm(), inp, eps, _flatten_shape) | |||
| return outvar | |||
| # assert weight is None and bias is None | |||
| return apply(op, inp)[0] | |||
| def batch_norm( | |||
| @@ -865,61 +865,6 @@ def test_conv1d(): | |||
| ) | |||
| def test_layer_norm(): | |||
| def _layer_norm(x, normalized_shape, affine, weight=None, bias=None, eps=1e-5): | |||
| __layer_norm = LayerNorm(normalized_shape=normalized_shape, affine=affine) | |||
| __layer_norm.weight = weight | |||
| __layer_norm.bias = bias | |||
| return __layer_norm(x) | |||
| def _layer_norm_numpy( | |||
| x, normalized_shape, affine, weight=None, bias=None, eps=1e-5 | |||
| ): | |||
| x_shape = x.shape | |||
| dim_delta = len(x_shape) - len(normalized_shape) | |||
| non_flatten_shape = x_shape[:dim_delta] | |||
| x = x.reshape(*non_flatten_shape, -1) | |||
| mean = x.mean(axis=-1, keepdims=True) | |||
| var = (x ** 2).mean(axis=-1, keepdims=True) - mean * mean | |||
| x = (x - mean) / F.sqrt(var + eps) | |||
| x = x.reshape(x_shape) | |||
| if affine: | |||
| x = weight * x + bias | |||
| return x | |||
| normalized_shape = (28, 28) | |||
| inp_feat = Tensor(np.random.randn(32, 64, 28, 28), dtype="float32") | |||
| weight = Tensor(np.random.randn(28, 28), dtype="float32") | |||
| bias = Tensor(np.random.randn(28, 28), dtype="float32") | |||
| inp_feat = inp_feat + 1 | |||
| weight = weight + 1 | |||
| bias = bias | |||
| affine = False | |||
| outvar = F.nn.layer_norm(inp_feat, normalized_shape, affine, weight, bias) | |||
| targetvar = _layer_norm_numpy(inp_feat, normalized_shape, affine, weight, bias) | |||
| assert abs(outvar - targetvar).mean() < 1e-7 | |||
| # no random, affine True | |||
| normalized_shape = (28, 28) | |||
| inp_feat = Tensor(np.ones((32, 64, 28, 28)), dtype="float32") | |||
| weight = Tensor(np.ones((28, 28)), dtype="float32") | |||
| bias = Tensor(np.zeros((28, 28)), dtype="float32") | |||
| affine = True | |||
| outvar = F.nn.layer_norm(inp_feat, normalized_shape, affine, weight, bias) | |||
| targetvar = _layer_norm(inp_feat, normalized_shape, affine, weight, bias) | |||
| assert abs((outvar - targetvar).mean()) < 1e-7 | |||
| assert abs(outvar.mean()) < 1e-7 | |||
| def test_batchnorm2d_autocast(): | |||
| """check amp's result is equal to manually converted result""" | |||
| amp.enabled = True | |||
| @@ -43,7 +43,7 @@ def test_cross_entropy(): | |||
| x = softmax(x) | |||
| l_ref = ref(x, y) | |||
| l = F.nn.cross_entropy(tensor(x, "float32"), tensor(y, "int32"), with_logits=False) | |||
| np.testing.assert_allclose(l.numpy(), l_ref) | |||
| np.testing.assert_allclose(l.numpy(), l_ref, 1e-6, 1e-6) | |||
| def test_cross_entropy_reduction(): | |||
| @@ -20,6 +20,7 @@ | |||
| #include "megbrain/opr/dnn/correlation.h" | |||
| #include "megbrain/opr/dnn/fake_quant.h" | |||
| #include "megbrain/opr/dnn/images2neibs.h" | |||
| #include "megbrain/opr/dnn/layer_norm.h" | |||
| #include "megbrain/opr/dnn/local.h" | |||
| #include "megbrain/opr/dnn/lrn.h" | |||
| #include "megbrain/opr/dnn/lsq.h" | |||
| @@ -636,4 +637,29 @@ auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
| } | |||
| OP_TRAIT_REG(LRN, LRN).apply_on_var_node(apply_on_var_node).fallback(); | |||
| } // namespace lrn | |||
| namespace layer_norm { | |||
| cg::OperatorNodeBase* apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
| auto&& op = static_cast<const LayerNorm&>(def); | |||
| size_t nr_inp = inputs.size(); | |||
| auto p = op.param(); | |||
| mgb_assert((nr_inp == 3 && p.affine) || (nr_inp == 1 && !p.affine)); | |||
| OperatorNodeConfig config{op.make_name()}; | |||
| if (nr_inp == 3) { | |||
| return opr::LayerNorm::make( | |||
| inputs[0], inputs[1], inputs[2], op.param(), config)[0] | |||
| .node() | |||
| ->owner_opr(); | |||
| } else { | |||
| return opr::LayerNorm::make(inputs[0], op.param(), config)[0] | |||
| .node() | |||
| ->owner_opr(); | |||
| } | |||
| } | |||
| OP_TRAIT_REG(LayerNorm, LayerNorm).apply_on_var_node(apply_on_var_node).fallback(); | |||
| } // namespace layer_norm | |||
| } // namespace mgb::imperative | |||
| @@ -431,4 +431,6 @@ def Padding: MgbHashableOp<"Padding", [PaddingParam]>; | |||
| def LRN: MgbHashableOp<"LRN", [LRNParam]>; | |||
| def LayerNorm: MgbHashableOp<"LayerNorm", [LayerNormParam]>; | |||
| #endif // MGB_OPS | |||
| @@ -16,6 +16,7 @@ | |||
| #include "megbrain/opr/dnn/correlation.h" | |||
| #include "megbrain/opr/dnn/fake_quant.h" | |||
| #include "megbrain/opr/dnn/images2neibs.h" | |||
| #include "megbrain/opr/dnn/layer_norm.h" | |||
| #include "megbrain/opr/dnn/local.h" | |||
| #include "megbrain/opr/dnn/lrn.h" | |||
| #include "megbrain/opr/dnn/lsq.h" | |||
| @@ -420,6 +421,47 @@ struct OprMaker<opr::BatchNormBackward, 6> { | |||
| } | |||
| }; | |||
| template <> | |||
| struct OprMaker<opr::LayerNorm, 0> { | |||
| using Param = opr::LayerNorm::Param; | |||
| static cg::OperatorNodeBase* make( | |||
| const Param& param, const cg::VarNodeArray& i, ComputingGraph& graph, | |||
| const OperatorNodeConfig& config) { | |||
| MGB_MARK_USED_VAR(graph); | |||
| if (i.size() == 3) { | |||
| return opr::LayerNorm::make(i[0], i[1], i[2], param, config)[0] | |||
| .node() | |||
| ->owner_opr(); | |||
| } else { | |||
| mgb_assert(i.size() == 1); | |||
| return opr::LayerNorm::make(i[0], param, config)[0].node()->owner_opr(); | |||
| } | |||
| } | |||
| }; | |||
| // OprMaker in MGB_SEREG_OPR only support unique output opr | |||
| template <> | |||
| struct OprMaker<opr::LayerNormBackward, 0> { | |||
| using Param = opr::LayerNormBackward::Param; | |||
| static cg::OperatorNodeBase* make( | |||
| const Param& param, const cg::VarNodeArray& i, ComputingGraph& graph, | |||
| const OperatorNodeConfig& config) { | |||
| MGB_MARK_USED_VAR(graph); | |||
| if (i.size() == 5) { | |||
| return opr::LayerNormBackward::make( | |||
| i[0], i[1], i[2], i[3], i[4], param, config)[0] | |||
| .node() | |||
| ->owner_opr(); | |||
| } else { | |||
| mgb_assert(i.size() == 4); | |||
| return opr::LayerNormBackward::make( | |||
| i[0], i[1], i[2], i[3], param, config)[0] | |||
| .node() | |||
| ->owner_opr(); | |||
| } | |||
| } | |||
| }; | |||
| template <class MegDNNConv = megdnn::LocalShare> | |||
| struct MakeLocalShareCaller2 { | |||
| template <typename Opr> | |||
| @@ -641,6 +683,8 @@ MGB_SEREG_OPR(TQT, 2); | |||
| MGB_SEREG_OPR(TQTBackward, 3); | |||
| MGB_SEREG_OPR(LSQ, 4); | |||
| MGB_SEREG_OPR(LSQBackward, 5); | |||
| MGB_SEREG_OPR(LayerNorm, 0); | |||
| MGB_SEREG_OPR(LayerNormBackward, 0); | |||
| } // namespace opr | |||
| } // namespace mgb | |||
| @@ -0,0 +1,248 @@ | |||
| /** | |||
| * \file src/opr/impl/dnn/layer_norm.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "megbrain/opr/dnn/layer_norm.h" | |||
| #include "megbrain/graph/grad_impl.h" | |||
| #include "megbrain/opr/internal/out_shape_by_sym_var.h" | |||
| #include "megbrain/opr/utility.h" | |||
| #include "../internal/megdnn_opr_wrapper.inl" | |||
| using namespace mgb; | |||
| using namespace opr; | |||
| /* ==================== LayerNormForward ==================== */ | |||
| MGB_DYN_TYPE_OBJ_FINAL_IMPL(LayerNormForward); | |||
| LayerNormForward::LayerNormForward( | |||
| VarNode* data, VarNode* weight, VarNode* bias, const Param& param, | |||
| const OperatorNodeConfig& config) | |||
| : Super{data->owner_graph(), config, "layer_norm", {data, weight, bias}} { | |||
| init_megdnn_opr(*this, param); | |||
| add_input({data, weight, bias}); | |||
| output(0)->dtype(data->dtype()); | |||
| output(1)->dtype(dtype::Float32()); | |||
| output(2)->dtype(dtype::Float32()); | |||
| } | |||
| LayerNormForward::LayerNormForward( | |||
| VarNode* data, const Param& param, const OperatorNodeConfig& config) | |||
| : Super{data->owner_graph(), config, "layer_norm", {data}} { | |||
| init_megdnn_opr(*this, param); | |||
| add_input({data}); | |||
| output(0)->dtype(data->dtype()); | |||
| output(1)->dtype(dtype::Float32()); | |||
| output(2)->dtype(dtype::Float32()); | |||
| } | |||
| SymbolVarArray LayerNormForward::make( | |||
| SymbolVar data, SymbolVar weight, SymbolVar bias, const Param& param, | |||
| const OperatorNodeConfig& config) { | |||
| auto outs = data.node() | |||
| ->owner_graph() | |||
| ->insert_opr(std::make_unique<LayerNormForward>( | |||
| data.node(), weight.node(), bias.node(), param, config)) | |||
| ->output(); | |||
| SymbolVarArray ret; | |||
| for (auto&& out : outs) { | |||
| ret.emplace_back(out); | |||
| } | |||
| return ret; | |||
| } | |||
| SymbolVarArray LayerNormForward::make( | |||
| SymbolVar data, const Param& param, const OperatorNodeConfig& config) { | |||
| auto outs = data.node() | |||
| ->owner_graph() | |||
| ->insert_opr(std::make_unique<LayerNormForward>( | |||
| data.node(), param, config)) | |||
| ->output(); | |||
| SymbolVarArray ret; | |||
| for (auto&& out : outs) { | |||
| ret.emplace_back(out); | |||
| } | |||
| return ret; | |||
| } | |||
| void LayerNormForward::get_output_var_shape( | |||
| const TensorShapeArray& inp_shape, TensorShapeArray& out_shape) const { | |||
| uint64_t normalized_dim = param().normalized_dim; | |||
| out_shape[0] = inp_shape[0]; | |||
| TensorShape unnormalized_shape; | |||
| unnormalized_shape.ndim = inp_shape[0].ndim - normalized_dim; | |||
| for (size_t i = 0; i < unnormalized_shape.ndim; ++i) { | |||
| unnormalized_shape.shape[i] = inp_shape[0].shape[i]; | |||
| } | |||
| out_shape[1] = unnormalized_shape; | |||
| out_shape[2] = unnormalized_shape; | |||
| } | |||
| size_t LayerNormForward::get_workspace_size_bytes( | |||
| const TensorShapeArray& input_shapes, | |||
| const TensorShapeArray& output_shapes) const { | |||
| return 0; | |||
| } | |||
| void LayerNormForward::scn_do_execute() { | |||
| if (param().affine) { | |||
| megdnn_opr()->exec( | |||
| input(0)->dev_tensor().as_megdnn(), input(1)->dev_tensor().as_megdnn(), | |||
| input(2)->dev_tensor().as_megdnn(), output(0)->dev_tensor().as_megdnn(), | |||
| output(1)->dev_tensor().as_megdnn(), | |||
| output(2)->dev_tensor().as_megdnn(), {}); | |||
| } else { | |||
| megdnn_opr()->exec( | |||
| input(0)->dev_tensor().as_megdnn(), {}, {}, | |||
| output(0)->dev_tensor().as_megdnn(), | |||
| output(1)->dev_tensor().as_megdnn(), | |||
| output(2)->dev_tensor().as_megdnn(), {}); | |||
| } | |||
| } | |||
| #if MGB_ENABLE_GRAD | |||
| MGB_IMPL_OPR_GRAD(LayerNormForward) { | |||
| auto p = opr.param(); | |||
| SymbolVarArray grad; | |||
| VarNodeArray ret; | |||
| if (p.affine) { | |||
| mgb_assert(wrt_idx < 3, "wrt_idx %zu is out of range", wrt_idx); | |||
| grad = LayerNormBackward::make( | |||
| out_grad[0], opr.input(0), opr.input(1), opr.output(1), opr.output(2), | |||
| opr.param()); | |||
| } else { | |||
| mgb_assert(wrt_idx < 1, "wrt_idx %zu is out of range", wrt_idx); | |||
| grad = LayerNormBackward::make( | |||
| out_grad[0], opr.input(0), opr.output(1), opr.output(2), opr.param()); | |||
| } | |||
| uint32_t nr_ret = p.affine ? 3 : 1; | |||
| for (uint32_t i = 0; i < nr_ret; ++i) { | |||
| ret.push_back(grad[i].node()); | |||
| } | |||
| return ret; | |||
| } | |||
| #endif | |||
| /* ==================== LayerNormBackward ==================== */ | |||
| MGB_DYN_TYPE_OBJ_FINAL_IMPL(LayerNormBackward); | |||
| LayerNormBackward::LayerNormBackward( | |||
| VarNode* diff, VarNode* data, VarNode* weight, VarNode* mean, VarNode* rstd, | |||
| const Param& param, const OperatorNodeConfig& config) | |||
| : Super({diff->owner_graph(), | |||
| config, | |||
| "layer_norm_backward", | |||
| {diff, data, weight, mean, rstd}}, | |||
| 0, true) { | |||
| init_megdnn_opr(*this, param); | |||
| add_input({diff, data, weight, mean, rstd}); | |||
| } | |||
| LayerNormBackward::LayerNormBackward( | |||
| VarNode* diff, VarNode* data, VarNode* mean, VarNode* rstd, const Param& param, | |||
| const OperatorNodeConfig& config) | |||
| : Super({diff->owner_graph(), | |||
| config, | |||
| "layer_norm_backward", | |||
| {diff, data, mean, rstd}}, | |||
| 0, true) { | |||
| init_megdnn_opr(*this, param); | |||
| add_input({diff, data, mean, rstd}); | |||
| auto mark_empty_var = [&](VarNode* var) { | |||
| var->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE) | |||
| .add_flag(VarNode::Flag::VOLATILE_CONTENT); | |||
| }; | |||
| mark_empty_var(output(1)); | |||
| mark_empty_var(output(2)); | |||
| } | |||
| SymbolVarArray LayerNormBackward::make( | |||
| SymbolVar diff, SymbolVar data, SymbolVar weight, SymbolVar mean, | |||
| SymbolVar rstd, const Param& param, const OperatorNodeConfig& config) { | |||
| auto outs = diff.node() | |||
| ->owner_graph() | |||
| ->insert_opr(std::make_unique<LayerNormBackward>( | |||
| diff.node(), data.node(), weight.node(), mean.node(), | |||
| rstd.node(), param, config)) | |||
| ->output(); | |||
| SymbolVarArray ret; | |||
| for (auto&& out : outs) { | |||
| ret.emplace_back(out); | |||
| } | |||
| return ret; | |||
| } | |||
| SymbolVarArray LayerNormBackward::make( | |||
| SymbolVar diff, SymbolVar data, SymbolVar mean, SymbolVar rstd, | |||
| const Param& param, const OperatorNodeConfig& config) { | |||
| auto outs = diff.node() | |||
| ->owner_graph() | |||
| ->insert_opr(std::make_unique<LayerNormBackward>( | |||
| diff.node(), data.node(), mean.node(), rstd.node(), | |||
| param, config)) | |||
| ->output(); | |||
| SymbolVarArray ret; | |||
| for (auto&& out : outs) { | |||
| ret.emplace_back(out); | |||
| } | |||
| return ret; | |||
| } | |||
| void LayerNormBackward::init_output_static_infer_desc() { | |||
| using namespace cg::static_infer; | |||
| auto&& mgr = owner_graph()->static_infer_manager(); | |||
| mgr.register_shape_infer(output(0), ShapeInferDesc::make_identity(input(1))); | |||
| if (param().affine) { | |||
| mgr.register_shape_infer(output(1), ShapeInferDesc::make_identity(input(2))); | |||
| mgr.register_shape_infer(output(2), ShapeInferDesc::make_identity(input(2))); | |||
| } else { | |||
| TensorShape empty; | |||
| empty.ndim = 0; | |||
| mgr.register_shape_infer(output(1), ShapeInferDesc::make_const(empty)); | |||
| mgr.register_shape_infer(output(2), ShapeInferDesc::make_const(empty)); | |||
| } | |||
| this->init_output_static_infer_desc_workspace(false); | |||
| } | |||
| void LayerNormBackward::init_output_dtype() { | |||
| output(0)->dtype(input(1)->dtype()); | |||
| output(1)->dtype(input(2)->dtype()); | |||
| output(2)->dtype(input(2)->dtype()); | |||
| } | |||
| size_t LayerNormBackward::get_workspace_size_bytes( | |||
| const TensorShapeArray& input_shapes, | |||
| const TensorShapeArray& output_shapes) const { | |||
| return 0; | |||
| } | |||
| void LayerNormBackward::scn_do_execute() { | |||
| if (param().affine) { | |||
| megdnn_opr()->exec( | |||
| input(0)->dev_tensor().as_megdnn(), input(1)->dev_tensor().as_megdnn(), | |||
| input(2)->dev_tensor().as_megdnn(), input(3)->dev_tensor().as_megdnn(), | |||
| input(4)->dev_tensor().as_megdnn(), output(0)->dev_tensor().as_megdnn(), | |||
| output(1)->dev_tensor().as_megdnn(), | |||
| output(2)->dev_tensor().as_megdnn(), {}); | |||
| } else { | |||
| megdnn_opr()->exec( | |||
| input(0)->dev_tensor().as_megdnn(), input(1)->dev_tensor().as_megdnn(), | |||
| {}, input(2)->dev_tensor().as_megdnn(), | |||
| input(3)->dev_tensor().as_megdnn(), output(0)->dev_tensor().as_megdnn(), | |||
| {}, {}, {}); | |||
| } | |||
| } | |||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||
| @@ -0,0 +1,78 @@ | |||
| /** | |||
| * \file src/opr/include/megbrain/opr/dnn/layer_norm.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #pragma once | |||
| #include "megbrain/opr/internal/megdnn_opr_wrapper.h" | |||
| #include "megdnn/oprs.h" | |||
| namespace mgb { | |||
| namespace opr { | |||
| MGB_DEFINE_OPR_CLASS_WITH_EXPORT( | |||
| LayerNormForward, intl::MegDNNOprWrapperFwd<megdnn::LayerNormForward>) // { | |||
| public: | |||
| MGE_WIN_DECLSPEC_FUC LayerNormForward( | |||
| VarNode* data, VarNode* weight, VarNode* bias, const Param& param, | |||
| const OperatorNodeConfig& config); | |||
| MGE_WIN_DECLSPEC_FUC LayerNormForward( | |||
| VarNode* data, const Param& param, const OperatorNodeConfig& config); | |||
| MGE_WIN_DECLSPEC_FUC static SymbolVarArray make( | |||
| SymbolVar data, SymbolVar weight, SymbolVar bias, const Param& param = {}, | |||
| const OperatorNodeConfig& config = {}); | |||
| MGE_WIN_DECLSPEC_FUC static SymbolVarArray make( | |||
| SymbolVar data, const Param& param = {}, | |||
| const OperatorNodeConfig& config = {}); | |||
| private: | |||
| void get_output_var_shape( | |||
| const TensorShapeArray& inp_shape, | |||
| TensorShapeArray& out_shape) const override; | |||
| size_t get_workspace_size_bytes( | |||
| const TensorShapeArray& input_shapes, | |||
| const TensorShapeArray& output_shapes) const override; | |||
| void scn_do_execute() override; | |||
| }; | |||
| using LayerNorm = LayerNormForward; | |||
| MGB_DEFINE_OPR_CLASS_WITH_EXPORT( | |||
| LayerNormBackward, intl::MegDNNOprWrapperBwd<megdnn::LayerNormBackward>) // { | |||
| public: | |||
| MGE_WIN_DECLSPEC_FUC LayerNormBackward( | |||
| VarNode* diff, VarNode* data, VarNode* weight, VarNode* mean, VarNode* rstd, | |||
| const Param& param, const OperatorNodeConfig& config); | |||
| MGE_WIN_DECLSPEC_FUC LayerNormBackward( | |||
| VarNode* diff, VarNode* data, VarNode* mean, VarNode* rstd, | |||
| const Param& param, const OperatorNodeConfig& config); | |||
| MGE_WIN_DECLSPEC_FUC static SymbolVarArray make( | |||
| SymbolVar diff, SymbolVar data, SymbolVar weight, SymbolVar mean, | |||
| SymbolVar rstd, const Param& param = {}, | |||
| const OperatorNodeConfig& config = {}); | |||
| MGE_WIN_DECLSPEC_FUC static SymbolVarArray make( | |||
| SymbolVar diff, SymbolVar data, SymbolVar mean, SymbolVar rstd, | |||
| const Param& param = {}, const OperatorNodeConfig& config = {}); | |||
| private: | |||
| void init_output_static_infer_desc() override; | |||
| void init_output_dtype() override; | |||
| size_t get_workspace_size_bytes( | |||
| const TensorShapeArray& input_shapes, | |||
| const TensorShapeArray& output_shapes) const override; | |||
| void scn_do_execute() override; | |||
| }; | |||
| } // namespace opr | |||
| } // namespace mgb | |||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||
| @@ -0,0 +1,108 @@ | |||
| /** | |||
| * \file src/opr/test/dnn/layer_norm.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "megbrain/opr/dnn/layer_norm.h" | |||
| #include "megbrain/comp_node_env.h" | |||
| #include "megbrain/test/autocheck.h" | |||
| #include "megbrain/test/helper.h" | |||
| #include "megbrain/test/megdnn_helper.h" | |||
| #include "megdnn/oprs.h" | |||
| #include <cmath> | |||
| #include <iomanip> | |||
| #include <random> | |||
| #include <sstream> | |||
| using namespace mgb; | |||
| namespace { | |||
| using Param = opr::LayerNormForward::Param; | |||
| void run_forward(bool is_affine, size_t normalized_size) { | |||
| using Checker = AutoOprChecker<3, 3>; | |||
| Param param; | |||
| param.eps = 1e-5; | |||
| param.affine = is_affine; | |||
| param.normalized_dim = 1; | |||
| param.normalized_size = normalized_size; | |||
| auto make_graph = [&](const Checker::SymInpArray& inputs) -> Checker::SymOutArray { | |||
| auto out = opr::LayerNormForward::make(inputs[0], inputs[1], inputs[2], param); | |||
| return {out[0], out[1], out[2]}; | |||
| }; | |||
| auto fwd = [&](Checker::NumOutArray& dest, Checker::NumInpArray inp) { | |||
| auto opr = | |||
| MegDNNHandle::get(CompNodeEnv::from_comp_node(CompNode::default_cpu())) | |||
| ->create_operator<megdnn::LayerNormForward>(); | |||
| auto inp_shape = inp[0]->shape(); | |||
| auto n_slices = inp_shape[0]; | |||
| auto slice_len = inp_shape[1]; | |||
| opr->param() = param; | |||
| dest[0].dtype(dtype::Float32()) | |||
| .comp_node(inp[0]->comp_node()) | |||
| .resize({n_slices, slice_len}); | |||
| dest[1].dtype(dtype::Float32()) | |||
| .comp_node(inp[0]->comp_node()) | |||
| .resize({n_slices}); | |||
| dest[2].dtype(dtype::Float32()) | |||
| .comp_node(inp[0]->comp_node()) | |||
| .resize({n_slices}); | |||
| opr->exec( | |||
| inp[0]->as_megdnn(), inp[1]->as_megdnn(), inp[2]->as_megdnn(), | |||
| dest[0].as_megdnn(), dest[1].as_megdnn(), dest[2].as_megdnn(), {}); | |||
| }; | |||
| auto gen = [&](HostTensorND& src) { | |||
| HostTensorGenerator<dtype::Float32, RandomDistribution::GAUSSIAN> src_gen(0.f); | |||
| src = *src_gen(src.shape(), src.comp_node()); | |||
| }; | |||
| Checker::RunOptions option; | |||
| option.numdiff_max_err = 1e-4; | |||
| Checker checker{make_graph, fwd}; | |||
| checker.set_input_generator(0, gen); | |||
| checker.set_input_generator(1, gen); | |||
| checker.set_input_generator(2, gen); | |||
| checker.set_input_allow_grad(0, false); | |||
| checker.set_input_allow_grad(1, false); | |||
| checker.set_input_allow_grad(2, false); | |||
| checker.set_output_allow_grad(0, false); | |||
| checker.set_output_allow_grad(1, false); | |||
| checker.set_output_allow_grad(2, false); | |||
| checker.run({TensorShape{normalized_size, normalized_size}, | |||
| TensorShape{normalized_size}, TensorShape{normalized_size}}, | |||
| option) | |||
| .run({TensorShape{normalized_size, normalized_size}, | |||
| TensorShape{normalized_size}, TensorShape{normalized_size}}, | |||
| option) | |||
| .run({TensorShape{normalized_size, normalized_size}, | |||
| TensorShape{normalized_size}, TensorShape{normalized_size}}, | |||
| option); | |||
| } | |||
| TEST(TestOprDNN, LayerNormForwardAffine) { | |||
| REQUIRE_GPU(1); | |||
| run_forward(true, 1); | |||
| run_forward(true, 16); | |||
| run_forward(true, 17); | |||
| } | |||
| } // anonymous namespace | |||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||
| @@ -116,6 +116,7 @@ union OperatorParam { | |||
| param.Padding = 82, | |||
| param.ShuffleRNG = 83, | |||
| param.CheckNonFinite = 84, | |||
| param.LayerNorm = 85, | |||
| } | |||
| table Operator { | |||