Browse Source

feat(mgb/opr): add layernorm forward and backward kernel

GitOrigin-RevId: 0cd484e753
tags/v1.8.0
Megvii Engine Team 4 years ago
parent
commit
a93741815b
25 changed files with 1960 additions and 125 deletions
  1. +69
    -0
      dnn/include/megdnn/oprs/nn.h
  2. +7
    -0
      dnn/scripts/opr_param_defs.py
  3. +4
    -1
      dnn/src/common/handle_impl.h
  4. +180
    -0
      dnn/src/common/layer_norm.cpp
  5. +2
    -0
      dnn/src/common/opr_trait.h
  6. +1
    -0
      dnn/src/cuda/handle_create.cpp
  7. +664
    -0
      dnn/src/cuda/layer_norm/layer_norm_cuda.cu
  8. +34
    -0
      dnn/src/cuda/layer_norm/layer_norm_cuda.cuh
  9. +94
    -0
      dnn/src/cuda/layer_norm/opr_impl.cpp
  10. +53
    -0
      dnn/src/cuda/layer_norm/opr_impl.h
  11. +1
    -0
      dnn/src/naive/handle.cpp
  12. +170
    -0
      dnn/src/naive/layer_norm/opr_impl.cpp
  13. +51
    -0
      dnn/src/naive/layer_norm/opr_impl.h
  14. +9
    -0
      dnn/test/common/deduce_layout_proxy.h
  15. +94
    -0
      dnn/test/cuda/layer_norm.cpp
  16. +19
    -68
      imperative/python/megengine/functional/nn.py
  17. +0
    -55
      imperative/python/test/unit/functional/test_functional.py
  18. +1
    -1
      imperative/python/test/unit/functional/test_loss.py
  19. +26
    -0
      imperative/src/impl/ops/specializations.cpp
  20. +2
    -0
      src/core/include/megbrain/ir/ops.td
  21. +44
    -0
      src/opr/impl/dnn/dnn.sereg.h
  22. +248
    -0
      src/opr/impl/dnn/layer_norm.cpp
  23. +78
    -0
      src/opr/include/megbrain/opr/dnn/layer_norm.h
  24. +108
    -0
      src/opr/test/dnn/layer_norm.cpp
  25. +1
    -0
      src/serialization/impl/schema.fbs

+ 69
- 0
dnn/include/megdnn/oprs/nn.h View File

@@ -1936,6 +1936,75 @@ protected:
const TensorLayout& grad_s, size_t workspace_in_bytes);
};

class LayerNormBase : public OperatorBase {
DEF_OPR_IMPL_CTOR(LayerNormBase, OperatorBase);
DEF_OPR_PARAM(LayerNorm);

protected:
void deduce_layout_fwd(
const TensorLayout& data, const TensorLayout& weight,
const TensorLayout& bias, TensorLayout& dst, TensorLayout& mean,
TensorLayout& rstd);
void check_layout_fwd(
const TensorLayout& data, const TensorLayout& weight,
const TensorLayout& bias, const TensorLayout& dst, const TensorLayout& mean,
const TensorLayout& rstd);
};

class LayerNormForward : public LayerNormBase {
DEF_OPR_IMPL(LayerNormForward, LayerNormBase, 3, 3);

public:
virtual void exec(
_megdnn_tensor_in data, _megdnn_tensor_in weight, _megdnn_tensor_in bias,
_megdnn_tensor_out dst, _megdnn_tensor_out mean, _megdnn_tensor_out rstd,
_megdnn_workspace workspace) = 0;
void deduce_layout(
const TensorLayout& data, const TensorLayout& weight,
const TensorLayout& bias, TensorLayout& dst, TensorLayout& mean,
TensorLayout& rstd);
virtual size_t get_workspace_in_bytes(
const TensorLayout& data, const TensorLayout& weight,
const TensorLayout& bias, const TensorLayout& dst, const TensorLayout& mean,
const TensorLayout& rstd) = 0;

protected:
void check_exec(
const TensorLayout& data, const TensorLayout& weight,
const TensorLayout& bias, const TensorLayout& dst, const TensorLayout& mean,
const TensorLayout& rstd, size_t workspace_in_bytes);
};
using LayerNorm = LayerNormForward;

class LayerNormBackward : public LayerNormBase {
DEF_OPR_IMPL(LayerNormBackward, LayerNormBase, 5, 3);

public:
virtual void exec(
_megdnn_tensor_in diff, _megdnn_tensor_in data, _megdnn_tensor_in weight,
_megdnn_tensor_in mean, _megdnn_tensor_in rstd, _megdnn_tensor_out ddata,
_megdnn_tensor_out dweight, _megdnn_tensor_out dbias,
_megdnn_workspace workspace) = 0;
void deduce_layout(
const TensorLayout& diff, const TensorLayout& data,
const TensorLayout& weight, const TensorLayout& mean,
const TensorLayout& rstd, TensorLayout& ddata, TensorLayout& dweight,
TensorLayout& dbias);
virtual size_t get_workspace_in_bytes(
const TensorLayout& diff, const TensorLayout& data,
const TensorLayout& weight, const TensorLayout& mean,
const TensorLayout& rstd, const TensorLayout& ddata,
const TensorLayout& dweight, const TensorLayout& dbias) = 0;

protected:
void check_exec(
const TensorLayout& diff, const TensorLayout& data,
const TensorLayout& weight, const TensorLayout& mean,
const TensorLayout& rstd, const TensorLayout& ddata,
const TensorLayout& dweight, const TensorLayout& dbias,
size_t workspace_in_bytes);
};

} // namespace megdnn
#include "megdnn/internal/opr_header_epilogue.h"



+ 7
- 0
dnn/scripts/opr_param_defs.py View File

@@ -1212,3 +1212,10 @@ PADDING_MODES = [Doc('REPLICATE = 0', 'aaaaaa|abcdefgh|hhhhhhh'),
member_alias=[(i, 'PADDING_{}'.format(i)) for i in PADDING_MODES]
)
)

(pdef('LayerNorm')
.add_fields('bool', 'affine', 'true')
.add_fields('float32', 'eps', '1e-5f')
.add_fields('uint64', 'normalized_dim', '1')
.add_fields('uint64', 'normalized_size', '1')
)

+ 4
- 1
dnn/src/common/handle_impl.h View File

@@ -209,7 +209,10 @@ private:
cb(LSQBackward) \
cb(Fill) \
cb(PaddingForward) \
cb(PaddingBackward)
cb(PaddingBackward) \
cb(LayerNormForward) \
cb(LayerNormBackward)

// clang-format on

/*!


+ 180
- 0
dnn/src/common/layer_norm.cpp View File

@@ -0,0 +1,180 @@
/**
* \file dnn/src/common/layer_norm.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "megdnn/oprs.h"

#include "src/common/utils.h"

namespace megdnn {

void LayerNormBase::deduce_layout_fwd(
const TensorLayout& data, const TensorLayout& weight, const TensorLayout& bias,
TensorLayout& dst, TensorLayout& mean, TensorLayout& rstd) {
MEGDNN_MARK_USED_VAR(weight);
MEGDNN_MARK_USED_VAR(bias);
auto p = param();
TensorShape unnormalized_shape;
unnormalized_shape.ndim = data.ndim - p.normalized_dim;
for (size_t i = 0; i < unnormalized_shape.ndim; ++i) {
unnormalized_shape.shape[i] = data.shape[i];
}
TensorLayout unnormalized_layout =
TensorLayout(unnormalized_shape, dtype::Float32());
dst = data;
mean = unnormalized_layout;
rstd = unnormalized_layout;
}

void LayerNormBase::check_layout_fwd(
const TensorLayout& data, const TensorLayout& weight, const TensorLayout& bias,
const TensorLayout& dst, const TensorLayout& mean, const TensorLayout& rstd) {
megdnn_assert_contiguous(data);
megdnn_assert_contiguous(weight);
megdnn_assert_contiguous(bias);
megdnn_assert_contiguous(dst);
megdnn_assert_contiguous(mean);
megdnn_assert_contiguous(rstd);
auto errmsg = [&]() {
return megdnn_layout_msg(data) + ", " + megdnn_layout_msg(weight) + ", " +
megdnn_layout_msg(bias) + ", " + megdnn_layout_msg(dst) + ", " +
megdnn_layout_msg(mean) + ", " + megdnn_layout_msg(rstd);
};
MEGDNN_MARK_USED_VAR(errmsg);

auto equal_layout = [](const TensorLayout& lhs, const TensorLayout& rhs) -> bool {
if (!(lhs.ndim == rhs.ndim && lhs.dtype == rhs.dtype &&
lhs.format == rhs.format))
return false;
for (size_t i = 0; i < lhs.ndim; ++i) {
if (lhs.shape[i] != rhs.shape[i] || lhs.stride[i] != rhs.stride[i]) {
return false;
}
}
return true;
};

megdnn_assert(equal_layout(data, dst), "%s", errmsg().c_str());
megdnn_assert(equal_layout(weight, bias), "%s", errmsg().c_str());
megdnn_assert(equal_layout(mean, rstd), "%s", errmsg().c_str());

auto p = param();
uint64_t normalized_dim = p.normalized_dim;
size_t unnormalized_dim = data.ndim - normalized_dim;
megdnn_assert(
normalized_dim < data.ndim,
"the dims of normalized shape should smaller than input dims");

for (size_t i = 0; i < unnormalized_dim; ++i) {
megdnn_assert(data.shape[i] == mean.shape[i], "%s", errmsg().c_str());
}
if (p.affine) {
for (size_t i = 0; i < normalized_dim; ++i) {
megdnn_assert(
data.shape[unnormalized_dim + i] == weight.shape[i], "%s",
errmsg().c_str());
}
}
}

void LayerNormForward::deduce_layout(
const TensorLayout& data, const TensorLayout& weight, const TensorLayout& bias,
TensorLayout& dst, TensorLayout& mean, TensorLayout& rstd) {
deduce_layout_fwd(data, weight, bias, dst, mean, rstd);
}

void LayerNormForward::check_exec(
const TensorLayout& data, const TensorLayout& weight, const TensorLayout& bias,
const TensorLayout& dst, const TensorLayout& mean, const TensorLayout& rstd,
size_t workspace_in_bytes) {
check_layout_fwd(data, weight, bias, dst, mean, rstd);
auto required_workspace_in_bytes =
get_workspace_in_bytes(data, weight, bias, dst, mean, rstd);
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
}

void LayerNormBackward::deduce_layout(
const TensorLayout& diff, const TensorLayout& data, const TensorLayout& weight,
const TensorLayout& mean, const TensorLayout& rstd, TensorLayout& ddata,
TensorLayout& dweight, TensorLayout& dbias) {
MEGDNN_MARK_USED_VAR(diff);
MEGDNN_MARK_USED_VAR(mean);
MEGDNN_MARK_USED_VAR(rstd);
ddata = data;
dweight = weight;
dbias = weight;
}

void LayerNormBackward::check_exec(
const TensorLayout& diff, const TensorLayout& data, const TensorLayout& weight,
const TensorLayout& mean, const TensorLayout& rstd, const TensorLayout& ddata,
const TensorLayout& dweight, const TensorLayout& dbias,
size_t workspace_in_bytes) {
auto p = param();
auto required_workspace_in_bytes = get_workspace_in_bytes(
diff, data, weight, mean, rstd, ddata, dweight, dbias);
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);

megdnn_assert_contiguous(diff);
megdnn_assert_contiguous(data);
megdnn_assert_contiguous(mean);
megdnn_assert_contiguous(rstd);
megdnn_assert_contiguous(ddata);
if (p.affine) {
megdnn_assert_contiguous(weight);
megdnn_assert_contiguous(dweight);
megdnn_assert_contiguous(dbias);
}

auto errmsg = [&]() {
return megdnn_layout_msg(diff) + ", " + megdnn_layout_msg(data) + ", " +
megdnn_layout_msg(weight) + ", " + megdnn_layout_msg(mean) + ", " +
megdnn_layout_msg(rstd) + ", " + megdnn_layout_msg(ddata) + ", " +
megdnn_layout_msg(dweight) + ", " + megdnn_layout_msg(dbias);
};
MEGDNN_MARK_USED_VAR(errmsg);

auto equal_layout = [](const TensorLayout& lhs, const TensorLayout& rhs) -> bool {
if (!(lhs.ndim == rhs.ndim && lhs.dtype == rhs.dtype &&
lhs.format == rhs.format))
return false;
for (size_t i = 0; i < lhs.ndim; ++i) {
if (lhs.shape[i] != rhs.shape[i] || lhs.stride[i] != rhs.stride[i]) {
return false;
}
}
return true;
};

megdnn_assert(equal_layout(data, ddata), "%s", errmsg().c_str());
megdnn_assert(equal_layout(mean, rstd), "%s", errmsg().c_str());
if (p.affine) {
megdnn_assert(equal_layout(weight, dweight), "%s", errmsg().c_str());
megdnn_assert(equal_layout(weight, dbias), "%s", errmsg().c_str());
}

size_t normalized_dim = p.normalized_dim;
size_t unnormalized_dim = data.ndim - normalized_dim;

for (size_t i = 0; i < unnormalized_dim; ++i) {
megdnn_assert(data.shape[i] == mean.shape[i], "%s", errmsg().c_str());
}
if (p.affine) {
for (size_t i = 0; i < normalized_dim; ++i) {
megdnn_assert(
data.shape[unnormalized_dim + i] == weight.shape[i], "%s",
errmsg().c_str());
}
}
}

} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 2
- 0
dnn/src/common/opr_trait.h View File

@@ -135,6 +135,8 @@ DEF(CheckNonFinite, 2, true, true);
DEF(LSQForward, 5, true, true);
DEF(LSQBackward, 7, true, false);
DEF(Fill, 1, true, false);
DEF(LayerNormForward, 6, true, true);
DEF(LayerNormBackward, 8, true, true);
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 1
- 0
dnn/src/cuda/handle_create.cpp View File

@@ -45,6 +45,7 @@
#include "src/cuda/images2neibs/opr_impl.h"
#include "src/cuda/indexing_multi_axis_vec/opr_impl.h"
#include "src/cuda/indexing_one_hot/opr_impl.h"
#include "src/cuda/layer_norm/opr_impl.h"
#include "src/cuda/linspace/opr_impl.h"
#include "src/cuda/local/opr_impl.h"
#include "src/cuda/local_share/opr_impl.h"


+ 664
- 0
dnn/src/cuda/layer_norm/layer_norm_cuda.cu View File

@@ -0,0 +1,664 @@
/**
* \file dnn/src/cuda/layer_norm/layer_norm_cuda.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include <thrust/pair.h>
#include <thrust/tuple.h>
#include <cfloat>
#include "megdnn/arch.h"
#include "megdnn/dtype.h"
#include "src/cuda/cuda_shfl_compat.cuh"
#include "src/cuda/layer_norm/layer_norm_cuda.cuh"
#include "src/cuda/utils.cuh"

namespace megdnn {
namespace cuda {
namespace layer_norm {

constexpr int kCUDANumThreads = 256;
constexpr int vec_size = 4;

// warp size may be used as array length, or used in host function,
// so we define WARP_SIZE rather than using warpSize
#define WARP_SIZE 32

#if defined(__clang__)
#define __ubsan_ignore_float_divide_by_zero__ \
__attribute__((no_sanitize("float-divide-by-zero")))
#else
#define __ubsan_ignore_float_divide_by_zero__
#endif

struct WelfordStat {
float mean;
float sigma2;
float count;
MEGDNN_HOST MEGDNN_DEVICE WelfordStat() : mean(0.f), sigma2(0.f), count(0.f) {}
MEGDNN_HOST MEGDNN_DEVICE WelfordStat(float mean, float sigma2, float count)
: mean(mean), sigma2(sigma2), count(count) {}
};

template <typename T, typename combine_t>
struct WelfordData {
T mean;
T sigma2;
combine_t count;

MEGDNN_HOST MEGDNN_DEVICE WelfordData() : mean(0), sigma2(0), count(0) {}

MEGDNN_HOST MEGDNN_DEVICE WelfordData(T mean, T sigma2, combine_t count)
: mean(mean), sigma2(sigma2), count(count) {}
};

template <typename T, typename combine_t, typename res_t>
struct WelfordOps {
public:
using WelfordData_T = WelfordData<T, combine_t>;
inline MEGDNN_DEVICE WelfordData_T reduce(WelfordData_T acc, T data) const {
T delta = data - acc.mean;
T new_mean = static_cast<T>(acc.mean + delta / (acc.count + 1));
T new_delta = static_cast<T>(data - new_mean);
return {
new_mean,
acc.sigma2 + delta * new_delta,
combine_t(acc.count + 1),
};
}
inline MEGDNN_DEVICE WelfordData_T
combine(WelfordData_T lhs, WelfordData_T rhs) const {
if (lhs.count != 0 && rhs.count != 0) {
T delta = rhs.mean - lhs.mean;
combine_t new_count = lhs.count + rhs.count;
T nb_over_n = rhs.count / new_count;
return {lhs.mean + delta * nb_over_n,
lhs.sigma2 + rhs.sigma2 + delta * delta * lhs.count * nb_over_n,
new_count};
} else {
return (lhs.count != 0) ? lhs : rhs;
}
}
inline MEGDNN_DEVICE res_t
project(WelfordData_T acc) const __ubsan_ignore_float_divide_by_zero__ {
const auto mean = static_cast<T>(acc.mean);
const combine_t divisor = static_cast<combine_t>(acc.count);
const auto var = acc.sigma2 / divisor;
res_t results(var, mean);
return results;
}

#if defined(__CUDACC__) || defined(__HIPCC__)
inline MEGDNN_DEVICE WelfordData_T
warp_shfl_down(WelfordData_T acc, int offset) const {
return {__shfl_down(acc.mean, offset, warpSize),
__shfl_down(acc.sigma2, offset, warpSize),
__shfl_down(acc.count, offset, warpSize)};
}
#endif
MEGDNN_HOST MEGDNN_DEVICE WelfordOps() {}
};

template <typename T, int vec_size>
struct alignas(sizeof(T) * vec_size) aligned_vector {
T val[vec_size];
};

template <typename T, bool is_cuda>
using acc_type = T;

template <typename U>
MEGDNN_DEVICE WelfordStat
update_welford_stat_online(const U val, const WelfordStat& curr_sum) {
U delta = static_cast<U>(val - curr_sum.mean);
U new_count = static_cast<U>(curr_sum.count + 1.f);
U new_mean = static_cast<U>(curr_sum.mean + delta * (1.f / new_count));
return {new_mean, curr_sum.sigma2 + delta * (val - new_mean), new_count};
}

MEGDNN_DEVICE WelfordStat
combine_welford_stat(const WelfordStat lhs, const WelfordStat rhs) {
using U = decltype(lhs.count);
U delta = lhs.mean - rhs.mean;
U count = rhs.count + lhs.count;
U mean, sigma2;
if (count > decltype(lhs.count){0}) {
auto coef = 1.f / count;
auto nA = rhs.count * coef;
auto nB = lhs.count * coef;
mean = nA * rhs.mean + nB * lhs.mean;
sigma2 = rhs.sigma2 + lhs.sigma2 + delta * delta * rhs.count * nB;
} else {
mean = U(0);
sigma2 = U(0);
}
return {mean, sigma2, count};
}

template <typename T>
MEGDNN_DEVICE WelfordStat
compute_stats(const T* __restrict__ X, const int slice_len, float* buf) {
using vec_t = aligned_vector<T, vec_size>;
using acc_t = acc_type<T, true>;
const vec_t* X_vec = reinterpret_cast<const vec_t*>(X);
const int numx = blockDim.x * blockDim.y;
const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
const int n_vec_to_read = slice_len / vec_size;
WelfordStat w_stat(0.f, 0.f, 0.f);
for (int i = thrx; i < n_vec_to_read; i += numx) {
vec_t data = X_vec[i];
#pragma unroll
for (int ii = 0; ii < vec_size; ii++) {
w_stat = update_welford_stat_online(
static_cast<acc_t>(data.val[ii]), w_stat);
}
}
// intra-warp reduction
#pragma unroll
for (int offset = (warpSize >> 1); offset > 0; offset >>= 1) {
WelfordStat w_tmp{
__shfl_down(w_stat.mean, offset, warpSize),
__shfl_down(w_stat.sigma2, offset, warpSize),
__shfl_down(w_stat.count, offset, warpSize)};
w_stat = combine_welford_stat(w_stat, w_tmp);
}

// threadIdx.x == 0 has correct values for each warp
// inter-warp reductions
if (blockDim.y > 1) {
float* mean_sigma_buf = buf;
float* count_buf = buf + blockDim.y;
for (int offset = blockDim.y / 2; offset > 0; offset /= 2) {
// upper half of warps write to shared
if (threadIdx.x == 0 && threadIdx.y >= offset && threadIdx.y < 2 * offset) {
const int wrt_y = threadIdx.y - offset;
mean_sigma_buf[2 * wrt_y] = w_stat.mean;
mean_sigma_buf[2 * wrt_y + 1] = w_stat.sigma2;
count_buf[wrt_y] = w_stat.count;
}
__syncthreads();

// lower half merges
if (threadIdx.x == 0 && threadIdx.y < offset) {
WelfordStat w_tmp{
mean_sigma_buf[2 * threadIdx.y],
mean_sigma_buf[2 * threadIdx.y + 1], count_buf[threadIdx.y]};
w_stat = combine_welford_stat(w_stat, w_tmp);
}
__syncthreads();
}
if (threadIdx.x == 0 && threadIdx.y == 0) {
mean_sigma_buf[0] = w_stat.mean;
mean_sigma_buf[1] = w_stat.sigma2 / float(slice_len);
}
__syncthreads();
return WelfordStat{mean_sigma_buf[0], mean_sigma_buf[1], 0.f};

} else {
return WelfordStat{
__shfl(w_stat.mean, 0, warpSize),
__shfl(w_stat.sigma2, 0, warpSize) / float(slice_len), 0.f};
}
}

template <typename T, typename T_ACC>
__global__ void vectorized_layer_norm_forward_affine_kernel(
const int slice_len, T_ACC eps, const T* __restrict__ X, const T* weight,
const T* bias, T_ACC* mean, T_ACC* rstd, T* Y) {
// if we made smem WelfordStat type, there would be bank conflicts,
// as one thread would have to write 3 consecutive floats
extern __shared__ float s_data[];

auto slice_id = blockIdx.x;
const T* slice = X + slice_id * slice_len;
WelfordStat slice_w_stat = compute_stats(slice, slice_len, s_data);
using vec_t = aligned_vector<T, vec_size>;
const vec_t* X_vec = reinterpret_cast<const vec_t*>(slice);
vec_t* Y_vec = reinterpret_cast<vec_t*>(Y + slice_id * slice_len);
const int numx = blockDim.x * blockDim.y;
const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
const int n_vec_to_read = slice_len / vec_size;
T_ACC rstd_val = static_cast<T_ACC>(rsqrt(slice_w_stat.sigma2 + eps));

for (int i = thrx; i < n_vec_to_read; i += numx) {
vec_t data = X_vec[i];
vec_t out;
// computation is performed in T_ACC, X is cast to T_ACC and result is
// implicitly cast to T

#pragma unroll
for (int ii = 0; ii < vec_size; ii++) {
out.val[ii] = static_cast<T_ACC>(weight[i * vec_size + ii]) *
(rstd_val * (static_cast<T_ACC>(data.val[ii]) -
slice_w_stat.mean)) +
static_cast<T_ACC>(bias[i * vec_size + ii]);
}
Y_vec[i] = out;
}
if (thrx == 0) {
mean[slice_id] = slice_w_stat.mean;
rstd[slice_id] = rstd_val;
}
}

template <typename T, typename T_ACC>
__global__ void vectorized_layer_norm_forward_kernel(
const int slice_len, T_ACC eps, const T* __restrict__ X, const T* weight,
const T* bias, T_ACC* mean, T_ACC* rstd, T* Y) {
extern __shared__ float s_data[];

auto slice_id = blockIdx.x;
const T* slice = X + slice_id * slice_len;
WelfordStat slice_w_stat = compute_stats(slice, slice_len, s_data);
using vec_t = aligned_vector<T, vec_size>;
const vec_t* X_vec = reinterpret_cast<const vec_t*>(slice);
vec_t* Y_vec = reinterpret_cast<vec_t*>(Y + slice_id * slice_len);
const int numx = blockDim.x * blockDim.y;
const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
const int n_vec_to_read = slice_len / vec_size;
T_ACC rstd_val = static_cast<T_ACC>(rsqrt(slice_w_stat.sigma2 + eps));

for (int i = thrx; i < n_vec_to_read; i += numx) {
vec_t data = X_vec[i];
vec_t out;

#pragma unroll
for (int ii = 0; ii < vec_size; ii++) {
out.val[ii] =
rstd_val * (static_cast<T_ACC>(data.val[ii]) - slice_w_stat.mean);
}
Y_vec[i] = out;
}
if (thrx == 0) {
mean[slice_id] = slice_w_stat.mean;
rstd[slice_id] = rstd_val;
}
}

template <typename T, typename T_ACC>
void launch_vectorized_layer_norm_forward_kernel(
int64_t slice_len, int64_t slice_num, T_ACC eps, const T* X_data,
const T* weight_data, const T* bias_data, T* Y_data, T_ACC* mean_data,
T_ACC* rstd_data, cudaStream_t stream) {
const int num_threads = 128;
const dim3 threads(WARP_SIZE, num_threads / WARP_SIZE, 1);
const dim3 blocks(slice_num);
int nshared = threads.y > 1 ? threads.y * 3 / 2 * sizeof(T_ACC) : 0;

if (weight_data == nullptr && bias_data == nullptr) {
vectorized_layer_norm_forward_kernel<<<blocks, threads, nshared, stream>>>(
slice_len, eps, X_data, weight_data, bias_data, mean_data, rstd_data,
Y_data);
} else {
vectorized_layer_norm_forward_affine_kernel<<<
blocks, threads, nshared, stream>>>(
slice_len, eps, X_data, weight_data, bias_data, mean_data, rstd_data,
Y_data);
}
after_kernel_launch();
}

template <typename T, class ReduceOp>
__inline__ MEGDNN_DEVICE T welford_warp_reduce(T val, const ReduceOp& op) {
#pragma unroll
for (int offset = (warpSize >> 1); offset > 0; offset >>= 1) {
val = op.combine(val, op.warp_shfl_down(val, offset));
}
return val;
}

template <typename T, class ReduceOp>
__inline__ MEGDNN_DEVICE T
welford_block_reduce(T val, const ReduceOp& op, const T& identity_element, T* shared) {
const int lid = threadIdx.x % warpSize;
const int wid = threadIdx.x / warpSize;
val = welford_warp_reduce(val, op);
__syncthreads();
if (lid == 0) {
shared[wid] = val;
}
__syncthreads();
val = (threadIdx.x < blockDim.x / warpSize) ? shared[lid] : identity_element;
if (wid == 0) {
val = welford_warp_reduce(val, op);
}
return val;
}

template <typename T, typename T_ACC>
__global__ void get_input_mean_and_rstd_kernel(
int64_t slice_len, T_ACC eps, const T* X, T_ACC* mean, T_ACC* rstd) {
using WelfordType = WelfordData<T_ACC, T_ACC>;
using WelfordOp = WelfordOps<T_ACC, T_ACC, thrust::pair<T_ACC, T_ACC>>;

__shared__ typename std::aligned_storage<
sizeof(WelfordType), alignof(WelfordType)>::type val_shared[WARP_SIZE];
WelfordType* val_shared_ptr = reinterpret_cast<WelfordType*>(val_shared);

const int64_t i = blockIdx.x;
WelfordOp welford_op;
WelfordType val(
static_cast<T_ACC>(0), static_cast<T_ACC>(0), static_cast<T_ACC>(0));

for (int64_t j = threadIdx.x; j < slice_len; j += blockDim.x) {
const int64_t index = i * slice_len + j;
val = welford_op.reduce(val, static_cast<T_ACC>(X[index]));
}
val = welford_block_reduce(
val, welford_op,
WelfordType(
static_cast<T_ACC>(0), static_cast<T_ACC>(0),
static_cast<T_ACC>(0)),
val_shared_ptr);

if (threadIdx.x == 0) {
T_ACC slice_mean;
T_ACC slice_sigma2;
thrust::tie(slice_sigma2, slice_mean) = welford_op.project(val);
mean[i] = slice_mean;
rstd[i] = rsqrt(slice_sigma2 + eps);
}
}

template <typename T, typename T_ACC>
__global__ void layer_norm_forward_kernel(
int64_t slice_len, const T* X, const T_ACC* mean, const T_ACC* rstd,
const T* weight, const T* bias, T* Y) {
const int64_t i = blockIdx.x;
for (int64_t j = threadIdx.x; j < slice_len; j += blockDim.x) {
const int64_t index = i * slice_len + j;
const T_ACC weight_v =
weight == nullptr ? T_ACC(1) : static_cast<T_ACC>(weight[j]);
const T_ACC bias_v = bias == nullptr ? T_ACC(0) : static_cast<T_ACC>(bias[j]);
Y[index] = (static_cast<T_ACC>(X[index]) - static_cast<T_ACC>(mean[i])) *
static_cast<T_ACC>(rstd[i]) * weight_v +
bias_v;
}
}

template <typename T, typename T_ACC>
void forward(
T* X, T* weight, T* bias, int64_t slice_num, int64_t slice_len, T_ACC eps, T* Y,
T_ACC* mean, T_ACC* rstd, cudaStream_t stream) {
auto can_vectorize = [&](const T* ptr, int alignment) {
uint64_t addr = reinterpret_cast<uint64_t>(ptr);
return addr % alignment == 0;
};
constexpr int num_vec_elems = vec_size;
constexpr int alignment = num_vec_elems * sizeof(T);
if ((std::is_same<T, dt_float32>::value || std::is_same<T, dt_float16>::value ||
std::is_same<T, dt_bfloat16>::value) &&
slice_len <= static_cast<int64_t>(1ULL << std::numeric_limits<float>::digits) &&
slice_len % num_vec_elems == 0 && can_vectorize(X, alignment) &&
can_vectorize(Y, alignment)) {
launch_vectorized_layer_norm_forward_kernel<T, T_ACC>(
slice_len, slice_num, static_cast<T_ACC>(eps), X, weight, bias, Y, mean,
rstd, stream);
after_kernel_launch();
} else {
get_input_mean_and_rstd_kernel<T, T_ACC>
<<<slice_num, 512, 0, stream>>>(slice_len, eps, X, mean, rstd);
after_kernel_launch();
layer_norm_forward_kernel<T, T_ACC><<<slice_num, kCUDANumThreads, 0, stream>>>(
slice_len, X, mean, rstd, weight, bias, Y);
after_kernel_launch();
}
}

template <typename T>
__inline__ MEGDNN_DEVICE T warp_reduce_sum(T val) {
#pragma unroll
for (int offset = (warpSize >> 1); offset > 0; offset >>= 1) {
val += __shfl_down(val, offset, warpSize);
}
return val;
}

template <typename T>
__inline__ MEGDNN_DEVICE T block_reduce_sum(T val, T* shared) {
const int lid = threadIdx.x % warpSize;
const int wid = threadIdx.x / warpSize;
val = warp_reduce_sum(val);
__syncthreads();
if (lid == 0) {
shared[wid] = val;
}
__syncthreads();
val = (threadIdx.x < blockDim.x / warpSize) ? shared[lid] : T(0);
if (wid == 0) {
val = warp_reduce_sum(val);
}
return val;
}

template <typename T, typename T_ACC>
__inline__ MEGDNN_DEVICE void layer_norm_grad_input_kernel_impl(
const T* __restrict__ dY, const T* __restrict__ X,
const T_ACC* __restrict__ mean, const T_ACC* __restrict__ rstd,
const T* __restrict__ weight, T* dX, const int slice_len, T_ACC* buf) {
const auto slice_id = blockIdx.x;
const T_ACC mean_val = mean[slice_id];
const T_ACC rstd_val = rstd[slice_id];
T_ACC stats_x1{0}, stats_x2{0};
constexpr int unroll = 4;
auto l = unroll * threadIdx.x;
const T* X_i = X + slice_id * slice_len;
const T* dY_i = dY + slice_id * slice_len;
T* dX_i = dX + slice_id * slice_len;
// vectorized reads don't improve perf, so use regular unrolling

for (; l + unroll - 1 < slice_len; l += blockDim.x * unroll) {
#pragma unroll
for (int k = 0; k < unroll; k++) {
T_ACC weight_val =
(weight != nullptr) ? static_cast<T_ACC>(weight[l + k]) : T_ACC(1);
const T_ACC c_h = static_cast<T_ACC>(X_i[l + k]);
const T_ACC c_loss = static_cast<T_ACC>(dY_i[l + k]);
stats_x1 += c_loss * weight_val;
stats_x2 += c_loss * weight_val * (c_h - mean_val) * rstd_val;
}
}
for (; l < slice_len; l++) {
T_ACC weight_val =
(weight != nullptr) ? static_cast<T_ACC>(weight[l]) : T_ACC(1);
const T_ACC c_h = static_cast<T_ACC>(X_i[l]);
const T_ACC c_loss = static_cast<T_ACC>(dY_i[l]);
stats_x1 += c_loss * weight_val;
stats_x2 += c_loss * weight_val * (c_h - mean_val) * rstd_val;
}

stats_x1 = block_reduce_sum(stats_x1, buf);
stats_x2 = block_reduce_sum(stats_x2, buf);
if (threadIdx.x == 0) {
buf[0] = stats_x1;
buf[1] = stats_x2;
}
__syncthreads();
stats_x1 = buf[0];
stats_x2 = buf[1];
T_ACC fH = slice_len;
T_ACC term1 = (T_ACC(1) / fH) * rstd_val;

for (int l = threadIdx.x; l < slice_len; l += blockDim.x) {
const T_ACC x = X_i[l];
const T_ACC dy = dY_i[l];
T_ACC weight_val =
(weight != nullptr) ? static_cast<T_ACC>(weight[l]) : T_ACC(1);
T_ACC f_grad_input = fH * weight_val * dy;
f_grad_input -= (x - mean_val) * rstd_val * stats_x2;
f_grad_input -= stats_x1;
f_grad_input *= term1;
dX_i[l] = f_grad_input;
}
}

template <typename T, typename T_ACC>
__global__ void layer_norm_grad_input_kernel(
const T* __restrict__ dY, const T* __restrict__ X,
const T_ACC* __restrict__ mean, const T_ACC* __restrict__ rstd,
const T* __restrict__ weight, T* dX, const int slice_len) {
alignas(sizeof(double)) extern __shared__ char s_data1[];
T_ACC* buf = reinterpret_cast<T_ACC*>(&s_data1);

layer_norm_grad_input_kernel_impl(dY, X, mean, rstd, weight, dX, slice_len, buf);
}

template <typename T, typename T_ACC>
__global__ void layer_norm_grad_weight_bias_simple_kernel(
int64_t slice_num, int64_t slice_len, const T* dY, const T* X,
const T_ACC* mean, const T_ACC* rstd, T* dweight, T* dbias) {
const int64_t j = blockIdx.x * blockDim.x + threadIdx.x;
if (j < slice_len) {
T_ACC sum1 = 0;
T_ACC sum2 = 0;
for (int64_t i = 0; i < slice_num; ++i) {
const int64_t index = i * slice_len + j;
sum1 += dweight == nullptr ? T_ACC(0)
: static_cast<T_ACC>(dY[index]) *
(static_cast<T_ACC>(X[index]) -
static_cast<T_ACC>(mean[i])) *
static_cast<T_ACC>(rstd[i]);
sum2 += dbias == nullptr ? T_ACC(0) : static_cast<T_ACC>(dY[index]);
}
if (dweight != nullptr) {
dweight[j] = sum1;
}
if (dbias != nullptr) {
dbias[j] = sum2;
}
}
}

template <typename T, typename T_ACC>
__global__ void layer_norm_grad_weight_bias_kernel(
int64_t slice_num, int64_t slice_len, const T* dY, const T* X,
const T_ACC* mean, const T_ACC* rstd, T* dweight, T* dbias) {
alignas(sizeof(double)) extern __shared__ char s_data1[];
T_ACC* s_data_typed = reinterpret_cast<T_ACC*>(&s_data1);
const int64_t j = blockIdx.x * blockDim.x + threadIdx.x;
constexpr int unroll = 8;
T dYs[unroll];
T Xs[unroll];
T_ACC* means = s_data_typed;
T_ACC* rstds = s_data_typed + unroll * blockDim.y;
T_ACC dg_sum = 0;
T_ACC db_sum = 0;
if (j < slice_len) {
int bcounter;
for (bcounter = 0; bcounter < slice_num / (blockDim.y * unroll); bcounter++) {
int offset = (bcounter * blockDim.y + threadIdx.y) * unroll;
#pragma unroll
for (int ii = 0; ii < unroll; ii++) {
if (threadIdx.x == 0) {
means[ii * blockDim.y + threadIdx.y] = mean[offset + ii];
rstds[ii * blockDim.y + threadIdx.y] = rstd[offset + ii];
}
dYs[ii] = dY[(offset + ii) * slice_len + j];
Xs[ii] = X[(offset + ii) * slice_len + j];
}
__syncthreads();
#pragma unroll
for (int ii = 0; ii < unroll; ii++) {
dg_sum += dYs[ii] * (Xs[ii] - means[ii * blockDim.y + threadIdx.y]) *
rstds[ii * blockDim.y + threadIdx.y];
db_sum += dYs[ii];
}
__syncthreads();
}
int offset = (bcounter * blockDim.y + threadIdx.y) * unroll;
for (int ii = 0; ii < 8; ii++) {
T_ACC mean_val, rstd_val; // we don't use smem in the tail to avoid awkward
// synchronizations, perf penalty is negligible
if ((offset + ii) < slice_num) {
mean_val = mean[offset + ii];
rstd_val = rstd[offset + ii];
dYs[0] = dY[(offset + ii) * slice_len + j];
Xs[0] = X[(offset + ii) * slice_len + j];
dg_sum += dYs[0] * (Xs[0] - mean_val) * rstd_val;
db_sum += dYs[0];
}
}
s_data_typed[threadIdx.y * blockDim.x + threadIdx.x] = dg_sum;
s_data_typed[blockDim.x * blockDim.y + threadIdx.y * blockDim.x + threadIdx.x] =
db_sum;
__syncthreads();
for (int offset = blockDim.y / 2; offset >= 1; offset /= 2) {
if (threadIdx.y < offset) {
s_data_typed[threadIdx.y * blockDim.x + threadIdx.x] +=
s_data_typed[(threadIdx.y + offset) * blockDim.x + threadIdx.x];
s_data_typed
[blockDim.x * blockDim.y + threadIdx.y * blockDim.x +
threadIdx.x] += s_data_typed
[blockDim.x * blockDim.y +
(threadIdx.y + offset) * blockDim.x + threadIdx.x];
}
__syncthreads();
}
if (threadIdx.y == 0) {
if (dweight) {
dweight[j] = s_data_typed[threadIdx.x];
}
if (dbias) {
dbias[j] = s_data_typed[threadIdx.x + blockDim.x * blockDim.y];
}
}
}
}

template <typename T, typename T_ACC>
void backward(
const T* dY_data, const T* X_data, const T_ACC* mean_data,
const T_ACC* rstd_data, const T* weight_data, int64_t slice_num,
int64_t slice_len, T* dX_data, T* dweight_data, T* dbias_data,
cudaStream_t stream) {
if (dX_data != nullptr) {
const int num_threads = 128;
const dim3 blocks(slice_num);
int nshared = (num_threads / WARP_SIZE) * sizeof(T_ACC);
layer_norm_grad_input_kernel<<<blocks, num_threads, nshared, stream>>>(
dY_data, X_data, mean_data, rstd_data, weight_data, dX_data, slice_len);
after_kernel_launch();
}
if (dweight_data || dbias_data) {
if (slice_num < 512) {
const int64_t B = (slice_len + kCUDANumThreads - 1) / kCUDANumThreads;
layer_norm_grad_weight_bias_simple_kernel<T, T_ACC>
<<<B, kCUDANumThreads, 0, stream>>>(
slice_num, slice_len, dY_data, X_data, mean_data, rstd_data,
dweight_data, dbias_data);
after_kernel_launch();
} else {
dim3 threads{16, 32};
int blocks = (slice_len + threads.x - 1) / threads.x;
layer_norm_grad_weight_bias_kernel<T, T_ACC>
<<<blocks, threads, 2 * sizeof(T_ACC) * threads.x * threads.y,
stream>>>(
slice_num, slice_len, dY_data, X_data, mean_data, rstd_data,
dweight_data, dbias_data);
after_kernel_launch();
}
}
}

#define INST(T, T_ACC) \
template void forward<T, T_ACC>( \
T*, T*, T*, int64_t, int64_t, T_ACC, T*, T_ACC*, T_ACC*, cudaStream_t); \
template void backward<T, T_ACC>( \
const T*, const T*, const T_ACC*, const T_ACC*, const T*, int64_t, \
int64_t, T*, T*, T*, cudaStream_t);

INST(dt_float32, dt_float32)
INST(dt_float16, dt_float32)
INST(dt_bfloat16, dt_float32)
#undef INST

} // namespace layer_norm
} // namespace cuda
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 34
- 0
dnn/src/cuda/layer_norm/layer_norm_cuda.cuh View File

@@ -0,0 +1,34 @@
/**
* \file dnn/src/cuda/layer_norm/layer_norm.cuh
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include <cuda_runtime_api.h>

namespace megdnn {
namespace cuda {
namespace layer_norm {

template <typename T, typename T_ACC>
void forward(
T* X, T* gamma, T* beta, int64_t M, int64_t N, T_ACC eps, T* Y, T_ACC* mean,
T_ACC* rstd, cudaStream_t stream);

template <typename T, typename T_ACC>
void backward(
const T* dY_data, const T* X_data, const T_ACC* mean_data,
const T_ACC* rstd_data, const T* gamma_data, int64_t M, int64_t N, T* dX_data,
T* dgamma_data, T* dbeta_data, cudaStream_t stream);

} // namespace layer_norm
} // namespace cuda
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 94
- 0
dnn/src/cuda/layer_norm/opr_impl.cpp View File

@@ -0,0 +1,94 @@
/**
* \file dnn/src/cuda/layer_norm/opr_impl.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/

#include "src/cuda/layer_norm/opr_impl.h"
#include "src/cuda/layer_norm/layer_norm_cuda.cuh"
#include "src/cuda/utils.h"

namespace megdnn {
namespace cuda {

void LayerNormForwardImpl::exec(
_megdnn_tensor_in data, _megdnn_tensor_in weight, _megdnn_tensor_in bias,
_megdnn_tensor_out dst, _megdnn_tensor_out mean, _megdnn_tensor_out rstd,
_megdnn_workspace workspace) {
check_exec(
data.layout, weight.layout, bias.layout, dst.layout, mean.layout,
rstd.layout, workspace.size);

auto p = param();
float eps = p.eps;
bool affine = p.affine;
uint64_t slice_length = p.normalized_size;
uint64_t slice_dim = p.normalized_dim;
uint64_t n_slices = 1;
for (size_t i = 0; i < data.layout.ndim - slice_dim; ++i) {
n_slices = n_slices * data.layout.shape[i];
}

auto stream = cuda_stream(handle());
using namespace ::megdnn::cuda::layer_norm;

#define cb(DType) \
if (data.layout.dtype == DType()) { \
using T = typename DTypeTrait<DType>::ctype; \
using T_ACC = float; \
forward<T, T_ACC>( \
data.ptr<T>(), affine ? weight.ptr<T>() : nullptr, \
affine ? bias.ptr<T>() : nullptr, static_cast<int64_t>(n_slices), \
static_cast<int64_t>(slice_length), static_cast<T_ACC>(eps), \
dst.ptr<T>(), mean.ptr<T_ACC>(), rstd.ptr<T_ACC>(), stream); \
return; \
}
MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb)
#undef cb
megdnn_throw("bad dtype");
}

void LayerNormBackwardImpl::exec(
_megdnn_tensor_in diff, _megdnn_tensor_in data, _megdnn_tensor_in weight,
_megdnn_tensor_in mean, _megdnn_tensor_in rstd, _megdnn_tensor_out ddata,
_megdnn_tensor_out dweight, _megdnn_tensor_out dbias,
_megdnn_workspace workspace) {
check_exec(
diff.layout, data.layout, weight.layout, mean.layout, rstd.layout,
ddata.layout, dweight.layout, dbias.layout, workspace.size);
auto p = param();
bool affine = p.affine;
uint64_t slice_length = p.normalized_size;
uint64_t slice_dim = p.normalized_dim;
uint64_t n_slices = 1;
for (size_t i = 0; i < data.layout.ndim - slice_dim; ++i) {
n_slices = n_slices * data.layout.shape[i];
}

auto stream = cuda_stream(handle());
using namespace ::megdnn::cuda::layer_norm;
#define cb(DType) \
if (data.layout.dtype == DType()) { \
using T = typename DTypeTrait<DType>::ctype; \
using T_ACC = float; \
backward<T, T_ACC>( \
diff.ptr<T>(), data.ptr<T>(), mean.ptr<T_ACC>(), rstd.ptr<T_ACC>(), \
affine ? weight.ptr<T>() : nullptr, n_slices, slice_length, \
ddata.ptr<T>(), affine ? dweight.ptr<T>() : nullptr, \
affine ? dbias.ptr<T>() : nullptr, stream); \
return; \
}
MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb)
#undef cb
megdnn_throw("bad dtype");
}

} // namespace cuda
} // namespace megdnn
// vim: syntax=cpp.doxygen

+ 53
- 0
dnn/src/cuda/layer_norm/opr_impl.h View File

@@ -0,0 +1,53 @@
/**
* \file dnn/src/cuda/layer_norm/opr_impl.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "megdnn/oprs.h"

#include "src/cuda/cudnn_wrapper.h"

namespace megdnn {
namespace cuda {

class LayerNormForwardImpl final : public LayerNormForward {
public:
using LayerNormForward::LayerNormForward;
void exec(
_megdnn_tensor_in data, _megdnn_tensor_in weight, _megdnn_tensor_in bias,
_megdnn_tensor_out dst, _megdnn_tensor_out mean, _megdnn_tensor_out rstd,
_megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(
const TensorLayout&, const TensorLayout&, const TensorLayout&,
const TensorLayout&, const TensorLayout&, const TensorLayout&) override {
return 0;
}
};

class LayerNormBackwardImpl final : public LayerNormBackward {
public:
using LayerNormBackward::LayerNormBackward;
void exec(
_megdnn_tensor_in diff, _megdnn_tensor_in data, _megdnn_tensor_in weight,
_megdnn_tensor_in mean, _megdnn_tensor_in rstd, _megdnn_tensor_out ddata,
_megdnn_tensor_out dweight, _megdnn_tensor_out dbias,
_megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(
const TensorLayout&, const TensorLayout&, const TensorLayout&,
const TensorLayout&, const TensorLayout&, const TensorLayout&,
const TensorLayout&, const TensorLayout&) override {
return 0;
}
};

} // namespace cuda
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 1
- 0
dnn/src/naive/handle.cpp View File

@@ -47,6 +47,7 @@
#include "src/naive/images2neibs/opr_impl.h"
#include "src/naive/indexing_multi_axis_vec/opr_impl.h"
#include "src/naive/indexing_one_hot/opr_impl.h"
#include "src/naive/layer_norm/opr_impl.h"
#include "src/naive/linspace/opr_impl.h"
#include "src/naive/local/opr_impl.h"
#include "src/naive/local_share/opr_impl.h"


+ 170
- 0
dnn/src/naive/layer_norm/opr_impl.cpp View File

@@ -0,0 +1,170 @@
/**
* \file dnn/src/naive/layer_norm/opr_impl.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/naive/layer_norm/opr_impl.h"
#include <algorithm>
#include "src/common/utils.h"
#include "src/naive/handle.h"

using namespace megdnn;
using namespace naive;

namespace {

using Param = megdnn::LayerNorm::Param;

template <typename T, typename T_ACC = float>
void forward(
_megdnn_tensor_in data, _megdnn_tensor_in weight, _megdnn_tensor_in bias,
_megdnn_tensor_out dst, _megdnn_tensor_out mean, _megdnn_tensor_out rstd,
const Param& param) {
float eps = param.eps;
bool affine = param.affine;
uint64_t slice_length = param.normalized_size;
uint64_t slice_dim = param.normalized_dim;
uint64_t n_slices = 1;
for (size_t i = 0; i < data.layout.ndim - slice_dim; ++i) {
n_slices = n_slices * data.layout.shape[i];
}

for (size_t i = 0; i < n_slices; i++) {
T_ACC slice_sum = static_cast<T>(0.0f);
for (size_t j = 0; j < slice_length; j++) {
auto value = data.ptr<T>()[i * slice_length + j];
slice_sum += value;
}
T_ACC slice_mean = static_cast<T>(slice_sum / slice_length);

T_ACC slice_var = static_cast<T>(0.0f);
for (size_t j = 0; j < slice_length; j++) {
slice_var += (data.ptr<T>()[i * slice_length + j] - slice_mean) *
(data.ptr<T>()[i * slice_length + j] - slice_mean);
}
slice_var = slice_var / slice_length;

T_ACC slice_std = static_cast<T>(sqrt(slice_var + eps));
for (size_t j = 0; j < slice_length; j++) {
dst.ptr<T>()[i * slice_length + j] =
(data.ptr<T>()[i * slice_length + j] - slice_mean) / slice_std;
if (affine) {
dst.ptr<T>()[i * slice_length + j] =
dst.ptr<T>()[i * slice_length + j] * weight.ptr<T>()[j] +
bias.ptr<T>()[j];
}
}
mean.ptr<T_ACC>()[i] = static_cast<T_ACC>(slice_mean);
rstd.ptr<T_ACC>()[i] = static_cast<T_ACC>(1.0 / slice_std);
}
}

template <typename T, typename T_ACC = float>
void backward(
_megdnn_tensor_in diff, _megdnn_tensor_in data, _megdnn_tensor_in weight,
_megdnn_tensor_in mean, _megdnn_tensor_in rstd, _megdnn_tensor_out ddata,
_megdnn_tensor_out dweight, _megdnn_tensor_out dbias, const Param& param) {
bool affine = param.affine;
uint64_t slice_length = param.normalized_size;
uint64_t slice_dim = param.normalized_dim;
uint64_t n_slices = 1;
for (size_t i = 0; i < data.layout.ndim - slice_dim; ++i) {
n_slices = n_slices * data.layout.shape[i];
}

if (affine) {
for (size_t i = 0; i < slice_length; ++i) {
dweight.ptr<T>()[i] = 0;
dbias.ptr<T>()[i] = 0;
}

for (size_t i = 0; i < n_slices; ++i) {
for (size_t j = 0; j < slice_length; ++j) {
dweight.ptr<T>()[j] +=
(data.ptr<T>()[i * slice_length + j] - mean.ptr<T_ACC>()[i]) *
rstd.ptr<T_ACC>()[i] * diff.ptr<T>()[i * slice_length + j];

dbias.ptr<T>()[j] += diff.ptr<T>()[i * slice_length + j];
}
}
}

for (size_t i = 0; i < n_slices; ++i) {
T_ACC ds = static_cast<T_ACC>(0.0f);
T_ACC db = static_cast<T_ACC>(0.0f);
T_ACC a = static_cast<T_ACC>(0.0f);
T_ACC b = static_cast<T_ACC>(0.0f);
T_ACC c = static_cast<T_ACC>(0.0f);

for (size_t j = 0; j < slice_length; ++j) {
auto value = data.ptr<T>()[i * slice_length + j];
auto diff_v = diff.ptr<T>()[i * slice_length + j];
auto weight_v = affine ? weight.ptr<T>()[j] : static_cast<T>(1.0f);
db += diff_v * weight_v;
ds += diff_v * value * weight_v;
}

a = rstd.ptr<T_ACC>()[i];
b = (db * mean.ptr<T_ACC>()[i] - ds) * a * a * a / slice_length;
c = -b * mean.ptr<T_ACC>()[i] - db * a / slice_length;

for (uint64_t j = 0; j < slice_length; j++) {
auto weight_v = affine ? weight.ptr<T>()[j] : static_cast<T>(1.0f);
ddata.ptr<T>()[i * slice_length + j] =
diff.ptr<T>()[i * slice_length + j] * a * weight_v +
data.ptr<T>()[i * slice_length + j] * b + c;
}
}
}

} // namespace

namespace megdnn {
namespace naive {

void LayerNormForwardImpl::exec(
_megdnn_tensor_in data, _megdnn_tensor_in weight, _megdnn_tensor_in bias,
_megdnn_tensor_out dst, _megdnn_tensor_out mean, _megdnn_tensor_out rstd,
_megdnn_workspace workspace) {
check_exec(
data.layout, weight.layout, bias.layout, dst.layout, mean.layout,
rstd.layout, workspace.size);
#define cb(DType) \
if (data.layout.dtype == DType()) { \
MEGDNN_DISPATCH_CPU_KERN_OPR(forward<typename DTypeTrait<DType>::ctype>( \
data, weight, bias, dst, mean, rstd, param())); \
return; \
}
MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb)
#undef cb
megdnn_throw("bad dtype");
}

void LayerNormBackwardImpl::exec(
_megdnn_tensor_in diff, _megdnn_tensor_in data, _megdnn_tensor_in weight,
_megdnn_tensor_in mean, _megdnn_tensor_in rstd, _megdnn_tensor_out ddata,
_megdnn_tensor_out dweight, _megdnn_tensor_out dbias,
_megdnn_workspace workspace) {
check_exec(
diff.layout, data.layout, weight.layout, mean.layout, rstd.layout,
ddata.layout, dweight.layout, dbias.layout, workspace.size);
#define cb(DType) \
if (data.layout.dtype == DType()) { \
MEGDNN_DISPATCH_CPU_KERN_OPR(backward<typename DTypeTrait<DType>::ctype>( \
diff, data, weight, mean, rstd, ddata, dweight, dbias, param())); \
return; \
}
MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb)
#undef cb
megdnn_throw("bad dtype");
}

} // namespace naive
} // namespace megdnn
// vim: syntax=cpp.doxygen

+ 51
- 0
dnn/src/naive/layer_norm/opr_impl.h View File

@@ -0,0 +1,51 @@
/**
* \file dnn/src/naive/layer_norm/opr_impl.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "megdnn/oprs.h"

namespace megdnn {
namespace naive {

class LayerNormForwardImpl final : public LayerNormForward {
public:
using LayerNormForward::LayerNormForward;
void exec(
_megdnn_tensor_in data, _megdnn_tensor_in weight, _megdnn_tensor_in bias,
_megdnn_tensor_out dst, _megdnn_tensor_out mean, _megdnn_tensor_out rstd,
_megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(
const TensorLayout&, const TensorLayout&, const TensorLayout&,
const TensorLayout&, const TensorLayout&, const TensorLayout&) override {
return 0;
}
};

class LayerNormBackwardImpl final : public LayerNormBackward {
public:
using LayerNormBackward::LayerNormBackward;
void exec(
_megdnn_tensor_in diff, _megdnn_tensor_in data, _megdnn_tensor_in weight,
_megdnn_tensor_in mean, _megdnn_tensor_in rstd, _megdnn_tensor_out ddata,
_megdnn_tensor_out dweight, _megdnn_tensor_out dbias,
_megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(
const TensorLayout&, const TensorLayout&, const TensorLayout&,
const TensorLayout&, const TensorLayout&, const TensorLayout&,
const TensorLayout&, const TensorLayout&) override {
return 0;
}
};

} // namespace naive
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 9
- 0
dnn/test/common/deduce_layout_proxy.h View File

@@ -57,6 +57,15 @@ struct DeduceLayoutProxy<Opr, 5, true> {
}
};

template <typename Opr>
struct DeduceLayoutProxy<Opr, 6, true> {
static void deduce_layout(Opr* opr, TensorLayoutArray& layouts) {
megdnn_assert(layouts.size() == 6);
opr->deduce_layout(
layouts[0], layouts[1], layouts[2], layouts[3], layouts[4], layouts[5]);
}
};

template <typename Opr>
struct DeduceLayoutProxy<Opr, 5, false> {
static void deduce_layout(Opr*, TensorLayoutArray&) {}


+ 94
- 0
dnn/test/cuda/layer_norm.cpp View File

@@ -0,0 +1,94 @@
/**
* \file dnn/test/cuda/layer_norm.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "test/cuda/fixture.h"

#include "test/common/checker.h"

namespace megdnn {
namespace test {

TEST_F(CUDA, LAYERNORM_FORWARD) {
using Param = LayerNormForward::Param;
Param param;
param.affine = true;
param.eps = 1e-6;
param.normalized_dim = 1;
Checker<LayerNormForward> checker(handle_cuda());
checker.set_epsilon(1e-2);

auto run = [&](DType d) {
for (size_t n_slices : {10, 30})
for (size_t slice_len : {10, 30}) {
param.normalized_size = slice_len;
checker.set_param(param)
.set_dtype(0, d)
.set_dtype(1, d)
.set_dtype(2, d)
.set_dtype(3, d)
.set_dtype(4, dtype::Float32())
.set_dtype(5, dtype::Float32())
.execs({{n_slices, slice_len},
{slice_len},
{slice_len},
{n_slices, slice_len},
{n_slices},
{n_slices}});
}
};

run(dtype::Float32());
run(dtype::Float16());
run(dtype::BFloat16());
}

TEST_F(CUDA, LAYERNORM_BACKWARD) {
using Param = LayerNormBackward::Param;
Param param;
param.affine = true;
param.eps = 1e-6;
param.normalized_dim = 1;
Checker<LayerNormBackward> checker(handle_cuda());
checker.set_epsilon(1e-1);

auto run = [&](DType d) {
for (size_t n_slices : {10, 30})
for (size_t slice_len : {10, 30}) {
param.normalized_size = slice_len;
checker.set_param(param)
.set_dtype(0, d)
.set_dtype(1, d)
.set_dtype(2, d)
.set_dtype(3, dtype::Float32())
.set_dtype(4, dtype::Float32())
.set_dtype(5, d)
.set_dtype(6, d)
.set_dtype(7, d)
.execs({{n_slices, slice_len},
{n_slices, slice_len},
{slice_len},
{n_slices},
{n_slices},
{n_slices, slice_len},
{slice_len},
{slice_len}});
}
};

run(dtype::Float32());
run(dtype::Float16());
run(dtype::BFloat16());
}

} // namespace test
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 19
- 68
imperative/python/megengine/functional/nn.py View File

@@ -1066,57 +1066,6 @@ def softmax(inp: Tensor, axis: Optional[int] = None) -> Tensor:
return cached / down


@lru_cache(maxsize=None)
def _get_layerNorm(device, dtype, dim, gopt_level=2):
@subgraph("LayerNormAffine", dtype, device, 5, gopt_level=gopt_level)
def layerNormAffine(inputs, f, c):
inp, eps, _flatten_shape, weight, bias = inputs
inp_shape = f(GetVarShape(), inp)

inp = f(Reshape(axis=dim), inp, _flatten_shape)
mean = f(Reduce(mode="mean", axis=-1), inp)
x2s = f(Reduce(mode="sum_sqr", axis=-1), inp)
reduce_shape = f(GetVarShape(), x2s)
reduce_size = f(
"//",
f(Reduce(mode="product", axis=0), inp_shape),
f(Reduce(mode="product", axis=0), reduce_shape),
)
reduce_size_f = f(TypeCvt(dtype=dtype), reduce_size)
var = f("-", f("/", x2s, reduce_size_f), f("**", mean, c(2)))
inv_sqrt_var = f("**", f("+", var, eps), c(-0.5))
oup = f("fma3", inp, inv_sqrt_var, f("*", f("-", mean), inv_sqrt_var))
affine_oup = f(Reshape(), oup, inp_shape)
affine_oup = f("fma3", affine_oup, weight, bias)

# NOTE: return oup make backward faster but take more memory
return (affine_oup, oup, mean, x2s), (True, False, False, False)

@subgraph("LayerNorm", dtype, device, 3, gopt_level=gopt_level)
def layerNorm(inputs, f, c):
inp, eps, _flatten_shape = inputs
inp_shape = f(GetVarShape(), inp)

inp = f(Reshape(axis=dim), inp, _flatten_shape)
mean = f(Reduce(mode="mean", axis=-1), inp)
x2s = f(Reduce(mode="sum_sqr", axis=-1), inp)
reduce_shape = f(GetVarShape(), x2s)
reduce_size = f(
"//",
f(Reduce(mode="product", axis=0), inp_shape),
f(Reduce(mode="product", axis=0), reduce_shape),
)
reduce_size_f = f(TypeCvt(dtype=dtype), reduce_size)
var = f("-", f("/", x2s, reduce_size_f), f("**", mean, c(2)))
inv_sqrt_var = f("**", f("+", var, eps), c(-0.5))
oup = f("fma3", inp, inv_sqrt_var, f("*", f("-", mean), inv_sqrt_var))
oup = f(Reshape(), oup, inp_shape)

return (oup,), (True,)

return (layerNorm, layerNormAffine)


def layer_norm(
inp: Tensor,
normalized_shape: tuple,
@@ -1133,32 +1082,34 @@ def layer_norm(
normalized_shape: the shape that you want to be normalizated
affine: whether to use weight and bias
weight: must not be None when the affine is true
bias: must not be None when the bias is true
bias: must not be None when the affine is true
eps: a value added to the denominator for numerical stability. Default: 1e-5
"""

if amp._enabled:
inp, weight, bias = cast_tensors(inp, weight, bias, promote=True)

_device = inp.device
_dtype = inp.dtype
_dim = len(inp.shape) - len(normalized_shape)
if isinstance(normalized_shape, int):
normalized_shape = [normalized_shape]

_flatten_shape = concat(
(
convert_single_value(inp.shape[:_dim], dtype="int32", device=inp.device),
convert_single_value(-1, dtype="int32", device=inp.device),
)
)
(layerNorm, layerNormAffine) = _get_layerNorm(_device, _dtype, _dim)
normalized_dim = len(normalized_shape)
assert normalized_dim > 0

eps = convert_single_value(eps, dtype=inp.dtype, device=inp.device)
normalized_size = 1
for i in range(normalized_dim):
normalized_size = normalized_size * normalized_shape[i]

op = builtin.LayerNorm(
affine=affine,
eps=eps,
normalized_dim=normalized_dim,
normalized_size=normalized_size,
)
if affine:
outvar, *_ = apply(layerNormAffine(), inp, eps, _flatten_shape, weight, bias)
assert weight is not None and bias is not None
return apply(op, inp, weight, bias)[0]
else:
outvar, *_ = apply(layerNorm(), inp, eps, _flatten_shape)

return outvar
# assert weight is None and bias is None
return apply(op, inp)[0]


def batch_norm(


+ 0
- 55
imperative/python/test/unit/functional/test_functional.py View File

@@ -865,61 +865,6 @@ def test_conv1d():
)


def test_layer_norm():
def _layer_norm(x, normalized_shape, affine, weight=None, bias=None, eps=1e-5):
__layer_norm = LayerNorm(normalized_shape=normalized_shape, affine=affine)
__layer_norm.weight = weight
__layer_norm.bias = bias
return __layer_norm(x)

def _layer_norm_numpy(
x, normalized_shape, affine, weight=None, bias=None, eps=1e-5
):
x_shape = x.shape
dim_delta = len(x_shape) - len(normalized_shape)
non_flatten_shape = x_shape[:dim_delta]
x = x.reshape(*non_flatten_shape, -1)

mean = x.mean(axis=-1, keepdims=True)
var = (x ** 2).mean(axis=-1, keepdims=True) - mean * mean

x = (x - mean) / F.sqrt(var + eps)
x = x.reshape(x_shape)
if affine:
x = weight * x + bias

return x

normalized_shape = (28, 28)
inp_feat = Tensor(np.random.randn(32, 64, 28, 28), dtype="float32")
weight = Tensor(np.random.randn(28, 28), dtype="float32")
bias = Tensor(np.random.randn(28, 28), dtype="float32")

inp_feat = inp_feat + 1
weight = weight + 1
bias = bias

affine = False

outvar = F.nn.layer_norm(inp_feat, normalized_shape, affine, weight, bias)
targetvar = _layer_norm_numpy(inp_feat, normalized_shape, affine, weight, bias)

assert abs(outvar - targetvar).mean() < 1e-7

# no random, affine True
normalized_shape = (28, 28)
inp_feat = Tensor(np.ones((32, 64, 28, 28)), dtype="float32")
weight = Tensor(np.ones((28, 28)), dtype="float32")
bias = Tensor(np.zeros((28, 28)), dtype="float32")

affine = True

outvar = F.nn.layer_norm(inp_feat, normalized_shape, affine, weight, bias)
targetvar = _layer_norm(inp_feat, normalized_shape, affine, weight, bias)
assert abs((outvar - targetvar).mean()) < 1e-7
assert abs(outvar.mean()) < 1e-7


def test_batchnorm2d_autocast():
"""check amp's result is equal to manually converted result"""
amp.enabled = True


+ 1
- 1
imperative/python/test/unit/functional/test_loss.py View File

@@ -43,7 +43,7 @@ def test_cross_entropy():
x = softmax(x)
l_ref = ref(x, y)
l = F.nn.cross_entropy(tensor(x, "float32"), tensor(y, "int32"), with_logits=False)
np.testing.assert_allclose(l.numpy(), l_ref)
np.testing.assert_allclose(l.numpy(), l_ref, 1e-6, 1e-6)


def test_cross_entropy_reduction():


+ 26
- 0
imperative/src/impl/ops/specializations.cpp View File

@@ -20,6 +20,7 @@
#include "megbrain/opr/dnn/correlation.h"
#include "megbrain/opr/dnn/fake_quant.h"
#include "megbrain/opr/dnn/images2neibs.h"
#include "megbrain/opr/dnn/layer_norm.h"
#include "megbrain/opr/dnn/local.h"
#include "megbrain/opr/dnn/lrn.h"
#include "megbrain/opr/dnn/lsq.h"
@@ -636,4 +637,29 @@ auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
}
OP_TRAIT_REG(LRN, LRN).apply_on_var_node(apply_on_var_node).fallback();
} // namespace lrn

namespace layer_norm {

cg::OperatorNodeBase* apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& op = static_cast<const LayerNorm&>(def);
size_t nr_inp = inputs.size();
auto p = op.param();
mgb_assert((nr_inp == 3 && p.affine) || (nr_inp == 1 && !p.affine));
OperatorNodeConfig config{op.make_name()};
if (nr_inp == 3) {
return opr::LayerNorm::make(
inputs[0], inputs[1], inputs[2], op.param(), config)[0]
.node()
->owner_opr();
} else {
return opr::LayerNorm::make(inputs[0], op.param(), config)[0]
.node()
->owner_opr();
}
}

OP_TRAIT_REG(LayerNorm, LayerNorm).apply_on_var_node(apply_on_var_node).fallback();

} // namespace layer_norm

} // namespace mgb::imperative

+ 2
- 0
src/core/include/megbrain/ir/ops.td View File

@@ -431,4 +431,6 @@ def Padding: MgbHashableOp<"Padding", [PaddingParam]>;

def LRN: MgbHashableOp<"LRN", [LRNParam]>;

def LayerNorm: MgbHashableOp<"LayerNorm", [LayerNormParam]>;

#endif // MGB_OPS

+ 44
- 0
src/opr/impl/dnn/dnn.sereg.h View File

@@ -16,6 +16,7 @@
#include "megbrain/opr/dnn/correlation.h"
#include "megbrain/opr/dnn/fake_quant.h"
#include "megbrain/opr/dnn/images2neibs.h"
#include "megbrain/opr/dnn/layer_norm.h"
#include "megbrain/opr/dnn/local.h"
#include "megbrain/opr/dnn/lrn.h"
#include "megbrain/opr/dnn/lsq.h"
@@ -420,6 +421,47 @@ struct OprMaker<opr::BatchNormBackward, 6> {
}
};

template <>
struct OprMaker<opr::LayerNorm, 0> {
using Param = opr::LayerNorm::Param;
static cg::OperatorNodeBase* make(
const Param& param, const cg::VarNodeArray& i, ComputingGraph& graph,
const OperatorNodeConfig& config) {
MGB_MARK_USED_VAR(graph);
if (i.size() == 3) {
return opr::LayerNorm::make(i[0], i[1], i[2], param, config)[0]
.node()
->owner_opr();
} else {
mgb_assert(i.size() == 1);
return opr::LayerNorm::make(i[0], param, config)[0].node()->owner_opr();
}
}
};

// OprMaker in MGB_SEREG_OPR only support unique output opr
template <>
struct OprMaker<opr::LayerNormBackward, 0> {
using Param = opr::LayerNormBackward::Param;
static cg::OperatorNodeBase* make(
const Param& param, const cg::VarNodeArray& i, ComputingGraph& graph,
const OperatorNodeConfig& config) {
MGB_MARK_USED_VAR(graph);
if (i.size() == 5) {
return opr::LayerNormBackward::make(
i[0], i[1], i[2], i[3], i[4], param, config)[0]
.node()
->owner_opr();
} else {
mgb_assert(i.size() == 4);
return opr::LayerNormBackward::make(
i[0], i[1], i[2], i[3], param, config)[0]
.node()
->owner_opr();
}
}
};

template <class MegDNNConv = megdnn::LocalShare>
struct MakeLocalShareCaller2 {
template <typename Opr>
@@ -641,6 +683,8 @@ MGB_SEREG_OPR(TQT, 2);
MGB_SEREG_OPR(TQTBackward, 3);
MGB_SEREG_OPR(LSQ, 4);
MGB_SEREG_OPR(LSQBackward, 5);
MGB_SEREG_OPR(LayerNorm, 0);
MGB_SEREG_OPR(LayerNormBackward, 0);
} // namespace opr

} // namespace mgb


+ 248
- 0
src/opr/impl/dnn/layer_norm.cpp View File

@@ -0,0 +1,248 @@
/**
* \file src/opr/impl/dnn/layer_norm.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/

#include "megbrain/opr/dnn/layer_norm.h"

#include "megbrain/graph/grad_impl.h"
#include "megbrain/opr/internal/out_shape_by_sym_var.h"
#include "megbrain/opr/utility.h"

#include "../internal/megdnn_opr_wrapper.inl"

using namespace mgb;
using namespace opr;

/* ==================== LayerNormForward ==================== */
MGB_DYN_TYPE_OBJ_FINAL_IMPL(LayerNormForward);

LayerNormForward::LayerNormForward(
VarNode* data, VarNode* weight, VarNode* bias, const Param& param,
const OperatorNodeConfig& config)
: Super{data->owner_graph(), config, "layer_norm", {data, weight, bias}} {
init_megdnn_opr(*this, param);

add_input({data, weight, bias});
output(0)->dtype(data->dtype());
output(1)->dtype(dtype::Float32());
output(2)->dtype(dtype::Float32());
}

LayerNormForward::LayerNormForward(
VarNode* data, const Param& param, const OperatorNodeConfig& config)
: Super{data->owner_graph(), config, "layer_norm", {data}} {
init_megdnn_opr(*this, param);

add_input({data});
output(0)->dtype(data->dtype());
output(1)->dtype(dtype::Float32());
output(2)->dtype(dtype::Float32());
}

SymbolVarArray LayerNormForward::make(
SymbolVar data, SymbolVar weight, SymbolVar bias, const Param& param,
const OperatorNodeConfig& config) {
auto outs = data.node()
->owner_graph()
->insert_opr(std::make_unique<LayerNormForward>(
data.node(), weight.node(), bias.node(), param, config))
->output();
SymbolVarArray ret;
for (auto&& out : outs) {
ret.emplace_back(out);
}
return ret;
}

SymbolVarArray LayerNormForward::make(
SymbolVar data, const Param& param, const OperatorNodeConfig& config) {
auto outs = data.node()
->owner_graph()
->insert_opr(std::make_unique<LayerNormForward>(
data.node(), param, config))
->output();
SymbolVarArray ret;
for (auto&& out : outs) {
ret.emplace_back(out);
}
return ret;
}

void LayerNormForward::get_output_var_shape(
const TensorShapeArray& inp_shape, TensorShapeArray& out_shape) const {
uint64_t normalized_dim = param().normalized_dim;
out_shape[0] = inp_shape[0];
TensorShape unnormalized_shape;
unnormalized_shape.ndim = inp_shape[0].ndim - normalized_dim;
for (size_t i = 0; i < unnormalized_shape.ndim; ++i) {
unnormalized_shape.shape[i] = inp_shape[0].shape[i];
}
out_shape[1] = unnormalized_shape;
out_shape[2] = unnormalized_shape;
}

size_t LayerNormForward::get_workspace_size_bytes(
const TensorShapeArray& input_shapes,
const TensorShapeArray& output_shapes) const {
return 0;
}

void LayerNormForward::scn_do_execute() {
if (param().affine) {
megdnn_opr()->exec(
input(0)->dev_tensor().as_megdnn(), input(1)->dev_tensor().as_megdnn(),
input(2)->dev_tensor().as_megdnn(), output(0)->dev_tensor().as_megdnn(),
output(1)->dev_tensor().as_megdnn(),
output(2)->dev_tensor().as_megdnn(), {});
} else {
megdnn_opr()->exec(
input(0)->dev_tensor().as_megdnn(), {}, {},
output(0)->dev_tensor().as_megdnn(),
output(1)->dev_tensor().as_megdnn(),
output(2)->dev_tensor().as_megdnn(), {});
}
}

#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(LayerNormForward) {
auto p = opr.param();
SymbolVarArray grad;
VarNodeArray ret;
if (p.affine) {
mgb_assert(wrt_idx < 3, "wrt_idx %zu is out of range", wrt_idx);
grad = LayerNormBackward::make(
out_grad[0], opr.input(0), opr.input(1), opr.output(1), opr.output(2),
opr.param());
} else {
mgb_assert(wrt_idx < 1, "wrt_idx %zu is out of range", wrt_idx);
grad = LayerNormBackward::make(
out_grad[0], opr.input(0), opr.output(1), opr.output(2), opr.param());
}

uint32_t nr_ret = p.affine ? 3 : 1;
for (uint32_t i = 0; i < nr_ret; ++i) {
ret.push_back(grad[i].node());
}
return ret;
}
#endif

/* ==================== LayerNormBackward ==================== */
MGB_DYN_TYPE_OBJ_FINAL_IMPL(LayerNormBackward);

LayerNormBackward::LayerNormBackward(
VarNode* diff, VarNode* data, VarNode* weight, VarNode* mean, VarNode* rstd,
const Param& param, const OperatorNodeConfig& config)
: Super({diff->owner_graph(),
config,
"layer_norm_backward",
{diff, data, weight, mean, rstd}},
0, true) {
init_megdnn_opr(*this, param);
add_input({diff, data, weight, mean, rstd});
}

LayerNormBackward::LayerNormBackward(
VarNode* diff, VarNode* data, VarNode* mean, VarNode* rstd, const Param& param,
const OperatorNodeConfig& config)
: Super({diff->owner_graph(),
config,
"layer_norm_backward",
{diff, data, mean, rstd}},
0, true) {
init_megdnn_opr(*this, param);
add_input({diff, data, mean, rstd});
auto mark_empty_var = [&](VarNode* var) {
var->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE)
.add_flag(VarNode::Flag::VOLATILE_CONTENT);
};
mark_empty_var(output(1));
mark_empty_var(output(2));
}

SymbolVarArray LayerNormBackward::make(
SymbolVar diff, SymbolVar data, SymbolVar weight, SymbolVar mean,
SymbolVar rstd, const Param& param, const OperatorNodeConfig& config) {
auto outs = diff.node()
->owner_graph()
->insert_opr(std::make_unique<LayerNormBackward>(
diff.node(), data.node(), weight.node(), mean.node(),
rstd.node(), param, config))
->output();
SymbolVarArray ret;
for (auto&& out : outs) {
ret.emplace_back(out);
}
return ret;
}

SymbolVarArray LayerNormBackward::make(
SymbolVar diff, SymbolVar data, SymbolVar mean, SymbolVar rstd,
const Param& param, const OperatorNodeConfig& config) {
auto outs = diff.node()
->owner_graph()
->insert_opr(std::make_unique<LayerNormBackward>(
diff.node(), data.node(), mean.node(), rstd.node(),
param, config))
->output();
SymbolVarArray ret;
for (auto&& out : outs) {
ret.emplace_back(out);
}
return ret;
}

void LayerNormBackward::init_output_static_infer_desc() {
using namespace cg::static_infer;
auto&& mgr = owner_graph()->static_infer_manager();
mgr.register_shape_infer(output(0), ShapeInferDesc::make_identity(input(1)));
if (param().affine) {
mgr.register_shape_infer(output(1), ShapeInferDesc::make_identity(input(2)));
mgr.register_shape_infer(output(2), ShapeInferDesc::make_identity(input(2)));
} else {
TensorShape empty;
empty.ndim = 0;
mgr.register_shape_infer(output(1), ShapeInferDesc::make_const(empty));
mgr.register_shape_infer(output(2), ShapeInferDesc::make_const(empty));
}
this->init_output_static_infer_desc_workspace(false);
}

void LayerNormBackward::init_output_dtype() {
output(0)->dtype(input(1)->dtype());
output(1)->dtype(input(2)->dtype());
output(2)->dtype(input(2)->dtype());
}

size_t LayerNormBackward::get_workspace_size_bytes(
const TensorShapeArray& input_shapes,
const TensorShapeArray& output_shapes) const {
return 0;
}

void LayerNormBackward::scn_do_execute() {
if (param().affine) {
megdnn_opr()->exec(
input(0)->dev_tensor().as_megdnn(), input(1)->dev_tensor().as_megdnn(),
input(2)->dev_tensor().as_megdnn(), input(3)->dev_tensor().as_megdnn(),
input(4)->dev_tensor().as_megdnn(), output(0)->dev_tensor().as_megdnn(),
output(1)->dev_tensor().as_megdnn(),
output(2)->dev_tensor().as_megdnn(), {});
} else {
megdnn_opr()->exec(
input(0)->dev_tensor().as_megdnn(), input(1)->dev_tensor().as_megdnn(),
{}, input(2)->dev_tensor().as_megdnn(),
input(3)->dev_tensor().as_megdnn(), output(0)->dev_tensor().as_megdnn(),
{}, {}, {});
}
}

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

+ 78
- 0
src/opr/include/megbrain/opr/dnn/layer_norm.h View File

@@ -0,0 +1,78 @@
/**
* \file src/opr/include/megbrain/opr/dnn/layer_norm.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#pragma once

#include "megbrain/opr/internal/megdnn_opr_wrapper.h"
#include "megdnn/oprs.h"

namespace mgb {
namespace opr {

MGB_DEFINE_OPR_CLASS_WITH_EXPORT(
LayerNormForward, intl::MegDNNOprWrapperFwd<megdnn::LayerNormForward>) // {
public:
MGE_WIN_DECLSPEC_FUC LayerNormForward(
VarNode* data, VarNode* weight, VarNode* bias, const Param& param,
const OperatorNodeConfig& config);
MGE_WIN_DECLSPEC_FUC LayerNormForward(
VarNode* data, const Param& param, const OperatorNodeConfig& config);

MGE_WIN_DECLSPEC_FUC static SymbolVarArray make(
SymbolVar data, SymbolVar weight, SymbolVar bias, const Param& param = {},
const OperatorNodeConfig& config = {});
MGE_WIN_DECLSPEC_FUC static SymbolVarArray make(
SymbolVar data, const Param& param = {},
const OperatorNodeConfig& config = {});

private:
void get_output_var_shape(
const TensorShapeArray& inp_shape,
TensorShapeArray& out_shape) const override;
size_t get_workspace_size_bytes(
const TensorShapeArray& input_shapes,
const TensorShapeArray& output_shapes) const override;
void scn_do_execute() override;
};
using LayerNorm = LayerNormForward;

MGB_DEFINE_OPR_CLASS_WITH_EXPORT(
LayerNormBackward, intl::MegDNNOprWrapperBwd<megdnn::LayerNormBackward>) // {
public:
MGE_WIN_DECLSPEC_FUC LayerNormBackward(
VarNode* diff, VarNode* data, VarNode* weight, VarNode* mean, VarNode* rstd,
const Param& param, const OperatorNodeConfig& config);

MGE_WIN_DECLSPEC_FUC LayerNormBackward(
VarNode* diff, VarNode* data, VarNode* mean, VarNode* rstd,
const Param& param, const OperatorNodeConfig& config);

MGE_WIN_DECLSPEC_FUC static SymbolVarArray make(
SymbolVar diff, SymbolVar data, SymbolVar weight, SymbolVar mean,
SymbolVar rstd, const Param& param = {},
const OperatorNodeConfig& config = {});
MGE_WIN_DECLSPEC_FUC static SymbolVarArray make(
SymbolVar diff, SymbolVar data, SymbolVar mean, SymbolVar rstd,
const Param& param = {}, const OperatorNodeConfig& config = {});

private:
void init_output_static_infer_desc() override;
void init_output_dtype() override;
size_t get_workspace_size_bytes(
const TensorShapeArray& input_shapes,
const TensorShapeArray& output_shapes) const override;
void scn_do_execute() override;
};

} // namespace opr
} // namespace mgb

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

+ 108
- 0
src/opr/test/dnn/layer_norm.cpp View File

@@ -0,0 +1,108 @@
/**
* \file src/opr/test/dnn/layer_norm.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/

#include "megbrain/opr/dnn/layer_norm.h"
#include "megbrain/comp_node_env.h"
#include "megbrain/test/autocheck.h"
#include "megbrain/test/helper.h"
#include "megbrain/test/megdnn_helper.h"

#include "megdnn/oprs.h"

#include <cmath>
#include <iomanip>
#include <random>
#include <sstream>

using namespace mgb;

namespace {
using Param = opr::LayerNormForward::Param;

void run_forward(bool is_affine, size_t normalized_size) {
using Checker = AutoOprChecker<3, 3>;

Param param;
param.eps = 1e-5;
param.affine = is_affine;
param.normalized_dim = 1;
param.normalized_size = normalized_size;

auto make_graph = [&](const Checker::SymInpArray& inputs) -> Checker::SymOutArray {
auto out = opr::LayerNormForward::make(inputs[0], inputs[1], inputs[2], param);
return {out[0], out[1], out[2]};
};

auto fwd = [&](Checker::NumOutArray& dest, Checker::NumInpArray inp) {
auto opr =
MegDNNHandle::get(CompNodeEnv::from_comp_node(CompNode::default_cpu()))
->create_operator<megdnn::LayerNormForward>();
auto inp_shape = inp[0]->shape();
auto n_slices = inp_shape[0];
auto slice_len = inp_shape[1];

opr->param() = param;

dest[0].dtype(dtype::Float32())
.comp_node(inp[0]->comp_node())
.resize({n_slices, slice_len});
dest[1].dtype(dtype::Float32())
.comp_node(inp[0]->comp_node())
.resize({n_slices});
dest[2].dtype(dtype::Float32())
.comp_node(inp[0]->comp_node())
.resize({n_slices});
opr->exec(
inp[0]->as_megdnn(), inp[1]->as_megdnn(), inp[2]->as_megdnn(),
dest[0].as_megdnn(), dest[1].as_megdnn(), dest[2].as_megdnn(), {});
};

auto gen = [&](HostTensorND& src) {
HostTensorGenerator<dtype::Float32, RandomDistribution::GAUSSIAN> src_gen(0.f);
src = *src_gen(src.shape(), src.comp_node());
};

Checker::RunOptions option;
option.numdiff_max_err = 1e-4;
Checker checker{make_graph, fwd};

checker.set_input_generator(0, gen);
checker.set_input_generator(1, gen);
checker.set_input_generator(2, gen);
checker.set_input_allow_grad(0, false);
checker.set_input_allow_grad(1, false);
checker.set_input_allow_grad(2, false);
checker.set_output_allow_grad(0, false);
checker.set_output_allow_grad(1, false);
checker.set_output_allow_grad(2, false);

checker.run({TensorShape{normalized_size, normalized_size},
TensorShape{normalized_size}, TensorShape{normalized_size}},
option)
.run({TensorShape{normalized_size, normalized_size},
TensorShape{normalized_size}, TensorShape{normalized_size}},
option)
.run({TensorShape{normalized_size, normalized_size},
TensorShape{normalized_size}, TensorShape{normalized_size}},
option);
}

TEST(TestOprDNN, LayerNormForwardAffine) {
REQUIRE_GPU(1);
run_forward(true, 1);
run_forward(true, 16);
run_forward(true, 17);
}

} // anonymous namespace

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

+ 1
- 0
src/serialization/impl/schema.fbs View File

@@ -116,6 +116,7 @@ union OperatorParam {
param.Padding = 82,
param.ShuffleRNG = 83,
param.CheckNonFinite = 84,
param.LayerNorm = 85,
}

table Operator {


Loading…
Cancel
Save