diff --git a/LLama/Abstractions/IModelParams.cs b/LLama/Abstractions/IModelParams.cs
index 8b4e8497..1ec7022f 100644
--- a/LLama/Abstractions/IModelParams.cs
+++ b/LLama/Abstractions/IModelParams.cs
@@ -34,11 +34,6 @@ namespace LLama.Abstractions
///
string ModelPath { get; set; }
- ///
- /// Number of threads (-1 = autodetect) (n_threads)
- ///
- uint? Threads { get; set; }
-
///
/// how split tensors should be distributed across GPUs
///
diff --git a/LLama/Common/FixedSizeQueue.cs b/LLama/Common/FixedSizeQueue.cs
index 97a4d6ee..d4577a47 100644
--- a/LLama/Common/FixedSizeQueue.cs
+++ b/LLama/Common/FixedSizeQueue.cs
@@ -12,7 +12,6 @@ namespace LLama.Common
public class FixedSizeQueue
: IEnumerable
{
- private readonly int _maxSize;
private readonly List _storage;
internal IReadOnlyList Items => _storage;
@@ -25,7 +24,7 @@ namespace LLama.Common
///
/// Maximum number of items allowed in this queue
///
- public int Capacity => _maxSize;
+ public int Capacity { get; }
///
/// Create a new queue
@@ -33,7 +32,7 @@ namespace LLama.Common
/// the maximum number of items to store in this queue
public FixedSizeQueue(int size)
{
- _maxSize = size;
+ Capacity = size;
_storage = new();
}
@@ -52,11 +51,11 @@ namespace LLama.Common
#endif
// Size of "data" is unknown, copy it all into a list
- _maxSize = size;
+ Capacity = size;
_storage = new List(data);
// Now check if that list is a valid size.
- if (_storage.Count > _maxSize)
+ if (_storage.Count > Capacity)
throw new ArgumentException($"The max size set for the quene is {size}, but got {_storage.Count} initial values.");
}
@@ -81,7 +80,7 @@ namespace LLama.Common
public void Enqueue(T item)
{
_storage.Add(item);
- if(_storage.Count >= _maxSize)
+ if(_storage.Count >= Capacity)
{
_storage.RemoveAt(0);
}
diff --git a/LLama/Common/ModelParams.cs b/LLama/Common/ModelParams.cs
index ed877853..998d4ec4 100644
--- a/LLama/Common/ModelParams.cs
+++ b/LLama/Common/ModelParams.cs
@@ -40,11 +40,11 @@ namespace LLama.Common
///
/// Use mlock to keep model in memory (use_mlock)
///
- public bool UseMemoryLock { get; set; } = false;
+ public bool UseMemoryLock { get; set; }
///
/// Compute perplexity over the prompt (perplexity)
///
- public bool Perplexity { get; set; } = false;
+ public bool Perplexity { get; set; }
///
/// Model path (model)
///
@@ -79,7 +79,7 @@ namespace LLama.Common
/// Whether to use embedding mode. (embedding) Note that if this is set to true,
/// The LLamaModel won't produce text response anymore.
///
- public bool EmbeddingMode { get; set; } = false;
+ public bool EmbeddingMode { get; set; }
///
/// how split tensors should be distributed across GPUs
diff --git a/LLama/Exceptions/GrammarFormatExceptions.cs b/LLama/Exceptions/GrammarFormatExceptions.cs
index 62b75224..4337e505 100644
--- a/LLama/Exceptions/GrammarFormatExceptions.cs
+++ b/LLama/Exceptions/GrammarFormatExceptions.cs
@@ -58,7 +58,7 @@ public class GrammarUnexpectedEndOfInput
: GrammarFormatException
{
internal GrammarUnexpectedEndOfInput()
- : base($"Unexpected end of input")
+ : base("Unexpected end of input")
{
}
}
diff --git a/LLama/Exceptions/RuntimeError.cs b/LLama/Exceptions/RuntimeError.cs
index 6b839ff0..a8ea0531 100644
--- a/LLama/Exceptions/RuntimeError.cs
+++ b/LLama/Exceptions/RuntimeError.cs
@@ -1,19 +1,20 @@
using System;
-namespace LLama.Exceptions
+namespace LLama.Exceptions;
+
+///
+/// Base class for LLamaSharp runtime errors (i.e. errors produced by llama.cpp, converted into exceptions)
+///
+public class RuntimeError
+ : Exception
{
- public class RuntimeError
- : Exception
+ ///
+ /// Create a new RuntimeError
+ ///
+ ///
+ public RuntimeError(string message)
+ : base(message)
{
- public RuntimeError()
- {
-
- }
-
- public RuntimeError(string message)
- : base(message)
- {
- }
}
-}
+}
\ No newline at end of file
diff --git a/LLama/Extensions/EncodingExtensions.cs b/LLama/Extensions/EncodingExtensions.cs
index c0a381c5..e88d83a7 100644
--- a/LLama/Extensions/EncodingExtensions.cs
+++ b/LLama/Extensions/EncodingExtensions.cs
@@ -1,5 +1,4 @@
using System;
-using System.Collections.Generic;
using System.Text;
namespace LLama.Extensions;
diff --git a/LLama/Extensions/KeyValuePairExtensions.cs b/LLama/Extensions/KeyValuePairExtensions.cs
index bf48f3e8..6e12654d 100644
--- a/LLama/Extensions/KeyValuePairExtensions.cs
+++ b/LLama/Extensions/KeyValuePairExtensions.cs
@@ -1,26 +1,23 @@
-using System.Collections.Generic;
+namespace LLama.Extensions;
-namespace LLama.Extensions
+///
+/// Extensions to the KeyValuePair struct
+///
+internal static class KeyValuePairExtensions
{
+#if NETSTANDARD2_0
///
- /// Extensions to the KeyValuePair struct
+ /// Deconstruct a KeyValuePair into it's constituent parts.
///
- internal static class KeyValuePairExtensions
+ /// The KeyValuePair to deconstruct
+ /// First element, the Key
+ /// Second element, the Value
+ /// Type of the Key
+ /// Type of the Value
+ public static void Deconstruct(this System.Collections.Generic.KeyValuePair pair, out TKey first, out TValue second)
{
-#if NETSTANDARD2_0
- ///
- /// Deconstruct a KeyValuePair into it's constituent parts.
- ///
- /// The KeyValuePair to deconstruct
- /// First element, the Key
- /// Second element, the Value
- /// Type of the Key
- /// Type of the Value
- public static void Deconstruct(this KeyValuePair pair, out TKey first, out TValue second)
- {
- first = pair.Key;
- second = pair.Value;
- }
-#endif
+ first = pair.Key;
+ second = pair.Value;
}
-}
+#endif
+}
\ No newline at end of file
diff --git a/LLama/Grammars/GBNFGrammarParser.cs b/LLama/Grammars/GBNFGrammarParser.cs
index 0f36edf8..ac915cf2 100644
--- a/LLama/Grammars/GBNFGrammarParser.cs
+++ b/LLama/Grammars/GBNFGrammarParser.cs
@@ -17,7 +17,7 @@ namespace LLama.Grammars
{
// NOTE: assumes valid utf8 (but checks for overrun)
// copied from llama.cpp
- private uint DecodeUTF8(ref ReadOnlySpan src)
+ private static uint DecodeUTF8(ref ReadOnlySpan src)
{
int[] lookup = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
@@ -40,46 +40,12 @@ namespace LLama.Grammars
return value;
}
- private uint GetSymbolId(ParseState state, ReadOnlySpan src, int len)
- {
- uint nextId = (uint)state.SymbolIds.Count;
- string key = Encoding.UTF8.GetString(src.Slice(0, src.Length - len).ToArray());
-
- if (state.SymbolIds.TryGetValue(key, out uint existingId))
- {
- return existingId;
- }
- else
- {
- state.SymbolIds[key] = nextId;
- return nextId;
- }
- }
-
- private uint GenerateSymbolId(ParseState state, string baseName)
- {
- uint nextId = (uint)state.SymbolIds.Count;
- string key = $"{baseName}_{nextId}";
- state.SymbolIds[key] = nextId;
- return nextId;
- }
-
- private void AddRule(ParseState state, uint ruleId, List rule)
- {
- while (state.Rules.Count <= ruleId)
- {
- state.Rules.Add(new List());
- }
-
- state.Rules[(int)ruleId] = rule;
- }
-
- private bool IsWordChar(byte c)
+ private static bool IsWordChar(byte c)
{
return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || ('0' <= c && c <= '9');
}
- private uint ParseHex(ref ReadOnlySpan src, int size)
+ private static uint ParseHex(ref ReadOnlySpan src, int size)
{
int pos = 0;
int end = size;
@@ -115,7 +81,7 @@ namespace LLama.Grammars
return value;
}
- private ReadOnlySpan ParseSpace(ReadOnlySpan src, bool newlineOk)
+ private static ReadOnlySpan ParseSpace(ReadOnlySpan src, bool newlineOk)
{
int pos = 0;
while (pos < src.Length &&
@@ -137,7 +103,7 @@ namespace LLama.Grammars
return src.Slice(pos);
}
- private ReadOnlySpan ParseName(ReadOnlySpan src)
+ private static ReadOnlySpan ParseName(ReadOnlySpan src)
{
int pos = 0;
while (pos < src.Length && IsWordChar(src[pos]))
@@ -151,7 +117,7 @@ namespace LLama.Grammars
return src.Slice(pos);
}
- private uint ParseChar(ref ReadOnlySpan src)
+ private static uint ParseChar(ref ReadOnlySpan src)
{
if (src[0] == '\\')
{
@@ -235,7 +201,7 @@ namespace LLama.Grammars
else if (IsWordChar(pos[0])) // rule reference
{
var nameEnd = ParseName(pos);
- uint refRuleId = GetSymbolId(state, pos, nameEnd.Length);
+ uint refRuleId = state.GetSymbolId(pos, nameEnd.Length);
pos = ParseSpace(nameEnd, isNested);
lastSymStart = outElements.Count;
outElements.Add(new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, refRuleId));
@@ -244,7 +210,7 @@ namespace LLama.Grammars
{
// parse nested alternates into synthesized rule
pos = ParseSpace(pos.Slice(1), true);
- uint subRuleId = GenerateSymbolId(state, ruleName);
+ uint subRuleId = state.GenerateSymbolId(ruleName);
pos = ParseAlternates(state, pos, ruleName, subRuleId, true);
lastSymStart = outElements.Count;
// output reference to synthesized rule
@@ -263,7 +229,7 @@ namespace LLama.Grammars
// S* --> S' ::= S S' |
// S+ --> S' ::= S S' | S
// S? --> S' ::= S |
- uint subRuleId = GenerateSymbolId(state, ruleName);
+ uint subRuleId = state.GenerateSymbolId(ruleName);
List subRule = new List();
@@ -287,7 +253,7 @@ namespace LLama.Grammars
subRule.Add(new LLamaGrammarElement(LLamaGrammarElementType.END, 0));
- AddRule(state, subRuleId, subRule);
+ state.AddRule(subRuleId, subRule);
// in original rule, replace previous symbol with reference to generated rule
outElements.RemoveRange(lastSymStart, outElements.Count - lastSymStart);
@@ -323,7 +289,7 @@ namespace LLama.Grammars
}
rule.Add(new LLamaGrammarElement(LLamaGrammarElementType.END, 0));
- AddRule(state, ruleId, rule);
+ state.AddRule(ruleId, rule);
return pos;
}
@@ -333,7 +299,7 @@ namespace LLama.Grammars
ReadOnlySpan nameEnd = ParseName(src);
ReadOnlySpan pos = ParseSpace(nameEnd, false);
int nameLen = src.Length - nameEnd.Length;
- uint ruleId = GetSymbolId(state, src.Slice(0, nameLen), 0);
+ uint ruleId = state.GetSymbolId(src.Slice(0, nameLen), 0);
string name = Encoding.UTF8.GetString(src.Slice(0, nameLen).ToArray());
if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '='))
@@ -393,6 +359,40 @@ namespace LLama.Grammars
{
public SortedDictionary SymbolIds { get; } = new();
public List> Rules { get; } = new();
+
+ public uint GetSymbolId(ReadOnlySpan src, int len)
+ {
+ var nextId = (uint)SymbolIds.Count;
+ var key = Encoding.UTF8.GetString(src.Slice(0, src.Length - len).ToArray());
+
+ if (SymbolIds.TryGetValue(key, out uint existingId))
+ {
+ return existingId;
+ }
+ else
+ {
+ SymbolIds[key] = nextId;
+ return nextId;
+ }
+ }
+
+ public uint GenerateSymbolId(string baseName)
+ {
+ var nextId = (uint)SymbolIds.Count;
+ var key = $"{baseName}_{nextId}";
+ SymbolIds[key] = nextId;
+ return nextId;
+ }
+
+ public void AddRule(uint ruleId, List rule)
+ {
+ while (Rules.Count <= ruleId)
+ {
+ Rules.Add(new List());
+ }
+
+ Rules[(int)ruleId] = rule;
+ }
}
}
}
diff --git a/LLama/Grammars/Grammar.cs b/LLama/Grammars/Grammar.cs
index dbb3658e..5135e341 100644
--- a/LLama/Grammars/Grammar.cs
+++ b/LLama/Grammars/Grammar.cs
@@ -112,7 +112,6 @@ namespace LLama.Grammars
case LLamaGrammarElementType.CHAR_ALT:
PrintGrammarChar(output, elem.Value);
break;
-
}
if (elem.IsCharElement())
diff --git a/LLama/LLamaContext.cs b/LLama/LLamaContext.cs
index 5a9f4893..a190c075 100644
--- a/LLama/LLamaContext.cs
+++ b/LLama/LLamaContext.cs
@@ -23,23 +23,21 @@ namespace LLama
: IDisposable
{
private readonly ILogger? _logger;
- private readonly Encoding _encoding;
- private readonly SafeLLamaContextHandle _ctx;
///
/// Total number of tokens in vocabulary of this model
///
- public int VocabCount => _ctx.VocabCount;
+ public int VocabCount => NativeHandle.VocabCount;
///
/// Total number of tokens in the context
///
- public int ContextSize => _ctx.ContextSize;
+ public int ContextSize => NativeHandle.ContextSize;
///
/// Dimension of embedding vectors
///
- public int EmbeddingSize => _ctx.EmbeddingSize;
+ public int EmbeddingSize => NativeHandle.EmbeddingSize;
///
/// The context params set for this context
@@ -50,20 +48,20 @@ namespace LLama
/// The native handle, which is used to be passed to the native APIs
///
/// Be careful how you use this!
- public SafeLLamaContextHandle NativeHandle => _ctx;
+ public SafeLLamaContextHandle NativeHandle { get; }
///
/// The encoding set for this model to deal with text input.
///
- public Encoding Encoding => _encoding;
+ public Encoding Encoding { get; }
internal LLamaContext(SafeLLamaContextHandle nativeContext, IContextParams @params, ILogger? logger = null)
{
Params = @params;
_logger = logger;
- _encoding = @params.Encoding;
- _ctx = nativeContext;
+ Encoding = @params.Encoding;
+ NativeHandle = nativeContext;
}
///
@@ -81,10 +79,10 @@ namespace LLama
Params = @params;
_logger = logger;
- _encoding = @params.Encoding;
+ Encoding = @params.Encoding;
@params.ToLlamaContextParams(out var lparams);
- _ctx = SafeLLamaContextHandle.Create(model.NativeHandle, lparams);
+ NativeHandle = SafeLLamaContextHandle.Create(model.NativeHandle, lparams);
}
///
@@ -96,7 +94,7 @@ namespace LLama
///
public llama_token[] Tokenize(string text, bool addBos = true, bool special = false)
{
- return _ctx.Tokenize(text, addBos, special, _encoding);
+ return NativeHandle.Tokenize(text, addBos, special, Encoding);
}
///
@@ -108,7 +106,7 @@ namespace LLama
{
var sb = new StringBuilder();
foreach (var token in tokens)
- _ctx.TokenToString(token, _encoding, sb);
+ NativeHandle.TokenToString(token, Encoding, sb);
return sb.ToString();
}
@@ -124,7 +122,7 @@ namespace LLama
File.Delete(filename);
// Estimate size of state to write to disk, this is always equal to or greater than the actual size
- var estimatedStateSize = (long)NativeApi.llama_get_state_size(_ctx);
+ var estimatedStateSize = (long)NativeApi.llama_get_state_size(NativeHandle);
// Map the file and write the bytes directly to it. This saves copying the bytes into a C# array
long writtenBytes;
@@ -135,7 +133,7 @@ namespace LLama
{
byte* ptr = null;
view.SafeMemoryMappedViewHandle.AcquirePointer(ref ptr);
- writtenBytes = (long)NativeApi.llama_copy_state_data(_ctx, ptr);
+ writtenBytes = (long)NativeApi.llama_copy_state_data(NativeHandle, ptr);
view.SafeMemoryMappedViewHandle.ReleasePointer();
}
}
@@ -151,14 +149,14 @@ namespace LLama
///
public State GetState()
{
- var stateSize = _ctx.GetStateSize();
+ var stateSize = NativeHandle.GetStateSize();
// Allocate a chunk of memory large enough to hold the entire state
var memory = Marshal.AllocHGlobal((nint)stateSize);
try
{
// Copy the state data into memory, discover the actual size required
- var actualSize = _ctx.GetState(memory, stateSize);
+ var actualSize = NativeHandle.GetState(memory, stateSize);
// Shrink to size
memory = Marshal.ReAllocHGlobal(memory, (nint)actualSize);
@@ -193,7 +191,7 @@ namespace LLama
{
byte* ptr = null;
view.SafeMemoryMappedViewHandle.AcquirePointer(ref ptr);
- NativeApi.llama_set_state_data(_ctx, ptr);
+ NativeApi.llama_set_state_data(NativeHandle, ptr);
view.SafeMemoryMappedViewHandle.ReleasePointer();
}
}
@@ -208,7 +206,7 @@ namespace LLama
{
unsafe
{
- _ctx.SetState((byte*)state.DangerousGetHandle().ToPointer());
+ NativeHandle.SetState((byte*)state.DangerousGetHandle().ToPointer());
}
}
@@ -235,13 +233,13 @@ namespace LLama
if (grammar != null)
{
- SamplingApi.llama_sample_grammar(_ctx, candidates, grammar);
+ SamplingApi.llama_sample_grammar(NativeHandle, candidates, grammar);
}
if (temperature <= 0)
{
// Greedy sampling
- id = SamplingApi.llama_sample_token_greedy(_ctx, candidates);
+ id = SamplingApi.llama_sample_token_greedy(NativeHandle, candidates);
}
else
{
@@ -250,23 +248,23 @@ namespace LLama
if (mirostat == MirostatType.Mirostat)
{
const int mirostat_m = 100;
- SamplingApi.llama_sample_temperature(_ctx, candidates, temperature);
- id = SamplingApi.llama_sample_token_mirostat(_ctx, candidates, mirostatTau, mirostatEta, mirostat_m, ref mu);
+ SamplingApi.llama_sample_temperature(NativeHandle, candidates, temperature);
+ id = SamplingApi.llama_sample_token_mirostat(NativeHandle, candidates, mirostatTau, mirostatEta, mirostat_m, ref mu);
}
else if (mirostat == MirostatType.Mirostat2)
{
- SamplingApi.llama_sample_temperature(_ctx, candidates, temperature);
- id = SamplingApi.llama_sample_token_mirostat_v2(_ctx, candidates, mirostatTau, mirostatEta, ref mu);
+ SamplingApi.llama_sample_temperature(NativeHandle, candidates, temperature);
+ id = SamplingApi.llama_sample_token_mirostat_v2(NativeHandle, candidates, mirostatTau, mirostatEta, ref mu);
}
else
{
// Temperature sampling
- SamplingApi.llama_sample_top_k(_ctx, candidates, topK, 1);
- SamplingApi.llama_sample_tail_free(_ctx, candidates, tfsZ, 1);
- SamplingApi.llama_sample_typical(_ctx, candidates, typicalP, 1);
- SamplingApi.llama_sample_top_p(_ctx, candidates, topP, 1);
- SamplingApi.llama_sample_temperature(_ctx, candidates, temperature);
- id = SamplingApi.llama_sample_token(_ctx, candidates);
+ SamplingApi.llama_sample_top_k(NativeHandle, candidates, topK, 1);
+ SamplingApi.llama_sample_tail_free(NativeHandle, candidates, tfsZ, 1);
+ SamplingApi.llama_sample_typical(NativeHandle, candidates, typicalP, 1);
+ SamplingApi.llama_sample_top_p(NativeHandle, candidates, topP, 1);
+ SamplingApi.llama_sample_temperature(NativeHandle, candidates, temperature);
+ id = SamplingApi.llama_sample_token(NativeHandle, candidates);
}
}
mirostat_mu = mu;
@@ -274,7 +272,7 @@ namespace LLama
if (grammar != null)
{
- NativeApi.llama_grammar_accept_token(_ctx, grammar, id);
+ NativeApi.llama_grammar_accept_token(NativeHandle, grammar, id);
}
return id;
@@ -295,7 +293,7 @@ namespace LLama
int repeatLastTokensCount = 64, float repeatPenalty = 1.1f, float alphaFrequency = .0f, float alphaPresence = .0f,
bool penalizeNL = true)
{
- var logits = _ctx.GetLogits();
+ var logits = NativeHandle.GetLogits();
// Apply params.logit_bias map
if (logitBias is not null)
@@ -305,7 +303,7 @@ namespace LLama
}
// Save the newline logit value
- var nl_token = NativeApi.llama_token_nl(_ctx);
+ var nl_token = NativeApi.llama_token_nl(NativeHandle);
var nl_logit = logits[nl_token];
// Convert logits into token candidates
@@ -316,8 +314,8 @@ namespace LLama
var last_n_array = lastTokens.TakeLast(last_n_repeat).ToArray();
// Apply penalties to candidates
- SamplingApi.llama_sample_repetition_penalty(_ctx, candidates_p, last_n_array, repeatPenalty);
- SamplingApi.llama_sample_frequency_and_presence_penalties(_ctx, candidates_p, last_n_array, alphaFrequency, alphaPresence);
+ SamplingApi.llama_sample_repetition_penalty(NativeHandle, candidates_p, last_n_array, repeatPenalty);
+ SamplingApi.llama_sample_frequency_and_presence_penalties(NativeHandle, candidates_p, last_n_array, alphaFrequency, alphaPresence);
// Restore newline token logit value if necessary
if (!penalizeNL)
@@ -408,9 +406,9 @@ namespace LLama
n_eval = (int)Params.BatchSize;
}
- if (!_ctx.Eval(tokens.Slice(i, n_eval), pastTokensCount))
+ if (!NativeHandle.Eval(tokens.Slice(i, n_eval), pastTokensCount))
{
- _logger?.LogError($"[LLamaContext] Failed to eval.");
+ _logger?.LogError("[LLamaContext] Failed to eval.");
throw new RuntimeError("Failed to eval.");
}
@@ -443,7 +441,7 @@ namespace LLama
///
public void Dispose()
{
- _ctx.Dispose();
+ NativeHandle.Dispose();
}
///
diff --git a/LLama/LLamaEmbedder.cs b/LLama/LLamaEmbedder.cs
index 54ef07b0..fde901b1 100644
--- a/LLama/LLamaEmbedder.cs
+++ b/LLama/LLamaEmbedder.cs
@@ -18,11 +18,22 @@ namespace LLama
///
public int EmbeddingSize => _ctx.EmbeddingSize;
+ ///
+ /// Create a new embedder (loading temporary weights)
+ ///
+ ///
+ [Obsolete("Preload LLamaWeights and use the constructor which accepts them")]
public LLamaEmbedder(ILLamaParams allParams)
: this(allParams, allParams)
{
}
+ ///
+ /// Create a new embedder (loading temporary weights)
+ ///
+ ///
+ ///
+ [Obsolete("Preload LLamaWeights and use the constructor which accepts them")]
public LLamaEmbedder(IModelParams modelParams, IContextParams contextParams)
{
using var weights = LLamaWeights.LoadFromFile(modelParams);
@@ -31,6 +42,11 @@ namespace LLama
_ctx = weights.CreateContext(contextParams);
}
+ ///
+ /// Create a new embedder, using the given LLamaWeights
+ ///
+ ///
+ ///
public LLamaEmbedder(LLamaWeights weights, IContextParams @params)
{
@params.EmbeddingMode = true;
diff --git a/LLama/LLamaExecutorBase.cs b/LLama/LLamaExecutorBase.cs
index df972e47..242ae10b 100644
--- a/LLama/LLamaExecutorBase.cs
+++ b/LLama/LLamaExecutorBase.cs
@@ -114,7 +114,7 @@ namespace LLama
}
else
{
- _logger?.LogWarning($"[LLamaExecutor] Session file does not exist, will create");
+ _logger?.LogWarning("[LLamaExecutor] Session file does not exist, will create");
}
_n_matching_session_tokens = 0;
diff --git a/LLama/LLamaInstructExecutor.cs b/LLama/LLamaInstructExecutor.cs
index 6faa3db2..dd84f218 100644
--- a/LLama/LLamaInstructExecutor.cs
+++ b/LLama/LLamaInstructExecutor.cs
@@ -18,10 +18,10 @@ namespace LLama
///
public class InstructExecutor : StatefulExecutorBase
{
- bool _is_prompt_run = true;
- string _instructionPrefix;
- llama_token[] _inp_pfx;
- llama_token[] _inp_sfx;
+ private bool _is_prompt_run = true;
+ private readonly string _instructionPrefix;
+ private llama_token[] _inp_pfx;
+ private llama_token[] _inp_sfx;
///
///
diff --git a/LLama/LLamaQuantizer.cs b/LLama/LLamaQuantizer.cs
index f1d89586..40632724 100644
--- a/LLama/LLamaQuantizer.cs
+++ b/LLama/LLamaQuantizer.cs
@@ -80,6 +80,7 @@ namespace LLama
return true;
case LLamaFtype.LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16:
+ case LLamaFtype.LLAMA_FTYPE_GUESSED:
default:
return false;
}
diff --git a/LLama/LLamaWeights.cs b/LLama/LLamaWeights.cs
index bcc41afb..a5e2adca 100644
--- a/LLama/LLamaWeights.cs
+++ b/LLama/LLamaWeights.cs
@@ -11,13 +11,11 @@ namespace LLama
public sealed class LLamaWeights
: IDisposable
{
- private readonly SafeLlamaModelHandle _weights;
-
///
/// The native handle, which is used in the native APIs
///
/// Be careful how you use this!
- public SafeLlamaModelHandle NativeHandle => _weights;
+ public SafeLlamaModelHandle NativeHandle { get; }
///
/// Total number of tokens in vocabulary of this model
@@ -46,7 +44,7 @@ namespace LLama
internal LLamaWeights(SafeLlamaModelHandle weights)
{
- _weights = weights;
+ NativeHandle = weights;
}
///
@@ -66,7 +64,7 @@ namespace LLama
if (adapter.Scale <= 0)
continue;
- weights.ApplyLoraFromFile(adapter.Path, adapter.Scale, @params.LoraBase, @params.Threads);
+ weights.ApplyLoraFromFile(adapter.Path, adapter.Scale, @params.LoraBase);
}
return new LLamaWeights(weights);
@@ -75,7 +73,7 @@ namespace LLama
///
public void Dispose()
{
- _weights.Dispose();
+ NativeHandle.Dispose();
}
///
diff --git a/LLama/Native/LLamaBatchSafeHandle.cs b/LLama/Native/LLamaBatchSafeHandle.cs
index fd72ddcd..1f7ef2c5 100644
--- a/LLama/Native/LLamaBatchSafeHandle.cs
+++ b/LLama/Native/LLamaBatchSafeHandle.cs
@@ -4,11 +4,18 @@ namespace LLama.Native;
using llama_token = Int32;
+///
+/// Input data for llama_decode. A llama_batch object can contain input about one or many sequences.
+///
public sealed class LLamaBatchSafeHandle
: SafeLLamaHandleBase
{
private readonly int _embd;
- public LLamaNativeBatch Batch { get; private set; }
+
+ ///
+ /// Get the native llama_batch struct
+ ///
+ public LLamaNativeBatch NativeBatch { get; private set; }
///
/// the token ids of the input (used when embd is NULL)
@@ -22,7 +29,7 @@ public sealed class LLamaBatchSafeHandle
if (_embd != 0)
return new Span(null, 0);
else
- return new Span(Batch.token, Batch.n_tokens);
+ return new Span(NativeBatch.token, NativeBatch.n_tokens);
}
}
}
@@ -37,10 +44,10 @@ public sealed class LLamaBatchSafeHandle
unsafe
{
// If embd != 0, llama_batch.embd will be allocated with size of n_tokens *embd * sizeof(float)
- /// Otherwise, llama_batch.token will be allocated to store n_tokens llama_token
+ // Otherwise, llama_batch.token will be allocated to store n_tokens llama_token
if (_embd != 0)
- return new Span(Batch.embd, Batch.n_tokens * _embd);
+ return new Span(NativeBatch.embd, NativeBatch.n_tokens * _embd);
else
return new Span(null, 0);
}
@@ -56,7 +63,7 @@ public sealed class LLamaBatchSafeHandle
{
unsafe
{
- return new Span(Batch.pos, Batch.n_tokens);
+ return new Span(NativeBatch.pos, NativeBatch.n_tokens);
}
}
}
@@ -70,7 +77,7 @@ public sealed class LLamaBatchSafeHandle
{
unsafe
{
- return new Span(Batch.seq_id, Batch.n_tokens);
+ return new Span(NativeBatch.seq_id, NativeBatch.n_tokens);
}
}
}
@@ -84,22 +91,40 @@ public sealed class LLamaBatchSafeHandle
{
unsafe
{
- return new Span(Batch.logits, Batch.n_tokens);
+ return new Span(NativeBatch.logits, NativeBatch.n_tokens);
}
}
}
- public LLamaBatchSafeHandle(int n_tokens, int embd)
+ ///
+ /// Create a safe handle owning a `LLamaNativeBatch`
+ ///
+ ///
+ ///
+ public LLamaBatchSafeHandle(LLamaNativeBatch batch, int embd)
: base((nint)1)
{
_embd = embd;
- Batch = NativeApi.llama_batch_init(n_tokens, embd);
+ NativeBatch = batch;
+ }
+
+ ///
+ /// Call `llama_batch_init` and create a new batch
+ ///
+ ///
+ ///
+ ///
+ public static LLamaBatchSafeHandle Create(int n_tokens, int embd)
+ {
+ var batch = NativeApi.llama_batch_init(n_tokens, embd);
+ return new LLamaBatchSafeHandle(batch, embd);
}
+ ///
protected override bool ReleaseHandle()
{
- NativeApi.llama_batch_free(Batch);
- Batch = default;
+ NativeApi.llama_batch_free(NativeBatch);
+ NativeBatch = default;
SetHandle(IntPtr.Zero);
return true;
}
diff --git a/LLama/Native/LLamaGrammarElement.cs b/LLama/Native/LLamaGrammarElement.cs
index 688f5ccb..5a7f27bb 100644
--- a/LLama/Native/LLamaGrammarElement.cs
+++ b/LLama/Native/LLamaGrammarElement.cs
@@ -45,7 +45,7 @@ namespace LLama.Native
/// CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA])
///
CHAR_ALT = 6,
- };
+ }
///
/// An element of a grammar
diff --git a/LLama/Native/LLamaPos.cs b/LLama/Native/LLamaPos.cs
index 18dc8294..4deae57b 100644
--- a/LLama/Native/LLamaPos.cs
+++ b/LLama/Native/LLamaPos.cs
@@ -1,15 +1,26 @@
namespace LLama.Native;
-public record struct LLamaPos
+///
+/// Indicates position in a sequence
+///
+public readonly record struct LLamaPos(int Value)
{
- public int Value;
-
- public LLamaPos(int value)
- {
- Value = value;
- }
+ ///
+ /// The raw value
+ ///
+ public readonly int Value = Value;
+ ///
+ /// Convert a LLamaPos into an integer (extract the raw value)
+ ///
+ ///
+ ///
public static explicit operator int(LLamaPos pos) => pos.Value;
+ ///
+ /// Convert an integer into a LLamaPos
+ ///
+ ///
+ ///
public static implicit operator LLamaPos(int value) => new(value);
}
\ No newline at end of file
diff --git a/LLama/Native/LLamaSeqId.cs b/LLama/Native/LLamaSeqId.cs
index d148fe4d..4a665a2c 100644
--- a/LLama/Native/LLamaSeqId.cs
+++ b/LLama/Native/LLamaSeqId.cs
@@ -1,15 +1,26 @@
namespace LLama.Native;
-public record struct LLamaSeqId
+///
+/// ID for a sequence in a batch
+///
+///
+public record struct LLamaSeqId(int Value)
{
- public int Value;
-
- public LLamaSeqId(int value)
- {
- Value = value;
- }
+ ///
+ /// The raw value
+ ///
+ public int Value = Value;
+ ///
+ /// Convert a LLamaSeqId into an integer (extract the raw value)
+ ///
+ ///
public static explicit operator int(LLamaSeqId pos) => pos.Value;
+ ///
+ /// Convert an integer into a LLamaSeqId
+ ///
+ ///
+ ///
public static explicit operator LLamaSeqId(int value) => new(value);
}
\ No newline at end of file
diff --git a/LLama/Native/LLamaTokenData.cs b/LLama/Native/LLamaTokenData.cs
index 0d3a56fc..1ea6820d 100644
--- a/LLama/Native/LLamaTokenData.cs
+++ b/LLama/Native/LLamaTokenData.cs
@@ -1,28 +1,28 @@
using System.Runtime.InteropServices;
-namespace LLama.Native
+namespace LLama.Native;
+
+///
+/// A single token along with probability of this token being selected
+///
+///
+///
+///
+[StructLayout(LayoutKind.Sequential)]
+public record struct LLamaTokenData(int id, float logit, float p)
{
- [StructLayout(LayoutKind.Sequential)]
- public struct LLamaTokenData
- {
- ///
- /// token id
- ///
- public int id;
- ///
- /// log-odds of the token
- ///
- public float logit;
- ///
- /// probability of the token
- ///
- public float p;
+ ///
+ /// token id
+ ///
+ public int id = id;
+
+ ///
+ /// log-odds of the token
+ ///
+ public float logit = logit;
- public LLamaTokenData(int id, float logit, float p)
- {
- this.id = id;
- this.logit = logit;
- this.p = p;
- }
- }
-}
+ ///
+ /// probability of the token
+ ///
+ public float p = p;
+}
\ No newline at end of file
diff --git a/LLama/Native/SafeLLamaContextHandle.cs b/LLama/Native/SafeLLamaContextHandle.cs
index 88572254..6d5e87ce 100644
--- a/LLama/Native/SafeLLamaContextHandle.cs
+++ b/LLama/Native/SafeLLamaContextHandle.cs
@@ -1,7 +1,5 @@
using System;
using System.Buffers;
-using System.Runtime.CompilerServices;
-using System.Runtime.InteropServices;
using System.Text;
using LLama.Exceptions;
@@ -212,9 +210,17 @@ namespace LLama.Native
}
}
+ ///
+ ///
+ ///
+ /// Positive return values does not mean a fatal error, but rather a warning:
+ /// - 0: success
+ /// - 1: could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
+ /// - < 0: error
+ ///
public int Decode(LLamaBatchSafeHandle batch)
{
- return NativeApi.llama_decode(this, batch.Batch);
+ return NativeApi.llama_decode(this, batch.NativeBatch);
}
#region state
diff --git a/LLama/Native/SafeLlamaModelHandle.cs b/LLama/Native/SafeLlamaModelHandle.cs
index adf6bd54..5f3900e9 100644
--- a/LLama/Native/SafeLlamaModelHandle.cs
+++ b/LLama/Native/SafeLlamaModelHandle.cs
@@ -84,14 +84,14 @@ namespace LLama.Native
/// adapter. Can be NULL to use the current loaded model.
///
///
- public void ApplyLoraFromFile(string lora, float scale, string? modelBase = null, uint? threads = null)
+ public void ApplyLoraFromFile(string lora, float scale, string? modelBase = null, int? threads = null)
{
var err = NativeApi.llama_model_apply_lora_from_file(
this,
lora,
scale,
string.IsNullOrEmpty(modelBase) ? null : modelBase,
- (int?)threads ?? -1
+ threads ?? Math.Max(1, Environment.ProcessorCount / 2)
);
if (err != 0)