diff --git a/LLama/Native/LLamaContextParams.cs b/LLama/Native/LLamaContextParams.cs
index 50f30c0a..0a397a3d 100644
--- a/LLama/Native/LLamaContextParams.cs
+++ b/LLama/Native/LLamaContextParams.cs
@@ -42,17 +42,41 @@ namespace LLama.Native
public uint n_threads_batch;
///
- /// ref: https://github.com/ggerganov/llama.cpp/pull/2054
- /// RoPE base frequency
+ /// RoPE scaling type, from `enum llama_rope_scaling_type`
///
- public float rope_freq_base;
+ public sbyte rope_scaling_type;
+
///
- /// ref: https://github.com/ggerganov/llama.cpp/pull/2054
- /// RoPE frequency scaling factor
+ /// RoPE base frequency, 0 = from model
///
- public float rope_freq_scale;
-
+ public float rope_freq_base;
+ ///
+ /// RoPE frequency scaling factor, 0 = from model
+ ///
+ public float rope_freq_scale;
+ ///
+ /// YaRN extrapolation mix factor, NaN = from model
+ ///
+ public float yarn_ext_factor;
+ ///
+ /// YaRN magnitude scaling factor
+ ///
+ public float yarn_attn_factor;
+ ///
+ /// YaRN low correction dim
+ ///
+ public float yarn_beta_fast;
+ ///
+ /// YaRN high correction dim
+ ///
+ public float yarn_beta_slow;
+
+ ///
+ /// YaRN original context size
+ ///
+ public uint yarn_orig_ctx;
+
///
/// if true, use experimental mul_mat_q kernels
///
diff --git a/LLama/Native/NativeApi.Sampling.cs b/LLama/Native/NativeApi.Sampling.cs
index e7ee32ba..365acacf 100644
--- a/LLama/Native/NativeApi.Sampling.cs
+++ b/LLama/Native/NativeApi.Sampling.cs
@@ -64,6 +64,17 @@ namespace LLama.Native
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern void llama_sample_top_p(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, float p, ulong min_keep);
+ ///
+ /// Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841
+ ///
+ ///
+ /// Pointer to LLamaTokenDataArray
+ ///
+ ///
+ [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
+ public static extern void llama_sample_min_p(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, float p, ulong min_keep);
+
+
///
/// Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
///
diff --git a/LLama/runtimes/ggml-metal.metal b/LLama/runtimes/ggml-metal.metal
index f4b46056..7c35f23a 100644
--- a/LLama/runtimes/ggml-metal.metal
+++ b/LLama/runtimes/ggml-metal.metal
@@ -184,36 +184,73 @@ kernel void kernel_soft_max(
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint3 tpitg[[thread_position_in_threadgroup]],
- uint3 ntg[[threads_per_threadgroup]]) {
- const int64_t i03 = tgpig[2];
- const int64_t i02 = tgpig[1];
- const int64_t i01 = tgpig[0];
+ threadgroup float * buf [[threadgroup(0)]],
+ uint tgpig[[threadgroup_position_in_grid]],
+ uint tpitg[[thread_position_in_threadgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint ntg[[threads_per_threadgroup]]) {
+ const int64_t i03 = (tgpig) / (ne02*ne01);
+ const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
+ const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
// parallel max
- float lmax = tpitg[0] < ne00 ? psrc0[tpitg[0]] : -INFINITY;
- for (int i00 = tpitg[0] + ntg[0]; i00 < ne00; i00 += ntg[0]) {
+ float lmax = tpitg < ne00 ? psrc0[tpitg] : -INFINITY;
+
+ for (int i00 = tpitg + ntg; i00 < ne00; i00 += ntg) {
lmax = MAX(lmax, psrc0[i00]);
}
- const float max = simd_max(lmax);
+
+ float max = simd_max(lmax);
+ if (tiisg == 0) {
+ buf[sgitg] = max;
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ // broadcast, simd group number is ntg / 32
+ for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
+ if (tpitg < i) {
+ buf[tpitg] = MAX(buf[tpitg], buf[tpitg + i]);
+ }
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ max = buf[0];
// parallel sum
float lsum = 0.0f;
- for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
+ for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
const float exp_psrc0 = exp(psrc0[i00] - max);
lsum += exp_psrc0;
// Remember the result of exp here. exp is expensive, so we really do not
- // whish to compute it twice.
+ // wish to compute it twice.
pdst[i00] = exp_psrc0;
}
- const float sum = simd_sum(lsum);
+ float sum = simd_sum(lsum);
+ if (tiisg == 0) {
+ buf[sgitg] = sum;
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
- for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
+ // broadcast, simd group number is ntg / 32
+ for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
+ if (tpitg < i) {
+ buf[tpitg] += buf[tpitg + i];
+ }
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ sum = buf[0];
+
+ for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
pdst[i00] /= sum;
}
}
@@ -224,37 +261,73 @@ kernel void kernel_soft_max_4(
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint3 tpitg[[thread_position_in_threadgroup]],
- uint3 ntg[[threads_per_threadgroup]]) {
- const int64_t i03 = tgpig[2];
- const int64_t i02 = tgpig[1];
- const int64_t i01 = tgpig[0];
+ threadgroup float * buf [[threadgroup(0)]],
+ uint tgpig[[threadgroup_position_in_grid]],
+ uint tpitg[[thread_position_in_threadgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint ntg[[threads_per_threadgroup]]) {
+ const int64_t i03 = (tgpig) / (ne02*ne01);
+ const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
+ const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
// parallel max
- float4 lmax4 = tpitg[0] < ne00/4 ? psrc4[tpitg[0]] : -INFINITY;
- for (int i00 = tpitg[0] + ntg[0]; i00 < ne00/4; i00 += ntg[0]) {
+ float4 lmax4 = tpitg < ne00/4 ? psrc4[tpitg] : -INFINITY;
+
+ for (int i00 = tpitg + ntg; i00 < ne00/4; i00 += ntg) {
lmax4 = fmax(lmax4, psrc4[i00]);
}
- float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
- const float max = simd_max(lmax);
+ const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
+ float max = simd_max(lmax);
+ if (tiisg == 0) {
+ buf[sgitg] = max;
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ // broadcast, simd group number is ntg / 32
+ for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
+ if (tpitg < i) {
+ buf[tpitg] = MAX(buf[tpitg], buf[tpitg + i]);
+ }
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ max = buf[0];
// parallel sum
float4 lsum4 = 0.0f;
- for (int i00 = tpitg[0]; i00 < ne00/4; i00 += ntg[0]) {
+ for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
const float4 exp_psrc4 = exp(psrc4[i00] - max);
lsum4 += exp_psrc4;
pdst4[i00] = exp_psrc4;
}
- float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
- const float sum = simd_sum(lsum);
+ const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
+ float sum = simd_sum(lsum);
+ if (tiisg == 0) {
+ buf[sgitg] = sum;
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ // broadcast, simd group number is ntg / 32
+ for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
+ if (tpitg < i) {
+ buf[tpitg] += buf[tpitg + i];
+ }
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ sum = buf[0];
- for (int i00 = tpitg[0]; i00 < ne00/4; i00 += ntg[0]) {
+ for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
pdst4[i00] /= sum;
}
}
@@ -274,7 +347,7 @@ kernel void kernel_diag_mask_inf(
dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY;
} else {
dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00];
- }
+ }
}
kernel void kernel_diag_mask_inf_8(
@@ -988,6 +1061,45 @@ kernel void kernel_alibi_f32(
}
}
+static float rope_yarn_ramp(const float low, const float high, const int i0) {
+ const float y = (i0 / 2 - low) / max(0.001f, high - low);
+ return 1.0f - min(1.0f, max(0.0f, y));
+}
+
+// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
+// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
+static void rope_yarn(
+ float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale,
+ thread float * cos_theta, thread float * sin_theta
+) {
+ // Get n-d rotational scaling corrected for extrapolation
+ float theta_interp = freq_scale * theta_extrap;
+ float theta = theta_interp;
+ if (ext_factor != 0.0f) {
+ float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
+ theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
+
+ // Get n-d magnitude scaling corrected for interpolation
+ mscale *= 1.0f + 0.1f * log(1.0f / freq_scale);
+ }
+ *cos_theta = cos(theta) * mscale;
+ *sin_theta = sin(theta) * mscale;
+}
+
+// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
+// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
+static float rope_yarn_corr_factor(int n_dims, int n_orig_ctx, float n_rot, float base) {
+ return n_dims * log(n_orig_ctx / (n_rot * 2 * M_PI_F)) / (2 * log(base));
+}
+
+static void rope_yarn_corr_dims(
+ int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]
+) {
+ // start and end correction dims
+ dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_fast, freq_base)));
+ dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_slow, freq_base)));
+}
+
typedef void (rope_t)(
device const void * src0,
device const int32_t * src1,
@@ -1011,8 +1123,13 @@ typedef void (rope_t)(
constant int & n_past,
constant int & n_dims,
constant int & mode,
+ constant int & n_orig_ctx,
constant float & freq_base,
constant float & freq_scale,
+ constant float & ext_factor,
+ constant float & attn_factor,
+ constant float & beta_fast,
+ constant float & beta_slow,
uint tiitg[[thread_index_in_threadgroup]],
uint3 tptg[[threads_per_threadgroup]],
uint3 tgpig[[threadgroup_position_in_grid]]);
@@ -1041,8 +1158,13 @@ kernel void kernel_rope(
constant int & n_past,
constant int & n_dims,
constant int & mode,
+ constant int & n_orig_ctx,
constant float & freq_base,
constant float & freq_scale,
+ constant float & ext_factor,
+ constant float & attn_factor,
+ constant float & beta_fast,
+ constant float & beta_slow,
uint tiitg[[thread_index_in_threadgroup]],
uint3 tptg[[threads_per_threadgroup]],
uint3 tgpig[[threadgroup_position_in_grid]]) {
@@ -1052,19 +1174,22 @@ kernel void kernel_rope(
const bool is_neox = mode & 2;
+ float corr_dims[2];
+ rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
+
device const int32_t * pos = src1;
const int64_t p = pos[i2];
- const float theta_0 = freq_scale * (float)p;
+ const float theta_0 = (float)p;
const float inv_ndims = -1.f/n_dims;
if (!is_neox) {
for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
const float theta = theta_0 * pow(freq_base, inv_ndims*i0);
- const float cos_theta = cos(theta);
- const float sin_theta = sin(theta);
+ float cos_theta, sin_theta;
+ rope_yarn(theta, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
@@ -1079,9 +1204,12 @@ kernel void kernel_rope(
for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
for (int64_t ic = 2*tiitg; ic < n_dims; ic += 2*tptg.x) {
- const float theta = theta_0 * pow(freq_base, inv_ndims*ic - ib);
- const float cos_theta = cos(theta);
- const float sin_theta = sin(theta);
+ // simplified from `(ib * n_dims + ic) * inv_ndims`
+ const float cur_rot = inv_ndims*ic - ib;
+
+ const float theta = theta_0 * pow(freq_base, cur_rot);
+ float cos_theta, sin_theta;
+ rope_yarn(theta, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
const int64_t i0 = ib*n_dims + ic/2;
diff --git a/LLama/runtimes/libllama.dylib b/LLama/runtimes/libllama.dylib
index 3f36bb35..71849df2 100755
Binary files a/LLama/runtimes/libllama.dylib and b/LLama/runtimes/libllama.dylib differ