diff --git a/.github/workflows/compile.yml b/.github/workflows/compile.yml
index 4773595c..f8c769f6 100644
--- a/.github/workflows/compile.yml
+++ b/.github/workflows/compile.yml
@@ -130,10 +130,8 @@ jobs:
fail-fast: true
matrix:
include:
- - build: 'cpu'
- defines: '-DLLAMA_BUILD_TESTS=OFF -DLLAMA_METAL=OFF -DLLAMA_BUILD_EXAMPLES=OFF -DLLAMA_BUILD_SERVER=OFF -DBUILD_SHARED_LIBS=ON -DCMAKE_OSX_ARCHITECTURES=arm64'
- build: 'metal'
- defines: '-DLLAMA_BUILD_TESTS=OFF -DLLAMA_METAL=ON -DLLAMA_BUILD_EXAMPLES=OFF -DLLAMA_BUILD_SERVER=OFF -DBUILD_SHARED_LIBS=ON -DCMAKE_OSX_ARCHITECTURES=arm64'
+ defines: '-DLLAMA_BUILD_TESTS=OFF -DLLAMA_BUILD_EXAMPLES=OFF -DLLAMA_BUILD_SERVER=OFF -DBUILD_SHARED_LIBS=ON -DCMAKE_OSX_ARCHITECTURES=arm64'
runs-on: macos-latest
steps:
- uses: actions/checkout@v3
@@ -197,8 +195,6 @@ jobs:
- name: Rearrange MacOS files
if: ${{ github.event.inputs.macos }}
run: |
- mkdir deps/macos-cpu
- cp artifacts/llama-bin-macos-cpu.dylib/libllama.dylib deps/macos-cpu/libllama.dylib
mkdir deps/macos-metal
cp artifacts/llama-bin-macos-metal.dylib/libllama.dylib deps/macos-metal/libllama.dylib
cp artifacts/ggml-metal.metal/ggml-metal.metal deps/macos-metal/ggml-metal.metal
diff --git a/LLama/LLamaSharp.Runtime.targets b/LLama/LLamaSharp.Runtime.targets
index df079ba3..8910f155 100644
--- a/LLama/LLamaSharp.Runtime.targets
+++ b/LLama/LLamaSharp.Runtime.targets
@@ -31,12 +31,8 @@
PreserveNewest
libllama.dylib
-
- None
- libllama-metal.dylib
-
- None
+ PreserveNewest
ggml-metal.metal
diff --git a/LLama/runtimes/ggml-metal.metal b/LLama/runtimes/ggml-metal.metal
index 82e1a0c7..7b5c21d9 100644
--- a/LLama/runtimes/ggml-metal.metal
+++ b/LLama/runtimes/ggml-metal.metal
@@ -25,9 +25,9 @@ typedef struct {
} block_q8_0;
kernel void kernel_add(
- device const float * src0,
- device const float * src1,
- device float * dst,
+ device const float4 * src0,
+ device const float4 * src1,
+ device float4 * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = src0[tpig] + src1[tpig];
}
@@ -35,18 +35,18 @@ kernel void kernel_add(
// assumption: src1 is a row
// broadcast src1 into src0
kernel void kernel_add_row(
- device const float * src0,
- device const float * src1,
- device float * dst,
- constant int64_t & ne00,
+ device const float4 * src0,
+ device const float4 * src1,
+ device float4 * dst,
+ constant int64_t & nb,
uint tpig[[thread_position_in_grid]]) {
- dst[tpig] = src0[tpig] + src1[tpig % ne00];
+ dst[tpig] = src0[tpig] + src1[tpig % nb];
}
kernel void kernel_mul(
- device const float * src0,
- device const float * src1,
- device float * dst,
+ device const float4 * src0,
+ device const float4 * src1,
+ device float4 * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = src0[tpig] * src1[tpig];
}
@@ -54,12 +54,12 @@ kernel void kernel_mul(
// assumption: src1 is a row
// broadcast src1 into src0
kernel void kernel_mul_row(
- device const float * src0,
- device const float * src1,
- device float * dst,
- constant int64_t & ne00,
+ device const float4 * src0,
+ device const float4 * src1,
+ device float4 * dst,
+ constant int64_t & nb,
uint tpig[[thread_position_in_grid]]) {
- dst[tpig] = src0[tpig] * src1[tpig % ne00];
+ dst[tpig] = src0[tpig] * src1[tpig % nb];
}
kernel void kernel_scale(
@@ -133,19 +133,24 @@ kernel void kernel_soft_max(
threadgroup_barrier(mem_flags::mem_threadgroup);
}
- // broadcast
- if (tpitg[0] == 0) {
- buf[0] = buf[0];
- }
+ //// broadcast - not needed. There is a threadgroup barrier above in the last iteration of
+ // the loop, and when that is done, buf[0] has the correct (synchronized) value
+ //if (tpitg[0] == 0) {
+ // buf[0] = buf[0];
+ //}
- threadgroup_barrier(mem_flags::mem_threadgroup);
+ //threadgroup_barrier(mem_flags::mem_threadgroup);
const float max = buf[0];
// parallel sum
buf[tpitg[0]] = 0.0f;
for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
- buf[tpitg[0]] += exp(psrc0[i00] - max);
+ const float exp_psrc0 = exp(psrc0[i00] - max);
+ buf[tpitg[0]] += exp_psrc0;
+ // Remember the result of exp here. exp is expensive, so we really do not
+ // whish to compute it twice.
+ pdst[i00] = exp_psrc0;
}
// reduce
@@ -157,17 +162,18 @@ kernel void kernel_soft_max(
threadgroup_barrier(mem_flags::mem_threadgroup);
}
- // broadcast
- if (tpitg[0] == 0) {
- buf[0] = buf[0];
- }
+ // broadcast - not needed, see above
+ //// broadcast
+ //if (tpitg[0] == 0) {
+ // buf[0] = buf[0];
+ //}
- threadgroup_barrier(mem_flags::mem_threadgroup);
+ //threadgroup_barrier(mem_flags::mem_threadgroup);
const float sum = buf[0];
for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
- pdst[i00] = exp(psrc0[i00] - max) / sum;
+ pdst[i00] /= sum;
}
}
@@ -214,25 +220,17 @@ kernel void kernel_norm(
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
- // broadcast
- if (tpitg == 0) {
- sum[0] /= ne00;
- }
- threadgroup_barrier(mem_flags::mem_threadgroup);
- const float mean = sum[0];
+ const float mean = sum[0] / ne00;
- // recenter
+ // recenter and VARIANCE
+ threadgroup_barrier(mem_flags::mem_threadgroup);
device float * y = dst + tgpig*ne00;
- for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
- y[i00] = x[i00] - mean;
- }
-
- // VARIANCE
- // parallel sum
sum[tpitg] = 0.0f;
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
+ y[i00] = x[i00] - mean;
sum[tpitg] += y[i00] * y[i00];
}
+
// reduce
threadgroup_barrier(mem_flags::mem_threadgroup);
for (uint i = ntg/2; i > 0; i /= 2) {
@@ -241,12 +239,7 @@ kernel void kernel_norm(
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
- // broadcast
- if (tpitg == 0) {
- sum[0] /= ne00;
- }
- threadgroup_barrier(mem_flags::mem_threadgroup);
- const float variance = sum[0];
+ const float variance = sum[0] / ne00;
const float scale = 1.0f/sqrt(variance + eps);
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
@@ -254,7 +247,6 @@ kernel void kernel_norm(
}
}
-
kernel void kernel_rms_norm(
device const void * src0,
device float * dst,
@@ -435,6 +427,8 @@ kernel void kernel_mul_mat_q4_1_f32(
mul_vec_q_n_f32(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
}
+#define NB_Q8_0 8
+
kernel void kernel_mul_mat_q8_0_f32(
device const void * src0,
device const float * src1,
@@ -463,30 +457,30 @@ kernel void kernel_mul_mat_q8_0_f32(
device const block_q8_0 * x = (device const block_q8_0 *) src0 + offset0;
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
- float yl[16];
+ float yl[NB_Q8_0];
float sumf[nr]={0.f};
- const int ix = tiisg/2;
- const int il = tiisg%2;
+ const int ix = tiisg/4;
+ const int il = tiisg%4;
- device const float * yb = y + ix * QK8_0 + 16*il;
+ device const float * yb = y + ix * QK8_0 + NB_Q8_0*il;
- // each thread in a SIMD group deals with half a block.
- for (int ib = ix; ib < nb; ib += nw/2) {
- for (int i = 0; i < 16; ++i) {
+ // each thread in a SIMD group deals with NB_Q8_0 quants at a time
+ for (int ib = ix; ib < nb; ib += nw/4) {
+ for (int i = 0; i < NB_Q8_0; ++i) {
yl[i] = yb[i];
}
for (int row = 0; row < nr; row++) {
- device const int8_t * qs = x[ib+row*nb].qs + 16*il;
+ device const int8_t * qs = x[ib+row*nb].qs + NB_Q8_0*il;
float sumq = 0.f;
- for (int iq = 0; iq < 16; ++iq) {
+ for (int iq = 0; iq < NB_Q8_0; ++iq) {
sumq += qs[iq] * yl[iq];
}
sumf[row] += sumq*x[ib+row*nb].d;
}
- yb += QK8_0 * 16;
+ yb += NB_Q8_0 * nw;
}
for (int row = 0; row < nr; ++row) {
@@ -497,7 +491,7 @@ kernel void kernel_mul_mat_q8_0_f32(
}
}
-kernel void kernel_mul_mat_f16_f32(
+kernel void kernel_mul_mat_f16_f32_1row(
device const char * src0,
device const char * src1,
device float * dst,
@@ -515,11 +509,8 @@ kernel void kernel_mul_mat_f16_f32(
constant uint64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
- threadgroup float * sum [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
- uint3 tpig[[thread_position_in_grid]],
- uint3 tpitg[[thread_position_in_threadgroup]],
- uint3 tptg[[threads_per_threadgroup]]) {
+ uint tiisg[[thread_index_in_simdgroup]]) {
const int64_t r0 = tgpig.x;
const int64_t r1 = tgpig.y;
@@ -528,23 +519,100 @@ kernel void kernel_mul_mat_f16_f32(
device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
- sum[tpitg.x] = 0.0f;
-
- for (int i = tpitg.x; i < ne00; i += tptg.x) {
- sum[tpitg.x] += (float) x[i] * (float) y[i];
+ float sumf = 0;
+ if (ne00 < 128) {
+ for (int i = tiisg; i < ne00; i += 32) {
+ sumf += (float) x[i] * (float) y[i];
+ }
+ float all_sum = simd_sum(sumf);
+ if (tiisg == 0) {
+ dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
+ }
+ } else {
+ device const half4 * x4 = (device const half4 *) x;
+ device const float4 * y4 = (device const float4 *) y;
+ for (int i = tiisg; i < ne00/4; i += 32) {
+ for (int k = 0; k < 4; ++k) sumf += (float)x4[i][k] * y4[i][k];
+ }
+ float all_sum = simd_sum(sumf);
+ if (tiisg == 0) {
+ for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i];
+ dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
+ }
}
- // accumulate the sum from all threads in the threadgroup
- threadgroup_barrier(mem_flags::mem_threadgroup);
- for (uint i = tptg.x/2; i > 0; i /= 2) {
- if (tpitg.x < i) {
- sum[tpitg.x] += sum[tpitg.x + i];
+}
+
+#define N_F16_F32 4
+
+kernel void kernel_mul_mat_f16_f32(
+ device const char * src0,
+ device const char * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]]) {
+
+ const int64_t r0 = tgpig.x;
+ const int64_t rb = tgpig.y*N_F16_F32;
+ const int64_t im = tgpig.z;
+
+ device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
+
+ if (ne00 < 128) {
+ for (int row = 0; row < N_F16_F32; ++row) {
+ int r1 = rb + row;
+ if (r1 >= ne11) {
+ break;
+ }
+
+ device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
+
+ float sumf = 0;
+ for (int i = tiisg; i < ne00; i += 32) {
+ sumf += (float) x[i] * (float) y[i];
+ }
+
+ float all_sum = simd_sum(sumf);
+ if (tiisg == 0) {
+ dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
+ }
}
- threadgroup_barrier(mem_flags::mem_threadgroup);
- }
+ } else {
+ device const half4 * x4 = (device const half4 *)x;
+ for (int row = 0; row < N_F16_F32; ++row) {
+ int r1 = rb + row;
+ if (r1 >= ne11) {
+ break;
+ }
+
+ device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
+ device const float4 * y4 = (device const float4 *) y;
+
+ float sumf = 0;
+ for (int i = tiisg; i < ne00/4; i += 32) {
+ for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k];
+ }
- if (tpitg.x == 0) {
- dst[im*ne1*ne0 + r1*ne0 + r0] = sum[0];
+ float all_sum = simd_sum(sumf);
+ if (tiisg == 0) {
+ for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i];
+ dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
+ }
+ }
}
}
@@ -614,25 +682,27 @@ kernel void kernel_rope(
constant int & mode,
constant float & freq_base,
constant float & freq_scale,
- uint3 tpig[[thread_position_in_grid]]) {
- const int64_t i3 = tpig[2];
- const int64_t i2 = tpig[1];
- const int64_t i1 = tpig[0];
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint3 tptg[[threads_per_threadgroup]],
+ uint3 tgpig[[threadgroup_position_in_grid]]) {
+ const int64_t i3 = tgpig[2];
+ const int64_t i2 = tgpig[1];
+ const int64_t i1 = tgpig[0];
const bool is_neox = mode & 2;
- const float theta_scale = pow(freq_base, -2.0f/n_dims);
const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
- float theta = freq_scale * (float)p;
+ const float theta_0 = freq_scale * (float)p;
+ const float inv_ndims = -1.f/n_dims;
if (!is_neox) {
- for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
+ for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
+
+ const float theta = theta_0 * pow(freq_base, inv_ndims*i0);
const float cos_theta = cos(theta);
const float sin_theta = sin(theta);
- theta *= theta_scale;
-
device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
device float * dst_data = (device float *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
@@ -644,12 +714,12 @@ kernel void kernel_rope(
}
} else {
for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
- for (int64_t ic = 0; ic < n_dims; ic += 2) {
+ for (int64_t ic = 2*tiitg; ic < n_dims; ic += 2*tptg.x) {
+
+ const float theta = theta_0 * pow(freq_base, inv_ndims*ic - ib);
const float cos_theta = cos(theta);
const float sin_theta = sin(theta);
- theta *= theta_scale;
-
const int64_t i0 = ib*n_dims + ic/2;
device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
@@ -1053,31 +1123,40 @@ kernel void kernel_mul_mat_q3_K_f32(
device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + offset0;
device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
- float yl[16];
+ float yl[32];
- const uint16_t kmask1 = 0x0303;
+ const uint16_t kmask1 = 0x3030;
const uint16_t kmask2 = 0x0f0f;
- const int tid = tiisg/2;
- const int ix = tiisg%2;
- const int ip = tid/8; // 0 or 1
- const int il = tid/2 - 4*ip; // 0...3
+ const int tid = tiisg/4;
+ const int ix = tiisg%4;
+ const int ip = tid/4; // 0 or 1
+ const int il = 2*((tid%4)/2); // 0 or 2
const int ir = tid%2;
const int n = 8;
const int l0 = n*ir;
- const uint16_t m1 = 1 << (4*ip + il);
- const uint16_t m2 = m1 << 8;
+ // One would think that the Metal compiler would figure out that ip and il can only have
+ // 4 possible states, and optimize accordingly. Well, no. It needs help, and we do it
+ // with these two tales.
+ //
+ // Possible masks for the high bit
+ const ushort4 mm[4] = {{0x0001, 0x0100, 0x0002, 0x0200}, // ip = 0, il = 0
+ {0x0004, 0x0400, 0x0008, 0x0800}, // ip = 0, il = 2
+ {0x0010, 0x1000, 0x0020, 0x2000}, // ip = 1, il = 0
+ {0x0040, 0x4000, 0x0080, 0x8000}}; // ip = 1, il = 2
+
+ // Possible masks for the low 2 bits
+ const int4 qm[2] = {{0x0003, 0x0300, 0x000c, 0x0c00}, {0x0030, 0x3000, 0x00c0, 0xc000}};
+
+ const ushort4 hm = mm[2*ip + il/2];
const int shift = 2*il;
- const uint16_t qm1 = 0x0003 << shift;
- const uint16_t qm2 = 0x0300 << shift;
- const int32_t v1 = 4 << shift;
- const int32_t v2 = 1024 << shift;
+ const float v1 = il == 0 ? 4.f : 64.f;
+ const float v2 = 4.f * v1;
const uint16_t s_shift1 = 4*ip;
- const uint16_t s_shift2 = s_shift1 + 2*(il/2);
- const int ik = 4 + (il%2);
+ const uint16_t s_shift2 = s_shift1 + il;
const int q_offset = 32*ip + l0;
const int y_offset = 128*ip + 32*il + l0;
@@ -1086,12 +1165,19 @@ kernel void kernel_mul_mat_q3_K_f32(
device const float * y1 = yy + ix*QK_K + y_offset;
- float sumf1[2] = {0.f}, sumf2[2] = {0.f};
- for (int i = ix; i < nb; i += 2) {
+ uint32_t scales32, aux32;
+ thread uint16_t * scales16 = (thread uint16_t *)&scales32;
+ thread const int8_t * scales = (thread const int8_t *)&scales32;
+
+ float sumf1[2] = {0.f};
+ float sumf2[2] = {0.f};
+ for (int i = ix; i < nb; i += 4) {
for (int l = 0; l < 8; ++l) {
- yl[l+0] = y1[l+ 0];
- yl[l+8] = y1[l+16];
+ yl[l+ 0] = y1[l+ 0];
+ yl[l+ 8] = y1[l+16];
+ yl[l+16] = y1[l+32];
+ yl[l+24] = y1[l+48];
}
device const uint16_t * q = (device const uint16_t *)(x[i].qs + q_offset);
@@ -1102,27 +1188,43 @@ kernel void kernel_mul_mat_q3_K_f32(
for (int row = 0; row < 2; ++row) {
const float d_all = (float)dh[0];
- const char2 scales = as_type((uint16_t)(((a[il] >> s_shift1) & kmask2) | (((a[ik] >> s_shift2) & kmask1) << 4)));
- float s1 = 0, s2 = 0;
+ scales16[0] = a[4];
+ scales16[1] = a[5];
+ aux32 = ((scales32 >> s_shift2) << 4) & 0x30303030;
+ scales16[0] = a[il+0];
+ scales16[1] = a[il+1];
+ scales32 = ((scales32 >> s_shift1) & 0x0f0f0f0f) | aux32;
+
+ float s1 = 0, s2 = 0, s3 = 0, s4 = 0, s5 = 0, s6 = 0;
for (int l = 0; l < n; l += 2) {
- const uint16_t qs = q[l/2];
- s1 += yl[l+0] * ((int32_t)(qs & qm1) - ((h[l/2] & m1) ? 0 : v1));
- s2 += yl[l+1] * ((int32_t)(qs & qm2) - ((h[l/2] & m2) ? 0 : v2));
+ const int32_t qs = q[l/2];
+ s1 += yl[l+0] * (qs & qm[il/2][0]);
+ s2 += yl[l+1] * (qs & qm[il/2][1]);
+ s3 += ((h[l/2] & hm[0]) ? 0.f : yl[l+0]) + ((h[l/2] & hm[1]) ? 0.f : yl[l+1]);
+ s4 += yl[l+16] * (qs & qm[il/2][2]);
+ s5 += yl[l+17] * (qs & qm[il/2][3]);
+ s6 += ((h[l/2] & hm[2]) ? 0.f : yl[l+16]) + ((h[l/2] & hm[3]) ? 0.f : yl[l+17]);
}
- float d = d_all * (s1 + 1.f/256.f * s2);
- sumf1[row] += d * scales[0];
- sumf2[row] += d;
+ float d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1);
+ float d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2);
+ sumf1[row] += d1 * (scales[0] - 32);
+ sumf2[row] += d2 * (scales[2] - 32);
- s1 = s2 = 0;
+ s1 = s2 = s3 = s4 = s5 = s6 = 0;
for (int l = 0; l < n; l += 2) {
- const uint16_t qs = q[l/2+8];
- s1 += yl[l+8] * ((int32_t)(qs & qm1) - ((h[l/2+8] & m1) ? 0 : v1));
- s2 += yl[l+9] * ((int32_t)(qs & qm2) - ((h[l/2+8] & m2) ? 0 : v2));
+ const int32_t qs = q[l/2+8];
+ s1 += yl[l+8] * (qs & qm[il/2][0]);
+ s2 += yl[l+9] * (qs & qm[il/2][1]);
+ s3 += ((h[l/2+8] & hm[0]) ? 0.f : yl[l+8]) + ((h[l/2+8] & hm[1]) ? 0.f : yl[l+9]);
+ s4 += yl[l+24] * (qs & qm[il/2][2]);
+ s5 += yl[l+25] * (qs & qm[il/2][3]);
+ s6 += ((h[l/2+8] & hm[2]) ? 0.f : yl[l+24]) + ((h[l/2+8] & hm[3]) ? 0.f : yl[l+25]);
}
- d = d_all * (s1 + 1.f/256.f * s2);
- sumf1[row] += d * scales[1];
- sumf2[row] += d;
+ d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1);
+ d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2);
+ sumf1[row] += d1 * (scales[1] - 32);
+ sumf2[row] += d2 * (scales[3] - 32);
q += step;
h += step;
@@ -1131,17 +1233,20 @@ kernel void kernel_mul_mat_q3_K_f32(
}
- y1 += 2 * QK_K;
+ y1 += 4 * QK_K;
}
for (int row = 0; row < 2; ++row) {
- const float sumf = (sumf1[row] - 32.f*sumf2[row]) / (1 << shift);
- const float tot = simd_sum(sumf);
- if (tiisg == 0) {
- dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = tot;
+ const float sumf = (sumf1[row] + 0.25f * sumf2[row]) / (1 << shift);
+ sumf1[row] = simd_sum(sumf);
+ }
+ if (tiisg == 0) {
+ for (int row = 0; row < 2; ++row) {
+ dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = sumf1[row];
}
}
+
}
#else
kernel void kernel_mul_mat_q3_K_f32(
@@ -1244,7 +1349,8 @@ kernel void kernel_mul_mat_q4_K_f32(
const int r0 = tgpig.x;
const int r1 = tgpig.y;
const int r2 = tgpig.z;
- const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+ //const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+ const int first_row = r0 * N_DST;
const int ib_row = first_row * nb;
const uint offset0 = r2/gqa*(nb*ne0);
device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
@@ -1493,17 +1599,25 @@ kernel void kernel_mul_mat_q5_K_f32(
sc16[2] = ((a[4] >> 0) & kmask2) | ((a[0] & kmask3) >> 2);
sc16[3] = ((a[4] >> 4) & kmask2) | ((a[2] & kmask3) >> 2);
- float4 acc = {0.f, 0.f, 0.f, 0.f};
+ float4 acc1 = {0.f};
+ float4 acc2 = {0.f};
for (int l = 0; l < n; ++l) {
uint8_t h = qh[l];
- acc[0] += yl[l+0] * ((uint16_t)(q1[l] & 0x0F) + (h & hm1 ? 16 : 0));
- acc[1] += yl[l+8] * ((uint16_t)(q1[l] & 0xF0) + (h & hm2 ? 256 : 0));
- acc[2] += yh[l+0] * ((uint16_t)(q2[l] & 0x0F) + (h & hm3 ? 16 : 0));
- acc[3] += yh[l+8] * ((uint16_t)(q2[l] & 0xF0) + (h & hm4 ? 256 : 0));
+ acc1[0] += yl[l+0] * (q1[l] & 0x0F);
+ acc1[1] += yl[l+8] * (q1[l] & 0xF0);
+ acc1[2] += yh[l+0] * (q2[l] & 0x0F);
+ acc1[3] += yh[l+8] * (q2[l] & 0xF0);
+ acc2[0] += h & hm1 ? yl[l+0] : 0.f;
+ acc2[1] += h & hm2 ? yl[l+8] : 0.f;
+ acc2[2] += h & hm3 ? yh[l+0] : 0.f;
+ acc2[3] += h & hm4 ? yh[l+8] : 0.f;
}
const float dall = dh[0];
const float dmin = dh[1];
- sumf[row] += dall * (acc[0] * sc8[0] + acc[1] * sc8[1] * 1.f/16.f + acc[2] * sc8[4] + acc[3] * sc8[5] * 1.f/16.f) -
+ sumf[row] += dall * (sc8[0] * (acc1[0] + 16.f*acc2[0]) +
+ sc8[1] * (acc1[1]/16.f + 16.f*acc2[1]) +
+ sc8[4] * (acc1[2] + 16.f*acc2[2]) +
+ sc8[5] * (acc1[3]/16.f + 16.f*acc2[3])) -
dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
q1 += step;
diff --git a/LLama/runtimes/libllama-metal.dylib b/LLama/runtimes/libllama-metal.dylib
deleted file mode 100755
index e9c2ee28..00000000
Binary files a/LLama/runtimes/libllama-metal.dylib and /dev/null differ
diff --git a/LLama/runtimes/libllama.dylib b/LLama/runtimes/libllama.dylib
index 53318c38..5bb4497d 100755
Binary files a/LLama/runtimes/libllama.dylib and b/LLama/runtimes/libllama.dylib differ