| @@ -34,11 +34,6 @@ namespace LLama.Abstractions | |||
| /// </summary> | |||
| string ModelPath { get; set; } | |||
| /// <summary> | |||
| /// Number of threads (-1 = autodetect) (n_threads) | |||
| /// </summary> | |||
| uint? Threads { get; set; } | |||
| /// <summary> | |||
| /// how split tensors should be distributed across GPUs | |||
| /// </summary> | |||
| @@ -12,7 +12,6 @@ namespace LLama.Common | |||
| public class FixedSizeQueue<T> | |||
| : IEnumerable<T> | |||
| { | |||
| private readonly int _maxSize; | |||
| private readonly List<T> _storage; | |||
| internal IReadOnlyList<T> Items => _storage; | |||
| @@ -25,7 +24,7 @@ namespace LLama.Common | |||
| /// <summary> | |||
| /// Maximum number of items allowed in this queue | |||
| /// </summary> | |||
| public int Capacity => _maxSize; | |||
| public int Capacity { get; } | |||
| /// <summary> | |||
| /// Create a new queue | |||
| @@ -33,7 +32,7 @@ namespace LLama.Common | |||
| /// <param name="size">the maximum number of items to store in this queue</param> | |||
| public FixedSizeQueue(int size) | |||
| { | |||
| _maxSize = size; | |||
| Capacity = size; | |||
| _storage = new(); | |||
| } | |||
| @@ -52,11 +51,11 @@ namespace LLama.Common | |||
| #endif | |||
| // Size of "data" is unknown, copy it all into a list | |||
| _maxSize = size; | |||
| Capacity = size; | |||
| _storage = new List<T>(data); | |||
| // Now check if that list is a valid size. | |||
| if (_storage.Count > _maxSize) | |||
| if (_storage.Count > Capacity) | |||
| throw new ArgumentException($"The max size set for the quene is {size}, but got {_storage.Count} initial values."); | |||
| } | |||
| @@ -81,7 +80,7 @@ namespace LLama.Common | |||
| public void Enqueue(T item) | |||
| { | |||
| _storage.Add(item); | |||
| if(_storage.Count >= _maxSize) | |||
| if(_storage.Count >= Capacity) | |||
| { | |||
| _storage.RemoveAt(0); | |||
| } | |||
| @@ -40,11 +40,11 @@ namespace LLama.Common | |||
| /// <summary> | |||
| /// Use mlock to keep model in memory (use_mlock) | |||
| /// </summary> | |||
| public bool UseMemoryLock { get; set; } = false; | |||
| public bool UseMemoryLock { get; set; } | |||
| /// <summary> | |||
| /// Compute perplexity over the prompt (perplexity) | |||
| /// </summary> | |||
| public bool Perplexity { get; set; } = false; | |||
| public bool Perplexity { get; set; } | |||
| /// <summary> | |||
| /// Model path (model) | |||
| /// </summary> | |||
| @@ -79,7 +79,7 @@ namespace LLama.Common | |||
| /// Whether to use embedding mode. (embedding) Note that if this is set to true, | |||
| /// The LLamaModel won't produce text response anymore. | |||
| /// </summary> | |||
| public bool EmbeddingMode { get; set; } = false; | |||
| public bool EmbeddingMode { get; set; } | |||
| /// <summary> | |||
| /// how split tensors should be distributed across GPUs | |||
| @@ -58,7 +58,7 @@ public class GrammarUnexpectedEndOfInput | |||
| : GrammarFormatException | |||
| { | |||
| internal GrammarUnexpectedEndOfInput() | |||
| : base($"Unexpected end of input") | |||
| : base("Unexpected end of input") | |||
| { | |||
| } | |||
| } | |||
| @@ -1,19 +1,20 @@ | |||
| using System; | |||
| namespace LLama.Exceptions | |||
| namespace LLama.Exceptions; | |||
| /// <summary> | |||
| /// Base class for LLamaSharp runtime errors (i.e. errors produced by llama.cpp, converted into exceptions) | |||
| /// </summary> | |||
| public class RuntimeError | |||
| : Exception | |||
| { | |||
| public class RuntimeError | |||
| : Exception | |||
| /// <summary> | |||
| /// Create a new RuntimeError | |||
| /// </summary> | |||
| /// <param name="message"></param> | |||
| public RuntimeError(string message) | |||
| : base(message) | |||
| { | |||
| public RuntimeError() | |||
| { | |||
| } | |||
| public RuntimeError(string message) | |||
| : base(message) | |||
| { | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -1,5 +1,4 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace LLama.Extensions; | |||
| @@ -1,26 +1,23 @@ | |||
| using System.Collections.Generic; | |||
| namespace LLama.Extensions; | |||
| namespace LLama.Extensions | |||
| /// <summary> | |||
| /// Extensions to the KeyValuePair struct | |||
| /// </summary> | |||
| internal static class KeyValuePairExtensions | |||
| { | |||
| #if NETSTANDARD2_0 | |||
| /// <summary> | |||
| /// Extensions to the KeyValuePair struct | |||
| /// Deconstruct a KeyValuePair into it's constituent parts. | |||
| /// </summary> | |||
| internal static class KeyValuePairExtensions | |||
| /// <param name="pair">The KeyValuePair to deconstruct</param> | |||
| /// <param name="first">First element, the Key</param> | |||
| /// <param name="second">Second element, the Value</param> | |||
| /// <typeparam name="TKey">Type of the Key</typeparam> | |||
| /// <typeparam name="TValue">Type of the Value</typeparam> | |||
| public static void Deconstruct<TKey, TValue>(this System.Collections.Generic.KeyValuePair<TKey, TValue> pair, out TKey first, out TValue second) | |||
| { | |||
| #if NETSTANDARD2_0 | |||
| /// <summary> | |||
| /// Deconstruct a KeyValuePair into it's constituent parts. | |||
| /// </summary> | |||
| /// <param name="pair">The KeyValuePair to deconstruct</param> | |||
| /// <param name="first">First element, the Key</param> | |||
| /// <param name="second">Second element, the Value</param> | |||
| /// <typeparam name="TKey">Type of the Key</typeparam> | |||
| /// <typeparam name="TValue">Type of the Value</typeparam> | |||
| public static void Deconstruct<TKey, TValue>(this KeyValuePair<TKey, TValue> pair, out TKey first, out TValue second) | |||
| { | |||
| first = pair.Key; | |||
| second = pair.Value; | |||
| } | |||
| #endif | |||
| first = pair.Key; | |||
| second = pair.Value; | |||
| } | |||
| } | |||
| #endif | |||
| } | |||
| @@ -17,7 +17,7 @@ namespace LLama.Grammars | |||
| { | |||
| // NOTE: assumes valid utf8 (but checks for overrun) | |||
| // copied from llama.cpp | |||
| private uint DecodeUTF8(ref ReadOnlySpan<byte> src) | |||
| private static uint DecodeUTF8(ref ReadOnlySpan<byte> src) | |||
| { | |||
| int[] lookup = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 }; | |||
| @@ -40,46 +40,12 @@ namespace LLama.Grammars | |||
| return value; | |||
| } | |||
| private uint GetSymbolId(ParseState state, ReadOnlySpan<byte> src, int len) | |||
| { | |||
| uint nextId = (uint)state.SymbolIds.Count; | |||
| string key = Encoding.UTF8.GetString(src.Slice(0, src.Length - len).ToArray()); | |||
| if (state.SymbolIds.TryGetValue(key, out uint existingId)) | |||
| { | |||
| return existingId; | |||
| } | |||
| else | |||
| { | |||
| state.SymbolIds[key] = nextId; | |||
| return nextId; | |||
| } | |||
| } | |||
| private uint GenerateSymbolId(ParseState state, string baseName) | |||
| { | |||
| uint nextId = (uint)state.SymbolIds.Count; | |||
| string key = $"{baseName}_{nextId}"; | |||
| state.SymbolIds[key] = nextId; | |||
| return nextId; | |||
| } | |||
| private void AddRule(ParseState state, uint ruleId, List<LLamaGrammarElement> rule) | |||
| { | |||
| while (state.Rules.Count <= ruleId) | |||
| { | |||
| state.Rules.Add(new List<LLamaGrammarElement>()); | |||
| } | |||
| state.Rules[(int)ruleId] = rule; | |||
| } | |||
| private bool IsWordChar(byte c) | |||
| private static bool IsWordChar(byte c) | |||
| { | |||
| return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || ('0' <= c && c <= '9'); | |||
| } | |||
| private uint ParseHex(ref ReadOnlySpan<byte> src, int size) | |||
| private static uint ParseHex(ref ReadOnlySpan<byte> src, int size) | |||
| { | |||
| int pos = 0; | |||
| int end = size; | |||
| @@ -115,7 +81,7 @@ namespace LLama.Grammars | |||
| return value; | |||
| } | |||
| private ReadOnlySpan<byte> ParseSpace(ReadOnlySpan<byte> src, bool newlineOk) | |||
| private static ReadOnlySpan<byte> ParseSpace(ReadOnlySpan<byte> src, bool newlineOk) | |||
| { | |||
| int pos = 0; | |||
| while (pos < src.Length && | |||
| @@ -137,7 +103,7 @@ namespace LLama.Grammars | |||
| return src.Slice(pos); | |||
| } | |||
| private ReadOnlySpan<byte> ParseName(ReadOnlySpan<byte> src) | |||
| private static ReadOnlySpan<byte> ParseName(ReadOnlySpan<byte> src) | |||
| { | |||
| int pos = 0; | |||
| while (pos < src.Length && IsWordChar(src[pos])) | |||
| @@ -151,7 +117,7 @@ namespace LLama.Grammars | |||
| return src.Slice(pos); | |||
| } | |||
| private uint ParseChar(ref ReadOnlySpan<byte> src) | |||
| private static uint ParseChar(ref ReadOnlySpan<byte> src) | |||
| { | |||
| if (src[0] == '\\') | |||
| { | |||
| @@ -235,7 +201,7 @@ namespace LLama.Grammars | |||
| else if (IsWordChar(pos[0])) // rule reference | |||
| { | |||
| var nameEnd = ParseName(pos); | |||
| uint refRuleId = GetSymbolId(state, pos, nameEnd.Length); | |||
| uint refRuleId = state.GetSymbolId(pos, nameEnd.Length); | |||
| pos = ParseSpace(nameEnd, isNested); | |||
| lastSymStart = outElements.Count; | |||
| outElements.Add(new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, refRuleId)); | |||
| @@ -244,7 +210,7 @@ namespace LLama.Grammars | |||
| { | |||
| // parse nested alternates into synthesized rule | |||
| pos = ParseSpace(pos.Slice(1), true); | |||
| uint subRuleId = GenerateSymbolId(state, ruleName); | |||
| uint subRuleId = state.GenerateSymbolId(ruleName); | |||
| pos = ParseAlternates(state, pos, ruleName, subRuleId, true); | |||
| lastSymStart = outElements.Count; | |||
| // output reference to synthesized rule | |||
| @@ -263,7 +229,7 @@ namespace LLama.Grammars | |||
| // S* --> S' ::= S S' | | |||
| // S+ --> S' ::= S S' | S | |||
| // S? --> S' ::= S | | |||
| uint subRuleId = GenerateSymbolId(state, ruleName); | |||
| uint subRuleId = state.GenerateSymbolId(ruleName); | |||
| List<LLamaGrammarElement> subRule = new List<LLamaGrammarElement>(); | |||
| @@ -287,7 +253,7 @@ namespace LLama.Grammars | |||
| subRule.Add(new LLamaGrammarElement(LLamaGrammarElementType.END, 0)); | |||
| AddRule(state, subRuleId, subRule); | |||
| state.AddRule(subRuleId, subRule); | |||
| // in original rule, replace previous symbol with reference to generated rule | |||
| outElements.RemoveRange(lastSymStart, outElements.Count - lastSymStart); | |||
| @@ -323,7 +289,7 @@ namespace LLama.Grammars | |||
| } | |||
| rule.Add(new LLamaGrammarElement(LLamaGrammarElementType.END, 0)); | |||
| AddRule(state, ruleId, rule); | |||
| state.AddRule(ruleId, rule); | |||
| return pos; | |||
| } | |||
| @@ -333,7 +299,7 @@ namespace LLama.Grammars | |||
| ReadOnlySpan<byte> nameEnd = ParseName(src); | |||
| ReadOnlySpan<byte> pos = ParseSpace(nameEnd, false); | |||
| int nameLen = src.Length - nameEnd.Length; | |||
| uint ruleId = GetSymbolId(state, src.Slice(0, nameLen), 0); | |||
| uint ruleId = state.GetSymbolId(src.Slice(0, nameLen), 0); | |||
| string name = Encoding.UTF8.GetString(src.Slice(0, nameLen).ToArray()); | |||
| if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '=')) | |||
| @@ -393,6 +359,40 @@ namespace LLama.Grammars | |||
| { | |||
| public SortedDictionary<string, uint> SymbolIds { get; } = new(); | |||
| public List<List<LLamaGrammarElement>> Rules { get; } = new(); | |||
| public uint GetSymbolId(ReadOnlySpan<byte> src, int len) | |||
| { | |||
| var nextId = (uint)SymbolIds.Count; | |||
| var key = Encoding.UTF8.GetString(src.Slice(0, src.Length - len).ToArray()); | |||
| if (SymbolIds.TryGetValue(key, out uint existingId)) | |||
| { | |||
| return existingId; | |||
| } | |||
| else | |||
| { | |||
| SymbolIds[key] = nextId; | |||
| return nextId; | |||
| } | |||
| } | |||
| public uint GenerateSymbolId(string baseName) | |||
| { | |||
| var nextId = (uint)SymbolIds.Count; | |||
| var key = $"{baseName}_{nextId}"; | |||
| SymbolIds[key] = nextId; | |||
| return nextId; | |||
| } | |||
| public void AddRule(uint ruleId, List<LLamaGrammarElement> rule) | |||
| { | |||
| while (Rules.Count <= ruleId) | |||
| { | |||
| Rules.Add(new List<LLamaGrammarElement>()); | |||
| } | |||
| Rules[(int)ruleId] = rule; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -112,7 +112,6 @@ namespace LLama.Grammars | |||
| case LLamaGrammarElementType.CHAR_ALT: | |||
| PrintGrammarChar(output, elem.Value); | |||
| break; | |||
| } | |||
| if (elem.IsCharElement()) | |||
| @@ -23,23 +23,21 @@ namespace LLama | |||
| : IDisposable | |||
| { | |||
| private readonly ILogger? _logger; | |||
| private readonly Encoding _encoding; | |||
| private readonly SafeLLamaContextHandle _ctx; | |||
| /// <summary> | |||
| /// Total number of tokens in vocabulary of this model | |||
| /// </summary> | |||
| public int VocabCount => _ctx.VocabCount; | |||
| public int VocabCount => NativeHandle.VocabCount; | |||
| /// <summary> | |||
| /// Total number of tokens in the context | |||
| /// </summary> | |||
| public int ContextSize => _ctx.ContextSize; | |||
| public int ContextSize => NativeHandle.ContextSize; | |||
| /// <summary> | |||
| /// Dimension of embedding vectors | |||
| /// </summary> | |||
| public int EmbeddingSize => _ctx.EmbeddingSize; | |||
| public int EmbeddingSize => NativeHandle.EmbeddingSize; | |||
| /// <summary> | |||
| /// The context params set for this context | |||
| @@ -50,20 +48,20 @@ namespace LLama | |||
| /// The native handle, which is used to be passed to the native APIs | |||
| /// </summary> | |||
| /// <remarks>Be careful how you use this!</remarks> | |||
| public SafeLLamaContextHandle NativeHandle => _ctx; | |||
| public SafeLLamaContextHandle NativeHandle { get; } | |||
| /// <summary> | |||
| /// The encoding set for this model to deal with text input. | |||
| /// </summary> | |||
| public Encoding Encoding => _encoding; | |||
| public Encoding Encoding { get; } | |||
| internal LLamaContext(SafeLLamaContextHandle nativeContext, IContextParams @params, ILogger? logger = null) | |||
| { | |||
| Params = @params; | |||
| _logger = logger; | |||
| _encoding = @params.Encoding; | |||
| _ctx = nativeContext; | |||
| Encoding = @params.Encoding; | |||
| NativeHandle = nativeContext; | |||
| } | |||
| /// <summary> | |||
| @@ -81,10 +79,10 @@ namespace LLama | |||
| Params = @params; | |||
| _logger = logger; | |||
| _encoding = @params.Encoding; | |||
| Encoding = @params.Encoding; | |||
| @params.ToLlamaContextParams(out var lparams); | |||
| _ctx = SafeLLamaContextHandle.Create(model.NativeHandle, lparams); | |||
| NativeHandle = SafeLLamaContextHandle.Create(model.NativeHandle, lparams); | |||
| } | |||
| /// <summary> | |||
| @@ -96,7 +94,7 @@ namespace LLama | |||
| /// <returns></returns> | |||
| public llama_token[] Tokenize(string text, bool addBos = true, bool special = false) | |||
| { | |||
| return _ctx.Tokenize(text, addBos, special, _encoding); | |||
| return NativeHandle.Tokenize(text, addBos, special, Encoding); | |||
| } | |||
| /// <summary> | |||
| @@ -108,7 +106,7 @@ namespace LLama | |||
| { | |||
| var sb = new StringBuilder(); | |||
| foreach (var token in tokens) | |||
| _ctx.TokenToString(token, _encoding, sb); | |||
| NativeHandle.TokenToString(token, Encoding, sb); | |||
| return sb.ToString(); | |||
| } | |||
| @@ -124,7 +122,7 @@ namespace LLama | |||
| File.Delete(filename); | |||
| // Estimate size of state to write to disk, this is always equal to or greater than the actual size | |||
| var estimatedStateSize = (long)NativeApi.llama_get_state_size(_ctx); | |||
| var estimatedStateSize = (long)NativeApi.llama_get_state_size(NativeHandle); | |||
| // Map the file and write the bytes directly to it. This saves copying the bytes into a C# array | |||
| long writtenBytes; | |||
| @@ -135,7 +133,7 @@ namespace LLama | |||
| { | |||
| byte* ptr = null; | |||
| view.SafeMemoryMappedViewHandle.AcquirePointer(ref ptr); | |||
| writtenBytes = (long)NativeApi.llama_copy_state_data(_ctx, ptr); | |||
| writtenBytes = (long)NativeApi.llama_copy_state_data(NativeHandle, ptr); | |||
| view.SafeMemoryMappedViewHandle.ReleasePointer(); | |||
| } | |||
| } | |||
| @@ -151,14 +149,14 @@ namespace LLama | |||
| /// <returns></returns> | |||
| public State GetState() | |||
| { | |||
| var stateSize = _ctx.GetStateSize(); | |||
| var stateSize = NativeHandle.GetStateSize(); | |||
| // Allocate a chunk of memory large enough to hold the entire state | |||
| var memory = Marshal.AllocHGlobal((nint)stateSize); | |||
| try | |||
| { | |||
| // Copy the state data into memory, discover the actual size required | |||
| var actualSize = _ctx.GetState(memory, stateSize); | |||
| var actualSize = NativeHandle.GetState(memory, stateSize); | |||
| // Shrink to size | |||
| memory = Marshal.ReAllocHGlobal(memory, (nint)actualSize); | |||
| @@ -193,7 +191,7 @@ namespace LLama | |||
| { | |||
| byte* ptr = null; | |||
| view.SafeMemoryMappedViewHandle.AcquirePointer(ref ptr); | |||
| NativeApi.llama_set_state_data(_ctx, ptr); | |||
| NativeApi.llama_set_state_data(NativeHandle, ptr); | |||
| view.SafeMemoryMappedViewHandle.ReleasePointer(); | |||
| } | |||
| } | |||
| @@ -208,7 +206,7 @@ namespace LLama | |||
| { | |||
| unsafe | |||
| { | |||
| _ctx.SetState((byte*)state.DangerousGetHandle().ToPointer()); | |||
| NativeHandle.SetState((byte*)state.DangerousGetHandle().ToPointer()); | |||
| } | |||
| } | |||
| @@ -235,13 +233,13 @@ namespace LLama | |||
| if (grammar != null) | |||
| { | |||
| SamplingApi.llama_sample_grammar(_ctx, candidates, grammar); | |||
| SamplingApi.llama_sample_grammar(NativeHandle, candidates, grammar); | |||
| } | |||
| if (temperature <= 0) | |||
| { | |||
| // Greedy sampling | |||
| id = SamplingApi.llama_sample_token_greedy(_ctx, candidates); | |||
| id = SamplingApi.llama_sample_token_greedy(NativeHandle, candidates); | |||
| } | |||
| else | |||
| { | |||
| @@ -250,23 +248,23 @@ namespace LLama | |||
| if (mirostat == MirostatType.Mirostat) | |||
| { | |||
| const int mirostat_m = 100; | |||
| SamplingApi.llama_sample_temperature(_ctx, candidates, temperature); | |||
| id = SamplingApi.llama_sample_token_mirostat(_ctx, candidates, mirostatTau, mirostatEta, mirostat_m, ref mu); | |||
| SamplingApi.llama_sample_temperature(NativeHandle, candidates, temperature); | |||
| id = SamplingApi.llama_sample_token_mirostat(NativeHandle, candidates, mirostatTau, mirostatEta, mirostat_m, ref mu); | |||
| } | |||
| else if (mirostat == MirostatType.Mirostat2) | |||
| { | |||
| SamplingApi.llama_sample_temperature(_ctx, candidates, temperature); | |||
| id = SamplingApi.llama_sample_token_mirostat_v2(_ctx, candidates, mirostatTau, mirostatEta, ref mu); | |||
| SamplingApi.llama_sample_temperature(NativeHandle, candidates, temperature); | |||
| id = SamplingApi.llama_sample_token_mirostat_v2(NativeHandle, candidates, mirostatTau, mirostatEta, ref mu); | |||
| } | |||
| else | |||
| { | |||
| // Temperature sampling | |||
| SamplingApi.llama_sample_top_k(_ctx, candidates, topK, 1); | |||
| SamplingApi.llama_sample_tail_free(_ctx, candidates, tfsZ, 1); | |||
| SamplingApi.llama_sample_typical(_ctx, candidates, typicalP, 1); | |||
| SamplingApi.llama_sample_top_p(_ctx, candidates, topP, 1); | |||
| SamplingApi.llama_sample_temperature(_ctx, candidates, temperature); | |||
| id = SamplingApi.llama_sample_token(_ctx, candidates); | |||
| SamplingApi.llama_sample_top_k(NativeHandle, candidates, topK, 1); | |||
| SamplingApi.llama_sample_tail_free(NativeHandle, candidates, tfsZ, 1); | |||
| SamplingApi.llama_sample_typical(NativeHandle, candidates, typicalP, 1); | |||
| SamplingApi.llama_sample_top_p(NativeHandle, candidates, topP, 1); | |||
| SamplingApi.llama_sample_temperature(NativeHandle, candidates, temperature); | |||
| id = SamplingApi.llama_sample_token(NativeHandle, candidates); | |||
| } | |||
| } | |||
| mirostat_mu = mu; | |||
| @@ -274,7 +272,7 @@ namespace LLama | |||
| if (grammar != null) | |||
| { | |||
| NativeApi.llama_grammar_accept_token(_ctx, grammar, id); | |||
| NativeApi.llama_grammar_accept_token(NativeHandle, grammar, id); | |||
| } | |||
| return id; | |||
| @@ -295,7 +293,7 @@ namespace LLama | |||
| int repeatLastTokensCount = 64, float repeatPenalty = 1.1f, float alphaFrequency = .0f, float alphaPresence = .0f, | |||
| bool penalizeNL = true) | |||
| { | |||
| var logits = _ctx.GetLogits(); | |||
| var logits = NativeHandle.GetLogits(); | |||
| // Apply params.logit_bias map | |||
| if (logitBias is not null) | |||
| @@ -305,7 +303,7 @@ namespace LLama | |||
| } | |||
| // Save the newline logit value | |||
| var nl_token = NativeApi.llama_token_nl(_ctx); | |||
| var nl_token = NativeApi.llama_token_nl(NativeHandle); | |||
| var nl_logit = logits[nl_token]; | |||
| // Convert logits into token candidates | |||
| @@ -316,8 +314,8 @@ namespace LLama | |||
| var last_n_array = lastTokens.TakeLast(last_n_repeat).ToArray(); | |||
| // Apply penalties to candidates | |||
| SamplingApi.llama_sample_repetition_penalty(_ctx, candidates_p, last_n_array, repeatPenalty); | |||
| SamplingApi.llama_sample_frequency_and_presence_penalties(_ctx, candidates_p, last_n_array, alphaFrequency, alphaPresence); | |||
| SamplingApi.llama_sample_repetition_penalty(NativeHandle, candidates_p, last_n_array, repeatPenalty); | |||
| SamplingApi.llama_sample_frequency_and_presence_penalties(NativeHandle, candidates_p, last_n_array, alphaFrequency, alphaPresence); | |||
| // Restore newline token logit value if necessary | |||
| if (!penalizeNL) | |||
| @@ -408,9 +406,9 @@ namespace LLama | |||
| n_eval = (int)Params.BatchSize; | |||
| } | |||
| if (!_ctx.Eval(tokens.Slice(i, n_eval), pastTokensCount)) | |||
| if (!NativeHandle.Eval(tokens.Slice(i, n_eval), pastTokensCount)) | |||
| { | |||
| _logger?.LogError($"[LLamaContext] Failed to eval."); | |||
| _logger?.LogError("[LLamaContext] Failed to eval."); | |||
| throw new RuntimeError("Failed to eval."); | |||
| } | |||
| @@ -443,7 +441,7 @@ namespace LLama | |||
| /// <inheritdoc /> | |||
| public void Dispose() | |||
| { | |||
| _ctx.Dispose(); | |||
| NativeHandle.Dispose(); | |||
| } | |||
| /// <summary> | |||
| @@ -18,11 +18,22 @@ namespace LLama | |||
| /// </summary> | |||
| public int EmbeddingSize => _ctx.EmbeddingSize; | |||
| /// <summary> | |||
| /// Create a new embedder (loading temporary weights) | |||
| /// </summary> | |||
| /// <param name="allParams"></param> | |||
| [Obsolete("Preload LLamaWeights and use the constructor which accepts them")] | |||
| public LLamaEmbedder(ILLamaParams allParams) | |||
| : this(allParams, allParams) | |||
| { | |||
| } | |||
| /// <summary> | |||
| /// Create a new embedder (loading temporary weights) | |||
| /// </summary> | |||
| /// <param name="modelParams"></param> | |||
| /// <param name="contextParams"></param> | |||
| [Obsolete("Preload LLamaWeights and use the constructor which accepts them")] | |||
| public LLamaEmbedder(IModelParams modelParams, IContextParams contextParams) | |||
| { | |||
| using var weights = LLamaWeights.LoadFromFile(modelParams); | |||
| @@ -31,6 +42,11 @@ namespace LLama | |||
| _ctx = weights.CreateContext(contextParams); | |||
| } | |||
| /// <summary> | |||
| /// Create a new embedder, using the given LLamaWeights | |||
| /// </summary> | |||
| /// <param name="weights"></param> | |||
| /// <param name="params"></param> | |||
| public LLamaEmbedder(LLamaWeights weights, IContextParams @params) | |||
| { | |||
| @params.EmbeddingMode = true; | |||
| @@ -114,7 +114,7 @@ namespace LLama | |||
| } | |||
| else | |||
| { | |||
| _logger?.LogWarning($"[LLamaExecutor] Session file does not exist, will create"); | |||
| _logger?.LogWarning("[LLamaExecutor] Session file does not exist, will create"); | |||
| } | |||
| _n_matching_session_tokens = 0; | |||
| @@ -18,10 +18,10 @@ namespace LLama | |||
| /// </summary> | |||
| public class InstructExecutor : StatefulExecutorBase | |||
| { | |||
| bool _is_prompt_run = true; | |||
| string _instructionPrefix; | |||
| llama_token[] _inp_pfx; | |||
| llama_token[] _inp_sfx; | |||
| private bool _is_prompt_run = true; | |||
| private readonly string _instructionPrefix; | |||
| private llama_token[] _inp_pfx; | |||
| private llama_token[] _inp_sfx; | |||
| /// <summary> | |||
| /// | |||
| @@ -80,6 +80,7 @@ namespace LLama | |||
| return true; | |||
| case LLamaFtype.LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16: | |||
| case LLamaFtype.LLAMA_FTYPE_GUESSED: | |||
| default: | |||
| return false; | |||
| } | |||
| @@ -11,13 +11,11 @@ namespace LLama | |||
| public sealed class LLamaWeights | |||
| : IDisposable | |||
| { | |||
| private readonly SafeLlamaModelHandle _weights; | |||
| /// <summary> | |||
| /// The native handle, which is used in the native APIs | |||
| /// </summary> | |||
| /// <remarks>Be careful how you use this!</remarks> | |||
| public SafeLlamaModelHandle NativeHandle => _weights; | |||
| public SafeLlamaModelHandle NativeHandle { get; } | |||
| /// <summary> | |||
| /// Total number of tokens in vocabulary of this model | |||
| @@ -46,7 +44,7 @@ namespace LLama | |||
| internal LLamaWeights(SafeLlamaModelHandle weights) | |||
| { | |||
| _weights = weights; | |||
| NativeHandle = weights; | |||
| } | |||
| /// <summary> | |||
| @@ -66,7 +64,7 @@ namespace LLama | |||
| if (adapter.Scale <= 0) | |||
| continue; | |||
| weights.ApplyLoraFromFile(adapter.Path, adapter.Scale, @params.LoraBase, @params.Threads); | |||
| weights.ApplyLoraFromFile(adapter.Path, adapter.Scale, @params.LoraBase); | |||
| } | |||
| return new LLamaWeights(weights); | |||
| @@ -75,7 +73,7 @@ namespace LLama | |||
| /// <inheritdoc /> | |||
| public void Dispose() | |||
| { | |||
| _weights.Dispose(); | |||
| NativeHandle.Dispose(); | |||
| } | |||
| /// <summary> | |||
| @@ -4,11 +4,18 @@ namespace LLama.Native; | |||
| using llama_token = Int32; | |||
| /// <summary> | |||
| /// Input data for llama_decode. A llama_batch object can contain input about one or many sequences. | |||
| /// </summary> | |||
| public sealed class LLamaBatchSafeHandle | |||
| : SafeLLamaHandleBase | |||
| { | |||
| private readonly int _embd; | |||
| public LLamaNativeBatch Batch { get; private set; } | |||
| /// <summary> | |||
| /// Get the native llama_batch struct | |||
| /// </summary> | |||
| public LLamaNativeBatch NativeBatch { get; private set; } | |||
| /// <summary> | |||
| /// the token ids of the input (used when embd is NULL) | |||
| @@ -22,7 +29,7 @@ public sealed class LLamaBatchSafeHandle | |||
| if (_embd != 0) | |||
| return new Span<int>(null, 0); | |||
| else | |||
| return new Span<int>(Batch.token, Batch.n_tokens); | |||
| return new Span<int>(NativeBatch.token, NativeBatch.n_tokens); | |||
| } | |||
| } | |||
| } | |||
| @@ -37,10 +44,10 @@ public sealed class LLamaBatchSafeHandle | |||
| unsafe | |||
| { | |||
| // If embd != 0, llama_batch.embd will be allocated with size of n_tokens *embd * sizeof(float) | |||
| /// Otherwise, llama_batch.token will be allocated to store n_tokens llama_token | |||
| // Otherwise, llama_batch.token will be allocated to store n_tokens llama_token | |||
| if (_embd != 0) | |||
| return new Span<llama_token>(Batch.embd, Batch.n_tokens * _embd); | |||
| return new Span<llama_token>(NativeBatch.embd, NativeBatch.n_tokens * _embd); | |||
| else | |||
| return new Span<llama_token>(null, 0); | |||
| } | |||
| @@ -56,7 +63,7 @@ public sealed class LLamaBatchSafeHandle | |||
| { | |||
| unsafe | |||
| { | |||
| return new Span<LLamaPos>(Batch.pos, Batch.n_tokens); | |||
| return new Span<LLamaPos>(NativeBatch.pos, NativeBatch.n_tokens); | |||
| } | |||
| } | |||
| } | |||
| @@ -70,7 +77,7 @@ public sealed class LLamaBatchSafeHandle | |||
| { | |||
| unsafe | |||
| { | |||
| return new Span<LLamaSeqId>(Batch.seq_id, Batch.n_tokens); | |||
| return new Span<LLamaSeqId>(NativeBatch.seq_id, NativeBatch.n_tokens); | |||
| } | |||
| } | |||
| } | |||
| @@ -84,22 +91,40 @@ public sealed class LLamaBatchSafeHandle | |||
| { | |||
| unsafe | |||
| { | |||
| return new Span<byte>(Batch.logits, Batch.n_tokens); | |||
| return new Span<byte>(NativeBatch.logits, NativeBatch.n_tokens); | |||
| } | |||
| } | |||
| } | |||
| public LLamaBatchSafeHandle(int n_tokens, int embd) | |||
| /// <summary> | |||
| /// Create a safe handle owning a `LLamaNativeBatch` | |||
| /// </summary> | |||
| /// <param name="batch"></param> | |||
| /// <param name="embd"></param> | |||
| public LLamaBatchSafeHandle(LLamaNativeBatch batch, int embd) | |||
| : base((nint)1) | |||
| { | |||
| _embd = embd; | |||
| Batch = NativeApi.llama_batch_init(n_tokens, embd); | |||
| NativeBatch = batch; | |||
| } | |||
| /// <summary> | |||
| /// Call `llama_batch_init` and create a new batch | |||
| /// </summary> | |||
| /// <param name="n_tokens"></param> | |||
| /// <param name="embd"></param> | |||
| /// <returns></returns> | |||
| public static LLamaBatchSafeHandle Create(int n_tokens, int embd) | |||
| { | |||
| var batch = NativeApi.llama_batch_init(n_tokens, embd); | |||
| return new LLamaBatchSafeHandle(batch, embd); | |||
| } | |||
| /// <inheritdoc /> | |||
| protected override bool ReleaseHandle() | |||
| { | |||
| NativeApi.llama_batch_free(Batch); | |||
| Batch = default; | |||
| NativeApi.llama_batch_free(NativeBatch); | |||
| NativeBatch = default; | |||
| SetHandle(IntPtr.Zero); | |||
| return true; | |||
| } | |||
| @@ -45,7 +45,7 @@ namespace LLama.Native | |||
| /// CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA]) | |||
| /// </summary> | |||
| CHAR_ALT = 6, | |||
| }; | |||
| } | |||
| /// <summary> | |||
| /// An element of a grammar | |||
| @@ -1,15 +1,26 @@ | |||
| namespace LLama.Native; | |||
| public record struct LLamaPos | |||
| /// <summary> | |||
| /// Indicates position in a sequence | |||
| /// </summary> | |||
| public readonly record struct LLamaPos(int Value) | |||
| { | |||
| public int Value; | |||
| public LLamaPos(int value) | |||
| { | |||
| Value = value; | |||
| } | |||
| /// <summary> | |||
| /// The raw value | |||
| /// </summary> | |||
| public readonly int Value = Value; | |||
| /// <summary> | |||
| /// Convert a LLamaPos into an integer (extract the raw value) | |||
| /// </summary> | |||
| /// <param name="pos"></param> | |||
| /// <returns></returns> | |||
| public static explicit operator int(LLamaPos pos) => pos.Value; | |||
| /// <summary> | |||
| /// Convert an integer into a LLamaPos | |||
| /// </summary> | |||
| /// <param name="value"></param> | |||
| /// <returns></returns> | |||
| public static implicit operator LLamaPos(int value) => new(value); | |||
| } | |||
| @@ -1,15 +1,26 @@ | |||
| namespace LLama.Native; | |||
| public record struct LLamaSeqId | |||
| /// <summary> | |||
| /// ID for a sequence in a batch | |||
| /// </summary> | |||
| /// <param name="Value"></param> | |||
| public record struct LLamaSeqId(int Value) | |||
| { | |||
| public int Value; | |||
| public LLamaSeqId(int value) | |||
| { | |||
| Value = value; | |||
| } | |||
| /// <summary> | |||
| /// The raw value | |||
| /// </summary> | |||
| public int Value = Value; | |||
| /// <summary> | |||
| /// Convert a LLamaSeqId into an integer (extract the raw value) | |||
| /// </summary> | |||
| /// <param name="pos"></param> | |||
| public static explicit operator int(LLamaSeqId pos) => pos.Value; | |||
| /// <summary> | |||
| /// Convert an integer into a LLamaSeqId | |||
| /// </summary> | |||
| /// <param name="value"></param> | |||
| /// <returns></returns> | |||
| public static explicit operator LLamaSeqId(int value) => new(value); | |||
| } | |||
| @@ -1,28 +1,28 @@ | |||
| using System.Runtime.InteropServices; | |||
| namespace LLama.Native | |||
| namespace LLama.Native; | |||
| /// <summary> | |||
| /// A single token along with probability of this token being selected | |||
| /// </summary> | |||
| /// <param name="id"></param> | |||
| /// <param name="logit"></param> | |||
| /// <param name="p"></param> | |||
| [StructLayout(LayoutKind.Sequential)] | |||
| public record struct LLamaTokenData(int id, float logit, float p) | |||
| { | |||
| [StructLayout(LayoutKind.Sequential)] | |||
| public struct LLamaTokenData | |||
| { | |||
| /// <summary> | |||
| /// token id | |||
| /// </summary> | |||
| public int id; | |||
| /// <summary> | |||
| /// log-odds of the token | |||
| /// </summary> | |||
| public float logit; | |||
| /// <summary> | |||
| /// probability of the token | |||
| /// </summary> | |||
| public float p; | |||
| /// <summary> | |||
| /// token id | |||
| /// </summary> | |||
| public int id = id; | |||
| /// <summary> | |||
| /// log-odds of the token | |||
| /// </summary> | |||
| public float logit = logit; | |||
| public LLamaTokenData(int id, float logit, float p) | |||
| { | |||
| this.id = id; | |||
| this.logit = logit; | |||
| this.p = p; | |||
| } | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// probability of the token | |||
| /// </summary> | |||
| public float p = p; | |||
| } | |||
| @@ -1,7 +1,5 @@ | |||
| using System; | |||
| using System.Buffers; | |||
| using System.Runtime.CompilerServices; | |||
| using System.Runtime.InteropServices; | |||
| using System.Text; | |||
| using LLama.Exceptions; | |||
| @@ -212,9 +210,17 @@ namespace LLama.Native | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// </summary> | |||
| /// <param name="batch"></param> | |||
| /// <returns>Positive return values does not mean a fatal error, but rather a warning:<br /> | |||
| /// - 0: success<br /> | |||
| /// - 1: could not find a KV slot for the batch (try reducing the size of the batch or increase the context)<br /> | |||
| /// - < 0: error<br /> | |||
| /// </returns> | |||
| public int Decode(LLamaBatchSafeHandle batch) | |||
| { | |||
| return NativeApi.llama_decode(this, batch.Batch); | |||
| return NativeApi.llama_decode(this, batch.NativeBatch); | |||
| } | |||
| #region state | |||
| @@ -84,14 +84,14 @@ namespace LLama.Native | |||
| /// adapter. Can be NULL to use the current loaded model.</param> | |||
| /// <param name="threads"></param> | |||
| /// <exception cref="RuntimeError"></exception> | |||
| public void ApplyLoraFromFile(string lora, float scale, string? modelBase = null, uint? threads = null) | |||
| public void ApplyLoraFromFile(string lora, float scale, string? modelBase = null, int? threads = null) | |||
| { | |||
| var err = NativeApi.llama_model_apply_lora_from_file( | |||
| this, | |||
| lora, | |||
| scale, | |||
| string.IsNullOrEmpty(modelBase) ? null : modelBase, | |||
| (int?)threads ?? -1 | |||
| threads ?? Math.Max(1, Environment.ProcessorCount / 2) | |||
| ); | |||
| if (err != 0) | |||