diff --git a/LLama/Native/NativeApi.cs b/LLama/Native/NativeApi.cs index 3f193730..e3b182bd 100644 --- a/LLama/Native/NativeApi.cs +++ b/LLama/Native/NativeApi.cs @@ -348,6 +348,7 @@ namespace LLama.Native /// Logits for the ith token. Equivalent to: llama_get_logits(ctx) + i*n_vocab /// /// + /// /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] public static extern float* llama_get_logits_ith(SafeLLamaContextHandle ctx, int i); diff --git a/LLama/runtimes/ggml-metal.metal b/LLama/runtimes/ggml-metal.metal index 99b9fd7a..f4b46056 100644 --- a/LLama/runtimes/ggml-metal.metal +++ b/LLama/runtimes/ggml-metal.metal @@ -18,6 +18,21 @@ typedef struct { uint8_t qs[QK4_1 / 2]; // nibbles / quants } block_q4_1; +#define QK5_0 32 +typedef struct { + half d; // delta + uint8_t qh[4]; // 5-th bit of quants + uint8_t qs[QK5_0 / 2]; // nibbles / quants +} block_q5_0; + +#define QK5_1 32 +typedef struct { + half d; // delta + half m; // min + uint8_t qh[4]; // 5-th bit of quants + uint8_t qs[QK5_1 / 2]; // nibbles / quants +} block_q5_1; + #define QK8_0 32 typedef struct { half d; // delta @@ -110,9 +125,17 @@ kernel void kernel_mul_row( } kernel void kernel_scale( + device const float * src0, + device float * dst, + constant float & scale, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = src0[tpig] * scale; +} + +kernel void kernel_scale_4( device const float4 * src0, device float4 * dst, - constant float & scale, + constant float & scale, uint tpig[[thread_position_in_grid]]) { dst[tpig] = src0[tpig] * scale; } @@ -399,8 +422,11 @@ kernel void kernel_rms_norm( // that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) { float d = qb_curr->d; + float2 acc = 0.f; + device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/2); + for (int i = 0; i < 8; i+=2) { acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F) + yl[i + 1] * (qs[i / 2] & 0x0F00); @@ -417,8 +443,11 @@ inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thre inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl, int il) { float d = qb_curr->d; float m = qb_curr->m; - device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2); + float2 acc = 0.f; + + device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2); + for (int i = 0; i < 8; i+=2) { acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F) + yl[i + 1] * (qs[i / 2] & 0x0F00); @@ -428,6 +457,49 @@ inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thre return d * (acc[0] + acc[1]) + sumy * m; } +// function for calculate inner product between half a q5_0 block and 16 floats (yl), sumy is SUM(yl[i]) +// il indicates where the q5 quants begin (0 or QK5_0/4) +// we assume that the yl's have been multiplied with the appropriate scale factor +// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) +inline float block_q_n_dot_y(device const block_q5_0 * qb_curr, float sumy, thread float * yl, int il) { + float d = qb_curr->d; + + float2 acc = 0.f; + + device const uint16_t * qs = ((device const uint16_t *)qb_curr + 3 + il/2); + const uint32_t qh = *((device const uint32_t *)qb_curr->qh); + + for (int i = 0; i < 8; i+=2) { + acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010)) + + yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000)); + acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100)) + + yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000)); + } + return d * (sumy * -16.f + acc[0] + acc[1]); +} + +// function for calculate inner product between half a q5_1 block and 16 floats (yl), sumy is SUM(yl[i]) +// il indicates where the q5 quants begin (0 or QK5_1/4) +// we assume that the yl's have been multiplied with the appropriate scale factor +// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) +inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thread float * yl, int il) { + float d = qb_curr->d; + float m = qb_curr->m; + + float2 acc = 0.f; + + device const uint16_t * qs = ((device const uint16_t *)qb_curr + 4 + il/2); + const uint32_t qh = *((device const uint32_t *)qb_curr->qh); + + for (int i = 0; i < 8; i+=2) { + acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010)) + + yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000)); + acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100)) + + yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000)); + } + return d * (acc[0] + acc[1]) + sumy * m; +} + // putting them in the kernel cause a significant performance penalty #define N_DST 4 // each SIMD group works on 4 rows #define N_SIMDGROUP 2 // number of SIMD groups in a thread group @@ -525,6 +597,43 @@ kernel void kernel_mul_mv_q4_1_f32( mul_vec_q_n_f32(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg); } +kernel void kernel_mul_mv_q5_0_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01[[buffer(4)]], + constant int64_t & ne02[[buffer(5)]], + constant int64_t & ne10[[buffer(9)]], + constant int64_t & ne12[[buffer(11)]], + constant int64_t & ne0[[buffer(15)]], + constant int64_t & ne1[[buffer(16)]], + constant uint & gqa[[buffer(17)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + mul_vec_q_n_f32(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg); +} + +kernel void kernel_mul_mv_q5_1_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01[[buffer(4)]], + constant int64_t & ne02[[buffer(5)]], + constant int64_t & ne10[[buffer(9)]], + constant int64_t & ne12[[buffer(11)]], + constant int64_t & ne0[[buffer(15)]], + constant int64_t & ne1[[buffer(16)]], + constant uint & gqa[[buffer(17)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + mul_vec_q_n_f32(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg); +} + + #define NB_Q8_0 8 kernel void kernel_mul_mv_q8_0_f32( @@ -2149,6 +2258,62 @@ void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg } } +template +void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg) { + device const uint16_t * qs = ((device const uint16_t *)xb + 3); + const float d = xb->d; + const float md = -16.h * xb->d; + const ushort mask = il ? 0x00F0 : 0x000F; + + const uint32_t qh = *((device const uint32_t *)xb->qh); + + const int x_mv = il ? 4 : 0; + + const int gh_mv = il ? 12 : 0; + const int gh_bk = il ? 0 : 4; + + for (int i = 0; i < 8; i++) { + // extract the 5-th bits for x0 and x1 + const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10; + const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10; + + // combine the 4-bits from qs with the 5th bit + const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0); + const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1); + + reg[i/2][2*(i%2)+0] = d * x0 + md; + reg[i/2][2*(i%2)+1] = d * x1 + md; + } +} + +template +void dequantize_q5_1(device const block_q5_1 *xb, short il, thread type4x4 & reg) { + device const uint16_t * qs = ((device const uint16_t *)xb + 4); + const float d = xb->d; + const float m = xb->m; + const ushort mask = il ? 0x00F0 : 0x000F; + + const uint32_t qh = *((device const uint32_t *)xb->qh); + + const int x_mv = il ? 4 : 0; + + const int gh_mv = il ? 12 : 0; + const int gh_bk = il ? 0 : 4; + + for (int i = 0; i < 8; i++) { + // extract the 5-th bits for x0 and x1 + const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10; + const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10; + + // combine the 4-bits from qs with the 5th bit + const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0); + const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1); + + reg[i/2][2*(i%2)+0] = d * x0 + m; + reg[i/2][2*(i%2)+1] = d * x1 + m; + } +} + template void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) { device const int8_t * qs = ((device const int8_t *)xb->qs); @@ -2490,6 +2655,8 @@ template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows; template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows; template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_t kernel_get_rows; template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_t kernel_get_rows; template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_t kernel_get_rows; template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows; @@ -2518,6 +2685,8 @@ template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm; diff --git a/LLama/runtimes/libllama-cuda11.dll b/LLama/runtimes/libllama-cuda11.dll index e5fc7dad..29949615 100644 Binary files a/LLama/runtimes/libllama-cuda11.dll and b/LLama/runtimes/libllama-cuda11.dll differ diff --git a/LLama/runtimes/libllama-cuda11.so b/LLama/runtimes/libllama-cuda11.so index 3532fe99..3a546e44 100644 Binary files a/LLama/runtimes/libllama-cuda11.so and b/LLama/runtimes/libllama-cuda11.so differ diff --git a/LLama/runtimes/libllama-cuda12.dll b/LLama/runtimes/libllama-cuda12.dll index 89f27e24..f9f6db79 100644 Binary files a/LLama/runtimes/libllama-cuda12.dll and b/LLama/runtimes/libllama-cuda12.dll differ diff --git a/LLama/runtimes/libllama-cuda12.so b/LLama/runtimes/libllama-cuda12.so index 81b4aa99..4488e941 100644 Binary files a/LLama/runtimes/libllama-cuda12.so and b/LLama/runtimes/libllama-cuda12.so differ diff --git a/LLama/runtimes/libllama.dll b/LLama/runtimes/libllama.dll index 6f92ebdf..cb0e9f88 100644 Binary files a/LLama/runtimes/libllama.dll and b/LLama/runtimes/libllama.dll differ diff --git a/LLama/runtimes/libllama.dylib b/LLama/runtimes/libllama.dylib index c2ca7ec8..c0f06f18 100755 Binary files a/LLama/runtimes/libllama.dylib and b/LLama/runtimes/libllama.dylib differ diff --git a/LLama/runtimes/libllama.so b/LLama/runtimes/libllama.so index b9ef4c1d..9702ae6b 100644 Binary files a/LLama/runtimes/libllama.so and b/LLama/runtimes/libllama.so differ