diff --git a/LLama/Abstractions/IModelParams.cs b/LLama/Abstractions/IModelParams.cs index 8b4e8497..1ec7022f 100644 --- a/LLama/Abstractions/IModelParams.cs +++ b/LLama/Abstractions/IModelParams.cs @@ -34,11 +34,6 @@ namespace LLama.Abstractions /// string ModelPath { get; set; } - /// - /// Number of threads (-1 = autodetect) (n_threads) - /// - uint? Threads { get; set; } - /// /// how split tensors should be distributed across GPUs /// diff --git a/LLama/Common/FixedSizeQueue.cs b/LLama/Common/FixedSizeQueue.cs index 97a4d6ee..d4577a47 100644 --- a/LLama/Common/FixedSizeQueue.cs +++ b/LLama/Common/FixedSizeQueue.cs @@ -12,7 +12,6 @@ namespace LLama.Common public class FixedSizeQueue : IEnumerable { - private readonly int _maxSize; private readonly List _storage; internal IReadOnlyList Items => _storage; @@ -25,7 +24,7 @@ namespace LLama.Common /// /// Maximum number of items allowed in this queue /// - public int Capacity => _maxSize; + public int Capacity { get; } /// /// Create a new queue @@ -33,7 +32,7 @@ namespace LLama.Common /// the maximum number of items to store in this queue public FixedSizeQueue(int size) { - _maxSize = size; + Capacity = size; _storage = new(); } @@ -52,11 +51,11 @@ namespace LLama.Common #endif // Size of "data" is unknown, copy it all into a list - _maxSize = size; + Capacity = size; _storage = new List(data); // Now check if that list is a valid size. - if (_storage.Count > _maxSize) + if (_storage.Count > Capacity) throw new ArgumentException($"The max size set for the quene is {size}, but got {_storage.Count} initial values."); } @@ -81,7 +80,7 @@ namespace LLama.Common public void Enqueue(T item) { _storage.Add(item); - if(_storage.Count >= _maxSize) + if(_storage.Count >= Capacity) { _storage.RemoveAt(0); } diff --git a/LLama/Common/ModelParams.cs b/LLama/Common/ModelParams.cs index ed877853..998d4ec4 100644 --- a/LLama/Common/ModelParams.cs +++ b/LLama/Common/ModelParams.cs @@ -40,11 +40,11 @@ namespace LLama.Common /// /// Use mlock to keep model in memory (use_mlock) /// - public bool UseMemoryLock { get; set; } = false; + public bool UseMemoryLock { get; set; } /// /// Compute perplexity over the prompt (perplexity) /// - public bool Perplexity { get; set; } = false; + public bool Perplexity { get; set; } /// /// Model path (model) /// @@ -79,7 +79,7 @@ namespace LLama.Common /// Whether to use embedding mode. (embedding) Note that if this is set to true, /// The LLamaModel won't produce text response anymore. /// - public bool EmbeddingMode { get; set; } = false; + public bool EmbeddingMode { get; set; } /// /// how split tensors should be distributed across GPUs diff --git a/LLama/Exceptions/GrammarFormatExceptions.cs b/LLama/Exceptions/GrammarFormatExceptions.cs index 62b75224..4337e505 100644 --- a/LLama/Exceptions/GrammarFormatExceptions.cs +++ b/LLama/Exceptions/GrammarFormatExceptions.cs @@ -58,7 +58,7 @@ public class GrammarUnexpectedEndOfInput : GrammarFormatException { internal GrammarUnexpectedEndOfInput() - : base($"Unexpected end of input") + : base("Unexpected end of input") { } } diff --git a/LLama/Exceptions/RuntimeError.cs b/LLama/Exceptions/RuntimeError.cs index 6b839ff0..a8ea0531 100644 --- a/LLama/Exceptions/RuntimeError.cs +++ b/LLama/Exceptions/RuntimeError.cs @@ -1,19 +1,20 @@ using System; -namespace LLama.Exceptions +namespace LLama.Exceptions; + +/// +/// Base class for LLamaSharp runtime errors (i.e. errors produced by llama.cpp, converted into exceptions) +/// +public class RuntimeError + : Exception { - public class RuntimeError - : Exception + /// + /// Create a new RuntimeError + /// + /// + public RuntimeError(string message) + : base(message) { - public RuntimeError() - { - - } - - public RuntimeError(string message) - : base(message) - { - } } -} +} \ No newline at end of file diff --git a/LLama/Extensions/EncodingExtensions.cs b/LLama/Extensions/EncodingExtensions.cs index c0a381c5..e88d83a7 100644 --- a/LLama/Extensions/EncodingExtensions.cs +++ b/LLama/Extensions/EncodingExtensions.cs @@ -1,5 +1,4 @@ using System; -using System.Collections.Generic; using System.Text; namespace LLama.Extensions; diff --git a/LLama/Extensions/KeyValuePairExtensions.cs b/LLama/Extensions/KeyValuePairExtensions.cs index bf48f3e8..6e12654d 100644 --- a/LLama/Extensions/KeyValuePairExtensions.cs +++ b/LLama/Extensions/KeyValuePairExtensions.cs @@ -1,26 +1,23 @@ -using System.Collections.Generic; +namespace LLama.Extensions; -namespace LLama.Extensions +/// +/// Extensions to the KeyValuePair struct +/// +internal static class KeyValuePairExtensions { +#if NETSTANDARD2_0 /// - /// Extensions to the KeyValuePair struct + /// Deconstruct a KeyValuePair into it's constituent parts. /// - internal static class KeyValuePairExtensions + /// The KeyValuePair to deconstruct + /// First element, the Key + /// Second element, the Value + /// Type of the Key + /// Type of the Value + public static void Deconstruct(this System.Collections.Generic.KeyValuePair pair, out TKey first, out TValue second) { -#if NETSTANDARD2_0 - /// - /// Deconstruct a KeyValuePair into it's constituent parts. - /// - /// The KeyValuePair to deconstruct - /// First element, the Key - /// Second element, the Value - /// Type of the Key - /// Type of the Value - public static void Deconstruct(this KeyValuePair pair, out TKey first, out TValue second) - { - first = pair.Key; - second = pair.Value; - } -#endif + first = pair.Key; + second = pair.Value; } -} +#endif +} \ No newline at end of file diff --git a/LLama/Grammars/GBNFGrammarParser.cs b/LLama/Grammars/GBNFGrammarParser.cs index 0f36edf8..ac915cf2 100644 --- a/LLama/Grammars/GBNFGrammarParser.cs +++ b/LLama/Grammars/GBNFGrammarParser.cs @@ -17,7 +17,7 @@ namespace LLama.Grammars { // NOTE: assumes valid utf8 (but checks for overrun) // copied from llama.cpp - private uint DecodeUTF8(ref ReadOnlySpan src) + private static uint DecodeUTF8(ref ReadOnlySpan src) { int[] lookup = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 }; @@ -40,46 +40,12 @@ namespace LLama.Grammars return value; } - private uint GetSymbolId(ParseState state, ReadOnlySpan src, int len) - { - uint nextId = (uint)state.SymbolIds.Count; - string key = Encoding.UTF8.GetString(src.Slice(0, src.Length - len).ToArray()); - - if (state.SymbolIds.TryGetValue(key, out uint existingId)) - { - return existingId; - } - else - { - state.SymbolIds[key] = nextId; - return nextId; - } - } - - private uint GenerateSymbolId(ParseState state, string baseName) - { - uint nextId = (uint)state.SymbolIds.Count; - string key = $"{baseName}_{nextId}"; - state.SymbolIds[key] = nextId; - return nextId; - } - - private void AddRule(ParseState state, uint ruleId, List rule) - { - while (state.Rules.Count <= ruleId) - { - state.Rules.Add(new List()); - } - - state.Rules[(int)ruleId] = rule; - } - - private bool IsWordChar(byte c) + private static bool IsWordChar(byte c) { return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || ('0' <= c && c <= '9'); } - private uint ParseHex(ref ReadOnlySpan src, int size) + private static uint ParseHex(ref ReadOnlySpan src, int size) { int pos = 0; int end = size; @@ -115,7 +81,7 @@ namespace LLama.Grammars return value; } - private ReadOnlySpan ParseSpace(ReadOnlySpan src, bool newlineOk) + private static ReadOnlySpan ParseSpace(ReadOnlySpan src, bool newlineOk) { int pos = 0; while (pos < src.Length && @@ -137,7 +103,7 @@ namespace LLama.Grammars return src.Slice(pos); } - private ReadOnlySpan ParseName(ReadOnlySpan src) + private static ReadOnlySpan ParseName(ReadOnlySpan src) { int pos = 0; while (pos < src.Length && IsWordChar(src[pos])) @@ -151,7 +117,7 @@ namespace LLama.Grammars return src.Slice(pos); } - private uint ParseChar(ref ReadOnlySpan src) + private static uint ParseChar(ref ReadOnlySpan src) { if (src[0] == '\\') { @@ -235,7 +201,7 @@ namespace LLama.Grammars else if (IsWordChar(pos[0])) // rule reference { var nameEnd = ParseName(pos); - uint refRuleId = GetSymbolId(state, pos, nameEnd.Length); + uint refRuleId = state.GetSymbolId(pos, nameEnd.Length); pos = ParseSpace(nameEnd, isNested); lastSymStart = outElements.Count; outElements.Add(new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, refRuleId)); @@ -244,7 +210,7 @@ namespace LLama.Grammars { // parse nested alternates into synthesized rule pos = ParseSpace(pos.Slice(1), true); - uint subRuleId = GenerateSymbolId(state, ruleName); + uint subRuleId = state.GenerateSymbolId(ruleName); pos = ParseAlternates(state, pos, ruleName, subRuleId, true); lastSymStart = outElements.Count; // output reference to synthesized rule @@ -263,7 +229,7 @@ namespace LLama.Grammars // S* --> S' ::= S S' | // S+ --> S' ::= S S' | S // S? --> S' ::= S | - uint subRuleId = GenerateSymbolId(state, ruleName); + uint subRuleId = state.GenerateSymbolId(ruleName); List subRule = new List(); @@ -287,7 +253,7 @@ namespace LLama.Grammars subRule.Add(new LLamaGrammarElement(LLamaGrammarElementType.END, 0)); - AddRule(state, subRuleId, subRule); + state.AddRule(subRuleId, subRule); // in original rule, replace previous symbol with reference to generated rule outElements.RemoveRange(lastSymStart, outElements.Count - lastSymStart); @@ -323,7 +289,7 @@ namespace LLama.Grammars } rule.Add(new LLamaGrammarElement(LLamaGrammarElementType.END, 0)); - AddRule(state, ruleId, rule); + state.AddRule(ruleId, rule); return pos; } @@ -333,7 +299,7 @@ namespace LLama.Grammars ReadOnlySpan nameEnd = ParseName(src); ReadOnlySpan pos = ParseSpace(nameEnd, false); int nameLen = src.Length - nameEnd.Length; - uint ruleId = GetSymbolId(state, src.Slice(0, nameLen), 0); + uint ruleId = state.GetSymbolId(src.Slice(0, nameLen), 0); string name = Encoding.UTF8.GetString(src.Slice(0, nameLen).ToArray()); if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '=')) @@ -393,6 +359,40 @@ namespace LLama.Grammars { public SortedDictionary SymbolIds { get; } = new(); public List> Rules { get; } = new(); + + public uint GetSymbolId(ReadOnlySpan src, int len) + { + var nextId = (uint)SymbolIds.Count; + var key = Encoding.UTF8.GetString(src.Slice(0, src.Length - len).ToArray()); + + if (SymbolIds.TryGetValue(key, out uint existingId)) + { + return existingId; + } + else + { + SymbolIds[key] = nextId; + return nextId; + } + } + + public uint GenerateSymbolId(string baseName) + { + var nextId = (uint)SymbolIds.Count; + var key = $"{baseName}_{nextId}"; + SymbolIds[key] = nextId; + return nextId; + } + + public void AddRule(uint ruleId, List rule) + { + while (Rules.Count <= ruleId) + { + Rules.Add(new List()); + } + + Rules[(int)ruleId] = rule; + } } } } diff --git a/LLama/Grammars/Grammar.cs b/LLama/Grammars/Grammar.cs index dbb3658e..5135e341 100644 --- a/LLama/Grammars/Grammar.cs +++ b/LLama/Grammars/Grammar.cs @@ -112,7 +112,6 @@ namespace LLama.Grammars case LLamaGrammarElementType.CHAR_ALT: PrintGrammarChar(output, elem.Value); break; - } if (elem.IsCharElement()) diff --git a/LLama/LLamaContext.cs b/LLama/LLamaContext.cs index 5a9f4893..a190c075 100644 --- a/LLama/LLamaContext.cs +++ b/LLama/LLamaContext.cs @@ -23,23 +23,21 @@ namespace LLama : IDisposable { private readonly ILogger? _logger; - private readonly Encoding _encoding; - private readonly SafeLLamaContextHandle _ctx; /// /// Total number of tokens in vocabulary of this model /// - public int VocabCount => _ctx.VocabCount; + public int VocabCount => NativeHandle.VocabCount; /// /// Total number of tokens in the context /// - public int ContextSize => _ctx.ContextSize; + public int ContextSize => NativeHandle.ContextSize; /// /// Dimension of embedding vectors /// - public int EmbeddingSize => _ctx.EmbeddingSize; + public int EmbeddingSize => NativeHandle.EmbeddingSize; /// /// The context params set for this context @@ -50,20 +48,20 @@ namespace LLama /// The native handle, which is used to be passed to the native APIs /// /// Be careful how you use this! - public SafeLLamaContextHandle NativeHandle => _ctx; + public SafeLLamaContextHandle NativeHandle { get; } /// /// The encoding set for this model to deal with text input. /// - public Encoding Encoding => _encoding; + public Encoding Encoding { get; } internal LLamaContext(SafeLLamaContextHandle nativeContext, IContextParams @params, ILogger? logger = null) { Params = @params; _logger = logger; - _encoding = @params.Encoding; - _ctx = nativeContext; + Encoding = @params.Encoding; + NativeHandle = nativeContext; } /// @@ -81,10 +79,10 @@ namespace LLama Params = @params; _logger = logger; - _encoding = @params.Encoding; + Encoding = @params.Encoding; @params.ToLlamaContextParams(out var lparams); - _ctx = SafeLLamaContextHandle.Create(model.NativeHandle, lparams); + NativeHandle = SafeLLamaContextHandle.Create(model.NativeHandle, lparams); } /// @@ -96,7 +94,7 @@ namespace LLama /// public llama_token[] Tokenize(string text, bool addBos = true, bool special = false) { - return _ctx.Tokenize(text, addBos, special, _encoding); + return NativeHandle.Tokenize(text, addBos, special, Encoding); } /// @@ -108,7 +106,7 @@ namespace LLama { var sb = new StringBuilder(); foreach (var token in tokens) - _ctx.TokenToString(token, _encoding, sb); + NativeHandle.TokenToString(token, Encoding, sb); return sb.ToString(); } @@ -124,7 +122,7 @@ namespace LLama File.Delete(filename); // Estimate size of state to write to disk, this is always equal to or greater than the actual size - var estimatedStateSize = (long)NativeApi.llama_get_state_size(_ctx); + var estimatedStateSize = (long)NativeApi.llama_get_state_size(NativeHandle); // Map the file and write the bytes directly to it. This saves copying the bytes into a C# array long writtenBytes; @@ -135,7 +133,7 @@ namespace LLama { byte* ptr = null; view.SafeMemoryMappedViewHandle.AcquirePointer(ref ptr); - writtenBytes = (long)NativeApi.llama_copy_state_data(_ctx, ptr); + writtenBytes = (long)NativeApi.llama_copy_state_data(NativeHandle, ptr); view.SafeMemoryMappedViewHandle.ReleasePointer(); } } @@ -151,14 +149,14 @@ namespace LLama /// public State GetState() { - var stateSize = _ctx.GetStateSize(); + var stateSize = NativeHandle.GetStateSize(); // Allocate a chunk of memory large enough to hold the entire state var memory = Marshal.AllocHGlobal((nint)stateSize); try { // Copy the state data into memory, discover the actual size required - var actualSize = _ctx.GetState(memory, stateSize); + var actualSize = NativeHandle.GetState(memory, stateSize); // Shrink to size memory = Marshal.ReAllocHGlobal(memory, (nint)actualSize); @@ -193,7 +191,7 @@ namespace LLama { byte* ptr = null; view.SafeMemoryMappedViewHandle.AcquirePointer(ref ptr); - NativeApi.llama_set_state_data(_ctx, ptr); + NativeApi.llama_set_state_data(NativeHandle, ptr); view.SafeMemoryMappedViewHandle.ReleasePointer(); } } @@ -208,7 +206,7 @@ namespace LLama { unsafe { - _ctx.SetState((byte*)state.DangerousGetHandle().ToPointer()); + NativeHandle.SetState((byte*)state.DangerousGetHandle().ToPointer()); } } @@ -235,13 +233,13 @@ namespace LLama if (grammar != null) { - SamplingApi.llama_sample_grammar(_ctx, candidates, grammar); + SamplingApi.llama_sample_grammar(NativeHandle, candidates, grammar); } if (temperature <= 0) { // Greedy sampling - id = SamplingApi.llama_sample_token_greedy(_ctx, candidates); + id = SamplingApi.llama_sample_token_greedy(NativeHandle, candidates); } else { @@ -250,23 +248,23 @@ namespace LLama if (mirostat == MirostatType.Mirostat) { const int mirostat_m = 100; - SamplingApi.llama_sample_temperature(_ctx, candidates, temperature); - id = SamplingApi.llama_sample_token_mirostat(_ctx, candidates, mirostatTau, mirostatEta, mirostat_m, ref mu); + SamplingApi.llama_sample_temperature(NativeHandle, candidates, temperature); + id = SamplingApi.llama_sample_token_mirostat(NativeHandle, candidates, mirostatTau, mirostatEta, mirostat_m, ref mu); } else if (mirostat == MirostatType.Mirostat2) { - SamplingApi.llama_sample_temperature(_ctx, candidates, temperature); - id = SamplingApi.llama_sample_token_mirostat_v2(_ctx, candidates, mirostatTau, mirostatEta, ref mu); + SamplingApi.llama_sample_temperature(NativeHandle, candidates, temperature); + id = SamplingApi.llama_sample_token_mirostat_v2(NativeHandle, candidates, mirostatTau, mirostatEta, ref mu); } else { // Temperature sampling - SamplingApi.llama_sample_top_k(_ctx, candidates, topK, 1); - SamplingApi.llama_sample_tail_free(_ctx, candidates, tfsZ, 1); - SamplingApi.llama_sample_typical(_ctx, candidates, typicalP, 1); - SamplingApi.llama_sample_top_p(_ctx, candidates, topP, 1); - SamplingApi.llama_sample_temperature(_ctx, candidates, temperature); - id = SamplingApi.llama_sample_token(_ctx, candidates); + SamplingApi.llama_sample_top_k(NativeHandle, candidates, topK, 1); + SamplingApi.llama_sample_tail_free(NativeHandle, candidates, tfsZ, 1); + SamplingApi.llama_sample_typical(NativeHandle, candidates, typicalP, 1); + SamplingApi.llama_sample_top_p(NativeHandle, candidates, topP, 1); + SamplingApi.llama_sample_temperature(NativeHandle, candidates, temperature); + id = SamplingApi.llama_sample_token(NativeHandle, candidates); } } mirostat_mu = mu; @@ -274,7 +272,7 @@ namespace LLama if (grammar != null) { - NativeApi.llama_grammar_accept_token(_ctx, grammar, id); + NativeApi.llama_grammar_accept_token(NativeHandle, grammar, id); } return id; @@ -295,7 +293,7 @@ namespace LLama int repeatLastTokensCount = 64, float repeatPenalty = 1.1f, float alphaFrequency = .0f, float alphaPresence = .0f, bool penalizeNL = true) { - var logits = _ctx.GetLogits(); + var logits = NativeHandle.GetLogits(); // Apply params.logit_bias map if (logitBias is not null) @@ -305,7 +303,7 @@ namespace LLama } // Save the newline logit value - var nl_token = NativeApi.llama_token_nl(_ctx); + var nl_token = NativeApi.llama_token_nl(NativeHandle); var nl_logit = logits[nl_token]; // Convert logits into token candidates @@ -316,8 +314,8 @@ namespace LLama var last_n_array = lastTokens.TakeLast(last_n_repeat).ToArray(); // Apply penalties to candidates - SamplingApi.llama_sample_repetition_penalty(_ctx, candidates_p, last_n_array, repeatPenalty); - SamplingApi.llama_sample_frequency_and_presence_penalties(_ctx, candidates_p, last_n_array, alphaFrequency, alphaPresence); + SamplingApi.llama_sample_repetition_penalty(NativeHandle, candidates_p, last_n_array, repeatPenalty); + SamplingApi.llama_sample_frequency_and_presence_penalties(NativeHandle, candidates_p, last_n_array, alphaFrequency, alphaPresence); // Restore newline token logit value if necessary if (!penalizeNL) @@ -408,9 +406,9 @@ namespace LLama n_eval = (int)Params.BatchSize; } - if (!_ctx.Eval(tokens.Slice(i, n_eval), pastTokensCount)) + if (!NativeHandle.Eval(tokens.Slice(i, n_eval), pastTokensCount)) { - _logger?.LogError($"[LLamaContext] Failed to eval."); + _logger?.LogError("[LLamaContext] Failed to eval."); throw new RuntimeError("Failed to eval."); } @@ -443,7 +441,7 @@ namespace LLama /// public void Dispose() { - _ctx.Dispose(); + NativeHandle.Dispose(); } /// diff --git a/LLama/LLamaEmbedder.cs b/LLama/LLamaEmbedder.cs index 54ef07b0..fde901b1 100644 --- a/LLama/LLamaEmbedder.cs +++ b/LLama/LLamaEmbedder.cs @@ -18,11 +18,22 @@ namespace LLama /// public int EmbeddingSize => _ctx.EmbeddingSize; + /// + /// Create a new embedder (loading temporary weights) + /// + /// + [Obsolete("Preload LLamaWeights and use the constructor which accepts them")] public LLamaEmbedder(ILLamaParams allParams) : this(allParams, allParams) { } + /// + /// Create a new embedder (loading temporary weights) + /// + /// + /// + [Obsolete("Preload LLamaWeights and use the constructor which accepts them")] public LLamaEmbedder(IModelParams modelParams, IContextParams contextParams) { using var weights = LLamaWeights.LoadFromFile(modelParams); @@ -31,6 +42,11 @@ namespace LLama _ctx = weights.CreateContext(contextParams); } + /// + /// Create a new embedder, using the given LLamaWeights + /// + /// + /// public LLamaEmbedder(LLamaWeights weights, IContextParams @params) { @params.EmbeddingMode = true; diff --git a/LLama/LLamaExecutorBase.cs b/LLama/LLamaExecutorBase.cs index df972e47..242ae10b 100644 --- a/LLama/LLamaExecutorBase.cs +++ b/LLama/LLamaExecutorBase.cs @@ -114,7 +114,7 @@ namespace LLama } else { - _logger?.LogWarning($"[LLamaExecutor] Session file does not exist, will create"); + _logger?.LogWarning("[LLamaExecutor] Session file does not exist, will create"); } _n_matching_session_tokens = 0; diff --git a/LLama/LLamaInstructExecutor.cs b/LLama/LLamaInstructExecutor.cs index 6faa3db2..dd84f218 100644 --- a/LLama/LLamaInstructExecutor.cs +++ b/LLama/LLamaInstructExecutor.cs @@ -18,10 +18,10 @@ namespace LLama /// public class InstructExecutor : StatefulExecutorBase { - bool _is_prompt_run = true; - string _instructionPrefix; - llama_token[] _inp_pfx; - llama_token[] _inp_sfx; + private bool _is_prompt_run = true; + private readonly string _instructionPrefix; + private llama_token[] _inp_pfx; + private llama_token[] _inp_sfx; /// /// diff --git a/LLama/LLamaQuantizer.cs b/LLama/LLamaQuantizer.cs index f1d89586..40632724 100644 --- a/LLama/LLamaQuantizer.cs +++ b/LLama/LLamaQuantizer.cs @@ -80,6 +80,7 @@ namespace LLama return true; case LLamaFtype.LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16: + case LLamaFtype.LLAMA_FTYPE_GUESSED: default: return false; } diff --git a/LLama/LLamaWeights.cs b/LLama/LLamaWeights.cs index bcc41afb..a5e2adca 100644 --- a/LLama/LLamaWeights.cs +++ b/LLama/LLamaWeights.cs @@ -11,13 +11,11 @@ namespace LLama public sealed class LLamaWeights : IDisposable { - private readonly SafeLlamaModelHandle _weights; - /// /// The native handle, which is used in the native APIs /// /// Be careful how you use this! - public SafeLlamaModelHandle NativeHandle => _weights; + public SafeLlamaModelHandle NativeHandle { get; } /// /// Total number of tokens in vocabulary of this model @@ -46,7 +44,7 @@ namespace LLama internal LLamaWeights(SafeLlamaModelHandle weights) { - _weights = weights; + NativeHandle = weights; } /// @@ -66,7 +64,7 @@ namespace LLama if (adapter.Scale <= 0) continue; - weights.ApplyLoraFromFile(adapter.Path, adapter.Scale, @params.LoraBase, @params.Threads); + weights.ApplyLoraFromFile(adapter.Path, adapter.Scale, @params.LoraBase); } return new LLamaWeights(weights); @@ -75,7 +73,7 @@ namespace LLama /// public void Dispose() { - _weights.Dispose(); + NativeHandle.Dispose(); } /// diff --git a/LLama/Native/LLamaBatchSafeHandle.cs b/LLama/Native/LLamaBatchSafeHandle.cs index fd72ddcd..1f7ef2c5 100644 --- a/LLama/Native/LLamaBatchSafeHandle.cs +++ b/LLama/Native/LLamaBatchSafeHandle.cs @@ -4,11 +4,18 @@ namespace LLama.Native; using llama_token = Int32; +/// +/// Input data for llama_decode. A llama_batch object can contain input about one or many sequences. +/// public sealed class LLamaBatchSafeHandle : SafeLLamaHandleBase { private readonly int _embd; - public LLamaNativeBatch Batch { get; private set; } + + /// + /// Get the native llama_batch struct + /// + public LLamaNativeBatch NativeBatch { get; private set; } /// /// the token ids of the input (used when embd is NULL) @@ -22,7 +29,7 @@ public sealed class LLamaBatchSafeHandle if (_embd != 0) return new Span(null, 0); else - return new Span(Batch.token, Batch.n_tokens); + return new Span(NativeBatch.token, NativeBatch.n_tokens); } } } @@ -37,10 +44,10 @@ public sealed class LLamaBatchSafeHandle unsafe { // If embd != 0, llama_batch.embd will be allocated with size of n_tokens *embd * sizeof(float) - /// Otherwise, llama_batch.token will be allocated to store n_tokens llama_token + // Otherwise, llama_batch.token will be allocated to store n_tokens llama_token if (_embd != 0) - return new Span(Batch.embd, Batch.n_tokens * _embd); + return new Span(NativeBatch.embd, NativeBatch.n_tokens * _embd); else return new Span(null, 0); } @@ -56,7 +63,7 @@ public sealed class LLamaBatchSafeHandle { unsafe { - return new Span(Batch.pos, Batch.n_tokens); + return new Span(NativeBatch.pos, NativeBatch.n_tokens); } } } @@ -70,7 +77,7 @@ public sealed class LLamaBatchSafeHandle { unsafe { - return new Span(Batch.seq_id, Batch.n_tokens); + return new Span(NativeBatch.seq_id, NativeBatch.n_tokens); } } } @@ -84,22 +91,40 @@ public sealed class LLamaBatchSafeHandle { unsafe { - return new Span(Batch.logits, Batch.n_tokens); + return new Span(NativeBatch.logits, NativeBatch.n_tokens); } } } - public LLamaBatchSafeHandle(int n_tokens, int embd) + /// + /// Create a safe handle owning a `LLamaNativeBatch` + /// + /// + /// + public LLamaBatchSafeHandle(LLamaNativeBatch batch, int embd) : base((nint)1) { _embd = embd; - Batch = NativeApi.llama_batch_init(n_tokens, embd); + NativeBatch = batch; + } + + /// + /// Call `llama_batch_init` and create a new batch + /// + /// + /// + /// + public static LLamaBatchSafeHandle Create(int n_tokens, int embd) + { + var batch = NativeApi.llama_batch_init(n_tokens, embd); + return new LLamaBatchSafeHandle(batch, embd); } + /// protected override bool ReleaseHandle() { - NativeApi.llama_batch_free(Batch); - Batch = default; + NativeApi.llama_batch_free(NativeBatch); + NativeBatch = default; SetHandle(IntPtr.Zero); return true; } diff --git a/LLama/Native/LLamaGrammarElement.cs b/LLama/Native/LLamaGrammarElement.cs index 688f5ccb..5a7f27bb 100644 --- a/LLama/Native/LLamaGrammarElement.cs +++ b/LLama/Native/LLamaGrammarElement.cs @@ -45,7 +45,7 @@ namespace LLama.Native /// CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA]) /// CHAR_ALT = 6, - }; + } /// /// An element of a grammar diff --git a/LLama/Native/LLamaPos.cs b/LLama/Native/LLamaPos.cs index 18dc8294..4deae57b 100644 --- a/LLama/Native/LLamaPos.cs +++ b/LLama/Native/LLamaPos.cs @@ -1,15 +1,26 @@ namespace LLama.Native; -public record struct LLamaPos +/// +/// Indicates position in a sequence +/// +public readonly record struct LLamaPos(int Value) { - public int Value; - - public LLamaPos(int value) - { - Value = value; - } + /// + /// The raw value + /// + public readonly int Value = Value; + /// + /// Convert a LLamaPos into an integer (extract the raw value) + /// + /// + /// public static explicit operator int(LLamaPos pos) => pos.Value; + /// + /// Convert an integer into a LLamaPos + /// + /// + /// public static implicit operator LLamaPos(int value) => new(value); } \ No newline at end of file diff --git a/LLama/Native/LLamaSeqId.cs b/LLama/Native/LLamaSeqId.cs index d148fe4d..4a665a2c 100644 --- a/LLama/Native/LLamaSeqId.cs +++ b/LLama/Native/LLamaSeqId.cs @@ -1,15 +1,26 @@ namespace LLama.Native; -public record struct LLamaSeqId +/// +/// ID for a sequence in a batch +/// +/// +public record struct LLamaSeqId(int Value) { - public int Value; - - public LLamaSeqId(int value) - { - Value = value; - } + /// + /// The raw value + /// + public int Value = Value; + /// + /// Convert a LLamaSeqId into an integer (extract the raw value) + /// + /// public static explicit operator int(LLamaSeqId pos) => pos.Value; + /// + /// Convert an integer into a LLamaSeqId + /// + /// + /// public static explicit operator LLamaSeqId(int value) => new(value); } \ No newline at end of file diff --git a/LLama/Native/LLamaTokenData.cs b/LLama/Native/LLamaTokenData.cs index 0d3a56fc..1ea6820d 100644 --- a/LLama/Native/LLamaTokenData.cs +++ b/LLama/Native/LLamaTokenData.cs @@ -1,28 +1,28 @@ using System.Runtime.InteropServices; -namespace LLama.Native +namespace LLama.Native; + +/// +/// A single token along with probability of this token being selected +/// +/// +/// +/// +[StructLayout(LayoutKind.Sequential)] +public record struct LLamaTokenData(int id, float logit, float p) { - [StructLayout(LayoutKind.Sequential)] - public struct LLamaTokenData - { - /// - /// token id - /// - public int id; - /// - /// log-odds of the token - /// - public float logit; - /// - /// probability of the token - /// - public float p; + /// + /// token id + /// + public int id = id; + + /// + /// log-odds of the token + /// + public float logit = logit; - public LLamaTokenData(int id, float logit, float p) - { - this.id = id; - this.logit = logit; - this.p = p; - } - } -} + /// + /// probability of the token + /// + public float p = p; +} \ No newline at end of file diff --git a/LLama/Native/SafeLLamaContextHandle.cs b/LLama/Native/SafeLLamaContextHandle.cs index 88572254..6d5e87ce 100644 --- a/LLama/Native/SafeLLamaContextHandle.cs +++ b/LLama/Native/SafeLLamaContextHandle.cs @@ -1,7 +1,5 @@ using System; using System.Buffers; -using System.Runtime.CompilerServices; -using System.Runtime.InteropServices; using System.Text; using LLama.Exceptions; @@ -212,9 +210,17 @@ namespace LLama.Native } } + /// + /// + /// + /// Positive return values does not mean a fatal error, but rather a warning:
+ /// - 0: success
+ /// - 1: could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
+ /// - < 0: error
+ ///
public int Decode(LLamaBatchSafeHandle batch) { - return NativeApi.llama_decode(this, batch.Batch); + return NativeApi.llama_decode(this, batch.NativeBatch); } #region state diff --git a/LLama/Native/SafeLlamaModelHandle.cs b/LLama/Native/SafeLlamaModelHandle.cs index adf6bd54..5f3900e9 100644 --- a/LLama/Native/SafeLlamaModelHandle.cs +++ b/LLama/Native/SafeLlamaModelHandle.cs @@ -84,14 +84,14 @@ namespace LLama.Native /// adapter. Can be NULL to use the current loaded model. /// /// - public void ApplyLoraFromFile(string lora, float scale, string? modelBase = null, uint? threads = null) + public void ApplyLoraFromFile(string lora, float scale, string? modelBase = null, int? threads = null) { var err = NativeApi.llama_model_apply_lora_from_file( this, lora, scale, string.IsNullOrEmpty(modelBase) ? null : modelBase, - (int?)threads ?? -1 + threads ?? Math.Max(1, Environment.ProcessorCount / 2) ); if (err != 0)