| @@ -476,6 +476,20 @@ extern "C" __global__ void __launch_bounds__(256) | |||||
| __syncthreads(); | __syncthreads(); | ||||
| } | } | ||||
| size_t oc = bidy * BM + (warp_y << 6) + 16 * idx_in_quad; | |||||
| const float* bias_ptr = bias + oc; | |||||
| int4 load_bias0 = make_int4(0, 0, 0, 0); | |||||
| int4 load_bias1 = make_int4(0, 0, 0, 0); | |||||
| int4 load_bias2 = make_int4(0, 0, 0, 0); | |||||
| int4 load_bias3 = make_int4(0, 0, 0, 0); | |||||
| if (oc < param.oc) { | |||||
| load_bias0 = *(reinterpret_cast<const int4*>(bias_ptr)); | |||||
| load_bias1 = *(reinterpret_cast<const int4*>(bias_ptr + 4)); | |||||
| load_bias2 = *(reinterpret_cast<const int4*>(bias_ptr + 8)); | |||||
| load_bias3 = *(reinterpret_cast<const int4*>(bias_ptr + 12)); | |||||
| } | |||||
| // read fuse_z | // read fuse_z | ||||
| int2 reg_fuse_z[reg_m] = {make_int2(z_zero_point, z_zero_point), | int2 reg_fuse_z[reg_m] = {make_int2(z_zero_point, z_zero_point), | ||||
| make_int2(z_zero_point, z_zero_point), | make_int2(z_zero_point, z_zero_point), | ||||
| @@ -595,18 +609,7 @@ extern "C" __global__ void __launch_bounds__(256) | |||||
| __syncthreads(); | __syncthreads(); | ||||
| /// output | /// output | ||||
| size_t oc = bidy * BM + (warp_y << 6) + 16 * idx_in_quad; | |||||
| const float* bias_ptr = bias + oc; | |||||
| int4 load_bias0 = make_int4(0, 0, 0, 0); | |||||
| int4 load_bias1 = make_int4(0, 0, 0, 0); | |||||
| int4 load_bias2 = make_int4(0, 0, 0, 0); | |||||
| int4 load_bias3 = make_int4(0, 0, 0, 0); | |||||
| if (oc < param.oc) { | if (oc < param.oc) { | ||||
| load_bias0 = *(reinterpret_cast<const int4*>(bias_ptr)); | |||||
| load_bias1 = *(reinterpret_cast<const int4*>(bias_ptr + 4)); | |||||
| load_bias2 = *(reinterpret_cast<const int4*>(bias_ptr + 8)); | |||||
| load_bias3 = *(reinterpret_cast<const int4*>(bias_ptr + 12)); | |||||
| mul_v4(load_bias0, load_bias0, beta); | mul_v4(load_bias0, load_bias0, beta); | ||||
| mul_v4(load_bias1, load_bias1, beta); | mul_v4(load_bias1, load_bias1, beta); | ||||
| mul_v4(load_bias2, load_bias2, beta); | mul_v4(load_bias2, load_bias2, beta); | ||||
| @@ -617,7 +620,6 @@ extern "C" __global__ void __launch_bounds__(256) | |||||
| #pragma unroll | #pragma unroll | ||||
| for (int y = 0; y < reg_m; y += 4) { | for (int y = 0; y < reg_m; y += 4) { | ||||
| I2F_4x8(reg_acc, y, 0); | |||||
| FMA_4x8(reg_acc, y, 0, alpha, load_bias0, load_bias1, load_bias2, load_bias3); | FMA_4x8(reg_acc, y, 0, alpha, load_bias0, load_bias1, load_bias2, load_bias3); | ||||
| FUSE_Z_4x8(reg_acc, y, 0, reg_fuse_z, gamma, z_zero_point); | FUSE_Z_4x8(reg_acc, y, 0, reg_fuse_z, gamma, z_zero_point); | ||||
| PACK_F2I_WITH_RELU_4x8(reg_acc, y, 0, relu, dst_zero_point); | PACK_F2I_WITH_RELU_4x8(reg_acc, y, 0, relu, dst_zero_point); | ||||
| @@ -657,6 +657,20 @@ extern "C" __global__ void __launch_bounds__(256) | |||||
| __syncthreads(); | __syncthreads(); | ||||
| } | } | ||||
| size_t oc = bidy * BM + 16 * idx_in_quad; | |||||
| const float* bias_ptr = bias + oc; | |||||
| int4 load_bias0 = make_int4(0, 0, 0, 0); | |||||
| int4 load_bias1 = make_int4(0, 0, 0, 0); | |||||
| int4 load_bias2 = make_int4(0, 0, 0, 0); | |||||
| int4 load_bias3 = make_int4(0, 0, 0, 0); | |||||
| if (oc < param.oc) { | |||||
| load_bias0 = *(reinterpret_cast<const int4*>(bias_ptr)); | |||||
| load_bias1 = *(reinterpret_cast<const int4*>(bias_ptr + 4)); | |||||
| load_bias2 = *(reinterpret_cast<const int4*>(bias_ptr + 8)); | |||||
| load_bias3 = *(reinterpret_cast<const int4*>(bias_ptr + 12)); | |||||
| } | |||||
| // read fuse_z | // read fuse_z | ||||
| int2 reg_fuse_z[reg_m] = {make_int2(z_zero_point, z_zero_point), | int2 reg_fuse_z[reg_m] = {make_int2(z_zero_point, z_zero_point), | ||||
| make_int2(z_zero_point, z_zero_point), | make_int2(z_zero_point, z_zero_point), | ||||
| @@ -712,6 +726,14 @@ extern "C" __global__ void __launch_bounds__(256) | |||||
| reg_flt[0][j] = make_int4(x, y, z, w); | reg_flt[0][j] = make_int4(x, y, z, w); | ||||
| } | } | ||||
| /// output | |||||
| if (oc < param.oc) { | |||||
| mul_v4(load_bias0, load_bias0, beta); | |||||
| mul_v4(load_bias1, load_bias1, beta); | |||||
| mul_v4(load_bias2, load_bias2, beta); | |||||
| mul_v4(load_bias3, load_bias3, beta); | |||||
| } | |||||
| // compute | // compute | ||||
| #pragma unroll | #pragma unroll | ||||
| for (int k_inner = 0; k_inner < BKd32; k_inner++) { | for (int k_inner = 0; k_inner < BKd32; k_inner++) { | ||||
| @@ -773,35 +795,20 @@ extern "C" __global__ void __launch_bounds__(256) | |||||
| __syncthreads(); | __syncthreads(); | ||||
| /// output | |||||
| size_t oc = bidy * BM + 16 * idx_in_quad; | |||||
| const float* bias_ptr = bias + oc; | |||||
| int4 load_bias0 = make_int4(0, 0, 0, 0); | |||||
| int4 load_bias1 = make_int4(0, 0, 0, 0); | |||||
| int4 load_bias2 = make_int4(0, 0, 0, 0); | |||||
| int4 load_bias3 = make_int4(0, 0, 0, 0); | |||||
| if (oc < param.oc) { | |||||
| load_bias0 = *(reinterpret_cast<const int4*>(bias_ptr)); | |||||
| load_bias1 = *(reinterpret_cast<const int4*>(bias_ptr + 4)); | |||||
| load_bias2 = *(reinterpret_cast<const int4*>(bias_ptr + 8)); | |||||
| load_bias3 = *(reinterpret_cast<const int4*>(bias_ptr + 12)); | |||||
| mul_v4(load_bias0, load_bias0, beta); | |||||
| mul_v4(load_bias1, load_bias1, beta); | |||||
| mul_v4(load_bias2, load_bias2, beta); | |||||
| mul_v4(load_bias3, load_bias3, beta); | |||||
| } | |||||
| int8_t* __restrict__ g_dst_ptr = dst + d_offset; | int8_t* __restrict__ g_dst_ptr = dst + d_offset; | ||||
| FMA_1x8(reg_acc, 0, 0, alpha, load_bias0, load_bias1, load_bias2, load_bias3); | |||||
| fuse_z_1x8(reg_acc[0], 0, reg_fuse_z[0], gamma, z_zero_point); | |||||
| PACK_F2I_WITH_RELU_1x8(reg_acc, 0, 0, relu, dst_zero_point); | |||||
| #pragma unroll | #pragma unroll | ||||
| for (int y = 0; y < reg_m; y += 4) { | |||||
| I2F_4x8(reg_acc, y, 0); | |||||
| FMA_4x8(reg_acc, y, 0, alpha, load_bias0, load_bias1, load_bias2, load_bias3); | |||||
| FUSE_Z_4x8(reg_acc, y, 0, reg_fuse_z, gamma, z_zero_point); | |||||
| PACK_F2I_WITH_RELU_4x8(reg_acc, y, 0, relu, dst_zero_point); | |||||
| STG_AFTER_LDG_4x1(g_offset, reg_acc, y, 0); | |||||
| for (int y = 1; y < reg_m; y += 1) { | |||||
| FMA_1x8(reg_acc, y, 0, alpha, load_bias0, load_bias1, load_bias2, load_bias3); | |||||
| fuse_z_1x8(reg_acc[y], 0, reg_fuse_z[y], gamma, z_zero_point); | |||||
| PACK_F2I_WITH_RELU_1x8(reg_acc, y, 0, relu, dst_zero_point); | |||||
| STG_AFTER_LDG(g_offset[y - 1], reg_acc[y - 1][0], stg_guard[y - 1]); | |||||
| } | } | ||||
| STG_AFTER_LDG(g_offset[7], reg_acc[7][0], stg_guard[7]); | |||||
| #endif | #endif | ||||
| } | } | ||||
| } // namespace | } // namespace | ||||
| @@ -437,7 +437,7 @@ extern "C" __global__ void __launch_bounds__(256) | |||||
| cp_async_fence(); | cp_async_fence(); | ||||
| } | } | ||||
| bool only_one_stage = (stage == 1) ? true : false; | |||||
| bool only_one_stage = (stage == 1); | |||||
| if (stage >= 2) { | if (stage >= 2) { | ||||
| cp_async_wait(stages - 2); | cp_async_wait(stages - 2); | ||||
| } else { | } else { | ||||
| @@ -844,6 +844,20 @@ extern "C" __global__ void __launch_bounds__(256) | |||||
| cp_async_wait(stages - 2); | cp_async_wait(stages - 2); | ||||
| } | } | ||||
| size_t oc = bidy * BM + (warp_y << 6) + 16 * idx_in_quad; | |||||
| const float* bias_ptr = bias + oc; | |||||
| int4 load_bias0 = make_int4(0, 0, 0, 0); | |||||
| int4 load_bias1 = make_int4(0, 0, 0, 0); | |||||
| int4 load_bias2 = make_int4(0, 0, 0, 0); | |||||
| int4 load_bias3 = make_int4(0, 0, 0, 0); | |||||
| if (oc < param.oc) { | |||||
| load_bias0 = *(reinterpret_cast<const int4*>(bias_ptr)); | |||||
| load_bias1 = *(reinterpret_cast<const int4*>(bias_ptr + 4)); | |||||
| load_bias2 = *(reinterpret_cast<const int4*>(bias_ptr + 8)); | |||||
| load_bias3 = *(reinterpret_cast<const int4*>(bias_ptr + 12)); | |||||
| } | |||||
| if (!only_one_stage) { | if (!only_one_stage) { | ||||
| #pragma unroll // low | #pragma unroll // low | ||||
| for (int i = 0; i < reg_nd4; ++i) { | for (int i = 0; i < reg_nd4; ++i) { | ||||
| @@ -975,6 +989,13 @@ extern "C" __global__ void __launch_bounds__(256) | |||||
| reg_flt[0][j] = make_int4(x, y, z, w); | reg_flt[0][j] = make_int4(x, y, z, w); | ||||
| } | } | ||||
| if (oc < param.oc) { | |||||
| mul_v4(load_bias0, load_bias0, beta); | |||||
| mul_v4(load_bias1, load_bias1, beta); | |||||
| mul_v4(load_bias2, load_bias2, beta); | |||||
| mul_v4(load_bias3, load_bias3, beta); | |||||
| } | |||||
| // compute | // compute | ||||
| #pragma unroll | #pragma unroll | ||||
| for (int k_inner = 0; k_inner < BKd32; k_inner++) { | for (int k_inner = 0; k_inner < BKd32; k_inner++) { | ||||
| @@ -1038,34 +1059,20 @@ extern "C" __global__ void __launch_bounds__(256) | |||||
| __syncthreads(); | __syncthreads(); | ||||
| /// output | /// output | ||||
| size_t oc = bidy * BM + (warp_y << 6) + 16 * idx_in_quad; | |||||
| const float* bias_ptr = bias + oc; | |||||
| int4 load_bias0 = make_int4(0, 0, 0, 0); | |||||
| int4 load_bias1 = make_int4(0, 0, 0, 0); | |||||
| int4 load_bias2 = make_int4(0, 0, 0, 0); | |||||
| int4 load_bias3 = make_int4(0, 0, 0, 0); | |||||
| if (oc < param.oc) { | |||||
| load_bias0 = *(reinterpret_cast<const int4*>(bias_ptr)); | |||||
| load_bias1 = *(reinterpret_cast<const int4*>(bias_ptr + 4)); | |||||
| load_bias2 = *(reinterpret_cast<const int4*>(bias_ptr + 8)); | |||||
| load_bias3 = *(reinterpret_cast<const int4*>(bias_ptr + 12)); | |||||
| mul_v4(load_bias0, load_bias0, beta); | |||||
| mul_v4(load_bias1, load_bias1, beta); | |||||
| mul_v4(load_bias2, load_bias2, beta); | |||||
| mul_v4(load_bias3, load_bias3, beta); | |||||
| } | |||||
| int8_t* __restrict__ g_dst_ptr = dst + d_offset; | int8_t* __restrict__ g_dst_ptr = dst + d_offset; | ||||
| FMA_1x8(reg_acc, 0, 0, alpha, load_bias0, load_bias1, load_bias2, load_bias3); | |||||
| fuse_z_1x8(reg_acc[0], 0, reg_fuse_z[0], gamma, z_zero_point); | |||||
| PACK_F2I_WITH_RELU_1x8(reg_acc, 0, 0, relu, dst_zero_point); | |||||
| #pragma unroll | #pragma unroll | ||||
| for (int y = 0; y < reg_m; y += 4) { | |||||
| I2F_4x8(reg_acc, y, 0); | |||||
| FMA_4x8(reg_acc, y, 0, alpha, load_bias0, load_bias1, load_bias2, load_bias3); | |||||
| FUSE_Z_4x8(reg_acc, y, 0, reg_fuse_z, gamma, z_zero_point); | |||||
| PACK_F2I_WITH_RELU_4x8(reg_acc, y, 0, relu, dst_zero_point); | |||||
| STG_AFTER_LDG_4x1(g_offset, reg_acc, y, 0); | |||||
| for (int y = 1; y < reg_m; y += 1) { | |||||
| FMA_1x8(reg_acc, y, 0, alpha, load_bias0, load_bias1, load_bias2, load_bias3); | |||||
| fuse_z_1x8(reg_acc[y], 0, reg_fuse_z[y], gamma, z_zero_point); | |||||
| PACK_F2I_WITH_RELU_1x8(reg_acc, y, 0, relu, dst_zero_point); | |||||
| STG_AFTER_LDG(g_offset[y - 1], reg_acc[y - 1][0], stg_guard[y - 1]); | |||||
| } | } | ||||
| STG_AFTER_LDG(g_offset[7], reg_acc[7][0], stg_guard[7]); | |||||
| #endif | #endif | ||||
| } | } | ||||
| } // namespace | } // namespace | ||||
| @@ -475,6 +475,20 @@ extern "C" __global__ void __launch_bounds__(256) | |||||
| __syncthreads(); | __syncthreads(); | ||||
| } | } | ||||
| size_t oc = bidy * BM + (warp_y << 6) + 16 * idx_in_quad; | |||||
| const float* bias_ptr = bias + oc; | |||||
| int4 load_bias0 = make_int4(0, 0, 0, 0); | |||||
| int4 load_bias1 = make_int4(0, 0, 0, 0); | |||||
| int4 load_bias2 = make_int4(0, 0, 0, 0); | |||||
| int4 load_bias3 = make_int4(0, 0, 0, 0); | |||||
| if (oc < param.oc) { | |||||
| load_bias0 = *(reinterpret_cast<const int4*>(bias_ptr)); | |||||
| load_bias1 = *(reinterpret_cast<const int4*>(bias_ptr + 4)); | |||||
| load_bias2 = *(reinterpret_cast<const int4*>(bias_ptr + 8)); | |||||
| load_bias3 = *(reinterpret_cast<const int4*>(bias_ptr + 12)); | |||||
| } | |||||
| guard = iter < 0; | guard = iter < 0; | ||||
| #pragma unroll | #pragma unroll | ||||
| for (int i = 0; i < reg_nd4; ++i) { | for (int i = 0; i < reg_nd4; ++i) { | ||||
| @@ -574,18 +588,8 @@ extern "C" __global__ void __launch_bounds__(256) | |||||
| size_t nhw_post3 = nhw_post0 + 24; | size_t nhw_post3 = nhw_post0 + 24; | ||||
| size_t stg_oc = bidy * BM + (warp_y << 6); | size_t stg_oc = bidy * BM + (warp_y << 6); | ||||
| size_t oc = bidy * BM + (warp_y << 6) + 16 * idx_in_quad; | |||||
| const float* bias_ptr = bias + oc; | |||||
| int4 load_bias0 = make_int4(0, 0, 0, 0); | |||||
| int4 load_bias1 = make_int4(0, 0, 0, 0); | |||||
| int4 load_bias2 = make_int4(0, 0, 0, 0); | |||||
| int4 load_bias3 = make_int4(0, 0, 0, 0); | |||||
| if (oc < param.oc) { | if (oc < param.oc) { | ||||
| load_bias0 = *(reinterpret_cast<const int4*>(bias_ptr)); | |||||
| load_bias1 = *(reinterpret_cast<const int4*>(bias_ptr + 4)); | |||||
| load_bias2 = *(reinterpret_cast<const int4*>(bias_ptr + 8)); | |||||
| load_bias3 = *(reinterpret_cast<const int4*>(bias_ptr + 12)); | |||||
| mul_v4(load_bias0, load_bias0, beta); | mul_v4(load_bias0, load_bias0, beta); | ||||
| mul_v4(load_bias1, load_bias1, beta); | mul_v4(load_bias1, load_bias1, beta); | ||||
| mul_v4(load_bias2, load_bias2, beta); | mul_v4(load_bias2, load_bias2, beta); | ||||
| @@ -599,7 +603,6 @@ extern "C" __global__ void __launch_bounds__(256) | |||||
| #pragma unroll | #pragma unroll | ||||
| for (int y = 0; y < reg_m; y += 4) { | for (int y = 0; y < reg_m; y += 4) { | ||||
| I2F_4x8(reg_acc, y, 0); | |||||
| FMA_4x8(reg_acc, y, 0, alpha, load_bias0, load_bias1, load_bias2, load_bias3); | FMA_4x8(reg_acc, y, 0, alpha, load_bias0, load_bias1, load_bias2, load_bias3); | ||||
| PACK_F2I_WITH_RELU_4x8(reg_acc, y, 0, relu, dst_zero_point); | PACK_F2I_WITH_RELU_4x8(reg_acc, y, 0, relu, dst_zero_point); | ||||
| STG_4x1(stg_ptr, reg_acc, y, 0); | STG_4x1(stg_ptr, reg_acc, y, 0); | ||||
| @@ -659,6 +659,20 @@ extern "C" __global__ void __launch_bounds__(256) | |||||
| __syncthreads(); | __syncthreads(); | ||||
| } | } | ||||
| size_t oc = bidy * BM + 16 * idx_in_quad; | |||||
| const float* bias_ptr = bias + oc; | |||||
| int4 load_bias0 = make_int4(0, 0, 0, 0); | |||||
| int4 load_bias1 = make_int4(0, 0, 0, 0); | |||||
| int4 load_bias2 = make_int4(0, 0, 0, 0); | |||||
| int4 load_bias3 = make_int4(0, 0, 0, 0); | |||||
| if (oc < param.oc) { | |||||
| load_bias0 = *(reinterpret_cast<const int4*>(bias_ptr)); | |||||
| load_bias1 = *(reinterpret_cast<const int4*>(bias_ptr + 4)); | |||||
| load_bias2 = *(reinterpret_cast<const int4*>(bias_ptr + 8)); | |||||
| load_bias3 = *(reinterpret_cast<const int4*>(bias_ptr + 12)); | |||||
| } | |||||
| guard = iter < 0; | guard = iter < 0; | ||||
| #pragma unroll // low | #pragma unroll // low | ||||
| for (int i = 0; i < reg_nd4; ++i) { | for (int i = 0; i < reg_nd4; ++i) { | ||||
| @@ -755,18 +769,8 @@ extern "C" __global__ void __launch_bounds__(256) | |||||
| size_t nhw_post3 = nhw_post0 + 24; | size_t nhw_post3 = nhw_post0 + 24; | ||||
| size_t stg_oc = bidy * BM; | size_t stg_oc = bidy * BM; | ||||
| size_t oc = bidy * BM + 16 * idx_in_quad; | |||||
| const float* bias_ptr = bias + oc; | |||||
| int4 load_bias0 = make_int4(0, 0, 0, 0); | |||||
| int4 load_bias1 = make_int4(0, 0, 0, 0); | |||||
| int4 load_bias2 = make_int4(0, 0, 0, 0); | |||||
| int4 load_bias3 = make_int4(0, 0, 0, 0); | |||||
| if (oc < param.oc) { | if (oc < param.oc) { | ||||
| load_bias0 = *(reinterpret_cast<const int4*>(bias_ptr)); | |||||
| load_bias1 = *(reinterpret_cast<const int4*>(bias_ptr + 4)); | |||||
| load_bias2 = *(reinterpret_cast<const int4*>(bias_ptr + 8)); | |||||
| load_bias3 = *(reinterpret_cast<const int4*>(bias_ptr + 12)); | |||||
| mul_v4(load_bias0, load_bias0, beta); | mul_v4(load_bias0, load_bias0, beta); | ||||
| mul_v4(load_bias1, load_bias1, beta); | mul_v4(load_bias1, load_bias1, beta); | ||||
| mul_v4(load_bias2, load_bias2, beta); | mul_v4(load_bias2, load_bias2, beta); | ||||
| @@ -779,7 +783,6 @@ extern "C" __global__ void __launch_bounds__(256) | |||||
| #pragma unroll | #pragma unroll | ||||
| for (int y = 0; y < reg_m; y += 4) { | for (int y = 0; y < reg_m; y += 4) { | ||||
| I2F_4x8(reg_acc, y, 0); | |||||
| FMA_4x8(reg_acc, y, 0, alpha, load_bias0, load_bias1, load_bias2, load_bias3); | FMA_4x8(reg_acc, y, 0, alpha, load_bias0, load_bias1, load_bias2, load_bias3); | ||||
| PACK_F2I_WITH_RELU_4x8(reg_acc, y, 0, relu, dst_zero_point); | PACK_F2I_WITH_RELU_4x8(reg_acc, y, 0, relu, dst_zero_point); | ||||
| STG_4x1(stg_ptr, reg_acc, y, 0); | STG_4x1(stg_ptr, reg_acc, y, 0); | ||||
| @@ -449,15 +449,15 @@ extern "C" __global__ void __launch_bounds__(256) | |||||
| bool stg_guard[8]; | bool stg_guard[8]; | ||||
| #pragma unroll | #pragma unroll | ||||
| for (int y = 0; y < reg_m; y += 4) { | for (int y = 0; y < reg_m; y += 4) { | ||||
| COMPUTE_OFFSET_4x1(reg_fuse_z, g_offset, y) | |||||
| COMPUTE_OFFSET_4x1(g_offset, y); | |||||
| nhw_post0 += 32; | |||||
| nhw_post0 += 32; | |||||
| nhw_post1 += 32; | nhw_post1 += 32; | ||||
| nhw_post2 += 32; | nhw_post2 += 32; | ||||
| nhw_post3 += 32; | nhw_post3 += 32; | ||||
| } | } | ||||
| bool only_one_stage = (stage == 1) ? true : false; | |||||
| bool only_one_stage = (stage == 1); | |||||
| if (stage >= 2) { | if (stage >= 2) { | ||||
| cp_async_wait(stages - 2); | cp_async_wait(stages - 2); | ||||
| } else { | } else { | ||||
| @@ -835,6 +835,20 @@ extern "C" __global__ void __launch_bounds__(256) | |||||
| cp_async_wait(stages - 2); | cp_async_wait(stages - 2); | ||||
| } | } | ||||
| size_t oc = bidy * BM + (warp_y << 6) + 16 * idx_in_quad; | |||||
| const float* bias_ptr = bias + oc; | |||||
| int4 load_bias0 = make_int4(0, 0, 0, 0); | |||||
| int4 load_bias1 = make_int4(0, 0, 0, 0); | |||||
| int4 load_bias2 = make_int4(0, 0, 0, 0); | |||||
| int4 load_bias3 = make_int4(0, 0, 0, 0); | |||||
| if (oc < param.oc) { | |||||
| load_bias0 = *(reinterpret_cast<const int4*>(bias_ptr)); | |||||
| load_bias1 = *(reinterpret_cast<const int4*>(bias_ptr + 4)); | |||||
| load_bias2 = *(reinterpret_cast<const int4*>(bias_ptr + 8)); | |||||
| load_bias3 = *(reinterpret_cast<const int4*>(bias_ptr + 12)); | |||||
| } | |||||
| if (!only_one_stage) { | if (!only_one_stage) { | ||||
| #pragma unroll // low | #pragma unroll // low | ||||
| for (int i = 0; i < reg_nd4; ++i) { | for (int i = 0; i < reg_nd4; ++i) { | ||||
| @@ -965,6 +979,13 @@ extern "C" __global__ void __launch_bounds__(256) | |||||
| reg_flt[0][j] = make_int4(x, y, z, w); | reg_flt[0][j] = make_int4(x, y, z, w); | ||||
| } | } | ||||
| if (oc < param.oc) { | |||||
| mul_v4(load_bias0, load_bias0, beta); | |||||
| mul_v4(load_bias1, load_bias1, beta); | |||||
| mul_v4(load_bias2, load_bias2, beta); | |||||
| mul_v4(load_bias3, load_bias3, beta); | |||||
| } | |||||
| // compute | // compute | ||||
| #pragma unroll | #pragma unroll | ||||
| for (int k_inner = 0; k_inner < BKd32; k_inner++) { | for (int k_inner = 0; k_inner < BKd32; k_inner++) { | ||||
| @@ -1028,38 +1049,19 @@ extern "C" __global__ void __launch_bounds__(256) | |||||
| __syncthreads(); | __syncthreads(); | ||||
| /// output | /// output | ||||
| size_t oc = bidy * BM + (warp_y << 6) + 16 * idx_in_quad; | |||||
| const float* bias_ptr = bias + oc; | |||||
| int4 load_bias0 = make_int4(0, 0, 0, 0); | |||||
| int4 load_bias1 = make_int4(0, 0, 0, 0); | |||||
| int4 load_bias2 = make_int4(0, 0, 0, 0); | |||||
| int4 load_bias3 = make_int4(0, 0, 0, 0); | |||||
| if (oc < param.oc) { | |||||
| load_bias0 = *(reinterpret_cast<const int4*>(bias_ptr)); | |||||
| load_bias1 = *(reinterpret_cast<const int4*>(bias_ptr + 4)); | |||||
| load_bias2 = *(reinterpret_cast<const int4*>(bias_ptr + 8)); | |||||
| load_bias3 = *(reinterpret_cast<const int4*>(bias_ptr + 12)); | |||||
| mul_v4(load_bias0, load_bias0, beta); | |||||
| mul_v4(load_bias1, load_bias1, beta); | |||||
| mul_v4(load_bias2, load_bias2, beta); | |||||
| mul_v4(load_bias3, load_bias3, beta); | |||||
| } | |||||
| int8_t* __restrict__ g_dst_ptr = dst + d_offset; | int8_t* __restrict__ g_dst_ptr = dst + d_offset; | ||||
| #pragma unroll | |||||
| for (int y = 0; y < reg_m; y += 4) { | |||||
| I2F_4x8(reg_acc, y, 0); | |||||
| FMA_4x8(reg_acc, y, 0, alpha, load_bias0, load_bias1, load_bias2, load_bias3); | |||||
| PACK_F2I_WITH_RELU_4x8(reg_acc, y, 0, relu, dst_zero_point); | |||||
| STG_AFTER_LDG_4x1(g_offset, reg_acc, y, 0); | |||||
| FMA_1x8(reg_acc, 0, 0, alpha, load_bias0, load_bias1, load_bias2, load_bias3); | |||||
| PACK_F2I_WITH_RELU_1x8(reg_acc, 0, 0, relu, dst_zero_point); | |||||
| nhw_post0 += 32; | |||||
| nhw_post1 += 32; | |||||
| nhw_post2 += 32; | |||||
| nhw_post3 += 32; | |||||
| #pragma unroll | |||||
| for (int y = 1; y < reg_m; y += 1) { | |||||
| FMA_1x8(reg_acc, y, 0, alpha, load_bias0, load_bias1, load_bias2, load_bias3); | |||||
| PACK_F2I_WITH_RELU_1x8(reg_acc, y, 0, relu, dst_zero_point); | |||||
| STG_AFTER_LDG(g_offset[y - 1], reg_acc[y - 1][0], stg_guard[y - 1]); | |||||
| } | } | ||||
| STG_AFTER_LDG(g_offset[7], reg_acc[7][0], stg_guard[7]); | |||||
| #endif | #endif | ||||
| } | } | ||||
| } // namespace | } // namespace | ||||
| @@ -23,78 +23,26 @@ __device__ __forceinline__ void mul_v4<float>( | |||||
| __device__ __forceinline__ void fma2( | __device__ __forceinline__ void fma2( | ||||
| int2& c0, const int2 a0, int2& c1, const int2 a1, const float alpha, | int2& c0, const int2 a0, int2& c1, const int2 a1, const float alpha, | ||||
| const int4 b) { | const int4 b) { | ||||
| asm("fma.rz.f32 %0, %1, %2, %3;" | |||||
| : "=f"(((float*)&c0)[0]) | |||||
| : "f"(((float*)&a0)[0]), "f"(alpha), "f"(((float*)&b)[0])); | |||||
| asm("fma.rz.f32 %0, %1, %2, %3;" | |||||
| : "=f"(((float*)&c0)[1]) | |||||
| : "f"(((float*)&a0)[1]), "f"(alpha), "f"(((float*)&b)[1])); | |||||
| asm("fma.rz.f32 %0, %1, %2, %3;" | |||||
| : "=f"(((float*)&c1)[0]) | |||||
| : "f"(((float*)&a1)[0]), "f"(alpha), "f"(((float*)&b)[2])); | |||||
| asm("fma.rz.f32 %0, %1, %2, %3;" | |||||
| : "=f"(((float*)&c1)[1]) | |||||
| : "f"(((float*)&a1)[1]), "f"(alpha), "f"(((float*)&b)[3])); | |||||
| } | |||||
| __device__ __forceinline__ void fuse_z_1x8( | |||||
| int4* a, const int& j, const int4& fuse_z, const float& gamma, | |||||
| const int32_t& zero_point) { | |||||
| const int2 z[2] = { | |||||
| *reinterpret_cast<const int2*>(&fuse_z), | |||||
| *(reinterpret_cast<const int2*>(&fuse_z) + 1)}; | |||||
| for (int k = 0; k < 4; k++) { | |||||
| int f = ((z[0].x >> (k * 8)) & 15); | |||||
| f = (f << 28) >> 28; | |||||
| ((float*)&(a[j + k]))[0] += (f - zero_point) * gamma; | |||||
| f = ((z[0].x >> (k * 8 + 4)) & 15); | |||||
| f = (f << 28) >> 28; | |||||
| ((float*)&(a[j + k]))[1] += (f - zero_point) * gamma; | |||||
| f = ((z[1].x >> (k * 8)) & 15); | |||||
| f = (f << 28) >> 28; | |||||
| ((float*)&(a[j + k]))[2] += (f - zero_point) * gamma; | |||||
| f = ((z[1].x >> (k * 8 + 4)) & 15); | |||||
| f = (f << 28) >> 28; | |||||
| ((float*)&(a[j + k]))[3] += (f - zero_point) * gamma; | |||||
| } | |||||
| for (int k = 0; k < 4; k++) { | |||||
| int f = ((z[0].y >> (k * 8)) & 15); | |||||
| f = (f << 28) >> 28; | |||||
| ((float*)&(a[j + k + 4]))[0] += (f - zero_point) * gamma; | |||||
| f = ((z[0].y >> (k * 8 + 4)) & 15); | |||||
| f = (f << 28) >> 28; | |||||
| ((float*)&(a[j + k + 4]))[1] += (f - zero_point) * gamma; | |||||
| f = ((z[1].y >> (k * 8)) & 15); | |||||
| f = (f << 28) >> 28; | |||||
| ((float*)&(a[j + k + 4]))[2] += (f - zero_point) * gamma; | |||||
| f = ((z[1].y >> (k * 8 + 4)) & 15); | |||||
| f = (f << 28) >> 28; | |||||
| ((float*)&(a[j + k + 4]))[3] += (f - zero_point) * gamma; | |||||
| } | |||||
| ((float*)&c0)[0] = a0.x * alpha + ((float*)&b)[0]; | |||||
| ((float*)&c0)[1] = a0.y * alpha + ((float*)&b)[1]; | |||||
| ((float*)&c1)[0] = a1.x * alpha + ((float*)&b)[2]; | |||||
| ((float*)&c1)[1] = a1.y * alpha + ((float*)&b)[3]; | |||||
| } | } | ||||
| __device__ __forceinline__ void fuse_z_1x8( | __device__ __forceinline__ void fuse_z_1x8( | ||||
| int2* a, const int& j, const int2& fuse_z, const float& gamma, | int2* a, const int& j, const int2& fuse_z, const float& gamma, | ||||
| const int32_t& zero_point) { | const int32_t& zero_point) { | ||||
| float x = zero_point * gamma; | |||||
| #pragma unroll | #pragma unroll | ||||
| for (int k = 0; k < 4; k++) { | for (int k = 0; k < 4; k++) { | ||||
| int f = ((fuse_z.x >> (k * 8)) & 15); | int f = ((fuse_z.x >> (k * 8)) & 15); | ||||
| f = (f << 28) >> 28; | |||||
| ((float*)&(a[j + k]))[0] += (f - zero_point) * gamma; | |||||
| ((float*)&(a[j + k]))[0] += f * gamma - x; | |||||
| f = ((fuse_z.x >> (k * 8 + 4)) & 15); | f = ((fuse_z.x >> (k * 8 + 4)) & 15); | ||||
| f = (f << 28) >> 28; | |||||
| ((float*)&(a[j + k]))[1] += (f - zero_point) * gamma; | |||||
| } | |||||
| #pragma unroll | |||||
| for (int k = 0; k < 4; k++) { | |||||
| int f = ((fuse_z.y >> (k * 8)) & 15); | |||||
| f = (f << 28) >> 28; | |||||
| ((float*)&(a[j + k + 4]))[0] += (f - zero_point) * gamma; | |||||
| ((float*)&(a[j + k]))[1] += f * gamma - x; | |||||
| f = ((fuse_z.y >> (k * 8)) & 15); | |||||
| ((float*)&(a[j + k + 4]))[0] += f * gamma - x; | |||||
| f = ((fuse_z.y >> (k * 8 + 4)) & 15); | f = ((fuse_z.y >> (k * 8 + 4)) & 15); | ||||
| f = (f << 28) >> 28; | |||||
| ((float*)&(a[j + k + 4]))[1] += (f - zero_point) * gamma; | |||||
| ((float*)&(a[j + k + 4]))[1] += f * gamma - x; | |||||
| } | } | ||||
| } | } | ||||
| @@ -282,12 +230,6 @@ __device__ __forceinline__ void pack_f2i_with_relu( | |||||
| fuse_z_1x8(a[i + 2], j, fuse_z[i + 2], gamma, zero_point); \ | fuse_z_1x8(a[i + 2], j, fuse_z[i + 2], gamma, zero_point); \ | ||||
| fuse_z_1x8(a[i + 3], j, fuse_z[i + 3], gamma, zero_point); | fuse_z_1x8(a[i + 3], j, fuse_z[i + 3], gamma, zero_point); | ||||
| #define FUSE_Z_4x8(a, i, j, fuse_z, gamma, zero_point) \ | |||||
| fuse_z_1x8(a[i], j, fuse_z[i], gamma, zero_point); \ | |||||
| fuse_z_1x8(a[i + 1], j, fuse_z[i + 1], gamma, zero_point); \ | |||||
| fuse_z_1x8(a[i + 2], j, fuse_z[i + 2], gamma, zero_point); \ | |||||
| fuse_z_1x8(a[i + 3], j, fuse_z[i + 3], gamma, zero_point); | |||||
| // 1x8 1x(2x8 int2) to 2 int2 | // 1x8 1x(2x8 int2) to 2 int2 | ||||
| #define PACK_F2I_1x8(a, i, j) \ | #define PACK_F2I_1x8(a, i, j) \ | ||||
| pack_f2i(a[i][j].x, a[i][j].z, a[i][j], a[i][j + 1], a[i][j + 2], a[i][j + 3]); \ | pack_f2i(a[i][j].x, a[i][j].z, a[i][j], a[i][j + 1], a[i][j + 2], a[i][j + 3]); \ | ||||
| @@ -316,24 +258,20 @@ __device__ __forceinline__ void pack_f2i_with_relu( | |||||
| stg_guard[i + 2]) \ | stg_guard[i + 2]) \ | ||||
| LDG(d[i + 3], s[i + 3], 3, reg_src_cache[0].w, reg_src_cache[1].w, stg_guard[i + 3]) | LDG(d[i + 3], s[i + 3], 3, reg_src_cache[0].w, reg_src_cache[1].w, stg_guard[i + 3]) | ||||
| #define COMPUTE_OFFSET(d, s, idx, n_reuse, hw_reuse, g) \ | |||||
| #define COMPUTE_OFFSET(s, idx, n_reuse, hw_reuse, g) \ | |||||
| n_reuse = nhw_post##idx / param.div_ohow; \ | n_reuse = nhw_post##idx / param.div_ohow; \ | ||||
| hw_reuse = nhw_post##idx % param.div_ohow; \ | hw_reuse = nhw_post##idx % param.div_ohow; \ | ||||
| s = n_reuse * param.obs + hw_reuse * (packed_channel >> 1); \ | s = n_reuse * param.obs + hw_reuse * (packed_channel >> 1); \ | ||||
| g = nhw_post##idx < param.nhw; | g = nhw_post##idx < param.nhw; | ||||
| #define COMPUTE_OFFSET_4x1(d, s, i) \ | |||||
| COMPUTE_OFFSET( \ | |||||
| d[i], s[i], 0, reg_src_cache[0].x, reg_src_cache[1].x, stg_guard[i]) \ | |||||
| COMPUTE_OFFSET( \ | |||||
| d[i + 1], s[i + 1], 1, reg_src_cache[0].y, reg_src_cache[1].y, \ | |||||
| stg_guard[i + 1]) \ | |||||
| COMPUTE_OFFSET( \ | |||||
| d[i + 2], s[i + 2], 2, reg_src_cache[0].z, reg_src_cache[1].z, \ | |||||
| stg_guard[i + 2]) \ | |||||
| COMPUTE_OFFSET( \ | |||||
| d[i + 3], s[i + 3], 3, reg_src_cache[0].w, reg_src_cache[1].w, \ | |||||
| stg_guard[i + 3]) | |||||
| #define COMPUTE_OFFSET_4x1(s, i) \ | |||||
| COMPUTE_OFFSET(s[i], 0, reg_src_cache[0].x, reg_src_cache[1].x, stg_guard[i]) \ | |||||
| COMPUTE_OFFSET( \ | |||||
| s[i + 1], 1, reg_src_cache[0].y, reg_src_cache[1].y, stg_guard[i + 1]) \ | |||||
| COMPUTE_OFFSET( \ | |||||
| s[i + 2], 2, reg_src_cache[0].z, reg_src_cache[1].z, stg_guard[i + 2]) \ | |||||
| COMPUTE_OFFSET( \ | |||||
| s[i + 3], 3, reg_src_cache[0].w, reg_src_cache[1].w, stg_guard[i + 3]) | |||||
| #define STG_AFTER_LDG(d, s, g) \ | #define STG_AFTER_LDG(d, s, g) \ | ||||
| if (stg_oc < param.oc && g) { \ | if (stg_oc < param.oc && g) { \ | ||||