diff --git a/LLama.Web/Common/ModelOptions.cs b/LLama.Web/Common/ModelOptions.cs index 8cbf2f09..2f9caffd 100644 --- a/LLama.Web/Common/ModelOptions.cs +++ b/LLama.Web/Common/ModelOptions.cs @@ -1,5 +1,6 @@ using System.Text; using LLama.Abstractions; +using LLama.Native; namespace LLama.Web.Common { @@ -118,6 +119,24 @@ namespace LLama.Web.Common /// public float? RopeFrequencyScale { get; set; } + /// + public float? YarnExtrapolationFactor { get; set; } + + /// + public float? YarnAttentionFactor { get; set; } + + /// + public float? YarnBetaFast { get; set; } + + /// + public float? YarnBetaSlow { get; set; } + + /// + public uint? YarnOriginalContext { get; set; } + + /// + public RopeScalingType? YarnScalingType { get; set; } + /// /// Use experimental mul_mat_q kernels /// diff --git a/LLama/Abstractions/IContextParams.cs b/LLama/Abstractions/IContextParams.cs index a2ac24f1..d9811cdc 100644 --- a/LLama/Abstractions/IContextParams.cs +++ b/LLama/Abstractions/IContextParams.cs @@ -1,4 +1,5 @@ using System.Text; +using LLama.Native; namespace LLama.Abstractions; @@ -67,4 +68,34 @@ public interface IContextParams /// Number of threads to use for batch processing (null = autodetect) (n_threads) /// uint? BatchThreads { get; set; } + + /// + /// YaRN extrapolation mix factor + /// + float? YarnExtrapolationFactor { get; set; } + + /// + /// YaRN magnitude scaling factor + /// + float? YarnAttentionFactor { get; set; } + + /// + /// YaRN low correction dim + /// + float? YarnBetaFast { get; set; } + + /// + /// YaRN high correction dim + /// + float? YarnBetaSlow { get; set; } + + /// + /// YaRN original context length + /// + uint? YarnOriginalContext { get; set; } + + /// + /// YaRN scaling method to use. + /// + RopeScalingType? YarnScalingType { get; set; } } \ No newline at end of file diff --git a/LLama/Common/ModelParams.cs b/LLama/Common/ModelParams.cs index 9561e482..a736ccbd 100644 --- a/LLama/Common/ModelParams.cs +++ b/LLama/Common/ModelParams.cs @@ -3,6 +3,7 @@ using System; using System.Text; using System.Text.Json; using System.Text.Json.Serialization; +using LLama.Native; namespace LLama.Common { @@ -68,6 +69,26 @@ namespace LLama.Common public float? RopeFrequencyScale { get; set; } /// + public float? YarnExtrapolationFactor { get; set; } + + /// + public float? YarnAttentionFactor { get; set; } + + /// + public float? YarnBetaFast { get; set; } + + /// + public float? YarnBetaSlow { get; set; } + + /// + public uint? YarnOriginalContext { get; set; } + + /// + public RopeScalingType? YarnScalingType { get; set; } + + /// + /// Use experimental mul_mat_q kernels + /// public bool MulMatQ { get; set; } /// diff --git a/LLama/Extensions/IContextParamsExtensions.cs b/LLama/Extensions/IContextParamsExtensions.cs index ed59c9df..bb029c16 100644 --- a/LLama/Extensions/IContextParamsExtensions.cs +++ b/LLama/Extensions/IContextParamsExtensions.cs @@ -29,6 +29,15 @@ namespace LLama.Extensions result.embedding = @params.EmbeddingMode; result.rope_freq_base = @params.RopeFrequencyBase ?? 0; result.rope_freq_scale = @params.RopeFrequencyScale ?? 0; + + // Default YaRN values copied from here: https://github.com/ggerganov/llama.cpp/blob/381efbf480959bb6d1e247a8b0c2328f22e350f8/common/common.h#L67 + result.yarn_ext_factor = @params.YarnExtrapolationFactor ?? -1f; + result.yarn_attn_factor = @params.YarnAttentionFactor ?? 1f; + result.yarn_beta_fast = @params.YarnBetaFast ?? 32f; + result.yarn_beta_slow = @params.YarnBetaSlow ?? 1f; + result.yarn_orig_ctx = @params.YarnOriginalContext ?? 0; + result.rope_scaling_type = @params.YarnScalingType ?? RopeScalingType.LLAMA_ROPE_SCALING_UNSPECIFIED; + result.mul_mat_q = @params.MulMatQ; result.n_threads = Threads(@params.Threads); diff --git a/LLama/Native/LLamaContextParams.cs b/LLama/Native/LLamaContextParams.cs index 9a0b2a8e..c0f2afa2 100644 --- a/LLama/Native/LLamaContextParams.cs +++ b/LLama/Native/LLamaContextParams.cs @@ -44,13 +44,13 @@ namespace LLama.Native /// /// RoPE scaling type, from `enum llama_rope_scaling_type` /// - public sbyte rope_scaling_type; + public RopeScalingType rope_scaling_type; /// /// RoPE base frequency, 0 = from model /// - public float rope_freq_base; + public float rope_freq_base; /// /// RoPE frequency scaling factor, 0 = from model /// diff --git a/LLama/Native/RopeScalingType.cs b/LLama/Native/RopeScalingType.cs new file mode 100644 index 00000000..435932e8 --- /dev/null +++ b/LLama/Native/RopeScalingType.cs @@ -0,0 +1,17 @@ +namespace LLama.Native +{ + /// + /// RoPE scaling type. C# equivalent of llama_rope_scaling_type + /// + public enum RopeScalingType + : sbyte + { + LLAMA_ROPE_SCALING_UNSPECIFIED = -1, + + LLAMA_ROPE_SCALING_NONE = 0, + + LLAMA_ROPE_SCALING_LINEAR = 1, + + LLAMA_ROPE_SCALING_YARN = 2, + } +}