| @@ -0,0 +1,205 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <stdio.h> | |||||
| #include <stdint.h> | |||||
| #include <cuda_runtime.h> | |||||
| #include "kernel/gpu/cuda_impl/layer_norm_grad_impl.cuh" | |||||
| constexpr int NUM_PER_THREAD_REDUCE = 4; | |||||
| constexpr int WARP_SIZE = 32; | |||||
| template <typename T> | |||||
| inline __device__ void GammaAndBetaThreadReduce(const int& col, const int& row_dim, const int& col_dim, | |||||
| const T& epsilon, const T* dy, const T* x, const T* mean, const T* var, | |||||
| T* dg, T* db) { | |||||
| int loop_num = (row_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE; | |||||
| for (int i = threadIdx.x; i < loop_num; i += blockDim.x) { | |||||
| for (int j = 0; j < NUM_PER_THREAD_REDUCE; j++) { | |||||
| int row = NUM_PER_THREAD_REDUCE * i + j; | |||||
| if (row >= row_dim) { | |||||
| return; | |||||
| } | |||||
| int pos = row * col_dim + col; | |||||
| dg[0] += dy[pos] * pow(var[row] + epsilon, -0.5) * (x[pos] - mean[row]); | |||||
| db[0] += dy[pos]; | |||||
| } | |||||
| } | |||||
| } | |||||
| template <typename T> | |||||
| inline __device__ void GammaAndBetaWarpReduce(T* dg, T* db) { | |||||
| for (int delta = (WARP_SIZE >> 1); delta > 0; delta >>= 1) { | |||||
| dg[0] += __shfl_down_sync(0xffffffff, dg[0], delta); | |||||
| db[0] += __shfl_down_sync(0xffffffff, db[0], delta); | |||||
| } | |||||
| } | |||||
| template <typename T> | |||||
| inline __device__ void GammaAndBetaBlockReduce(const int& col, const int& row_dim, T* dg, T* db, T* dg_addr, | |||||
| T* db_addr) { | |||||
| if (threadIdx.x >= row_dim) { | |||||
| return; | |||||
| } | |||||
| // load data to share memory | |||||
| // thread(0, 32, 64, 96, ...) keep the data | |||||
| extern __shared__ T share_mem[]; | |||||
| if (threadIdx.x % WARP_SIZE == 0) { | |||||
| int offset = threadIdx.x / WARP_SIZE * 2; | |||||
| share_mem[offset] = dg[0]; | |||||
| share_mem[offset + 1] = db[0]; | |||||
| } | |||||
| __syncthreads(); | |||||
| for (int stride = blockDim.x / WARP_SIZE / 2; stride > 0; stride >>= 1) { | |||||
| if (threadIdx.x < stride) { | |||||
| int offset = (threadIdx.x + stride) * 2; | |||||
| share_mem[threadIdx.x * 2] += share_mem[offset]; | |||||
| share_mem[threadIdx.x * 2 + 1] += share_mem[offset + 1]; | |||||
| } | |||||
| } | |||||
| __syncthreads(); | |||||
| if (threadIdx.x == 0) { | |||||
| dg_addr[col] = share_mem[0]; | |||||
| db_addr[col] = share_mem[1]; | |||||
| } | |||||
| } | |||||
| template <typename T> | |||||
| __global__ void GammaAndBetaPropKernel(const int row_dim, const int col_dim, const T epsilon, const T* dy, const T* x, | |||||
| const T* mean_addr, const T* var_addr, T* dg_addr, T* db_addr) { | |||||
| // row: [0:param_axis] | |||||
| // col: [param_axis:] | |||||
| // dg[i][j] = dy[i][j] * (var[i] + epsilon, -0.5) * (x[i][j] - mean[i]) | |||||
| // dg[j] = \Sigma_{j}dg[i][j] | |||||
| for (int col = blockIdx.x; col < col_dim; col += gridDim.x) { | |||||
| T dg = 0; | |||||
| T db = 0; | |||||
| GammaAndBetaThreadReduce(col, row_dim, col_dim, epsilon, dy, x, mean_addr, var_addr, &dg, &db); | |||||
| GammaAndBetaWarpReduce(&dg, &db); | |||||
| GammaAndBetaBlockReduce(col, row_dim, &dg, &db, dg_addr, db_addr); | |||||
| } | |||||
| } | |||||
| template <typename T> | |||||
| inline __device__ void InputThreadReduce(const int& row, const int& col_dim, const int& param_dim, const T& epsilon, | |||||
| T* sum1, T* sum2, T* sum3, const T* dy, const T* x, const T* mean, | |||||
| const T* var, const T* gamma) { | |||||
| int loop_num = (col_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE; | |||||
| for (int i = threadIdx.x; i < loop_num; i += blockDim.x) { | |||||
| for (int j = 0; j < NUM_PER_THREAD_REDUCE; j++) { | |||||
| int col = NUM_PER_THREAD_REDUCE * i + j; | |||||
| if (col >= col_dim) { | |||||
| return; | |||||
| } | |||||
| int pos = row * col_dim + col; | |||||
| int gamma_offset = pos % param_dim; | |||||
| T v1 = dy[pos] * gamma[gamma_offset]; | |||||
| T v2 = x[pos] - mean[row]; | |||||
| sum1[0] += -0.5 * v1 * v2 * pow(var[row] + epsilon, -1.5); | |||||
| sum2[0] += v1; | |||||
| sum3[0] += -2.0 * v2; | |||||
| } | |||||
| } | |||||
| } | |||||
| template <typename T> | |||||
| inline __device__ void InputWarpReduce(T* sum1, T* sum2, T* sum3) { | |||||
| for (int delta = (WARP_SIZE >> 1); delta > 0; delta >>= 1) { | |||||
| sum1[0] += __shfl_down_sync(0xffffffff, sum1[0], delta); | |||||
| sum2[0] += __shfl_down_sync(0xffffffff, sum2[0], delta); | |||||
| sum3[0] += __shfl_down_sync(0xffffffff, sum3[0], delta); | |||||
| } | |||||
| } | |||||
| template <typename T> | |||||
| inline __device__ void InputBlockReduce(const int& col_dim, T* sum1, T* sum2, T* sum3, T* share_mem) { | |||||
| if (threadIdx.x >= col_dim) { | |||||
| return; | |||||
| } | |||||
| // load data to share memory | |||||
| // thread(0, 32, 64, 96, ...) keep the data | |||||
| if (threadIdx.x % WARP_SIZE == 0) { | |||||
| int offset = threadIdx.x / WARP_SIZE * 3; | |||||
| share_mem[offset] = sum1[0]; | |||||
| share_mem[offset + 1] = sum2[0]; | |||||
| share_mem[offset + 2] = sum3[0]; | |||||
| } | |||||
| __syncthreads(); | |||||
| for (int stride = blockDim.x / WARP_SIZE / 2; stride > 0; stride >>= 1) { | |||||
| if (threadIdx.x < stride) { | |||||
| int offset = (threadIdx.x + stride) * 3; | |||||
| share_mem[threadIdx.x * 3] += share_mem[offset]; | |||||
| share_mem[threadIdx.x * 3 + 1] += share_mem[offset + 1]; | |||||
| share_mem[threadIdx.x * 3 + 2] += share_mem[offset + 2]; | |||||
| } | |||||
| } | |||||
| __syncthreads(); | |||||
| } | |||||
| template <typename T> | |||||
| inline __device__ void InputProp(const int& row, const int& col_dim, const int& param_dim, const T& epsilon, | |||||
| const T* dy, const T* x, const T* mean, const T* var, const T* gamma, T* dx, | |||||
| const T* share_mem) { | |||||
| for (int col = threadIdx.x; col < col_dim; col += blockDim.x) { | |||||
| int pos = (row * col_dim + col); | |||||
| int gamma_offset = pos % param_dim; | |||||
| T v1 = dy[pos] * gamma[gamma_offset]; | |||||
| T v2 = x[pos] - mean[row]; | |||||
| T v3 = pow(var[row] + epsilon, -0.5); | |||||
| dx[pos] = v1 * v3 + share_mem[0] * (2.0 / col_dim) * v2 + | |||||
| (-1.0 * v3 * share_mem[1] + (1.0 / col_dim) * share_mem[0] * share_mem[2]) * (1.0 / col_dim); | |||||
| } | |||||
| } | |||||
| template <typename T> | |||||
| __global__ void InputPropKernel(const int row_dim, const int col_dim, const int param_dim, const T epsilon, const T* dy, | |||||
| const T* x, const T* mean, const T* var, const T* gamma, T* dx) { | |||||
| for (int row = blockIdx.x; row < row_dim; row += gridDim.x) { | |||||
| T sum1 = 0; | |||||
| T sum2 = 0; | |||||
| T sum3 = 0; | |||||
| extern __shared__ T share_mem[]; | |||||
| InputThreadReduce(row, col_dim, param_dim, epsilon, &sum1, &sum2, &sum3, dy, x, mean, var, gamma); | |||||
| InputWarpReduce(&sum1, &sum2, &sum3); | |||||
| InputBlockReduce(col_dim, &sum1, &sum2, &sum3, share_mem); | |||||
| InputProp(row, col_dim, param_dim, epsilon, dy, x, mean, var, gamma, dx, share_mem); | |||||
| } | |||||
| } | |||||
| template <typename T> | |||||
| void LayerNormGrad(const int& row_dim, const int& col_dim, const int& param_dim, const T& epsilon, const T* dy, | |||||
| const T* x, const T* mean, const T* var, const T* gamma, T* dx, T* dg, T* db, cudaStream_t stream) { | |||||
| int share_mem = | |||||
| ((col_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE + WARP_SIZE - 1) / WARP_SIZE * 3 * sizeof(T); | |||||
| InputPropKernel<<<row_dim, 256, share_mem, stream>>>(row_dim, col_dim, param_dim, epsilon, dy, x, mean, var, gamma, | |||||
| dx); | |||||
| share_mem = | |||||
| ((row_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE + WARP_SIZE - 1) / WARP_SIZE * 2 * sizeof(T); | |||||
| GammaAndBetaPropKernel<<<col_dim, 256, share_mem, stream>>>(row_dim, col_dim, epsilon, dy, x, mean, var, dg, db); | |||||
| } | |||||
| template void LayerNormGrad(const int& row_dim, const int& col_dim, const int& param_dim, const float& epsilon, | |||||
| const float* dy, const float* x, const float* mean, const float* var, const float* gamma, | |||||
| float* dx, float* dg, float* db, cudaStream_t stream); | |||||
| @@ -0,0 +1,26 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_LAYER_NORM_GRAD_H_ | |||||
| #define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_LAYER_NORM_GRAD_H_ | |||||
| #include "device/gpu/cuda_common.h" | |||||
| template <typename T> | |||||
| void LayerNormGrad(const int& row_dim, const int& col_dim, const int& param_dim, const T& epsilon, const T* dy, | |||||
| const T* x, const T* mean, const T* var, const T* gamma, T* dx, T* dg, T* db, cudaStream_t stream); | |||||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_LAYER_NORM_GRAD_H_ | |||||
| @@ -0,0 +1,148 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <stdio.h> | |||||
| #include <stdint.h> | |||||
| #include <cuda_runtime.h> | |||||
| #include "kernel/gpu/cuda_impl/layer_norm_impl.cuh" | |||||
| constexpr int NUM_PER_THREAD_REDUCE = 4; | |||||
| constexpr int WARP_SIZE = 32; | |||||
| template <typename T> | |||||
| inline __device__ void MeanAndVarAccumulation(T* mean, T* var, T* num, const T& val) { | |||||
| // Welford Algorithm: | |||||
| // \mu_k = \mu_{k-1} + (x_k - \mu_{k-1})/k | |||||
| // \sigma_k^2 = \sigma_{k-1}^2 + (x_k - \mu_{k-1}) * (x_k - \mu_k) | |||||
| num[0]++; | |||||
| T mean_new = mean[0] + (val - mean[0]) / num[0]; | |||||
| var[0] = var[0] + (val - mean[0]) * (val - mean_new); | |||||
| mean[0] = mean_new; | |||||
| } | |||||
| template <typename T> | |||||
| inline __device__ void MeanAndVarMerge(T* m1, T* v1, T* n1, const T& m2, const T& v2, const T& n2) { | |||||
| if (n2 == 0) { | |||||
| return; | |||||
| } | |||||
| T count = n1[0] + n2; | |||||
| v1[0] = v1[0] + v2 + (m1[0] - m2) * (m1[0] - m2) * n1[0] * n2 / count; | |||||
| m1[0] = (n1[0] * m1[0] + n2 * m2) / count; | |||||
| n1[0] = count; | |||||
| } | |||||
| template <typename T> | |||||
| inline __device__ void ThreadReduce(const int& col_dim, const T* block_addr, T* mean, T* var, T* num) { | |||||
| int loop_num = (col_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE; | |||||
| for (int i = threadIdx.x; i < loop_num; i += blockDim.x) { | |||||
| for (int j = 0; j < NUM_PER_THREAD_REDUCE; j++) { | |||||
| int pos = NUM_PER_THREAD_REDUCE * i + j; | |||||
| if (pos >= col_dim) { | |||||
| return; | |||||
| } | |||||
| MeanAndVarAccumulation(mean, var, num, block_addr[pos]); | |||||
| } | |||||
| } | |||||
| } | |||||
| template <typename T> | |||||
| inline __device__ void WarpReduce(T* mean, T* var, T* num) { | |||||
| for (int delta = (WARP_SIZE >> 1); delta > 0; delta >>= 1) { | |||||
| T mean_other = __shfl_down_sync(0xffffffff, mean[0], delta); | |||||
| T var_other = __shfl_down_sync(0xffffffff, var[0], delta); | |||||
| T num_other = __shfl_down_sync(0xffffffff, num[0], delta); | |||||
| MeanAndVarMerge(mean, var, num, mean_other, var_other, num_other); | |||||
| } | |||||
| } | |||||
| template <typename T> | |||||
| inline __device__ void BlockReduce(const int& col_dim, T* mean, T* var, T* num, T* mean_addr, T* var_addr, | |||||
| T* share_mem) { | |||||
| if (threadIdx.x >= col_dim) { | |||||
| return; | |||||
| } | |||||
| // load data to share memory | |||||
| // thread(0, 32, 64, 96, ...) keep the data | |||||
| if (threadIdx.x % WARP_SIZE == 0) { | |||||
| int offset = threadIdx.x / WARP_SIZE * 3; | |||||
| share_mem[offset] = mean[0]; | |||||
| share_mem[offset + 1] = var[0]; | |||||
| share_mem[offset + 2] = num[0]; | |||||
| } | |||||
| __syncthreads(); | |||||
| for (int stride = blockDim.x / WARP_SIZE / 2; stride > 0; stride >>= 1) { | |||||
| if (threadIdx.x < stride) { | |||||
| int offset = (threadIdx.x + stride) * 3; | |||||
| MeanAndVarMerge(&share_mem[threadIdx.x * 3], &share_mem[threadIdx.x * 3 + 1], &share_mem[threadIdx.x * 3 + 2], | |||||
| share_mem[offset], share_mem[offset + 1], share_mem[offset + 2]); | |||||
| } | |||||
| } | |||||
| __syncthreads(); | |||||
| if (threadIdx.x == 0) { | |||||
| mean_addr[blockIdx.x] = share_mem[0]; // todo: blockDim.x < row | |||||
| share_mem[1] /= col_dim; | |||||
| var_addr[blockIdx.x] = share_mem[1]; | |||||
| } | |||||
| } | |||||
| template <typename T> | |||||
| inline __device__ void LayerNorm(const int& row, const int& col_dim, const int& param_dim, const T* x, | |||||
| const T* share_mem, const T* gamma, const T* beta, const T epsilon, T* y) { | |||||
| for (int col = threadIdx.x; col < col_dim; col += blockDim.x) { | |||||
| int pos = row * col_dim + col; | |||||
| int i = pos % param_dim; | |||||
| y[pos] = (x[pos] - share_mem[0]) / sqrt(share_mem[1] + epsilon) * gamma[i] + beta[i]; | |||||
| } | |||||
| } | |||||
| template <typename T> | |||||
| __global__ void LayerNormKernel(const int row_dim, const int col_dim, const int param_dim, const T epsilon, const T* x, | |||||
| const T* gamma, const T* beta, T* y, T* mean_addr, T* var_addr) { | |||||
| for (auto row = blockIdx.x; row < row_dim; row += gridDim.x) { | |||||
| T mean = 0; | |||||
| T var = 0; | |||||
| T num = 0; | |||||
| const T* block_addr = x + row * col_dim; | |||||
| extern __shared__ T share_mem[]; | |||||
| ThreadReduce(col_dim, block_addr, &mean, &var, &num); | |||||
| WarpReduce(&mean, &var, &num); | |||||
| BlockReduce(col_dim, &mean, &var, &num, mean_addr, var_addr, share_mem); | |||||
| __syncthreads(); | |||||
| LayerNorm(row, col_dim, param_dim, x, share_mem, gamma, beta, epsilon, y); | |||||
| } | |||||
| } | |||||
| template <typename T> | |||||
| void LayerNorm(const int& row_dim, const int& col_dim, const int& param_dim, const T& epsilon, const T* x, | |||||
| const T* gamma, const T* beta, T* y, T* mean, T* var, cudaStream_t stream) { | |||||
| const dim3 block(row_dim); | |||||
| const dim3 thread(256); | |||||
| // keep the mean/var/num after warp reduce | |||||
| int share_mem = | |||||
| ((col_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE + WARP_SIZE - 1) / WARP_SIZE * 3 * sizeof(T); | |||||
| LayerNormKernel<<<block, thread, share_mem, stream>>>(row_dim, col_dim, param_dim, epsilon, x, gamma, beta, y, mean, | |||||
| var); | |||||
| } | |||||
| template void LayerNorm(const int& row_dim, const int& col_dim, const int& param_dim, const float& epsilon, | |||||
| const float* x, const float* gamma, const float* beta, float* y, float* mean, float* var, | |||||
| cudaStream_t stream); | |||||
| @@ -0,0 +1,26 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_LAYER_NORM_H_ | |||||
| #define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_LAYER_NORM_H_ | |||||
| #include "device/gpu/cuda_common.h" | |||||
| template <typename T> | |||||
| void LayerNorm(const int& outer, const int& inner, const int& param_dim, const T& epsilon, const T* x, const T* gamma, | |||||
| const T* beta, T* y, T* mean, T* var, cudaStream_t stream); | |||||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_LAYER_NORM_H_ | |||||
| @@ -0,0 +1,31 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "kernel/gpu/nn/layer_norm_gpu_kernel.h" | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| MS_REG_GPU_KERNEL_ONE(LayerNorm, | |||||
| KernelAttr() | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddOutputAttr(kNumberTypeFloat32) | |||||
| .AddOutputAttr(kNumberTypeFloat32) | |||||
| .AddOutputAttr(kNumberTypeFloat32), | |||||
| LayerNormGpuKernel, float) | |||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,103 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_LAYER_NORM_GPU_KERNEL_H_ | |||||
| #define MINDSPORE_CCSRC_KERNEL_GPU_NN_LAYER_NORM_GPU_KERNEL_H_ | |||||
| #include <vector> | |||||
| #include "kernel/gpu/gpu_kernel.h" | |||||
| #include "kernel/gpu/gpu_kernel_factory.h" | |||||
| #include "kernel/gpu/cuda_impl/layer_norm_impl.cuh" | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| template <typename T> | |||||
| class LayerNormGpuKernel : public GpuKernel { | |||||
| public: | |||||
| LayerNormGpuKernel() : input_row_(1), input_col_(1), param_dim_(1) {} | |||||
| ~LayerNormGpuKernel() override = default; | |||||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||||
| const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | |||||
| const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | |||||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, | |||||
| const std::vector<AddressPtr> &outputs, uintptr_t stream_ptr) override { | |||||
| auto x = GetDeviceAddress<T>(inputs, 0); | |||||
| auto gamma = GetDeviceAddress<T>(inputs, 1); | |||||
| auto beta = GetDeviceAddress<T>(inputs, 2); | |||||
| auto y = GetDeviceAddress<T>(outputs, 0); | |||||
| auto mean = GetDeviceAddress<T>(outputs, 1); | |||||
| auto variance = GetDeviceAddress<T>(outputs, 2); | |||||
| T epsilon = 10e-12; | |||||
| LayerNorm(input_row_, input_col_, param_dim_, epsilon, x, gamma, beta, y, mean, variance, | |||||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||||
| return true; | |||||
| } | |||||
| bool Init(const CNodePtr &kernel_node) override { | |||||
| int begin_norm_axis = GetAttr<int>(kernel_node, "begin_norm_axis"); | |||||
| int begin_params_axis = GetAttr<int>(kernel_node, "begin_params_axis"); | |||||
| auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||||
| if (begin_norm_axis < 0) { | |||||
| begin_norm_axis += input_shape.size(); | |||||
| } | |||||
| if (begin_params_axis < 0) { | |||||
| begin_params_axis += input_shape.size(); | |||||
| } | |||||
| for (size_t i = 0; i < IntToSize(begin_norm_axis); i++) { | |||||
| input_row_ *= input_shape[i]; | |||||
| } | |||||
| for (size_t i = begin_norm_axis; i < input_shape.size(); i++) { | |||||
| input_col_ *= input_shape[i]; | |||||
| } | |||||
| for (size_t i = begin_params_axis; i < input_shape.size(); i++) { | |||||
| param_dim_ *= input_shape[i]; | |||||
| } | |||||
| InitSizeLists(); | |||||
| return true; | |||||
| } | |||||
| protected: | |||||
| void InitSizeLists() override { | |||||
| input_size_list_.push_back(input_row_ * input_col_ * sizeof(T)); | |||||
| input_size_list_.push_back(param_dim_ * sizeof(T)); | |||||
| input_size_list_.push_back(param_dim_ * sizeof(T)); | |||||
| output_size_list_.push_back(input_row_ * input_col_ * sizeof(T)); | |||||
| output_size_list_.push_back(input_row_ * sizeof(T)); | |||||
| output_size_list_.push_back(input_row_ * sizeof(T)); | |||||
| return; | |||||
| } | |||||
| private: | |||||
| std::vector<size_t> input_size_list_; | |||||
| std::vector<size_t> output_size_list_; | |||||
| std::vector<size_t> workspace_size_list_; | |||||
| int input_row_; | |||||
| int input_col_; | |||||
| int param_dim_; | |||||
| }; | |||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_LAYER_NORM_GPU_KERNEL_H_ | |||||
| @@ -0,0 +1,33 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "kernel/gpu/nn/layer_norm_grad_gpu_kernel.h" | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| MS_REG_GPU_KERNEL_ONE(LayerNormGrad, | |||||
| KernelAttr() | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddOutputAttr(kNumberTypeFloat32) | |||||
| .AddOutputAttr(kNumberTypeFloat32) | |||||
| .AddOutputAttr(kNumberTypeFloat32), | |||||
| LayerNormGradGpuKernel, float) | |||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,107 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_LAYER_NORM_GRAD_GPU_KERNEL_H_ | |||||
| #define MINDSPORE_CCSRC_KERNEL_GPU_NN_LAYER_NORM_GRAD_GPU_KERNEL_H_ | |||||
| #include <vector> | |||||
| #include "kernel/gpu/gpu_kernel.h" | |||||
| #include "kernel/gpu/gpu_kernel_factory.h" | |||||
| #include "kernel/gpu/cuda_impl/layer_norm_grad_impl.cuh" | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| template <typename T> | |||||
| class LayerNormGradGpuKernel : public GpuKernel { | |||||
| public: | |||||
| LayerNormGradGpuKernel() : input_row_(1), input_col_(1), param_dim_(1) {} | |||||
| ~LayerNormGradGpuKernel() override = default; | |||||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||||
| const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | |||||
| const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | |||||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, | |||||
| const std::vector<AddressPtr> &outputs, uintptr_t stream_ptr) override { | |||||
| auto dy = GetDeviceAddress<T>(inputs, 0); | |||||
| auto x = GetDeviceAddress<T>(inputs, 1); | |||||
| auto var = GetDeviceAddress<T>(inputs, 2); | |||||
| auto mean = GetDeviceAddress<T>(inputs, 3); | |||||
| auto gamma = GetDeviceAddress<T>(inputs, 4); | |||||
| auto dx = GetDeviceAddress<T>(outputs, 0); | |||||
| auto dg = GetDeviceAddress<T>(outputs, 1); | |||||
| auto db = GetDeviceAddress<T>(outputs, 2); | |||||
| T epsilon = 10e-12; | |||||
| LayerNormGrad(input_row_, input_col_, param_dim_, epsilon, dy, x, mean, var, gamma, dx, dg, db, | |||||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||||
| return true; | |||||
| } | |||||
| bool Init(const CNodePtr &kernel_node) override { | |||||
| int begin_norm_axis = GetAttr<int>(kernel_node, "begin_norm_axis"); | |||||
| int begin_params_axis = GetAttr<int>(kernel_node, "begin_params_axis"); | |||||
| auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||||
| if (begin_norm_axis < 0) { | |||||
| begin_norm_axis += input_shape.size(); | |||||
| } | |||||
| if (begin_params_axis < 0) { | |||||
| begin_params_axis += input_shape.size(); | |||||
| } | |||||
| for (size_t i = 0; i < IntToSize(begin_norm_axis); i++) { | |||||
| input_row_ *= input_shape[i]; | |||||
| } | |||||
| for (size_t i = begin_norm_axis; i < input_shape.size(); i++) { | |||||
| input_col_ *= input_shape[i]; | |||||
| } | |||||
| for (size_t i = begin_params_axis; i < input_shape.size(); i++) { | |||||
| param_dim_ *= input_shape[i]; | |||||
| } | |||||
| InitSizeLists(); | |||||
| return true; | |||||
| } | |||||
| protected: | |||||
| void InitSizeLists() override { | |||||
| input_size_list_.push_back(input_row_ * input_col_ * sizeof(T)); | |||||
| input_size_list_.push_back(input_row_ * input_col_ * sizeof(T)); | |||||
| input_size_list_.push_back(input_row_ * sizeof(T)); | |||||
| input_size_list_.push_back(input_row_ * sizeof(T)); | |||||
| input_size_list_.push_back(param_dim_ * sizeof(T)); | |||||
| output_size_list_.push_back(input_row_ * input_col_ * sizeof(T)); | |||||
| output_size_list_.push_back(param_dim_ * sizeof(T)); | |||||
| output_size_list_.push_back(param_dim_ * sizeof(T)); | |||||
| return; | |||||
| } | |||||
| private: | |||||
| std::vector<size_t> input_size_list_; | |||||
| std::vector<size_t> output_size_list_; | |||||
| std::vector<size_t> workspace_size_list_; | |||||
| int input_row_; | |||||
| int input_col_; | |||||
| int param_dim_; | |||||
| }; | |||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_LAYER_NORM_GRAD_GPU_KERNEL_H_ | |||||
| @@ -0,0 +1,140 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import pytest | |||||
| import numpy as np | |||||
| from mindspore import Tensor | |||||
| from mindspore.ops import operations as P | |||||
| from mindspore.ops.operations import _grad_ops as G | |||||
| from mindspore.ops import composite as C | |||||
| import mindspore.nn as nn | |||||
| import mindspore.context as context | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||||
| class LayerNormGradNet(nn.Cell): | |||||
| def __init__(self, begin_norm_axis, begin_params_axis): | |||||
| super(LayerNormGradNet, self).__init__() | |||||
| self.norm = G.LayerNormGrad(begin_norm_axis, begin_params_axis) | |||||
| def construct(self, dy, x, var, mean, gamma): | |||||
| return self.norm(dy, x, var, mean, gamma) | |||||
| def LayerNormGradReference(x, dy, gamma, epsilon, begin_norm_axis, begin_params_axis): | |||||
| begin_norm_axis = begin_norm_axis if begin_norm_axis >=0 else begin_norm_axis + len(x.shape) | |||||
| begin_params_axis = begin_params_axis if begin_params_axis >=0 else begin_params_axis + len(x.shape) | |||||
| norm_axis = [i for i in range(begin_norm_axis, len(x.shape))] | |||||
| param_axis = [i for i in range(0, begin_params_axis)] | |||||
| num = 1 | |||||
| for i in range(begin_norm_axis, len(x.shape)): | |||||
| num *= x.shape[i] | |||||
| mean = np.mean(x, axis=tuple(norm_axis), keepdims=True) | |||||
| var = np.var(x, axis=tuple(norm_axis), keepdims=True) | |||||
| gamma = gamma.reshape((*((1,)*begin_params_axis), *x.shape[begin_params_axis:])) | |||||
| dg = np.sum(dy * np.power(var + epsilon, -0.5) * (x - mean), axis=tuple(param_axis), keepdims=True) | |||||
| db = np.sum(dy, axis=tuple(param_axis), keepdims=True) | |||||
| sum1 = np.sum((-0.5) * dy * gamma * (x - mean) * np.power(var + epsilon, -1.5), axis=tuple(norm_axis), keepdims=True) | |||||
| sum2 = np.sum(dy * gamma, axis=tuple(norm_axis), keepdims=True) | |||||
| sum3 = np.sum(-2.0 * (x - mean), axis=tuple(norm_axis), keepdims=True) | |||||
| dx1 = dy * gamma * np.power(var + epsilon, -0.5) | |||||
| dx2 = sum1 * 2.0 / num * (x - mean) | |||||
| dx3 = ((-1.0) * np.power(var + epsilon, -0.5) * sum2 + (1.0 / num) * sum1 * sum3) * (1.0 / num) | |||||
| dx = dx1 + dx2 + dx3 | |||||
| return dx, dg, db, mean, var | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_layernormgrad0(): | |||||
| begin_norm_axis = 1 | |||||
| begin_params_axis = 1 | |||||
| x_np = np.random.randn(4096, 3072).astype(np.float32) | |||||
| dy_np = np.random.randn(4096, 3072).astype(np.float32) | |||||
| gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32) | |||||
| epsilon = 10e-12 | |||||
| dx_np, dg_np, db_np, mean_np, var_np = LayerNormGradReference(x_np, dy_np, gamma_np, epsilon, begin_norm_axis, begin_params_axis) | |||||
| dy_ms = Tensor(dy_np) | |||||
| x_ms = Tensor(x_np) | |||||
| var_ms = Tensor(var_np) | |||||
| mean_ms = Tensor(mean_np) | |||||
| gamma_ms = Tensor(gamma_np) | |||||
| net = LayerNormGradNet(begin_norm_axis, begin_params_axis) | |||||
| dx_ms, dg_ms, db_ms = net(dy_ms, x_ms, var_ms, mean_ms, gamma_ms) | |||||
| assert np.allclose(dx_ms.asnumpy(), dx_np, rtol=1e-6, atol=1e-6) | |||||
| assert np.allclose(dg_ms.asnumpy(), dg_np, rtol=1e-6, atol=1e-3) | |||||
| assert np.allclose(db_ms.asnumpy(), db_np, rtol=1e-6, atol=1e-3) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_layernormgrad1(): | |||||
| begin_norm_axis = 1 | |||||
| begin_params_axis = 1 | |||||
| x_np = np.random.randn(640, 768).astype(np.float32) | |||||
| dy_np = np.random.randn(640, 768).astype(np.float32) | |||||
| gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32) | |||||
| epsilon = 10e-12 | |||||
| dx_np, dg_np, db_np, mean_np, var_np = LayerNormGradReference(x_np, dy_np, gamma_np, epsilon, begin_norm_axis, begin_params_axis) | |||||
| dy_ms = Tensor(dy_np) | |||||
| x_ms = Tensor(x_np) | |||||
| var_ms = Tensor(var_np) | |||||
| mean_ms = Tensor(mean_np) | |||||
| gamma_ms = Tensor(gamma_np) | |||||
| net = LayerNormGradNet(begin_norm_axis, begin_params_axis) | |||||
| dx_ms, dg_ms, db_ms = net(dy_ms, x_ms, var_ms, mean_ms, gamma_ms) | |||||
| assert np.allclose(dx_ms.asnumpy(), dx_np, rtol=1e-6, atol=1e-6) | |||||
| assert np.allclose(dg_ms.asnumpy(), dg_np, rtol=1e-6, atol=1e-3) | |||||
| assert np.allclose(db_ms.asnumpy(), db_np, rtol=1e-6, atol=1e-3) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_layernormgrad2(): | |||||
| begin_norm_axis = -1 | |||||
| begin_params_axis = -1 | |||||
| x_np = np.random.randn(32, 128, 768).astype(np.float32) | |||||
| dy_np = np.random.randn(32, 128, 768).astype(np.float32) | |||||
| gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32) | |||||
| epsilon = 10e-12 | |||||
| dx_np, dg_np, db_np, mean_np, var_np = LayerNormGradReference(x_np, dy_np, gamma_np, epsilon, begin_norm_axis, begin_params_axis) | |||||
| dy_ms = Tensor(dy_np) | |||||
| x_ms = Tensor(x_np) | |||||
| var_ms = Tensor(var_np) | |||||
| mean_ms = Tensor(mean_np) | |||||
| gamma_ms = Tensor(gamma_np) | |||||
| net = LayerNormGradNet(begin_norm_axis, begin_params_axis) | |||||
| dx_ms, dg_ms, db_ms = net(dy_ms, x_ms, var_ms, mean_ms, gamma_ms) | |||||
| assert np.allclose(dx_ms.asnumpy(), dx_np, rtol=1e-6, atol=1e-6) | |||||
| assert np.allclose(dg_ms.asnumpy(), dg_np, rtol=1e-6, atol=1e-3) | |||||
| assert np.allclose(db_ms.asnumpy(), db_np, rtol=1e-6, atol=1e-3) | |||||
| @@ -0,0 +1,134 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import pytest | |||||
| import numpy as np | |||||
| from mindspore import Tensor | |||||
| from mindspore.ops import operations as P | |||||
| import mindspore.nn as nn | |||||
| import mindspore.context as context | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||||
| class LayerNormNet(nn.Cell): | |||||
| def __init__(self, begin_norm_axis, begin_params_axis): | |||||
| super(LayerNormNet, self).__init__() | |||||
| self.norm = P.LayerNorm(begin_norm_axis, begin_params_axis) | |||||
| def construct(self, x, gamma, beta): | |||||
| return self.norm(x, gamma, beta) | |||||
| def LayerNormReference(begin_norm_axis, begin_params_axis, x, gamma, beta): | |||||
| begin_norm_axis = begin_norm_axis if begin_norm_axis >=0 else begin_norm_axis + len(x.shape) | |||||
| begin_params_axis = begin_params_axis if begin_params_axis >=0 else begin_params_axis + len(x.shape) | |||||
| axis = [i for i in range(begin_norm_axis, len(x.shape))] | |||||
| mean = np.mean(x, axis=tuple(axis), keepdims=True) | |||||
| var = np.var(x, axis=tuple(axis), keepdims=True) | |||||
| gamma = gamma.reshape((*((1,)*begin_params_axis), *x.shape[begin_params_axis:])) | |||||
| beta = beta.reshape((*((1,)*begin_params_axis), *x.shape[begin_params_axis:])) | |||||
| y = np.subtract(x, mean) / np.sqrt(var + 1e-12) * gamma + beta | |||||
| return y, mean, var | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_layernorm0(): | |||||
| begin_norm_axis = 1 | |||||
| begin_params_axis = 1 | |||||
| x_np = np.random.randn(4096, 3072).astype(np.float32) | |||||
| gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32) | |||||
| beta_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32) | |||||
| y_np, mean_np, var_np = LayerNormReference(begin_norm_axis, begin_params_axis, x_np, gamma_np, beta_np) | |||||
| x_ms = Tensor(x_np) | |||||
| gamma_ms = Tensor(gamma_np) | |||||
| beta_ms = Tensor(beta_np) | |||||
| net = LayerNormNet(begin_norm_axis, begin_params_axis) | |||||
| y_ms, mean_ms, var_ms = net(x_ms, gamma_ms, beta_ms) | |||||
| assert np.allclose(y_ms.asnumpy(), y_np, atol=1e-6) | |||||
| assert np.allclose(mean_ms.asnumpy(), mean_np, atol=1e-6) | |||||
| assert np.allclose(var_ms.asnumpy(), var_np, atol=1e-6) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_layernorm1(): | |||||
| begin_norm_axis = 1 | |||||
| begin_params_axis = 1 | |||||
| x_np = np.random.randn(640, 768).astype(np.float32) | |||||
| gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32) | |||||
| beta_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32) | |||||
| y_np, mean_np, var_np = LayerNormReference(begin_norm_axis, begin_params_axis, x_np, gamma_np, beta_np) | |||||
| x_ms = Tensor(x_np) | |||||
| gamma_ms = Tensor(gamma_np) | |||||
| beta_ms = Tensor(beta_np) | |||||
| net = LayerNormNet(begin_norm_axis, begin_params_axis) | |||||
| y_ms, mean_ms, var_ms = net(x_ms, gamma_ms, beta_ms) | |||||
| assert np.allclose(y_ms.asnumpy(), y_np, rtol=1e-6, atol=1e-6) | |||||
| assert np.allclose(mean_ms.asnumpy(), mean_np, rtol=1e-6, atol=1e-6) | |||||
| assert np.allclose(var_ms.asnumpy(), var_np, rtol=1e-6, atol=1e-6) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_layernorm3d_1(): | |||||
| begin_norm_axis = -1 | |||||
| begin_params_axis = -1 | |||||
| x_np = np.random.randn(32, 128, 768).astype(np.float32) | |||||
| gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32) | |||||
| beta_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32) | |||||
| y_np, mean_np, var_np = LayerNormReference(begin_norm_axis, begin_params_axis, x_np, gamma_np, beta_np) | |||||
| x_ms = Tensor(x_np) | |||||
| gamma_ms = Tensor(gamma_np) | |||||
| beta_ms = Tensor(beta_np) | |||||
| net = LayerNormNet(begin_norm_axis, begin_params_axis) | |||||
| y_ms, mean_ms, var_ms = net(x_ms, gamma_ms, beta_ms) | |||||
| assert np.allclose(y_ms.asnumpy(), y_np, rtol=1e-6, atol=1e-6) | |||||
| assert np.allclose(mean_ms.asnumpy(), mean_np, rtol=1e-6, atol=1e-6) | |||||
| assert np.allclose(var_ms.asnumpy(), var_np, rtol=1e-6, atol=1e-6) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_layernorm3d_2(): | |||||
| begin_norm_axis = -1 | |||||
| begin_params_axis = 1 | |||||
| x_np = np.random.randn(32, 128, 768).astype(np.float32) | |||||
| gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32) | |||||
| beta_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32) | |||||
| y_np, mean_np, var_np = LayerNormReference(begin_norm_axis, begin_params_axis, x_np, gamma_np, beta_np) | |||||
| x_ms = Tensor(x_np) | |||||
| gamma_ms = Tensor(gamma_np) | |||||
| beta_ms = Tensor(beta_np) | |||||
| net = LayerNormNet(begin_norm_axis, begin_params_axis) | |||||
| y_ms, mean_ms, var_ms = net(x_ms, gamma_ms, beta_ms) | |||||
| assert np.allclose(y_ms.asnumpy(), y_np, rtol=1e-6, atol=1e-6) | |||||
| assert np.allclose(mean_ms.asnumpy(), mean_np, rtol=1e-6, atol=1e-6) | |||||
| assert np.allclose(var_ms.asnumpy(), var_np, rtol=1e-6, atol=1e-6) | |||||