|
|
|
@@ -18,6 +18,21 @@ typedef struct { |
|
|
|
uint8_t qs[QK4_1 / 2]; // nibbles / quants |
|
|
|
} block_q4_1; |
|
|
|
|
|
|
|
#define QK5_0 32 |
|
|
|
typedef struct { |
|
|
|
half d; // delta |
|
|
|
uint8_t qh[4]; // 5-th bit of quants |
|
|
|
uint8_t qs[QK5_0 / 2]; // nibbles / quants |
|
|
|
} block_q5_0; |
|
|
|
|
|
|
|
#define QK5_1 32 |
|
|
|
typedef struct { |
|
|
|
half d; // delta |
|
|
|
half m; // min |
|
|
|
uint8_t qh[4]; // 5-th bit of quants |
|
|
|
uint8_t qs[QK5_1 / 2]; // nibbles / quants |
|
|
|
} block_q5_1; |
|
|
|
|
|
|
|
#define QK8_0 32 |
|
|
|
typedef struct { |
|
|
|
half d; // delta |
|
|
|
@@ -110,9 +125,17 @@ kernel void kernel_mul_row( |
|
|
|
} |
|
|
|
|
|
|
|
kernel void kernel_scale( |
|
|
|
device const float * src0, |
|
|
|
device float * dst, |
|
|
|
constant float & scale, |
|
|
|
uint tpig[[thread_position_in_grid]]) { |
|
|
|
dst[tpig] = src0[tpig] * scale; |
|
|
|
} |
|
|
|
|
|
|
|
kernel void kernel_scale_4( |
|
|
|
device const float4 * src0, |
|
|
|
device float4 * dst, |
|
|
|
constant float & scale, |
|
|
|
constant float & scale, |
|
|
|
uint tpig[[thread_position_in_grid]]) { |
|
|
|
dst[tpig] = src0[tpig] * scale; |
|
|
|
} |
|
|
|
@@ -399,8 +422,11 @@ kernel void kernel_rms_norm( |
|
|
|
// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) |
|
|
|
inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) { |
|
|
|
float d = qb_curr->d; |
|
|
|
|
|
|
|
float2 acc = 0.f; |
|
|
|
|
|
|
|
device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/2); |
|
|
|
|
|
|
|
for (int i = 0; i < 8; i+=2) { |
|
|
|
acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F) |
|
|
|
+ yl[i + 1] * (qs[i / 2] & 0x0F00); |
|
|
|
@@ -417,8 +443,11 @@ inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thre |
|
|
|
inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl, int il) { |
|
|
|
float d = qb_curr->d; |
|
|
|
float m = qb_curr->m; |
|
|
|
device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2); |
|
|
|
|
|
|
|
float2 acc = 0.f; |
|
|
|
|
|
|
|
device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2); |
|
|
|
|
|
|
|
for (int i = 0; i < 8; i+=2) { |
|
|
|
acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F) |
|
|
|
+ yl[i + 1] * (qs[i / 2] & 0x0F00); |
|
|
|
@@ -428,6 +457,49 @@ inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thre |
|
|
|
return d * (acc[0] + acc[1]) + sumy * m; |
|
|
|
} |
|
|
|
|
|
|
|
// function for calculate inner product between half a q5_0 block and 16 floats (yl), sumy is SUM(yl[i]) |
|
|
|
// il indicates where the q5 quants begin (0 or QK5_0/4) |
|
|
|
// we assume that the yl's have been multiplied with the appropriate scale factor |
|
|
|
// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) |
|
|
|
inline float block_q_n_dot_y(device const block_q5_0 * qb_curr, float sumy, thread float * yl, int il) { |
|
|
|
float d = qb_curr->d; |
|
|
|
|
|
|
|
float2 acc = 0.f; |
|
|
|
|
|
|
|
device const uint16_t * qs = ((device const uint16_t *)qb_curr + 3 + il/2); |
|
|
|
const uint32_t qh = *((device const uint32_t *)qb_curr->qh); |
|
|
|
|
|
|
|
for (int i = 0; i < 8; i+=2) { |
|
|
|
acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010)) |
|
|
|
+ yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000)); |
|
|
|
acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100)) |
|
|
|
+ yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000)); |
|
|
|
} |
|
|
|
return d * (sumy * -16.f + acc[0] + acc[1]); |
|
|
|
} |
|
|
|
|
|
|
|
// function for calculate inner product between half a q5_1 block and 16 floats (yl), sumy is SUM(yl[i]) |
|
|
|
// il indicates where the q5 quants begin (0 or QK5_1/4) |
|
|
|
// we assume that the yl's have been multiplied with the appropriate scale factor |
|
|
|
// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) |
|
|
|
inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thread float * yl, int il) { |
|
|
|
float d = qb_curr->d; |
|
|
|
float m = qb_curr->m; |
|
|
|
|
|
|
|
float2 acc = 0.f; |
|
|
|
|
|
|
|
device const uint16_t * qs = ((device const uint16_t *)qb_curr + 4 + il/2); |
|
|
|
const uint32_t qh = *((device const uint32_t *)qb_curr->qh); |
|
|
|
|
|
|
|
for (int i = 0; i < 8; i+=2) { |
|
|
|
acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010)) |
|
|
|
+ yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000)); |
|
|
|
acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100)) |
|
|
|
+ yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000)); |
|
|
|
} |
|
|
|
return d * (acc[0] + acc[1]) + sumy * m; |
|
|
|
} |
|
|
|
|
|
|
|
// putting them in the kernel cause a significant performance penalty |
|
|
|
#define N_DST 4 // each SIMD group works on 4 rows |
|
|
|
#define N_SIMDGROUP 2 // number of SIMD groups in a thread group |
|
|
|
@@ -525,6 +597,43 @@ kernel void kernel_mul_mv_q4_1_f32( |
|
|
|
mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg); |
|
|
|
} |
|
|
|
|
|
|
|
kernel void kernel_mul_mv_q5_0_f32( |
|
|
|
device const void * src0, |
|
|
|
device const float * src1, |
|
|
|
device float * dst, |
|
|
|
constant int64_t & ne00, |
|
|
|
constant int64_t & ne01[[buffer(4)]], |
|
|
|
constant int64_t & ne02[[buffer(5)]], |
|
|
|
constant int64_t & ne10[[buffer(9)]], |
|
|
|
constant int64_t & ne12[[buffer(11)]], |
|
|
|
constant int64_t & ne0[[buffer(15)]], |
|
|
|
constant int64_t & ne1[[buffer(16)]], |
|
|
|
constant uint & gqa[[buffer(17)]], |
|
|
|
uint3 tgpig[[threadgroup_position_in_grid]], |
|
|
|
uint tiisg[[thread_index_in_simdgroup]], |
|
|
|
uint sgitg[[simdgroup_index_in_threadgroup]]) { |
|
|
|
mul_vec_q_n_f32<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg); |
|
|
|
} |
|
|
|
|
|
|
|
kernel void kernel_mul_mv_q5_1_f32( |
|
|
|
device const void * src0, |
|
|
|
device const float * src1, |
|
|
|
device float * dst, |
|
|
|
constant int64_t & ne00, |
|
|
|
constant int64_t & ne01[[buffer(4)]], |
|
|
|
constant int64_t & ne02[[buffer(5)]], |
|
|
|
constant int64_t & ne10[[buffer(9)]], |
|
|
|
constant int64_t & ne12[[buffer(11)]], |
|
|
|
constant int64_t & ne0[[buffer(15)]], |
|
|
|
constant int64_t & ne1[[buffer(16)]], |
|
|
|
constant uint & gqa[[buffer(17)]], |
|
|
|
uint3 tgpig[[threadgroup_position_in_grid]], |
|
|
|
uint tiisg[[thread_index_in_simdgroup]], |
|
|
|
uint sgitg[[simdgroup_index_in_threadgroup]]) { |
|
|
|
mul_vec_q_n_f32<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg); |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
#define NB_Q8_0 8 |
|
|
|
|
|
|
|
kernel void kernel_mul_mv_q8_0_f32( |
|
|
|
@@ -2149,6 +2258,62 @@ void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
template <typename type4x4> |
|
|
|
void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg) { |
|
|
|
device const uint16_t * qs = ((device const uint16_t *)xb + 3); |
|
|
|
const float d = xb->d; |
|
|
|
const float md = -16.h * xb->d; |
|
|
|
const ushort mask = il ? 0x00F0 : 0x000F; |
|
|
|
|
|
|
|
const uint32_t qh = *((device const uint32_t *)xb->qh); |
|
|
|
|
|
|
|
const int x_mv = il ? 4 : 0; |
|
|
|
|
|
|
|
const int gh_mv = il ? 12 : 0; |
|
|
|
const int gh_bk = il ? 0 : 4; |
|
|
|
|
|
|
|
for (int i = 0; i < 8; i++) { |
|
|
|
// extract the 5-th bits for x0 and x1 |
|
|
|
const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10; |
|
|
|
const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10; |
|
|
|
|
|
|
|
// combine the 4-bits from qs with the 5th bit |
|
|
|
const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0); |
|
|
|
const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1); |
|
|
|
|
|
|
|
reg[i/2][2*(i%2)+0] = d * x0 + md; |
|
|
|
reg[i/2][2*(i%2)+1] = d * x1 + md; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
template <typename type4x4> |
|
|
|
void dequantize_q5_1(device const block_q5_1 *xb, short il, thread type4x4 & reg) { |
|
|
|
device const uint16_t * qs = ((device const uint16_t *)xb + 4); |
|
|
|
const float d = xb->d; |
|
|
|
const float m = xb->m; |
|
|
|
const ushort mask = il ? 0x00F0 : 0x000F; |
|
|
|
|
|
|
|
const uint32_t qh = *((device const uint32_t *)xb->qh); |
|
|
|
|
|
|
|
const int x_mv = il ? 4 : 0; |
|
|
|
|
|
|
|
const int gh_mv = il ? 12 : 0; |
|
|
|
const int gh_bk = il ? 0 : 4; |
|
|
|
|
|
|
|
for (int i = 0; i < 8; i++) { |
|
|
|
// extract the 5-th bits for x0 and x1 |
|
|
|
const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10; |
|
|
|
const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10; |
|
|
|
|
|
|
|
// combine the 4-bits from qs with the 5th bit |
|
|
|
const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0); |
|
|
|
const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1); |
|
|
|
|
|
|
|
reg[i/2][2*(i%2)+0] = d * x0 + m; |
|
|
|
reg[i/2][2*(i%2)+1] = d * x1 + m; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
template <typename type4x4> |
|
|
|
void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) { |
|
|
|
device const int8_t * qs = ((device const int8_t *)xb->qs); |
|
|
|
@@ -2490,6 +2655,8 @@ template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows |
|
|
|
template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>; |
|
|
|
template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>; |
|
|
|
template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>; |
|
|
|
template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_t kernel_get_rows<block_q5_0, 2, dequantize_q5_0>; |
|
|
|
template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_t kernel_get_rows<block_q5_1, 2, dequantize_q5_1>; |
|
|
|
template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_t kernel_get_rows<block_q8_0, 2, dequantize_q8_0>; |
|
|
|
template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_t kernel_get_rows<block_q2_K, QK_NL, dequantize_q2_K>; |
|
|
|
template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows<block_q3_K, QK_NL, dequantize_q3_K>; |
|
|
|
@@ -2518,6 +2685,8 @@ template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<f |
|
|
|
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>; |
|
|
|
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>; |
|
|
|
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>; |
|
|
|
template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_0, 2, dequantize_q5_0>; |
|
|
|
template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_1, 2, dequantize_q5_1>; |
|
|
|
template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2, dequantize_q8_0>; |
|
|
|
template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>; |
|
|
|
template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q3_K, QK_NL, dequantize_q3_K>; |
|
|
|
|