diff --git a/LLama/Native/NativeApi.cs b/LLama/Native/NativeApi.cs
index 3f193730..e3b182bd 100644
--- a/LLama/Native/NativeApi.cs
+++ b/LLama/Native/NativeApi.cs
@@ -348,6 +348,7 @@ namespace LLama.Native
/// Logits for the ith token. Equivalent to: llama_get_logits(ctx) + i*n_vocab
///
///
+ ///
///
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern float* llama_get_logits_ith(SafeLLamaContextHandle ctx, int i);
diff --git a/LLama/runtimes/ggml-metal.metal b/LLama/runtimes/ggml-metal.metal
index 99b9fd7a..f4b46056 100644
--- a/LLama/runtimes/ggml-metal.metal
+++ b/LLama/runtimes/ggml-metal.metal
@@ -18,6 +18,21 @@ typedef struct {
uint8_t qs[QK4_1 / 2]; // nibbles / quants
} block_q4_1;
+#define QK5_0 32
+typedef struct {
+ half d; // delta
+ uint8_t qh[4]; // 5-th bit of quants
+ uint8_t qs[QK5_0 / 2]; // nibbles / quants
+} block_q5_0;
+
+#define QK5_1 32
+typedef struct {
+ half d; // delta
+ half m; // min
+ uint8_t qh[4]; // 5-th bit of quants
+ uint8_t qs[QK5_1 / 2]; // nibbles / quants
+} block_q5_1;
+
#define QK8_0 32
typedef struct {
half d; // delta
@@ -110,9 +125,17 @@ kernel void kernel_mul_row(
}
kernel void kernel_scale(
+ device const float * src0,
+ device float * dst,
+ constant float & scale,
+ uint tpig[[thread_position_in_grid]]) {
+ dst[tpig] = src0[tpig] * scale;
+}
+
+kernel void kernel_scale_4(
device const float4 * src0,
device float4 * dst,
- constant float & scale,
+ constant float & scale,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = src0[tpig] * scale;
}
@@ -399,8 +422,11 @@ kernel void kernel_rms_norm(
// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) {
float d = qb_curr->d;
+
float2 acc = 0.f;
+
device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/2);
+
for (int i = 0; i < 8; i+=2) {
acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
+ yl[i + 1] * (qs[i / 2] & 0x0F00);
@@ -417,8 +443,11 @@ inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thre
inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl, int il) {
float d = qb_curr->d;
float m = qb_curr->m;
- device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2);
+
float2 acc = 0.f;
+
+ device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2);
+
for (int i = 0; i < 8; i+=2) {
acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
+ yl[i + 1] * (qs[i / 2] & 0x0F00);
@@ -428,6 +457,49 @@ inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thre
return d * (acc[0] + acc[1]) + sumy * m;
}
+// function for calculate inner product between half a q5_0 block and 16 floats (yl), sumy is SUM(yl[i])
+// il indicates where the q5 quants begin (0 or QK5_0/4)
+// we assume that the yl's have been multiplied with the appropriate scale factor
+// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
+inline float block_q_n_dot_y(device const block_q5_0 * qb_curr, float sumy, thread float * yl, int il) {
+ float d = qb_curr->d;
+
+ float2 acc = 0.f;
+
+ device const uint16_t * qs = ((device const uint16_t *)qb_curr + 3 + il/2);
+ const uint32_t qh = *((device const uint32_t *)qb_curr->qh);
+
+ for (int i = 0; i < 8; i+=2) {
+ acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010))
+ + yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000));
+ acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100))
+ + yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000));
+ }
+ return d * (sumy * -16.f + acc[0] + acc[1]);
+}
+
+// function for calculate inner product between half a q5_1 block and 16 floats (yl), sumy is SUM(yl[i])
+// il indicates where the q5 quants begin (0 or QK5_1/4)
+// we assume that the yl's have been multiplied with the appropriate scale factor
+// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
+inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thread float * yl, int il) {
+ float d = qb_curr->d;
+ float m = qb_curr->m;
+
+ float2 acc = 0.f;
+
+ device const uint16_t * qs = ((device const uint16_t *)qb_curr + 4 + il/2);
+ const uint32_t qh = *((device const uint32_t *)qb_curr->qh);
+
+ for (int i = 0; i < 8; i+=2) {
+ acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010))
+ + yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000));
+ acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100))
+ + yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000));
+ }
+ return d * (acc[0] + acc[1]) + sumy * m;
+}
+
// putting them in the kernel cause a significant performance penalty
#define N_DST 4 // each SIMD group works on 4 rows
#define N_SIMDGROUP 2 // number of SIMD groups in a thread group
@@ -525,6 +597,43 @@ kernel void kernel_mul_mv_q4_1_f32(
mul_vec_q_n_f32(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
}
+kernel void kernel_mul_mv_q5_0_f32(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01[[buffer(4)]],
+ constant int64_t & ne02[[buffer(5)]],
+ constant int64_t & ne10[[buffer(9)]],
+ constant int64_t & ne12[[buffer(11)]],
+ constant int64_t & ne0[[buffer(15)]],
+ constant int64_t & ne1[[buffer(16)]],
+ constant uint & gqa[[buffer(17)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ mul_vec_q_n_f32(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
+}
+
+kernel void kernel_mul_mv_q5_1_f32(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01[[buffer(4)]],
+ constant int64_t & ne02[[buffer(5)]],
+ constant int64_t & ne10[[buffer(9)]],
+ constant int64_t & ne12[[buffer(11)]],
+ constant int64_t & ne0[[buffer(15)]],
+ constant int64_t & ne1[[buffer(16)]],
+ constant uint & gqa[[buffer(17)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ mul_vec_q_n_f32(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
+}
+
+
#define NB_Q8_0 8
kernel void kernel_mul_mv_q8_0_f32(
@@ -2149,6 +2258,62 @@ void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg
}
}
+template
+void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg) {
+ device const uint16_t * qs = ((device const uint16_t *)xb + 3);
+ const float d = xb->d;
+ const float md = -16.h * xb->d;
+ const ushort mask = il ? 0x00F0 : 0x000F;
+
+ const uint32_t qh = *((device const uint32_t *)xb->qh);
+
+ const int x_mv = il ? 4 : 0;
+
+ const int gh_mv = il ? 12 : 0;
+ const int gh_bk = il ? 0 : 4;
+
+ for (int i = 0; i < 8; i++) {
+ // extract the 5-th bits for x0 and x1
+ const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
+ const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
+
+ // combine the 4-bits from qs with the 5th bit
+ const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
+ const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
+
+ reg[i/2][2*(i%2)+0] = d * x0 + md;
+ reg[i/2][2*(i%2)+1] = d * x1 + md;
+ }
+}
+
+template
+void dequantize_q5_1(device const block_q5_1 *xb, short il, thread type4x4 & reg) {
+ device const uint16_t * qs = ((device const uint16_t *)xb + 4);
+ const float d = xb->d;
+ const float m = xb->m;
+ const ushort mask = il ? 0x00F0 : 0x000F;
+
+ const uint32_t qh = *((device const uint32_t *)xb->qh);
+
+ const int x_mv = il ? 4 : 0;
+
+ const int gh_mv = il ? 12 : 0;
+ const int gh_bk = il ? 0 : 4;
+
+ for (int i = 0; i < 8; i++) {
+ // extract the 5-th bits for x0 and x1
+ const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
+ const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
+
+ // combine the 4-bits from qs with the 5th bit
+ const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
+ const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
+
+ reg[i/2][2*(i%2)+0] = d * x0 + m;
+ reg[i/2][2*(i%2)+1] = d * x1 + m;
+ }
+}
+
template
void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) {
device const int8_t * qs = ((device const int8_t *)xb->qs);
@@ -2490,6 +2655,8 @@ template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows
template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows;
template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows;
template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows;
+template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_t kernel_get_rows;
+template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_t kernel_get_rows;
template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_t kernel_get_rows;
template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_t kernel_get_rows;
template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows;
@@ -2518,6 +2685,8 @@ template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm;
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm;
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm;
+template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm;
+template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm;
template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm;
template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm;
template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm;
diff --git a/LLama/runtimes/libllama-cuda11.dll b/LLama/runtimes/libllama-cuda11.dll
index e5fc7dad..29949615 100644
Binary files a/LLama/runtimes/libllama-cuda11.dll and b/LLama/runtimes/libllama-cuda11.dll differ
diff --git a/LLama/runtimes/libllama-cuda11.so b/LLama/runtimes/libllama-cuda11.so
index 3532fe99..3a546e44 100644
Binary files a/LLama/runtimes/libllama-cuda11.so and b/LLama/runtimes/libllama-cuda11.so differ
diff --git a/LLama/runtimes/libllama-cuda12.dll b/LLama/runtimes/libllama-cuda12.dll
index 89f27e24..f9f6db79 100644
Binary files a/LLama/runtimes/libllama-cuda12.dll and b/LLama/runtimes/libllama-cuda12.dll differ
diff --git a/LLama/runtimes/libllama-cuda12.so b/LLama/runtimes/libllama-cuda12.so
index 81b4aa99..4488e941 100644
Binary files a/LLama/runtimes/libllama-cuda12.so and b/LLama/runtimes/libllama-cuda12.so differ
diff --git a/LLama/runtimes/libllama.dll b/LLama/runtimes/libllama.dll
index 6f92ebdf..cb0e9f88 100644
Binary files a/LLama/runtimes/libllama.dll and b/LLama/runtimes/libllama.dll differ
diff --git a/LLama/runtimes/libllama.dylib b/LLama/runtimes/libllama.dylib
index c2ca7ec8..c0f06f18 100755
Binary files a/LLama/runtimes/libllama.dylib and b/LLama/runtimes/libllama.dylib differ
diff --git a/LLama/runtimes/libllama.so b/LLama/runtimes/libllama.so
index b9ef4c1d..9702ae6b 100644
Binary files a/LLama/runtimes/libllama.so and b/LLama/runtimes/libllama.so differ