diff --git a/LLama.Examples/NewVersion/LoadAndSaveSession.cs b/LLama.Examples/NewVersion/LoadAndSaveSession.cs
index 33774b13..9e6116ce 100644
--- a/LLama.Examples/NewVersion/LoadAndSaveSession.cs
+++ b/LLama.Examples/NewVersion/LoadAndSaveSession.cs
@@ -8,7 +8,7 @@ namespace LLama.Examples.NewVersion
{
Console.Write("Please input your model path: ");
var modelPath = Console.ReadLine();
- var prompt = File.ReadAllText("Assets/chat-with-bob.txt").Trim();
+ var prompt = (await File.ReadAllTextAsync("Assets/chat-with-bob.txt")).Trim();
var parameters = new ModelParams(modelPath)
{
@@ -50,7 +50,7 @@ namespace LLama.Examples.NewVersion
Console.ForegroundColor = ConsoleColor.White;
ex.Context.Dispose();
- ex = new(new LLamaContext(parameters));
+ ex = new(new LLamaContext(model, parameters));
session = new ChatSession(ex);
session.LoadSession(statePath);
diff --git a/LLama.Unittest/BasicTest.cs b/LLama.Unittest/BasicTest.cs
index 93d82192..2cd1806f 100644
--- a/LLama.Unittest/BasicTest.cs
+++ b/LLama.Unittest/BasicTest.cs
@@ -29,7 +29,6 @@ namespace LLama.Unittest
Assert.Equal(32000, _model.VocabCount);
Assert.Equal(4096, _model.ContextSize);
Assert.Equal(4096, _model.EmbeddingSize);
- Assert.Equal(Encoding.UTF8, _model.Encoding);
}
}
}
\ No newline at end of file
diff --git a/LLama.Unittest/LLamaContextTests.cs b/LLama.Unittest/LLamaContextTests.cs
index 6a181734..2edf3a62 100644
--- a/LLama.Unittest/LLamaContextTests.cs
+++ b/LLama.Unittest/LLamaContextTests.cs
@@ -10,7 +10,10 @@ namespace LLama.Unittest
public LLamaContextTests()
{
- var @params = new ModelParams(Constants.ModelPath);
+ var @params = new ModelParams(Constants.ModelPath)
+ {
+ ContextSize = 768,
+ };
_weights = LLamaWeights.LoadFromFile(@params);
_context = _weights.CreateContext(@params);
}
@@ -24,7 +27,7 @@ namespace LLama.Unittest
[Fact]
public void CheckProperties()
{
- Assert.Equal(4096, _context.ContextSize);
+ Assert.Equal(768, _context.ContextSize);
Assert.Equal(4096, _context.EmbeddingSize);
Assert.Equal(32000, _context.VocabCount);
}
diff --git a/LLama.Unittest/ModelsParamsTests.cs b/LLama.Unittest/ModelsParamsTests.cs
index 413bda83..000f5853 100644
--- a/LLama.Unittest/ModelsParamsTests.cs
+++ b/LLama.Unittest/ModelsParamsTests.cs
@@ -13,7 +13,6 @@ namespace LLama.Unittest
{
BatchSize = 17,
ContextSize = 42,
- LoraAdapter = "adapter",
Seed = 42,
GpuLayerCount = 111
};
@@ -31,9 +30,13 @@ namespace LLama.Unittest
{
BatchSize = 17,
ContextSize = 42,
- LoraAdapter = "adapter",
Seed = 42,
- GpuLayerCount = 111
+ GpuLayerCount = 111,
+ LoraAdapters =
+ {
+ new("abc", 1),
+ new("def", 0)
+ }
};
var settings = new Newtonsoft.Json.JsonSerializerSettings();
diff --git a/LLama.Web/Common/ModelOptions.cs b/LLama.Web/Common/ModelOptions.cs
index c5b5c54b..2fd8558c 100644
--- a/LLama.Web/Common/ModelOptions.cs
+++ b/LLama.Web/Common/ModelOptions.cs
@@ -4,7 +4,7 @@ using LLama.Abstractions;
namespace LLama.Web.Common
{
public class ModelOptions
- : IModelParams
+ : ILLamaParams
{
public string Name { get; set; }
@@ -51,16 +51,11 @@ namespace LLama.Web.Common
/// Model path (model)
///
public string ModelPath { get; set; }
+
///
- /// model alias
- ///
- public string ModelAlias { get; set; } = "unknown";
- ///
- /// lora adapter path (lora_adapter)
+ /// List of LoRAs to apply
///
- public string LoraAdapter { get; set; } = string.Empty;
-
- public float LoraAdapterScale { get; set; } = 1;
+ public AdapterCollection LoraAdapters { get; set; } = new();
///
/// base model path for the lora adapter (lora_base)
diff --git a/LLama/Abstractions/IContextParams.cs b/LLama/Abstractions/IContextParams.cs
new file mode 100644
index 00000000..a59512d3
--- /dev/null
+++ b/LLama/Abstractions/IContextParams.cs
@@ -0,0 +1,60 @@
+using System.Text;
+
+namespace LLama.Abstractions;
+
+///
+/// The parameters for initializing a LLama context from a model.
+///
+public interface IContextParams
+{
+ ///
+ /// Model context size (n_ctx)
+ ///
+ uint ContextSize { get; set; }
+
+ ///
+ /// batch size for prompt processing (must be >=32 to use BLAS) (n_batch)
+ ///
+ uint BatchSize { get; set; }
+
+ ///
+ /// Seed for the random number generator (seed)
+ ///
+ uint Seed { get; set; }
+
+ ///
+ /// Use f16 instead of f32 for memory kv (memory_f16)
+ ///
+ bool UseFp16Memory { get; set; }
+
+ ///
+ /// Compute perplexity over the prompt (perplexity)
+ ///
+ bool Perplexity { get; set; }
+
+ ///
+ /// Whether to use embedding mode. (embedding) Note that if this is set to true,
+ /// The LLamaModel won't produce text response anymore.
+ ///
+ bool EmbeddingMode { get; set; }
+
+ ///
+ /// RoPE base frequency
+ ///
+ float RopeFrequencyBase { get; set; }
+
+ ///
+ /// RoPE frequency scaling factor
+ ///
+ float RopeFrequencyScale { get; set; }
+
+ ///
+ /// Use experimental mul_mat_q kernels
+ ///
+ bool MulMatQ { get; set; }
+
+ ///
+ /// The encoding to use for models
+ ///
+ Encoding Encoding { get; set; }
+}
\ No newline at end of file
diff --git a/LLama/Abstractions/ILLamaParams.cs b/LLama/Abstractions/ILLamaParams.cs
new file mode 100644
index 00000000..636ba199
--- /dev/null
+++ b/LLama/Abstractions/ILLamaParams.cs
@@ -0,0 +1,11 @@
+namespace LLama.Abstractions
+{
+ ///
+ /// Convenience interface for implementing both type of parameters.
+ ///
+ /// Mostly exists for backwards compatibility reasons, when these two were not split.
+ public interface ILLamaParams
+ : IModelParams, IContextParams
+ {
+ }
+}
diff --git a/LLama/Abstractions/IModelParams.cs b/LLama/Abstractions/IModelParams.cs
index 168654c4..31304acb 100644
--- a/LLama/Abstractions/IModelParams.cs
+++ b/LLama/Abstractions/IModelParams.cs
@@ -1,4 +1,6 @@
-using System.Text;
+using System;
+using System.Collections.Generic;
+using System.Linq;
namespace LLama.Abstractions
{
@@ -7,36 +9,16 @@ namespace LLama.Abstractions
///
public interface IModelParams
{
- ///
- /// Model context size (n_ctx)
- ///
- uint ContextSize { get; set; }
-
///
/// the GPU that is used for scratch and small tensors
///
int MainGpu { get; set; }
- ///
- /// if true, reduce VRAM usage at the cost of performance
- ///
- bool LowVram { get; set; }
-
///
/// Number of layers to run in VRAM / GPU memory (n_gpu_layers)
///
int GpuLayerCount { get; set; }
- ///
- /// Seed for the random number generator (seed)
- ///
- uint Seed { get; set; }
-
- ///
- /// Use f16 instead of f32 for memory kv (memory_f16)
- ///
- bool UseFp16Memory { get; set; }
-
///
/// Use mmap for faster loads (use_mmap)
///
@@ -47,72 +29,78 @@ namespace LLama.Abstractions
///
bool UseMemoryLock { get; set; }
- ///
- /// Compute perplexity over the prompt (perplexity)
- ///
- bool Perplexity { get; set; }
-
///
/// Model path (model)
///
string ModelPath { get; set; }
- ///
- /// lora adapter path (lora_adapter)
- ///
- string LoraAdapter { get; set; }
-
- float LoraAdapterScale { get; set; }
-
- ///
- /// base model path for the lora adapter (lora_base)
- ///
- string LoraBase { get; set; }
-
///
/// Number of threads (-1 = autodetect) (n_threads)
///
int Threads { get; set; }
- ///
- /// batch size for prompt processing (must be >=32 to use BLAS) (n_batch)
- ///
- uint BatchSize { get; set; }
-
- ///
- /// Whether to use embedding mode. (embedding) Note that if this is set to true,
- /// The LLamaModel won't produce text response anymore.
- ///
- bool EmbeddingMode { get; set; }
-
///
/// how split tensors should be distributed across GPUs
///
float[]? TensorSplits { get; set; }
///
- /// RoPE base frequency
+ /// Load vocab only (no weights)
///
- float RopeFrequencyBase { get; set; }
+ bool VocabOnly { get; set; }
///
- /// RoPE frequency scaling factor
+ /// List of LoRA adapters to apply
///
- float RopeFrequencyScale { get; set; }
+ AdapterCollection LoraAdapters { get; }
///
- /// Use experimental mul_mat_q kernels
+ /// base model path for the lora adapter (lora_base)
///
- bool MulMatQ { get; set; }
+ string LoraBase { get; set; }
+ }
- ///
- /// The encoding to use for models
- ///
- Encoding Encoding { get; set; }
+ ///
+ /// A LoRA adapter to apply to a model
+ ///
+ /// Path to the LoRA file
+ /// Strength of this LoRA
+ public readonly record struct LoraAdapter(string Path, float Scale);
- ///
- /// Load vocab only (no weights)
- ///
- bool VocabOnly { get; set; }
+ ///
+ /// A list of LoraAdapter objects
+ ///
+ public sealed class AdapterCollection
+ : List, IEquatable
+ {
+ ///
+ public bool Equals(AdapterCollection? other)
+ {
+ if (other == null)
+ return false;
+
+ return this.SequenceEqual(other);
+ }
+
+ ///
+ public override bool Equals(object? obj)
+ {
+ return Equals(obj as AdapterCollection);
+ }
+
+ ///
+ public override int GetHashCode()
+ {
+ unchecked
+ {
+ var hash = 17;
+ for (var i = 0; i < Count; i++)
+ {
+ hash += this[i].GetHashCode();
+ hash *= 7823;
+ }
+ return hash;
+ }
+ }
}
}
\ No newline at end of file
diff --git a/LLama/Common/ModelParams.cs b/LLama/Common/ModelParams.cs
index a0d1688a..09b5e4af 100644
--- a/LLama/Common/ModelParams.cs
+++ b/LLama/Common/ModelParams.cs
@@ -1,5 +1,6 @@
using LLama.Abstractions;
using System;
+using System.Collections.Generic;
using System.Text;
using System.Text.Json;
using System.Text.Json.Serialization;
@@ -10,7 +11,7 @@ namespace LLama.Common
/// The parameters for initializing a LLama model.
///
public record ModelParams
- : IModelParams
+ : ILLamaParams
{
///
/// Model context size (n_ctx)
@@ -20,10 +21,7 @@ namespace LLama.Common
/// the GPU that is used for scratch and small tensors
///
public int MainGpu { get; set; } = 0;
- ///
- /// if true, reduce VRAM usage at the cost of performance
- ///
- public bool LowVram { get; set; } = false;
+
///
/// Number of layers to run in VRAM / GPU memory (n_gpu_layers)
///
@@ -52,17 +50,17 @@ namespace LLama.Common
/// Model path (model)
///
public string ModelPath { get; set; }
+
///
- /// lora adapter path (lora_adapter)
+ /// List of LoRAs to apply
///
- public string LoraAdapter { get; set; } = string.Empty;
-
- public float LoraAdapterScale { get; set; } = 1;
+ public AdapterCollection LoraAdapters { get; set; } = new();
///
/// base model path for the lora adapter (lora_base)
///
public string LoraBase { get; set; } = string.Empty;
+
///
/// Number of threads (-1 = autodetect) (n_threads)
///
@@ -162,7 +160,6 @@ namespace LLama.Common
UseMemoryLock = useMemoryLock;
Perplexity = perplexity;
ModelPath = modelPath;
- LoraAdapter = loraAdapter;
LoraBase = loraBase;
Threads = threads == -1 ? Math.Max(Environment.ProcessorCount / 2, 1) : threads;
BatchSize = batchSize;
@@ -171,6 +168,7 @@ namespace LLama.Common
RopeFrequencyScale = ropeFrequencyScale;
MulMatQ = mulMatQ;
Encoding = Encoding.GetEncoding(encoding);
+ LoraAdapters.Add(new LoraAdapter(loraAdapter, 1));
}
}
diff --git a/LLama/Extensions/IModelParamsExtensions.cs b/LLama/Extensions/IModelParamsExtensions.cs
index 1bf19958..9be239df 100644
--- a/LLama/Extensions/IModelParamsExtensions.cs
+++ b/LLama/Extensions/IModelParamsExtensions.cs
@@ -19,7 +19,7 @@ namespace LLama.Extensions
///
///
///
- public static void ToLlamaContextParams(this IModelParams @params, out LLamaContextParams result)
+ public static void ToLlamaContextParams(this IContextParams @params, out LLamaContextParams result)
{
result = NativeApi.llama_context_default_params();
result.n_ctx = @params.ContextSize;
diff --git a/LLama/LLamaContext.cs b/LLama/LLamaContext.cs
index 9fef6af5..e6222dcf 100644
--- a/LLama/LLamaContext.cs
+++ b/LLama/LLamaContext.cs
@@ -42,9 +42,9 @@ namespace LLama
public int EmbeddingSize => _ctx.EmbeddingSize;
///
- /// The model params set for this model.
+ /// The context params set for this context
///
- public IModelParams Params { get; set; }
+ public IContextParams Params { get; set; }
///
/// The native handle, which is used to be passed to the native APIs
@@ -57,24 +57,7 @@ namespace LLama
///
public Encoding Encoding => _encoding;
- ///
- ///
- ///
- /// Model params.
- /// The logger.
- [Obsolete("Use the LLamaWeights.CreateContext instead")]
- public LLamaContext(IModelParams @params, ILogger? logger = null)
- {
- Params = @params;
-
- _logger = logger;
- _encoding = @params.Encoding;
-
- _logger?.LogInformation($"[LLamaContext] Initializing LLama model with params: {this.Params}");
- _ctx = Utils.InitLLamaContextFromModelParams(Params);
- }
-
- internal LLamaContext(SafeLLamaContextHandle nativeContext, IModelParams @params, ILogger? logger = null)
+ internal LLamaContext(SafeLLamaContextHandle nativeContext, IContextParams @params, ILogger? logger = null)
{
Params = @params;
@@ -90,7 +73,7 @@ namespace LLama
///
///
///
- public LLamaContext(LLamaWeights model, IModelParams @params, ILogger? logger = null)
+ public LLamaContext(LLamaWeights model, IContextParams @params, ILogger? logger = null)
{
if (model.NativeHandle.IsClosed)
throw new ObjectDisposedException("Cannot create context, model weights have been disposed");
diff --git a/LLama/LLamaEmbedder.cs b/LLama/LLamaEmbedder.cs
index 64c17539..54ef07b0 100644
--- a/LLama/LLamaEmbedder.cs
+++ b/LLama/LLamaEmbedder.cs
@@ -18,19 +18,22 @@ namespace LLama
///
public int EmbeddingSize => _ctx.EmbeddingSize;
- ///
- ///
- ///
- ///
- public LLamaEmbedder(IModelParams @params)
+ public LLamaEmbedder(ILLamaParams allParams)
+ : this(allParams, allParams)
{
- @params.EmbeddingMode = true;
- using var weights = LLamaWeights.LoadFromFile(@params);
- _ctx = weights.CreateContext(@params);
}
- public LLamaEmbedder(LLamaWeights weights, IModelParams @params)
+ public LLamaEmbedder(IModelParams modelParams, IContextParams contextParams)
{
+ using var weights = LLamaWeights.LoadFromFile(modelParams);
+
+ contextParams.EmbeddingMode = true;
+ _ctx = weights.CreateContext(contextParams);
+ }
+
+ public LLamaEmbedder(LLamaWeights weights, IContextParams @params)
+ {
+ @params.EmbeddingMode = true;
_ctx = weights.CreateContext(@params);
}
diff --git a/LLama/LLamaStatelessExecutor.cs b/LLama/LLamaStatelessExecutor.cs
index 6854a1f6..ad47541e 100644
--- a/LLama/LLamaStatelessExecutor.cs
+++ b/LLama/LLamaStatelessExecutor.cs
@@ -20,7 +20,7 @@ namespace LLama
: ILLamaExecutor
{
private readonly LLamaWeights _weights;
- private readonly IModelParams _params;
+ private readonly IContextParams _params;
///
/// The context used by the executor when running the inference.
@@ -32,7 +32,7 @@ namespace LLama
///
///
///
- public StatelessExecutor(LLamaWeights weights, IModelParams @params)
+ public StatelessExecutor(LLamaWeights weights, IContextParams @params)
{
_weights = weights;
_params = @params;
@@ -41,20 +41,6 @@ namespace LLama
Context.Dispose();
}
- ///
- /// Create a new stateless executor which will use the model used to create the given context
- ///
- ///
- [Obsolete("Use the constructor which automatically creates contexts using the LLamaWeights")]
- public StatelessExecutor(LLamaContext context)
- {
- _weights = new LLamaWeights(context.NativeHandle.ModelHandle, context.Params.Encoding);
- _params = context.Params;
-
- Context = _weights.CreateContext(_params);
- Context.Dispose();
- }
-
///
public async IAsyncEnumerable InferAsync(string text, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
diff --git a/LLama/LLamaWeights.cs b/LLama/LLamaWeights.cs
index e59f2990..bcc41afb 100644
--- a/LLama/LLamaWeights.cs
+++ b/LLama/LLamaWeights.cs
@@ -1,5 +1,4 @@
using System;
-using System.Text;
using LLama.Abstractions;
using LLama.Extensions;
using LLama.Native;
@@ -20,11 +19,6 @@ namespace LLama
/// Be careful how you use this!
public SafeLlamaModelHandle NativeHandle => _weights;
- ///
- /// Encoding to use to convert text into bytes for the model
- ///
- public Encoding Encoding { get; }
-
///
/// Total number of tokens in vocabulary of this model
///
@@ -50,10 +44,9 @@ namespace LLama
///
public int EmbeddingSize => NativeHandle.EmbeddingSize;
- internal LLamaWeights(SafeLlamaModelHandle weights, Encoding encoding)
+ internal LLamaWeights(SafeLlamaModelHandle weights)
{
_weights = weights;
- Encoding = encoding;
}
///
@@ -66,10 +59,17 @@ namespace LLama
using var pin = @params.ToLlamaModelParams(out var lparams);
var weights = SafeLlamaModelHandle.LoadFromFile(@params.ModelPath, lparams);
- if (!string.IsNullOrEmpty(@params.LoraAdapter))
- weights.ApplyLoraFromFile(@params.LoraAdapter, @params.LoraAdapterScale, @params.LoraBase, @params.Threads);
+ foreach (var adapter in @params.LoraAdapters)
+ {
+ if (string.IsNullOrEmpty(adapter.Path))
+ continue;
+ if (adapter.Scale <= 0)
+ continue;
+
+ weights.ApplyLoraFromFile(adapter.Path, adapter.Scale, @params.LoraBase, @params.Threads);
+ }
- return new LLamaWeights(weights, @params.Encoding);
+ return new LLamaWeights(weights);
}
///
@@ -83,7 +83,7 @@ namespace LLama
///
///
///
- public LLamaContext CreateContext(IModelParams @params)
+ public LLamaContext CreateContext(IContextParams @params)
{
return new LLamaContext(this, @params);
}
diff --git a/LLama/Utils.cs b/LLama/Utils.cs
deleted file mode 100644
index d08501c0..00000000
--- a/LLama/Utils.cs
+++ /dev/null
@@ -1,108 +0,0 @@
-using LLama.Abstractions;
-using LLama.Native;
-using System;
-using System.Collections.Generic;
-using System.Runtime.InteropServices;
-using System.Text;
-using LLama.Extensions;
-
-namespace LLama
-{
- using llama_token = Int32;
-
- ///
- /// Assorted llama utilities
- ///
- public static class Utils
- {
- [Obsolete("Use LLamaWeights.LoadFromFile and LLamaWeights.CreateContext instead")]
- #pragma warning disable CS1591 // Missing XML comment for publicly visible type or member
- public static SafeLLamaContextHandle InitLLamaContextFromModelParams(IModelParams @params)
- #pragma warning restore CS1591 // Missing XML comment for publicly visible type or member
- {
- using var weights = LLamaWeights.LoadFromFile(@params);
-
- @params.ToLlamaContextParams(out var lparams);
- return SafeLLamaContextHandle.Create(weights.NativeHandle, lparams);
- }
-
- [Obsolete("Use SafeLLamaContextHandle Tokenize method instead")]
- #pragma warning disable CS1591 // Missing XML comment for publicly visible type or member
- public static IEnumerable Tokenize(SafeLLamaContextHandle ctx, string text, bool add_bos, Encoding encoding)
- #pragma warning restore CS1591 // Missing XML comment for publicly visible type or member
- {
- return ctx.Tokenize(text, add_bos, encoding);
- }
-
- [Obsolete("Use SafeLLamaContextHandle GetLogits method instead")]
- #pragma warning disable CS1591 // Missing XML comment for publicly visible type or member
- public static Span GetLogits(SafeLLamaContextHandle ctx, int length)
- #pragma warning restore CS1591 // Missing XML comment for publicly visible type or member
- {
- if (length != ctx.VocabCount)
- throw new ArgumentException("length must be the VocabSize");
-
- return ctx.GetLogits();
- }
-
- [Obsolete("Use SafeLLamaContextHandle Eval method instead")]
- #pragma warning disable CS1591 // Missing XML comment for publicly visible type or member
- public static int Eval(SafeLLamaContextHandle ctx, llama_token[] tokens, int startIndex, int n_tokens, int n_past)
- #pragma warning restore CS1591 // Missing XML comment for publicly visible type or member
- {
- var slice = tokens.AsSpan().Slice(startIndex, n_tokens);
- return ctx.Eval(slice, n_past) ? 0 : 1;
- }
-
- [Obsolete("Use SafeLLamaContextHandle TokenToString method instead")]
- #pragma warning disable CS1591 // Missing XML comment for publicly visible type or member
- public static string TokenToString(llama_token token, SafeLLamaContextHandle ctx, Encoding encoding)
- #pragma warning restore CS1591 // Missing XML comment for publicly visible type or member
- {
- return ctx.TokenToString(token, encoding);
- }
-
- [Obsolete("No longer used internally by LlamaSharp")]
- #pragma warning disable CS1591 // Missing XML comment for publicly visible type or member
- public static string PtrToString(IntPtr ptr, Encoding encoding)
- #pragma warning restore CS1591 // Missing XML comment for publicly visible type or member
- {
-#if NET6_0_OR_GREATER
- // ReSharper disable once PossibleUnintendedReferenceComparison
- if(encoding == Encoding.UTF8)
- {
- return Marshal.PtrToStringUTF8(ptr)!;
- }
- // ReSharper disable once PossibleUnintendedReferenceComparison
- else if(encoding == Encoding.Unicode)
- {
- return Marshal.PtrToStringUni(ptr)!;
- }
- else
- {
- return Marshal.PtrToStringAuto(ptr)!;
- }
-#else
- unsafe
- {
- byte* tp = (byte*)ptr.ToPointer();
- List bytes = new();
- while (true)
- {
- byte c = *tp++;
- if (c == '\0')
- {
- break;
- }
- else
- {
- bytes.Add(c);
- }
- }
- return encoding.GetString(bytes.ToArray());
- }
-#endif
- }
-
- }
-}